当前位置: 首页 > news >正文

lerobot[评估策略,训练策略]

本文是lerobot[部署,元数据集,加载数据集]的后续

目录

  • 评估策略
      • 加载模型
      • 加载环境
      • 交互
  • 训练策略
      • 加载数据集
      • 加载策略
      • 加载优化器
      • 前向传播,反向传播,更新参数
  • 参考资料
  • 后续

评估策略

策略评估这块基本就是套路式的三步走: 加载模型,加载环境,循环{ 获取状态,将策略输入到状态中获得动作,与环境交互}

加载模型

TODO: diffusion model解析
如果能够再本地找到pretrained_policy_path就用本地参数,不能就从hf上下载

pretrained_policy_path = "lerobot/diffusion_pusht"
policy=DiffusionPolicy.from_pretrained(pretrained_policy_path)

加载环境

这里env 需要在前面下载环境的时候配置

env = gym.make(
    "gym_pusht/PushT-v0",
    obs_type="pixels_agent_pos",
    max_episode_steps=300,
)

交互

采集state,policy(state) 获得action,env(action) 的循环

while not done:
    state = torch.from_numpy(numpy_observation["agent_pos"])
    image = torch.from_numpy(numpy_observation["pixels"])
    state = state.to(torch.float32)
    image = image.to(torch.float32) / 255
    image = image.permute(2, 0, 1)
    state = state.to(device, non_blocking=True)
    image = image.to(device, non_blocking=True)
    state = state.unsqueeze(0)
    image = image.unsqueeze(0)
    observation = {
        "observation.state": state,
        "observation.image": image,
    }
    with torch.inference_mode():
        action = policy.select_action(observation)
    numpy_action = action.squeeze(0).to("cpu").numpy()

    numpy_observation, reward, terminated, truncated, info = env.step(numpy_action)

    print(f"{step=} {reward=} {terminated=}")
    rewards.append(reward)
    frames.append(env.render())
    done = terminated | truncated | done
    step += 1
if terminated:
    print("Success!")
else:
    print("Failure!")

最后还有一个可视化,将每帧图片连起来组成视频

video_path = output_directory / "rollout.mp4"
imageio.mimsave(str(video_path), numpy.stack(frames), fps=fps)

TODO插入视频

训练策略

训练策略的步骤也是比较固定,由于是offline training也是比较简单的:加载策略,加载数据集,加载优化器,前向传播,反向传播,更新参数

加载数据集

dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
features=dataset_to_policy_features(dataset_metadata.features)
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}

delta_timestamps = {
        "observation.image": [-0.1, 0.0],
        "observation.state": [-0.1, 0.0],
        "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
    }
dataset=LeRobotDataset("lerobot/pusht",delta_timestamps=delta_timestamps)
dataloader = torch.utils.data.DataLoader(
        dataset,
        num_workers=4,
        batch_size=64,
        shuffle=True,
        pin_memory=device.type != "cpu",
        drop_last=True,
    )

加载策略

cfg=DiffusionConfig(input_features=input_features,output_features=output_features)
policy=DiffusionPolicy(cfg,dataset_stats=dataset_metadata.stats)
policy.train()
policy.to(device)

加载优化器

optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)

前向传播,反向传播,更新参数

    step = 0
    done = False
    while not done:
        for batch in dataloader:
            batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
            loss, _ = policy.forward(batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            if step % log_freq == 0:
                print(f"step: {step} loss: {loss.item():.3f}")
            step += 1
            if step >= training_steps:
                done = True
                break

参考资料

https://huggingface.co/lerobot

后续

diffusion policy,act算法解析

相关文章:

  • C++ 标准库 vector(三十七)
  • (51单片机)独立按键控制流水灯LED流向(独立按键教程)(LED使用教程)
  • day40——种花问题(LeetCode-605)
  • chromadb 安装和使用
  • Lecture 44: NVIDIA Profiling (未完)
  • 10种电阻综合对比——《器件手册--电阻》
  • CNN-SE-Attention-ITCN多特征输入回归预测(Matlab完整源码和数据)
  • DeepSeek推动办公智能向“人机共智”阶段跃迁
  • centos7 yum install docker 安装错误
  • java面试篇 并发编程篇
  • 低代码开发:重塑软件开发的未来
  • MCP server的stdio和SSE分别是什么?
  • 网络初识 - Java
  • C# Winform 入门(11)之制作酷炫灯光效果
  • DeepSeek 教我 C++ (8) :C++ 静态类型不安全的情况
  • 内网渗透(杂项集合) --- 中的多协议与漏洞利用技术(杂项知识点 重点) 持续更新
  • Three.js 系列专题 3:光照与阴影
  • Spring Data JPA中的List底层:深入解析ArrayList的奥秘!!!
  • linux Gitkraken 破解
  • 基于springboot协同过滤算法的农产品销售推荐系统(源码+lw+部署文档+讲解),源码可白嫖!
  • 律师网站建设建议/百度推广非企代理
  • 搜索引擎找不到的网站/适合发朋友圈的营销广告
  • 域名备案码/春哥seo博客
  • 怎么用自己主机做网站/app下载推广
  • 哪个网站可以做批发/营销公关
  • 访问自己做的网站/百度投诉中心24人工