深度学习:自定义数据集处理、数据增强与最优模型管理
目录
一、整体流程概述
二、自定义数据集处理
1. 数据集结构设计
2. 数据集索引文件生成
三、自定义数据集加载类
四、数据增强策略
五、卷积神经网络模型构建
六、模型训练与最优模型管理
1. 训练与测试函数
2. 训练主流程
七、模型加载与推理
一、整体流程概述
深度学习项目通常遵循 "数据→模型→训练→部署" 的闭环流程,本文以食物识别任务为例,详解从自定义数据集处理到模型训练、优化及调用的完整流程。核心环节包括:
- 自定义数据集构建与预处理
- 数据增强策略设计
- 卷积神经网络 (CNN) 模型构建
- 模型训练与最优模型保存
- 训练后模型调用与推理
二、自定义数据集处理
1. 数据集结构设计
通常采用 "类别目录 + 图片文件" 的组织结构:
food_dataset/
├── train/
│ ├── 苹果/
│ │ ├── apple1.jpg
│ │ └── apple2.jpg
│ └── 香蕉/
│ └── ...
└── test/└── ...
2. 数据集索引文件生成
需要将图片路径与标签映射关系保存为文本文件(如train.txt
、test.txt
),方便模型加载。
核心代码(食物识别 1 - 预处理.py):
import osdef train_test_file(root, dir):file_txt = open(dir + '.txt', 'w+', encoding='utf-8')file_canzhao = open('canzhao.txt', 'w+', encoding='utf-8')path = os.path.join(root, dir)dirs = [] # 存储类别目录名recorded_classes = set() # 去重集合for roots, directories, files in os.walk(path):# 获取一级子目录作为类别if not dirs and roots == path:dirs = directoriesif files: # 处理图片文件current_class = roots.split(os.sep)[-1] # 提取类别名class_index = dirs.index(current_class) # 分配类别索引# 写入图片路径和标签for file in files:img_path = os.path.join(roots, file)file_txt.write(f"{img_path} {class_index}\n")# 写入类别-索引映射(去重)if current_class not in recorded_classes:file_canzhao.write(f"{current_class} {class_index}\n")recorded_classes.add(current_class)file_txt.close()file_canzhao.close()# 生成训练集和测试集索引文件
root = r'..\\food_dataset'
train_test_file(root, 'train')
train_test_file(root, 'test')
..\\food_dataset\train\八宝粥\img_八宝粥罐_22.jpeg 0
..\\food_dataset\train\八宝粥\img_八宝粥罐_29.jpeg 0
..\\food_dataset\train\八宝粥\img_八宝粥罐_65.jpeg 0
..\\food_dataset\train\八宝粥\img_八宝粥罐_68.jpeg 0
..\\food_dataset\train\八宝粥\img_八宝粥罐_84.jpeg 0
..\\food_dataset\train\八宝粥\img_八宝粥罐_98.jpeg 0
..\\food_dataset\train\哈密瓜\img_水果_103.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_13.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_136.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_142.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_163.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_174.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_191.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_209.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_238.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_30.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_34.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_42.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_44.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_57.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_81.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_92.jpeg 1
..\\food_dataset\train\圣女果\img_圣女果_104.jpeg 2
..\\food_dataset\train\圣女果\img_圣女果_105.jpeg 2
..\\food_dataset\train\圣女果\img_圣女果_116.jpeg 2
..\\food_dataset\train\圣女果\img_圣女果_12.jpeg 2
..\\food_dataset\train\圣女果\img_圣女果_13.jpeg 2
略
知识点解析:
os.walk()
:递归遍历目录结构,获取所有文件路径os.sep
:适配不同操作系统的路径分隔符(Windows 用\
,Linux 用/
)- 类别索引映射:通过
dirs.index(current_class)
建立类别与数字索引的映射,便于模型处理
三、自定义数据集加载类
使用 PyTorch 的Dataset
类封装数据集,实现数据的按需加载。
核心代码:
from torch.utils.data import Dataset
from PIL import Image
import torch
import numpy as npclass food_dataset(Dataset):def __init__(self, file_path, transform=None):self.imgs = []self.labels = []self.transform = transform# 从索引文件读取数据with open(file_path, 'r', encoding='utf-8') as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs) # 返回样本总数def __getitem__(self, idx):# 读取图片并应用转换image = Image.open(self.imgs[idx]).convert('RGB') # 确保RGB格式if self.transform:image = self.transform(image)# 标签转换为Tensorlabel = torch.from_numpy(np.array(self.labels[idx], dtype=np.int64))return image, label
知识点解析:
Dataset
抽象类:必须实现__len__
(返回样本数)和__getitem__
(获取单个样本)方法- 延迟加载:仅在需要时才读取图片,节省内存
- 数据转换接口:通过
transform
参数灵活接入数据增强 pipeline
四、数据增强策略
数据增强是提升模型泛化能力的关键手段,通过对训练数据进行随机变换,增加数据多样性。
核心代码(食物识别 1 - 数据增强.py):
from torchvision import transformsdata_transforms = {'train': transforms.Compose([transforms.Resize((300, 300)), # 缩放至300x300transforms.RandomRotation(45), # 随机旋转(-45°~45°)transforms.CenterCrop(256), # 中心裁剪至256x256transforms.RandomHorizontalFlip(p=0.5), # 50%概率水平翻转transforms.RandomVerticalFlip(p=0.5), # 50%概率垂直翻转transforms.ColorJitter( # 颜色抖动brightness=0.2, # 亮度contrast=0.1, # 对比度saturation=0.1, # 饱和度hue=0.1 # 色调),transforms.RandomGrayscale(p=0.1), # 10%概率转为灰度图transforms.ToTensor(), # 转为Tensor(0-1范围)transforms.Normalize( # 标准化[0.485, 0.456, 0.406], # 均值[0.229, 0.224, 0.225] # 标准差)]),'valid': transforms.Compose([ # 验证集仅做必要转换transforms.Resize((256, 256)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}
知识点解析:
- 训练集增强原则:保留语义信息的同时增加多样性(旋转、翻转、颜色变化等)
- 验证集处理:仅做必要的尺寸调整和标准化,保证评估一致性
- 标准化(Normalize):使用 ImageNet 数据集的均值和标准差,使输入分布更稳定
Compose
:将多个变换组合成一个 pipeline,按顺序执行
五、卷积神经网络模型构建
设计适用于食物识别的 CNN 模型,通过卷积层提取视觉特征,全连接层完成分类。
核心代码:
import torch
from torch import nnclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()# 卷积块1:特征提取+降维self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3, # 输入通道数(RGB)out_channels=16, # 卷积核数量kernel_size=5, # 卷积核大小5x5stride=1, # 步长padding=2 # 填充,保持尺寸),nn.ReLU(), # 激活函数nn.MaxPool2d(kernel_size=2) # 2x2池化,尺寸减半)# 卷积块2:加深特征提取self.conv2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),nn.ReLU())# 卷积块3:进一步提取+降维self.conv3 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2))# 卷积块4:高级特征提取self.conv4 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),nn.ReLU())# 全连接层:分类输出self.line = nn.Linear(128 * 64 * 64, 20) # 20类食物def forward(self, x):# 前向传播路径x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = self.conv4(x)x = x.view(x.size(0), -1) # 展平特征图out = self.line(x)return out
知识点解析:
- 卷积层作用:通过滑动窗口提取局部特征(边缘、纹理、形状等)
- 池化层作用:降低特征图尺寸,减少参数数量,增强平移不变性
- 激活函数(ReLU):引入非线性,使模型能拟合复杂特征关系
- 全连接层:将卷积提取的高维特征映射到类别空间
六、模型训练与最优模型管理
1. 训练与测试函数
# 训练函数
def train(dataloader, model, loss_fn, optimizer):model.train() # 训练模式(启用dropout等)batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device) # 数据移至计算设备pred = model(X) # 前向传播loss = loss_fn(pred, y) # 计算损失# 反向传播与参数更新optimizer.zero_grad() # 清空梯度loss.backward() # 计算梯度optimizer.step() # 更新参数# 打印训练进度if batch_size_num % 100 == 1:print(f"Loss: {loss.item():.4f} [batch:{batch_size_num}]")batch_size_num += 1# 测试函数(含最优模型保存)
best_acc = 0
def test(dataloader, model, loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)correct = 0loss_sum = 0model.eval() # 评估模式(关闭dropout等)with torch.no_grad(): # 关闭梯度计算for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)loss_sum += loss.item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()acc = correct / sizelatest_loss = loss_sum / num_batchesprint(f"Accuracy: {(acc * 100)}%, Loss: {latest_loss:.4f}")# 保存最优模型if acc > best_acc:best_acc = acctorch.save(model, 'best1.pt') # 保存完整模型print(f"保存最优模型,准确率:{best_acc*100:.2f}%")
2. 训练主流程
# 设备选择
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f'Using device: {device}')# 加载数据集
train_data = food_dataset('train.txt', data_transforms['train'])
test_data = food_dataset('test.txt', data_transforms['valid'])
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)# 初始化模型、损失函数和优化器
model = CNN().to(device)
loss_fn = nn.CrossEntropyLoss() # 多分类损失
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) # Adam优化器# 执行训练
epochs = 100
print('训练开始')
for epoch in range(epochs):print(f"Epoch {epoch + 1}/{epochs}")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
print(f'训练结束, 最佳准确率:{best_acc*100:.2f}%')
知识点解析:
- 设备加速:自动选择 GPU(CUDA/MPS)或 CPU 进行计算
- 训练模式与评估模式:
model.train()
和model.eval()
控制 dropout、BN 层等行为 - 梯度管理:
optimizer.zero_grad()
避免梯度累积,with torch.no_grad()
节省评估时内存 - 最优模型保存:通过跟踪验证集准确率,只保存表现最好的模型,避免过拟合模型
七、模型加载与推理
训练完成后,加载最优模型进行实际预测。
核心代码(食物识别 1 - 调用最优模型.py):
# 加载模型
device = "cuda" if torch.cuda.is_available() else "cpu"
model = torch.load('best1.pt', map_location=device) # 加载完整模型
model.eval() # 切换到评估模式# 加载类别映射
def load_class_mapping(canzhao_path='canzhao.txt'):index_to_name = {}with open(canzhao_path, 'r', encoding='utf-8') as f:for line in f.readlines():food_name, index = line.strip().split(' ')index_to_name[int(index)] = food_namereturn index_to_name# 单张图片预测
def predict_image(image_path, model, index_to_name):try:# 图片预处理(与验证集一致)transform = transforms.Compose([transforms.Resize((256, 256)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])image = Image.open(image_path).convert('RGB')image = transform(image).unsqueeze(0) # 增加批次维度image = image.to(device)# 推理with torch.no_grad():pred = model(image)pred_index = pred.argmax(1).item() # 获取最高概率类别return index_to_name.get(pred_index, "未知类别")except Exception as e:return f"预测失败:{str(e)}"# 交互预测
index_to_name = load_class_mapping()
while True:img_path = input("请输入图片路径(q退出):")if img_path.lower() == 'q':breakprint("预测结果:", predict_image(img_path, model, index_to_name))
知识点解析:
- 模型加载:
torch.load()
加载保存的模型,map_location
适配不同计算设备 - 推理预处理:必须与训练时的验证集处理完全一致,否则会导致分布不匹配
- 批次维度:模型输入需为
(batch_size, channel, height, width)
,通过unsqueeze(0)
添加批次维度 - 类别映射:将模型输出的数字索引转换为实际类别名称