当前位置: 首页 > news >正文

egpo进行train_egpo训练时,keyvalueError:“replay_sequence_length“

def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> LocalIterator[dict]:
if config.get(“prioritized_replay”):
prio_args = {
“prioritized_replay_alpha”: config[“prioritized_replay_alpha”],
“prioritized_replay_beta”: config[“prioritized_replay_beta”],
“prioritized_replay_eps”: config[“prioritized_replay_eps”],
}
else:
prio_args = {}

local_replay_buffer = LocalReplayBuffer(num_shards=1,learning_starts=config["learning_starts"],buffer_size=config["buffer_size"],replay_batch_size=config["train_batch_size"],replay_mode=config["multiagent"]["replay_mode"],#这一行需要注释掉,如果不注释掉,整个代码就跑不起来,可能是因为ray1.4.1版本没有这个参数# replay_sequence_length=config["replay_sequence_length"],**prio_args)rollouts = ParallelRollouts(workers, mode="bulk_sync")# Update penalty
rollouts = rollouts.for_each(UpdateSaverPenalty(workers))
# We execute the following steps concurrently:
# (1) Generate rollouts and store them in our local replay buffer. Calling
# next() on store_op drives this.
store_op = rollouts.for_each(StoreToReplayBuffer(local_buffer=local_replay_buffer))def update_prio(item):samples, info_dict = itemif config.get("prioritized_replay"):prio_dict = {}for policy_id, info in info_dict.items():# TODO(sven): This is currently structured differently for#  torch/tf. Clean up these results/info dicts across#  policies (note: fixing this in torch_policy.py will#  break e.g. DDPPO!).td_error = info.get("td_error",info[LEARNER_STATS_KEY].get("td_error"))prio_dict[policy_id] = (samples.policy_batches[policy_id].data.get("batch_indexes"), td_error)local_replay_buffer.update_priorities(prio_dict)return info_dict# (2) Read and train on experiences from the replay buffer. Every batch
# returned from the LocalReplay() iterator is passed to TrainOneStep to
# take a SGD step, and then we decide whether to update the target network.
post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
replay_op = Replay(local_buffer=local_replay_buffer) \.for_each(lambda x: post_fn(x, workers, config)) \.for_each(TrainOneStep(workers)) \.for_each(update_prio) \.for_each(UpdateTargetNetwork(workers, config["target_network_update_freq"]))# Alternate deterministically between (1) and (2). Only return the output
# of (2) since training metrics are not available until (2) runs.
train_op = Concurrently([store_op, replay_op],mode="round_robin",output_indexes=[1],round_robin_weights=calculate_rr_weights(config))return StandardMetricsReporting(train_op, workers, config)

相关文章:

  • react+html-docx-js将页面导出为docx
  • 圈奶牛--二维凸包
  • HarmonyOs开发之———使用HTTP访问网络资源
  • 【Vue 3 + Vue Router 4】如何正确重置路由实例(resetRouter)——避免“VueRouter is not defined”错误
  • 前端面试每日三题 - Day 34
  • 【SSL部署与优化​】​​TLS 1.3的核心改进与性能优化​​
  • 模态参数识别中的特征实现算法
  • 嵌入式自学第二十一天(5.14)
  • 如何利用大模型对文章进行分段,提高向量搜索的准确性?
  • PyTorch 的自动微分和动态计算图
  • 信息化项目绩效管理办法V5.0
  • Seed1.5-VL:高效通用的视觉-语言基础模型
  • 基于 TensorFlow 框架的联邦学习可穿戴设备健康数据个性化健康管理平台研究
  • 单片机-STM32部分:14、SPI
  • 【计算机视觉】OpenCV实战项目:Face-Mask-Detection 项目深度解析:基于深度学习的口罩检测系统
  • 自然语言处理入门级项目——文本分类
  • MQTT 在Spring Boot 中的使用
  • Oracle — PL-SQL
  • 使用深度学习预训练模型检测物体
  • lesson01-PyTorch初见(理论+代码实战)
  • 杨文庄当选中国人口学会会长,曾任国家卫健委人口家庭司司长
  • 奥迪车加油时频繁“跳枪”维修两年未解决,4S店拒退换:可延长质保
  • 颜福庆与顾临的争论:1930年代在中国维持一家医学院要花多少钱
  • GDP逼近五千亿,向海图强,对接京津,沧州剑指沿海经济强市
  • 京东美团饿了么等外卖平台被约谈
  • 专访|导演刘江:给谍战题材注入现实主义的魂