当前位置: 首页 > news >正文

PyTorch教程:如何读写张量与模型参数

本文演示了PyTorch中张量(Tensor)和模型参数的保存与加载方法,并提供完整的代码示例及输出结果,帮助读者快速掌握数据持久化的核心操作。


1. 保存和加载单个张量

通过torch.savetorch.load可以直接保存和读取张量。

import torch

# 创建并保存张量
x = torch.arange(4)
torch.save(x, 'x-file')

# 加载张量
x2 = torch.load('x-file')
print(x2)  # 输出:tensor([0, 1, 2, 3])

输出结果

tensor([0, 1, 2, 3])

2. 保存和加载张量列表

可以将多个张量存储为列表,并一次性加载。

# 创建两个张量并保存为列表
y = torch.zeros(4)
torch.save([x, y], 'x-files')

# 加载列表
x2, y2 = torch.load('x-files')
print((x2, y2))

输出结果

(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))

3. 保存和加载字典

通过字典可以更灵活地管理多个张量。

# 创建字典并保存
mydict = {'x': x, 'y': y}
torch.save(mydict, 'mydict')

# 加载字典
mydict2 = torch.load('mydict')
print(mydict2)

输出结果

{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}

4. 定义神经网络模型

以下是一个简单的全连接神经网络示例:

from torch import nn
from torch.nn import functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)  # 隐藏层
        self.output = nn.Linear(256, 10)   # 输出层
    
    def forward(self, x):
        return self.output(F.relu(self.hidden(x)))

# 实例化模型并进行前向传播
net = Model()
x = torch.rand(size=(2, 20))
y = net(x)
print(y)

输出结果(因随机初始化可能不同):

tensor([[-0.0711, 0.1161, -0.1113, ..., 0.0787],
        [-0.0151, 0.0275, -0.1652, ..., 0.0109]], grad_fn=<AddmmBackward0>)

5. 保存模型参数

使用state_dict保存模型参数:

torch.save(net.state_dict(), 'net.params')

6. 加载模型参数并验证

加载参数到新模型实例,并验证一致性:

# 创建新模型并加载参数
clone = Model()
clone.load_state_dict(torch.load('net.params'))
clone.eval()  # 设置为评估模式(关闭Dropout/BatchNorm等)

# 比较输出结果
Y_clone = clone(x)
print(Y_clone == y)

输出结果

tensor([[True, True, ..., True],
        [True, True, ..., True]])

总结

  1. 张量读写:直接使用torch.savetorch.load,支持列表和字典。

  2. 模型参数保存:通过state_dict保存模型状态,加载时需重新实例化模型。

  3. 验证一致性:加载参数后,输出与原模型一致表明操作成功。

通过本文的代码示例,读者可以快速掌握PyTorch中数据和模型参数的持久化方法,为模型训练和部署提供便利。

http://www.dtcms.com/a/116965.html

相关文章:

  • 【科普】 探秘图像评价指标的奇妙世界
  • 可发1区的超级创新思路(python 实现):一种轻量化的动态稀疏门控网络
  • forms实现地铁跑酷小游戏
  • Spring的简单介绍
  • C++ std::shared_mutex
  • 汽车与航空航天领域软件维护:深度剖析与未来展望
  • SSH远程工具
  • C语言传参寄存器压栈流程总结
  • 洛谷 P1330 封锁阳光大学
  • C++11:lambda表达式
  • 说一下java的探针agent的应用场景
  • 如何用开源工具,把“定制动漫面具”做成柔性制造?
  • Github最新AI工具汇总2025年4月份第2周
  • 【SPSS/EXCEl】主成分分析构建__综合评价指数
  • 遥感卫星概述#卫星工程系列
  • Linux基本操作指令5(查看IP)
  • 【合新通信】光纤延迟线(ODL)的原理
  • 一周学会Pandas2 Python数据处理与分析-NumPy算术运算和统计计算
  • qml信号与槽函数
  • 《命理学》专项探究与研习
  • css2学习总结之尚品汇静态页面
  • ragflow本地部署(WSL下Ubuntu)
  • Python Cookbook-5.7 在增加元素时保持序列的顺序
  • 人工智能通识速览(Part3. 强化学习)
  • OpenNMT 部署和集成指南
  • Dify 的介绍
  • Diffusion Policy Visuomotor Policy Learning via Action Diffusion官方项目解读(二)(4)
  • C++动态内存管理完全指南:从基础到现代最佳实践
  • Windows系统本地化部署DeepSeek+Open-WebUi
  • OpenBMC:BmcWeb 处理http请求4 处理路由对象