当前位置: 首页 > news >正文

PyTorch入门之【AlexNet】

参考文献:https://www.bilibili.com/video/BV1DP411C7Bw/?spm_id_from=333.999.0.0&vd_source=98d31d5c9db8c0021988f2c2c25a9620
AlexNet 是一个经典的卷积神经网络模型,用于图像分类任务。

目录

  • 大纲
  • dataloader
  • model
  • train
  • test

大纲

在这里插入图片描述
各个文件的作用:

  • data就是数据集
  • dataloader.py就是数据集的加载以及实例初始化
  • model.py就是AlexNet模块的定义
  • train.py就是模型的训练
  • test.py就是模型的测试

dataloader

import torch
import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np


# define the dataloader
transform = transforms.Compose(
    [transforms.Resize(224),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 16

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


if __name__ == '__main__':
    # get some random training images
    dataiter = iter(train_loader)
    images, labels = next(dataiter)

    # print labels
    print(' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))

    # show images
    img_grid = torchvision.utils.make_grid(images)
    img_grid = img_grid / 2 + 0.5
    npimg = img_grid.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

model

import torch.nn as nn
import torch

class AlexNet(nn.Module):
    def __init__(self, num_classes=10):
        super(AlexNet, self).__init__()
        self.conv_1 = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 3, stride = 2))
        self.conv_2 = nn.Sequential(
            nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 3, stride = 2))
        self.conv_3 = nn.Sequential(
            nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU())
        self.conv_4 = nn.Sequential(
            nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU())
        self.conv_5 = nn.Sequential(
            nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 3, stride = 2))
        self.fc_1 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(9216, 4096),
            nn.ReLU())
        self.fc_2 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU())
        self.fc_3= nn.Sequential(
            nn.Linear(4096, num_classes))
        
    def forward(self, x):
        out = self.conv_1(x)
        out = self.conv_2(out)
        out = self.conv_3(out)
        out = self.conv_4(out)
        out = self.conv_5(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc_1(out)
        out = self.fc_2(out)
        out = self.fc_3(out)
        return out

if __name__ == '__main__':
    model = AlexNet()
    print(model)
    x = torch.randn(1, 3, 224, 224)
    y = model(x)
    print(y.size())

train

import torch
import torch.nn as nn

from dataloader import train_loader, test_loader
from model import AlexNet


# define the hyperparameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 10
num_epochs = 20
learning_rate = 1e-3


# load the model
model = AlexNet(num_classes).to(device)


# loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  


# train the model
total_len = len(train_loader)

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # move tensors to the configured device
        images = images.to(device)
        labels = labels.to(device)
        
        # forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                epoch+1, num_epochs, i+1, total_len, loss.item()
            ))
            
    # Validation
    with torch.no_grad():
        model.eval()
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        model.train()
        print('Accuracy of the network on the {} validation images: {} %'.format(10000, 100 * correct / total))

# save the model checkpoint
torch.save(model.state_dict(), 'alexnet.pth')

test

import torch

from dataloader import test_loader, classes
from model import AlexNet


# load the pretrained model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AlexNet().to(device)
model.load_state_dict(torch.load('alexnet.pth', map_location=device))

# test the pretrained model on CIFAR-10 test data
with torch.no_grad():
    model.eval()
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print('Accuracy of the network on the {} validation images: {} %'.format(10000, 100 * correct / total))

相关文章:

  • [React源码解析] React的设计理念和源码架构 (一)
  • JMeter性能测试
  • [架构之路-228]:目标系统 - 纵向分层 - 计算机硬件与体系结构 - 硬盘存储结构原理:如何表征0和1,即如何存储0和1,如何读数据,如何写数据(修改数据)
  • 【17】c++设计模式——>原型模式
  • Raid10--Raid01介绍
  • 华为云云耀云服务器L实例评测|基于canal缓存自动更新流程 SpringBoot项目应用案例和源码
  • 关于Go语言的底层,Slice,map
  • 博弈论——伯特兰德寡头模型(Bertrand Model)
  • 利用fitnesse实现api接口自动化测试
  • Waves 14混音特效插件合集mac/win
  • 目前制造企业生产计划现状是什么?有没有自动化排产系统?
  • minio分布式文件存储
  • SCROLLINFO scrollInfo; 2023/10/5 下午3:38:53
  • 【C语言】善于利用指针(二)
  • SpringCloud(二)Docker、Spring AMQP、ElasticSearch
  • 逐步解决Could not find artifact com:ojdbc8:jar:12
  • Ubuntu安装samba服务器
  • 讲讲springboot的@Async
  • 王杰国庆作业day6
  • 打开MySQL数据库
  • 国台办:实现祖国完全统一是大势所趋、大义所在、民心所向
  • 远如《月球背面》,近似你我内心
  • 日月谭天丨这轮中美关税会谈让台湾社会看清了什么?
  • 李强会见巴西总统卢拉
  • 这个“超强致癌细菌”,宝宝感染率高达40%,预防却很简单
  • 从普通人经历中发现历史,王笛解读《线索与痕迹》