当前位置: 首页 > news >正文

20250303-代码笔记-class CVRPTester

文章目录

  • 前言
  • 一、class CVRPTester:__init__(self,env_params,model_params, tester_params)
    • 1.1函数解析
    • 1.2函数分析
      • 1.2.1加载预训练模型
    • 1.2函数代码
  • 二、class CVRPTester:run(self)
    • 函数解析
    • 函数代码
  • 三、class CVRPTester:_test_one_batch(self, batch_size)
    • 函数解析
    • 函数代码
  • 附录
    • 代码(全)


前言

学习代码CVRPTester.py,对代码的分析如下。

/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPTester.py


一、class CVRPTester:init(self,env_params,model_params, tester_params)

1.1函数解析

执行流程图链接

在这里插入图片描述

1.2函数分析

1.2.1加载预训练模型

代码:

# Restore
model_load = tester_params['model_load']
checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
checkpoint = torch.load(checkpoint_fullname, map_location=device)
self.model.load_state_dict(checkpoint['model_state_dict'])
  • model_load: 这是一个字典,包含了从哪里加载预训练模型的路径信息以及具体的 epoch
model_load = tester_params['model_load']
  • checkpoint_fullname: 使用 Python 的字符串格式化功能,构造预训练模型的文件路径。
    这会生成形如 /path/to/model/checkpoint-8100.pt 的文件路径。即需要输入参数pathepoch
checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
  • 加载模型:
    • torch.load(checkpoint_fullname, map_location=device):从磁盘加载模型检查点(即 .pt 文件),并将其存储在 checkpoint 变量中。map_location=device 确保模型会被加载到正确的设备上(GPU 或 CPU)。
    • self.model.load_state_dict(checkpoint['model_state_dict']):从加载的检查点中提取模型的状态字典,并将其加载到 self.model 中。
checkpoint = torch.load(checkpoint_fullname, map_location=device)
self.model.load_state_dict(checkpoint['model_state_dict'])

示例
假设 tester_params_regret[‘model_load’] 如下所示:

tester_params_regret = {
    'model_load': {
        'path': '../../pretrained/vrp100',
       
        'epoch': 8100,
    },
    # 其他参数...
}

然后 checkpoint_fullname 会被构造为/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/pretrained/models/checkpoint-8100.pt,模型会从该路径加载。
在这里插入图片描述


1.2函数代码

    def __init__(self,
                 env_params,
                 model_params,
                 tester_params):

        # save arguments
        self.env_params = env_params
        self.model_params = model_params
        self.tester_params = tester_params

        # result folder, logger
        self.logger = getLogger(name='trainer')
        self.result_folder = get_result_folder()


        # cuda
        USE_CUDA = self.tester_params['use_cuda']
        if USE_CUDA:
            cuda_device_num = self.tester_params['cuda_device_num']
            torch.cuda.set_device(cuda_device_num)
            device = torch.device('cuda', cuda_device_num)
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            device = torch.device('cpu')
            torch.set_default_tensor_type('torch.FloatTensor')
        self.device = device

        # ENV and MODEL
        self.env = Env(**self.env_params)
        self.model = Model(**self.model_params)

        # Restore
        model_load = tester_params['model_load']
        checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
        checkpoint = torch.load(checkpoint_fullname, map_location=device)
        self.model.load_state_dict(checkpoint['model_state_dict'])

        # utility
        self.time_estimator = TimeEstimator()


二、class CVRPTester:run(self)

函数解析

函数执行流程图链接
在这里插入图片描述

函数代码

    def run(self):
        self.time_estimator.reset()

        score_AM = AverageMeter()
        aug_score_AM = AverageMeter()

        if self.tester_params['test_data_load']['enable']:
            self.env.use_saved_problems(self.tester_params['test_data_load']['filename'], self.device)

        test_num_episode = self.tester_params['test_episodes']
        episode = 0

        while episode < test_num_episode:

            remaining = test_num_episode - episode
            batch_size = min(self.tester_params['test_batch_size'], remaining)

            score, aug_score = self._test_one_batch(batch_size)

            score_AM.update(score, batch_size)
            aug_score_AM.update(aug_score, batch_size)

            episode += batch_size

            ############################
            # Logs
            ############################
            elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(episode, test_num_episode)
            self.logger.info("episode {:3d}/{:3d}, Elapsed[{}], Remain[{}], score:{:.3f}, aug_score:{:.3f}".format(
                episode, test_num_episode, elapsed_time_str, remain_time_str, score, aug_score))

            all_done = (episode == test_num_episode)

            if all_done:
                self.logger.info(" *** Test Done *** ")
                self.logger.info(" NO-AUG SCORE: {:.4f} ".format(score_AM.avg))
                self.logger.info(" AUGMENTATION SCORE: {:.4f} ".format(aug_score_AM.avg))


三、class CVRPTester:_test_one_batch(self, batch_size)

函数解析

执行流程图链接
在这里插入图片描述

函数代码

    def _test_one_batch(self, batch_size):

        # Augmentation
        ###############################################
        if self.tester_params['augmentation_enable']:
            aug_factor = self.tester_params['aug_factor']
        else:
            aug_factor = 1

        # Ready
        ###############################################
        self.model.eval()
        with torch.no_grad():
            self.env.load_problems(batch_size, aug_factor)
            reset_state, _, _ = self.env.reset()
            self.model.pre_forward(reset_state)

        # POMO Rollout
        ###############################################
        state, reward, done = self.env.pre_step()
        while not done:
            selected, _ = self.model(state)
            # shape: (batch, pomo)
            state, reward, done = self.env.step(selected)

        # Return
        ###############################################
        aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size)
        # shape: (augmentation, batch, pomo)

        max_pomo_reward, _ = aug_reward.max(dim=2)  # get best results from pomo
        # shape: (augmentation, batch)
        no_aug_score = -max_pomo_reward[0, :].float().mean()  # negative sign to make positive value

        max_aug_pomo_reward, _ = max_pomo_reward.max(dim=0)  # get best results from augmentation
        # shape: (batch,)
        aug_score = -max_aug_pomo_reward.float().mean()  # negative sign to make positive value

        return no_aug_score.item(), aug_score.item()


附录

代码(全)


import torch

import os
from logging import getLogger

from CVRPEnv import CVRPEnv as Env
from CVRPModel import CVRPModel as Model

from utils.utils import *


class CVRPTester:
    def __init__(self,
                 env_params,
                 model_params,
                 tester_params):

        # save arguments
        self.env_params = env_params
        self.model_params = model_params
        self.tester_params = tester_params

        # result folder, logger
        self.logger = getLogger(name='trainer')
        self.result_folder = get_result_folder()


        # cuda
        USE_CUDA = self.tester_params['use_cuda']
        if USE_CUDA:
            cuda_device_num = self.tester_params['cuda_device_num']
            torch.cuda.set_device(cuda_device_num)
            device = torch.device('cuda', cuda_device_num)
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            device = torch.device('cpu')
            torch.set_default_tensor_type('torch.FloatTensor')
        self.device = device

        # ENV and MODEL
        self.env = Env(**self.env_params)
        self.model = Model(**self.model_params)

        # Restore
        model_load = tester_params['model_load']
        checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
        checkpoint = torch.load(checkpoint_fullname, map_location=device)
        self.model.load_state_dict(checkpoint['model_state_dict'])

        # utility
        self.time_estimator = TimeEstimator()

    def run(self):
        self.time_estimator.reset()

        score_AM = AverageMeter()
        aug_score_AM = AverageMeter()

        if self.tester_params['test_data_load']['enable']:
            self.env.use_saved_problems(self.tester_params['test_data_load']['filename'], self.device)

        test_num_episode = self.tester_params['test_episodes']
        episode = 0

        while episode < test_num_episode:

            remaining = test_num_episode - episode
            batch_size = min(self.tester_params['test_batch_size'], remaining)

            score, aug_score = self._test_one_batch(batch_size)

            score_AM.update(score, batch_size)
            aug_score_AM.update(aug_score, batch_size)

            episode += batch_size

            ############################
            # Logs
            ############################
            elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(episode, test_num_episode)
            self.logger.info("episode {:3d}/{:3d}, Elapsed[{}], Remain[{}], score:{:.3f}, aug_score:{:.3f}".format(
                episode, test_num_episode, elapsed_time_str, remain_time_str, score, aug_score))

            all_done = (episode == test_num_episode)

            if all_done:
                self.logger.info(" *** Test Done *** ")
                self.logger.info(" NO-AUG SCORE: {:.4f} ".format(score_AM.avg))
                self.logger.info(" AUGMENTATION SCORE: {:.4f} ".format(aug_score_AM.avg))

    def _test_one_batch(self, batch_size):

        # Augmentation
        ###############################################
        if self.tester_params['augmentation_enable']:
            aug_factor = self.tester_params['aug_factor']
        else:
            aug_factor = 1

        # Ready
        ###############################################
        self.model.eval()
        with torch.no_grad():
            self.env.load_problems(batch_size, aug_factor)
            reset_state, _, _ = self.env.reset()
            self.model.pre_forward(reset_state)

        # POMO Rollout
        ###############################################
        state, reward, done = self.env.pre_step()
        while not done:
            selected, _ = self.model(state)
            # shape: (batch, pomo)
            state, reward, done = self.env.step(selected)

        # Return
        ###############################################
        aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size)
        # shape: (augmentation, batch, pomo)

        max_pomo_reward, _ = aug_reward.max(dim=2)  # get best results from pomo
        # shape: (augmentation, batch)
        no_aug_score = -max_pomo_reward[0, :].float().mean()  # negative sign to make positive value

        max_aug_pomo_reward, _ = max_pomo_reward.max(dim=0)  # get best results from augmentation
        # shape: (batch,)
        aug_score = -max_aug_pomo_reward.float().mean()  # negative sign to make positive value

        return no_aug_score.item(), aug_score.item()

相关文章:

  • 【现代深度学习技术】卷积神经网络03:填充和步幅
  • BUU43 [BJDCTF2020]The mystery of ip 1
  • 通配符匹配在Redis中的实现
  • 爬虫不“刑”教程
  • c++ cout详解
  • 探秘虚拟与现实的融合:VR、AR、MR 技术的变革力量
  • 清华大学AI赋能医药代表销售培训讲师专家培训师唐兴通Deepseek医药数字化营销大健康数字化转型医药新媒体营销
  • 云平台 | 玩转单细胞比率可视化
  • yoloV5的学习-pycharm版本
  • T41LQ专为人工智能物联网(AIoT)应用设计,适用于智能安防、智能家居、机器视觉等领域 软硬件资料+样品测试
  • 【前端】前端设计中的响应式设计详解
  • PAT乙级真题 / 知识点(1)
  • 【JavaEE】线程安全
  • 从 JVM 源码(HotSpot)看 synchronized 原理
  • MySQL面试题(二)
  • c++ 内存管理系统之智能指针
  • Java自动拆箱装箱/实例化顺序/缓存使用/原理/实例
  • 关于常规模式下运行VScode无法正确执行“pwsh”问题
  • 【Python 数据结构 5.栈】
  • AIC8800---编译环境搭建
  • 以总理内塔尼亚胡称决心彻底击败哈马斯
  • 融创中国:今年前4个月销售额约112亿元
  • 马上评|比余华与史铁生的友情更动人的是什么
  • 44岁街舞运动推广者、浙江省街舞运动协会常务理事钟永玮离世
  • 专家解读《人源类器官研究伦理指引》:构建类器官研究全过程伦理治理框架
  • 南京明孝陵石兽遭涂鸦“到此一游”,景区:已恢复原貌,警方在排查