当前位置: 首页 > news >正文

TensorFlow2 Python深度学习 - 卷积神经网络示例2-使用Fashion MNIST识别时装示例

锋哥原创的TensorFlow2 Python深度学习视频教程:

https://www.bilibili.com/video/BV1X5xVz6E4w/

课程介绍

本课程主要讲解基于TensorFlow2的Python深度学习知识,包括深度学习概述,TensorFlow2框架入门知识,以及卷积神经网络(CNN),循环神经网络(RNN),生成对抗网络(GAN),模型保存与加载等。

TensorFlow2 Python深度学习 - 卷积神经网络示例2-使用Fashion MNIST识别时装示例

Fashion MNIST数据集介绍

Fashion MNIST 是一个由 Zalando 公司发布的图像分类数据集,通常用于机器学习和计算机视觉任务,特别是图像分类的研究。它可以作为 MNIST 数据集(手写数字图像分类)的替代,因为它包含了更多的实际应用场景,并且数据类型更加复杂。Fashion MNIST 被广泛应用于算法验证和基准测试。

数据集特点

  • 图像尺寸:每张图像是 28x28 像素的灰度图。

  • 类别:该数据集包含 10 类不同的服饰产品。具体类别包括:

    1. T恤/上衣(T-shirt/top)

    2. 裙子(Trouser)

    3. 套头衫(Pullover)

    4. 连衣裙(Dress)

    5. 外套(Coat)

    6. 凉鞋(Sandal)

    7. 衬衫(Shirt)

    8. 运动鞋(Sneaker)

    9. 包(Bag)

    10. 踝靴(Ankle boot)

    这些图像是灰度图像,意味着每个像素的值在 0 到 255 之间,表示不同的灰度强度。

数据集构成

  • 训练集:包含 60,000 张图像,均匀分布在 10 个类别中。

  • 测试集:包含 10,000 张图像,用于评估模型的性能。

使用场景

Fashion MNIST 被用作初学者和研究者训练和测试图像分类模型的标准数据集。它的难度适中,适合用于以下应用:

  • 深度学习入门:用于神经网络(尤其是卷积神经网络)的训练。

  • 算法对比:不同机器学习算法的性能比较。

  • 特征提取与学习:用于测试特征工程和表示学习方法。

  • 图像分类基础:了解和练习分类任务。

卷积神经网络示例2-使用Fashion MNIST识别时装示例

import tensorflow as tf
from keras import Input, layers
from matplotlib import pyplot as plt
​
# 1,加载Fashion MINIST数据集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
print(x_train[0], x_train[0].shape)
print(y_train, y_train.shape)
​
# 2,数据预处理
x_train = x_train / 255.0  # 归一化
x_test = x_test / 255.0  # 归一化
print(x_train[0], x_train[0].shape)
# 将数据重塑为 (样本数, 高, 宽, 通道数) 的形状
print(x_train, x_train.shape)
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
print(x_train, x_train.shape)
​
# 3,构建CNN模型
model = tf.keras.models.Sequential([Input(shape=(28, 28, 1)),layers.Conv2D(32, (3, 3), activation='relu'),  # 第一卷积层,卷积核大小3x3,滤波器数为32,ReLU激活函数layers.MaxPooling2D((2, 2)),  # 第一池化层,2x2最大池化layers.Conv2D(64, (3, 3), activation='relu'),  # 第二卷积层,卷积核大小3x3,滤波器数为64,ReLU激活函数layers.MaxPooling2D((2, 2)),  # 第二池化层,2x2最大池化layers.Conv2D(128, (3, 3), activation='relu'),  # 第三卷积层,卷积核大小3x3,滤波器数为64,ReLU激活函数layers.MaxPooling2D((2, 2)),  # 第三池化层,2x2最大池化layers.Flatten(),  # 展平层 将二维特征图展平为一维layers.Dense(512, activation='relu'),  # 全连接层,512个神经元,ReLU激活函数layers.Dense(10, activation='softmax')  # 输出层,10个神经元(对应数字0-9),softmax激活函数
])
​
# 4,模型编译
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
​
# 5,模型训练
history = model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test), verbose=1)
​
# 6,模型评估
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"Test accuracy: {test_acc}")

项目运行结果:

可视化训练过程:

# 7,可视化训练过程
# 设置matplotlib使用黑体显示中文
plt.rcParams['font.family'] = 'Microsoft YaHei'
​
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.xlabel('轮次')
plt.ylabel('准确率')
plt.legend()
plt.show()

预测结果:

# 预测测试集中的一张图片
predictions = model.predict(x_test)
​
# 显示第一个预测结果
print(f"Predicted label: {predictions[0].argmax()}")
print(f"True label: {y_test[0]}")
​
# 显示第一张图片
plt.imshow(x_test[0].reshape(28, 28), cmap='gray')
plt.show()

http://www.dtcms.com/a/490684.html

相关文章:

  • Eureka: Human-Level Reward Design via Coding Large Language Models 译读笔记
  • 随时随地看监控:我的UptimeKuma远程访问改造记
  • 关于网站篡改应急演练剧本编写(模拟真实场景)
  • 河北省企业网站建设公司企业管理系统软件有哪些
  • JVM的classpath
  • RVO优化
  • ethercat 环型拓扑(Ring Topology)
  • 颠覆PD快充、工业控制与智能家电等领域高CTR,高隔离电压高可靠性光电耦合器OCT1018/OCT1019
  • 【机器学习入门】8.1 降维的概念和意义:一文读懂降维的概念与意义 —— 从 “维度灾难” 到低维嵌入
  • 黄骅市旅游景点有哪些盐城网站关键词优化
  • 对于网站建设的调查问卷爱南宁app官网下载
  • 一文读懂 YOLOv1 与 YOLOv2:目标检测领域的早期里程碑
  • 在 Windows 10/11 LTSC等精简系统中安装Winget和微软应用商店,Windows Server安装Microsoft Store的应用
  • A2A架构详解
  • 基础 - SQL命令速查
  • logo图片素材大全sem和seo都包括什么
  • 把 AI“缝”进布里:生成式编织神经网络让布料自带摄像头
  • 岳阳建网站长沙网站优化价格
  • [Sora] 分布式训练 | 并行化策略 | `plugin_type` | `booster.boost()`
  • Linux系统函数link、unlink与dentry的关系及使用注意事项
  • 安卓手机 IP 切换指南:告别卡顿,轻松换 IP
  • 微服务拆分:领域驱动设计,单体应用如何平滑迁移?
  • 企业网站推广的形式有哪些福州网站推广排名
  • 关键词优化网站排名群英云服务器
  • nano-GPT:最小可复现的GPT实操
  • 网站建设公众号wordpress中文模板下载地址
  • 菜单及库(Num28)
  • super()核心作用是调用父类的属性/方法
  • 【Win32 多线程程序设计基础第三章笔记】
  • CentOS 7 FTP安装与配置详细介绍