知识蒸馏实战:用PyTorch和预训练模型提升小模型性能
在深度学习的浪潮中,我们常常追求更大、更深、更复杂的模型以达到最先进的性能。然而,这些“庞然大物”般的模型往往伴随着高昂的计算成本和缓慢的推理速度,使得它们难以部署在资源受限的环境中,如移动设备或边缘计算平台。知识蒸馏(Knowledge Distillation)技术为此提供了一个优雅的解决方案:将一个大型、高性能的“教师模型”所学习到的“知识”迁移到一个小巧、高效的“学生模型”中。
本篇将一步步使用 PyTorch 实现一个知识蒸馏的案例,其中教师模型将采用预训练模型。
什么是知识蒸馏?
知识蒸馏的核心思想是,训练一个小型学生模型 (Student Model) 来模仿一个大型教师模型 (Teacher Model) 的行为。这种模仿不仅仅是学习教师模型对“硬标签”(即真实标签)的预测,更重要的是学习教师模型输出的“软标签”(Soft Targets)。
- 教师模型 (Teacher Model): 通常是一个已经训练好的、性能优越的大型模型。例如,在计算机视觉领域,可以是 ImageNet 上预训练的 ResNet、VGG 等。
- 学生模型 (Student Model): 一个参数量较小、计算更高效的轻量级模型,我们希望它能达到接近教师模型的性能。
- 软标签 (Soft Targets): 教师模型在输出层(softmax之前,即logits)经过一个较高的“温度”(Temperature, T)调整后的概率分布。高温会使概率分布更平滑,从而揭示类别间的相似性信息,这些被称为“暗知识”(Dark Knowledge)。
- 硬标签 (Hard Targets): 数据集的真实标签。
- 蒸馏损失 (Distillation Loss): 通常由两部分组成:
- 学生模型在真实标签上的损失(例如交叉熵损失)。
- 学生模型与教师模型软标签之间的损失(例如KL散度或均方误差)。
这两部分损失通过一个超参数 a l p h a \\alpha alpha 来加权平衡。
PyTorch 实现步骤
接下来,我们将通过一个图像分类的例子来演示如何实现知识蒸馏。假设我们的任务是对一个包含10个类别的图像数据集进行分类。
1. 准备工作:导入库和设置设备
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms # 用于数据预处理# 检查是否有可用的 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
2. 定义教师模型 (Pre-trained ResNet18)
我们将使用 torchvision.models
中预训练的 ResNet18 作为教师模型。为了适应我们自定义的分类任务(例如10分类),我们需要替换其原始的1000类全连接层。
class PretrainedTeacherModel(nn.Module):def __init__(self, num_classes, pretrained=True):super(PretrainedTeacherModel, self).__init__()# 加载预训练的 ResNet18 模型# PyTorch 1.9+ 推荐使用 weights 参数if pretrained:self.resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)else:self.resnet = models.resnet18(weights=None) # 或者 models.resnet18(pretrained=False) for older versions# 获取 ResNet18 原本的输出特征数num_ftrs = self.resnet.fc.in_features# 替换最后的全连接层以适应我们的任务类别数self.resnet.fc = nn.Linear(num_ftrs, num_classes)def forward(self, x):return self.resnet(x)
在蒸馏过程中,教师模型的参数通常是固定的,不参与训练。
3. 定义学生模型
学生模型应该是一个比教师模型更小、更轻量的网络。这里我们定义一个简单的卷积神经网络 (CNN)。
class StudentCNNModel(nn.Module):def __init__(self, num_classes):super(StudentCNNModel, self).__init__()# 输入通道数为3 (RGB图像), 假设输入图像大小为 32x32self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)self.relu1 = nn.ReLU()self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 32x32 -> 16x16self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # 16x16 -> 8x8# 展平后的特征数: 32 channels * 8 * 8self.fc = nn.Linear(32 * 8 * 8, num_classes)def forward(self, x):out = self.pool1(self.relu1(self.conv1(x)))out = self.pool2(self.relu2(self.conv2(x)))out = out.view(out.size(0), -1) # 展平out = self.fc(out)return out
4. 定义蒸馏损失函数
这是知识蒸馏的核心。损失函数结合了学生模型在硬标签上的性能和与教师模型软标签的匹配程度。
- L _ h a r d L\_{hard} L_hard: 学生模型输出与真实标签之间的交叉熵损失。
- L _ s o f t L\_{soft} L_soft: 学生模型的软化输出与教师模型的软化输出之间的KL散度。
- 总损失 L = a l p h a c d o t L _ h a r d + ( 1 − a l p h a ) c d o t L _ s o f t c d o t T 2 L = \\alpha \\cdot L\_{hard} + (1 - \\alpha) \\cdot L\_{soft} \\cdot T^2 L=alphacdotL_hard+(1−alpha)cdotL_softcdotT2
- T T T 是温度参数。较高的 T T T 会使概率分布更平滑。
- a l p h a \\alpha alpha 是平衡两个损失项的权重。
- L _ s o f t L\_{soft} L_soft 乘以 T 2 T^2 T2 是为了确保软标签损失的梯度与硬标签损失的梯度在量级上大致相当。
class DistillationLoss(nn.Module):def __init__(self, alpha, temperature):super(DistillationLoss, self).__init__()self.alpha = alphaself.temperature = temperatureself.criterion_hard = nn.CrossEntropyLoss() # 硬标签损失# reduction='batchmean' 会将KL散度在batch维度上取平均,这在很多实现中是常见的self.criterion_soft = nn.KLDivLoss(reduction='batchmean') # 软标签损失def forward(self, student_logits, teacher_logits, labels):# 硬标签损失loss_hard = self.criterion_hard(student_logits, labels)# 软标签损失# 使用 softmax 和 temperature 来计算软标签和软预测# 注意:KLDivLoss期望的输入是 (log_probs, probs)soft_teacher_probs = F.softmax(teacher_logits / self.temperature, dim=1)soft_student_log_probs = F.log_softmax(student_logits / self.temperature, dim=1)# 计算KL散度损失loss_soft = self.criterion_soft(soft_student_log_probs, soft_teacher_probs) * (self.temperature ** 2)# 总损失loss = self.alpha * loss_hard + (1 - self.alpha) * loss_softreturn loss
5. 训练流程
现在我们将所有部分组合起来进行训练。
# --- 示例参数 ---
num_classes = 10 # 假设我们的任务是10分类
img_channels = 3
img_height = 32
img_width = 32learning_rate = 0.001
num_epochs = 20 # 实际应用中需要更多 epochs 和真实数据
batch_size = 32
temperature = 4.0 # 蒸馏温度
alpha = 0.3 # 硬标签损失的权重# --- 实例化模型 ---
teacher_model = PretrainedTeacherModel(num_classes=num_classes, pretrained=True).to(device)
teacher_model.eval() # 教师模型设为评估模式,不更新其权重student_model = StudentCNNModel(num_classes=num_classes).to(device)# --- 准备优化器和损失函数 ---
optimizer = optim.Adam(student_model.parameters(), lr=learning_rate) # 只优化学生模型的参数
distillation_criterion = DistillationLoss(alpha=alpha, temperature=temperature).to(device)# --- 生成一些虚拟图像数据进行演示 ---
# !!! 警告: 实际应用中必须使用真实数据加载器 (DataLoader) 和正确的预处理 !!!
# 预训练模型通常对输入有特定的归一化要求。
# 例如,ImageNet预训练模型通常使用:
# normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# 并且输入尺寸也需要匹配,或进行适当调整。
# 本例中学生模型接收 32x32 输入,教师模型(ResNet)通常处理更大图像如 224x224。
# 为简化,我们假设教师模型能处理学生模型的输入尺寸,或者在教师模型前对输入进行适配。
dummy_inputs = torch.randn(batch_size, img_channels, img_height, img_width).to(device)
dummy_labels = torch.randint(0, num_classes, (batch_size,)).to(device)print("开始训练学生模型...")
# --- 训练学生模型 ---
for epoch in range(num_epochs):student_model.train() # 学生模型设为训练模式# 获取教师模型的输出 (logits)with torch.no_grad(): # 教师模型的权重不更新# 如果教师模型和学生模型期望的输入尺寸不同,需要适配# teacher_input_adjusted = F.interpolate(dummy_inputs, size=(224, 224), mode='bilinear', align_corners=False) # 示例调整# teacher_logits = teacher_model(teacher_input_adjusted)teacher_logits = teacher_model(dummy_inputs) # 假设教师模型可以处理此尺寸或已适配# 前向传播 - 学生模型student_logits = student_model(dummy_inputs)# 计算蒸馏损失loss = distillation_criterion(student_logits, teacher_logits, dummy_labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()if (epoch + 1) % 5 == 0 or epoch == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')print("学生模型训练完成!")# (可选) 保存学生模型
# torch.save(student_model.state_dict(), 'student_cnn_distilled.pth')
# print("蒸馏后的学生CNN模型已保存。")
关键点与最佳实践
- 数据预处理: 对于预训练的教师模型,其输入数据必须经过与预训练时相同的预处理(如归一化、尺寸调整)。这是确保教师模型发挥其最佳性能并传递有效知识的关键。
- 输入兼容性: 确保教师模型和学生模型接收的输入在语义上是一致的。如果它们的网络结构原生接受不同尺寸的输入,你可能需要调整输入数据(例如,通过插值
F.interpolate
)以适应教师模型,或者确保两个模型都能处理相同的输入。 - 超参数调优:
alpha
,temperature
,learning_rate
等超参数对蒸馏效果至关重要。通常需要通过实验来找到最佳组合。较高的temperature
可以让学生学习到更多类别间的细微差别,但过高可能会导致信息模糊。 - 教师模型的选择: 教师模型越强大,通常能传递的知识越多。但也要考虑其推理成本(即使只在训练时)。
- 学生模型的设计: 学生模型不应过于简单,以至于无法吸收教师的知识;也不应过于复杂,从而失去蒸馏的意义。
- 训练时长: 知识蒸馏通常需要足够的训练轮次才能让学生模型充分学习。
- 不仅仅是 Logits: 本文介绍的是最常见的基于 Logits 的蒸馏。还有其他蒸馏方法,例如匹配教师模型和学生模型中间层的特征表示(Feature Distillation),这有时能带来更好的效果。