利用迁移学习实现食物分类:基于PyTorch与ResNet18的实战案例
利用迁移学习实现食物分类:基于PyTorch与ResNet18的实战案例
在深度学习领域,训练一个高性能的模型往往需要大量的数据和计算资源。然而,通过迁移学习,我们能够巧妙地利用在大规模数据集上预训练好的模型,将其知识迁移到我们特定的任务中,不仅可以大幅减少训练时间和数据需求,还能取得出色的效果。本文将以食物分类为例,详细介绍如何使用PyTorch和ResNet18进行迁移学习。
一、迁移学习概述
迁移学习的核心思想是将在一个任务(源任务)中学习到的知识,应用到另一个相关任务(目标任务)中。在计算机视觉领域,许多预训练模型,如ResNet、VGG等,已经在大规模图像数据集(如ImageNet)上进行了充分训练,学习到了丰富的图像特征表示。这些预训练模型的底层网络结构能够提取通用的图像特征,如边缘、纹理等,而顶层网络结构则与源任务的类别紧密相关。因此,在目标任务中,我们可以保留预训练模型的底层结构,仅对顶层进行微调,使其适应目标任务的分类需求。
二、食物分类项目实现
1. 环境与库导入
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch import nn
import torchvision.models as models
from PIL import Image
import numpy as np
上述代码导入了项目所需的核心库。torch
是PyTorch的核心库,用于构建和训练深度学习模型;DataLoader
和Dataset
用于数据的加载和管理;transforms
用于对图像进行预处理;nn
是PyTorch的神经网络模块;models
包含了各种预训练模型;Image
用于处理图像;numpy
用于数值计算。
2. 加载预训练模型并调整结构
resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
for param in resnet_model.parameters():param.requires_grad = False
in_features = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(in_features, 20)
params_to_update = []
for param in resnet_model.parameters():if param.requires_grad == True:params_to_update.append(param)
首先,通过models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
加载在ImageNet数据集上预训练好的ResNet18模型。然后,将模型的所有参数的requires_grad
属性设置为False
,冻结模型的参数,避免在训练过程中对其进行更新。接着,获取原模型全连接层的输入特征个数in_features
,并将原全连接层替换为一个新的全连接层,输出维度为20,对应食物分类任务的20个类别。最后,筛选出需要更新的参数,即新添加的全连接层的参数。
3. 数据准备与预处理
food_type = {0: "八宝粥", 1: "巴旦木", 2: "白萝卜", 3: "板栗", 4: "菠萝", 5: "草莓", 6: "蛋", 7: "蛋挞", 8: "骨肉相连",9: "瓜子", 10: "哈密瓜", 11: "汉堡", 12: "胡萝卜", 13: "火龙果", 14: "鸡翅", 15: "青菜", 16: "生肉", 17: "圣女果", 18: "薯条", 19: "炸鸡"}
data_transforms = {'train':transforms.Compose([transforms.Resize([300, 300]),transforms.RandomRotation(45),transforms.CenterCrop(224),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.RandomGrayscale(p=0.1),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'valid':transforms.Compose([transforms.Resize([224, 224]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}class food_dataset(Dataset):def __init__(self, file_path, transform=None):self.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs)def __getitem__(self, idx):image = Image.open(self.imgs[idx])if self.transform:image = self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label, dtype=np.int64))return image, labeltraining_data = food_dataset(file_path='trainda.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path='testda.txt', transform=data_transforms['valid'])train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
定义了食物类别字典food_type
,以及训练集和验证集的图像预处理操作。训练集的预处理包括调整图像大小、随机旋转、中心裁剪、随机水平和垂直翻转、随机灰度化、转换为张量以及标准化;验证集的预处理相对简单,仅进行调整大小、转换为张量和标准化。
创建自定义的数据集类food_dataset
,继承自Dataset
类,实现了__init__
、__len__
和__getitem__
方法,用于读取数据文件、获取数据集大小以及加载和预处理图像。最后,使用DataLoader
将训练集和测试集封装为可迭代的数据加载器,方便在训练和测试过程中按批次获取数据。
4. 模型训练与测试
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
model = resnet_model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params_to_update, lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)def train(dataloader, model, loss_fn, optimizer):model.train()for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()best_acc = 0
def test(dataloader, model, loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeresult = zip(pred.argmax(1).tolist(), y.tolist())for i in result:print(f"当前测试的结果为:{food_type[i[0]]},当前真实的结果为:{food_type[i[1]]}")print(f"Test result:\n Accurracy:{(100 * correct)}%,AVG loss:{test_loss}")test_loss /= num_batchescorrect /= sizeif correct > best_acc:best_acc = correctepoch = 10
acc_s = []
loss_s = []
for i in range(epoch):print(i + 1)train(train_dataloader, model, loss_fn, optimizer)scheduler.step()test(test_dataloader, model, loss_fn)
print('最终训练结果:', best_acc)
首先,根据当前设备是否支持GPU或苹果M系列芯片的GPU,选择合适的计算设备,并将模型移动到该设备上。定义交叉熵损失函数loss_fn
、Adam优化器optimizer
以及学习率调整策略scheduler
。
train
函数用于模型的训练,在训练过程中,将数据传入设备,进行前向传播计算预测值,计算损失,通过反向传播计算梯度并更新模型参数。test
函数用于模型的测试,在测试过程中,将模型设置为评估模式,关闭梯度计算,计算测试集上的损失和准确率,并输出每个样本的预测结果和真实结果。
最后,通过循环进行多个epoch的训练和测试,在每个epoch结束后调整学习率,并记录最佳准确率。
三、总结
通过本次食物分类项目,我们成功地运用迁移学习技术,基于预训练的ResNet18模型完成了特定任务。这种方法不仅减少了训练时间和数据需求,还展示了迁移学习在实际应用中的强大能力。在未来的深度学习项目中,迁移学习将继续发挥重要作用,帮助我们更高效地解决各种复杂的问题。同时,我们还可以进一步探索不同的预训练模型、调整超参数以及优化数据预处理方法,以提升模型的性能。