lerobot[评估策略,训练策略]
本文是lerobot[部署,元数据集,加载数据集]的后续
目录
- 评估策略
- 加载模型
- 加载环境
- 交互
- 训练策略
- 加载数据集
- 加载策略
- 加载优化器
- 前向传播,反向传播,更新参数
- 参考资料
- 后续
评估策略
策略评估这块基本就是套路式的三步走: 加载模型,加载环境,循环{ 获取状态,将策略输入到状态中获得动作,与环境交互}
加载模型
TODO: diffusion model解析
如果能够再本地找到pretrained_policy_path
就用本地参数,不能就从hf上下载
pretrained_policy_path = "lerobot/diffusion_pusht"
policy=DiffusionPolicy.from_pretrained(pretrained_policy_path)
加载环境
这里env 需要在前面下载环境的时候配置
env = gym.make(
"gym_pusht/PushT-v0",
obs_type="pixels_agent_pos",
max_episode_steps=300,
)
交互
采集state,policy(state) 获得action,env(action) 的循环
while not done:
state = torch.from_numpy(numpy_observation["agent_pos"])
image = torch.from_numpy(numpy_observation["pixels"])
state = state.to(torch.float32)
image = image.to(torch.float32) / 255
image = image.permute(2, 0, 1)
state = state.to(device, non_blocking=True)
image = image.to(device, non_blocking=True)
state = state.unsqueeze(0)
image = image.unsqueeze(0)
observation = {
"observation.state": state,
"observation.image": image,
}
with torch.inference_mode():
action = policy.select_action(observation)
numpy_action = action.squeeze(0).to("cpu").numpy()
numpy_observation, reward, terminated, truncated, info = env.step(numpy_action)
print(f"{step=} {reward=} {terminated=}")
rewards.append(reward)
frames.append(env.render())
done = terminated | truncated | done
step += 1
if terminated:
print("Success!")
else:
print("Failure!")
最后还有一个可视化,将每帧图片连起来组成视频
video_path = output_directory / "rollout.mp4"
imageio.mimsave(str(video_path), numpy.stack(frames), fps=fps)
TODO插入视频
训练策略
训练策略的步骤也是比较固定,由于是offline training也是比较简单的:加载策略,加载数据集,加载优化器,前向传播,反向传播,更新参数。
加载数据集
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
features=dataset_to_policy_features(dataset_metadata.features)
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}
delta_timestamps = {
"observation.image": [-0.1, 0.0],
"observation.state": [-0.1, 0.0],
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
}
dataset=LeRobotDataset("lerobot/pusht",delta_timestamps=delta_timestamps)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=64,
shuffle=True,
pin_memory=device.type != "cpu",
drop_last=True,
)
加载策略
cfg=DiffusionConfig(input_features=input_features,output_features=output_features)
policy=DiffusionPolicy(cfg,dataset_stats=dataset_metadata.stats)
policy.train()
policy.to(device)
加载优化器
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
前向传播,反向传播,更新参数
step = 0
done = False
while not done:
for batch in dataloader:
batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if step % log_freq == 0:
print(f"step: {step} loss: {loss.item():.3f}")
step += 1
if step >= training_steps:
done = True
break
参考资料
https://huggingface.co/lerobot
后续
diffusion policy,act算法解析