当前位置: 首页 > news >正文

kron积计算mask类别矩阵

文章目录

  • 1. 生成类别矩阵如下
  • 2. pytorch 代码
  • 3. 循环移动矩阵

1. 生成类别矩阵如下

在这里插入图片描述

2. pytorch 代码

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.set_printoptions(precision=3, sci_mode=False)

if __name__ == "__main__":
    run_code = 0
    a_matrix = torch.arange(4).reshape(2, 2) + 1
    b_matrix = torch.ones((2, 2))
    print(f"a_matrix=\n{a_matrix}")
    print(f"b_matrix=\n{b_matrix}")
    c_matrix = torch.kron(input=a_matrix, other=b_matrix)
    print(f"c_matrix=\n{c_matrix}")
    d_matrix = torch.arange(9).reshape(3, 3) + 1
    e_matrix = torch.ones((2, 2))
    f_matrix = torch.kron(input=d_matrix, other=e_matrix)
    print(f"d_matrix=\n{d_matrix}")
    print(f"e_matrix=\n{e_matrix}")
    print(f"f_matrix=\n{f_matrix}")
    g_matrix = f_matrix[1:-1, 1:-1]
    print(f"g_matrix=\n{g_matrix}")
  • 结果:
a_matrix=
tensor([[1, 2],
        [3, 4]])
b_matrix=
tensor([[1., 1.],
        [1., 1.]])
c_matrix=
tensor([[1., 1., 2., 2.],
        [1., 1., 2., 2.],
        [3., 3., 4., 4.],
        [3., 3., 4., 4.]])
d_matrix=
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
e_matrix=
tensor([[1., 1.],
        [1., 1.]])
f_matrix=
tensor([[1., 1., 2., 2., 3., 3.],
        [1., 1., 2., 2., 3., 3.],
        [4., 4., 5., 5., 6., 6.],
        [4., 4., 5., 5., 6., 6.],
        [7., 7., 8., 8., 9., 9.],
        [7., 7., 8., 8., 9., 9.]])
g_matrix=
tensor([[1., 2., 2., 3.],
        [4., 5., 5., 6.],
        [4., 5., 5., 6.],
        [7., 8., 8., 9.]])

3. 循环移动矩阵

  • excel 表示
    在这里插入图片描述
  • pytorch 源码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

torch.set_printoptions(precision=3, sci_mode=False)


class WindowMatrix(object):
    def __init__(self, num_patch=4, size=2):
        self.num_patch = num_patch
        self.size = size
        self.width = self.num_patch
        self.height = self.size * self.size
        self._result = torch.zeros((self.width, self.height))

    @property
    def result(self):
        a_size = int(math.sqrt(self.num_patch))
        a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1
        b_matrix = torch.ones(self.size, self.size)
        self._result = torch.kron(input=a_matrix, other=b_matrix)
        return self._result


class ShiftedWindowMatrix(object):
    def __init__(self, num_patch=9, size=2):
        self.num_patch = num_patch
        self.size = size
        self.width = self.num_patch
        self.height = self.size * self.size
        self._result = torch.zeros((self.width, self.height))

    @property
    def result(self):
        a_size = int(math.sqrt(self.num_patch))
        a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1
        b_matrix = torch.ones(self.size, self.size)
        my_result = torch.kron(input=a_matrix, other=b_matrix)
        self._result = my_result[1:-1, 1:-1]
        return self._result


class RollShiftedWindowMatrix(object):
    def __init__(self, num_patch=9, size=2):
        self.num_patch = num_patch
        self.size = size
        self.width = self.num_patch
        self.height = self.size * self.size
        self._result = torch.zeros((self.width, self.height))

    @property
    def result(self):
        a_size = int(math.sqrt(self.num_patch))
        a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1
        b_matrix = torch.ones(self.size, self.size)
        my_result = torch.kron(input=a_matrix, other=b_matrix)
        my_result = my_result[1:-1, 1:-1]
        roll_result = torch.roll(input=my_result, shifts=(-1, -1), dims=(-1, -2))
        self._result = roll_result
        return self._result


class BackRollShiftedWindowMatrix(object):
    def __init__(self, num_patch=9, size=2):
        self.num_patch = num_patch
        self.size = size
        self.width = self.num_patch
        self.height = self.size * self.size
        self._result = torch.zeros((self.width, self.height))

    @property
    def result(self):
        a_size = int(math.sqrt(self.num_patch))
        a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1
        b_matrix = torch.ones(self.size, self.size)
        my_result = torch.kron(input=a_matrix, other=b_matrix)
        my_result = my_result[1:-1, 1:-1]
        roll_result = torch.roll(input=my_result, shifts=(-1, -1), dims=(-1, -2))
        print(f"roll_result=\n{roll_result}")
        roll_result = torch.roll(input=roll_result, shifts=(1, 1), dims=(-1, -2))
        self._result = roll_result
        return self._result


if __name__ == "__main__":
    run_code = 0
    my_window_matrix = WindowMatrix()
    my_window_matrix_result = my_window_matrix.result
    print(f"my_window_matrix_result=\n{my_window_matrix_result}")
    shifted_window_matrix = ShiftedWindowMatrix()
    shifed_window_matrix_result = shifted_window_matrix.result
    print(f"shifed_window_matrix_result=\n{shifed_window_matrix_result}")
    roll_shifted_window_matrix = RollShiftedWindowMatrix()
    roll_shifed_window_matrix_result = roll_shifted_window_matrix.result
    print(f"roll_shifed_window_matrix_result=\n{roll_shifed_window_matrix_result}")
    Back_roll_shifted_window_matrix = BackRollShiftedWindowMatrix()
    back_roll_shifed_window_matrix_result = Back_roll_shifted_window_matrix.result
    print(f"back_roll_shifed_window_matrix_result=\n{back_roll_shifed_window_matrix_result}")
  • 结果:
my_window_matrix_result=
tensor([[1., 1., 2., 2.],
        [1., 1., 2., 2.],
        [3., 3., 4., 4.],
        [3., 3., 4., 4.]])
shifed_window_matrix_result=
tensor([[1., 2., 2., 3.],
        [4., 5., 5., 6.],
        [4., 5., 5., 6.],
        [7., 8., 8., 9.]])
roll_shifed_window_matrix_result=
tensor([[5., 5., 6., 4.],
        [5., 5., 6., 4.],
        [8., 8., 9., 7.],
        [2., 2., 3., 1.]])
roll_result=
tensor([[5., 5., 6., 4.],
        [5., 5., 6., 4.],
        [8., 8., 9., 7.],
        [2., 2., 3., 1.]])
back_roll_shifed_window_matrix_result=
tensor([[1., 2., 2., 3.],
        [4., 5., 5., 6.],
        [4., 5., 5., 6.],
        [7., 8., 8., 9.]])

相关文章:

  • python连点器
  • 【STM32】舵机SG90
  • 部署 DeepSeek R1各个版本所需硬件配置清单
  • 网络分析工具—WireShark的安装及使用
  • 【自然语言处理】TextRank 算法提取关键词、短语、句(Python源码实现)
  • 【学习笔记】for、forEach会不会被await阻塞
  • 【2024~2025年备受关注的AI大模型】
  • 杂记:STM32 调试信息打印实现方式
  • 关于 IoT DC3 中驱动(Driver)的理解
  • SolidWorks C# How
  • go语言获取机器的进程和进程运行参数 获取当前进程的jmx端口 go调用/jstat获取当前Java进程gc情况
  • 【前端】几种常见的跨域解决方案代理的概念
  • SQLMesh系列教程-2:SQLMesh入门项目实战(上篇)
  • SQL布尔盲注、时间盲注
  • [SQL Server]从数据类型 varchar 转换为 numeric 时出错
  • 排序--四种算法
  • STM32、GD32驱动TM1640原理图、源码分享
  • HCIA项目实践--RIP相关原理知识面试问题总结回答
  • 服务器,交换机和路由器的一些笔记
  • 机器学习(李宏毅)——self-Attention
  • 图集︱“中国排面”威武亮相
  • 黄玮接替周继红出任国家体育总局游泳运动管理中心主任
  • 湖南省邵阳市副市长仇珂静主动向组织交代问题,接受审查调查
  • “用鲜血和生命凝结的深厚情谊”——习近平主席署名文章中的中俄友好故事
  • 昆廷·斯金纳:作为“独立自主”的自由
  • 又一日军“慰安妇”制度受害者去世,大陆登记在册幸存者仅剩7人