当前位置: 首页 > news >正文

einops测试

文章目录

  • 1. einops
  • 2. code
  • 3. pytorch

1. einops

einops 主要是通过爱因斯坦标记法来处理张量矩阵的库,让矩阵处理上非常简单。

  • conda :
conda install conda-forge::einops
  • python:

2. code

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat, reduce

torch.set_printoptions(precision=3, sci_mode=False)

if __name__ == "__main__":
    run_code = 0
    x = torch.arange(96).reshape((2, 3, 4, 4)).to(torch.float32)
    print(f"x.shape={x.shape}")
    print(f"x=\n{x}")

    # 1. 转置
    x_torch_trans = x.transpose(1, 2)
    x_einops_trans = rearrange(x, 'b i w h -> b w i h')
    x_check_trans = torch.allclose(x_torch_trans, x_einops_trans)
    print(f"x_torch_trans is {x_check_trans} same with x_einops_trans")

    # 2. 变形
    x_torch_reshape = x.reshape(6, 4, 4)
    x_einops_reshape = rearrange(x, 'b i w h -> (b i) w h')
    x_check_reshape = torch.allclose(x_torch_reshape, x_einops_reshape)
    print(f"x_einops_reshape is {x_check_reshape} same with x_check_reshape")

    # 3. image2patch
    image2patch = rearrange(x, 'b i (h1 p1) (w1 p2) -> b i (h1 w1) p1 p2', p1=2, p2=2)
    print(f"image2patch.shape={image2patch.shape}")
    print(f"image2patch=\n{image2patch}")
    image2patch2 = rearrange(image2patch, 'b i j h w -> b (i j) h w')
    print(f"image2patch2.shape={image2patch2.shape}")
    print(f"image2patch2=\n{image2patch2}")
    y = torch.arange(24).reshape((2, 3, 4)).to(torch.float32)
    y_einops_mean = reduce(y, 'b h w -> b h', 'mean')
    print(f"y=\n{y}")
    print(f"y_einops_mean=\n{y_einops_mean}")
    y_tensor = torch.arange(24).reshape(2, 2, 2, 3)
    y_list = [y_tensor, y_tensor, y_tensor]
    y_output = rearrange(y_list, 'n b i h w -> n b i h w')
    print(f"y_tensor=\n{y_tensor}")
    print(f"y_output=\n{y_output}")
    z_tensor = torch.arange(12).reshape(2, 2, 3).to(torch.float32)
    z_tensor_1 = rearrange(z_tensor, 'b h w -> b h w 1')
    print(f"z_tensor=\n{z_tensor}")
    print(f"z_tensor_1=\n{z_tensor_1}")
    z_tensor_2 = repeat(z_tensor_1, 'b h w 1 -> b h w 2')
    print(f"z_tensor_2=\n{z_tensor_2}")
    z_tensor_repeat = repeat(z_tensor, 'b h w -> b (2 h) (2 w)')
    print(f"z_tensor_repeat=\n{z_tensor_repeat}")
  • python:
x.shape=torch.Size([2, 3, 4, 4])
x=
tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.],
          [12., 13., 14., 15.]],

         [[16., 17., 18., 19.],
          [20., 21., 22., 23.],
          [24., 25., 26., 27.],
          [28., 29., 30., 31.]],

         [[32., 33., 34., 35.],
          [36., 37., 38., 39.],
          [40., 41., 42., 43.],
          [44., 45., 46., 47.]]],


        [[[48., 49., 50., 51.],
          [52., 53., 54., 55.],
          [56., 57., 58., 59.],
          [60., 61., 62., 63.]],

         [[64., 65., 66., 67.],
          [68., 69., 70., 71.],
          [72., 73., 74., 75.],
          [76., 77., 78., 79.]],

         [[80., 81., 82., 83.],
          [84., 85., 86., 87.],
          [88., 89., 90., 91.],
          [92., 93., 94., 95.]]]])
x_torch_trans is True same with x_einops_trans
x_einops_reshape is True same with x_check_reshape
image2patch.shape=torch.Size([2, 3, 4, 2, 2])
image2patch=
tensor([[[[[ 0.,  1.],
           [ 4.,  5.]],

          [[ 2.,  3.],
           [ 6.,  7.]],

          [[ 8.,  9.],
           [12., 13.]],

          [[10., 11.],
           [14., 15.]]],


         [[[16., 17.],
           [20., 21.]],

          [[18., 19.],
           [22., 23.]],

          [[24., 25.],
           [28., 29.]],

          [[26., 27.],
           [30., 31.]]],


         [[[32., 33.],
           [36., 37.]],

          [[34., 35.],
           [38., 39.]],

          [[40., 41.],
           [44., 45.]],

          [[42., 43.],
           [46., 47.]]]],



        [[[[48., 49.],
           [52., 53.]],

          [[50., 51.],
           [54., 55.]],

          [[56., 57.],
           [60., 61.]],

          [[58., 59.],
           [62., 63.]]],


         [[[64., 65.],
           [68., 69.]],

          [[66., 67.],
           [70., 71.]],

          [[72., 73.],
           [76., 77.]],

          [[74., 75.],
           [78., 79.]]],


         [[[80., 81.],
           [84., 85.]],

          [[82., 83.],
           [86., 87.]],

          [[88., 89.],
           [92., 93.]],

          [[90., 91.],
           [94., 95.]]]]])
image2patch2.shape=torch.Size([2, 12, 2, 2])
image2patch2=
tensor([[[[ 0.,  1.],
          [ 4.,  5.]],

         [[ 2.,  3.],
          [ 6.,  7.]],

         [[ 8.,  9.],
          [12., 13.]],

         [[10., 11.],
          [14., 15.]],

         [[16., 17.],
          [20., 21.]],

         [[18., 19.],
          [22., 23.]],

         [[24., 25.],
          [28., 29.]],

         [[26., 27.],
          [30., 31.]],

         [[32., 33.],
          [36., 37.]],

         [[34., 35.],
          [38., 39.]],

         [[40., 41.],
          [44., 45.]],

         [[42., 43.],
          [46., 47.]]],


        [[[48., 49.],
          [52., 53.]],

         [[50., 51.],
          [54., 55.]],

         [[56., 57.],
          [60., 61.]],

         [[58., 59.],
          [62., 63.]],

         [[64., 65.],
          [68., 69.]],

         [[66., 67.],
          [70., 71.]],

         [[72., 73.],
          [76., 77.]],

         [[74., 75.],
          [78., 79.]],

         [[80., 81.],
          [84., 85.]],

         [[82., 83.],
          [86., 87.]],

         [[88., 89.],
          [92., 93.]],

         [[90., 91.],
          [94., 95.]]]])
y=
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])
y_einops_mean=
tensor([[ 1.500,  5.500,  9.500],
        [13.500, 17.500, 21.500]])
y_tensor=
tensor([[[[ 0,  1,  2],
          [ 3,  4,  5]],

         [[ 6,  7,  8],
          [ 9, 10, 11]]],


        [[[12, 13, 14],
          [15, 16, 17]],

         [[18, 19, 20],
          [21, 22, 23]]]])
y_output=
tensor([[[[[ 0,  1,  2],
           [ 3,  4,  5]],

          [[ 6,  7,  8],
           [ 9, 10, 11]]],


         [[[12, 13, 14],
           [15, 16, 17]],

          [[18, 19, 20],
           [21, 22, 23]]]],



        [[[[ 0,  1,  2],
           [ 3,  4,  5]],

          [[ 6,  7,  8],
           [ 9, 10, 11]]],


         [[[12, 13, 14],
           [15, 16, 17]],

          [[18, 19, 20],
           [21, 22, 23]]]],



        [[[[ 0,  1,  2],
           [ 3,  4,  5]],

          [[ 6,  7,  8],
           [ 9, 10, 11]]],


         [[[12, 13, 14],
           [15, 16, 17]],

          [[18, 19, 20],
           [21, 22, 23]]]]])
z_tensor=
tensor([[[ 0.,  1.,  2.],
         [ 3.,  4.,  5.]],

        [[ 6.,  7.,  8.],
         [ 9., 10., 11.]]])
z_tensor_1=
tensor([[[[ 0.],
          [ 1.],
          [ 2.]],

         [[ 3.],
          [ 4.],
          [ 5.]]],


        [[[ 6.],
          [ 7.],
          [ 8.]],

         [[ 9.],
          [10.],
          [11.]]]])
z_tensor_2=
tensor([[[[ 0.,  0.],
          [ 1.,  1.],
          [ 2.,  2.]],

         [[ 3.,  3.],
          [ 4.,  4.],
          [ 5.,  5.]]],


        [[[ 6.,  6.],
          [ 7.,  7.],
          [ 8.,  8.]],

         [[ 9.,  9.],
          [10., 10.],
          [11., 11.]]]])
z_tensor_repeat=
tensor([[[ 0.,  1.,  2.,  0.,  1.,  2.],
         [ 3.,  4.,  5.,  3.,  4.,  5.],
         [ 0.,  1.,  2.,  0.,  1.,  2.],
         [ 3.,  4.,  5.,  3.,  4.,  5.]],

        [[ 6.,  7.,  8.,  6.,  7.,  8.],
         [ 9., 10., 11.,  9., 10., 11.],
         [ 6.,  7.,  8.,  6.,  7.,  8.],
         [ 9., 10., 11.,  9., 10., 11.]]])

3. pytorch

在这里插入图片描述

相关文章:

  • C#导出dataGridView数据
  • 【Node.js】express框架
  • 【论文带读(1)】《End-to-End Object Detection with Transformers》论文超详细带读 + 翻译
  • 人工智能(AI)的不同维度分类
  • 【知识】Nginx反向代理路径到指定端口,很全面
  • 3D模型在线转换工具:轻松实现3DM转OBJ
  • 深度学习的集装箱箱号OCR识别技术,识别率99.9%
  • mysql之B+ 树索引 (InnoDB 存储引擎)机制
  • Eclipse2024中文汉化教程(图文版)
  • Kafka客户端连接服务端异常 Can‘t resolve address: VM-12-16-centos:9092
  • 深入理解设计模式之外观模式
  • 【Java】Java 常用核心类篇 —— 时间-日期API(上)
  • 个人环境配置--安装记录
  • 怎麼利用靜態ISP住宅代理在指紋流覽器中管理社媒帳號?
  • uniapp微信小程序PC端选择文件(无法使用wx.chooseMessageFile问题)
  • Linux 常用命令最全总结大全【推荐收藏】
  • 安当ASP:中小企业低成本Radius认证服务器解决方案
  • C++核心编程之引用
  • python装饰器的详解使用
  • 深入理解 Java 接口的回调机制 【学术会议-2025年人工智能与计算智能(AICI 2025)】
  • 万玲、胡春平调任江西省鹰潭市副市长
  • 5月12日至13日北京禁飞“低慢小”航空器
  • 光大华夏:近代中国私立大学遥不可及的梦想
  • 城管给商户培训英语、政银企合作纾困,上海街镇这样优化营商环境
  • 公积金利率降至历史最低!多项房地产利好政策落地,购房者置业成本又降了
  • 潘功胜发布会答问五大要点:除了降准降息,这些政策“含金量”也很高