当前位置: 首页 > wzjs >正文

海外网站服务器下载wordpress源码分析

海外网站服务器下载,wordpress源码分析,企业网站建设目标,做一个网站开发要多少钱1.前言 模型蒸馏(Model Distillation),又称为知识蒸馏(Knowledge Distillation),是一种将大型、复杂的模型(通常称为教师模型,Teacher Model)的知识转移到小型、简单模型…

1.前言

        模型蒸馏(Model Distillation),又称为知识蒸馏(Knowledge Distillation),是一种将大型、复杂的模型(通常称为教师模型,Teacher Model)的知识转移到小型、简单模型(通常称为学生模型,Student Model)上的技术。以下是模型蒸馏的介绍、出现原因及其作用:

(1)模型蒸馏的介绍

  1. 基本概念

    • 教师模型:一个已经训练好的、性能优异的大模型。
    • 学生模型:一个较小、较简单的模型,目标是学习教师模型的行为和知识。
    • 软标签(Soft Labels):教师模型输出的概率分布,而不是简单的类别标签,这些概率分布包含了教师模型关于输入数据的丰富信息。
  2. 训练过程

    • 训练教师模型直到它达到较高的准确率。
    • 使用教师模型的输出(软标签)来训练学生模型。
    • 学生模型同时学习硬标签(实际类别标签)和软标签,以此来模拟教师模型的行为。

(2)模型蒸馏为什么会出现

        模型蒸馏的出现主要是为了解决以下问题:

  1. 模型部署:大型模型在移动设备或嵌入式系统上部署时,由于计算资源有限,难以运行。
  2. 计算效率:大型模型在训练和推理过程中需要大量的计算资源,导致速度慢、成本高。
  3. 能源消耗:大型模型在数据中心运行时消耗大量电力,不符合节能减排的要求。

(3)模型蒸馏的作用

  1. 模型压缩:通过蒸馏,可以将大型模型压缩成小型模型,减少模型的参数数量,降低存储和计算需求。
  2. 性能保持:学生模型在保持较小规模的同时,能够尽可能地接近教师模型的性能。
  3. 加速推理:小型模型在推理时更快,适用于需要快速响应的应用场景。
  4. 降低能耗:小型模型在运行时消耗更少的计算资源,有助于降低能源消耗。
  5. 跨模型迁移:蒸馏技术可以用于将知识从一个领域的模型迁移到另一个领域,实现跨领域学习。

2.准备训练代码

(1) 定义模型结构

import torch.nn as nn
import torchclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channel)self.relu = nn.ReLU()self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channel)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += identityout = self.relu(out)return outclass Bottleneck(nn.Module):expansion = 4def __init__(self, in_channel, out_channel, stride=1, downsample=None,groups=1, width_per_group=64):super(Bottleneck, self).__init__()width = int(out_channel * (width_per_group / 64.)) * groupsself.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,kernel_size=1, stride=1, bias=False)  # squeeze channelsself.bn1 = nn.BatchNorm2d(width)# -----------------------------------------self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,kernel_size=3, stride=stride, bias=False, padding=1)self.bn2 = nn.BatchNorm2d(width)# -----------------------------------------self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,kernel_size=1, stride=1, bias=False)  # unsqueeze channelsself.bn3 = nn.BatchNorm2d(out_channel*self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += identityout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self,block,blocks_num,num_classes=1000,include_top=True,groups=1,width_per_group=64):super(ResNet, self).__init__()self.include_top = include_topself.in_channel = 64self.groups = groupsself.width_per_group = width_per_groupself.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,padding=3, bias=False)self.bn1 = nn.BatchNorm2d(self.in_channel)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, blocks_num[0])self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)if self.include_top:self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')def _make_layer(self, block, channel, block_num, stride=1):downsample = Noneif stride != 1 or self.in_channel != channel * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(channel * block.expansion))layers = []layers.append(block(self.in_channel,channel,downsample=downsample,stride=stride,groups=self.groups,width_per_group=self.width_per_group))self.in_channel = channel * block.expansionfor _ in range(1, block_num):layers.append(block(self.in_channel,channel,groups=self.groups,width_per_group=self.width_per_group))return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)if self.include_top:x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return xdef resnet34(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet34-333f7ec4.pthreturn ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)

(2)训练代码

temperature

  • 这个参数用于调节教师模型和学生模型输出logits的软化程度。在代码中,temperature 被设置为 5.0。
  • 在蒸馏过程中,教师和学生的logits通过除以温度值来软化,这有助于在训练学生模型时更好地捕捉教师模型的概率分布。
  • 温度值较高时,概率分布更加平滑,有助于学生模型学习;温度值较低时,概率分布更尖锐,更接近硬标签。

loss_function

  • 这是一个用于计算蒸馏损失的函数,代码中使用的是 nn.KLDivLoss,它是Kullback-Leibler散度损失,用于测量两个概率分布之间的差异。
  • reduction='batchmean' 表示损失是通过对批次中的所有样本求平均来减少的。

student_loss_function

  • 这是用于计算学生模型在真实标签上的分类损失的函数,代码中使用的是 nn.CrossEntropyLoss,这是多分类问题中常用的损失函数。

loss 和 student_loss

  • loss 是蒸馏损失,它是通过比较软化后的学生logits和教师logits来计算的。
  • student_loss 是学生模型在真实标签上的分类损失。
  • 这两个损失通过加权平均组合起来,形成最终的训练损失,其中蒸馏损失和分类损失的权重都是0.5。

optimizer

  • 这是用于优化学生模型参数的优化器,代码中使用的是 optim.Adam,它是一种自适应学习率的优化算法。
  • params 是学生模型中需要优化的参数列表。
import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from torchvision import models
from model import resnet34def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))# image_path = os.path.join(data_root, "data_set", "flower_data")image_path = "/home/trq/data/Test5_resnet/flower_data"assert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())json_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 16nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])print('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))# Load teacher modelteacher_net = resnet34(num_classes=5).to(device)tearcher_model_weight_path = "resNet34.pth"assert os.path.exists(tearcher_model_weight_path), f"File '{tearcher_model_weight_path}' does not exist."teacher_net.load_state_dict(torch.load(tearcher_model_weight_path, map_location="cpu"),strict=False)teacher_net.to(device)# Load student modelstudent_net = models.resnet18(pretrained=False)student_model_weight_path = "resnet18-f37072fd.pth"assert os.path.exists(student_model_weight_path), "file {} does not exist.".format(student_model_weight_path)student_net.load_state_dict(torch.load(student_model_weight_path, map_location="cpu"))student_net.fc = nn.Linear(student_net.fc.in_features, 5)student_net.to(device)# Distillation loss functionloss_function = nn.KLDivLoss(reduction='batchmean')student_loss_function = nn.CrossEntropyLoss()# Optimizer for the student modelparams = [p for p in student_net.parameters() if p.requires_grad]optimizer = optim.Adam(params, lr=0.0001)epochs = 30best_acc = 0.0save_path = ('./distilled_ConvNet.pth')train_steps = len(train_loader)temperature = 5.0  # Temperature for distillationfor epoch in range(epochs):student_net.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()teacher_logits = teacher_net(images.to(device))student_logits = student_net(images.to(device))# Soften the logitsteacher_logits = teacher_logits / temperaturestudent_logits = student_logits / temperature# Compute the distillation lossloss = loss_function(torch.nn.functional.log_softmax(student_logits, dim=1),torch.nn.functional.softmax(teacher_logits, dim=1)) * (temperature ** 2)# Compute the classification lossstudent_loss = student_loss_function(student_logits, labels.to(device))# Combine lossesloss = 0.5 * loss + 0.5 * student_lossloss.backward()optimizer.step()running_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)student_net.eval()acc = 0.0with torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = student_net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(student_net.state_dict(), save_path)print('Finished Training')if __name__ == '__main__':main()

(3)模型和数据集的下载链接

        包含resnet18模型和resnet34模型,class_indices.json,图像等相关数据

https://pan.baidu.com/s/1ZDCbichDcdaiAH6kxYNsIA

提取码: svv5 

3.自建模型训练使用蒸馏技术训练自建模型

(1)模型结构-model_10.py

import torch
from torch import nnclass ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()# 定义10层卷积self.conv_layers = nn.Sequential(nn.Conv2d(3, 32, kernel_size=3, padding=1),  # 输入通道数为3,输出通道数为32nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(32, 64, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))  # 添加自适应平均池化层# 全连接层self.fc_layers = nn.Sequential(nn.Linear(512 * 1 * 1, 1024),  # 根据MaxPool的使用次数和输入图像大小计算得来的维度nn.ReLU(),nn.Linear(1024, 5)  # 输出层,5分类)def forward(self, x):x = self.conv_layers(x)x = self.adaptive_pool(x)  # 应用自适应池化x = x.view(x.size(0), -1)x = self.fc_layers(x)return x

(2)自建模型训练-train-10.py

import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from model_10 import ConvNetdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}image_path = "/home/trq/data/Test5_resnet/flower_data"assert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 16nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))net = ConvNet()weights_path = "ConvNet.pth"assert os.path.exists(weights_path), f"File '{weights_path}' does not exist."# model.load_state_dict(torch.load(weights_path, map_location="cpu"))state_dict = torch.load(weights_path, map_location="cpu")net.load_state_dict(state_dict,strict=False)net.to(device)# define loss functionloss_function = nn.CrossEntropyLoss()# construct an optimizerparams = [p for p in net.parameters() if p.requires_grad]optimizer = optim.Adam(params, lr=0.0001)epochs = 30best_acc = 0.0save_path = './ConvNet.pth'train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()logits = net(images.to(device))loss = loss_function(logits, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))# loss = loss_function(outputs, test_labels)predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')if __name__ == '__main__':main()

  (3)训练结果

        训练60epoch后的结果,模型val_accuracy: 0.780已经是最高了

train epoch[1/30] loss:0.971: 100%|██████████| 207/207 [00:08<00:00, 24.01it/s]
valid epoch[1/30]: 100%|██████████| 23/23 [00:00<00:00, 31.44it/s]
[epoch 1] train_loss: 0.623  val_accuracy: 0.742
train epoch[2/30] loss:0.368: 100%|██████████| 207/207 [00:07<00:00, 26.76it/s]
valid epoch[2/30]: 100%|██████████| 23/23 [00:00<00:00, 33.18it/s]
[epoch 2] train_loss: 0.604  val_accuracy: 0.736
train epoch[3/30] loss:0.661: 100%|██████████| 207/207 [00:07<00:00, 26.76it/s]
valid epoch[3/30]: 100%|██████████| 23/23 [00:00<00:00, 32.38it/s]
[epoch 3] train_loss: 0.614  val_accuracy: 0.723
train epoch[4/30] loss:0.797: 100%|██████████| 207/207 [00:07<00:00, 26.66it/s]
valid epoch[4/30]: 100%|██████████| 23/23 [00:00<00:00, 31.70it/s]
[epoch 4] train_loss: 0.619  val_accuracy: 0.725
train epoch[5/30] loss:0.809: 100%|██████████| 207/207 [00:07<00:00, 26.87it/s]
valid epoch[5/30]: 100%|██████████| 23/23 [00:00<00:00, 32.26it/s]
[epoch 5] train_loss: 0.594  val_accuracy: 0.698
train epoch[6/30] loss:0.302: 100%|██████████| 207/207 [00:07<00:00, 26.81it/s]
valid epoch[6/30]: 100%|██████████| 23/23 [00:00<00:00, 32.49it/s]
[epoch 6] train_loss: 0.591  val_accuracy: 0.728
train epoch[7/30] loss:0.708: 100%|██████████| 207/207 [00:07<00:00, 26.60it/s]
valid epoch[7/30]: 100%|██████████| 23/23 [00:00<00:00, 33.09it/s]
[epoch 7] train_loss: 0.589  val_accuracy: 0.720
train epoch[8/30] loss:0.709: 100%|██████████| 207/207 [00:07<00:00, 26.73it/s]
valid epoch[8/30]: 100%|██████████| 23/23 [00:00<00:00, 32.55it/s]
[epoch 8] train_loss: 0.575  val_accuracy: 0.734
train epoch[9/30] loss:0.691: 100%|██████████| 207/207 [00:07<00:00, 26.61it/s]
valid epoch[9/30]: 100%|██████████| 23/23 [00:00<00:00, 34.43it/s]
[epoch 9] train_loss: 0.555  val_accuracy: 0.734
train epoch[10/30] loss:0.442: 100%|██████████| 207/207 [00:07<00:00, 26.81it/s]
valid epoch[10/30]: 100%|██████████| 23/23 [00:00<00:00, 32.91it/s]
[epoch 10] train_loss: 0.548  val_accuracy: 0.703
train epoch[11/30] loss:0.363: 100%|██████████| 207/207 [00:07<00:00, 26.46it/s]
valid epoch[11/30]: 100%|██████████| 23/23 [00:00<00:00, 30.53it/s]
[epoch 11] train_loss: 0.550  val_accuracy: 0.728
train epoch[12/30] loss:0.519: 100%|██████████| 207/207 [00:07<00:00, 26.19it/s]
valid epoch[12/30]: 100%|██████████| 23/23 [00:00<00:00, 33.14it/s]
[epoch 12] train_loss: 0.545  val_accuracy: 0.734
train epoch[13/30] loss:0.478: 100%|██████████| 207/207 [00:07<00:00, 26.48it/s]
valid epoch[13/30]: 100%|██████████| 23/23 [00:00<00:00, 32.75it/s]
[epoch 13] train_loss: 0.532  val_accuracy: 0.755
train epoch[14/30] loss:0.573: 100%|██████████| 207/207 [00:07<00:00, 26.68it/s]
valid epoch[14/30]: 100%|██████████| 23/23 [00:00<00:00, 33.40it/s]
[epoch 14] train_loss: 0.542  val_accuracy: 0.747
train epoch[15/30] loss:0.595: 100%|██████████| 207/207 [00:07<00:00, 26.68it/s]
valid epoch[15/30]: 100%|██████████| 23/23 [00:00<00:00, 34.54it/s]
[epoch 15] train_loss: 0.542  val_accuracy: 0.758
train epoch[16/30] loss:0.191: 100%|██████████| 207/207 [00:07<00:00, 26.83it/s]
valid epoch[16/30]: 100%|██████████| 23/23 [00:00<00:00, 32.04it/s]
[epoch 16] train_loss: 0.532  val_accuracy: 0.761
train epoch[17/30] loss:0.566: 100%|██████████| 207/207 [00:07<00:00, 26.60it/s]
valid epoch[17/30]: 100%|██████████| 23/23 [00:00<00:00, 33.56it/s]
[epoch 17] train_loss: 0.523  val_accuracy: 0.739
train epoch[18/30] loss:0.509: 100%|██████████| 207/207 [00:07<00:00, 26.79it/s]
valid epoch[18/30]: 100%|██████████| 23/23 [00:00<00:00, 30.35it/s]
[epoch 18] train_loss: 0.526  val_accuracy: 0.742
train epoch[19/30] loss:0.781: 100%|██████████| 207/207 [00:07<00:00, 26.60it/s]
valid epoch[19/30]: 100%|██████████| 23/23 [00:00<00:00, 31.58it/s]
[epoch 19] train_loss: 0.506  val_accuracy: 0.764
train epoch[20/30] loss:0.336: 100%|██████████| 207/207 [00:07<00:00, 26.64it/s]
valid epoch[20/30]: 100%|██████████| 23/23 [00:00<00:00, 33.95it/s]
[epoch 20] train_loss: 0.537  val_accuracy: 0.764
train epoch[21/30] loss:0.475: 100%|██████████| 207/207 [00:07<00:00, 26.65it/s]
valid epoch[21/30]: 100%|██████████| 23/23 [00:00<00:00, 33.27it/s]
[epoch 21] train_loss: 0.511  val_accuracy: 0.764
train epoch[22/30] loss:0.513: 100%|██████████| 207/207 [00:07<00:00, 26.53it/s]
valid epoch[22/30]: 100%|██████████| 23/23 [00:00<00:00, 32.16it/s]
[epoch 22] train_loss: 0.482  val_accuracy: 0.761
train epoch[23/30] loss:0.172: 100%|██████████| 207/207 [00:07<00:00, 26.62it/s]
valid epoch[23/30]: 100%|██████████| 23/23 [00:00<00:00, 33.02it/s]
[epoch 23] train_loss: 0.501  val_accuracy: 0.761
train epoch[24/30] loss:1.127: 100%|██████████| 207/207 [00:07<00:00, 26.54it/s]
valid epoch[24/30]: 100%|██████████| 23/23 [00:00<00:00, 34.24it/s]
[epoch 24] train_loss: 0.492  val_accuracy: 0.755
train epoch[25/30] loss:0.905: 100%|██████████| 207/207 [00:07<00:00, 26.76it/s]
valid epoch[25/30]: 100%|██████████| 23/23 [00:00<00:00, 30.22it/s]
[epoch 25] train_loss: 0.492  val_accuracy: 0.758
train epoch[26/30] loss:1.044: 100%|██████████| 207/207 [00:07<00:00, 26.75it/s]
valid epoch[26/30]: 100%|██████████| 23/23 [00:00<00:00, 33.86it/s]
[epoch 26] train_loss: 0.476  val_accuracy: 0.777
train epoch[27/30] loss:0.552: 100%|██████████| 207/207 [00:07<00:00, 26.73it/s]
valid epoch[27/30]: 100%|██████████| 23/23 [00:00<00:00, 31.55it/s]
[epoch 27] train_loss: 0.465  val_accuracy: 0.745
train epoch[28/30] loss:0.387: 100%|██████████| 207/207 [00:07<00:00, 26.68it/s]
valid epoch[28/30]: 100%|██████████| 23/23 [00:00<00:00, 32.30it/s]
[epoch 28] train_loss: 0.482  val_accuracy: 0.769
train epoch[29/30] loss:0.251: 100%|██████████| 207/207 [00:07<00:00, 26.69it/s]
valid epoch[29/30]: 100%|██████████| 23/23 [00:00<00:00, 32.98it/s]
[epoch 29] train_loss: 0.466  val_accuracy: 0.777
train epoch[30/30] loss:0.368: 100%|██████████| 207/207 [00:07<00:00, 26.57it/s]
valid epoch[30/30]: 100%|██████████| 23/23 [00:00<00:00, 31.95it/s]
[epoch 30] train_loss: 0.467  val_accuracy: 0.780
Finished Training

(4)蒸馏训练

import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from model import resnet34
from model_10 import ConvNetdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))# image_path = os.path.join(data_root, "data_set", "flower_data")image_path = "/home/trq/data/Test5_resnet/flower_data"assert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())json_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 16nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])print('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))teacher_net = resnet34(num_classes=5).to(device)tearcher_model_weight_path = "resNet34.pth"assert os.path.exists(tearcher_model_weight_path), f"File '{tearcher_model_weight_path}' does not exist."teacher_net.load_state_dict(torch.load(tearcher_model_weight_path, map_location="cpu"),strict=False)teacher_net.to(device)# Load student modelstudent_net = ConvNet()student_model_weight_path = "ConvNet.pth"assert os.path.exists(student_model_weight_path), "file {} does not exist.".format(student_model_weight_path)student_net.load_state_dict(torch.load(student_model_weight_path, map_location="cpu"))student_net.to(device)# Distillation loss functionloss_function = nn.KLDivLoss(reduction='batchmean')student_loss_function = nn.CrossEntropyLoss()# Optimizer for the student modelparams = [p for p in student_net.parameters() if p.requires_grad]optimizer = optim.Adam(params, lr=0.0001)epochs = 30best_acc = 0.0save_path = ('./distilled_ConvNet.pth')train_steps = len(train_loader)temperature = 5.0  # Temperature for distillationfor epoch in range(epochs):student_net.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()teacher_logits = teacher_net(images.to(device))student_logits = student_net(images.to(device))# Soften the logitsteacher_logits = teacher_logits / temperaturestudent_logits = student_logits / temperature# Compute the distillation lossloss = loss_function(torch.nn.functional.log_softmax(student_logits, dim=1),torch.nn.functional.softmax(teacher_logits, dim=1)) * (temperature ** 2)# Compute the classification lossstudent_loss = student_loss_function(student_logits, labels.to(device))# Combine lossesloss = 0.5 * loss + 0.5 * student_lossloss.backward()optimizer.step()running_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)student_net.eval()acc = 0.0with torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = student_net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(student_net.state_dict(), save_path)print('Finished Training')if __name__ == '__main__':main()

       没有截屏,可以自己试试,测试了自建模型训练30epoch后接着蒸馏训练30epoch,val_accuracy可以到达0.81.

(5)模型文件

https://pan.baidu.com/s/1gVTJPvAQ3oDEZcGYoJvuLw

提取码: ddk5 

4.总结

        如果模型结果简单,可以使用蒸馏训练提升模型的准确性,当然要先训练一个教师模型.


文章转载自:

http://oyXScZEz.wtnyg.cn
http://EhdVLk4S.wtnyg.cn
http://1Q5k9FJs.wtnyg.cn
http://iBlHclHE.wtnyg.cn
http://jojBbje9.wtnyg.cn
http://BkQU9c0P.wtnyg.cn
http://iAoIlO79.wtnyg.cn
http://iYKz9JbH.wtnyg.cn
http://dOmHBYvw.wtnyg.cn
http://5cV5xcbY.wtnyg.cn
http://7ROsVyNK.wtnyg.cn
http://ArpaPtnA.wtnyg.cn
http://K1Yc4kzG.wtnyg.cn
http://KqQorqqO.wtnyg.cn
http://74idq4ta.wtnyg.cn
http://LESWNvje.wtnyg.cn
http://DJbVJixP.wtnyg.cn
http://e16Lau6M.wtnyg.cn
http://rpARxY4m.wtnyg.cn
http://WbMwvpFT.wtnyg.cn
http://EnczmglG.wtnyg.cn
http://EuJn09XB.wtnyg.cn
http://2eSU7J60.wtnyg.cn
http://AEAgIhEF.wtnyg.cn
http://87f0R485.wtnyg.cn
http://sbiTyz0U.wtnyg.cn
http://XrcwLHWG.wtnyg.cn
http://IWN9eYTI.wtnyg.cn
http://4AWbefrl.wtnyg.cn
http://RvYixYS9.wtnyg.cn
http://www.dtcms.com/wzjs/649467.html

相关文章:

  • 如何做商城网站小程序搜索引擎优化怎么做
  • 建设厅注册中心网站首页浙江省建设安全监督站的网站
  • 贪便宜网站网站内容上传
  • 网站开发的常用流程网站怎么做交易平台
  • 怎么做自己优惠券网站自己有服务器如何建设微网站
  • 公司做网站推广哪些网站可以做旅游
  • 仿牌外贸网站推广wordpress 去掉评论框
  • php如何做网站嘉兴网站推广企业
  • 建设企业网站企业网上银行助手下载宁波网站推广专业的建站优化公司
  • 网站开发与兼容模式创建网站英文
  • 网站项目流程表类似设计师联盟的网站
  • 网站开发 发布移动网站优化
  • 长春网站建设q479185700強网站 攻击 刷流量
  • 网站seo诊断评分63威海市临港区建设局网站
  • 网站建设开发费会计分录亚马逊跨境电商下载
  • 网站在浏览器的图标怎么做网页设计需要学什么知识
  • 新闻源网站怎么做吴中seo网站优化软件
  • 企业手机网站建设有徐州城乡建设网站
  • 个人网站 不用备案吗宿迁房产网官网房价
  • 网站备案登录室内装饰设计师职业标准
  • 广州市建设集团网站深圳招工包吃住8000元
  • .la域名的门户网站psd做网站切片
  • 在韩国申请网站域名需要什么前端代码练习网站
  • 佛山网站建设78788加工网袋的设备多少钱
  • 杭州cms模板建站网页设计与制作用什么软件做
  • 做网站首选九零后网络网站应用程序池
  • 手机网站开发软件企业网站销售
  • 外贸网站推广哪个比较好如何设置便于搜索引擎收录的网站结构
  • 网站建设水上乐园怎么搭建php网站
  • 如何做php网站成都洛可可设计有限公司