当前位置: 首页 > news >正文

用Python玩转人工智能——手搓图像分类模型

目录

  一、预训练模型原理与状态字典

  (一)预训练模型原理

  (二)状态字典(state_dict)

  (三)模型保存与加载示例

  二、加载 ImageNet 预训练模型

  三、数据准备与可视化

  (一)加载数据集

  (二)使用 Matplotlib 可视化数据

  四、模型训练

  五、使用 torchvision 进行模型微调

  六、观察模型预测结果

  七、固定模型参数

  八、使用 TensorBoard 可视化训练结果

  九、课程总结

  (一)技术要点

  (二)难点

  十、随堂练习题


  一、预训练模型原理与状态字典

  (一)预训练模型原理

  预训练模型是深度学习中的重要技术,其核心思想是利用大规模数据集(如 ImageNet)训练好的模型,将其学习到的通用特征表示迁移到其他相关任务中。这种方法可以显著减少训练时间和数据需求,并提高模型在小数据集上的性能。

  预训练模型的优势:

  • 特征复用:底层卷积层学习到的边缘、纹理等特征对多种任务都有帮助
  • 训练加速:无需从头开始训练,只需微调少数层
  • 小数据友好:在数据有限的场景下也能取得良好效果

  (二)状态字典(state_dict)

  在 PyTorch 中,state_dict是一个 Python 字典对象,它存储了模型中每个层的可学习参数(权重和偏置)。只有具有可学习参数的层(如卷积层、全连接层)才有对应的state_dict条目。

  关键函数

  • model.state_dict():返回模型的状态字典
  • torch.save(state_dict, path):保存状态字典到文件
  • model.load_state_dict(torch.load(path)):从文件加载状态字典

  (三)模型保存与加载示例

import torch
import torch.nn as nn# 定义一个简单模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.conv1 = nn.Conv2d(3, 16, 3)  # 输入3通道,输出16通道,卷积核3x3self.fc1 = nn.Linear(16*10*10, 10)  # 全连接层,输入维度1600,输出10def forward(self, x):x = self.conv1(x)x = x.view(-1, 16*10*10)  # 展平为一维向量x = self.fc1(x)return x# 创建模型实例并保存
model = SimpleModel()
torch.save(model.state_dict(), 'model_weights.pth')  # 保存为.pth文件# 加载模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
loaded_model = SimpleModel()
loaded_model.load_state_dict(torch.load('model_weights.pth', map_location=device))  # 加载时指定设备
loaded_model.to(device)  # 将模型移至指定设备
loaded_model.eval()  # 设置为评估模式(关闭Dropout等训练专用层)

  注意事项

  1. 保存 / 加载时应使用state_dict而非整个模型,这样更灵活且节省空间
  2. 加载模型时需先创建相同结构的模型实例
  3. 使用map_location参数可指定加载到 CPU 还是 GPU
  4. 推理前需调用model.eval()将模型设置为评估模式

  二、加载 ImageNet 预训练模型

import torchvision.models as models# 加载预训练的ResNet18模型
resnet18 = models.resnet18(pretrained=True)  # pretrained=True表示加载预训练权重# 查看模型结构
print(resnet18)# 修改最后一层以适应我们的分类任务(假设是10类分类)
num_ftrs = resnet18.fc.in_features  # 获取最后一层的输入特征数
resnet18.fc = nn.Linear(num_ftrs, 10)  # 替换最后一层为新的全连接层,输出10类

  使用的库

  torchvision.models:提供了常用的预训练模型(如 ResNet、VGG、AlexNet 等)

  关键函数

  models.resnet18(pretrained=True):加载预训练的 ResNet18 模型

  nn.Linear(in_features, out_features):创建全连接层

  注意事项

  不同模型的最后一层名称可能不同(如 ResNet 是fc,VGG 是classifier[6])

  修改最后一层时需保持输入特征数不变

  第一次加载预训练模型时会自动下载权重文件,需确保网络连接

  三、数据准备与可视化

相关文章:

  • 【PhysUnits】13 改进减法(sub.rs)
  • 【加密算法】
  • 从“被动养老”到“主动健康管理”:平台如何重构代际关系?
  • Odoo 条码功能全面深度解析(VIP15万字版)
  • LiveNVR :实现非国标流转国标流的全方位解决方案
  • 勾股数的性质和应用
  • 接地气的方式认识JVM(一)
  • 通过teamcity cloud创建你的一个build
  • 【C语言】详解 指针
  • Java开发之定时器学习
  • 欧拉角转为旋转矩阵
  • 二叉树的锯齿形层序遍历——灵活跳跃的层次结构解析
  • w~视觉~合集6
  • 自我觉察是成长的第一步,如何构建内心的平静
  • 【线程与进程区别】
  • Spring AI框架快速入门
  • 华为OD机试真题——最佳的出牌方法(2025A卷:200分)Java/python/JavaScript/C/C++/GO最佳实现
  • SAR ADC 比较器的offset 校正
  • 加密协议知多少
  • STP(生成树协议)原理与配置
  • 移动端网站开发公司/抖音广告怎么投放
  • 日本女做受网站/吴忠seo
  • 网站建设的软件介绍/网站查询关键词排名软件
  • 移动端网页设计图片/seo顾问阿亮
  • 新乡高端网站建设/网上怎么做广告
  • 做传销网站的程序员犯法吗/seo怎么才能做好