当前位置: 首页 > news >正文

数据增强和微调

这几天看小说去了😋摆了两天😭

codes

文章目录

  • 📊 9. 数据增强和微调
    • 🎯 前言
    • 🌈 数据增强技术
    • 9.1 随机裁剪 ✂️
      • `transforms.RandomCrop`实战
    • 9.2 水平翻转 🔄
      • `transforms.RandomHorizontalFlip`应用
    • 9.3 色彩调整 🎨
      • `transforms.ColorJitter`综合应用
    • 🧩 实战:构建增强型数据集
      • 完整数据处理流程
      • 🚀 双变换策略
      • 创建dataloader
    • 🧠 迁移学习与微调技术
      • VGG16预训练模型初始化
      • ⚙️ 微调核心操作
    • 📈 性能可视化
      • 损失曲线对比
      • 准确率变化趋势
    • 💎 核心经验总结
    • 可能遇到的问题

📊 9. 数据增强和微调

🎯 前言

数据增强(Data augmentation)是指对训练数据进行变换以增加训练数据量,提高模型的泛化能力。微调(Fine-tuning)则是迁移学习中的关键技术,通过解冻预训练模型的部分层进行二次训练,显著提升模型在特定任务上的性能

🌈 数据增强技术

数据增强的核心方法

  1. 🔄 翻转:水平或垂直翻转图像
  2. ✂️ 裁剪:随机裁剪图像区域
  3. 🔄 旋转:随机旋转图像角度
  4. 🔍 缩放:随机改变图像大小
  5. ↔️ 平移:随机移动图像位置
  6. 📶 噪声:添加随机噪点
  7. 🎨 颜色抖动:改变图像色彩属性
  8. ☀️ 亮度/对比度/饱和度调整
  9. ⚙️ 其他方法:模糊、锐化等

💡 重要原则:仅对训练集进行数据增强,测试集和验证集保持原始状态,确保评估结果客观准确


9.1 随机裁剪 ✂️

transforms.RandomCrop实战

# 读取一张图片
pil_img=Image.open("/root/autodl-tmp/dataset2/cloudy9.jpg"   ) # 读取图片
# 设置裁剪变换
transform = transforms.Compose([transforms.Resize((256, 256)),      # 调整图片尺寸transforms.RandomCrop((224, 224))    # 随机裁剪
])# 应用变换并显示结果
plt.figure(figsize=(12, 8))
for i in range(6):img = transform(pil_img)plt.subplot(2, 3, i + 1)plt.imshow(img)
plt.show()

效果说明:每次执行都会生成不同位置的裁剪结果,有效增加数据多样性


9.2 水平翻转 🔄

transforms.RandomHorizontalFlip应用

# 100%概率执行水平翻转
trans_img = transforms.RandomHorizontalFlip(p=1)(pil_img)# 对比显示原图与翻转图
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(pil_img)  # 原图
plt.subplot(1, 2, 2)
plt.imshow(trans_img)  # 翻转图
plt.show()

关键参数p值控制翻转概率,推荐设置为0.5实现随机翻转效果
在这里插入图片描述


9.3 色彩调整 🎨

transforms.ColorJitter综合应用

明暗、颜色、对比度、饱和度

# 设置色彩抖动参数
trandform = transforms.ColorJitter(brightness=(0.7, 1.3),  # 亮度调整范围contrast=(0.7, 1.3),    # 对比度调整范围saturation=(0.7, 1.3),  # 饱和度调整范围hue=(-0.05, 0.05)       # 色调调整范围
)# 生成多样本展示
for i in range(6):img = trandform(pil_img)plt.subplot(2, 3, i + 1)plt.imshow(img)
plt.show()

参数详解

  • brightness:亮度调整系数
  • contrast:对比度调整系数
  • saturation:饱和度调整系数
  • hue:色调偏移值(范围-0.5~0.5)
    在这里插入图片描述

🧩 实战:构建增强型数据集

完整数据处理流程

# 定义数据集类
class WT_Dataset(data.Dataset):def __init__(self, imgs_path, labels, transform):self.imgs_path = imgs_pathself.labels = labelsself.transform = transformdef __len__(self):return len(self.imgs_path)def __getitem__(self, index):img_path = self.imgs_path[index]label = self.labels[index]pil_img = Image.open(img_path).convert('RGB')return self.transform(pil_img), label# 划分训练集/测试集(8:2比例)
imgs=glob.glob(r'/root/autodl-tmp/dataset2/*.jpg') 
# 上面是读取图片的路径,/*.jpg表示读取所有jpg格式的图片,imgs是一个列表,里面包含了所有图片的路径。species=['cloudy','rain','shine','sunrise'] #4 classes
# 字典推导式获取类别到编号的映射关系
species_to_idx=dict((c,i) for i,c in enumerate(species))
# 字典推导式获取编号到类别的映射关系
idx_to_species=dict((i,c) for i,c in enumerate(species))# 下面提取图片路径列表对应的标签列表
labels=[]
for img in imgs:for i,c in enumerate(species):if c in img: # 判断图片路径是否包含某个种类的名称labels.append(i)# 为了随即划分训练数据和测试数据,我们先设置一个乱序的index
# 同时对图片和路径使用index进行乱序,这样保证了乱序之后图片和路径是一一对应的
np.random.seed(2025) # 设置随机种子,保证每次运行结果一致
index=np.random.permutation(len(imgs)) # 随机生成index
'''
permutation函数返回一个乱序的index,这里我们将其赋值给index变量,index是一个一维数组,包含了0到len(imgs)-1的所有整数,顺序是随机的。
'''
imgs=np.array(imgs)[index] # 将图片路径按照index进行乱序labels=np.array(labels)[index] # 将标签按照index进行乱序# 对乱序后的数据,直接切片选取前80%的数据作为训练集,后20%的数据作为测试集。
seq=int(len(imgs)*0.8) # 计算80%的数据量
train_imgs=imgs[:seq] # 切片选取前80%的数据作为训练集
train_labels=labels[:seq] # 切片选取前80%的数据的标签作为训练集的标签
test_imgs=imgs[seq:] # 切片选取后20%的数据作为测试集
test_labels=labels[seq:] # 切片选取后20%的数据的标签作为测试集的标签

🚀 双变换策略

训练集增强变换

train_transform = transforms.Compose([transforms.Resize((256, 256)),transforms.RandomCrop((224, 224)),transforms.RandomHorizontalFlip(p=0.5),transforms.ColorJitter(brightness=(0.7, 1.3),contrast=(0.7, 1.3),saturation=(0.7, 1.3),hue=(-0.05, 0.05)),transforms.ToTensor(),transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

测试集基础变换

test_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

创建dataloader

# 使用train_transform初始化训练数据的dataset
train_ds=WT_Dataset(train_imgs, train_labels, train_transform)
# 使用test_transform初始化验证数据的dataset
test_ds=WT_Dataset(test_imgs, test_labels, test_transform)train_dl=data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl=data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

🧠 迁移学习与微调技术

这一部分只讲核心代码,具体的模型训练代码可以去我的gitee上看

VGG16预训练模型初始化

# 加载预训练模型(新式写法)
from torchvision.models import VGG16_Weights
model = torchvision.models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1)# 冻结卷积层参数 🔒
for param in model.features.parameters():param.requires_grad = False# 修改输出层为四分类
model.classifier[-1].out_features = 4

⚙️ 微调核心操作

# 解冻所有层参数 🔓
for param in model.parameters():param.requires_grad = True# 使用更小的学习率
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)# 继续训练
extend_epochs = 15
train_loss_, train_acc_, test_loss_, test_acc_ = fit(extend_epochs, ...)

微调效果:通过解冻参数+继续训练,模型准确率平均提升10-15%


📈 性能可视化

损失曲线对比

plt.plot(range(1, epochs+1), train_loss, label='train_loss')
plt.plot(range(1, epochs+1), test_loss, label='test_loss')
plt.legend()
plt.show()

准确率变化趋势

plt.plot(range(1, epochs+extend_epochs+1), train_acc+train_acc_, label='train_acc')
plt.plot(range(1, epochs+extend_epochs+1), test_acc+test_acc_, label='test_acc')
plt.legend()
plt.show()

💎 核心经验总结

  1. 数据增强黄金组合:裁剪+翻转+色彩抖动
  2. 微调最佳实践
    • 先冻结训练分类器
    • 解冻后使用更小学习率(1e-5量级)
    • 适当增加训练轮次(10-20 epochs)
  3. 效果验证:微调后测试准确率普遍提升10%以上
  4. 注意事项
    • 务必设置随机种子保证可复现性
    • 验证集/测试集禁用数据增强
    • 监控过拟合现象(训练/验证损失分化)

可能遇到的问题

在使用某个深度学习库(如 torchvision)时,pretrained参数已被弃用,需要改用weights参数。
例如

# 旧代码(会触发警告)
import torchvision
model = torchvision.models.resnet50(pretrained=True)# 新代码(推荐)
import torchvision
from torchvision.models import ResNet50_Weight
model = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)# 旧代码(会触发警告)
import torchvision
model = torchvision.models.vgg16(pretrained=True)# 新代码(推荐)
import torchvision
from torchvision.models import VGG16_Weightsmodel = torchvision.models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1)
http://www.dtcms.com/a/288107.html

相关文章:

  • Codeforces Round 1037 (Div. 3)
  • windows docker-02-docker 最常用的命令汇总
  • uniapp props、$ref、$emit、$parent、$child、$on
  • 【数据结构】栈(stack)
  • xss-labs1-8题
  • ubuntu24 ros2 jazzy
  • OpenVINO使用教程--图像增强算法DarkIR
  • 华为擎云L420安装LocalSend
  • Oracle为什么需要临时表?——让数据处理更灵活
  • LeetCode 322. 零钱兑换 LeetCode 279.完全平方数 LeetCode 139.单词拆分 多重背包基础 56. 携带矿石资源
  • 【补题】Codeforces Round 958 (Div. 2) D. The Omnipotent Monster Killer
  • 窗口(6)-QMessageBox
  • ctf.show-web习题-web4-flag获取详解、总结
  • 动态规划——状压DP经典题目
  • Weavefox 图片 1 比 1 生成前端源代码
  • 计算机网络:(十)虚拟专用网 VPN 和网络地址转换 NAT
  • 详细阐述 TCP、UDP、ICMPv4 和 ICMPv6 协议-以及防火墙端口原理优雅草卓伊凡
  • 【王树森推荐系统】推荐系统涨指标的方法04:多样性
  • sql练习二
  • 模型自信度提升:增强输出技巧
  • 《Spring Boot 插件化架构实战:从 SPI 到热插拔的三级跳》
  • 6. 装饰器模式
  • 教育科技内容平台的破局之路:从组织困境到 UGC 生态的构建
  • 我是怎么设计一个订单号生成策略的(库存系统)
  • 带root权限_新魔百和cm311-5_gk6323不分代工通刷优盘强刷及线刷
  • Openlayers 面试题及答案180道(141-160)
  • JavaScript 中的继承
  • MySQL——约束类型
  • 【RK3576】【Android14】分区划分
  • Java行为型模式---中介者模式