pytorch学习笔记-argparse的使用(加更版)
写这篇文章是因为很多例程执行的语法并不是简单的python train.py,更多时候给了一些可选的参数,所以加更一篇关于参数解析的,当然关于argparse的使用有更多更详细的教程,我这里只记录一下基本的、在网络中的argparse的用法。
教程以最简单的设置epoch为例
参数解析需要经过以下过程(不写成函数也ok,但是还是建议写成函数):
import argparse#参数解析
def parse_args():#创建解析器对象parse = argparse.ArgumentParser(description="训练模型时的参数设置")#设置参数parse.add_argument("--epoch",default=10,type=int,help='训练次数')#参数解析args = parse.parse_args()return args
然后在外部需要调用的地方:
#解析命令行参数
args = parse_args() #这里调用的是上面定义的func
#通过args.name来访问对应参数
epoch = args.epoch
print(epoch)
此时执行python train_argparse.py --epoch 5,打印epoch,可以发现epoch由默认的10更新为了5
下面给出完整例程:
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch.nn as nn
from model import *
import argparse#参数解析
def parse_args():#创建解析器对象parse = argparse.ArgumentParser(description="训练模型时的参数设置")#设置参数parse.add_argument("--epoch",default=10,type=int,help='训练次数')#参数解析args = parse.parse_args()return argsfrom torch.utils.tensorboard import SummaryWriterdata_transforms = transforms.Compose([transforms.ToTensor()
])#引入数据集
train_data = datasets.CIFAR10("./dataset",train=True,transform=data_transforms,download=True)test_data = datasets.CIFAR10("./dataset",train=False,transform=data_transforms,download=True)#加载数据
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)#确定设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)#建立模型
my_module = MyModule()
my_module.to(device)#设置损失函数
cross_loss = nn.CrossEntropyLoss()
cross_loss.to(device)#设置优化器
#设置学习率
learning_rate = 1e-2
optimizer = torch.optim.SGD(my_module.parameters(),lr=learning_rate)#进行训练
#设置迭代次数
# epoch = 10#解析命令行参数
args = parse_args()
#通过args.name来访问对应参数
epoch = args.epoch
print(epoch)total_train_steps = 0writer = SummaryWriter("train_logs")for i in range(epoch):print("第{}轮训练".format(i+1))#训练my_module.train() #只对某些层起作用for data in train_dataloader:imgs, targets = dataimgs, targets = imgs.to(device), targets.to(device)outputs = my_module(imgs)#计算损失loss = cross_loss(outputs, targets)#优化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_steps +=1if total_train_steps % 100 ==0:print("训练次数:{},Loss:{}".format(total_train_steps,loss.item()))writer.add_scalar("train_loss",loss.item(),total_train_steps)#测试,不再梯度下降my_module.eval() #同样只对某些层起作用 total_test_loss = 0# total_test_steps = 0total_accuracy = 0test_data_size = len(test_data)with torch.no_grad():for data in test_dataloader:imgs, targets = dataimgs, targets = imgs.to(device), targets.to(device)outputs = my_module(imgs)loss = cross_loss(outputs,targets)total_test_loss += loss.item()##对于分类任务可以求一下准确的个数,非必须#argmax(1)按行取最大的下标 argmax(0)按列取最大的下标accuracy = (outputs.argmax(1)==targets).sum()total_accuracy += accuracyprint("第{}轮的测试集Loss:{}".format(i+1,total_test_loss))print("测试集准确率:{}".format(total_accuracy/test_data_size))writer.add_scalar("test_loss",total_test_loss,i)#存储模型if i % 5 == 0:torch.save(my_module.state_dict(),"my_module_{}.pth".format(i))print("模型存储成功")writer.close()