数据增强和微调
这几天看小说去了😋摆了两天😭
codes
文章目录
- 📊 9. 数据增强和微调
- 🎯 前言
- 🌈 数据增强技术
- 9.1 随机裁剪 ✂️
- `transforms.RandomCrop`实战
- 9.2 水平翻转 🔄
- `transforms.RandomHorizontalFlip`应用
- 9.3 色彩调整 🎨
- `transforms.ColorJitter`综合应用
- 🧩 实战:构建增强型数据集
- 完整数据处理流程
- 🚀 双变换策略
- 创建dataloader
- 🧠 迁移学习与微调技术
- VGG16预训练模型初始化
- ⚙️ 微调核心操作
- 📈 性能可视化
- 损失曲线对比
- 准确率变化趋势
- 💎 核心经验总结
- 可能遇到的问题
📊 9. 数据增强和微调
🎯 前言
数据增强(Data augmentation)是指对训练数据进行变换以增加训练数据量,提高模型的泛化能力。微调(Fine-tuning)则是迁移学习中的关键技术,通过解冻预训练模型的部分层进行二次训练,显著提升模型在特定任务上的性能。
🌈 数据增强技术
数据增强的核心方法:
- 🔄 翻转:水平或垂直翻转图像
- ✂️ 裁剪:随机裁剪图像区域
- 🔄 旋转:随机旋转图像角度
- 🔍 缩放:随机改变图像大小
- ↔️ 平移:随机移动图像位置
- 📶 噪声:添加随机噪点
- 🎨 颜色抖动:改变图像色彩属性
- ☀️ 亮度/对比度/饱和度调整
- ⚙️ 其他方法:模糊、锐化等
💡 重要原则:仅对训练集进行数据增强,测试集和验证集保持原始状态,确保评估结果客观准确
9.1 随机裁剪 ✂️
transforms.RandomCrop
实战
# 读取一张图片
pil_img=Image.open("/root/autodl-tmp/dataset2/cloudy9.jpg" ) # 读取图片
# 设置裁剪变换
transform = transforms.Compose([transforms.Resize((256, 256)), # 调整图片尺寸transforms.RandomCrop((224, 224)) # 随机裁剪
])# 应用变换并显示结果
plt.figure(figsize=(12, 8))
for i in range(6):img = transform(pil_img)plt.subplot(2, 3, i + 1)plt.imshow(img)
plt.show()
效果说明:每次执行都会生成不同位置的裁剪结果,有效增加数据多样性
9.2 水平翻转 🔄
transforms.RandomHorizontalFlip
应用
# 100%概率执行水平翻转
trans_img = transforms.RandomHorizontalFlip(p=1)(pil_img)# 对比显示原图与翻转图
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(pil_img) # 原图
plt.subplot(1, 2, 2)
plt.imshow(trans_img) # 翻转图
plt.show()
关键参数:p
值控制翻转概率,推荐设置为0.5实现随机翻转效果
9.3 色彩调整 🎨
transforms.ColorJitter
综合应用
明暗、颜色、对比度、饱和度
# 设置色彩抖动参数
trandform = transforms.ColorJitter(brightness=(0.7, 1.3), # 亮度调整范围contrast=(0.7, 1.3), # 对比度调整范围saturation=(0.7, 1.3), # 饱和度调整范围hue=(-0.05, 0.05) # 色调调整范围
)# 生成多样本展示
for i in range(6):img = trandform(pil_img)plt.subplot(2, 3, i + 1)plt.imshow(img)
plt.show()
参数详解:
brightness
:亮度调整系数contrast
:对比度调整系数saturation
:饱和度调整系数hue
:色调偏移值(范围-0.5~0.5)
🧩 实战:构建增强型数据集
完整数据处理流程
# 定义数据集类
class WT_Dataset(data.Dataset):def __init__(self, imgs_path, labels, transform):self.imgs_path = imgs_pathself.labels = labelsself.transform = transformdef __len__(self):return len(self.imgs_path)def __getitem__(self, index):img_path = self.imgs_path[index]label = self.labels[index]pil_img = Image.open(img_path).convert('RGB')return self.transform(pil_img), label# 划分训练集/测试集(8:2比例)
imgs=glob.glob(r'/root/autodl-tmp/dataset2/*.jpg')
# 上面是读取图片的路径,/*.jpg表示读取所有jpg格式的图片,imgs是一个列表,里面包含了所有图片的路径。species=['cloudy','rain','shine','sunrise'] #4 classes
# 字典推导式获取类别到编号的映射关系
species_to_idx=dict((c,i) for i,c in enumerate(species))
# 字典推导式获取编号到类别的映射关系
idx_to_species=dict((i,c) for i,c in enumerate(species))# 下面提取图片路径列表对应的标签列表
labels=[]
for img in imgs:for i,c in enumerate(species):if c in img: # 判断图片路径是否包含某个种类的名称labels.append(i)# 为了随即划分训练数据和测试数据,我们先设置一个乱序的index
# 同时对图片和路径使用index进行乱序,这样保证了乱序之后图片和路径是一一对应的
np.random.seed(2025) # 设置随机种子,保证每次运行结果一致
index=np.random.permutation(len(imgs)) # 随机生成index
'''
permutation函数返回一个乱序的index,这里我们将其赋值给index变量,index是一个一维数组,包含了0到len(imgs)-1的所有整数,顺序是随机的。
'''
imgs=np.array(imgs)[index] # 将图片路径按照index进行乱序labels=np.array(labels)[index] # 将标签按照index进行乱序# 对乱序后的数据,直接切片选取前80%的数据作为训练集,后20%的数据作为测试集。
seq=int(len(imgs)*0.8) # 计算80%的数据量
train_imgs=imgs[:seq] # 切片选取前80%的数据作为训练集
train_labels=labels[:seq] # 切片选取前80%的数据的标签作为训练集的标签
test_imgs=imgs[seq:] # 切片选取后20%的数据作为测试集
test_labels=labels[seq:] # 切片选取后20%的数据的标签作为测试集的标签
🚀 双变换策略
训练集增强变换:
train_transform = transforms.Compose([transforms.Resize((256, 256)),transforms.RandomCrop((224, 224)),transforms.RandomHorizontalFlip(p=0.5),transforms.ColorJitter(brightness=(0.7, 1.3),contrast=(0.7, 1.3),saturation=(0.7, 1.3),hue=(-0.05, 0.05)),transforms.ToTensor(),transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
测试集基础变换:
test_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
创建dataloader
# 使用train_transform初始化训练数据的dataset
train_ds=WT_Dataset(train_imgs, train_labels, train_transform)
# 使用test_transform初始化验证数据的dataset
test_ds=WT_Dataset(test_imgs, test_labels, test_transform)train_dl=data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl=data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
🧠 迁移学习与微调技术
这一部分只讲核心代码,具体的模型训练代码可以去我的gitee上看
VGG16预训练模型初始化
# 加载预训练模型(新式写法)
from torchvision.models import VGG16_Weights
model = torchvision.models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1)# 冻结卷积层参数 🔒
for param in model.features.parameters():param.requires_grad = False# 修改输出层为四分类
model.classifier[-1].out_features = 4
⚙️ 微调核心操作
# 解冻所有层参数 🔓
for param in model.parameters():param.requires_grad = True# 使用更小的学习率
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)# 继续训练
extend_epochs = 15
train_loss_, train_acc_, test_loss_, test_acc_ = fit(extend_epochs, ...)
微调效果:通过解冻参数+继续训练,模型准确率平均提升10-15%
📈 性能可视化
损失曲线对比
plt.plot(range(1, epochs+1), train_loss, label='train_loss')
plt.plot(range(1, epochs+1), test_loss, label='test_loss')
plt.legend()
plt.show()
准确率变化趋势
plt.plot(range(1, epochs+extend_epochs+1), train_acc+train_acc_, label='train_acc')
plt.plot(range(1, epochs+extend_epochs+1), test_acc+test_acc_, label='test_acc')
plt.legend()
plt.show()
💎 核心经验总结
- 数据增强黄金组合:裁剪+翻转+色彩抖动
- 微调最佳实践:
- 先冻结训练分类器
- 解冻后使用更小学习率(1e-5量级)
- 适当增加训练轮次(10-20 epochs)
- 效果验证:微调后测试准确率普遍提升10%以上
- 注意事项:
- 务必设置随机种子保证可复现性
- 验证集/测试集禁用数据增强
- 监控过拟合现象(训练/验证损失分化)
可能遇到的问题
在使用某个深度学习库(如 torchvision)时,pretrained参数已被弃用,需要改用weights参数。
例如
# 旧代码(会触发警告)
import torchvision
model = torchvision.models.resnet50(pretrained=True)# 新代码(推荐)
import torchvision
from torchvision.models import ResNet50_Weight
model = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)# 旧代码(会触发警告)
import torchvision
model = torchvision.models.vgg16(pretrained=True)# 新代码(推荐)
import torchvision
from torchvision.models import VGG16_Weightsmodel = torchvision.models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1)