当前位置: 首页 > news >正文

20250226-代码笔记05-class CVRP_Decoder

文章目录

  • 前言
  • 一、class CVRP_Decoder(nn.Module):__init__(self, **model_params)
    • 函数功能
    • 函数代码
  • 二、class CVRP_Decoder(nn.Module):set_kv(self, encoded_nodes)
    • 函数功能
    • 函数代码
  • 三、class CVRP_Decoder(nn.Module):set_q1(self, encoded_q1)
    • 函数功能
    • 函数代码
  • 四、class CVRP_Decoder(nn.Module):set_q2(self, encoded_q2)
    • 函数功能
    • 函数代码
  • 五、class CVRP_Decoder(nn.Module):forward(self, encoded_last_node, load, ninf_mask)
    • 函数功能
    • 函数代码
  • 附录
    • class CVRP_Decoder代码(全)


前言

class CVRP_DecoderCVRP_Model.py里的类。

/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPModel.py


一、class CVRP_Decoder(nn.Module):init(self, **model_params)

函数功能

init 方法是 CVRP_Decoder 类中的构造函数,主要功能是初始化该类所需的所有网络层、权重矩阵和参数。
该方法设置了用于多头注意力机制的权重、一个用于表示"遗憾"的参数、以及其他必要的操作用于计算注意力权重。

执行流程图链接
在这里插入图片描述

函数代码

    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']

        # self.Wq_1 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wq_2 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wq_last = nn.Linear(embedding_dim+1, head_num * qkv_dim, bias=False)
        self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)

        self.regret_embedding = nn.Parameter(torch.Tensor(embedding_dim))
        self.regret_embedding.data.uniform_(-1, 1)

        self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)

        self.k = None  # saved key, for multi-head attention
        self.v = None  # saved value, for multi-head_attention
        self.single_head_key = None  # saved, for single-head attention
        # self.q1 = None  # saved q1, for multi-head attention
        self.q2 = None  # saved q2, for multi-head attention


二、class CVRP_Decoder(nn.Module):set_kv(self, encoded_nodes)

函数功能

set_kv 方法的功能是将 encoded_nodes 中的节点嵌入转换为多头注意力机制所需的 键(K)值(V),并将它们分别保存为类的属性。
这个方法将输入的节点嵌入通过权重矩阵进行线性变换,得到键和值的表示,并为后续的多头注意力计算做好准备。
执行流程图链接
在这里插入图片描述

函数代码

    def set_kv(self, encoded_nodes):
        # encoded_nodes.shape: (batch, problem+1, embedding)
        head_num = self.model_params['head_num']

        self.k = reshape_by_heads(self.Wk(encoded_nodes), head_num=head_num)
        self.v = reshape_by_heads(self.Wv(encoded_nodes), head_num=head_num)
        # shape: (batch, head_num, problem+1, qkv_dim)
        self.single_head_key = encoded_nodes.transpose(1, 2)
        # shape: (batch, embedding, problem+1)


三、class CVRP_Decoder(nn.Module):set_q1(self, encoded_q1)

函数功能

set_q1 方法的主要功能是 计算查询(Q) 并将其转换为适用于多头注意力机制的形状。
该方法接受输入的查询张量 encoded_q1,通过线性层 self.Wq_1 映射到一个新的维度,并使用 reshape_by_heads 函数将其调整为适合多头注意力机制计算的形状。计算出的查询会被保存为类的属性 q1,供后续使用。

执行流程图链接
在这里插入图片描述

函数代码

    def set_q1(self, encoded_q1):
        # encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomo
        head_num = self.model_params['head_num']
        self.q1 = reshape_by_heads(self.Wq_1(encoded_q1), head_num=head_num)
        # shape: (batch, head_num, n, qkv_dim)


四、class CVRP_Decoder(nn.Module):set_q2(self, encoded_q2)

函数功能

set_q2 方法的主要功能是 计算查询(Q) 并将其转换为适用于多头注意力机制的形状。
该方法接收输入的查询张量 encoded_q2,通过线性层 self.Wq_2 映射到一个新的维度,并使用 reshape_by_heads 函数将其调整为适合多头注意力计算的形状。
执行流程图链接
在这里插入图片描述

函数代码

    def set_q2(self, encoded_q2):
        # encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomo
        head_num = self.model_params['head_num']
        self.q2 = reshape_by_heads(self.Wq_2(encoded_q2), head_num=head_num)
        # shape: (batch, head_num, n, qkv_dim)


五、class CVRP_Decoder(nn.Module):forward(self, encoded_last_node, load, ninf_mask)

函数功能

forward 方法是 CVRP_Decoder 类中的前向传播函数,主要功能是执行 多头自注意力机制 和 单头注意力计算,并最终输出每个可能节点的选择概率(probs)。
该方法通过多头注意力计算、前馈神经网络处理,以及概率计算来进行节点选择。

执行流程图链接
在这里插入图片描述

函数代码

    def forward(self, encoded_last_node, load, ninf_mask):
        # encoded_last_node.shape: (batch, pomo, embedding)
        # load.shape: (batch, pomo)
        # ninf_mask.shape: (batch, pomo, problem)

        head_num = self.model_params['head_num']

        #  Multi-Head Attention
        #######################################################
        input_cat = torch.cat((encoded_last_node, load[:, :, None]), dim=2)
        # shape = (batch, group, EMBEDDING_DIM+1)

        q_last = reshape_by_heads(self.Wq_last(input_cat), head_num=head_num)
        # shape: (batch, head_num, pomo, qkv_dim)

        # q = self.q1 + self.q2 + q_last
        # # shape: (batch, head_num, pomo, qkv_dim)
        # q = q_last
        # shape: (batch, head_num, pomo, qkv_dim)
        q = self.q2 + q_last
        # # shape: (batch, head_num, pomo, qkv_dim)

        out_concat = multi_head_attention(q, self.k, self.v, rank3_ninf_mask=ninf_mask)
        # shape: (batch, pomo, head_num*qkv_dim)

        mh_atten_out = self.multi_head_combine(out_concat)
        # shape: (batch, pomo, embedding)

        #  Single-Head Attention, for probability calculation
        #######################################################
        score = torch.matmul(mh_atten_out, self.single_head_key)
        # shape: (batch, pomo, problem)

        sqrt_embedding_dim = self.model_params['sqrt_embedding_dim']
        logit_clipping = self.model_params['logit_clipping']

        score_scaled = score / sqrt_embedding_dim
        # shape: (batch, pomo, problem)

        score_clipped = logit_clipping * torch.tanh(score_scaled)

        score_masked = score_clipped + ninf_mask

        probs = F.softmax(score_masked, dim=2)
        # shape: (batch, pomo, problem)

        return probs


附录

class CVRP_Decoder代码(全)

class CVRP_Decoder(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']

        # self.Wq_1 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wq_2 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wq_last = nn.Linear(embedding_dim+1, head_num * qkv_dim, bias=False)
        self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)

        self.regret_embedding = nn.Parameter(torch.Tensor(embedding_dim))
        self.regret_embedding.data.uniform_(-1, 1)

        self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)

        self.k = None  # saved key, for multi-head attention
        self.v = None  # saved value, for multi-head_attention
        self.single_head_key = None  # saved, for single-head attention
        # self.q1 = None  # saved q1, for multi-head attention
        self.q2 = None  # saved q2, for multi-head attention

    def set_kv(self, encoded_nodes):
        # encoded_nodes.shape: (batch, problem+1, embedding)
        head_num = self.model_params['head_num']

        self.k = reshape_by_heads(self.Wk(encoded_nodes), head_num=head_num)
        self.v = reshape_by_heads(self.Wv(encoded_nodes), head_num=head_num)
        # shape: (batch, head_num, problem+1, qkv_dim)
        self.single_head_key = encoded_nodes.transpose(1, 2)
        # shape: (batch, embedding, problem+1)

    def set_q1(self, encoded_q1):
        # encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomo
        head_num = self.model_params['head_num']
        self.q1 = reshape_by_heads(self.Wq_1(encoded_q1), head_num=head_num)
        # shape: (batch, head_num, n, qkv_dim)

    def set_q2(self, encoded_q2):
        # encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomo
        head_num = self.model_params['head_num']
        self.q2 = reshape_by_heads(self.Wq_2(encoded_q2), head_num=head_num)
        # shape: (batch, head_num, n, qkv_dim)

    def forward(self, encoded_last_node, load, ninf_mask):
        # encoded_last_node.shape: (batch, pomo, embedding)
        # load.shape: (batch, pomo)
        # ninf_mask.shape: (batch, pomo, problem)

        head_num = self.model_params['head_num']

        #  Multi-Head Attention
        #######################################################
        input_cat = torch.cat((encoded_last_node, load[:, :, None]), dim=2)
        # shape = (batch, group, EMBEDDING_DIM+1)

        q_last = reshape_by_heads(self.Wq_last(input_cat), head_num=head_num)
        # shape: (batch, head_num, pomo, qkv_dim)

        # q = self.q1 + self.q2 + q_last
        # # shape: (batch, head_num, pomo, qkv_dim)
        # q = q_last
        # shape: (batch, head_num, pomo, qkv_dim)
        q = self.q2 + q_last
        # # shape: (batch, head_num, pomo, qkv_dim)

        out_concat = multi_head_attention(q, self.k, self.v, rank3_ninf_mask=ninf_mask)
        # shape: (batch, pomo, head_num*qkv_dim)

        mh_atten_out = self.multi_head_combine(out_concat)
        # shape: (batch, pomo, embedding)

        #  Single-Head Attention, for probability calculation
        #######################################################
        score = torch.matmul(mh_atten_out, self.single_head_key)
        # shape: (batch, pomo, problem)

        sqrt_embedding_dim = self.model_params['sqrt_embedding_dim']
        logit_clipping = self.model_params['logit_clipping']

        score_scaled = score / sqrt_embedding_dim
        # shape: (batch, pomo, problem)

        score_clipped = logit_clipping * torch.tanh(score_scaled)

        score_masked = score_clipped + ninf_mask

        probs = F.softmax(score_masked, dim=2)
        # shape: (batch, pomo, problem)

        return probs


http://www.dtcms.com/a/45473.html

相关文章:

  • 【无人机】无人机通信模块,无人机图数传模块的介绍,数传,图传,图传数传一体电台,
  • 什么是HA
  • keil主题(vscode风格)
  • ClickHouse
  • P1123 取数游戏
  • 实战案例:排查 Java 应用 CPU 飙高问题
  • 自由学习记录(40)
  • HFSS 仿真学习1 K波段定向耦合器
  • JAVA面试_进阶部分_netty面试题
  • 【Java】多线程篇 —— 多线程的基本使用
  • 58、深度学习-自学之路-自己搭建深度学习框架-19、RNN神经网络梯度消失和爆炸的原因(从公式推导方向来说明),通过RNN的前向传播和反向传播公式来理解。
  • 商城源码的框架
  • JAVA学习笔记038——bean的概念和常见注解标注
  • 计算机毕业设计SpringBoot+Vue.js体育馆使用预约平台(源码+文档+PPT+讲解)
  • Pytest之fixture的常见用法
  • AI人工智能机器学习之监督线性模型
  • 【广度优先搜索】图像渲染 岛屿数量
  • 7-1JVMCG垃圾回收
  • 【文献阅读】A Survey Of Resource-Efficient LLM And Multimodal Foundation Models
  • 如何保证 Redis 缓存和数据库的一致性?
  • 在编译Linux的内核镜像和模块时,必须先编译内核镜像,再编译模块,顺序不可随意调整的原因
  • 备战蓝桥杯Day11 DFS
  • React 常见面试题及答案
  • Mysql系统表
  • 【考试大纲】中级信息安全工程师考试大纲
  • HTMLS基本结构及标签
  • 神经网络之CNN图像识别(torch api 调用)
  • 建易WordPress
  • 算法-二叉树篇23-二叉搜索树中的插入操作
  • 夜天之书 #106 Apache 软件基金会如何投票选举?