当前位置: 首页 > news >正文

模型压缩与迁移:基于蒸馏技术的实战教程

1.前言

        模型蒸馏(Model Distillation),又称为知识蒸馏(Knowledge Distillation),是一种将大型、复杂的模型(通常称为教师模型,Teacher Model)的知识转移到小型、简单模型(通常称为学生模型,Student Model)上的技术。以下是模型蒸馏的介绍、出现原因及其作用:

(1)模型蒸馏的介绍

  1. 基本概念

    • 教师模型:一个已经训练好的、性能优异的大模型。
    • 学生模型:一个较小、较简单的模型,目标是学习教师模型的行为和知识。
    • 软标签(Soft Labels):教师模型输出的概率分布,而不是简单的类别标签,这些概率分布包含了教师模型关于输入数据的丰富信息。
  2. 训练过程

    • 训练教师模型直到它达到较高的准确率。
    • 使用教师模型的输出(软标签)来训练学生模型。
    • 学生模型同时学习硬标签(实际类别标签)和软标签,以此来模拟教师模型的行为。

(2)模型蒸馏为什么会出现

        模型蒸馏的出现主要是为了解决以下问题:

  1. 模型部署:大型模型在移动设备或嵌入式系统上部署时,由于计算资源有限,难以运行。
  2. 计算效率:大型模型在训练和推理过程中需要大量的计算资源,导致速度慢、成本高。
  3. 能源消耗:大型模型在数据中心运行时消耗大量电力,不符合节能减排的要求。

(3)模型蒸馏的作用

  1. 模型压缩:通过蒸馏,可以将大型模型压缩成小型模型,减少模型的参数数量,降低存储和计算需求。
  2. 性能保持:学生模型在保持较小规模的同时,能够尽可能地接近教师模型的性能。
  3. 加速推理:小型模型在推理时更快,适用于需要快速响应的应用场景。
  4. 降低能耗:小型模型在运行时消耗更少的计算资源,有助于降低能源消耗。
  5. 跨模型迁移:蒸馏技术可以用于将知识从一个领域的模型迁移到另一个领域,实现跨领域学习。

2.准备训练代码

(1) 定义模型结构

import torch.nn as nn
import torch


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None,
                 groups=1, width_per_group=64):
        super(Bottleneck, self).__init__()

        width = int(out_channel * (width_per_group / 64.)) * groups

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self,
                 block,
                 blocks_num,
                 num_classes=1000,
                 include_top=True,
                 groups=1,
                 width_per_group=64):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        self.groups = groups
        self.width_per_group = width_per_group

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel,
                            channel,
                            downsample=downsample,
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel,
                                channel,
                                groups=self.groups,
                                width_per_group=self.width_per_group))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        return x


def resnet34(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)

(2)训练代码

temperature

  • 这个参数用于调节教师模型和学生模型输出logits的软化程度。在代码中,temperature 被设置为 5.0。
  • 在蒸馏过程中,教师和学生的logits通过除以温度值来软化,这有助于在训练学生模型时更好地捕捉教师模型的概率分布。
  • 温度值较高时,概率分布更加平滑,有助于学生模型学习;温度值较低时,概率分布更尖锐,更接近硬标签。

loss_function

  • 这是一个用于计算蒸馏损失的函数,代码中使用的是 nn.KLDivLoss,它是Kullback-Leibler散度损失,用于测量两个概率分布之间的差异。
  • reduction='batchmean' 表示损失是通过对批次中的所有样本求平均来减少的。

student_loss_function

  • 这是用于计算学生模型在真实标签上的分类损失的函数,代码中使用的是 nn.CrossEntropyLoss,这是多分类问题中常用的损失函数。

loss 和 student_loss

  • loss 是蒸馏损失,它是通过比较软化后的学生logits和教师logits来计算的。
  • student_loss 是学生模型在真实标签上的分类损失。
  • 这两个损失通过加权平均组合起来,形成最终的训练损失,其中蒸馏损失和分类损失的权重都是0.5。

optimizer

  • 这是用于优化学生模型参数的优化器,代码中使用的是 optim.Adam,它是一种自适应学习率的优化算法。
  • params 是学生模型中需要优化的参数列表。
import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from torchvision import models
from model import resnet34


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    # data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
    # image_path = os.path.join(data_root, "data_set", "flower_data")
    image_path = "/home/trq/data/Test5_resnet/flower_data"
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 16
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

    # Load teacher model
    teacher_net = resnet34(num_classes=5).to(device)
    tearcher_model_weight_path = "resNet34.pth"
    assert os.path.exists(tearcher_model_weight_path), f"File '{tearcher_model_weight_path}' does not exist."
    teacher_net.load_state_dict(torch.load(tearcher_model_weight_path, map_location="cpu"),strict=False)
    teacher_net.to(device)


    # Load student model
    student_net = models.resnet18(pretrained=False)
    student_model_weight_path = "resnet18-f37072fd.pth"
    assert os.path.exists(student_model_weight_path), "file {} does not exist.".format(student_model_weight_path)
    student_net.load_state_dict(torch.load(student_model_weight_path, map_location="cpu"))
    student_net.fc = nn.Linear(student_net.fc.in_features, 5)
    student_net.to(device)
    

    # Distillation loss function
    loss_function = nn.KLDivLoss(reduction='batchmean')
    student_loss_function = nn.CrossEntropyLoss()

    # Optimizer for the student model
    params = [p for p in student_net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=0.0001)

    epochs = 30
    best_acc = 0.0
    save_path = ('./distilled_ConvNet.pth')
    train_steps = len(train_loader)
    temperature = 5.0  # Temperature for distillation

    for epoch in range(epochs):
        student_net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            teacher_logits = teacher_net(images.to(device))
            student_logits = student_net(images.to(device))

            # Soften the logits
            teacher_logits = teacher_logits / temperature
            student_logits = student_logits / temperature

            # Compute the distillation loss
            loss = loss_function(torch.nn.functional.log_softmax(student_logits, dim=1),
                                 torch.nn.functional.softmax(teacher_logits, dim=1)) * (temperature ** 2)

            # Compute the classification loss
            student_loss = student_loss_function(student_logits, labels.to(device))

            # Combine losses
            loss = 0.5 * loss + 0.5 * student_loss
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        student_net.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = student_net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                           epochs)

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(student_net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

(3)模型和数据集的下载链接

        包含resnet18模型和resnet34模型,class_indices.json,图像等相关数据

https://pan.baidu.com/s/1ZDCbichDcdaiAH6kxYNsIA

提取码: svv5 

3.自建模型训练使用蒸馏技术训练自建模型

(1)模型结构-model_10.py

import torch
from torch import nn


class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        
        # 定义10层卷积
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),  # 输入通道数为3,输出通道数为32
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))  # 添加自适应平均池化层
        # 全连接层
        self.fc_layers = nn.Sequential(
            nn.Linear(512 * 1 * 1, 1024),  # 根据MaxPool的使用次数和输入图像大小计算得来的维度
            nn.ReLU(),
            nn.Linear(1024, 5)  # 输出层,5分类
        )
    
    def forward(self, x):
        x = self.conv_layers(x)
        x = self.adaptive_pool(x)  # 应用自适应池化
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        return x

(2)自建模型训练-train-10.py

import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from model_10 import ConvNet

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
    
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
    
    
    image_path = "/home/trq/data/Test5_resnet/flower_data"
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)
    
    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)
    
    batch_size = 16
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)
    
    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)
    
    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))
    
    net = ConvNet()
    weights_path = "ConvNet.pth"
    assert os.path.exists(weights_path), f"File '{weights_path}' does not exist."
    # model.load_state_dict(torch.load(weights_path, map_location="cpu"))
    state_dict = torch.load(weights_path, map_location="cpu")
    net.load_state_dict(state_dict,strict=False)
    net.to(device)
    
    # define loss function
    loss_function = nn.CrossEntropyLoss()
    
    # construct an optimizer
    params = [p for p in net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=0.0001)
    
    epochs = 30
    best_acc = 0.0
    save_path = './ConvNet.pth'
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            logits = net(images.to(device))
            loss = loss_function(logits, labels.to(device))
            loss.backward()
            optimizer.step()
            
            # print statistics
            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)
        
        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                # loss = loss_function(outputs, test_labels)
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)

        
        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))
        
        
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
    
    print('Finished Training')


if __name__ == '__main__':
    main()

  (3)训练结果

        训练60epoch后的结果,模型val_accuracy: 0.780已经是最高了

train epoch[1/30] loss:0.971: 100%|██████████| 207/207 [00:08<00:00, 24.01it/s]
valid epoch[1/30]: 100%|██████████| 23/23 [00:00<00:00, 31.44it/s]
[epoch 1] train_loss: 0.623  val_accuracy: 0.742
train epoch[2/30] loss:0.368: 100%|██████████| 207/207 [00:07<00:00, 26.76it/s]
valid epoch[2/30]: 100%|██████████| 23/23 [00:00<00:00, 33.18it/s]
[epoch 2] train_loss: 0.604  val_accuracy: 0.736
train epoch[3/30] loss:0.661: 100%|██████████| 207/207 [00:07<00:00, 26.76it/s]
valid epoch[3/30]: 100%|██████████| 23/23 [00:00<00:00, 32.38it/s]
[epoch 3] train_loss: 0.614  val_accuracy: 0.723
train epoch[4/30] loss:0.797: 100%|██████████| 207/207 [00:07<00:00, 26.66it/s]
valid epoch[4/30]: 100%|██████████| 23/23 [00:00<00:00, 31.70it/s]
[epoch 4] train_loss: 0.619  val_accuracy: 0.725
train epoch[5/30] loss:0.809: 100%|██████████| 207/207 [00:07<00:00, 26.87it/s]
valid epoch[5/30]: 100%|██████████| 23/23 [00:00<00:00, 32.26it/s]
[epoch 5] train_loss: 0.594  val_accuracy: 0.698
train epoch[6/30] loss:0.302: 100%|██████████| 207/207 [00:07<00:00, 26.81it/s]
valid epoch[6/30]: 100%|██████████| 23/23 [00:00<00:00, 32.49it/s]
[epoch 6] train_loss: 0.591  val_accuracy: 0.728
train epoch[7/30] loss:0.708: 100%|██████████| 207/207 [00:07<00:00, 26.60it/s]
valid epoch[7/30]: 100%|██████████| 23/23 [00:00<00:00, 33.09it/s]
[epoch 7] train_loss: 0.589  val_accuracy: 0.720
train epoch[8/30] loss:0.709: 100%|██████████| 207/207 [00:07<00:00, 26.73it/s]
valid epoch[8/30]: 100%|██████████| 23/23 [00:00<00:00, 32.55it/s]
[epoch 8] train_loss: 0.575  val_accuracy: 0.734
train epoch[9/30] loss:0.691: 100%|██████████| 207/207 [00:07<00:00, 26.61it/s]
valid epoch[9/30]: 100%|██████████| 23/23 [00:00<00:00, 34.43it/s]
[epoch 9] train_loss: 0.555  val_accuracy: 0.734
train epoch[10/30] loss:0.442: 100%|██████████| 207/207 [00:07<00:00, 26.81it/s]
valid epoch[10/30]: 100%|██████████| 23/23 [00:00<00:00, 32.91it/s]
[epoch 10] train_loss: 0.548  val_accuracy: 0.703
train epoch[11/30] loss:0.363: 100%|██████████| 207/207 [00:07<00:00, 26.46it/s]
valid epoch[11/30]: 100%|██████████| 23/23 [00:00<00:00, 30.53it/s]
[epoch 11] train_loss: 0.550  val_accuracy: 0.728
train epoch[12/30] loss:0.519: 100%|██████████| 207/207 [00:07<00:00, 26.19it/s]
valid epoch[12/30]: 100%|██████████| 23/23 [00:00<00:00, 33.14it/s]
[epoch 12] train_loss: 0.545  val_accuracy: 0.734
train epoch[13/30] loss:0.478: 100%|██████████| 207/207 [00:07<00:00, 26.48it/s]
valid epoch[13/30]: 100%|██████████| 23/23 [00:00<00:00, 32.75it/s]
[epoch 13] train_loss: 0.532  val_accuracy: 0.755
train epoch[14/30] loss:0.573: 100%|██████████| 207/207 [00:07<00:00, 26.68it/s]
valid epoch[14/30]: 100%|██████████| 23/23 [00:00<00:00, 33.40it/s]
[epoch 14] train_loss: 0.542  val_accuracy: 0.747
train epoch[15/30] loss:0.595: 100%|██████████| 207/207 [00:07<00:00, 26.68it/s]
valid epoch[15/30]: 100%|██████████| 23/23 [00:00<00:00, 34.54it/s]
[epoch 15] train_loss: 0.542  val_accuracy: 0.758
train epoch[16/30] loss:0.191: 100%|██████████| 207/207 [00:07<00:00, 26.83it/s]
valid epoch[16/30]: 100%|██████████| 23/23 [00:00<00:00, 32.04it/s]
[epoch 16] train_loss: 0.532  val_accuracy: 0.761
train epoch[17/30] loss:0.566: 100%|██████████| 207/207 [00:07<00:00, 26.60it/s]
valid epoch[17/30]: 100%|██████████| 23/23 [00:00<00:00, 33.56it/s]
[epoch 17] train_loss: 0.523  val_accuracy: 0.739
train epoch[18/30] loss:0.509: 100%|██████████| 207/207 [00:07<00:00, 26.79it/s]
valid epoch[18/30]: 100%|██████████| 23/23 [00:00<00:00, 30.35it/s]
[epoch 18] train_loss: 0.526  val_accuracy: 0.742
train epoch[19/30] loss:0.781: 100%|██████████| 207/207 [00:07<00:00, 26.60it/s]
valid epoch[19/30]: 100%|██████████| 23/23 [00:00<00:00, 31.58it/s]
[epoch 19] train_loss: 0.506  val_accuracy: 0.764
train epoch[20/30] loss:0.336: 100%|██████████| 207/207 [00:07<00:00, 26.64it/s]
valid epoch[20/30]: 100%|██████████| 23/23 [00:00<00:00, 33.95it/s]
[epoch 20] train_loss: 0.537  val_accuracy: 0.764
train epoch[21/30] loss:0.475: 100%|██████████| 207/207 [00:07<00:00, 26.65it/s]
valid epoch[21/30]: 100%|██████████| 23/23 [00:00<00:00, 33.27it/s]
[epoch 21] train_loss: 0.511  val_accuracy: 0.764
train epoch[22/30] loss:0.513: 100%|██████████| 207/207 [00:07<00:00, 26.53it/s]
valid epoch[22/30]: 100%|██████████| 23/23 [00:00<00:00, 32.16it/s]
[epoch 22] train_loss: 0.482  val_accuracy: 0.761
train epoch[23/30] loss:0.172: 100%|██████████| 207/207 [00:07<00:00, 26.62it/s]
valid epoch[23/30]: 100%|██████████| 23/23 [00:00<00:00, 33.02it/s]
[epoch 23] train_loss: 0.501  val_accuracy: 0.761
train epoch[24/30] loss:1.127: 100%|██████████| 207/207 [00:07<00:00, 26.54it/s]
valid epoch[24/30]: 100%|██████████| 23/23 [00:00<00:00, 34.24it/s]
[epoch 24] train_loss: 0.492  val_accuracy: 0.755
train epoch[25/30] loss:0.905: 100%|██████████| 207/207 [00:07<00:00, 26.76it/s]
valid epoch[25/30]: 100%|██████████| 23/23 [00:00<00:00, 30.22it/s]
[epoch 25] train_loss: 0.492  val_accuracy: 0.758
train epoch[26/30] loss:1.044: 100%|██████████| 207/207 [00:07<00:00, 26.75it/s]
valid epoch[26/30]: 100%|██████████| 23/23 [00:00<00:00, 33.86it/s]
[epoch 26] train_loss: 0.476  val_accuracy: 0.777
train epoch[27/30] loss:0.552: 100%|██████████| 207/207 [00:07<00:00, 26.73it/s]
valid epoch[27/30]: 100%|██████████| 23/23 [00:00<00:00, 31.55it/s]
[epoch 27] train_loss: 0.465  val_accuracy: 0.745
train epoch[28/30] loss:0.387: 100%|██████████| 207/207 [00:07<00:00, 26.68it/s]
valid epoch[28/30]: 100%|██████████| 23/23 [00:00<00:00, 32.30it/s]
[epoch 28] train_loss: 0.482  val_accuracy: 0.769
train epoch[29/30] loss:0.251: 100%|██████████| 207/207 [00:07<00:00, 26.69it/s]
valid epoch[29/30]: 100%|██████████| 23/23 [00:00<00:00, 32.98it/s]
[epoch 29] train_loss: 0.466  val_accuracy: 0.777
train epoch[30/30] loss:0.368: 100%|██████████| 207/207 [00:07<00:00, 26.57it/s]
valid epoch[30/30]: 100%|██████████| 23/23 [00:00<00:00, 31.95it/s]
[epoch 30] train_loss: 0.467  val_accuracy: 0.780
Finished Training

(4)蒸馏训练

import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from model import resnet34
from model_10 import ConvNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    # data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
    # image_path = os.path.join(data_root, "data_set", "flower_data")
    image_path = "/home/trq/data/Test5_resnet/flower_data"
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 16
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))


    teacher_net = resnet34(num_classes=5).to(device)
    tearcher_model_weight_path = "resNet34.pth"
    assert os.path.exists(tearcher_model_weight_path), f"File '{tearcher_model_weight_path}' does not exist."
    teacher_net.load_state_dict(torch.load(tearcher_model_weight_path, map_location="cpu"),strict=False)
    teacher_net.to(device)

    
    # Load student model
    student_net = ConvNet()
    student_model_weight_path = "ConvNet.pth"
    assert os.path.exists(student_model_weight_path), "file {} does not exist.".format(student_model_weight_path)
    student_net.load_state_dict(torch.load(student_model_weight_path, map_location="cpu"))
    student_net.to(device)

    # Distillation loss function
    loss_function = nn.KLDivLoss(reduction='batchmean')
    student_loss_function = nn.CrossEntropyLoss()

    # Optimizer for the student model
    params = [p for p in student_net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=0.0001)

    epochs = 30
    best_acc = 0.0
    save_path = ('./distilled_ConvNet.pth')
    train_steps = len(train_loader)
    temperature = 5.0  # Temperature for distillation

    for epoch in range(epochs):
        student_net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            teacher_logits = teacher_net(images.to(device))
            student_logits = student_net(images.to(device))

            # Soften the logits
            teacher_logits = teacher_logits / temperature
            student_logits = student_logits / temperature

            # Compute the distillation loss
            loss = loss_function(torch.nn.functional.log_softmax(student_logits, dim=1),
                                 torch.nn.functional.softmax(teacher_logits, dim=1)) * (temperature ** 2)

            # Compute the classification loss
            student_loss = student_loss_function(student_logits, labels.to(device))

            # Combine losses
            loss = 0.5 * loss + 0.5 * student_loss
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        student_net.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = student_net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                           epochs)

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(student_net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

       没有截屏,可以自己试试,测试了自建模型训练30epoch后接着蒸馏训练30epoch,val_accuracy可以到达0.81.

(5)模型文件

https://pan.baidu.com/s/1gVTJPvAQ3oDEZcGYoJvuLw

提取码: ddk5 

4.总结

        如果模型结果简单,可以使用蒸馏训练提升模型的准确性,当然要先训练一个教师模型.

相关文章:

  • 本地化智能运维助手:基于 LangChain 数据增强 和 DeepSeek-R1 的K8s运维文档检索与问答系统 Demo
  • 【C++游戏引擎开发】《线性代数》(2):矩阵加减法与SIMD集成
  • JAVA学习笔记——第十二章 异常
  • 【Mysql】深入剖析 MySQL 死锁问题及应对策略
  • 项目-苍穹外卖(十四) Spring Task+订单状态定时处理
  • Langchain4j实现本地RAG和联网查询
  • 网络中常用协议
  • 【机器学习】基础知识
  • Ubuntu Linux安装PyQt5并配置Qt Designer
  • 面试记录3
  • IoT平台实时监测机器人状态的实现方案
  • Ubuntu24.04 离线安装 MySQL8.0.41
  • 零基础如何学习自动化测试
  • RAGFlow部署时遇到的mysql unhealthy问题解决方案汇总
  • 108.在 Vue 3 中使用 OpenLayers 加载 XYZ 地图的示例
  • [微信小程序]对接sse接口
  • 安装 pgsql 将gis数据入库
  • SpringMVC 入门教程
  • Elasticsearch:人工智能时代的公共部门数据治理
  • vue 图片放大到全局
  • 国内做心理咨询师培训出名的网站/seo排名规则
  • 网站建设和维护待遇怎样/优化游戏卡顿的软件
  • 做网站哪里接单/淘宝怎么优化关键词步骤
  • 个人跨境电商赚钱吗/新手学seo
  • php可以做视频网站/营销策划公司 品牌策划公司
  • 域名 去掉wordpress/网络优化的基本方法