当前位置: 首页 > news >正文

einsum函数

理解专家并行,需要了解einsum函数

import torch

# 设置输入张量的维度:s = 3 tokens, e = 2 experts, c = 2 capacity, m = 4 embedding dim
s, e, c, m = 3, 2, 2, 4

# 1. 输入 token 的嵌入向量 (s, m)
reshaped_input = torch.tensor([
    [1.0, 1.0, 1.0, 1.0],  # token 0
    [2.0, 2.0, 2.0, 2.0],  # token 1
    [3.0, 3.0, 3.0, 3.0],  # token 2
])

# 2. dispatch_mask: (s, e, c)
# 表示每个 token 被分配到哪个 expert 的哪个槽位(slot)
dispatch_mask = torch.tensor([
    # token 0
    [[1, 0],   # expert 0: slot 0
     [0, 0]],  # expert 1: no slot

    # token 1
    [[0, 0],
     [1, 0]],  # expert 1: slot 0

    # token 2
    [[0, 1],   # expert 0: slot 1
     [0, 0]],  # expert 1: no slot
])
dispatch_mask = dispatch_mask.float()
# 3. 应用 einsum 进行 token 分发到专家
dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask, reshaped_input)

# 4. 打印结果
print("Dispatched Input shape:", dispatched_input.shape)
print("\nDispatched Input Tensor:")
print(dispatched_input)

结果如下

Dispatched Input shape: torch.Size([2, 2, 4])

Dispatched Input Tensor:
tensor([[[1., 1., 1., 1.],
         [3., 3., 3., 3.]],

        [[2., 2., 2., 2.],
         [0., 0., 0., 0.]]])

增加画图功能

import torch
import matplotlib.pyplot as plt
# 输入数据
reshaped_input = torch.tensor([
    [1.0, 1.0, 1.0, 1.0],
    [2.0, 2.0, 2.0, 2.0],
    [3.0, 3.0, 3.0, 3.0],
])  # float32

dispatch_mask = torch.tensor([
    [[1, 0], [0, 0]],
    [[0, 0], [1, 0]],
    [[0, 1], [0, 0]],
])  # int64 → 不兼容

# 修复:转换为 float 类型
dispatch_mask = dispatch_mask.float()

# Einsum 分发
dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask, reshaped_input)

# 输出
print("Dispatched input shape:", dispatched_input.shape)
print(dispatched_input)

def visualize_dispatch(dispatch_mask):
    s, e, c = dispatch_mask.shape  # tokens, experts, capacity
    plt.figure(figsize=(6, 4))

    for token in range(s):
        for expert in range(e):
            for slot in range(c):
                if dispatch_mask[token, expert, slot] > 0:
                    # token 位置 (左边)
                    x_token, y_token = 0, s - token
                    # expert-slot 位置 (右边)
                    x_expert, y_expert = 4, e * c - (expert * c + slot)

                    # 画连接线
                    plt.plot([x_token, x_expert], [y_token, y_expert], 'k-', lw=1)

                    # 标记 token
                    plt.text(x_token - 0.2, y_token, f"T{token}", va='center', ha='right', fontsize=10)

                    # 标记 expert-slot
                    plt.text(x_expert + 0.2, y_expert, f"E{expert}-S{slot}", va='center', ha='left', fontsize=10)

    # 设置图形样式
    plt.xlim(-1, 6)
    plt.ylim(0, max(s, e*c) + 1)
    plt.axis('off')
    plt.title("Token → Expert-Slot Routing")
    plt.show()

visualize_dispatch(dispatch_mask)


http://www.dtcms.com/a/106790.html

相关文章:

  • 技术回顾day3
  • 大语言模型在端到端智驾中的应用
  • 【Ragflow】9.问答为什么比搜索响应慢?从源码角度深入分析
  • 社交类 APP 设计:打造高用户粘性的界面
  • LE AUDIO CIS连接建立失败问题分析
  • 6.git项目实现变更拉取与上传
  • C++虚函数与抽象类
  • 使用 libevent 处理 TCP 粘包问题(基于 Content-Length 或双 \r\n)
  • 操作系统高频(七)虚拟地址与页表
  • ADASH VA5 Pro中的route功能
  • electron 的 appData 和 userData 有什么区别
  • SPI高级特性分析
  • JavaScript instanceof 运算符全解析
  • 「DeepSeek-V3 技术解析」:无辅助损失函数的负载均衡
  • 双模多态驱动:DeepSeek-V3-0324与DeepSeek-R1医疗领域应用比较分析与混合应用讨论
  • 移动通信网络中漫游机制深度解析:归属网络与拜访网络的协同逻辑
  • PHP的相关配置和优化
  • openstack 查看所有项目配额的命令
  • SU CTF 2025 web 复现
  • tcp的粘包拆包问题,如何解决?
  • 【深度学习量化交易21】行情数据获取方式比测(2)——基于miniQMT的量化交易回测系统开发实记
  • 常见电源模块设计
  • ColPali:基于视觉语言模型的高效文档检索
  • 探索鸿蒙操作系统:迎接万物互联新时代
  • 【IOS webview】源代码映射错误,页面卡住不动
  • STM32单片机入门学习——第7节: [3-3] GPIO输入
  • 树莓派超全系列教程文档--(22)使用外部存储设备的相关操作
  • Spring Boot 集成Redis中 RedisTemplate 及相关操作接口对比与方法说明
  • #Linux内存管理# 假设设备上安装了一块2G的物理内存,在系统启动时,ARM Linux内核是如何映射的?
  • RAG 和 RAGFlow 学习笔记