当前位置: 首页 > news >正文

PyTorch_张量拼接

张量的拼接操作在神经网络搭建过程中是非常常用的方法,例如:残差网络,注意力机制中都使用张量拼接。


torch.cat 函数的使用

可以将两个张量根据指定的维度拼接起来。

import torch 
import numpy as np def test01():data1 = torch.randint(0, 10, [3, 4, 5])data2 = torch.randint(0, 10, [3, 4, 5])print(data1.shape)print(data2.shape)# dim 对应的值可以是负数,可以通过list来思考# 按照第 0 维度进行拼接new_data = torch.cat([data1, data2], dim = 0)  # 是列表print(new_data.shape)# 按照第 1 维度进行拼接new_data = torch.cat([data1, data2], dim = 1)print(new_data.shape)# 按照第 2 维度进行拼接new_data = torch.cat([data1, data2], dim = 2)print(new_data.shape)if __name__ == "__main__":test01() 

torch.stack 函数的使用

torch.stack 函数可以将两个张量根据指定的维度叠加起来,或者组合成新的元素。叠加的意思:当两个元素叠在一起,我们就将这两个元素当作一个元素。

import torch 
import numpy as np def test01():data1 = torch.randint(0, 10, [2, 3])data2 = torch.randint(0, 10, [2, 3])print(data1)print(data2)# 将两个张量 stack 叠加起来,像 cat 一样指定维度# 1. 按照第0维度进行叠加new_data = torch.stack([data1, data2], dim=0)print(new_data.shape)# 2. 按照第1维度进行叠加new_data = torch.stack([data1, data2], dim=1)print(new_data)# 3. 按照第2维度进行叠加new_data = torch.stack([data1, data2], dim=2)print(new_data)if __name__ == "__main__":test01() 

相关文章:

  • 多语言笔记系列:Polyglot Notebooks 多种使用方式
  • 升级 CUDA Toolkit 12.9 与 cuDNN 9.9.0 后验证指南:功能与虚拟环境检测
  • 基于大模型的隐睾(睾丸可触及)预测及临床干预策略研究报告
  • 机器学习+多目标优化的算法如何设计?
  • Fortran语言,do-end do循环,相互包含测试,自动性能优化
  • 《新手学看盘》速读笔记
  • 【浅尝Java】变量与数据类型(含隐式类型转换、强制类型转换、整型与字符串互相转换等)
  • 百度系列产品学习
  • Linux环境下的进程创建、退出和进程等待
  • C++专业面试题
  • comfyui错误记录:Text_Translation :No module named ‘translators‘
  • Linux文件权限管理:chmod修改权限 与 chown修改所有者
  • LeetCode 热题 100 48. 旋转图像
  • shell编程补充内容(Linux课程实验3)
  • 胶合目录解释
  • 如何提升个人情商?
  • TF-IDF算法详解
  • 【Godot】使用 Shader 实现可配置圆角效果
  • 缓存与数据库的高效读写流程解析
  • C++动态内存分配:从基础到最佳实践
  • 民族音乐还能这样玩!这场音乐会由AI作曲
  • 高速变道致连环车祸,白车“骑”隔离栏压住另一车,交警回应
  • 马上评|扩大高速免费救援范围,打消出行后顾之忧
  • 看见“看得见的手”,看见住房与土地——读《央地之间》
  • 出行注意防晒补水,上海五一假期以多云天气为主最高33℃
  • 跟着京剧电影游运河,京杭大运河沿线六城举行京剧电影展映