当前位置: 首页 > news >正文

人工智能开发框架 10. MNIST手写数字识别任务(三)

目录

步骤六、样本可视化

步骤七、定义网络

步骤八、定义损失函数及优化器


步骤六、样本可视化

读取前10个样本,然后进行样本可视化,以此来确定样本是不是真实数据集。

#显示前10张图片以及对应标签,检查图片是否是正确的数据集
data = DS.create_dict_iterator().__next__()
images = data['image'].asnumpy()
labels = data['label'].asnumpy()
plt.figure(figsize=(15,5))
for i in range(1,11):plt.subplot(2, 5, i)plt.imshow(np.squeeze(images[i]))plt.title('Number: %s' % labels[i])plt.xticks([])
plt.show()

图2-1 样本可视化

步骤七、定义网络

我们通过定义一个简单的全连接网络来完成图像识别,网络只有3层:

第一层全连接层,形状为784*512;

第二层全连接层,形状为512*128;

最后一层输出层,形状为128*10。

使用MindSpore定义神经网络需要继承mindspore.nn.Cell。Cell是所有神经网络(Conv2d等)的基类。

神经网络的各层需要预先在__init__方法中定义,然后通过定义construct方法来完成神经网络的前向构造。定义网络各层如下:

#创建模型。模型包括3个全连接层,最后输出层使用softmax进行多分类,共分成(0-9)10类
class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsnet = Network()

步骤八、定义损失函数及优化器

损失函数:又叫目标函数,用于衡量预测值与实际值差异的程度。深度学习通过不停地迭代来缩小损失值。定义一个好的损失函数,可以有效提高模型的性能。

优化器:用于最小化损失函数,从而在训练过程中改进模型。

定义了损失函数后,可以得到损失函数关于权重的梯度。梯度用于指示优化器优化权重的方向,以提高模型性能。MindSpore支持的损失函数有SoftmaxCrossEntropyWithLogits、L1Loss、MSELoss等。这里使用SoftmaxCrossEntropyWithLogits损失函数。

MindSpore提供了callback机制,可以在训练过程中执行自定义逻辑,这里使用框架提供的ModelCheckpoint为例。 ModelCheckpoint可以保存网络模型和参数,以便进行后续的fine-tuning(微调)操作。

#创建网络,损失函数,评估指标  优化器,设定相关超参数
lr = 0.001
num_epoch = 10
momentum = 0.9
loss = nn.loss.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
metrics={"Accuracy": Accuracy()}
opt = nn.Adam(net.trainable_params(), lr) 

http://www.dtcms.com/a/314661.html

相关文章:

  • 补:《每日AI-人工智能-编程日报》--2025年7月27日
  • STM32 串口收发HEX数据包
  • 汇川PLC通过ModbusTCP转Profinet网关连接西门子PLC配置案例
  • Linux Epool的作用
  • el-image图片预览下标错乱--解决:initial-index
  • 体验Java接入langchain4j运用大模型OpenAi
  • [激光原理与应用-134]:光学器件 - 图解透镜原理和元件
  • stm32/gd32驱动DAC8830
  • 川翔云电脑:引领开启算力无边界时代
  • 【云馨AI-大模型】2025年8月第一周AI浪潮席卷全球:创新与政策双轮驱动
  • Spring核心之面向切面编程(AOP)
  • 专题:2025生命科学与生物制药全景报告:产业图谱、投资方向及策略洞察|附130+份报告PDF、原数据表汇总下载
  • mysql远程登陆失败
  • 昇思学习营-模型推理和性能优化学习心得
  • 北京手机基站数据分享:9.3万点位+双格式,解锁城市通信「基础设施地图」
  • FreeRTOS学习(一)
  • VUE+SPRINGBOOT从0-1打造前后端-前后台系统-注册实现
  • 网络安全 | 从 0 到 1 了解 WAF:Web 应用防火墙到底是什么?
  • 【Unity3D】Ctrl+Shift+P暂停快捷键(Unity键盘快捷键)用不了问题快捷键无法使用问题
  • 规则方法关系抽取-笔记总结
  • 《Leetcode》-面试题-hot100-子串
  • 数据结构(2)
  • AI开发框架与工具:构建智能应用的技术基石
  • 从感知到创造:无穿戴动捕技术构建中小学人工智能实验教学场景
  • go学习笔记:panic是什么含义
  • AI鉴伪技术鉴赏:“看不见”的伪造痕迹如何被AI识破
  • 每日任务day0804:小小勇者成长记之药剂师的小咪
  • Design Compiler:高层次优化与数据通路优化
  • openeuler离线安装软件
  • 段落注入(Passage Injection):让RAG系统在噪声中保持清醒的推理能力