OpenPI源码分析
🎯 OpenPI 基础模型详解
一、模型整体架构
OpenPI 的基础模型是一个多模态 Transformer 架构,称为 π₀(Pi Zero),由三个核心组件组成:
┌─────────────────────────────────────────────────────┐
│ π₀ 基础模型架构 │
├─────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌──────────────┐ │
│ │ 图像输入 │ │ 语言指令 │ │
│ │ (多相机视角) │ │ ("拿起叉子") │ │
│ └──────┬───────┘ └──────┬───────┘ │
│ │ │ │
│ ▼ ▼ │
│ ┌──────────────┐ ┌──────────────┐ │
│ │ SigLIP │ │ Tokenizer │ │
│ │ 视觉编码器 │ │ 语言分词器 │ │
│ └──────┬───────┘ └──────┬───────┘ │
│ │ │ │
│ └─────────┬───────────┘ │
│ ▼ │
│ ┌─────────────────────┐ │
│ │ PaliGemma LLM │ │
│ │ (Gemma 2B) │ ← 机器人状态 │
│ │ 视觉-语言融合 │ │
│ └─────────┬───────────┘ │
│ ▼ │
│ ┌─────────────────────┐ │
│ │ Action Expert │ │
│ │ (Gemma 300M) │ ← 时间步嵌入 │
│ │ 动作专家网络 │ │
│ └─────────┬───────────┘ │
│ ▼ │
│ ┌─────────────────────┐ │
│ │ 动作输出投影层 │ │
│ └─────────┬───────────┘ │
│ ▼ │
│ 动作序列输出 │
│ (action_horizon × action_dim) │
│ │
└─────────────────────────────────────────────────────┘
二、三大核心组件详解
1️⃣ 视觉编码器 - SigLIP (So400m/14)
来源: Google 的 SigLIP (Sigmoid Loss for Language Image Pre-training)
规格:
- 变体:
So400m/14
(约 4 亿参数) - 补丁大小:14×14
- 输入分辨率:224×224
- 架构:Vision Transformer (ViT)
功能:
# 代码位置:src/openpi/models/siglip.py
img = _siglip.Module(num_classes=paligemma_config.width, # 输出维度 = 2048variant="So400m/14",pool_type="none", # 不做池化,保留所有图像 tokenscan=True,dtype_mm=config.dtype,
)
输出: 每张 224×224 图像被转换成 256 个图像 token(16×16 网格),每个 token 的维度为 2048
2️⃣ 语言模型主干 - PaliGemma (基于 Gemma 2B)
来源: Google 的 Gemma 模型(Gemini 的开源版本)
规格:
# 配置:src/openpi/models/gemma.py
paligemma_config = Config(width=2048, # 隐藏层维度depth=18, # Transformer 层数mlp_dim=16_384, # FFN 中间层维度num_heads=8, # 注意力头数num_kv_heads=1, # KV 头数(使用 GQA)head_dim=256, # 每个头的维度
)
# 总参数量:约 2B (20亿)
关键特性:
- Grouped Query Attention (GQA):
num_kv_heads=1
减少 KV 缓存内存 - RMSNorm: 替代 LayerNorm,更高效
- SwiGLU 激活函数: 性能更好的 FFN 激活
- 旋转位置编码 (RoPE): 更好的长度外推能力
功能: 融合视觉和语言信息,生成统一的多模态表示
3️⃣ 动作专家网络 - Action Expert (Gemma 300M)
规格:
# 配置:src/openpi/models/gemma.py
action_expert_config = Config(width=1024, # 隐藏层维度depth=18, # Transformer 层数mlp_dim=4096, # FFN 中间层维度num_heads=8,num_kv_heads=1,head_dim=256,
)
# 总参数量:约 311M (3.11亿)
功能: 专门用于生成机器人动作的轻量级 Transformer
π₀ vs π₀.₅ 的关键差异:
特性 | π₀ | π₀.₅ |
---|---|---|
状态输入方式 | 连续状态 token | 离散化状态(加入语言 token) |
时间步注入 | MLP 混合 | AdaRMSNorm(自适应归一化) |
训练技术 | 标准训练 | Knowledge Insulation(知识隔离) |
AdaRMSNorm (π₀.₅独有):
# 时间步条件注入方式
time_emb = posemb_sincos(timestep, ...) # 正弦余弦位置编码
time_emb = self.time_mlp_in(time_emb)
time_emb = swish(time_emb)
time_emb = self.time_mlp_out(time_emb)
# 这个 time_emb 作为 adarms_cond 注入到 Transformer 的归一化层
三、输入输出规格
模型输入 (Observation
):
{"images": {"base_0_rgb": [B, 224, 224, 3], # 外部相机"left_wrist_0_rgb": [B, 224, 224, 3], # 左手腕相机"right_wrist_0_rgb": [B, 224, 224, 3], # 右手腕相机},"image_masks": {"base_0_rgb": [B], # bool,指示图像是否有效...},"state": [B, 32], # 机器人关节位置、速度等"tokenized_prompt": [B, 48/200], # 分词后的语言指令"tokenized_prompt_mask": [B, 48/200], # token mask
}
模型输出 (Actions
):
{"actions": [B, 50, 32] # (批次, 时间步, 动作维度)
}
# B = batch size
# 50 = action_horizon (一次预测50步未来动作)
# 32 = action_dim (动作维度,取决于机器人自由度)
四、动作生成机制
π₀ / π₀.₅ 使用流匹配(Flow Matching):
# 训练时:src/openpi/models/pi0.py compute_loss()
# 1. 添加噪声到真实动作
time = random.beta(1.5, 1) # 时间步 t ∈ [0, 1]
x_t = t * noise + (1 - t) * actions # 插值# 2. 预测速度场(velocity field)
v_t = model(observation, x_t, time)# 3. 计算损失
u_t = noise - actions # 真实速度场
loss = mean_square(v_t - u_t)# 推理时:sample_actions()
# 通过 ODE 求解器从噪声迭代到干净动作
x_0 = x_1 # 从纯噪声开始
for step in range(num_steps): # 默认 10 步v_t = model(observation, x_t, time)x_t = x_t + dt * v_t # 欧拉法求解 ODE
优势:
- 比扩散模型更快(10步 vs 50-100步)
- 训练更稳定
- 生成质量高
π₀-FAST 使用自回归 + FAST 分词器:
# 动作离散化为 token
actions → FAST Tokenizer → action_tokens [B, L]
# 然后像语言模型一样自回归生成
for i in range(action_horizon):token_i = model(observation, tokens[:i])actions_i = detokenize(token_i)
五、模型参数量总结
组件 | 参数量 | 说明 |
---|---|---|
SigLIP 视觉编码器 | ~400M | 预训练权重(冻结或微调) |
PaliGemma LLM | ~2B | 视觉-语言融合主干 |
Action Expert | ~311M | 动作生成专家网络 |
投影层 | ~1M | 输入输出投影 |
总计 | ~2.7B | 约 27 亿参数 |
六、预训练数据
基础模型在 10,000+ 小时的机器人数据上预训练,包括:
- 内部机器人数据(Physical Intelligence 公司收集)
- 多种机器人平台(双臂、单臂、移动机器人)
- 多样化任务(抓取、放置、操作、导航等)
- 多模态数据(视觉 + 语言 + 动作)
七、为什么叫"基础模型"?
与语言领域的 GPT、LLaMA 类似,π₀ 是通用的机器人基础模型:
✅ 预训练在大规模数据上 → 学到丰富的机器人先验知识
✅ 可迁移到下游任务 → 通过微调适配特定机器人/任务
✅ 多模态理解能力 → 融合视觉、语言、动作
✅ 零样本泛化 → 在新场景中直接使用(有限程度)
这就是 OpenPI 提供的"基础模型"——一个预训练好的、可以快速适配各种机器人任务的强大起点!
关键模型源代码分析
📝 gemma.py 的来源
gemma.py
是从 Google 的 Big Vision 项目改编而来的,并非完全自己实现。
证据:
1. 文件头部版权声明(第1行):
# Copyright 2024 Big Vision Authors.
2. 文件注释(第15行):
"""Gemma adaptation for Pi, taken from big_vision.
明确说明:“为 Pi 改编的 Gemma 实现,取自 big_vision”
Big Vision 是什么?
Big Vision 是 Google Research 开发的一个大规模视觉模型训练框架,使用 JAX/Flax 实现。它包含了许多 Google 的视觉和多模态模型的实现,包括:
- PaliGemma (视觉-语言模型)
- Gemma (语言模型)
- SigLIP (视觉编码器)
- ViT (Vision Transformer)
OpenPI 做了哪些改编?
虽然源自 Big Vision,但 OpenPI 团队对代码进行了重要的适配和扩展:
1️⃣ 支持多专家架构(Multi-Expert)
# 原始 Gemma 只有一个模型
# OpenPI 改为支持多个 configs(PaliGemma + Action Expert)
class Module(nn.Module):configs: Sequence[Config] # 支持多个专家配置
2️⃣ 添加 AdaRMSNorm(自适应归一化)
# 第115-131行:为 π₀.₅ 添加的新功能
class RMSNorm(nn.Module):def __call__(self, x, cond):# ...if cond is None:# 标准 RMSNorm (原始实现)else:# 自适应 RMSNorm (OpenPI 新增)# 用于注入流匹配的时间步信息
3️⃣ 添加 LoRA 支持
# 集成了低秩适配(LoRA)用于高效微调
lora_configs: dict[str, lora.LoRAConfig] = ...
4️⃣ 特定的模型配置
# 为 π₀ 项目定制的 Gemma 变体
"gemma_2b", "gemma_300m", "gemma_2b_lora", "gemma_300m_lora"
📚 完整的来源总结
文件 | 来源 | 说明 |
---|---|---|
gemma.py | 🔵 Google Big Vision | “taken from big_vision” |
siglip.py | 🔵 Google Big Vision | “taken from big_vision” |
vit.py | 🟢 Google Vision Transformer | 改编自 google-research/vision_transformer |
总结
✅ OpenPI 并非从零实现这些模型,而是:
-
复用 Google 开源的高质量实现(Big Vision、Vision Transformer)
-
进行关键的架构改编以适应机器人任务:
- 多专家架构(PaliGemma + Action Expert)
- AdaRMSNorm(时间条件注入)
- LoRA 高效微调
- 流匹配动作生成
-
添加机器人特定的组件(完全自己实现):
pi0.py
- 整体模型架构pi0_fast.py
- FAST 变体tokenizer.py
- 动作分词器- 各种策略类(
aloha_policy.py
,droid_policy.py
等)
这是一种聪明且务实的工程实践 ✨:
- 站在巨人的肩膀上(复用 Google 的模型实现)
- 专注于创新点(机器人控制的架构设计)
- 快速验证想法(不需要重复造轮子)
这也符合 Apache 2.0 开源协议的要求,保留了原始版权声明和许可证信息。
让我系统地梳理 OpenPI 项目中的模型定义代码:
[3 tools called]
让我再查看几个关键文件来确认其作用:
[3 tools called]
完美!现在我可以为您系统地梳理 OpenPI 中所有模型定义相关的代码:
📁 OpenPI 模型定义代码全景图
🎯 核心分类
src/openpi/
├── models/ # JAX/Flax 实现(主要版本)
└── models_pytorch/ # PyTorch 实现(新增支持)
一、JAX/Flax 模型定义 (src/openpi/models/
)
🏗️ 1. 基础架构层
文件 | 作用 | 关键内容 |
---|---|---|
model.py | 🔴 模型基类和接口定义 | - BaseModel 抽象类- Observation 数据结构- Actions 类型定义- ModelType 枚举- 模型加载/保存逻辑 |
# model.py - 定义所有模型的基础接口
class BaseModel:@abc.abstractmethoddef compute_loss(self, rng, observation, actions, *, train=False):"""计算训练损失"""@abc.abstractmethoddef sample_actions(self, rng, observation, **kwargs):"""推理时采样动作"""
🧠 2. 完整模型实现层
这些是完整的端到端模型,继承自 BaseModel
:
文件 | 模型名称 | 说明 |
---|---|---|
pi0.py | 🟢 π₀ / π₀.₅ | - 流匹配(Flow Matching)VLA - 支持 π₀ 和 π₀.₅ 两种变体 - 包含完整的前向传播、损失计算、动作采样 |
pi0_fast.py | 🟡 π₀-FAST | - 自回归 VLA - 使用 FAST 动作分词器 - 像语言模型一样生成动作 token |
pi0_config.py | ⚙️ 模型配置 | - Pi0Config 数据类- 定义模型超参数(维度、层数等) |
代码示例 - pi0.py
的核心结构:
class Pi0(BaseModel):def __init__(self, config: Pi0Config, rngs: nnx.Rngs):# 1. 初始化 PaliGemma (2B Transformer)# 2. 初始化 Action Expert (300M Transformer)# 3. 初始化 SigLIP 视觉编码器# 4. 初始化投影层def compute_loss(self, ...):# 1. 嵌入图像和语言# 2. 流匹配训练# 3. 返回损失def sample_actions(self, ...):# 1. 编码观察# 2. ODE 求解生成动作# 3. 返回动作序列
🔧 3. 组件模块层
这些是可复用的模型组件,被完整模型调用:
📝 Transformer 骨干
文件 | 来源 | 作用 |
---|---|---|
gemma.py | Google Big Vision | - Gemma Transformer 实现 - 支持多专家架构 - RMSNorm / AdaRMSNorm - RoPE 位置编码 - GQA 注意力 |
gemma_fast.py | Big Vision + 改编 | - π₀-FAST 专用的 Gemma 变体 |
🖼️ 视觉编码器
文件 | 来源 | 作用 |
---|---|---|
siglip.py | Google Big Vision | - SigLIP 视觉 Transformer - 将图像编码为 token 序列 |
vit.py | Google Vision Transformer | - Vision Transformer 基础实现 - 被 SigLIP 使用 |
🎯 特定功能组件
文件 | 作用 |
---|---|
tokenizer.py | - PaligemmaTokenizer : 语言分词- FASTTokenizer : 动作分词(π₀-FAST)- FSQTokenizer : 有限标量量化分词 |
lora.py | - LoRA(低秩适配)实现 - 用于高效微调 |
utils/fsq_tokenizer.py | - 有限标量量化(FSQ) - VQ-VAE 的替代方案 |
🧪 4. 测试文件
文件 | 说明 |
---|---|
model_test.py | 模型基类测试 |
pi0_test.py | π₀ 模型测试 |
lora_test.py | LoRA 功能测试 |
tokenizer_test.py | 分词器测试 |
二、PyTorch 模型定义 (src/openpi/models_pytorch/
)
🔥 PyTorch 版本的模型
文件 | 对应 JAX 版本 | 说明 |
---|---|---|
pi0_pytorch.py | pi0.py | - PyTorch 实现的 π₀/π₀.₅ - 使用 HuggingFace Transformers |
gemma_pytorch.py | gemma.py | - PyTorch 实现的 Gemma - 基于 HF Transformers |
preprocessing_pytorch.py | - | - PyTorch 数据预处理工具 |
🔌 Transformers 库扩展
transformers_replace/models/
├── gemma/
│ ├── configuration_gemma.py # Gemma 配置
│ └── modeling_gemma.py # Gemma 模型(支持 AdaRMS)
├── paligemma/
│ └── modeling_paligemma.py # PaliGemma 模型
└── siglip/└── modeling_siglip.py # SigLIP 模型
这些文件替换 HuggingFace Transformers 库中的对应文件,添加了:
- ✅ AdaRMSNorm 支持
- ✅ 正确的精度控制
- ✅ KV 缓存控制
三、模型层次结构图
┌─────────────────────────────────────────────────────────────┐
│ BaseModel (抽象基类) │
│ model.py │
│ - compute_loss() - sample_actions() │
└────────────┬────────────────────────────────────────────────┘│┌───────┴───────┬──────────────────┐▼ ▼ ▼
┌─────────┐ ┌─────────┐ ┌─────────────┐
│ Pi0 │ │Pi0FAST │ │ Pi0PyTorch │
│ pi0.py │ │pi0_fast.│ │pi0_pytorch.│
└────┬────┘ └────┬────┘ └──────┬──────┘│ │ ││ ┌──────────┴───────────────────┴─────────┐│ │ │▼ ▼ ▼
┌──────────────────────┐ ┌──────────────────┐
│ Gemma Transformer │ │ HF Transformers │
│ gemma.py │ │ (PyTorch版) │
│ - RMSNorm │ │ transformers_ │
│ - Attention (GQA) │ │ replace/ │
│ - FeedForward │ └──────────────────┘
│ - Multi-Expert │
└──────────────────────┘│▼┌──────────────┐│ SigLIP ViT ││ siglip.py ││ vit.py │└──────────────┘
四、模型定义代码的责任划分
🎨 架构设计 (model.py
, pi0_config.py
)
- 定义数据格式(
Observation
,Actions
) - 定义模型接口(
BaseModel
) - 定义配置系统(
BaseModelConfig
)
🏗️ 完整模型 (pi0.py
, pi0_fast.py
, pi0_pytorch.py
)
- 组装所有组件
- 实现训练逻辑(损失计算)
- 实现推理逻辑(动作采样)
- 定义前向传播流程
🧩 可复用组件 (gemma.py
, siglip.py
, vit.py
)
- Transformer 层实现
- 注意力机制
- 归一化层
- 位置编码
🎯 专用工具 (tokenizer.py
, lora.py
, fsq_tokenizer.py
)
- 文本/动作分词
- 低秩适配
- 量化编码
五、哪些是"核心"模型定义?
如果你想理解或修改模型架构,重点关注这些文件:
🔴 必读:
model.py
- 理解整体接口和数据流pi0.py
- π₀ 的完整实现(最重要!)gemma.py
- Transformer 骨干siglip.py
- 视觉编码器
🟡 重要:
pi0_config.py
- 模型配置tokenizer.py
- 分词逻辑pi0_fast.py
- 自回归变体
🟢 辅助:
lora.py
- 微调技术vit.py
- ViT 基础实现- PyTorch 版本 - 如果你更熟悉 PyTorch
六、代码量统计
# JAX 模型核心代码
pi0.py: ~280 行 ⭐⭐⭐⭐⭐
gemma.py: ~460 行 ⭐⭐⭐⭐
siglip.py: ~374 行 ⭐⭐⭐
pi0_fast.py: ~314 行 ⭐⭐⭐
model.py: ~333 行 ⭐⭐⭐⭐
tokenizer.py: ~372 行 ⭐⭐总计:约 2,133 行核心模型代码
总结
OpenPI 的模型定义代码具有清晰的层次结构:
✅ 基础接口层 → 定义规范(model.py
)
✅ 完整模型层 → 端到端实现(pi0.py
, pi0_fast.py
)
✅ 组件模块层 → 可复用模块(gemma.py
, siglip.py
)
✅ 工具支持层 → 辅助功能(tokenizer.py
, lora.py
)
这种设计使得:
- 📦 模块化:组件可独立复用
- 🔧 可扩展:易于添加新模型变体
- 🎯 职责清晰:每个文件有明确的作用
- 🔄 双框架支持:JAX 和 PyTorch 并存