当前位置: 首页 > news >正文

12.使用VGG网络进行Fashion-Mnist分类

12.1 VGG网络结构设计

在这里插入图片描述

import torch
from torch import nn
import matplotlib.pyplot as plt
from torchsummary import summary
#vgg block实现
def vgg_block(num_convs,in_channels,out_channels):layers=[]for _ in range(num_convs):layers.append(nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1))layers.append(nn.ReLU())in_channels=out_channelslayers.append(nn.MaxPool2d(kernel_size=2,stride=2))#layers只是一个块return nn.Sequential(*layers)
#vgg
def vgg(conv_arch):conv_bls=[]in_channels=1for (num_convs,out_channels) in conv_arch:conv_bls.append(vgg_block(num_convs,in_channels,out_channels))in_channels=out_channels#下一个vgg block的in等于前一个的outnet=nn.Sequential(*conv_bls,nn.Flatten(),nn.Linear(out_channels*7*7,4096),nn.ReLU(),nn.Dropout(p=0.5),nn.Linear(4096,4096),nn.ReLU(),nn.Dropout(p=0.5),nn.Linear(4096,10))return net
conv_arch_11 = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))
conv_arch_16 = ((2, 64), (2, 128), (3, 256), (3, 512), (3, 512))
conv_arch_19 = ((2, 64), (2, 128), (4, 256), (4, 512), (4, 512))
X=torch.randn(size=(1,1,224,224))
model=vgg(conv_arch_19)
summary(model,input_size=(1,224,224))

在这里插入图片描述

12.2 VGG网络实现Fashion-Mnist分类

import torch
import torchvision
from torch import nn
import matplotlib.pyplot as plt
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import accuracy_score
plt.rcParams['font.family']=['Times New Roman']
def vgg_block(num_convs,in_channels,out_channels):layers=[]for _ in range(num_convs):layers.append(nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1))layers.append(nn.ReLU())in_channels=out_channelslayers.append(nn.MaxPool2d(kernel_size=2,stride=2))#layers只是一个块return nn.Sequential(*layers)
#vgg
def vgg(conv_arch):conv_bls=[]in_channels=1for (num_convs,out_channels) in conv_arch:conv_bls.append(vgg_block(num_convs,in_channels,out_channels))in_channels=out_channels#下一个vgg block的in等于前一个的outnet=nn.Sequential(*conv_bls,nn.Flatten(),nn.Linear(out_channels*7*7,4096),nn.ReLU(),nn.Dropout(p=0.5),nn.Linear(4096,4096),nn.ReLU(),nn.Dropout(p=0.5),nn.Linear(4096,10))return net
class Reshape(torch.nn.Module):def forward(self,x):return x.view(-1,1,28,28)#[bs,1,28,28]
def plot_metrics(train_loss_list, train_acc_list, test_acc_list, title='Training Curve'):epochs = range(1, len(train_loss_list) + 1)plt.figure(figsize=(4, 3))plt.plot(epochs, train_loss_list, label='Train Loss')plt.plot(epochs, train_acc_list, label='Train Acc',linestyle='--')plt.plot(epochs, test_acc_list, label='Test Acc', linestyle='--')plt.xlabel('Epoch')plt.ylabel('Value')plt.title(title)plt.legend()plt.grid(True)plt.tight_layout()plt.show()
def train_model(model,train_data,test_data,num_epochs):train_loss_list = []train_acc_list = []test_acc_list = []for epoch in range(num_epochs):total_loss=0total_acc_sample=0total_samples=0#loop=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")loop1=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")loop2=tqdm(test_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")for X,y in loop1:#X=X.reshape(X.shape[0],-1)#print(X.shape)X=X.to(device)y=y.to(device)y_hat=model(X)loss=CEloss(y_hat,y)optimizer.zero_grad()loss.backward()optimizer.step()#loss累加total_loss+=loss.item()*X.shape[0]y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()total_acc_sample+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数total_samples+=X.shape[0]test_acc_samples=0test_samples=0for X,y in loop2:X=X.to(device)y=y.to(device)#X=X.reshape(X.shape[0],-1)y_hat=model(X)y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()test_acc_samples+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数test_samples+=X.shape[0]avg_train_loss=total_loss/total_samplesavg_train_acc=total_acc_sample/total_samplesavg_test_acc=test_acc_samples/test_samplestrain_loss_list.append(avg_train_loss)train_acc_list.append(avg_train_acc)test_acc_list.append(avg_test_acc)print(f"Epoch {epoch+1}: Loss: {avg_train_loss:.4f},Trian Accuracy: {avg_train_acc:.4f},test Accuracy: {avg_test_acc:.4f}")plot_metrics(train_loss_list, train_acc_list, test_acc_list)return model
def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)
################################################################################################################
#注意这里从28*28 resize成224*224了
transforms=transforms.Compose([transforms.Resize(224),transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])#第一个是mean,第二个是std
train_img=torchvision.datasets.FashionMNIST(root="./data",train=True,transform=transforms,download=True)
test_img=torchvision.datasets.FashionMNIST(root="./data",train=False,transform=transforms,download=True)
train_data=DataLoader(train_img,batch_size=256,num_workers=4,shuffle=True)
test_data=DataLoader(test_img,batch_size=256,num_workers=4,shuffle=False)
################################################################################################################
conv_arch_11 = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))
conv_arch_16 = ((2, 64), (2, 128), (3, 256), (3, 512), (3, 512))
conv_arch_19 = ((2, 64), (2, 128), (4, 256), (4, 512), (4, 512))
model=vgg(conv_arch_11)
model.apply(init_weights)
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
#print(device)
model=model.to(device)
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
CEloss=nn.CrossEntropyLoss()
num_epochs=1
model=train_model(model,train_data,test_data,num_epochs)
http://www.dtcms.com/a/275870.html

相关文章:

  • Jenkins+Gitee+Docker容器化部署
  • 三步定位 Git Push 403:从日志到解决
  • 【深度剖析】致力“四个最”的君乐宝数字化转型(下篇:转型成效5-打造数字化生存能力探索可持续发展路径)
  • 【Datawhale AI夏令营】mcp-server
  • LeetCode 每日一题 2025/7/7-2025/7/13
  • 1. 好的设计原则
  • XCTF-Mary_Morton双漏洞交响曲:格式化字符串漏洞泄露Canary与栈溢出劫持的完美配合
  • 【2024CSP-J初赛】阅读程序(2)试题详解
  • 剑指offer57_和为S的两个数字
  • 深入详解:决策树在医学影像脑部疾病诊断中的应用与实现
  • Java 属性配置文件读取方法详解
  • 《Java HashMap底层原理全解析(源码+性能+面试)》
  • LangChain 的链(Chain)
  • Java 接口与抽象类:深入解析两者的区别及应用场景
  • 【深度学习】常见评估指标Params、FLOPs、MACs
  • 牛客:HJ19 简单错误记录[华为机考][字符串]
  • 多表查询-4-外连接
  • EMC接地
  • 试用了10款翻译软件后,我只推荐这一款!完全免费还超好用
  • 6.isaac sim4.2 教程-Core API-多机器人,多任务
  • 单细胞入门(1)——介绍
  • C语言中整数编码方式(原码、反码、补码)
  • C++ 模板工厂、支持任意参数代理、模板元编程
  • 如何使用postman做接口测试?
  • dify 用postman调试参数注意
  • MOSFET驱动电路设计时,为什么“慢”开,“快”关?
  • 《Java Web程序设计》实验报告二 学习使用HTML标签、表格、表单
  • 零基础搭建监控系统:Grafana+InfluxDB 保姆级教程,5分钟可视化服务器性能!​
  • elementuiPlus+vue3手脚架后台管理系统,上生产环境之后,如何隐藏vite.config.ts的target地址
  • 游戏开发日记7.12