当前位置: 首页 > wzjs >正文

医院网站开发公司上海搜索引擎优化公司排名

医院网站开发公司,上海搜索引擎优化公司排名,网络营销课程总结1000字,金融集团网站模板本文是lerobot[部署,元数据集,加载数据集]的后续 目录 评估策略加载模型加载环境交互 训练策略加载数据集加载策略加载优化器前向传播,反向传播,更新参数 参考资料后续 评估策略 策略评估这块基本就是套路式的三步走: 加载模型&…

本文是lerobot[部署,元数据集,加载数据集]的后续

目录

  • 评估策略
      • 加载模型
      • 加载环境
      • 交互
  • 训练策略
      • 加载数据集
      • 加载策略
      • 加载优化器
      • 前向传播,反向传播,更新参数
  • 参考资料
  • 后续

评估策略

策略评估这块基本就是套路式的三步走: 加载模型,加载环境,循环{ 获取状态,将策略输入到状态中获得动作,与环境交互}

加载模型

TODO: diffusion model解析
如果能够再本地找到pretrained_policy_path就用本地参数,不能就从hf上下载

pretrained_policy_path = "lerobot/diffusion_pusht"
policy=DiffusionPolicy.from_pretrained(pretrained_policy_path)

加载环境

这里env 需要在前面下载环境的时候配置

env = gym.make("gym_pusht/PushT-v0",obs_type="pixels_agent_pos",max_episode_steps=300,
)

交互

采集state,policy(state) 获得action,env(action) 的循环

while not done:state = torch.from_numpy(numpy_observation["agent_pos"])image = torch.from_numpy(numpy_observation["pixels"])state = state.to(torch.float32)image = image.to(torch.float32) / 255image = image.permute(2, 0, 1)state = state.to(device, non_blocking=True)image = image.to(device, non_blocking=True)state = state.unsqueeze(0)image = image.unsqueeze(0)observation = {"observation.state": state,"observation.image": image,}with torch.inference_mode():action = policy.select_action(observation)numpy_action = action.squeeze(0).to("cpu").numpy()numpy_observation, reward, terminated, truncated, info = env.step(numpy_action)print(f"{step=} {reward=} {terminated=}")rewards.append(reward)frames.append(env.render())done = terminated | truncated | donestep += 1
if terminated:print("Success!")
else:print("Failure!")

最后还有一个可视化,将每帧图片连起来组成视频

video_path = output_directory / "rollout.mp4"
imageio.mimsave(str(video_path), numpy.stack(frames), fps=fps)

TODO插入视频

训练策略

训练策略的步骤也是比较固定,由于是offline training也是比较简单的:加载策略,加载数据集,加载优化器,前向传播,反向传播,更新参数

加载数据集

dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
features=dataset_to_policy_features(dataset_metadata.features)
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}delta_timestamps = {"observation.image": [-0.1, 0.0],"observation.state": [-0.1, 0.0],"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],}
dataset=LeRobotDataset("lerobot/pusht",delta_timestamps=delta_timestamps)
dataloader = torch.utils.data.DataLoader(dataset,num_workers=4,batch_size=64,shuffle=True,pin_memory=device.type != "cpu",drop_last=True,)

加载策略

cfg=DiffusionConfig(input_features=input_features,output_features=output_features)
policy=DiffusionPolicy(cfg,dataset_stats=dataset_metadata.stats)
policy.train()
policy.to(device)

加载优化器

optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)

前向传播,反向传播,更新参数

    step = 0done = Falsewhile not done:for batch in dataloader:batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}loss, _ = policy.forward(batch)loss.backward()optimizer.step()optimizer.zero_grad()if step % log_freq == 0:print(f"step: {step} loss: {loss.item():.3f}")step += 1if step >= training_steps:done = Truebreak

参考资料

https://huggingface.co/lerobot

后续

diffusion policy,act算法解析

http://www.dtcms.com/wzjs/142965.html

相关文章:

  • 利用jquery做音乐网站长沙推广引流
  • 欧洲美国韩国中国优化网络的软件下载
  • 武功县住房和城乡建设局网站系统优化方法
  • 唐山市政建设总公司网站2022年五月份热点事件
  • 深圳网站建设世纪前线买友情链接
  • 寮步仿做网站深圳英文网站推广
  • dede模板网站教程餐饮最有效的营销方案
  • 黄冈网站优化公司哪家好新东方英语线下培训学校
  • 济南网站建设泉诺一键优化下载
  • 找一家秦皇岛市做网站的公司西安百度关键词排名服务
  • 网站在哪里设置关键字腾讯广告平台
  • 楚雄网站开发rewlkj站长工具查询网站信息
  • 网站怎么做动态切图网络营销是指什么
  • 爱做的小说网站吗沈阳seo合作
  • 河南安阳市有几个县搜索引擎优化服务
  • newszone wordpress magazineseo快速优化文章排名
  • 卡地亚官方网站制作需要多少钱惠州seo推广优化
  • 主机网站建设制作seo关键词排名公司
  • 网站开发需要懂多少代码免费引流在线推广
  • 如何建立官网吉林网络seo
  • 自己做的产品在哪个网站上可从卖免费广告推广软件
  • 手机网站 需求模板培训平台
  • 企业网站运营问题域名注册新网
  • 创建网页用什么软件seo营销优化软件
  • iis 网站 优化自动seo优化
  • 上海免费网站建设品牌网络营销是干嘛的
  • 网站用户建设的设计与实现海口网站关键词优化
  • 海南的房产网站建设山东seo首页关键词优化
  • 重庆怎么推广企业网站百度搜索排名与点击有关吗
  • b2b网站排名前十简述如何优化网站的方法