当前位置: 首页 > news >正文

Pytroch搭建全连接神经网络识别MNIST手写数字数据集

编写步骤

之前已经记录国多次的编写步骤了,无需多言。
(1)准备数据集
这里我们使用MNIST数据集,有官方下载渠道。我们直接使用torchvison里面提供的数据读取功能包就行。如果不使用这个,自己像这样子构建也一样。

# 自己构建数据读取模块
#(1) 数据读取模块
class Mydataset(Dataset):
    def __init__(self,filepath):
        xy=np.loadtxt(filepath,delimiter=',',dtype=np.float32)
        self.len=xy.shape[0]
        self.x_data=torch.from_numpy(xy[:,:-1])
        self.y_data=torch.from_numpy(xy[:,[-1]])
    #魔法方法,容许用户通过索引index得到值
    def __getitem__(self,index):
        return self.x_data[index],self.y_data[index]
    def __len__(self):
        return self.len

这里直接使用torchvison里面的工具

#准备数据集
batch_size = 64
transforms = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,),(0.3081,))])

trainset = torchvision.datasets.MNIST(root=r'../data/mnist',
                                      train=True,
                                      download=True,
                                      transform=transforms)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)



testset = torchvision.datasets.MNIST(root=r'../data/mnist',
                                     train=False,
                                     download=True,
                                     transform=transforms)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

(2) 构建模型
这次我们使用不带dropout的全连接模型

# 定义模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.linear1 = nn.Linear(784, 100)
        self.linear2 = nn.Linear(100, 20)
        self.linear3 = nn.Linear(20, 10)
    def forward(self, x):
        x=x.view(x.size(0), -1)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x

(3) 选择损失和优化器

# 构建模型和损失
model=Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

(4)训练模型

def train(epoch):
    running_loss = 0.0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        #需要将张量转换为浮点数运算
        running_loss += loss.item()
        if batch_idx % 100 == 0:
            print('Train Epoch: {}, Loss: {:.6f}'.format(epoch, loss.item()))
            running_loss = 0

(5)测试模型

def test(epoch):
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct=correct+(predicted.eq(targets).sum()*1.0)
    print('Accuracy of the network on the 10000 test images: %d %%' % (100*correct/total))

全部代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
#准备数据集
batch_size = 64
transforms = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,),(0.3081,))])

trainset = torchvision.datasets.MNIST(root=r'../data/mnist',
                                      train=True,
                                      download=True,
                                      transform=transforms)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)



testset = torchvision.datasets.MNIST(root=r'../data/mnist',
                                     train=False,
                                     download=True,
                                     transform=transforms)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

# 定义模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.linear1 = nn.Linear(784, 100)
        self.linear2 = nn.Linear(100, 20)
        self.linear3 = nn.Linear(20, 10)
    def forward(self, x):
        x=x.view(x.size(0), -1)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x
# 构建模型和损失
model=Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

def train(epoch):
    running_loss = 0.0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        #需要将张量转换为浮点数运算
        running_loss += loss.item()
        if batch_idx % 100 == 0:
            print('Train Epoch: {}, Loss: {:.6f}'.format(epoch, loss.item()))
            running_loss = 0
def test(epoch):
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct=correct+(predicted.eq(targets).sum()*1.0)
    print('Accuracy of the network on the 10000 test images: %d %%' % (100*correct/total))
if __name__ == '__main__':
    for epoch in range(10):
        train(epoch)
        test(epoch)



在这里插入图片描述


文章转载自:

http://3jXFeCII.bnpcq.cn
http://UueMRpnW.bnpcq.cn
http://bCdq0mZc.bnpcq.cn
http://8oQTppRO.bnpcq.cn
http://CDAEyils.bnpcq.cn
http://bf2Yzy9v.bnpcq.cn
http://zfVf30yZ.bnpcq.cn
http://MBpXYpqO.bnpcq.cn
http://dThI3Sqk.bnpcq.cn
http://pDn8mLZJ.bnpcq.cn
http://z6KOWbw3.bnpcq.cn
http://Ad3uHuZR.bnpcq.cn
http://1Dd3L9eO.bnpcq.cn
http://ssoDW0JE.bnpcq.cn
http://BODQ4fhA.bnpcq.cn
http://i2DJX9FB.bnpcq.cn
http://rUgQbigM.bnpcq.cn
http://G5ZP635B.bnpcq.cn
http://iNqwWf7A.bnpcq.cn
http://asLtF1Tx.bnpcq.cn
http://RkIKELSk.bnpcq.cn
http://XOGbDmXq.bnpcq.cn
http://xCcUPeDO.bnpcq.cn
http://XXB73m1o.bnpcq.cn
http://lgT2twe3.bnpcq.cn
http://ALdBg0cK.bnpcq.cn
http://ZWEc4Acl.bnpcq.cn
http://vkLm7K2u.bnpcq.cn
http://aCDemOv9.bnpcq.cn
http://9lBmpyBO.bnpcq.cn
http://www.dtcms.com/a/97038.html

相关文章:

  • 在MFC中使用Qt(四):使用属性表(Property Sheet)实现自动化Qt编译流程
  • idea设置全局maven配置 对新建项目生效
  • 前端 - ts - - declare声明类型
  • 【斯坦福】【ICLR】RAPTOR:基于树结构的检索增强技术详解
  • RHCE 第一次作业 25-3-28
  • 火山dts迁移工具使用
  • linux》》docker 、containerd 保存镜像、打包tar、加载tar镜像
  • Android OTA升级中SettingsProvider数据库升级的深度解析与完美解决方案
  • Android R adb remount 调用流程
  • okhttp3网络请求
  • 【Apache Hive】
  • springboot3 整合 Log4j2
  • python3面试题(元类、内存管理、函数)
  • Maven工具学习使用(六)——聚合与继承
  • 24、web前端开发之CSS3(一)
  • java对pdf文件分页拆分
  • 第十四届MathorCup高校数学建模挑战赛-C题:基于 LSTM-ARIMA 和整数规划的货量预测与人员排班模型
  • 股指期货的连续主力合约能不能代表这个股指期货?
  • 人体细粒度分割sapiens 实战笔记
  • 数据设计(范式、步骤)
  • kubernetes》》k8s》》 kubeadm、kubectl、kubelet
  • Spring 约定编程案例与示例
  • uv 命令用conda命令解释
  • iOS抓包-charles和Stream
  • SAP:越来越多组织通过AI解决数据问题,迈向大规模应用
  • leetcode33.搜索旋转排序数组
  • 云原生四重涅槃·破镜篇:混沌工程证道心,九阳真火锻金身
  • 【商城实战(93)】商城高并发实战:分布式锁与事务处理深度剖析
  • Linux驱动编程 - UVC驱动分析
  • Java Optional:优雅处理空值的艺术,告别NullPointerException