当前位置: 首页 > news >正文

pytorch学习笔记-argparse的使用(加更版)

写这篇文章是因为很多例程执行的语法并不是简单的python train.py,更多时候给了一些可选的参数,所以加更一篇关于参数解析的,当然关于argparse的使用有更多更详细的教程,我这里只记录一下基本的、在网络中的argparse的用法。

教程以最简单的设置epoch为例

参数解析需要经过以下过程(不写成函数也ok,但是还是建议写成函数):

import argparse#参数解析
def parse_args():#创建解析器对象parse = argparse.ArgumentParser(description="训练模型时的参数设置")#设置参数parse.add_argument("--epoch",default=10,type=int,help='训练次数')#参数解析args = parse.parse_args()return args

然后在外部需要调用的地方:

#解析命令行参数
args = parse_args() #这里调用的是上面定义的func
#通过args.name来访问对应参数
epoch = args.epoch
print(epoch)

此时执行python train_argparse.py --epoch 5,打印epoch,可以发现epoch由默认的10更新为了5

下面给出完整例程:

import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch.nn as nn
from model import *
import argparse#参数解析
def parse_args():#创建解析器对象parse = argparse.ArgumentParser(description="训练模型时的参数设置")#设置参数parse.add_argument("--epoch",default=10,type=int,help='训练次数')#参数解析args = parse.parse_args()return argsfrom torch.utils.tensorboard import SummaryWriterdata_transforms = transforms.Compose([transforms.ToTensor()
])#引入数据集
train_data = datasets.CIFAR10("./dataset",train=True,transform=data_transforms,download=True)test_data = datasets.CIFAR10("./dataset",train=False,transform=data_transforms,download=True)#加载数据
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)#确定设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)#建立模型
my_module = MyModule()
my_module.to(device)#设置损失函数
cross_loss = nn.CrossEntropyLoss()
cross_loss.to(device)#设置优化器
#设置学习率
learning_rate = 1e-2
optimizer = torch.optim.SGD(my_module.parameters(),lr=learning_rate)#进行训练
#设置迭代次数
# epoch = 10#解析命令行参数
args = parse_args()
#通过args.name来访问对应参数
epoch = args.epoch
print(epoch)total_train_steps = 0writer = SummaryWriter("train_logs")for i in range(epoch):print("第{}轮训练".format(i+1))#训练my_module.train() #只对某些层起作用for data in train_dataloader:imgs, targets = dataimgs, targets = imgs.to(device), targets.to(device)outputs = my_module(imgs)#计算损失loss = cross_loss(outputs, targets)#优化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_steps +=1if total_train_steps % 100 ==0:print("训练次数:{},Loss:{}".format(total_train_steps,loss.item()))writer.add_scalar("train_loss",loss.item(),total_train_steps)#测试,不再梯度下降my_module.eval() #同样只对某些层起作用 total_test_loss = 0# total_test_steps = 0total_accuracy = 0test_data_size = len(test_data)with torch.no_grad():for data in test_dataloader:imgs, targets = dataimgs, targets = imgs.to(device), targets.to(device)outputs = my_module(imgs)loss = cross_loss(outputs,targets)total_test_loss += loss.item()##对于分类任务可以求一下准确的个数,非必须#argmax(1)按行取最大的下标 argmax(0)按列取最大的下标accuracy = (outputs.argmax(1)==targets).sum()total_accuracy += accuracyprint("第{}轮的测试集Loss:{}".format(i+1,total_test_loss))print("测试集准确率:{}".format(total_accuracy/test_data_size))writer.add_scalar("test_loss",total_test_loss,i)#存储模型if i % 5 == 0:torch.save(my_module.state_dict(),"my_module_{}.pth".format(i))print("模型存储成功")writer.close()
http://www.dtcms.com/a/339482.html

相关文章:

  • 基于SpringBoot+Vue的写真馆预约管理系统(邮箱通知、WebSocket及时通讯、协同过滤算法)
  • 哪些仪器适合对接电子实验记录本,哪些不适合?
  • Java 11中的Collections类详解
  • Web安全攻防基础
  • 什么是IP隔离?一文讲清跨境电商/海外社媒的IP隔离逻辑
  • JVM对象创建和内存分配
  • 2025年12大AI测试自动化工具
  • 基礎複分析習題6.級數與乘積展開
  • 广东省省考备考(第八十一天8.19)——资料分析、数量(强化训练)
  • MVC、MVP、MVCC 和 MVI 架构的介绍及区别对比
  • 面试题储备-MQ篇 2-说说你对RocketMQ的理解
  • 基于WebSocket和SpringBoot聊天项目ChatterBox测试报告
  • 怎样平衡NLP技术发展中数据质量和隐私保护的关系?
  • 中科米堆CASAIM自动化三维测量设备测量汽车壳体直径尺寸
  • 多模态大模型应用落地:从图文生成到音视频交互的技术选型与实践
  • 5.1Pina介绍
  • 进程间的通信(管道,信号)
  • 知行社:以爱之名,共筑公益梦想
  • Podman:Mysql(使用卷)
  • 【Goland】:面向对象编程
  • Day 29 类的装饰器
  • 如何将任意文件一键转为PDF?
  • 【PHP】模拟斗地主后端编写
  • Matplotlib数据可视化实战:Matplotlib图表美化与进阶教程
  • 软件系统运维常见问题
  • idea中如何设置文件的编码格式
  • Python Day31 JavaScript 基础核心知识点详解 及 例题分析
  • 【完整源码+数据集+部署教程】太阳能板表面损伤检测图像分割系统源码和数据集:改进yolo11-DynamicHGNetV2
  • 服务器Linux防火墙怎样实现访问控制
  • Nginx前后端分离反代(VUE+FastAPI)