当前位置: 首页 > news >正文

残差网络 迁移学习对食物分类案例的改进

目录

一.直接修改最后一层

1.导入模型

2.冻结模型所有参数

3.修改最后一层

4.获取需要更新的参数

5.创建模型

6.优化器和调整学习率

7.完整代码

二.增加一层

1.模型获取和冻结参数

2.定义新的网络类

3.创建模型

4.获取需要更新的参数

5.优化器和调整学习率

6.完整代码


本篇我们直接使用ImageNet竞赛第一名何凯明训练好的经典的ResNet模型来改进我们的食物分类案例

而何凯明训练的18层残差网络模型输出是用来分类1000种物体而不是我们食物分类案例的20所以我们是需要再最后的输出层进行调整,这里我们有两种方法:①直接修改最后一层网络的输出②再添加一层网络

一.直接修改最后一层

1.导入模型

torchvision库中有大量的优秀的视觉方面的已经训练好的模型,torch.nn则倾向于自己搭建网络

import torchvision.models as models

我们直接导入何凯明的resnet18模型,注意这里的模型是需要网络下载的

resnet_model=models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

2.冻结模型所有参数

然后我们需要冻结模型的所有参数,避免我们后续训练改变它们

require_grad属性设置为False即可

for param in resnet_model.parameters():#冻结模型所有参数param.requires_grad=False

3.修改最后一层

fc.in_features属性可以获得最后一层全连接层的输入个数

然后我们直接把这个fc层赋值成一个新的全连接层并把输出设置为20

in_feature=resnet_model.fc.in_features
resnet_model.fc=nn.Linear(in_feature,20)

4.获取需要更新的参数

遍历模型所有的参数,由于我们之前已经冻结了所有参数,所以只有我们上一步重新赋值的fc层是可以修改参数的,将fc的所有权重参数放入一个列表中,方便我们后续优化器进行反向传播优化

params_to_update=[]
for param in resnet_model.parameters():if param.requires_grad==True:params_to_update.append(param)

5.创建模型

由于我们这里是迁移学习,所以我们不用向像之前再创建类的的对象,我们已经有了模型的对象

model=resnet_model.to(device)

6.优化器和调整学习率

优化器我们要注意传入的参数是我们需要更新的最后一层的参数

调整学习率我们采用有序的等间隔调整

optimizer=torch.optim.Adam(params_to_update,lr=0.001)
scheduler=torch.optim.lr_scheduler.StepLR(optimizer,5,0.5)

其余代码并未做任何的修改完整代码如下:

7.完整代码

import torch
from torch import  nn
from torch.utils.data import Dataset,DataLoader
import numpy as np
from PIL import Image
from torchvision import transforms
import torchvision.models as models
import os
dire={}
def train_test_file(root,dir):f_out=open(dir+'.txt','w')path=os.path.join(root,dir)for root,directories,files in os.walk(path):if len(directories)!=0:dirs=directorieselse:now_dir=root.split('\\')for file in files:path=os.path.join(root,file)f_out.write(path+' '+str(dirs.index(now_dir[-1]))+'\n')dire[dirs.index(now_dir[-1])]=now_dir[-1]f_out.close()
root=r'.\food_dataset2'
train_dir='train'
test_dir='test'
train_test_file(root,train_dir)
train_test_file(root,test_dir)resnet_model=models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
for param in resnet_model.parameters():#冻结模型所有参数param.requires_grad=Falsein_feature=resnet_model.fc.in_features
resnet_model.fc=nn.Linear(in_feature,20)params_to_update=[]
for param in resnet_model.parameters():if param.requires_grad==True:params_to_update.append(param)data_transforms={'train':transforms.Compose([transforms.Resize([300,300]),transforms.RandomRotation(45),#随机旋转transforms.CenterCrop(224),#保持与迁移的模型一致transforms.RandomHorizontalFlip(p=0.5),#随机水平旋转,概率p=0.5transforms.RandomVerticalFlip(p=0.5),#随机垂直旋转,概率p=0.5transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue=0.1),#brightness(亮度)contrast(对比度)saturation(饱和度)hue(色调)transforms.RandomGrayscale(p=0.1),#p的概率转换为灰度图,但任然是三个通道,不过三个通道相同transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])#标准化,设置通用的均值和标准差]),'valid':#测试集只用标准化transforms.Compose([transforms.Resize([224,224]),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
}
class food_dataset(Dataset):#能通过索引的方式返回图片数据和标签结果def __init__(self,file_path,transform=None):self.file_path=file_pathself.imgs_paths=[]self.labels=[]self.transform=transformwith open(self.file_path) as f:samples=[x.strip().split(' ') for x in f.readlines()]for img_path,label in samples:self.imgs_paths.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs_paths)def __getitem__(self, idx):image=Image.open(self.imgs_paths[idx])if self.transform:image=self.transform(image)label=self.labels[idx]label=torch.from_numpy(np.array(label,dtype=np.int64))#label也转化为tensorreturn image,labeltrain_data=food_dataset(file_path='./train.txt',transform=data_transforms['train'])
test_data=food_dataset(file_path='./test.txt',transform=data_transforms['valid'])train_loader=DataLoader(train_data,batch_size=32,shuffle=True)
test_loader=DataLoader(test_data,batch_size=32,shuffle=True)
device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device')model=resnet_model.to(device)def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num=1for X,y in dataloader:X,y=X.to(device),y.to(device)pred=model(X)loss=loss_fn(pred,y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value=loss.item()if batch_size_num % 100 == 0:print(f'loss:{loss_value:>7f} [number:{batch_size_num}]')batch_size_num += 1
best_acc=0
def test(dataloader,model,loss_fn):global best_accmodel.eval()len_data=len(dataloader.dataset)correct,loss_sum=0,0num_batch=0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss_sum += loss_fn(pred, y).item()correct+=(pred.argmax(1)==y).type(torch.float).sum().item()num_batch+=1loss_avg=loss_sum/num_batchaccuracy=correct/len_dataif accuracy>best_acc:best_acc=accuracyprint(f'Accuracy:{100 * accuracy}%\nLoss Avg:{loss_avg}')return pred.argmax(1)
loss_fn=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(params_to_update,lr=0.001)
scheduler=torch.optim.lr_scheduler.StepLR(optimizer,5,0.5)epochs=10
for i in range(epochs):print(f'==========第{i + 1}轮训练==============')train(train_loader, model, loss_fn, optimizer)test(test_loader, model, loss_fn)scheduler.step()print(f'第{i + 1}轮训练结束')
print('best_acc=',best_acc)

二.增加一层


我们如何在最后添加一层是通过定义一个类来实现的原理如下:

我们是通过类将迁移学习的模块和新的fc层组合在一起实现在残差网络中添加一层网络

1.模型获取和冻结参数

resnet_model=models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
for param in resnet_model.parameters():#冻结模型所有参数param.requires_grad=False

2.定义新的网络类

class ResNet(nn.Module):def __init__(self):super().__init__()self.resnet_=resnet_modelself.fc=nn.Linear(1000,20)def forward(self,x):out=self.resnet_(x)out=self.fc(out)return out

3.创建模型

这里创建模型我们又用到了类的创建

model=ResNet().to(device)

4.获取需要更新的参数

params_to_update=[]
for param in model.parameters():if param.requires_grad==True:params_to_update.append(param)

5.优化器和调整学习率

这里的调整学习率我们采用自适应调整

optimizer=torch.optim.Adam(params_to_update,lr=0.001)
# scheduler=torch.optim.lr_scheduler.StepLR(optimizer,5,0.5)
scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,mode='min',factor=0.5,patience=2,verbose=False,threshold=0.0001,threshold_mode='rel',cooldown=0,min_lr=0,eps=1e-08)

6.完整代码

其余代码并无修改,只是train()方法做了一个返回平均损失值的改进用来作为学习率调整的根据指标,完整代码如下:

import torch
from torch import  nn
from torch.utils.data import Dataset,DataLoader
import numpy as np
from PIL import Image
from torchvision import transforms
import torchvision.models as modelsresnet_model=models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
for param in resnet_model.parameters():#冻结模型所有参数param.requires_grad=False
class ResNet(nn.Module):def __init__(self):super().__init__()self.resnet_=resnet_modelself.fc=nn.Linear(1000,20)def forward(self,x):out=self.resnet_(x)out=self.fc(out)return outdata_transforms={'train':transforms.Compose([transforms.Resize([300,300]),transforms.RandomRotation(45),#随机旋转transforms.CenterCrop(224),#保持与迁移的模型一致transforms.RandomHorizontalFlip(p=0.5),#随机水平旋转,概率p=0.5transforms.RandomVerticalFlip(p=0.5),#随机垂直旋转,概率p=0.5transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue=0.1),#brightness(亮度)contrast(对比度)saturation(饱和度)hue(色调)transforms.RandomGrayscale(p=0.1),#p的概率转换为灰度图,但任然是三个通道,不过三个通道相同transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])#标准化,设置通用的均值和标准差]),'valid':#测试集只用标准化transforms.Compose([transforms.Resize([224,224]),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
}
class food_dataset(Dataset):#能通过索引的方式返回图片数据和标签结果def __init__(self,file_path,transform=None):self.file_path=file_pathself.imgs_paths=[]self.labels=[]self.transform=transformwith open(self.file_path) as f:samples=[x.strip().split(' ') for x in f.readlines()]for img_path,label in samples:self.imgs_paths.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs_paths)def __getitem__(self, idx):image=Image.open(self.imgs_paths[idx])if self.transform:image=self.transform(image)label=self.labels[idx]label=torch.from_numpy(np.array(label,dtype=np.int64))#label也转化为tensorreturn image,labeltrain_data=food_dataset(file_path='./train.txt',transform=data_transforms['train'])
test_data=food_dataset(file_path='./test.txt',transform=data_transforms['valid'])train_loader=DataLoader(train_data,batch_size=32,shuffle=True)
test_loader=DataLoader(test_data,batch_size=32,shuffle=True)
device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device')model=ResNet().to(device)
params_to_update=[]
for param in model.parameters():if param.requires_grad==True:params_to_update.append(param)def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num=1loss_sum=0total_batches = len(dataloader)  # 获取总批次数量for X,y in dataloader:X,y=X.to(device),y.to(device)pred=model(X)loss=loss_fn(pred,y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value=loss.item()loss_sum+=loss_valueif batch_size_num % 100 == 0:print(f'loss:{loss_value:>7f} [number:{batch_size_num}]')batch_size_num += 1return loss_sum/total_batches
best_acc=0
def test(dataloader,model,loss_fn):global best_accmodel.eval()len_data=len(dataloader.dataset)correct,loss_sum=0,0num_batch=0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss_sum += loss_fn(pred, y).item()correct+=(pred.argmax(1)==y).type(torch.float).sum().item()num_batch+=1loss_avg=loss_sum/num_batchaccuracy=correct/len_dataif accuracy>best_acc:best_acc=accuracyprint(f'Accuracy:{100 * accuracy}%\nLoss Avg:{loss_avg}')return pred.argmax(1)
loss_fn=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(params_to_update,lr=0.001)
# scheduler=torch.optim.lr_scheduler.StepLR(optimizer,5,0.5)
scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,mode='min',factor=0.5,patience=2,verbose=False,threshold=0.0001,threshold_mode='rel',cooldown=0,min_lr=0,eps=1e-08)epochs=10
for i in range(epochs):print(f'==========第{i + 1}轮训练==============')loss=train(train_loader, model, loss_fn, optimizer)scheduler.step(loss)test(test_loader, model, loss_fn)print(f'第{i + 1}轮训练结束')
print('best_acc=',best_acc)


文章转载自:

http://IEWgO80E.tqrxm.cn
http://Sva2tGh7.tqrxm.cn
http://fos4hMyN.tqrxm.cn
http://K1nE5NFB.tqrxm.cn
http://cYqfqfuF.tqrxm.cn
http://cP669pyR.tqrxm.cn
http://Y0xxWUko.tqrxm.cn
http://mWcNwT6x.tqrxm.cn
http://cTrdd8qI.tqrxm.cn
http://8IgaahIE.tqrxm.cn
http://PCQpSPwo.tqrxm.cn
http://GWBAE2o4.tqrxm.cn
http://wSMg0mRt.tqrxm.cn
http://vt1BjfcW.tqrxm.cn
http://kdrLTsne.tqrxm.cn
http://CKrA1seQ.tqrxm.cn
http://6xfg7YFC.tqrxm.cn
http://a8mfl3E2.tqrxm.cn
http://OlQUDYjn.tqrxm.cn
http://kUDJ9icD.tqrxm.cn
http://AcDQ83uC.tqrxm.cn
http://Du3PFNdq.tqrxm.cn
http://M7eSKImq.tqrxm.cn
http://gwcSsN2w.tqrxm.cn
http://X6tCExl4.tqrxm.cn
http://Jc4DlTv8.tqrxm.cn
http://yYnMwBEs.tqrxm.cn
http://zemzYJtE.tqrxm.cn
http://yXQEzV0H.tqrxm.cn
http://Zx7gOi2o.tqrxm.cn
http://www.dtcms.com/a/370216.html

相关文章:

  • VBA之Excel应用第四章第七节:单元格区域的整行或整列扩展
  • 【Flask】测试平台开发,数据看板开发-第二十一篇
  • [光学原理与应用-433]:晶体光学 - 晶体光学是研究光在单晶体中传播规律及其伴随现象的分支学科,聚焦于各向异性光学媒质的光学特性
  • C++面试10——构造函数、拷贝构造函数和赋值运算符
  • PID控制技术深度剖析:从基础原理到高级应用(六)
  • 登录优化(双JWT+Redis)
  • 【基础-单选】在下面哪个文件中可以设置页面的路径配置信息?
  • C++ 内存模型:用生活中的例子理解并发编程
  • 【3D图像算法技术】如何在Blender中对复杂物体进行有效减面?
  • 电脑音频录制 | 系统麦克混录 / 系统声卡直录 | 方法汇总 / 常见问题
  • 论文阅读:VGGT Visual Geometry Grounded Transformer
  • 用 PHP 玩向量数据库:一个从小说网站开始的小尝试
  • [光学原理与应用-432]:非线性光学 - 既然光也是电磁波,为什么不能直接通过电生成特定频率的光波?
  • python调用mysql
  • redis-----事务
  • 集成学习(随机森林算法、Adaboost算法)
  • 形式化方法与安全模型
  • Python两种顺序生成组合
  • 【Python自动化】 21 Pandas Excel 操作完整指南
  • Unity与硬件交互终极指南:从Arduino到自定义USB设备
  • Codeforces Round 1046 (Div. 2) vp补题
  • 【LeetCode热题100道笔记】二叉树的右视图
  • Day22_【机器学习—集成学习(1)—基本思想、分类】
  • 自动化运维,ansible综合测试练习题
  • 【面试题】领域模型持续预训练数据选取方法
  • OpenHarmony之USB Manager 架构深度解析
  • 新服务器初始化:Git全局配置与SSH密钥生成
  • 主流分布式数据库集群选型指南
  • 【Proteus仿真】定时器控制系列仿真——秒表计数/数码管显示时间
  • python advance -----object-oriented