当前位置: 首页 > news >正文

大模型微调步骤整理


在对深度学习模型进行微调时,我通常会遵循以下几个通用步骤。

第一步是选择一个合适的预训练模型。PyTorch 的 torchvision.models 模块提供了很多经典的预训练模型,比如 ResNet、VGG、EfficientNet 等。我们可以直接使用它们作为模型的基础结构。例如,加载一个预训练的 ResNet50 可以这样写:

import torchvision.models as models
model = models.resnet50(pretrained=True)

第二步是准备数据集。我会使用 torchvision.datasets 来加载常见的图像分类数据集,比如 ImageFolder 用于自定义文件夹结构的数据集。然后通过 DataLoader 将其转换为可迭代的批次数据,方便训练:

from torchvision import datasets, transforms
from torch.utils.data import DataLoadertransform = transforms.Compose([transforms.Resize(224),transforms.ToTensor()
])dataset = datasets.ImageFolder(root='data_path', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

第三步是根据具体任务调整模型结构。比如我们的人脸识别任务有新的类别数量,就需要替换掉原来的全连接层(fc)。比如把 ResNet 中的最后输出改为我们自己的类别数:

num_classes = 100  # 假设我们有100个类别
model.fc = nn.Linear(512, num_classes)

有时候我们还会冻结前面的一些层,只训练最后几层,尤其是在数据量较小的情况下,可以防止过拟合:

for param in model.parameters():param.requires_grad = False  # 冻结所有层
for param in model.fc.parameters():param.requires_grad = True   # 只训练最后一层

第四步是定义损失函数和优化器。常用的损失函数是交叉熵损失,优化器可以选择 Adam 或 SGD,同时还可以配合学习率调度器一起使用:

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.para

相关文章:

  • 第9章 组件及事件处理
  • Mac 在恢复模式下出现 旋转地球图标 但进度非常缓慢
  • Oracle 内存优化
  • java中的Servlet3.x详解
  • sparkSQL读入csv文件写入mysql
  • 10.8 LangChain三大模块深度实战:从模型交互到企业级Agent工具链全解析
  • 多模态大语言模型arxiv论文略读(八十一)
  • SuperYOLO:多模态遥感图像中的超分辨率辅助目标检测之论文阅读
  • 贪心算法应用:最大匹配问题详解
  • 算法岗实习八股整理——深度学习篇(不断更新中)
  • 软件工程各种图总结
  • MySQL开发规范
  • 互联网大厂Java面试:从Spring到微服务的深度探讨
  • 大模型deepseek与知识图谱的实践
  • 【数据结构】2-3-3单链表的查找
  • 离散文本表示
  • spark数据处理练习题详解【下】
  • [论文品鉴] DeepSeek V3 最新论文 之 MHA、MQA、GQA、MLA
  • Linux编译rpm包与deb包
  • 用 UniApp 开发 TilePuzzle:一个由 CodeBuddy 主动驱动的拼图小游戏
  • 巴基斯坦副总理兼外长达尔将访华
  • 新华社千笔楼:地方文旅宣传应走出“魔性尬舞”的流量焦虑
  • 广西隆林突发山洪,致3人遇难1人失联
  • 工人日报:应对“职场肥胖”,健康与减重同受关注
  • 终于,俄罗斯和乌克兰谈上了
  • 时隔三年,俄乌直接谈判重启