robot_lab train的整体逻辑
Go2机器人推理(Play)流程详细分析
概述
本文档详细分析了使用命令 python scripts/rsl_rl/base/play.py --task RobotLab-Isaac-Velocity-Rough-Unitree-Go2-v0
进行Go2机器人推理的完整流程,基于实际的代码实现,包括模型加载、环境配置调整、推理循环以及可视化等关键组件。
1. 推理启动流程
1.1 命令行参数解析
# 在 play.py 中
parser = argparse.ArgumentParser(description="Train an RL agent with RSL-RL.")# 推理专用参数
parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
parser.add_argument("--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations.")
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--use_pretrained_checkpoint", action="store_true",help="Use the pre-trained checkpoint from Nucleus.")
parser.add_argument("--real-time", action="store_true", default=False, help="Run in real-time, if possible.")
parser.add_argument("--keyboard", action="store_true", default=False, help="Whether to use keyboard.")# 添加RSL-RL和AppLauncher参数
cli_args.add_rsl_rl_args(parser)
AppLauncher.add_app_launcher_args(parser)
关键推理参数:
--task
: 指定任务环境(与训练时相同)--video
: 启用视频录制功能--video_length
: 视频录制长度(默认200步)--num_envs
: 推理环境数量(默认会被调整为50)--real-time
: 实时运行模式--keyboard
: 启用键盘控制--use_pretrained_checkpoint
: 使用预训练检查点--checkpoint
: 指定特定检查点路径
1.2 相机和视频配置
# 如果启用视频录制,自动启用相机
if args_cli.video:args_cli.enable_cameras = True# 启动Isaac Sim应用
app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app
2. 环境配置调整(推理专用优化)
2.1 基础环境配置加载
def main():# 解析环境配置env_cfg = parse_env_cfg(args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric)# 解析智能体配置agent_cfg: RslRlOnPolicyRunnerCfg = cli_args.parse_rsl_rl_cfg(args_cli.task, args_cli)
2.2 推理专用环境调整
# 1. 减少环境数量以提高推理效率
env_cfg.scene.num_envs = 50 # 从训练时的4096减少到50# 2. 地形配置调整
env_cfg.scene.terrain.max_init_terrain_level = None # 随机生成机器人位置
if env_cfg.scene.terrain.terrain_generator is not None:env_cfg.scene.terrain.terrain_generator.num_rows = 5 # 减少地形行数env_cfg.scene.terrain.terrain_generator.num_cols = 5 # 减少地形列数env_cfg.scene.terrain.terrain_generator.curriculum = False # 禁用课程学习# 3. 禁用观测噪声以获得一致性能
env_cfg.observations.policy.enable_corruption = False# 4. 移除训练时的随机扰动
env_cfg.events.randomize_apply_external_force_torque = None # 禁用随机外力
env_cfg.events.push_robot = None # 禁用机器人推力扰动
推理环境优化说明:
- 环境数量:从4096减少到50,降低计算负载
- 地形简化:减少地形复杂度,专注于展示性能
- 禁用随机化:去除所有随机干扰因素
- 禁用课程学习:使用固定难度级别
2.3 键盘控制模式配置
if args_cli.keyboard:# 单环境模式env_cfg.scene.num_envs = 1# 禁用超时终止env_cfg.terminations.time_out = None# 关闭命令可视化env_cfg.commands.base_velocity.debug_vis = False# 创建键盘控制器controller = Se2Keyboard(v_x_sensitivity=env_cfg.commands.base_velocity.ranges.lin_vel_x[1], # 前进后退敏感度v_y_sensitivity=env_cfg.commands.base_velocity.ranges.lin_vel_y[1], # 左右移动敏感度omega_z_sensitivity=env_cfg.commands.base_velocity.ranges.ang_vel_z[1], # 转向敏感度)# 替换速度命令观测项为键盘输入env_cfg.observations.policy.velocity_commands = ObsTerm(func=lambda env: torch.tensor(controller.advance(), dtype=torch.float32).unsqueeze(0).to(env.device),)
键盘控制特点:
- 实时交互:用户可以通过键盘实时控制机器人
- 单环境:专注于一个机器人的控制
- 无超时:持续运行直到用户停止
- 相机跟随:相机会自动跟随机器人移动
3. 模型加载与恢复
3.1 检查点路径解析
# 设置日志根路径
log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name)
log_root_path = os.path.abspath(log_root_path)
print(f"[INFO] Loading experiment from directory: {log_root_path}")# 确定检查点路径的优先级
if args_cli.use_pretrained_checkpoint:# 使用预训练检查点resume_path = get_published_pretrained_checkpoint("rsl_rl", args_cli.task)if not resume_path:print("[INFO] Unfortunately a pre-trained checkpoint is currently unavailable for this task.")return
elif args_cli.checkpoint:# 使用指定的检查点路径resume_path = retrieve_file_path(args_cli.checkpoint)
else:# 自动查找最新检查点resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint)log_dir = os.path.dirname(resume_path)
检查点查找优先级:
- 预训练检查点:来自官方发布的预训练模型
- 指定检查点:用户明确指定的检查点文件
- 自动检查点:从实验目录自动查找最新检查点
3.2 环境创建和包装
# 创建Isaac环境
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)# 处理多智能体环境转换
if isinstance(env.unwrapped, DirectMARLEnv):env = multi_agent_to_single_agent(env)# 视频录制包装器
if args_cli.video:video_kwargs = {"video_folder": os.path.join(log_dir, "videos", "play"), # 视频保存目录"step_trigger": lambda step: step == 0, # 立即开始录制"video_length": args_cli.video_length, # 录制长度"disable_logger": True, # 禁用额外日志}print("[INFO] Recording videos during training.")print_dict(video_kwargs, nesting=4)env = gym.wrappers.RecordVideo(env, **video_kwargs)# RSL-RL环境包装器
env = RslRlVecEnvWrapper(env, clip_actions=agent_cfg.clip_actions)
3.3 模型加载和初始化
print(f"[INFO]: Loading model checkpoint from: {resume_path}")# 创建PPO运行器并加载模型
ppo_runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device)
ppo_runner.load(resume_path)# 获取推理策略
policy = ppo_runner.get_inference_policy(device=env.unwrapped.device)# 提取神经网络模块(兼容不同版本)
try:# RSL-RL 2.3及以上版本policy_nn = ppo_runner.alg.policy
except AttributeError:# RSL-RL 2.2及以下版本policy_nn = ppo_runner.alg.actor_critic
模型加载特点:
- 版本兼容:支持不同版本的RSL-RL库
- 设备管理:自动处理GPU/CPU设备转换
- 推理优化:使用专门的推理策略而非训练策略
4. 模型导出功能
4.1 导出为ONNX和JIT格式
# 设置导出目录
export_model_dir = os.path.join(os.path.dirname(resume_path), "exported")# 导出ONNX格式(跨平台部署)
export_policy_as_onnx(policy=policy_nn,normalizer=ppo_runner.obs_normalizer, # 包含观测归一化信息path=export_model_dir,filename="policy.onnx",
)# 导出JIT格式(PyTorch部署)
export_policy_as_jit(policy=policy_nn,normalizer=ppo_runner.obs_normalizer,path=export_model_dir,filename="policy.pt",
)
导出格式说明:
- ONNX格式:跨平台部署,支持多种推理引擎
- JIT格式:PyTorch原生格式,高性能推理
- 包含归一化:导出模型包含观测归一化参数
5. 推理主循环
5.1 推理循环实现
dt = env.unwrapped.step_dt # 获取仿真时间步长# 重置环境
obs, _ = env.get_observations()
timestep = 0# 主推理循环
while simulation_app.is_running():start_time = time.time()# 推理模式执行with torch.inference_mode():# 策略推理actions = policy(obs)# actions = torch.zeros_like(actions) # 可选:零动作测试# 环境步进obs, _, _, _ = env.step(actions)# 视频录制控制if args_cli.video:timestep += 1# 录制完成后退出if timestep == args_cli.video_length:break# 键盘模式下的相机跟随if args_cli.keyboard:rsl_rl_utils.camera_follow(env)# 实时模式的时间控制sleep_time = dt - (time.time() - start_time)if args_cli.real_time and sleep_time > 0:time.sleep(sleep_time)# 关闭环境
env.close()
5.2 推理循环特点
性能优化:
- torch.inference_mode():禁用梯度计算,提高推理速度
- 实时控制:可选的实时运行模式
- 内存优化:减少不必要的张量操作
交互功能:
- 视频录制:自动录制指定长度的演示视频
- 键盘控制:实时人机交互控制
- 相机跟随:动态视角跟踪
循环控制:
- 条件退出:视频录制完成或用户中断
- 时间同步:保持与仿真时间步长一致
6. 实时控制和可视化
6.1 实时模式实现
def real_time_control(dt, start_time):"""实时模式时间控制"""sleep_time = dt - (time.time() - start_time)if args_cli.real_time and sleep_time > 0:time.sleep(sleep_time)
实时特性:
- 时间同步:确保推理循环与物理时间步长同步
- 帧率控制:维持稳定的可视化帧率
- 响应性:保证键盘输入的实时响应
6.2 键盘控制实现
class Se2Keyboard:"""2D移动键盘控制器"""def __init__(self, v_x_sensitivity, v_y_sensitivity, omega_z_sensitivity):self.v_x_sensitivity = v_x_sensitivity # 前进后退敏感度self.v_y_sensitivity = v_y_sensitivity # 左右移动敏感度 self.omega_z_sensitivity = omega_z_sensitivity # 转向敏感度def advance(self):"""获取当前键盘输入状态"""# 返回 [v_x, v_y, omega_z] 速度命令return self.get_keyboard_input()
键盘映射:
- WASD:基本移动控制
- 方向键:精确移动控制
- 鼠标:视角控制
- ESC:退出程序
6.3 相机跟随系统
def camera_follow(env):"""相机自动跟随机器人"""# 获取机器人当前位置robot_pos = env.unwrapped.scene.robot.data.root_pos_w[0]# 更新相机位置和朝向camera_offset = torch.tensor([2.0, 0.0, 1.5]) # 相机偏移量camera_pos = robot_pos + camera_offset# 设置相机朝向机器人camera_target = robot_pos# 应用相机变换set_camera_view(camera_pos, camera_target)
7. 视频录制系统
7.1 录制配置
if args_cli.video:video_kwargs = {"video_folder": os.path.join(log_dir, "videos", "play"),"step_trigger": lambda step: step == 0, # 立即开始录制"video_length": args_cli.video_length, # 录制步数"disable_logger": True, # 禁用日志输出}env = gym.wrappers.RecordVideo(env, **video_kwargs)
7.2 录制流程
def video_recording_flow():"""视频录制流程"""timestep = 0while simulation_app.is_running():# 正常推理步骤with torch.inference_mode():actions = policy(obs)obs, _, _, _ = env.step(actions)# 录制控制if args_cli.video:timestep += 1# 达到指定长度后停止录制if timestep == args_cli.video_length:print(f"[INFO] Video recording completed: {timestep} steps")break# 其他循环逻辑...
录制特点:
- 自动触发:程序启动即开始录制
- 固定长度:录制指定步数后自动停止
- 高质量:使用rgb_array渲染模式
- 自动保存:录制完成后自动保存到指定目录
8. 关键差异对比(训练 vs 推理)
8.1 环境配置差异
配置项 | 训练模式 | 推理模式 | 说明 |
---|---|---|---|
环境数量 | 4096 | 50 | 推理时大幅减少以提高效率 |
观测噪声 | 启用 | 禁用 | 推理时需要一致性能 |
域随机化 | 启用 | 禁用 | 推理时使用固定参数 |
课程学习 | 启用 | 禁用 | 推理时使用固定难度 |
外力扰动 | 启用 | 禁用 | 推理时避免干扰 |
地形复杂度 | 高 | 简化 | 推理时减少计算负载 |
8.2 计算模式差异
方面 | 训练模式 | 推理模式 |
---|---|---|
梯度计算 | 启用 | 禁用(torch.inference_mode) |
内存使用 | 高(经验缓冲区) | 低(仅当前状态) |
策略采样 | 随机采样 | 确定性输出 |
网络模式 | 训练模式 | 评估模式 |
批处理 | 大批量 | 小批量或单步 |
8.3 功能差异
功能 | 训练模式 | 推理模式 |
---|---|---|
模型更新 | ✓ | ✗ |
视频录制 | 可选 | ✓ |
键盘控制 | ✗ | ✓ |
实时运行 | ✗ | ✓ |
模型导出 | ✗ | ✓ |
相机跟随 | ✗ | ✓ |
9. 使用示例
9.1 基本推理
# 基本推理模式
python scripts/rsl_rl/base/play.py \--task RobotLab-Isaac-Velocity-Rough-Unitree-Go2-v0
9.2 视频录制
# 录制200步演示视频
python scripts/rsl_rl/base/play.py \--task RobotLab-Isaac-Velocity-Rough-Unitree-Go2-v0 \--video \--video_length 200
9.3 键盘控制
# 启用键盘实时控制
python scripts/rsl_rl/base/play.py \--task RobotLab-Isaac-Velocity-Rough-Unitree-Go2-v0 \--keyboard \--real-time
9.4 指定检查点
# 使用特定检查点
python scripts/rsl_rl/base/play.py \--task RobotLab-Isaac-Velocity-Rough-Unitree-Go2-v0 \--checkpoint /path/to/model.pt \--video
9.5 预训练模型
# 使用官方预训练模型
python scripts/rsl_rl/base/play.py \--task RobotLab-Isaac-Velocity-Rough-Unitree-Go2-v0 \--use_pretrained_checkpoint \--video
10. 故障排除
10.1 常见问题
问题1:找不到检查点文件
解决方案:
1. 检查训练是否完成并保存了检查点
2. 确认路径是否正确
3. 使用--checkpoint参数指定具体路径
问题2:推理速度太慢
解决方案:
1. 减少环境数量 --num_envs 16
2. 禁用视频录制
3. 使用--headless模式
问题3:视频录制失败
解决方案:
1. 确保启用了相机 --enable_cameras
2. 检查输出目录权限
3. 确认ffmpeg已安装
10.2 性能优化建议
GPU内存优化:
- 减少并行环境数量
- 使用较小的观测向量
- 启用混合精度推理
CPU优化:
- 使用JIT编译的模型
- 减少Python开销
- 启用多线程
11. 总结
Go2机器人的推理(play)流程相比训练流程具有以下特点:
11.1 核心特点
- 轻量化配置:大幅减少计算资源需求
- 用户友好:支持实时交互和可视化
- 部署导向:提供多种模型导出格式
- 稳定可靠:移除随机化确保一致性能
11.2 主要用途
- 性能验证:测试训练结果
- 演示展示:录制演示视频
- 交互控制:实时人机交互
- 模型部署:导出部署格式
11.3 技术优势
- 高效推理:torch.inference_mode()优化
- 实时响应:支持实时控制和反馈
- 多模式支持:自动/手动/录制模式
- 易于使用:简单的命令行接口
这个推理系统为Go2机器人提供了完整的部署和测试解决方案,是从训练到实际应用的重要桥梁。