当前位置: 首页 > news >正文

16.使用ResNet网络进行Fashion-Mnist分类

16.1 ResNet网络结构设计

在这里插入图片描述
在这里插入图片描述

################################################################################################################
#ResNet
################################################################################################################
import torch
from torch import nn
from torch.nn import functional as F
from torchsummary import summary
class Residual(nn.Module):def __init__(self, input_channels,num_channels,use_1x1conv=False,strides=1):super().__init__()self.conv1=nn.Conv2d(input_channels,num_channels,kernel_size=3,padding=1,stride=strides)self.conv2=nn.Conv2d(num_channels,num_channels,kernel_size=3,padding=1)if use_1x1conv:self.conv3=nn.Conv2d(input_channels,num_channels,kernel_size=1,stride=strides)else:self.conv3=Noneself.bn1=nn.BatchNorm2d(num_channels)self.bn2=nn.BatchNorm2d(num_channels)def forward(self,X):Y=F.relu(self.bn1(self.conv1(X)))Y=self.bn2(self.conv2(Y))if self.conv3:X=self.conv3(X)Y=Y+X #f(x)=x+g(x)return F.relu(Y)
def resnet_block(input_channels,num_channels,num_residuals,first_block=False):blk=[]#首先对齐residual的第一个层,需要做下采样#网络第一个阶段通常不做下采样,也不调整通道数。for i in range(num_residuals):if i==0 and not first_block:blk.append(Residual(input_channels,num_channels,use_1x1conv=True,strides=2))#这个是除了第一个模块之后需要执行的else:blk.append(Residual(num_channels,num_channels))return blk
b1=nn.Sequential(nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b2=nn.Sequential(*resnet_block(64,64,2,first_block=True))
b3=nn.Sequential(*resnet_block(64,128,2))
b4=nn.Sequential(*resnet_block(128,256,2))
b5=nn.Sequential(*resnet_block(256,512,2))
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=nn.Sequential(b1,b2,b3,b4,b5,nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(),nn.Linear(512,10)).to(device)
summary(model,input_size=(1,224,224),batch_size=64)

16.2 ResNet网络实现Fashion-Mnist分类

################################################################################################################
#ResNet
################################################################################################################
import torch
import torchvision
from torch import nn
import matplotlib.pyplot as plt
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from torch.nn import functional as F
plt.rcParams['font.family']=['Times New Roman']
class Reshape(torch.nn.Module):def forward(self,x):return x.view(-1,1,28,28)#[bs,1,28,28]
def plot_metrics(train_loss_list, train_acc_list, test_acc_list, title='Training Curve'):epochs = range(1, len(train_loss_list) + 1)plt.figure(figsize=(4, 3))plt.plot(epochs, train_loss_list, label='Train Loss')plt.plot(epochs, train_acc_list, label='Train Acc',linestyle='--')plt.plot(epochs, test_acc_list, label='Test Acc', linestyle='--')plt.xlabel('Epoch')plt.ylabel('Value')plt.title(title)plt.legend()plt.grid(True)plt.tight_layout()plt.show()
def train_model(model,train_data,test_data,num_epochs):train_loss_list = []train_acc_list = []test_acc_list = []for epoch in range(num_epochs):total_loss=0total_acc_sample=0total_samples=0loop=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")for X,y in loop:#X=X.reshape(X.shape[0],-1)#print(X.shape)X=X.to(device)y=y.to(device)y_hat=model(X)loss=CEloss(y_hat,y)optimizer.zero_grad()loss.backward()optimizer.step()#loss累加total_loss+=loss.item()*X.shape[0]y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()total_acc_sample+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数total_samples+=X.shape[0]test_acc_samples=0test_samples=0for X,y in test_data:X=X.to(device)y=y.to(device)#X=X.reshape(X.shape[0],-1)y_hat=model(X)y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()test_acc_samples+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数test_samples+=X.shape[0]avg_train_loss=total_loss/total_samplesavg_train_acc=total_acc_sample/total_samplesavg_test_acc=test_acc_samples/test_samplestrain_loss_list.append(avg_train_loss)train_acc_list.append(avg_train_acc)test_acc_list.append(avg_test_acc)print(f"Epoch {epoch+1}: Loss: {avg_train_loss:.4f},Trian Accuracy: {avg_train_acc:.4f},test Accuracy: {avg_test_acc:.4f}")plot_metrics(train_loss_list, train_acc_list, test_acc_list)return model
def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)
################################################################################################################
#ResNet
################################################################################################################
class Residual(nn.Module):def __init__(self, input_channels,num_channels,use_1x1conv=False,strides=1):super().__init__()self.conv1=nn.Conv2d(input_channels,num_channels,kernel_size=3,padding=1,stride=strides)self.conv2=nn.Conv2d(num_channels,num_channels,kernel_size=3,padding=1)if use_1x1conv:self.conv3=nn.Conv2d(input_channels,num_channels,kernel_size=1,stride=strides)else:self.conv3=Noneself.bn1=nn.BatchNorm2d(num_channels)self.bn2=nn.BatchNorm2d(num_channels)def forward(self,X):Y=F.relu(self.bn1(self.conv1(X)))Y=self.bn2(self.conv2(Y))if self.conv3:X=self.conv3(X)Y=Y+X #f(x)=x+g(x)return F.relu(Y)
def resnet_block(input_channels,num_channels,num_residuals,first_block=False):blk=[]#首先对齐residual的第一个层,需要做下采样#网络第一个阶段通常不做下采样,也不调整通道数。for i in range(num_residuals):if i==0 and not first_block:blk.append(Residual(input_channels,num_channels,use_1x1conv=True,strides=2))#这个是除了第一个模块之后需要执行的else:blk.append(Residual(num_channels,num_channels))return blk
b1=nn.Sequential(nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b2=nn.Sequential(*resnet_block(64,64,2,first_block=True))
b3=nn.Sequential(*resnet_block(64,128,2))
b4=nn.Sequential(*resnet_block(128,256,2))
b5=nn.Sequential(*resnet_block(256,512,2))
################################################################################################################
transforms=transforms.Compose([transforms.Resize(96),transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])#第一个是mean,第二个是std
train_img=torchvision.datasets.FashionMNIST(root="./data",train=True,transform=transforms,download=True)
test_img=torchvision.datasets.FashionMNIST(root="./data",train=False,transform=transforms,download=True)
train_data=DataLoader(train_img,batch_size=128,num_workers=4,shuffle=True)
test_data=DataLoader(test_img,batch_size=128,num_workers=4,shuffle=False)
################################################################################################################
device=torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model=nn.Sequential(b1,b2,b3,b4,b5,nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(),nn.Linear(512,10)).to(device)
model.apply(init_weights)
optimizer=torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.9)
CEloss=nn.CrossEntropyLoss()
model=train_model(model,train_data,test_data,num_epochs=15)
################################################################################################################

在这里插入图片描述

http://www.dtcms.com/a/278065.html

相关文章:

  • css如何同时给元素设置背景和背景图?
  • 每日算法刷题Day47:7.13:leetcode 复习完滑动窗口一章,用时2h30min
  • 说实话,统计分析用Python这5个第三方库就够了
  • AutoLabor-ROS-Python 学习记录——第一章 ROS概述与环境搭建
  • PortsSwiggerLab: SSRF with blacklist-based input filter
  • JS进阶-day1 作用域解构箭头函数
  • Spring AI 项目实战(十六):Spring Boot + AI + 通义万相图像生成工具全栈项目实战(附完整源码)
  • NO.5数据结构串和KMP算法|字符串匹配|主串与模式串|KMP|失配分析|next表
  • pthread_mutex_unlock函数的概念和用法
  • 大规模电商系统分库分表实战经验分享
  • NFSV4锁机制(三)
  • 编程技术杂谈2.0
  • DVWA靶场通关笔记-XSS DOM(High级别)
  • 垃圾收集器-Serial Old
  • CVE-2022-0609
  • vue2入门(1)vue核心语法详解复习笔记
  • 【开源项目】网络诊断告别命令行!NetSonar:开源多协议网络诊断利器
  • 1.1.1+1.1.3 操作系统的概念、功能
  • c++无锁队列moodycamel::ConcurrentQueue测试结果
  • 在高并发场景下,仅依赖数据库机制(如行锁、版本控制)无法完全避免数据异常的问题
  • Sping AI Alibaba
  • 第11章 AB实验评估指标体系
  • Soul方程式:Z世代背景下兴趣社交平台的商业模式解析
  • Java行业前景如何?零基础又该如何去学Java?
  • 深入理解 RocketMQ:生产者详解
  • 并行并发丨C++ 协程、现场池 学习笔记
  • 闲庭信步使用图像验证平台加速FPGA的开发:第十三课——图像浮雕效果的FPGA实现
  • 语言模型常用的激活函数(Sigmoid ,GeLU ,SwiGLU,GLU,SiLU,Swish)
  • 算法-汽水瓶兑换
  • Spring AI 项目实战(十七):Spring Boot + AI + 通义千问星辰航空智能机票预订系统(附完整源码)