当前位置: 首页 > news >正文

如何使用PyTorch搭建一个基础的神经网络并进行训练?

神经网络作为人工智能的重要组成部分,在图像处理、自然语言处理、语音识别、机器翻译等领域具有广泛的应用。今天的分享就来详细的介绍如何搭建一个简单的神经网络框架并进行训练。

首先定义一个名为NeuralNetwork的类,它继承了PyTorch框架的nn.Module类,用于创建神经网络。接下来部分是类的构造函数__init__(),用于初始化神经网络的各个层。在这个类中,初始化了以下层:

  • flatten:一个将输入展平的层。

  • hidden1:第一个隐藏层,输入大小为28x28(图像大小),输出大小为128。

  • hidden2:第二个隐藏层,输入大小为128,输出大小为128。

  • hidden3:第三个隐藏层,输入大小为128,输出大小为64。

  • out:输出层,输入大小为64,输出大小为10(类别数)。

# 定义一个名为NeuralNetwork的类,它继承了PyTorch框架的nn.Module类,用于创建神经网络。classNeuralNetwork(nn.Module):def__init__(self):# 继承父类nn.Module的方法和属性super(NeuralNetwork, self).__init__()# 数据进行展平操作self.flatten=nn.Flatten()# 定义一层线性神经网络self.hidden1=nn.Linear(28*28,128)self.hidden2=nn.Linear(128,128)self.hidden3=nn.Linear(128,64)self.out=nn.Linear(64,10)

另外想自学机器学习深度学习,不知道如何入门的同学,还为大家整理了一份入门路线图(更新迭代不下10次),包含基础、理论、代码、实战项目、必读论文等等,希望可以帮到大家,大家可以添加小助手无偿自取

这部分定义了前向传播方法forward(),通过前向传播计算输入数据x的输出。首先输入的数据先通过flatten层展平,然后依次经过隐藏层和激活函数进行线性变换和非线性处理。最后经过输出层输出预测结果x。(注意这个函数也是定义在NeuralNetwork之中的。)

defforward(self,x):        # 对输入数据进行展平操作x=self.flatten(x)        # 将数据传入第一层线性神经网络x=self.hidden1(x)        # 对神经网络的输出应用ReLU激活函数x=torch.relu(x)x=self.hidden2(x)        # 对神经网络的输出应用sigmoid激活函数x=torch.sigmoid(x)x=self.hidden3(x)x=torch.relu(x)        # 将数据传入输出层x=self.out(x)        # 最后经过输出层输出预测结果xreturnx

这行代码创建了一个model的神经网络模型实例,并将其移动到特定的设备(我这里使用的是GPU)上进行计算。

model=NeuralNetwork().to(device)

这行代码创建了一个CrossEntropyLoss(交叉熵损失函数)的实例,用作损失函数。​​​​​​​

#交叉熵损失函数loss_fn=nn.CrossEntropyLoss()这行代码创建了一个Adam优化器的实例,将模型参数和学习率作为参数传入。​​​​​​​
# model.parameters()用于获取模型的所有参数。这些参数包括权重和偏差等# 它们都是Tensor类型的,是神经网络的重要组成部分optimizer=torch.optim.Adam(model.parameters(),lr=0.005)

这行代码定义了一个名为train的函数,用于进行训练。函数接受训练数据、模型、损失函数和优化器作为输入参数。​​​​​​​

'''定义训练函数'''deftrain(dataloader,model,loss_fn,optimizer):# 设置模型为训练模式    model.train()# 记录优化次数    num=1# 遍历数据加载器中的每一个数据批次。for X,y in dataloader:        X,y=X.to(device),y.to(device)# 自动初始化权值w        pred=model.forward(X)        loss=loss_fn(pred,y) # 计算损失值# 将优化器的梯度缓存清零        optimizer.zero_grad()# 执行反向传播计算梯度        loss.backward()# 并通过优化器更新模型参数        optimizer.step()# 将损失值转换为标量        loss_value=loss.item()#将此次损失值打印出来        print(f'loss:{loss_value},[numbes]:{num}')#增加计数器num        num+=1

这行代码调用train函数来开始模型的训练。它传入训练数据、模型、损失函数和优化器,并执行训练过程。(注意训练前要先获取dataloader!)​​​​​​​

#调用函数train(),传入训练数据,神经网络模型,损失函数和优化算法train(train_dataloader,model,loss_fn,optimizer)

完整代码展示:​​​​​​​

'''定义神经网络'''classNeuralNetwork(nn.Module):def__init__(self):        super(NeuralNetwork, self).__init__()        self.flatten=nn.Flatten()        self.hidden1=nn.Linear(28*28,128)        self.hidden2=nn.Linear(128,128)        self.hidden3=nn.Linear(128,64)        self.out=nn.Linear(64,10)defforward(self,x):        x=self.flatten(x)        x=self.hidden1(x)        x=torch.relu(x)        x=self.hidden2(x)        x=torch.sigmoid(x)        x=self.hidden3(x)        x=torch.relu(x)        x=self.out(x)return xmodel=NeuralNetwork().to(device)print(model)'''建立损失函数和优化算法'''#交叉熵损失函数loss_fn=nn.CrossEntropyLoss()# 优化算法为随机梯度算法/Adam优化算法optimizer=torch.optim.Adam(model.parameters(),lr=0.005)'''定义训练函数'''deftrain(dataloader,model,loss_fn,optimizer):    model.train()# 记录优化次数    num=1for X,y in dataloader:        X,y=X.to(device),y.to(device)# 自动初始化权值w        pred=model.forward(X)        loss=loss_fn(pred,y) # 计算损失值        optimizer.zero_grad()        loss.backward()        optimizer.step()        loss_value=loss.item()        print(f'loss:{loss_value},[numbes]:{num}')        num+=1train(train_dataloader,model,loss_fn,optimizer)

http://www.dtcms.com/a/352797.html

相关文章:

  • skywalking 原理
  • H20 性能表现之 gpt-oss-120b
  • 软考-系统架构设计师 管理信息系统(MIS)详细讲解
  • React内网开发代理配置详解
  • C++ 力扣 704.二分查找 基础二分查找 题解 每日一题
  • Https之(四)国密GMTLS
  • 【Redis#8】Redis 数据结构 -- Zset 类型
  • 改造thinkphp6的命令行工具和分批次导出大量数据
  • GTCB:引领金融革命,打造数字经济时代标杆
  • 【js】加密库sha.js 严重漏洞速查
  • UTXO 模型及扩展模型
  • 香港数字资产交易市场蓬勃发展,监管与创新并驾齐驱
  • 完整实验命令解析:从集群搭建到负载均衡配置(2)
  • 记录使用ruoyi-flowable开发部署中出现的问题以及解决方法(二)
  • 电脑开机显示器不亮
  • 智能安防:以AI重塑安全新边界
  • 欧盟《人工智能法案》生效一年主要实施进展概览(一)
  • docker-runc not installed on system
  • 【科研绘图系列】R语言在海洋生态学数据可视化中的应用:以浮游植物叶绿素和初级生产力为例
  • Kafka 4.0 兼容性矩阵解读、升级顺序与降级边界
  • [特殊字符]论一个 bug 如何经过千难万险占领线上
  • 大数据毕业设计选题推荐-基于大数据的城镇居民食品消费量数据分析与可视化系统-Hadoop-Spark-数据可视化-BigData
  • electron应用开发:命令npm install electron的执行逻辑
  • 搜狗拼音输入法的一个bug
  • 解锁Java分布式魔法:CAP与BASE的奇幻冒险
  • 如何安装 mysql-installer-community-8.0.21.0.tar.gz(Linux 详细教程附安装包下载)​
  • 配置ipv6
  • UE5蓝图接口的创建和使用方法
  • 【C语言强化训练16天】--从基础到进阶的蜕变之旅:Day14
  • 在 Ubuntu 系统上安装 MySQL