day46 python预训练模型补充
目录
一、预训练模型的背景知识
二、实验过程
(一)实验环境与数据准备
(二)预训练模型的选择与适配
(三)训练策略
三、实验结果与分析
四、学习总结与展望
一、预训练模型的背景知识
在传统的神经网络训练中,模型的参数是随机初始化的,这可能导致训练初期的不稳定,并且容易陷入局部最优解。而预训练模型的出现,为这一问题提供了有效的解决方案。预训练模型是在大规模数据集(如 ImageNet)上预先训练好的模型,它已经学习到了丰富的通用特征。当我们面临一个新的图像分类任务时,可以直接利用这些预训练好的模型参数来初始化我们的模型,这样模型在初始阶段就具备了一定的特征提取能力,能够更快地收敛,并且在一定程度上避免了局部最优解的问题。
预训练模型的选择至关重要。首先,预训练任务与目标任务的相似性是关键因素。如果两个任务在特征层面具有相似性,那么预训练模型提取的特征将对目标任务更有帮助。其次,预训练数据集的规模也非常重要。大规模的数据集能够支撑模型学习到更通用的特征,从而在不同的任务中具有更好的泛化能力。例如,ImageNet 数据集拥有 1000 个类别,1.2 亿张图像,尺寸为 224x224,是一个非常适合用于预训练的大规模图像数据集。
二、实验过程
(一)实验环境与数据准备
本次实验使用的是 PyTorch 深度学习框架,借助其丰富的预训练模型库和便捷的数据处理工具,能够高效地完成模型的加载、训练和测试。实验中使用的 CIFAR-10 数据集是一个经典的图像分类数据集,包含 10 个类别,共 60000 张 32x32 的彩色图像,其中训练集有 50000 张图像,测试集有 10000 张图像。由于 CIFAR-10 图像的尺寸较小,且类别相对较少,因此直接在该数据集上训练模型可能会面临过拟合等问题,而预训练模型的引入则有望缓解这一问题。
在数据预处理阶段,为了增强模型的泛化能力,对训练集进行了多种数据增强操作,包括随机裁剪、随机水平翻转、颜色抖动以及随机旋转等。这些操作能够在训练过程中为模型提供更多的“干扰”或变形,使模型能够学习到更加鲁棒的特征。具体的数据预处理代码如下:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题# 检查 GPU 是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 数据预处理(训练集增强,测试集标准化)
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 加载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root='./data',train=True,download=True,transform=train_transform
)test_dataset = datasets.CIFAR10(root='./data',train=False,transform=test_transform
)# 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
对于测试集,则仅进行了标准化处理,以确保其与训练集在数据分布上具有一致性。
(二)预训练模型的选择与适配
在本次实验中,选择了 ResNet18 作为预训练模型。ResNet18 是一种经典的卷积神经网络架构,具有 18 层深度,并且通过残差连接解决了深层网络训练中的梯度消失问题。它在 ImageNet 数据集上预训练后,能够提取出具有较强表达能力的特征。
由于 ResNet18 在 ImageNet 数据集上预训练时,其输入图像尺寸为 224x224,而 CIFAR-10 图像的尺寸为 32x32,因此需要对模型进行适当的调整以适配 CIFAR-10 数据集。具体来说,需要修改模型的最后一层全连接层,将其输出类别数从 1000 改为 10,以匹配 CIFAR-10 的类别数量。此外,由于输入图像尺寸的变化,还需要调整模型中的一些层的参数,例如池化层的步长等。以下是 ResNet18 模型的适配代码:
from torchvision.models import resnet18# 定义 ResNet18 模型(支持预训练权重加载)
def create_resnet18(pretrained=True, num_classes=10):model = resnet18(pretrained=pretrained)in_features = model.fc.in_featuresmodel.fc = nn.Linear(in_features, num_classes)return model.to(device)# 创建 ResNet18 模型(加载 ImageNet 预训练权重,不进行微调)
model = create_resnet18(pretrained=True, num_classes=10)
model.eval() # 设置为推理模式
(三)训练策略
在训练过程中,采用了阶段式训练策略。首先,冻结模型的卷积层参数,仅训练全连接层,这样可以在不破坏预训练模型特征提取能力的前提下,快速调整模型的输出层以适应 CIFAR-10 数据集。经过一定轮次的训练后,解冻模型的所有参数,进行整体训练,以进一步提升模型的性能。这种策略能够在训练初期快速降低损失,并在后续训练中充分利用预训练模型的特征提取能力,实现更好的收敛效果。
具体来说,实验中设置了前 5 轮冻结卷积层参数,之后解冻所有参数进行训练。在解冻后,为了防止过拟合,还适当降低了学习率。以下是训练函数的关键代码:
# 冻结/解冻模型层的函数
def freeze_model(model, freeze=True):"""冻结或解冻模型的卷积层参数"""for name, param in model.named_parameters():if 'fc' not in name:param.requires_grad = not freezefrozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)total_params = sum(p.numel() for p in model.parameters())if freeze:print(f"已冻结模型卷积层参数 ({frozen_params}/{total_params} 参数)")else:print(f"已解冻模型所有参数 ({total_params}/{total_params} 参数可训练)")return model# 训练函数(支持阶段式训练)
def train_with_freeze_schedule(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs, freeze_epochs=5):model = freeze_model(model, freeze=True)for epoch in range(epochs):if epoch == freeze_epochs:model = freeze_model(model, freeze=False)optimizer.param_groups[0]['lr'] = 1e-4 # 解冻后调整学习率model.train()running_loss = 0.0correct_train = 0total_train = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = output.max(1)total_train += target.size(0)correct_train += predicted.eq(target).sum().item()epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct_train / total_trainmodel.eval()correct_test = 0total_test = 0test_loss = 0.0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total_test += target.size(0)correct_test += predicted.eq(target).sum().item()epoch_test_loss = test_loss / len(test_loader)epoch_test_acc = 100. * correct_test / total_testif scheduler is not None:scheduler.step(epoch_test_loss)print(f"Epoch {epoch+1} 完成 | 训练损失: {epoch_train_loss:.4f} | 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%")
三、实验结果与分析
经过 40 轮的训练,最终测试准确率达到了 86.30%。从训练过程的输出中可以看出,在解冻卷积层参数后,模型的训练损失迅速下降,训练准确率和测试准确率都得到了显著提升。这充分证明了预训练模型的强大优势,即使在 CIFAR-10 这种相对较小的数据集上,也能够通过微调取得优异的性能。
此外,由于训练集采用了数据增强操作,模型在训练初期可能会出现训练准确率暂时低于测试准确率的情况。这是因为数据增强增加了模型训练的难度,而测试集是标准的、未增强的图像,模型在测试集上预测相对轻松。随着训练的推进,模型逐渐适应了数据增强带来的变化,训练准确率和测试准确率之间的差距逐渐缩小。
以下是完整的训练代码:
# 主函数:训练模型
def main():# 参数设置epochs = 40 # 总训练轮次freeze_epochs = 5 # 冻结卷积层的轮次learning_rate = 1e-3 # 初始学习率weight_decay = 1e-4 # 权重衰减# 创建 ResNet18 模型(加载预训练权重)model = create_resnet18(pretrained=True, num_classes=10)# 定义优化器和损失函数optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)criterion = nn.CrossEntropyLoss()# 定义学习率调度器scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)# 开始训练(前 5 轮冻结卷积层,之后解冻)final_accuracy = train_with_freeze_schedule(model=model,train_loader=train_loader,test_loader=test_loader,criterion=criterion,optimizer=optimizer,scheduler=scheduler,device=device,epochs=epochs,freeze_epochs=freeze_epochs)print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")if __name__ == "__main__":main()
@浙大疏锦行
补充-60日计划day44,pynote中day55