当前位置: 首页 > news >正文

模仿学习模型diffusion_policy部署

首先下载diffusion_policy代码:https://github.com/real-stanford/diffusion_policy/tree/main

修改diffusion_policy/policy/diffusion_unet_lowdim_policy.py:

from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce
from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerfrom diffusion_policy.model.common.normalizer import LinearNormalizer
from diffusion_policy.policy.base_lowdim_policy import BaseLowdimPolicy
from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D
from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGeneratorclass DiffusionUnetLowdimPolicy(BaseLowdimPolicy):def __init__(self, model: ConditionalUnet1D,noise_scheduler: DDPMScheduler,horizon, obs_dim, action_dim, n_action_steps, n_obs_steps,num_inference_steps=None,obs_as_local_cond=False,obs_as_global_cond=False,pred_action_steps_only=False,oa_step_convention=False,# parameters passed to step**kwargs):super().__init__()assert not (obs_as_local_cond and obs_as_global_cond)if pred_action_steps_only:assert obs_as_global_condself.model = modelself.noise_scheduler = noise_schedulerself.mask_generator = LowdimMaskGenerator(action_dim=action_dim,obs_dim=0 if (obs_as_local_cond or obs_as_global_cond) else obs_dim,max_n_obs_steps=n_obs_steps,fix_obs_steps=True,action_visible=False)self.normalizer = LinearNormalizer()self.horizon = horizonself.obs_dim = obs_dimself.action_dim = action_dimself.n_action_steps = n_action_stepsself.n_obs_steps = n_obs_stepsself.obs_as_local_cond = obs_as_local_condself.obs_as_global_cond = obs_as_global_condself.pred_action_steps_only = pred_action_steps_onlyself.oa_step_convention = oa_step_conventionself.kwargs = kwargsif num_inference_steps is None:num_inference_steps = noise_scheduler.config.num_train_timestepsself.num_inference_steps = num_inference_steps# ========= inference  ============def conditional_sample(self, condition_data, condition_mask,local_cond=None, global_cond=None,generator=None,# keyword arguments to scheduler.step**kwargs):model = self.modelscheduler = self.noise_schedulertrajectory = torch.randn(size=condition_data.shape, dtype=condition_data.dtype,device=condition_data.device,generator=generator)# set step valuesscheduler.set_timesteps(self.num_inference_steps)for t in scheduler.timesteps:print(t)# 1. apply conditioning#trajectory[condition_mask] = condition_data[condition_mask]trajectory = torch.where(condition_mask == 1,condition_data, trajectory)# 2. predict model outputmodel_output = model(trajectory, t, local_cond=local_cond, global_cond=global_cond)# 3. compute previous image: x_t -> x_t-1trajectory = scheduler.step(model_output, t, trajectory, generator=generator,**kwargs).prev_sample# finally make sure conditioning is enforced#trajectory[condition_mask] = condition_data[condition_mask]        trajectory = torch.where(condition_mask == 1,condition_data, trajectory)return trajectorydef predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:"""obs_dict: must include "obs" keyresult: must include "action" key"""assert 'obs' in obs_dictassert 'past_action' not in obs_dict # not implemented yetnobs = self.normalizer['obs'].normalize(obs_dict['obs'])B, _, Do = nobs.shapeTo = self.n_obs_stepsassert Do == self.obs_dimT = self.horizonDa = self.action_dim# build inputdevice = self.devicedtype = self.dtype# handle different ways of passing observationlocal_cond = Noneglobal_cond = Noneif self.obs_as_local_cond:# condition through local feature# all zero except first To timestepslocal_cond = torch.zeros(size=(B,T,Do), device=device, dtype=dtype)local_cond[:,:To] = nobs[:,:To]shape = (B, T, Da)cond_data = torch.zeros(size=shape, device=device, dtype=dtype)cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)elif self.obs_as_global_cond:# condition throught global featureglobal_cond = nobs[:,:To].reshape(nobs.shape[0], -1)shape = (B, T, Da)if self.pred_action_steps_only:shape = (B, self.n_action_steps, Da)cond_data = torch.zeros(size=shape, device=device, dtype=dtype)cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)else:# condition through impaintingshape = (B, T, Da+Do)cond_data = torch.zeros(size=shape, device=device, dtype=dtype)cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)cond_data[:,:To,Da:] = nobs[:,:To]cond_mask[:,:To,Da:] = True# run samplingnsample = self.conditional_sample(cond_data, cond_mask,local_cond=local_cond,global_cond=global_cond,**self.kwargs)# unnormalize predictionnaction_pred = nsample[...,:Da]action_pred = self.normalizer['action'].unnormalize(naction_pred)# get actionif self.pred_action_steps_only:action = action_predelse:start = Toif self.oa_step_convention:start = To - 1end = start + self.n_action_stepsaction = action_pred[:,start:end]result = {'action': action,'action_pred': action_pred}if not (self.obs_as_local_cond or self.obs_as_global_cond):nobs_pred = nsample[...,Da:]obs_pred = self.normalizer['obs'].unnormalize(nobs_pred)action_obs_pred = obs_pred[:,start:end]result['action_obs_pred'] = action_obs_predresult['obs_pred'] = obs_predreturn resultdef forward(self, obs, obs_mask):"""obs_dict: must include "obs" keyresult: must include "action" key"""obs_dict = {'obs':obs, 'obs_mask':obs_mask}nobs = self.normalizer['obs'].normalize(obs_dict['obs'])B, _, Do = nobs.shapeTo = self.n_obs_stepsassert Do == self.obs_dimT = self.horizonDa = self.action_dim# build inputdevice = self.devicedtype = self.dtype# handle different ways of passing observationlocal_cond = Noneglobal_cond = Noneif self.obs_as_local_cond:# condition through local feature# all zero except first To timestepslocal_cond = torch.zeros(size=(B,T,Do), device=device, dtype=dtype)local_cond[:,:To] = nobs[:,:To]shape = (B, T, Da)cond_data = torch.zeros(size=shape, device=device, dtype=dtype)cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)elif self.obs_as_global_cond:# condition throught global featureglobal_cond = nobs[:,:To].reshape(nobs.shape[0], -1)shape = (B, T, Da)if self.pred_action_steps_only:shape = (B, self.n_action_steps, Da)cond_data = torch.zeros(size=shape, device=device, dtype=dtype)cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)else:# condition through impaintingshape = (B, T, Da+Do)cond_data = torch.zeros(size=shape, device=device, dtype=dtype)cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)cond_data[:,:To,Da:] = nobs[:,:To]cond_mask[:,:To,Da:] = True# run samplingnsample = self.conditional_sample(cond_data, cond_mask,local_cond=local_cond,global_cond=global_cond,**self.kwargs)# unnormalize predictionnaction_pred = nsample[...,:Da]action_pred = self.normalizer['action'].unnormalize(naction_pred)# get actionif self.pred_action_steps_only:action = action_predelse:start = Toif self.oa_step_convention:start = To - 1end = start + self.n_action_stepsaction = action_pred[:,start:end]result = {'action': action,'action_pred': action_pred}if not (self.obs_as_local_cond or self.obs_as_global_cond):nobs_pred = nsample[...,Da:]obs_pred = self.normalizer['obs'].unnormalize(nobs_pred)action_obs_pred = obs_pred[:,start:end]result['action_obs_pred'] = action_obs_predresult['obs_pred'] = obs_predreturn action, action_pred, action_obs_pred, obs_pred# ========= training  ============def set_normalizer(self, normalizer: LinearNormalizer):self.normalizer.load_state_dict(normalizer.state_dict())def compute_loss(self, batch):# normalize inputassert 'valid_mask' not in batchnbatch = self.normalizer.normalize(batch)obs = nbatch['obs']action = nbatch['action']# handle different ways of passing observationlocal_cond = Noneglobal_cond = Nonetrajectory = actionif self.obs_as_local_cond:# zero out observations after n_obs_stepslocal_cond = obslocal_cond[:,self.n_obs_steps:,:] = 0elif self.obs_as_global_cond:global_cond = obs[:,:self.n_obs_steps,:].reshape(obs.shape[0], -1)if self.pred_action_steps_only:To = self.n_obs_stepsstart = Toif self.oa_step_convention:start = To - 1end = start + self.n_action_stepstrajectory = action[:,start:end]else:trajectory = torch.cat([action, obs], dim=-1)# generate impainting maskif self.pred_action_steps_only:condition_mask = torch.zeros_like(trajectory, dtype=torch.bool)else:condition_mask = self.mask_generator(trajectory.shape)# Sample noise that we'll add to the imagesnoise = torch.randn(trajectory.shape, device=trajectory.device)bsz = trajectory.shape[0]# Sample a random timestep for each imagetimesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=trajectory.device).long()# Add noise to the clean images according to the noise magnitude at each timestep# (this is the forward diffusion process)noisy_trajectory = self.noise_scheduler.add_noise(trajectory, noise, timesteps)# compute loss maskloss_mask = ~condition_mask# apply conditioningnoisy_trajectory[condition_mask] = trajectory[condition_mask]# Predict the noise residualpred = self.model(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond)pred_type = self.noise_scheduler.config.prediction_type if pred_type == 'epsilon':target = noiseelif pred_type == 'sample':target = trajectoryelse:raise ValueError(f"Unsupported prediction type {pred_type}")loss = F.mse_loss(pred, target, reduction='none')loss = loss * loss_mask.type(loss.dtype)loss = reduce(loss, 'b ... -> b (...)', 'mean')loss = loss.mean()return loss

编写脚本导出onnx模型:

import torch
import hydra
import dill
from diffusion_policy.workspace.base_workspace import BaseWorkspacecheckpoint = "data/0550-test_mean_score=0.969.ckpt"
output_dir = "data/pusht_eval_output"
payload = torch.load(open(checkpoint, 'rb'), pickle_module=dill)
cfg = payload['cfg']
cls = hydra.utils.get_class(cfg._target_)
workspace = cls(cfg, output_dir=output_dir)
workspace: BaseWorkspace
workspace.load_payload(payload, exclude_keys=None, include_keys=None)policy = workspace.model
policy = policy.to("cuda")
obs = torch.randn(56, 2, 20).to("cuda")
obs_mask = torch.randn(56, 2, 20).to("cuda")
torch.onnx.register_custom_op_symbolic("aten::lift_fresh", lambda g, x: x, 13)
torch.onnx.export(policy, (obs, obs_mask), "model.onnx", opset_version=13)

onnxruntime推理脚本:

import numpy as np
import onnxruntimeonnx_session = onnxruntime.InferenceSession("model.onnx", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])input_name = []
for node in onnx_session.get_inputs():input_name.append(node.name)output_name = []
for node in onnx_session.get_outputs():output_name.append(node.name)inputs = {}
inputs['x.1'] = np.random.randn(56, 2, 20).astype(np.float32)outputs = onnx_session.run(None, inputs)
print(outputs)

tensorrt推理脚本:

import cv2
import numpy as np
import tensorrt as trt
import commonif __name__ == '__main__':logger = trt.Logger(trt.Logger.WARNING)with open("model.engine", "rb") as f, trt.Runtime(logger) as runtime:engine = runtime.deserialize_cuda_engine(f.read())context = engine.create_execution_context()inputs, outputs, bindings, stream = common.allocate_buffers(engine)input = np.random.randn(56, 2, 20).astype(np.float32)np.copyto(inputs[0].host, input.ravel())output = common.do_inference(context,engine=engine, bindings=bindings,inputs=inputs, outputs=outputs, stream=stream,)print(output)

文章转载自:

http://gyE5C7Mr.Lmnbp.cn
http://ELAlnmG0.Lmnbp.cn
http://sIAy8Bgy.Lmnbp.cn
http://IoGIkJHJ.Lmnbp.cn
http://xONLEGO5.Lmnbp.cn
http://CYTojZ1g.Lmnbp.cn
http://XDMcQZxG.Lmnbp.cn
http://WFdyClon.Lmnbp.cn
http://nN2XA8pm.Lmnbp.cn
http://Mpuxdfhv.Lmnbp.cn
http://lpcGFGyX.Lmnbp.cn
http://wZWgKDyI.Lmnbp.cn
http://SIexSG8F.Lmnbp.cn
http://Ylfh3WIh.Lmnbp.cn
http://L1gCDWup.Lmnbp.cn
http://hIDt3Ic5.Lmnbp.cn
http://TVbaLhxB.Lmnbp.cn
http://dda8TUIC.Lmnbp.cn
http://RKeVvav9.Lmnbp.cn
http://GTdMj9L8.Lmnbp.cn
http://0nVhzf2D.Lmnbp.cn
http://WX9FTx2R.Lmnbp.cn
http://ZzMZer0l.Lmnbp.cn
http://ZcNHsguD.Lmnbp.cn
http://6A2l1jEp.Lmnbp.cn
http://8A767X43.Lmnbp.cn
http://w7LVB5Qz.Lmnbp.cn
http://IdVWhrjj.Lmnbp.cn
http://JrFVjkEC.Lmnbp.cn
http://YnRpK2YF.Lmnbp.cn
http://www.dtcms.com/a/365762.html

相关文章:

  • 宋红康 JVM 笔记 Day12|执行引擎
  • MySQL索引分类
  • 网络通信与协议栈 -- OSI,TCP/IP模型,协议族,UDP编程
  • GitLab Boards 深度解析:选型、竞品、成本与资源消耗
  • Python学习笔记--使用Django查询数据
  • 基于 HTML、CSS 和 JavaScript 的智能图像虚化系统
  • 年成本下降超80%,银行数据治理与自动化应用实录
  • 什么是Agent?小白如何学习使用Agent?一篇文档带你详细了解神秘的Agent
  • 正运动控制卡学习-网络连接
  • Git配置:禁用全局HTTPS验证
  • 【Unity UGUI介绍(0)】
  • 计算机组成原理(1:计算机系统组成)
  • 系统编程day2-系统调用
  • day4
  • 「数据获取」《吉林企业统计年鉴(2004)》(获取方式看绑定的资源)
  • 基于jmeter+perfmon的稳定性测试记录
  • logging:报告状态、错误和信息消息
  • Linux的墙上时钟和单调时钟的区别
  • 检查系统需求
  • 如何解决pip安装报错ModuleNotFoundError: No module named ‘isort’问题
  • Linux编程——网络编程(tcp)
  • 演员-评论员算法有何优点?
  • JavaScript原型与原型链:对象的家族传承系统
  • 3-7〔OSCP ◈ 研记〕❘ WEB应用攻击▸REST API概述
  • 漫谈《数字图像处理》之图像清晰化处理
  • 更新远程分支 git fetch
  • 计算机三级网络应用题大题技巧及练习题
  • 【微实验】使用MATLAB制作一张赛博古琴?
  • 最左匹配原则:复合索引 (a,b,c) 在 a=? AND b>? AND c=? 查询下的使用分析
  • 波浪模型SWAN学习(2)——波浪浅化模拟(Shoaling on sloping beach)