当前位置: 首页 > news >正文

19.数据增强技术

19.1 图像水平翻转与垂直翻转

import torch
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):Y = [aug(img) for _ in range(num_rows * num_cols)]figsize = (num_cols * scale, num_rows * scale)fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()  for ax, y in zip(axes, Y):ax.imshow(y)ax.axis('off')plt.tight_layout()plt.show()
apply(image,torchvision.transforms.RandomHorizontalFlip())#左右翻转
apply(image,torchvision.transforms.RandomVerticalFlip())#上下翻转

在这里插入图片描述

19.1 图像随机裁剪与色彩调整

import torch
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):Y = [aug(img) for _ in range(num_rows * num_cols)]figsize = (num_cols * scale, num_rows * scale)fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()  for ax, y in zip(axes, Y):ax.imshow(y)ax.axis('off')plt.tight_layout()plt.show()
apply(image,torchvision.transforms.RandomResizedCrop(size=(200,200),scale=(0.2,1),ratio=(1,2)))#随机裁剪
apply(image,torchvision.transforms.ColorJitter(brightness=0.5,contrast=0.5,saturation=0.5,hue=0.5))#明亮度/对比度等调整

在这里插入图片描述

19.3 图像整体增强变换

import torch
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):Y = [aug(img) for _ in range(num_rows * num_cols)]figsize = (num_cols * scale, num_rows * scale)fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()  for ax, y in zip(axes, Y):ax.imshow(y)ax.axis('off')plt.tight_layout()plt.show()
loc_aug=torchvision.transforms.RandomHorizontalFlip()
shape_aug = torchvision.transforms.RandomResizedCrop((200, 200), scale=(0.1, 1), ratio=(0.5, 2))
color_aug = torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
augs = torchvision.transforms.Compose([loc_aug,color_aug, shape_aug])
apply(image, augs)

在这里插入图片描述

19.4 基于CiFar-10数据集的图像增强效果对比

在这里插入图片描述

################################################################################################################
#ResNet
################################################################################################################
import torch
import torchvision
from torch import nn
import matplotlib.pyplot as plt
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import torchvision.models as models
from sklearn.metrics import accuracy_score
from torch.nn import functional as F
plt.rcParams['font.family']=['Times New Roman']
class Reshape(torch.nn.Module):def forward(self,x):return x.view(-1,1,28,28)#[bs,1,28,28]
def plot_metrics(train_loss_list, train_acc_list, test_acc_list, title='Training Curve'):epochs = range(1, len(train_loss_list) + 1)plt.figure(figsize=(4, 3))plt.plot(epochs, train_loss_list, label='Train Loss')plt.plot(epochs, train_acc_list, label='Train Acc',linestyle='--')plt.plot(epochs, test_acc_list, label='Test Acc', linestyle='--')plt.xlabel('Epoch')plt.ylabel('Value')plt.title(title)plt.legend()plt.grid(True)plt.tight_layout()plt.show()
def train_model(model,train_data,test_data,num_epochs):train_loss_list = []train_acc_list = []test_acc_list = []for epoch in range(num_epochs):total_loss=0total_acc_sample=0total_samples=0loop1=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")for X,y in loop1:#X=X.reshape(X.shape[0],-1)#print(X.shape)X=X.to(device)y=y.to(device)y_hat=model(X)loss=CEloss(y_hat,y)optimizer.zero_grad()loss.backward()optimizer.step()#loss累加total_loss+=loss.item()*X.shape[0]y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()total_acc_sample+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数total_samples+=X.shape[0]test_acc_samples=0test_samples=0loop2=tqdm(test_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")for X,y in loop2:X=X.to(device)y=y.to(device)#X=X.reshape(X.shape[0],-1)y_hat=model(X)y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()test_acc_samples+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数test_samples+=X.shape[0]avg_train_loss=total_loss/total_samplesavg_train_acc=total_acc_sample/total_samplesavg_test_acc=test_acc_samples/test_samplestrain_loss_list.append(avg_train_loss)train_acc_list.append(avg_train_acc)test_acc_list.append(avg_test_acc)print(f"Epoch {epoch+1}: Loss: {avg_train_loss:.4f},Trian Accuracy: {avg_train_acc:.4f},test Accuracy: {avg_test_acc:.4f}")plot_metrics(train_loss_list, train_acc_list, test_acc_list)return model
def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)
################################################################################################################
#这里选取一个是翻转,一个是归一化,一个是调整明亮度,最后是tensor化
transforms_train=transforms.Compose([transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.5),transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
transforms_test=transforms.Compose([transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.5),transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
train_img=torchvision.datasets.CIFAR10(root="./data",train=True,transform=transforms_train,download=True)
test_img=torchvision.datasets.CIFAR10(root="./data",train=False,transform=transforms_test,download=True)
train_data=DataLoader(train_img,batch_size=128,num_workers=4,shuffle=True)
test_data=DataLoader(test_img,batch_size=128,num_workers=4,shuffle=False)
################################################################################################################
device=torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model=models.resnet50(pretrained=True)#直接调用ResNet-50进行训练
model.fc=nn.Linear(model.fc.in_features,10)
model.to(device)
model.apply(init_weights)
optimizer=torch.optim.SGD(model.parameters(),lr=0.05,momentum=0.9)
CEloss=nn.CrossEntropyLoss()
model=train_model(model,train_data,test_data,num_epochs=20)
################################################################################################################

在这里插入图片描述

http://www.dtcms.com/a/278792.html

相关文章:

  • 管程! 解决互斥,同步问题的现代化手段(操作系统os)
  • Java行为型模式---模板方法模式
  • Imx6ull用网线与电脑连接
  • SpringBoot JAR 反编译替换文件
  • 【嵌入式汇编基础】-操作系统基础(三)
  • 【每日刷题】移动零
  • LabVIEW-Origin 船模数据处理系统
  • 【爬虫】Python实现爬取京东商品信息(超详细)
  • 期权和期货的区别主要是什么?
  • [论文阅读] 人工智能 | 用大型语言模型玩转多语言主观性检测:CheckThat! 2025赛事中的亮眼表现
  • Unity3D + VS2022连接雷电模拟器调试
  • 【PTA数据结构 | C语言版】字符串连接操作(不限长)
  • 分布式一致性协议
  • Android动画:属性动画以及实现点击图标缩放的动画效果
  • Relocations in generic ELF (EM: 40)
  • “国乙黑月光”指的是谁?
  • YOLOv11调参指南
  • Maven 依赖原则和依赖冲突
  • Docker入门指南(超详细)
  • Jetpack Compose 重组陷阱:一个“乌龙”带来的启示
  • yolo8+声纹识别(实时字幕)
  • 从“炼丹”到“流水线”——如何用Prompt Engineering把LLM微调成本打下来?
  • 前端缓存优化全景指南:从HTTP到应用层的性能加速实践
  • 学习软件测试的第十五天
  • PHP password_verify() 函数
  • 设备巡检系统的主要用途
  • Java 大视界 -- 基于 Java 的大数据可视化在城市地下管网管理与风险预警中的应用
  • 2025-07-14如何批量下载behance网站里的图片?
  • 神经网络项目--基于FPGA的AI简易项目(1-9图片数字识别)
  • 如何基于FFMPEG 实现视频推拉流