当前位置: 首页 > news >正文

pytorch小记(十):pytorch中torch.tril 和 torch.triu 详解

pytorch小记(十):pytorch中torch.tril 和 torch.triu 详解

  • PyTorch `torch.tril` 和 `torch.triu` 详解
    • 1. `torch.tril`(计算下三角矩阵)
      • 📌 作用
      • 🔍 语法
      • 🔹 参数
      • 📌 示例
      • 🔍 `diagonal` 参数
      • 🔍 `torch.tril` 的应用
    • 2. `torch.triu`(计算上三角矩阵)
      • 📌 作用
      • 🔍 语法
      • 🔹 参数
      • 📌 示例
      • 🔍 `diagonal` 参数
    • 3. `torch.tril` vs `torch.triu` 对比
    • 总结


PyTorch torch.triltorch.triu 详解

在数值计算和深度学习中,下三角矩阵(Lower Triangular Matrix)上三角矩阵(Upper Triangular Matrix) 是非常常见的矩阵操作。PyTorch 提供了 torch.tril()torch.triu() 这两个函数,分别用于计算下三角矩阵和上三角矩阵。


1. torch.tril(计算下三角矩阵)

📌 作用

torch.tril 返回输入张量的 下三角部分,即:

  • 保留 主对角线及其以下的元素
  • 主对角线以上的元素全部变为 0

🔍 语法

torch.tril(input, diagonal=0)

🔹 参数

参数说明
input输入张量
diagonal控制对角线位置(默认 0
diagonal=0保留主对角线 及其以下的元素
diagonal>0向上偏移,保留主对角线以上 diagonal
diagonal<0向下偏移,移除 -diagonal 行的主对角线元素

📌 示例

import torch

# 创建一个 4×4 的矩阵
A = torch.tensor([
    [1, 2, 3, 4],
    [5, 6, 7, 8],
    [9, 10, 11, 12],
    [13, 14, 15, 16]
])

print("原始矩阵 A:")
print(A)

# 计算 A 的下三角矩阵
L = torch.tril(A)
print("\nA 的下三角矩阵(diagonal=0):")
print(L)

输出:

原始矩阵 A:
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12],
        [13, 14, 15, 16]])

A 的下三角矩阵(diagonal=0):
tensor([[ 1,  0,  0,  0],
        [ 5,  6,  0,  0],
        [ 9, 10, 11,  0],
        [13, 14, 15, 16]])

💡 说明:主对角线上的元素保留,其上的元素变为 0


🔍 diagonal 参数

print(torch.tril(A, diagonal=1))  # 保留主对角线以上 1 行
print(torch.tril(A, diagonal=-1)) # 移除主对角线

输出:

A 的下三角矩阵(diagonal=1):
tensor([[ 1,  2,  0,  0],
        [ 5,  6,  7,  0],
        [ 9, 10, 11, 12],
        [13, 14, 15, 16]])

A 的下三角矩阵(diagonal=-1):
tensor([[ 0,  0,  0,  0],
        [ 5,  0,  0,  0],
        [ 9, 10,  0,  0],
        [13, 14, 15,  0]])

🔺 diagonal=1向上偏移,保留 1 行主对角线以上的元素。
🔻 diagonal=-1向下偏移,移除主对角线。


🔍 torch.tril 的应用

📌 用于 Masking(掩码)

seq_length = 5
mask = torch.tril(torch.ones(seq_length, seq_length))  # 创建一个下三角 Mask
print(mask)

输出:

tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])

💡 Transformer 中,这种 Mask 用于防止模型在训练时提前看到未来的信息。


2. torch.triu(计算上三角矩阵)

📌 作用

torch.triu 返回输入张量的 上三角部分,即:

  • 保留 主对角线及其以上的元素
  • 主对角线以下的元素全部变为 0

🔍 语法

torch.triu(input, diagonal=0)

🔹 参数

参数说明
input输入张量
diagonal=0保留主对角线及其以上的元素
diagonal>0移除 diagonal 行的主对角线元素
diagonal<0保留主对角线以下 -diagonal

📌 示例

U = torch.triu(A)
print("A 的上三角矩阵(diagonal=0):")
print(U)

输出:

A 的上三角矩阵(diagonal=0):
tensor([[ 1,  2,  3,  4],
        [ 0,  6,  7,  8],
        [ 0,  0, 11, 12],
        [ 0,  0,  0, 16]])

💡 说明:主对角线上的元素及其上的元素保留,下面的元素变为 0


🔍 diagonal 参数

print(torch.triu(A, diagonal=1))  # 移除主对角线元素
print(torch.triu(A, diagonal=-1)) # 保留主对角线以下 1 行

输出:

A 的上三角矩阵(diagonal=1):
tensor([[ 0,  2,  3,  4],
        [ 0,  0,  7,  8],
        [ 0,  0,  0, 12],
        [ 0,  0,  0,  0]])

A 的上三角矩阵(diagonal=-1):
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 0, 10, 11, 12],
        [ 0,  0, 15, 16]])

🔺 diagonal=1:移除主对角线的元素,仅保留主对角线以上的元素。
🔻 diagonal=-1:允许保留主对角线以下 1 行的元素。


3. torch.tril vs torch.triu 对比

作用torch.tril(A)torch.triu(A)
计算结果取下三角部分取上三角部分
对角线控制diagonal=0 保留主对角线diagonal=0 保留主对角线
diagonal>0保留主对角线以上元素移除主对角线部分元素
diagonal<0移除主对角线部分元素保留主对角线以下部分

总结

  • torch.tril()下三角矩阵,可以用于 Cholesky 分解Transformer Masking
  • torch.triu()上三角矩阵,常用于 线性代数计算矩阵变换

🚀 你可以根据不同的需求选择合适的函数,在 PyTorch 中高效处理矩阵运算!

相关文章:

  • 一场由 ES 分片 routing 引发的问题
  • 【含文档+PPT+源码】基于小程序的智能停车管理系统设计与开发
  • 【数据分享】1999—2023年地级市固定资产投资和对外经济贸易数据(Shp/Excel格式)
  • 咖啡点单小程序毕业设计(JAVA+SpringBoot+微信小程序+完整源码+论文)
  • 卷积神经网络(CNN)与反向传播
  • 威联通 NAS 的 Docker 镜像与安装 logseq
  • 案例驱动的 IT 团队管理:创新与突破之路:第三章 项目攻坚:从流程优化到敏捷破局-3.2.3技术债务的可视化管理方案
  • 永磁同步电机模型第二篇之两相电机实时模型
  • 使用 ESP8266 和 Android 应用程序实现基于 IOT 的语音控制家庭自动化
  • Apache DolphinScheduler:一个可视化大数据工作流调度平台
  • VSTO(C#)Excel开发13:实现定时器
  • Search after解决ES深度分页问题
  • Modbus通信协议基础知识总结
  • 003-掌控命令行-CLI11-C++开源库108杰
  • 音频大语言模型可作为描述性语音质量评价器
  • java学习笔记4
  • Java动态代理模式深度解析
  • Git 分支删除操作指南(含本地与远程)
  • 如何将MediaPipe编译成Android中Chaquopy插件可用的 .whl 文件
  • 鸿蒙NEXT开发问题大全(不断更新中.....)
  • 吴清:基金业绩差的必须少收管理费,督促基金公司从“重规模”向“重回报”转变
  • 李云泽:支持设立新的金融资产投资公司,今天即将批复一家
  • 李云泽:对受关税影响较大、经营暂时困难的市场主体,一企一策提供精准服务
  • 世界哮喘日|专家:哮喘无法根治,“临床治愈”已成治疗新目标
  • 十大券商看后市|A股风险偏好有回升空间,把握做多窗口
  • 贵州省委省政府迅速组织开展黔西市游船倾覆事故救援工作