模型训练——使用预训练权重、冻结训练
文章目录
- 前言
- 一、如何使用预训练权重
- 二、如何冻结训练
前言
一般来讲,从0开始训练效果会很差,因为权值太过随机,特征提取效果不明显,所以训练时不好收敛,尤其是针对数据较少的情况。
一、如何使用预训练权重
PyTorch提供了 state_dict() 和 load_state_dict() 两个方法用来保存和加载模型参数,前者将模型参数保存为字典形式,后者将字典形式的模型参数载入到模型当中。
使用预训练权重的步骤如下:
(1)加载预训练模型权重、读取当前模型的字典结构
(2)使用预训练模型权重的参数 更新 当前模型的参数
(3)加载更新后的当前模型参数
下面的方式就是在训练模型时的常用两种方式:一个是经验性的通用权重初始化,另一种就是使用上述的预训练权重来进行初始化。 这里值得注意的是 load_state_dict()函数中有一个strict参数,该参数决定网络在恢复过程中是严格恢复(默认是严格恢复),还是非严格恢复,如果严格恢复,则会严格匹配所有的字典,所以当前模型与预训练模型的结构必须完全相同,否则就会报错。所以大多数情况下,都是设置 strict=False ,来使其只有相同的网络层进行初始化。这一步也是迁移学习中常用的backbone初始化。
代码如下(示例)