当前位置: 首页 > news >正文

TensorFlow2 Python深度学习 - 卷积神经网络示例-使用MNIST识别数字示例

锋哥原创的TensorFlow2 Python深度学习视频教程:

https://www.bilibili.com/video/BV1X5xVz6E4w/

课程介绍

本课程主要讲解基于TensorFlow2的Python深度学习知识,包括深度学习概述,TensorFlow2框架入门知识,以及卷积神经网络(CNN),循环神经网络(RNN),生成对抗网络(GAN),模型保存与加载等。

TensorFlow2 Python深度学习 - 卷积神经网络示例-使用MNIST识别数字示例

MNIST数据集介绍

MNIST(Modified National Institute of Standards and Technology)数据集是一个常用于机器学习和深度学习领域的经典数据集,特别是在图像识别任务中。它由美国国家标准与技术研究院(NIST)提供,广泛用于手写数字识别的研究和算法测试。

主要特点:

  1. 数据内容:

    • MNIST数据集包含了28x28像素的灰度图像,表示从0到9的手写数字。每个图像展示了一个单一的手写数字(0到9之一)。

    • 数据集分为两个部分:

      • 训练集:包含60,000个样本,用于训练模型。

      • 测试集:包含10,000个样本,用于测试和评估模型的性能。

  2. 标签信息:

    • 每个图像都有一个对应的标签,表示图像中手写数字的真实值(即0到9之间的某个数字)。

  3. 数据预处理:

    • 图像的大小是28x28像素,灰度级别为0到255,其中0表示白色,255表示黑色。图像通常在输入神经网络之前会被标准化或者归一化。

  4. 应用领域:

    • 手写数字识别:这是MINIST数据集的经典应用,用于测试各种机器学习算法的性能。

    • 分类问题:可以用于对比不同模型(如支持向量机、神经网络、决策树等)的分类准确性。

卷积神经网络示例-使用MNIST识别数字示例

import tensorflow as tf
from keras import Input, layers
from matplotlib import pyplot as plt
​
# 1,加载MNIST数据集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
print(x_train[0], x_train[0].shape)
print(y_train, y_train.shape)
​
# 2,数据预处理
x_train = x_train / 255.0  # 归一化
x_test = x_test / 255.0  # 归一化
print(x_train[0], x_train[0].shape)
# 将数据重塑为 (样本数, 高, 宽, 通道数) 的形状
print(x_train, x_train.shape)
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
print(x_train, x_train.shape)
​
# 3,构建CNN模型
model = tf.keras.models.Sequential([Input(shape=(28, 28, 1)),layers.Conv2D(32, (3, 3), activation='relu'),  # 第一卷积层,卷积核大小3x3,滤波器数为32,ReLU激活函数layers.MaxPooling2D((2, 2)),  # 第一池化层,2x2最大池化layers.Conv2D(64, (3, 3), activation='relu'),  # 第二卷积层,卷积核大小3x3,滤波器数为64,ReLU激活函数layers.MaxPooling2D((2, 2)),  # 第二池化层,2x2最大池化layers.Flatten(),  # 展平层 将二维特征图展平为一维layers.Dense(64, activation='relu'),  # 全连接层,64个神经元,ReLU激活函数layers.Dense(10, activation='softmax')  # 输出层,10个神经元(对应数字0-9),softmax激活函数
])
​
# 4,模型编译
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
​
# 5,模型训练
history = model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test), verbose=1)
​
# 6,模型评估
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"Test accuracy: {test_acc}")

运行结果:

可视化训练过程:

# 设置matplotlib使用黑体显示中文
plt.rcParams['font.family'] = 'Microsoft YaHei'
​
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.xlabel('轮次')
plt.ylabel('准确率')
plt.legend()
plt.show()

预测结果:

# 预测测试集中的一张图片
predictions = model.predict(x_test)
​
# 显示第一个预测结果
print(f"Predicted label: {predictions[0].argmax()}")
print(f"True label: {y_test[0]}")
​
# 显示第一张图片
plt.imshow(x_test[0].reshape(28, 28), cmap='gray')
plt.show()

http://www.dtcms.com/a/486888.html

相关文章:

  • LKT4305GM多功能安全芯片
  • 大连网站建设蛇皮果服装设计公司排行
  • 淄博网站建设公司乐达站长工具综合查询官网
  • 7. 从0到上线:.NET 8 + ML.NET LTR 智能类目匹配实战--反馈存储与数据治理:MongoDB 设计与运维
  • C语言基础知识回顾
  • 未来之窗昭和仙君(二十)订单通知提醒——东方仙盟筑基期
  • 网址链接在桌面上创建快捷方式步骤
  • UVa 10766 Organising the Organisation
  • FastDFS 可观测性最佳实践
  • 网站推广在哪些平台做外链广州建工集团有限公司官网
  • Linux中字符串拷贝函数strlcpy的实现
  • PostgreSQL 18 发布
  • DrissionPage下载文件
  • 观澜做网站公司百度seo网站在线诊断
  • 电子商务网站建设题目男女直接做网站
  • 前端 Web 开发工具全流程指南,打造高效开发与调试体系
  • html网站中文模板下载seo营销型网站
  • 【编号220】中国国内生产总值历史数据汇编1952-2021合订本(PDF扫描版)
  • 百度多久收录一次网站北京企业网站建设飞沐
  • 特斯拉前AI总监开源的一款“小型本地版ChatGPT”,普通家用电脑就能运行!
  • 鸿蒙:创建公共事件、订阅公共事件和退订公共事件
  • 鸿蒙NEXT Function Flow Runtime开发指南:掌握下一代并发编程
  • 遥控器外壳设计网站推荐哈尔滨建设信息网官网
  • 哈夫曼树 红黑树 B树 B+树 WTF!M3?(树形查找)
  • 【Linux内核】DMABUF 与文件描述符(fd)的绑定过程
  • AngularJS 模型
  • 网页设计与网站建设毕业设计成全看免费观看
  • MySQL数据库操作全指南(一)
  • 【项目】年会抽奖系统
  • 烟台建站程序如何用电脑主机做网站