当前位置: 首页 > news >正文

20250408-报错:pre_state = state同更新现象

项目场景:

基于强化学习解决组合优化问题
在这里插入图片描述


问题描述

        # POMO Rollout
        state, reward, done = self.env.pre_step()
        # next_state = state
        with autocast():
            while not done:
                # 执行动作并获取新状态和奖励
                
                selected, prob = self.model(state)
                
                # 更新状态:因为·self.env.step(selected)之后会自动更新state
                # 为了解决克隆state问题构建一个新的函数方法(.clone())专门克隆state所有向量
                pre_state = state
                state, reward, done = self.env.step(selected)

                # 存储当前的动作概率(用于后续计算比率)
                prob_list = torch.cat((prob_list, prob[:, :, None]), dim=2)
                old_prob_list = torch.cat((old_prob_list, prob[:, :, None]), dim=2)

                print(f"Encoded nodes in forward: {reward.shape}")

                # 计算当前状态的值函数 V(s)
                value = self.model_critic(pre_state)
                # 计算下一个状态的值函数 V(s')
                next_value = self.model_critic(state)

                # 计算优势函数
                advantage = reward + (1 - done.to(torch.int)) * next_value - value             
pre_state = state

出现与目标相反的现象:pre_state = state目的是保存state的前一步数据,因为state, reward, done = self.env.step(selected)会更新state的值。但随着state的更新pre_state也会更新。


原因分析:

state是类似全局变量的参数,可能pre_statestate指着同一个地址,这导致state更新pre_state也更新。


解决方案:

使用.clone()操作将他们分开到两个地址中。
但要注意的是state = Step_State()。其中 Step_State()为多向量数据。

@dataclass
class Step_State:
    BATCH_IDX: torch.Tensor = None      #表示批次的索引 # shape: (batch, pomo)
    POMO_IDX: torch.Tensor = None       #表示 POMO 算法中的多智能体索引 # shape: (batch, pomo)
    selected_count: int = None          #表示当前已经选中的节点数量 # shape: (batch, pomo)
    load: torch.Tensor = None           #表示当前负载状态 # shape: (batch, pomo)
    current_node: torch.Tensor = None   #表示当前正在访问的节点编号 # shape: (batch, pomo)
    ninf_mask: torch.Tensor = None      #表示负无穷掩码    # shape: (batch, pomo, problem+1)
    

这里构建一个函数cloneStep_State()所有向量。

@dataclass
class Step_State:
    BATCH_IDX: torch.Tensor = None      #表示批次的索引 # shape: (batch, pomo)
    POMO_IDX: torch.Tensor = None       #表示 POMO 算法中的多智能体索引 # shape: (batch, pomo)
    selected_count: int = None          #表示当前已经选中的节点数量 # shape: (batch, pomo)
    load: torch.Tensor = None           #表示当前负载状态 # shape: (batch, pomo)
    current_node: torch.Tensor = None   #表示当前正在访问的节点编号 # shape: (batch, pomo)
    ninf_mask: torch.Tensor = None      #表示负无穷掩码    # shape: (batch, pomo, problem+1)
    
    def clone(self):
        # 克隆每个张量属性
        return Step_State(
            BATCH_IDX=self.BATCH_IDX.clone() if self.BATCH_IDX is not None else None,
            POMO_IDX=self.POMO_IDX.clone() if self.POMO_IDX is not None else None,
            selected_count=self.selected_count,  # 如果是整数,直接赋值即可
            load=self.load.clone() if self.load is not None else None,
            current_node=self.current_node.clone() if self.current_node is not None else None,
            ninf_mask=self.ninf_mask.clone() if self.ninf_mask is not None else None
        )

相关文章:

  • PortswiggerLab:Exploiting a mass assignment vulnerability
  • 【C#知识点详解】List<T>储存结构详解
  • Python 字典和集合(字典推导)
  • 美国NAB展会次日实况
  • C++ 基类的虚析构函数与派生的析构函数关系
  • Reflexion 框架 | 提示词工程(4)
  • HOW - React 组件渲染受其他无关数据影响和优化方案(含memo和props.children)
  • equals() 和 hashCode()
  • 泛目录站群,无极多功能泛目录站群程序:AI驱动的SEO增长引擎
  • java设计模式-单例模式
  • 【unity游戏开发入门到精通——动画篇】Animator2D序列帧动画
  • 解锁健康养生密码,拥抱活力人生
  • 手写数字识别实战教程:从零实现MNIST分类器(完整代码示例)
  • 算法篇(八)【递归】
  • 【代码随想录 字符串6.实现strstr】 KMP算法。
  • 1区6.6分CHARLS最新文章解读
  • 【学习笔记】文件上传漏洞--二次渲染、.htaccess、变异免杀
  • 2025年客运从业资格证备考刷题题库
  • 7-11 分段计算居民水费
  • 告别循环!用Stream优雅处理集合
  • flow创意博客wordpress主题/宁波seo公司网站推广
  • cms建站模板app/windows优化大师有哪些功能
  • 甘特图模板关于网站建设/怎么建个网站
  • 关于h5的网站模板/旧版优化大师
  • 安徽动漫公司 网站制作 安徽网新/aso优化什么意思
  • 网站域名备案密码/品牌运营策略