当前位置: 首页 > news >正文

深度学习学习笔记-模型的修改和CRUD

目录

  • 1.打印模型,理解模型结构
  • 2.模型保存与加载
  • 3.模型的模块CRUD和模块的层的CRUD

1.打印模型,理解模型结构

import torch


class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = torch.nn.Sequential(
            torch.nn.Linear(3, 4),
            torch.nn.Linear(4, 3),
        )
        self.layer2 = torch.nn.Linear(3, 6)

        self.layer3 = torch.nn.Sequential(
            torch.nn.Linear(6, 7),
            torch.nn.Linear(7, 5),
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x


net = MyModel()
print(net)

在这里插入图片描述

2.模型保存与加载

本节介绍如何保存模型,如何保存模型参数

import torchvision.models as models
from torchsummary import summary
import torch


# https://pytorch.org/vision/stable/models.html
# alexnet = models.alexnet(weights=None)
# resnet50 = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
# print(resnet50)


# -----------------------------------------------------------
# 保存模型 / 保存模型+参数
# -----------------------------------------------------------

# resnet50 = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)

# 1、仅保存模型的参数
# state_dict是存储模型参数的
# torch.save(resnet50.state_dict(), 'resnet50_weight.pth')

# 2、保存模型 + 参数
# torch.save(resnet50, 'resnet50.pth')


# -----------------------------------------------------------
# 加载模型 / 加载模型+参数
# -----------------------------------------------------------

# 1、加载模型+参数
net = torch.load("resnet50.pth")
print(net)

# 2、已有模型,加载预训练参数
# resnet50 = models.resnet50(weights=None)

# resnet50.load_state_dict(torch.load('resnet50_weight.pth'))

3.模型的模块CRUD和模块的层的CRUD

本节介绍模型的层layer的CRUD

import torch.nn as nn
import torchvision.models as models


alexnet = models.alexnet(weights=models.AlexNet_Weights.DEFAULT)
print(alexnet)

# 1、----- 删除网络的最后一层 -----
# 代码解释del alexnet.classifier是直接删除AlexNet中名称为classifier的模块
# 代码解释del alexnet.classifier[6]是删除classifier模块里面的第六层,也就是括号(6)
# del alexnet.classifier
# del alexnet.classifier[6]
# print(alexnet)


# 2、----- 删除网络的最后多层 -----
# 代码解释: 列表切片
# alexnet.classifier = alexnet.classifier[:-2]
# print(alexnet)


# 3、----- 修改网络的某一层 -----
# alexnet.classifier[6] = nn.Linear(in_features=4096, out_features=1024)
# print(alexnet)


# 4、----- 网络添加层, 每次添加一层 -----
# alexnet.classifier.add_module('7', nn.ReLU(inplace=True))
# alexnet.classifier.add_module('8', nn.Linear(in_features=1024, out_features=20))
# print(alexnet)


# 4、----- 网络添加层,一次添加多层 -----
# block = nn.Sequential(nn.ReLU(inplace=True),
#                       nn.Linear(in_features=1024, out_features=20))
# 模型中添加名称为block的模块
# alexnet.add_module('block', block)
# print(alexnet)

结合代码注释和下图理解即可
在这里插入图片描述

相关文章:

  • Spring IOC之@ComponentScan
  • LAXCUS分布式操作系统是怎么实现的?
  • 【广州华锐互动】利用AR进行野外地质调查学习,培养学生实践能力
  • 【算法教程】排列与组合的实现
  • 华为OD 绘图机器(100分)【java】A卷+B卷
  • 项目经理之识别项目干系人
  • 百分点科技受邀参加“一带一路”国际合作高峰论坛
  • Android 特权应用 privapp-permissions 权限解读
  • 华为数通方向HCIP-DataCom H12-831题库(多选题:1-20)
  • Ansible 的脚本 --- playbook 剧本
  • SSD算法学习(单步多框目标检测)
  • 美格智能出席无锡智能网联汽车生态大会,共话数字座舱新势力
  • 【数据结构】模拟实现无头单向非循环链表
  • JOSEF约瑟 JHOK-ZBM1;JHOK-ZBL1多档切换式漏电(剩余)继电器 面板导轨安装
  • Flink之常用处理函数
  • 使用cxf将wsdl文件转换成java文件 webservice
  • 中文编程开发语言工具开发的实际软件案例:称重管理系统软件
  • 2023年最新版CorelDraw(cdr)软件下载安装教程
  • 【广州华锐互动】VR营销心理学情景模拟培训系统介绍
  • Mysql数据库表操作--存储
  • 申活观察|咖香涌动北外滩,带来哪些消费新想象?
  • 韩国代总统、国务总理韩德洙宣布辞职,将择期宣布参选总统
  • 向左繁华都市,向右和美乡村,嘉兴如何打造城乡融合发展样本
  • 铁路迎来节前出行高峰,今日全国铁路预计发送旅客1870万人次
  • 国泰海通合并后首份业绩报告出炉:一季度净利润增逾391%
  • 新一届中国女排亮相,奥运冠军龚翔宇担任队长