当前位置: 首页 > news >正文

kron积计算mask类别矩阵

文章目录

  • 1. 生成类别矩阵如下
  • 2. pytorch 代码
  • 3. 循环移动矩阵

1. 生成类别矩阵如下

在这里插入图片描述

2. pytorch 代码

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.set_printoptions(precision=3, sci_mode=False)

if __name__ == "__main__":
    run_code = 0
    a_matrix = torch.arange(4).reshape(2, 2) + 1
    b_matrix = torch.ones((2, 2))
    print(f"a_matrix=\n{a_matrix}")
    print(f"b_matrix=\n{b_matrix}")
    c_matrix = torch.kron(input=a_matrix, other=b_matrix)
    print(f"c_matrix=\n{c_matrix}")
    d_matrix = torch.arange(9).reshape(3, 3) + 1
    e_matrix = torch.ones((2, 2))
    f_matrix = torch.kron(input=d_matrix, other=e_matrix)
    print(f"d_matrix=\n{d_matrix}")
    print(f"e_matrix=\n{e_matrix}")
    print(f"f_matrix=\n{f_matrix}")
    g_matrix = f_matrix[1:-1, 1:-1]
    print(f"g_matrix=\n{g_matrix}")
  • 结果:
a_matrix=
tensor([[1, 2],
        [3, 4]])
b_matrix=
tensor([[1., 1.],
        [1., 1.]])
c_matrix=
tensor([[1., 1., 2., 2.],
        [1., 1., 2., 2.],
        [3., 3., 4., 4.],
        [3., 3., 4., 4.]])
d_matrix=
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
e_matrix=
tensor([[1., 1.],
        [1., 1.]])
f_matrix=
tensor([[1., 1., 2., 2., 3., 3.],
        [1., 1., 2., 2., 3., 3.],
        [4., 4., 5., 5., 6., 6.],
        [4., 4., 5., 5., 6., 6.],
        [7., 7., 8., 8., 9., 9.],
        [7., 7., 8., 8., 9., 9.]])
g_matrix=
tensor([[1., 2., 2., 3.],
        [4., 5., 5., 6.],
        [4., 5., 5., 6.],
        [7., 8., 8., 9.]])

3. 循环移动矩阵

  • excel 表示
    在这里插入图片描述
  • pytorch 源码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

torch.set_printoptions(precision=3, sci_mode=False)


class WindowMatrix(object):
    def __init__(self, num_patch=4, size=2):
        self.num_patch = num_patch
        self.size = size
        self.width = self.num_patch
        self.height = self.size * self.size
        self._result = torch.zeros((self.width, self.height))

    @property
    def result(self):
        a_size = int(math.sqrt(self.num_patch))
        a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1
        b_matrix = torch.ones(self.size, self.size)
        self._result = torch.kron(input=a_matrix, other=b_matrix)
        return self._result


class ShiftedWindowMatrix(object):
    def __init__(self, num_patch=9, size=2):
        self.num_patch = num_patch
        self.size = size
        self.width = self.num_patch
        self.height = self.size * self.size
        self._result = torch.zeros((self.width, self.height))

    @property
    def result(self):
        a_size = int(math.sqrt(self.num_patch))
        a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1
        b_matrix = torch.ones(self.size, self.size)
        my_result = torch.kron(input=a_matrix, other=b_matrix)
        self._result = my_result[1:-1, 1:-1]
        return self._result


class RollShiftedWindowMatrix(object):
    def __init__(self, num_patch=9, size=2):
        self.num_patch = num_patch
        self.size = size
        self.width = self.num_patch
        self.height = self.size * self.size
        self._result = torch.zeros((self.width, self.height))

    @property
    def result(self):
        a_size = int(math.sqrt(self.num_patch))
        a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1
        b_matrix = torch.ones(self.size, self.size)
        my_result = torch.kron(input=a_matrix, other=b_matrix)
        my_result = my_result[1:-1, 1:-1]
        roll_result = torch.roll(input=my_result, shifts=(-1, -1), dims=(-1, -2))
        self._result = roll_result
        return self._result


class BackRollShiftedWindowMatrix(object):
    def __init__(self, num_patch=9, size=2):
        self.num_patch = num_patch
        self.size = size
        self.width = self.num_patch
        self.height = self.size * self.size
        self._result = torch.zeros((self.width, self.height))

    @property
    def result(self):
        a_size = int(math.sqrt(self.num_patch))
        a_matrix = torch.arange(self.num_patch).reshape(a_size, a_size) + 1
        b_matrix = torch.ones(self.size, self.size)
        my_result = torch.kron(input=a_matrix, other=b_matrix)
        my_result = my_result[1:-1, 1:-1]
        roll_result = torch.roll(input=my_result, shifts=(-1, -1), dims=(-1, -2))
        print(f"roll_result=\n{roll_result}")
        roll_result = torch.roll(input=roll_result, shifts=(1, 1), dims=(-1, -2))
        self._result = roll_result
        return self._result


if __name__ == "__main__":
    run_code = 0
    my_window_matrix = WindowMatrix()
    my_window_matrix_result = my_window_matrix.result
    print(f"my_window_matrix_result=\n{my_window_matrix_result}")
    shifted_window_matrix = ShiftedWindowMatrix()
    shifed_window_matrix_result = shifted_window_matrix.result
    print(f"shifed_window_matrix_result=\n{shifed_window_matrix_result}")
    roll_shifted_window_matrix = RollShiftedWindowMatrix()
    roll_shifed_window_matrix_result = roll_shifted_window_matrix.result
    print(f"roll_shifed_window_matrix_result=\n{roll_shifed_window_matrix_result}")
    Back_roll_shifted_window_matrix = BackRollShiftedWindowMatrix()
    back_roll_shifed_window_matrix_result = Back_roll_shifted_window_matrix.result
    print(f"back_roll_shifed_window_matrix_result=\n{back_roll_shifed_window_matrix_result}")
  • 结果:
my_window_matrix_result=
tensor([[1., 1., 2., 2.],
        [1., 1., 2., 2.],
        [3., 3., 4., 4.],
        [3., 3., 4., 4.]])
shifed_window_matrix_result=
tensor([[1., 2., 2., 3.],
        [4., 5., 5., 6.],
        [4., 5., 5., 6.],
        [7., 8., 8., 9.]])
roll_shifed_window_matrix_result=
tensor([[5., 5., 6., 4.],
        [5., 5., 6., 4.],
        [8., 8., 9., 7.],
        [2., 2., 3., 1.]])
roll_result=
tensor([[5., 5., 6., 4.],
        [5., 5., 6., 4.],
        [8., 8., 9., 7.],
        [2., 2., 3., 1.]])
back_roll_shifed_window_matrix_result=
tensor([[1., 2., 2., 3.],
        [4., 5., 5., 6.],
        [4., 5., 5., 6.],
        [7., 8., 8., 9.]])
http://www.dtcms.com/a/14204.html

相关文章:

  • python连点器
  • 【STM32】舵机SG90
  • 部署 DeepSeek R1各个版本所需硬件配置清单
  • 网络分析工具—WireShark的安装及使用
  • 【自然语言处理】TextRank 算法提取关键词、短语、句(Python源码实现)
  • 【学习笔记】for、forEach会不会被await阻塞
  • 【2024~2025年备受关注的AI大模型】
  • 杂记:STM32 调试信息打印实现方式
  • 关于 IoT DC3 中驱动(Driver)的理解
  • SolidWorks C# How
  • go语言获取机器的进程和进程运行参数 获取当前进程的jmx端口 go调用/jstat获取当前Java进程gc情况
  • 【前端】几种常见的跨域解决方案代理的概念
  • SQLMesh系列教程-2:SQLMesh入门项目实战(上篇)
  • SQL布尔盲注、时间盲注
  • [SQL Server]从数据类型 varchar 转换为 numeric 时出错
  • 排序--四种算法
  • STM32、GD32驱动TM1640原理图、源码分享
  • HCIA项目实践--RIP相关原理知识面试问题总结回答
  • 服务器,交换机和路由器的一些笔记
  • 机器学习(李宏毅)——self-Attention
  • 常见的排序算法:插入排序、选择排序、冒泡排序、快速排序
  • 利用Java爬虫按图搜索1688商品(拍立淘):实战案例指南
  • 集成学习(一):从理论到实战(附代码)
  • sqli-lab靶场学习(六)——Less18-22(User-Agent、Referer、Cookie注入)
  • 网络工程师 (35)以太网通道
  • iptables网络安全服务详细使用
  • ES节点配置的最佳实践
  • 开发指南098-logback-spring.xml说明
  • 六西格玛设计培训如何破解风电设备制造质量与成本困局
  • 错误报告:WebSocket 设备连接断开处理问题