当前位置: 首页 > news >正文

T30_Keras示例(MINST)

MNIST 手写数字识别

加载本地图片

网上下载的资源是train-images-idx3-ubyte数据,需要解析成图片

import numpy as np
import cv2
import os
import structdef save_mnist_to_jpg(mnist_image_file, mnist_label_file, save_dir):if 'train' in os.path.basename(mnist_image_file):prefix = 'train'else:prefix = 'test'labelIndex = 0imageIndex = 0i = 0lbdata = open(mnist_label_file, 'rb').read()magic, nums = struct.unpack_from(">II", lbdata, labelIndex)labelIndex += struct.calcsize('>II')imgdata = open(mnist_image_file, "rb").read()magic, nums, numRows, numColumns = struct.unpack_from('>IIII', imgdata, imageIndex)imageIndex += struct.calcsize('>IIII')for i in range(nums):label = struct.unpack_from('>B', lbdata, labelIndex)[0]labelIndex += struct.calcsize('>B')im = struct.unpack_from('>784B', imgdata, imageIndex)imageIndex += struct.calcsize('>784B')im = np.array(im, dtype='uint8')img = im.reshape(28, 28)save_name = os.path.join(save_dir, '{}_{}_{}.jpg'.format(prefix, i, label))cv2.imwrite(save_name, img)train_images = r'D:\MyCode\MNIST\train-images-idx3-ubyte'  # 训练集图像的文件名
train_labels = r'D:\MyCode\MNIST\train-labels-idx1-ubyte'  # 训练集label的文件名
test_images = r'D:\MyCode\MNIST\t10k-images-idx3-ubyte'  # 测试集图像的文件名
test_labels = r'D:\MyCode\MNIST\t10k-labels-idx1-ubyte'  # 测试集label的文件名save_train_dir = r'D:\MyCode\MNIST\raw\data\mk\train_images'
save_test_dir = r'D:\MyCode\MNIST\raw\data\mk\test_images/'if not os.path.exists(save_train_dir):os.makedirs(save_train_dir)
if not os.path.exists(save_test_dir):os.makedirs(save_test_dir)save_mnist_to_jpg(test_images, test_labels, save_test_dir)
save_mnist_to_jpg(train_images, train_labels, save_train_dir)

加载图片,制作dataset,由于图片的名称是按照一定格式命名的(train_index_no.jpg),直接从文件名中获取label

import pathlibdata_path = pathlib.Path(r'D:\MyCode\MNIST\raw\data\mk\train_images')
all_image_paths = list(data_path.glob('*.jpg'))  
all_image_paths = [str(path) for path in all_image_paths]  # 所有图片路径的列表
print(len(all_image_paths))# 获取数据标签
all_image_labels=[]
for img_path in all_image_paths:file_name = pathlib.Path(img_path).nameall_image_labels.append(file_name[0:len(file_name)-4].split('_')[2])test_data_path = pathlib.Path(r'D:\MyCode\MNIST\raw\data\mk\test_images')
test_image_paths = list(test_data_path.glob('*.jpg'))  
test_image_paths = [str(path) for path in test_image_paths]  # 所有图片路径的列表
print(len(test_image_paths))# 获取数据标签
test_image_labels=[]
for img_path in test_image_paths:file_name = pathlib.Path(img_path).nametest_image_labels.append(file_name[0:len(file_name)-4].split('_')[2])

可以简单验证一下

for img,label,i in zip(all_image_paths,all_image_labels,range(0,len(all_image_paths))):print(img,'---->',label)if(i>10):break

创建dataset

import tensorflow as tfdef load_image_label(path,label):img = tf.io.read_file(path)img = tf.image.decode_jpeg(img,channels=3)img = tf.image.resize(img,[28,28])img /=255.0 #归一化,在sklearn的学习中,明确指出神经网络算法对数据缩放敏感return  tf.reshape(img, [28*28]),tf.one_hot(int(label), depth=10) # 为什要reshape一下,看自定以网络就明白了img_path_ds = tf.data.Dataset.from_tensor_slices((all_image_paths,all_image_labels))
img_lab_ds = img_path_ds.map(load_image_label)img_test_path_ds =  tf.data.Dataset.from_tensor_slices((test_image_paths,test_image_labels))
img_test_lab_ds = img_test_path_ds.map(load_image_label)

自定义网络

定义网络层

class MyDense(keras.layers.Layer):def __init__(self, inp_dim, outp_dim):super(MyDense, self).__init__()self.kernel = self.add_weight('w', [inp_dim, outp_dim])self.bias = self.add_weight('b', [outp_dim])def call(self, inputs, training=None):out = inputs @ self.kernel + self.biasreturn out 

定义model

class MyModel(keras.Model):def __init__(self):super(MyModel, self).__init__()self.fc1 = MyDense(28*28, 256)self.fc2 = MyDense(256, 128)self.fc3 = MyDense(128, 64)self.fc4 = MyDense(64, 32)self.fc5 = MyDense(32, 10)def call(self, inputs, training=None):x = self.fc1(inputs)x = tf.nn.relu(x)x = self.fc2(x)x = tf.nn.relu(x)x = self.fc3(x)x = tf.nn.relu(x)x = self.fc4(x)x = tf.nn.relu(x)x = self.fc5(x) return x

模型训练与保存

model = MyModel()model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.01),loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy']
)history = model.fit(img_lab_ds.batch(128),epochs=20,
)tf.saved_model.save(model,"my_model") # 自定义的模型不能保存为H5格式文件

测试

x,y = next(iter(img_test_lab_ds.batch(1)))
# print('predict x:', x)
print('predict y:', y)
out = model.predict(x)
print(np.argmax(out)) # 输出预测值的索引(可以与y对比一下)
model.evaluate(img_test_lab_ds.batch(128)) # 模型测试,测试在 img_test_lab_ds 上的性能表现

可视化

模型训练时,将训练的参数保存到日志中

import datetimelog_dir = "logs/tfit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)history = model.fit(img_lab_ds.batch(128),epochs=20,callbacks=[tensorboard_callback]
)

启动tensorboard,指定路径

tensorboard --logdir D:\MyCode\jupyter\demo\tensorflow\logs\tfit

根据提示,在浏览器中输入地址打开即可

http://www.dtcms.com/a/308230.html

相关文章:

  • 逻辑回归——银行贷款案例分析
  • 内存网格、KV存储和Redis的概念、使用场景及异同
  • 企业签名的多种形式
  • 【AI落地应用实战】基于 Amazon Bedrock + DeepSeek构建 GraphRAG 应用程序
  • 30. background-size 有哪些属性
  • IO流专题
  • socket编程-UDP(1)-设计echo server进行接口使用
  • FPGA实现AD9361采集转SRIO与DSP交互,FPGA+DSP多核异构信号处理架构,提供2套工程源码和技术支持
  • 【12】大恒相机SDK C#开发 ——多相机开发,枚举所有相机,并按配置文件中的相机顺序 将所有相机加入设备列表,以便于对每个指定的相机操作
  • 存储学习笔记
  • CSS选择器常用语法
  • day24作业
  • 《Linux自动化运维三例:磁盘告警、服务守护与网络检测》​
  • Mysql超详细安装配置教程(详细图文,保姆级)
  • 掩码语言模型(MLM)技术解析:理论基础、演进脉络与应用创新
  • 【Prompt集合】一个学习英文单词更好的提示词
  • 从姑苏区人工智能大模型基础设施招标|学习服务器、AI处理器、GPU
  • 数据结构 ArrayList与顺序表
  • 机器学习——互信息(超详细)
  • 【物联网】基于树莓派的物联网开发【19】——树莓派搭建MQTT客户端及MQTTX使用
  • Vision Transformer(ViT)模型实例化PyTorch逐行实现
  • 从 MySQL 迁移到 TiDB:使用 SQL-Replay 工具进行真实线上流量回放测试 SOP
  • SpringBoot3.x入门到精通系列:1.2 开发环境搭建
  • 25-vue-photo-preview的使用及使用过程中的问题解决方案
  • 实战教程 ---- Nginx结合Lua实现WAF拦截并可视化配置教程框架
  • 走进computed,了解computed的前世今生
  • 【云故事探索】NO.16:阿里云弹性计算加速精准学 AI 教育普惠落地
  • 谁在托举Agent?阿里云抢滩Agent Infra新赛道
  • 安装 docker compose v2版 笔记250731
  • 对接八大应用渠道