当前位置: 首页 > news >正文

PyTorch深度学习进阶(四)(数据增广)

数据增广

对图片做不同处理,如去掉部分像素,对颜色变换,对亮度变换

一般是将不同的生成方法随机的用在数据上

总结

代码

基础操作

读取图片

img = d2l.Image.open('01_Data/02_cat.jpg')

显示图片

d2l.plt.imshow(img)

传入aug图片增广方法

def apply(img, aug, num_rows=2, num_cols=4, scale=1.5)

用aug方法对图片作用八次

Y = [aug(img) for _ in range(num_rows * num_cols)]

生成结果用num_cols行,num_cols列展示 

d2l.show_images(Y, num_rows, num_cols, scale=scale) 

水平随机翻转

apply(img, torchvision.transforms.RandomHorizontalFlip())

上下随机翻转

apply(img, torchvision.transforms.RandomVerticalFlip())

随机剪裁,剪裁后的大小为(200,200)

(0.1,1)使得随即剪裁原始图片的10%到100%区域里的大小,ratio=(0.5,2)使得高宽比为2:1,下面是显示时显示的1:1

shape_aug = torchvision.transforms.RandomResizedCrop((200,200),scale=(0.1,1),ratio=(0.5,2))     
apply(img,shape_aug)

随即更改图像的亮度

apply(img,torchvision.transforms.ColorJitter(brightness=0.5,contrast=0,saturation=0,hue=0))

随即改变色调

apply(img,torchvision.transforms.ColorJitter(brightness=0,contrast=0,saturation=0,hue=0.5))

随机更改图像的亮度(brightness)、对比度(constrast)、饱和度(saturation)和色调(hue)

color_aug = torchvision.transforms.ColorJitter(brightness=0.5,contrast=0.5,saturation=0.5,hue=0.5)
apply(img,color_aug)

结合多种图像增广方法
先随即水平翻转,再做颜色增广,再做形状增广

augs = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),color_aug,shape_aug])   
apply(img,augs)

训练

下载图片,并显示部分图片

all_images = torchvision.datasets.CIFAR10(train=True, root='01_Data/03_CIFAR10', download=True)    
d2l.show_images([all_images[i][0] for i in range(32)], 4, 8, scale=0.8)

只使用最简单的随机左右翻转

train_augs = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor()])test_augs = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])  

定义一个辅助函数,以便于读取图像和应用图像增广

def load_cifar10(is_train, augs, batch_size):dataset = torchvision.datasets.CIFAR10(root='01_Data/03_CIFAR10',train=is_train,transform=augs, download=True)dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=is_train,num_workers = 0)   return dataloader

定义一个函数,使用多GPU模式进行训练和评估

def train_batch_ch13(net, X, y, loss, trainer, devices):

如果X是一个list,则把数据一个接一个都挪到devices[0]上

if isinstance(X, list):X = [x.to(devices[0]) for x in X]

训练一个batch

如果X不是一个list,则把X挪到devices[0]上

else:X = X.to(devices[0])
    y = y.to(devices[0])net.train()trainer.zero_grad()pred = net(X)l = loss(pred, y)l.sum().backward()trainer.step()train_loss_sum = l.sum()train_acc_sum = d2l.accuracy(pred, y)return train_loss_sum, train_acc_sum
def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices=d2l.try_all_gpus()):timer, num_batches = d2l.Timer(), len(train_iter)animator = d2l.Animator(xlabel='epoch',xlim=[1,num_epochs],ylim=[0,1],legend=['train loss', 'train acc', 'test acc'])# nn.DataParallel使用多GPUnet = nn.DataParallel(net, device_ids=devices).to(devices[0])for epoch in range(num_epochs):metric = d2l.Accumulator(4)for i, (features, labels) in enumerate(train_iter):timer.start()l, acc = train_batch_ch13(net,features,labels,loss,trainer,devices)   metric.add(l,acc,labels.shape[0],labels.numel())timer.stop()if (i + 1) % (num_batches // 5) == 0 or i == num_batches -1:animator.add(epoch + (i + 1) / num_batches,(metric[0] / metric[2], metric[1] / metric[3], None))              test_acc = d2l.evaluate_accuracy_gpu(net,test_iter)animator.add(epoch+1,(None,None,test_acc))print(f'loss {metric[0] / metric[2]:.3f}, train acc'f' {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f' {metric[2] * num_epochs / timer.sum():.1f} examples/sec on 'f' {str(devices)}') 

定义train_with_data_aug函数,使用图像增广来训练模型

batch_size, devices, net = 256, d2l.try_all_gpus(), d2l.resnet18(10,3)def init_weights(m):if type(m) in [nn.Linear, nn.Conv2d]:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)def train_with_data_aug(train_augs, test_augs, net, lr=0.001):train_iter = load_cifar10(True, train_augs, batch_size)test_iter = load_cifar10(False, test_augs, batch_size)loss = nn.CrossEntropyLoss(reduction="none")# Adam优化器算是一个比较平滑的SGD,它对学习率调参不是很敏感trainer = torch.optim.Adam(net.parameters(),lr=lr)train_ch13(net, train_iter, test_iter, loss, trainer, 10, devices)train_with_data_aug(train_augs, test_augs, net)

结果

http://www.dtcms.com/a/610068.html

相关文章:

  • 股指期货豁免开通条件是什么?
  • 上传模型/数据集到huggingface的三种方法
  • 33_FastMCP 2.x 中文文档之FastMCP客户端核心业务:提示模板详解
  • wordpress插件访客亚马逊seo推广
  • Juc篇-线程安全问题引入(从i++问题的底层出发)
  • Arbess V2.1.7版本发布,新增任务AliYun OSS上传、下载功能,新增流水线评审功能
  • 算法基础篇:(八)贪心算法之简单贪心:从直觉到逻辑的实战指南
  • 昊源建设监理有限公司网站外贸网站代码
  • 大专生就业:学历限制的现实考量与能力突围路径
  • Node.js 与 Docker 深度整合:轻松部署与管理 Node.js 应用
  • 中国企业500强榜单2021廊坊seo排名优化
  • 基于高光谱成像和偏最小二乘法(PLS)的苹果糖度检测MATLAB实现
  • 随访系统如何支持临床研究和数据分析?
  • idea 刷新maven,提示java.lang.RuntimeException: java.lang.OutOfMemoryError
  • 邢台本地网站vue做的pc线上网站
  • Arang Briket木炭块检测与识别:基于Mask R-CNN的精确识别方案详解
  • 怎么在百度建设一个网站工业设计大学排名前50
  • 【C++:封装红黑树】C++红黑树封装实战:从零实现MyMap与MySet
  • 构建AI智能体:九十四、Hugging Face 与 Transformers 完全指南:解锁现代 NLP 的强大力量
  • 保定网站排名哪家公司好有没一些网站只做临床药学
  • 目前做网站流行的语言网站策划书市场分析2000字
  • 18.HTTP协议(一)
  • 【每天一个AI小知识】:什么是逻辑回归?
  • Moe框架分析
  • Windows下nacos开机自启动
  • C++ 11 中的move赋值运算符
  • Java:startsWith()
  • 【Linux】进程间通信(四)消息队列、信号量与内核管理ipc资源机制
  • php整站最新版本下载html5 网站开发工具
  • wordpress更换网站数据库中国网络公司排名前十