当前位置: 首页 > news >正文

PyTorch中diag_embed和transpose函数使用详解

torch.diag_embed 是 PyTorch 中用于将一个向量(或批量向量)**嵌入为对角矩阵(或批量对角矩阵)**的函数。它常用于图神经网络(GNN)或线性代数中生成对角矩阵。


函数原型

torch.diag_embed(input, offset=0, dim1=-2, dim2=-1)
参数解释:
  • input:形状为 (..., n) 的张量,表示一个或多个长度为 n 的向量;
  • offset:对角偏移量(默认是 0,即主对角线);
  • dim1, dim2:在哪两个维度上插入对角矩阵(通常保持默认即可)。

示例

示例 1:单个向量生成对角矩阵
x = torch.tensor([1, 2, 3])
out = torch.diag_embed(x)
# 输出:
# tensor([[1, 0, 0],
#         [0, 2, 0],
#         [0, 0, 3]])
示例 2:批量嵌入
x = torch.tensor([[1, 2, 3], [4, 5, 6]])  # shape: (2, 3)
out = torch.diag_embed(x)
# 输出 shape: (2, 3, 3)
# 第一个矩阵是 [1,2,3] 的对角形式,第二个是 [4,5,6] 的对角形式

应用场景(

degree_signal = torch.sum(corr_graph, dim=-1)           # shape: (1, N)
D = torch.diag_embed(degree_signal)                     # shape: (1, N, N)
corr_laplacian = (D - corr_graph).squeeze(0)            # shape: (N, N)

这个操作是为了构造图拉普拉斯矩阵(Laplacian):

L = D − A L = D - A L=DA

其中:

  • A A A 是图的邻接矩阵(corr_graph);
  • D D D 是度矩阵(对角矩阵,diag_embed(degree_signal))。

在 PyTorch 中,transpose() 是用于交换张量中两个指定维度的函数,常用于调整张量维度顺序,特别是在矩阵运算或图神经网络等场景中。


函数格式:

torch.transpose(input, dim0, dim1)
# 或者张量对象方法形式:
input.transpose(dim0, dim1)

参数说明:

  • input:输入的张量(Tensor)。
  • dim0:要交换的第一个维度索引。
  • dim1:要交换的第二个维度索引。

示例 1:二维张量(矩阵)

x = torch.tensor([[1, 2], [3, 4]])  # shape: (2, 2)
print(x.shape)  # torch.Size([2, 2])y = x.transpose(0, 1)  # 转置矩阵
print(y)
# tensor([[1, 3],
#         [2, 4]])

示例 2:三维张量

x = torch.randn(2, 3, 4)  # shape: (batch=2, height=3, width=4)# 交换第1维(height)和第2维(width)
y = x.transpose(1, 2)  # shape: (2, 4, 3)
print(y.shape)

注意事项:

  • transpose()交换两个维度,如果要重新排列多个维度,请使用 permute()
  • transpose() 返回的是一个视图(view),不复制数据。

.T 的区别:

  • tensor.T 只适用于 二维张量,是 transpose(0, 1) 的简写。
  • 多维张量请使用 transpose(dim0, dim1)permute()

示例:配合 .permute()

x = torch.randn(2, 3, 4)
# 等价于 transpose(1, 2)
x.transpose(1, 2) == x.permute(0, 2, 1)  # True

相关文章:

  • 工商业预付费系统组成架构及系统特点介绍
  • 01-jenkins学习之旅-window-下载-安装-安装后设置向导
  • Spring IoC 和 AOP -- 核心原理与高频面试题解析
  • 设计双向链表--LeetCode
  • MinerU教程第二弹丨MinerU 本地部署保姆级“喂饭”教程
  • BGE-M3 文本情感分类实战:预训练模型微调,导出ONNX并测试
  • OpenCv高阶(十七)——dlib库安装、dlib人脸检测
  • Jeecg漏洞总结及tscan poc分享
  • Mujoco 学习系列(四)官方模型仓库 mujoco_menagerie
  • LangChain文档加载器实战:构建高效RAG数据流水线
  • 第八天的尝试
  • js中encodeURIComponent函数使用场景
  • 3.9/Q1,GBD数据库最新文章解读
  • FinalShell 密码在线解析方法(含完整源码与运行平台)
  • SQLServer与MySQL数据迁移案例解析
  • mysql日志文件binlog分析记录
  • 软考 系统架构设计师系列知识点之杂项集萃(69)
  • [Usaco2007 Dec]队列变换 题解
  • Python之web错误处理与异常捕获
  • LeRobot的机器人控制系统(下)
  • 高校两学一做专题网站/网络推广员要怎么做
  • wordpress 占用资源/郑州seo优化哪家好
  • 业绩显示屏 东莞网站建设技术支持/有链接的网站
  • 温州网站建设方案报价/百度竞价推广点击软件奔奔
  • 如何自己做时时彩网站/最新域名ip地址
  • 中级经济师考试报名/优化技术基础