当前位置: 首页 > news >正文

Python 基于卷积神经网络手写数字识别

Ubuntu系统:22.04

python版本:3.9

安装依赖库:

pip install tensorflow==2.13 matplotlib numpy -i https://mirrors.aliyun.com/pypi/simple

代码实现:

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
import numpy as np
import matplotlib.pyplot as plt# 加载MNIST数据集
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# 数据预处理
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') / 255
test_images = test_images.reshape(test_images.shape[0], 28, 28, 1).astype('float32') / 255# 构建CNN模型
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(10, activation='softmax'))# 编译模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 训练模型
history = model.fit(train_images, train_labels,batch_size=128,epochs=5,verbose=1,validation_data=(test_images, test_labels))# 评估模型
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=0)
print(f"\n测试准确率: {test_acc:.4f}")# 保存模型
model.save('mnist_cnn_model.keras')
print("模型已保存为 mnist_cnn_model.keras")# 可视化训练过程
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.title('模型准确率')
plt.ylabel('准确率')
plt.xlabel('训练轮次')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('模型损失')
plt.ylabel('损失')
plt.xlabel('训练轮次')
plt.legend()plt.tight_layout()
plt.savefig('training_history.png')
print("训练过程图表已保存为 training_history.png")# 测试预测
sample_idx = np.random.randint(0, len(test_images))
sample_image = test_images[sample_idx].reshape(1, 28, 28, 1)
prediction = model.predict(sample_image, verbose=0)plt.figure(figsize=(5, 3))
plt.imshow(test_images[sample_idx].reshape(28, 28), cmap='gray')
plt.title(f"真实标签: {test_labels[sample_idx]}\n预测结果: {np.argmax(prediction)}")
plt.axis('off')
plt.savefig('sample_prediction.png')
print(f"样本预测图已保存为 sample_prediction.png\n真实标签: {test_labels[sample_idx]},预测结果: {np.argmax(prediction)}")

相关文章:

  • 基于ELK的分布式日志实时分析与可视化系统设计
  • PHP序列化和反序列化
  • 分布式数据库备份实践
  • word文档格式规范(论文格式规范、word格式、论文格式、文章格式、格式prompt)
  • python中使用高并发分布式队列库celery的那些坑
  • 基于Web的分布式图集管理系统架构设计与实践
  • ICASSP2025丨融合语音停顿信息与语言模型的阿尔兹海默病检测
  • 分布式不同数据的一致性模型
  • 从零实现基于BERT的中文文本情感分析的任务
  • 分布式CAP理论
  • 【STIP】安全Transformer推理协议
  • 云原生时代 Kafka 深度实践:02快速上手与环境搭建
  • pcie gen3 phy tx
  • t009-线上代驾管理系统
  • StarRocks x Iceberg:云原生湖仓分析技术揭秘与最佳实践
  • Apache Kafka 实现原理深度解析:生产、存储与消费全流程
  • 如何在 Ubuntu 24.04 服务器上安装 Apache Solr
  • 高密爆炸警钟长鸣:AI为化工安全戴上“智能护盾”
  • QuickBASIC QB64 支持 64 位系统和跨平台Linux/MAC OS
  • 【深度学习新浪潮】什么是混合精度分解?
  • wordpress商城支付宝/网站内部优化有哪些内容
  • web网站代做/网页制作的步骤
  • 电子商务网站规划原则/线上平台推广方式
  • 北京建设网站方舟爸爸/百度下载官方下载安装
  • 比较好看的网站设计/目前最好的引流推广方法
  • 网站优化关键词公司/web成品网站源码免费