用Python玩转人工智能——手搓图像分类模型
目录
一、预训练模型原理与状态字典
(一)预训练模型原理
(二)状态字典(state_dict)
(三)模型保存与加载示例
二、加载 ImageNet 预训练模型
三、数据准备与可视化
(一)加载数据集
(二)使用 Matplotlib 可视化数据
四、模型训练
五、使用 torchvision 进行模型微调
六、观察模型预测结果
七、固定模型参数
八、使用 TensorBoard 可视化训练结果
九、课程总结
(一)技术要点
(二)难点
十、随堂练习题
一、预训练模型原理与状态字典
(一)预训练模型原理
预训练模型是深度学习中的重要技术,其核心思想是利用大规模数据集(如 ImageNet)训练好的模型,将其学习到的通用特征表示迁移到其他相关任务中。这种方法可以显著减少训练时间和数据需求,并提高模型在小数据集上的性能。
预训练模型的优势:
- 特征复用:底层卷积层学习到的边缘、纹理等特征对多种任务都有帮助
- 训练加速:无需从头开始训练,只需微调少数层
- 小数据友好:在数据有限的场景下也能取得良好效果
(二)状态字典(state_dict)
在 PyTorch 中,state_dict是一个 Python 字典对象,它存储了模型中每个层的可学习参数(权重和偏置)。只有具有可学习参数的层(如卷积层、全连接层)才有对应的state_dict条目。
关键函数:
- model.state_dict():返回模型的状态字典
- torch.save(state_dict, path):保存状态字典到文件
- model.load_state_dict(torch.load(path)):从文件加载状态字典
(三)模型保存与加载示例
import torch
import torch.nn as nn# 定义一个简单模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.conv1 = nn.Conv2d(3, 16, 3) # 输入3通道,输出16通道,卷积核3x3self.fc1 = nn.Linear(16*10*10, 10) # 全连接层,输入维度1600,输出10def forward(self, x):x = self.conv1(x)x = x.view(-1, 16*10*10) # 展平为一维向量x = self.fc1(x)return x# 创建模型实例并保存
model = SimpleModel()
torch.save(model.state_dict(), 'model_weights.pth') # 保存为.pth文件# 加载模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
loaded_model = SimpleModel()
loaded_model.load_state_dict(torch.load('model_weights.pth', map_location=device)) # 加载时指定设备
loaded_model.to(device) # 将模型移至指定设备
loaded_model.eval() # 设置为评估模式(关闭Dropout等训练专用层)
注意事项:
- 保存 / 加载时应使用state_dict而非整个模型,这样更灵活且节省空间
- 加载模型时需先创建相同结构的模型实例
- 使用map_location参数可指定加载到 CPU 还是 GPU
- 推理前需调用model.eval()将模型设置为评估模式
二、加载 ImageNet 预训练模型
import torchvision.models as models# 加载预训练的ResNet18模型
resnet18 = models.resnet18(pretrained=True) # pretrained=True表示加载预训练权重# 查看模型结构
print(resnet18)# 修改最后一层以适应我们的分类任务(假设是10类分类)
num_ftrs = resnet18.fc.in_features # 获取最后一层的输入特征数
resnet18.fc = nn.Linear(num_ftrs, 10) # 替换最后一层为新的全连接层,输出10类
使用的库:
torchvision.models:提供了常用的预训练模型(如 ResNet、VGG、AlexNet 等)
关键函数:
models.resnet18(pretrained=True):加载预训练的 ResNet18 模型
nn.Linear(in_features, out_features):创建全连接层
注意事项:
不同模型的最后一层名称可能不同(如 ResNet 是fc,VGG 是classifier[6])
修改最后一层时需保持输入特征数不变
第一次加载预训练模型时会自动下载权重文件,需确保网络连接