当前位置: 首页 > news >正文

【AI学习从零至壹】pytorch基础

pytorch基础

  • pytorch基础
    • 张量(Tensor)
    • 张量的属性
      • 张量的索引和切⽚:
      • 张量的拼接
      • 张量的算数运算
      • 单元素张量
      • In-place操作
    • 与numpy之间的转换
      • 张量到numpy数组
  • 计算图
    • 静态计算图
    • 动态计算图
    • pytorch计算图可视化

pytorch基础

PyTorch 是⼀个开源的深度学习框架,由 Facebook 的⼈⼯智能研究团队开发和维护,在学术界和⼯业界都得到了⼴泛应⽤。

张量(Tensor)

张量(Tensor)是pytorch中的基本单位,也是深度学习框架构成的重要组成。
我们可以先把张量看做是⼀个容器,⾥⾯承载了需要运算的数据。
张量可以通过多种⽅式初始化

  • 直接从数据
import torch
data = [[1,2],[3,4]]
x_data = torch.tensor(data)
x_data
  • 从 NumPy 数组
    张量可以从 NumPy 数组创建,反之亦然
a = np.array([1,2,3])
b = torch.from_numpy(a)
b
  • 从另⼀个张量:
    除⾮明确覆盖,否则新张量保留参数张量的属性(形状、数据类型)。
c = torch.tensor([1,2,3])
d = torch.ones_like(c)# # 保留of x_data的属性,但里面的值全为1
print("ones_like:",d)
e = torch.rand_like(c,dtype=float)#覆盖 x_data的数据类型,里面的值为0-1的随机值
print("rand_like",e)

使⽤随机值或常量值:shape 是张量维度的元组

import torch
shape = (2,3)#两行三列
rand_tensor = torch.rand(shape)
ones_tensor = torch.ones(shape)
zeros_tensor = torch.zeros(shape)
print(f"Random Tensor: \n {rand_tensor} \n")
print(f"Ones Tensor: \n {ones_tensor} \n")
print(f"Zeros Tensor: \n {zeros_tensor}")
#输出
Random Tensor: 
 tensor([[0.5640, 0.9760, 0.1646],
        [0.9231, 0.4425, 0.8974]]) 
Ones Tensor: 
 tensor([[1., 1., 1.],
        [1., 1., 1.]]) 
Zeros Tensor: 
 tensor([[0., 0., 0.],
        [0., 0., 0.]])
  • 其他一些创建方法
m = torch.tensor([2,3],dtype=torch.double)
n = torch.ones(5,3,dtype=torch.double)
a = torch.rand_like(n,dtype=torch.float)
print(a.size())
# 均匀分布

print(torch.rand(5,3))
# 标准正态分布

print(torch.randn(5,3))
# 离散正态分布

print(torch.normal(mean=.0,std=1.0,size=(5,3)))
# 线性间隔向量(返回⼀个1维张量,包含在区间start和end上均匀间隔的steps个点)
torch.linspace(start=1,end=10,steps=20)

张量的属性

张量的属性描述了张量的形状、数据类型和存储它们的设备。以对象的⻆度来判断,张量可以看做是具有特征和⽅法的对象。

tensor = torch.rand(3,4)
print(f"Shape of tensor: {tensor.shape}")
print(f"Datatype of tensor: {tensor.dtype}")
print(f"Device tensor is stored on: {tensor.device}")
#输出
Shape of tensor: torch.Size([3, 4])
Datatype of tensor: torch.float32
Device tensor is stored on: cpu

张量的索引和切⽚:

tensor = torch.ones(5,2)
print("tensor原数组",tensor)
print("tensor first row:",tensor[0])
print("tensor first column:",tensor[:,0])
print("tensor last column",tensor[:,-1])
#输出
tensor原数组 tensor([[1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.]])
tensor first row: tensor([1., 1.])
tensor first column: tensor([1., 1., 1., 1., 1.])
tensor last column tensor([1., 1., 1., 1., 1.])

张量的拼接

可以使⽤ torch.cat ⽤来连接指定维度的⼀系列张量。另⼀个和 torch.cat 功能类似的函数是torch.stack
在 PyTorch 中,torch.cat() 是一种用于在指定维度上连接张量的操作。它能够将多个张量沿某个轴拼接成一个新的张量。

torch.cat(tensors, dim=0)

tensors:一个包含多个待拼接张量的列表或元组。这些张量在指定的 dim 维度以外的所有维度上必须具有相同的形状。
dim:指定在哪个维度上进行拼接操作。
使用规则
1.在指定维度上,张量的形状可以不同(因为会拼接)。
2.在其他维度上,张量的形状必须相同。

tensor1 = torch.ones(2,2)
tensor2 = torch.zeros(2,2)
tensor3 =torch.cat((tensor1,tensor2),dim = 0)#在行上进行拼接
print(tensor3)
tensor4=torch.cat((tensor1,tensor2),dim = 1)#在列上进行拼接
print(tensor4)
#输出
tensor([[1., 1.],
        [1., 1.],
        [0., 0.],
        [0., 0.]])
tensor([[1., 1., 0., 0.],
        [1., 1., 0., 0.]])

张量的算数运算

tensor = torch.arange(1,10,dtype=float).reshape(3,3)
tensor
y1 = tensor @ tensor.T
y1
y2 = tensor.matmul(tensor.T)
y2
y3 = torch.rand_like(tensor)
torch.matmul(tensor,tensor.T,out=y3)
print(y3)
#输出
# tensor([[ 14.,  32.,  50.],
#         [ 32.,  77., 122.],
#         [ 50., 122., 194.]], dtype=torch.float64)
#上面是向量的内积
#下面是向量的逐一相乘
z1 = tensor * tensor
z2 = tensor.mul(tensor)

z3 = torch.rand_like(tensor)
torch.mul(tensor, tensor, out=z3)
tensor = torch.arange(1,10,dtype=float).reshape(3,3)
tensor
y1 = tensor @ tensor.T
y1
y2 = tensor.matmul(tensor.T)
y2
y3 = torch.rand_like(tensor)
torch.matmul(tensor,tensor.T,out=y3)
print(y3)
#输出
#tensor([[ 14.,  32.,  50.],
#[ 32.,  77., 122.],
#[ 50., 122., 194.]], dtype=torch.float64)
#上面是向量的内积
#下面是向量的逐一相乘
z1 = tensor * tensor
z2 = tensor.mul(tensor)
z3 = torch.rand_like(tensor)
torch.mul(tensor, tensor, out=z3)
#输出
tensor([[ 1.,  4.,  9.],
        [16., 25., 36.],
        [49., 64., 81.]], dtype=torch.float64)

单元素张量

如果⼀个单元素张量,例如将张量的值聚合计算,可以使⽤ item() ⽅法将其转换为 Python 数值
item理解:

  • 取出张量具体位置的元素元素值,并且返回的是该位置元素值的高精度值,保持原元素类型不变;必须指定位置,即:原张量元素为整形,则返回整形,原张量元素为浮点型则返回浮点型
agg = tensor.sum()
print(type(agg))
agg = agg.item()
print(type(agg))
#输出
<class 'torch.Tensor'>
<class 'float'>

In-place操作

把计算结果存储到当前操作数中的操作就称为就地操作。含义和pandas中inPlace参数的含义⼀样。pytorch中,这些操作是由带有下划线 _ 后缀的函数表⽰。例如:x.copy_(y) , x.t_() , 将改变 x ⾃⾝的值。

print(tensor)
tensor.add_(5)
print(tensor)
#输出
tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]], dtype=torch.float64)
tensor([[ 6.,  7.,  8.],
        [ 9., 10., 11.],
        [12., 13., 14.]], dtype=torch.float64)

In-place操作虽然节省了⼀部分内存,但在计算导数时可能会出现问题,因为它会⽴即丢失历史记录。因此,不⿎励使⽤它们。

与numpy之间的转换

CPU 和 NumPy 数组上的张量共享底层内存位置,所以改变⼀个另⼀个也会变。

张量到numpy数组

t = torch.ones(5)
print(f"t: {t}")
n = t.numpy()
print(f"n: {n}")

张量值的变更也反映在关联的NumPy 数组中

t.add_(1)
print(f"t: {t}")
print(f"n: {n}")

计算图

在进⼀步学习pytorch之前,先要了解⼀个概念 —— 计算图( Computation graph)所有的深度学习框架都依赖于计算图来完成梯度下降、优化梯度值等计算。⽽计算图的创建和应⽤,通常包含如下两个部分:

  • ⽤⼾构建前向传播图
  • 框架处理后向传播(梯度更新)

模型从简单到复杂,pytorch和tensorflow都使⽤计算图来完成⼯作。
但是,这两个框架所使⽤的计算图也却有所不同:
tensorflow1.x 使⽤的是静态计算图,tensorflow2.x和pytorch使⽤的是动态计算图。

静态计算图

通常包括以下两个阶段。
阶段1:定义⼀个架构(可以使⽤⼀些基本的流控制⽅法,⽐如循环和条件指令)
阶段2:运⾏⼀组数据来训练模型,进⾏推理。
优点:允许对图进⾏强⼤的离线优化/调度,所以速度相对较快。
缺点:难以调试,对代码中处理结构化或者可变⼤⼩的数据处理⽐较复杂。

动态计算图

在执⾏正向计算时,隐式地定义图(动态构建)。
优点:灵活,侵⼊性⼩,允许动态构建和评估。
缺点:难以优化。
两种计算图⽐较起来,可以看出:动态图是对调试友好的(对程序员友好)。它允许逐⾏执⾏代码,并可以访问所有张量。这样更便于发现和找到我们计算或逻辑中的问题。

pytorch计算图可视化

import torch
from torchviz import make_dot
# 定义矩阵 A,向量 b 和常数 c
A = torch.randn(10, 10,requires_grad=True)
b = torch.randn(10,requires_grad=True)
c = torch.randn(1,requires_grad=True)
x = torch.randn(10, requires_grad=True)
# 计算 x^T * A  b * x + c
result = torch.matmul(A, x.T+ torch.matmul(b, x) + c)
# ⽣成计算图节点

dot = make_dot(result, params={'A': A, 'b': b, 'c': c, 'x': x})
# 绘制计算图

dot.render('expression', format='png', cleanup=True, view=False)

相关文章:

  • Linux安装Apache2.4.54操作步骤
  • 前端js搭建(搭建后包含cookie,弹窗,禁用f12)
  • onerror事件的理解与用法
  • 【人工智能】GPT-4 vs DeepSeek-R1:谁主导了2025年的AI技术竞争?
  • 对大模型输出的 logits 进行处理,从而控制文本的生成
  • Java---入门基础篇(下)---方法与数组
  • C++类和对象:匿名对象及连续构造拷贝编译器的优化
  • Windows下git疑难:有文件无法被跟踪
  • FPGA开发,使用Deepseek V3还是R1(1):应用场景
  • openssl下aes128算法CFB模式加解密运算实例
  • 【自学笔记】大数据基础知识点总览-持续更新
  • 机器视觉3D偏光法原理解析
  • Oracle 数据库基础入门(四):分组与联表查询的深度探索(上)
  • 8. Nginx 配合 + Keepalived 搭建高可用集群
  • DeepSeek 助力 Vue3 开发:打造丝滑的密码输入框(Password Input)
  • 模拟退火算法浅尝
  • Java 大视界 -- 基于 Java 的大数据分布式缓存一致性维护策略解析(109)
  • 阿里管理三板斧课程和管理工具包(视频精讲+工具文档).zip
  • Excel 豆知识 - XLOOKUP 为啥会出 #N/A 错误
  • git的恢复命令
  • 怎么看出是模板网站/推广自己产品的文案
  • 租号网站咋做/地推拉新app推广接单平台
  • 宜兴网站设计/百度指数里的资讯指数是什么
  • 网站怎么做用户体验/张雷明任河南省委常委
  • 如何做网站的301重定向/微信公众号营销
  • 医学网站建设方案/友链目录网