东莞市建设企业网站企业seo学院
文章目录
- 前言
- 一、class CVRPTester:__init__(self,env_params,model_params, tester_params)
- 1.1函数解析
- 1.2函数分析
- 1.2.1加载预训练模型
- 1.2函数代码
- 二、class CVRPTester:run(self)
- 函数解析
- 函数代码
- 三、class CVRPTester:_test_one_batch(self, batch_size)
- 函数解析
- 函数代码
- 附录
- 代码(全)
前言
学习代码CVRPTester.py,对代码的分析如下。
/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPTester.py
一、class CVRPTester:init(self,env_params,model_params, tester_params)
1.1函数解析
执行流程图链接
1.2函数分析
1.2.1加载预训练模型
代码:
# Restore
model_load = tester_params['model_load']
checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
checkpoint = torch.load(checkpoint_fullname, map_location=device)
self.model.load_state_dict(checkpoint['model_state_dict'])
model_load
: 这是一个字典,包含了从哪里加载预训练模型的路径信息以及具体的epoch
:
model_load = tester_params['model_load']
checkpoint_fullname
: 使用 Python 的字符串格式化功能,构造预训练模型的文件路径。
这会生成形如/path/to/model/checkpoint-8100.pt
的文件路径。即需要输入参数path
和epoch
。
checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
- 加载模型:
torch.load(checkpoint_fullname, map_location=device)
:从磁盘加载模型检查点(即 .pt 文件),并将其存储在checkpoint
变量中。map_location=device
确保模型会被加载到正确的设备上(GPU 或 CPU)。self.model.load_state_dict(checkpoint['model_state_dict'])
:从加载的检查点中提取模型的状态字典,并将其加载到self.model
中。
checkpoint = torch.load(checkpoint_fullname, map_location=device)
self.model.load_state_dict(checkpoint['model_state_dict'])
示例
假设 tester_params_regret[‘model_load’] 如下所示:
tester_params_regret = {'model_load': {'path': '../../pretrained/vrp100','epoch': 8100,},# 其他参数...
}
然后 checkpoint_fullname
会被构造为/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/pretrained/models/checkpoint-8100.pt
,模型会从该路径加载。
1.2函数代码
def __init__(self,env_params,model_params,tester_params):# save argumentsself.env_params = env_paramsself.model_params = model_paramsself.tester_params = tester_params# result folder, loggerself.logger = getLogger(name='trainer')self.result_folder = get_result_folder()# cudaUSE_CUDA = self.tester_params['use_cuda']if USE_CUDA:cuda_device_num = self.tester_params['cuda_device_num']torch.cuda.set_device(cuda_device_num)device = torch.device('cuda', cuda_device_num)torch.set_default_tensor_type('torch.cuda.FloatTensor')else:device = torch.device('cpu')torch.set_default_tensor_type('torch.FloatTensor')self.device = device# ENV and MODELself.env = Env(**self.env_params)self.model = Model(**self.model_params)# Restoremodel_load = tester_params['model_load']checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)checkpoint = torch.load(checkpoint_fullname, map_location=device)self.model.load_state_dict(checkpoint['model_state_dict'])# utilityself.time_estimator = TimeEstimator()
二、class CVRPTester:run(self)
函数解析
函数执行流程图链接
函数代码
def run(self):self.time_estimator.reset()score_AM = AverageMeter()aug_score_AM = AverageMeter()if self.tester_params['test_data_load']['enable']:self.env.use_saved_problems(self.tester_params['test_data_load']['filename'], self.device)test_num_episode = self.tester_params['test_episodes']episode = 0while episode < test_num_episode:remaining = test_num_episode - episodebatch_size = min(self.tester_params['test_batch_size'], remaining)score, aug_score = self._test_one_batch(batch_size)score_AM.update(score, batch_size)aug_score_AM.update(aug_score, batch_size)episode += batch_size############################# Logs############################elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(episode, test_num_episode)self.logger.info("episode {:3d}/{:3d}, Elapsed[{}], Remain[{}], score:{:.3f}, aug_score:{:.3f}".format(episode, test_num_episode, elapsed_time_str, remain_time_str, score, aug_score))all_done = (episode == test_num_episode)if all_done:self.logger.info(" *** Test Done *** ")self.logger.info(" NO-AUG SCORE: {:.4f} ".format(score_AM.avg))self.logger.info(" AUGMENTATION SCORE: {:.4f} ".format(aug_score_AM.avg))
三、class CVRPTester:_test_one_batch(self, batch_size)
函数解析
执行流程图链接
函数代码
def _test_one_batch(self, batch_size):# Augmentation###############################################if self.tester_params['augmentation_enable']:aug_factor = self.tester_params['aug_factor']else:aug_factor = 1# Ready###############################################self.model.eval()with torch.no_grad():self.env.load_problems(batch_size, aug_factor)reset_state, _, _ = self.env.reset()self.model.pre_forward(reset_state)# POMO Rollout###############################################state, reward, done = self.env.pre_step()while not done:selected, _ = self.model(state)# shape: (batch, pomo)state, reward, done = self.env.step(selected)# Return###############################################aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size)# shape: (augmentation, batch, pomo)max_pomo_reward, _ = aug_reward.max(dim=2) # get best results from pomo# shape: (augmentation, batch)no_aug_score = -max_pomo_reward[0, :].float().mean() # negative sign to make positive valuemax_aug_pomo_reward, _ = max_pomo_reward.max(dim=0) # get best results from augmentation# shape: (batch,)aug_score = -max_aug_pomo_reward.float().mean() # negative sign to make positive valuereturn no_aug_score.item(), aug_score.item()
附录
代码(全)
import torchimport os
from logging import getLoggerfrom CVRPEnv import CVRPEnv as Env
from CVRPModel import CVRPModel as Modelfrom utils.utils import *class CVRPTester:def __init__(self,env_params,model_params,tester_params):# save argumentsself.env_params = env_paramsself.model_params = model_paramsself.tester_params = tester_params# result folder, loggerself.logger = getLogger(name='trainer')self.result_folder = get_result_folder()# cudaUSE_CUDA = self.tester_params['use_cuda']if USE_CUDA:cuda_device_num = self.tester_params['cuda_device_num']torch.cuda.set_device(cuda_device_num)device = torch.device('cuda', cuda_device_num)torch.set_default_tensor_type('torch.cuda.FloatTensor')else:device = torch.device('cpu')torch.set_default_tensor_type('torch.FloatTensor')self.device = device# ENV and MODELself.env = Env(**self.env_params)self.model = Model(**self.model_params)# Restoremodel_load = tester_params['model_load']checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)checkpoint = torch.load(checkpoint_fullname, map_location=device)self.model.load_state_dict(checkpoint['model_state_dict'])# utilityself.time_estimator = TimeEstimator()def run(self):self.time_estimator.reset()score_AM = AverageMeter()aug_score_AM = AverageMeter()if self.tester_params['test_data_load']['enable']:self.env.use_saved_problems(self.tester_params['test_data_load']['filename'], self.device)test_num_episode = self.tester_params['test_episodes']episode = 0while episode < test_num_episode:remaining = test_num_episode - episodebatch_size = min(self.tester_params['test_batch_size'], remaining)score, aug_score = self._test_one_batch(batch_size)score_AM.update(score, batch_size)aug_score_AM.update(aug_score, batch_size)episode += batch_size############################# Logs############################elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(episode, test_num_episode)self.logger.info("episode {:3d}/{:3d}, Elapsed[{}], Remain[{}], score:{:.3f}, aug_score:{:.3f}".format(episode, test_num_episode, elapsed_time_str, remain_time_str, score, aug_score))all_done = (episode == test_num_episode)if all_done:self.logger.info(" *** Test Done *** ")self.logger.info(" NO-AUG SCORE: {:.4f} ".format(score_AM.avg))self.logger.info(" AUGMENTATION SCORE: {:.4f} ".format(aug_score_AM.avg))def _test_one_batch(self, batch_size):# Augmentation###############################################if self.tester_params['augmentation_enable']:aug_factor = self.tester_params['aug_factor']else:aug_factor = 1# Ready###############################################self.model.eval()with torch.no_grad():self.env.load_problems(batch_size, aug_factor)reset_state, _, _ = self.env.reset()self.model.pre_forward(reset_state)# POMO Rollout###############################################state, reward, done = self.env.pre_step()while not done:selected, _ = self.model(state)# shape: (batch, pomo)state, reward, done = self.env.step(selected)# Return###############################################aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size)# shape: (augmentation, batch, pomo)max_pomo_reward, _ = aug_reward.max(dim=2) # get best results from pomo# shape: (augmentation, batch)no_aug_score = -max_pomo_reward[0, :].float().mean() # negative sign to make positive valuemax_aug_pomo_reward, _ = max_pomo_reward.max(dim=0) # get best results from augmentation# shape: (batch,)aug_score = -max_aug_pomo_reward.float().mean() # negative sign to make positive valuereturn no_aug_score.item(), aug_score.item()