当前位置: 首页 > wzjs >正文

建设移动网站河南网络推广公司

建设移动网站,河南网络推广公司,建设网银官网,公司就我一个设计文章目录 前言一、class CVRPTester:__init__(self,env_params,model_params, tester_params)1.1函数解析1.2函数分析1.2.1加载预训练模型 1.2函数代码 二、class CVRPTester:run(self)函数解析函数代码 三、class CVRPTester:_test_one_batch(self, batch_size)函数解析函数代…

文章目录

  • 前言
  • 一、class CVRPTester:__init__(self,env_params,model_params, tester_params)
    • 1.1函数解析
    • 1.2函数分析
      • 1.2.1加载预训练模型
    • 1.2函数代码
  • 二、class CVRPTester:run(self)
    • 函数解析
    • 函数代码
  • 三、class CVRPTester:_test_one_batch(self, batch_size)
    • 函数解析
    • 函数代码
  • 附录
    • 代码(全)


前言

学习代码CVRPTester.py,对代码的分析如下。

/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPTester.py


一、class CVRPTester:init(self,env_params,model_params, tester_params)

1.1函数解析

执行流程图链接

在这里插入图片描述

1.2函数分析

1.2.1加载预训练模型

代码:

# Restore
model_load = tester_params['model_load']
checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
checkpoint = torch.load(checkpoint_fullname, map_location=device)
self.model.load_state_dict(checkpoint['model_state_dict'])
  • model_load: 这是一个字典,包含了从哪里加载预训练模型的路径信息以及具体的 epoch
model_load = tester_params['model_load']
  • checkpoint_fullname: 使用 Python 的字符串格式化功能,构造预训练模型的文件路径。
    这会生成形如 /path/to/model/checkpoint-8100.pt 的文件路径。即需要输入参数pathepoch
checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
  • 加载模型:
    • torch.load(checkpoint_fullname, map_location=device):从磁盘加载模型检查点(即 .pt 文件),并将其存储在 checkpoint 变量中。map_location=device 确保模型会被加载到正确的设备上(GPU 或 CPU)。
    • self.model.load_state_dict(checkpoint['model_state_dict']):从加载的检查点中提取模型的状态字典,并将其加载到 self.model 中。
checkpoint = torch.load(checkpoint_fullname, map_location=device)
self.model.load_state_dict(checkpoint['model_state_dict'])

示例
假设 tester_params_regret[‘model_load’] 如下所示:

tester_params_regret = {'model_load': {'path': '../../pretrained/vrp100','epoch': 8100,},# 其他参数...
}

然后 checkpoint_fullname 会被构造为/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/pretrained/models/checkpoint-8100.pt,模型会从该路径加载。
在这里插入图片描述


1.2函数代码

    def __init__(self,env_params,model_params,tester_params):# save argumentsself.env_params = env_paramsself.model_params = model_paramsself.tester_params = tester_params# result folder, loggerself.logger = getLogger(name='trainer')self.result_folder = get_result_folder()# cudaUSE_CUDA = self.tester_params['use_cuda']if USE_CUDA:cuda_device_num = self.tester_params['cuda_device_num']torch.cuda.set_device(cuda_device_num)device = torch.device('cuda', cuda_device_num)torch.set_default_tensor_type('torch.cuda.FloatTensor')else:device = torch.device('cpu')torch.set_default_tensor_type('torch.FloatTensor')self.device = device# ENV and MODELself.env = Env(**self.env_params)self.model = Model(**self.model_params)# Restoremodel_load = tester_params['model_load']checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)checkpoint = torch.load(checkpoint_fullname, map_location=device)self.model.load_state_dict(checkpoint['model_state_dict'])# utilityself.time_estimator = TimeEstimator()

二、class CVRPTester:run(self)

函数解析

函数执行流程图链接
在这里插入图片描述

函数代码

    def run(self):self.time_estimator.reset()score_AM = AverageMeter()aug_score_AM = AverageMeter()if self.tester_params['test_data_load']['enable']:self.env.use_saved_problems(self.tester_params['test_data_load']['filename'], self.device)test_num_episode = self.tester_params['test_episodes']episode = 0while episode < test_num_episode:remaining = test_num_episode - episodebatch_size = min(self.tester_params['test_batch_size'], remaining)score, aug_score = self._test_one_batch(batch_size)score_AM.update(score, batch_size)aug_score_AM.update(aug_score, batch_size)episode += batch_size############################# Logs############################elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(episode, test_num_episode)self.logger.info("episode {:3d}/{:3d}, Elapsed[{}], Remain[{}], score:{:.3f}, aug_score:{:.3f}".format(episode, test_num_episode, elapsed_time_str, remain_time_str, score, aug_score))all_done = (episode == test_num_episode)if all_done:self.logger.info(" *** Test Done *** ")self.logger.info(" NO-AUG SCORE: {:.4f} ".format(score_AM.avg))self.logger.info(" AUGMENTATION SCORE: {:.4f} ".format(aug_score_AM.avg))

三、class CVRPTester:_test_one_batch(self, batch_size)

函数解析

执行流程图链接
在这里插入图片描述

函数代码

    def _test_one_batch(self, batch_size):# Augmentation###############################################if self.tester_params['augmentation_enable']:aug_factor = self.tester_params['aug_factor']else:aug_factor = 1# Ready###############################################self.model.eval()with torch.no_grad():self.env.load_problems(batch_size, aug_factor)reset_state, _, _ = self.env.reset()self.model.pre_forward(reset_state)# POMO Rollout###############################################state, reward, done = self.env.pre_step()while not done:selected, _ = self.model(state)# shape: (batch, pomo)state, reward, done = self.env.step(selected)# Return###############################################aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size)# shape: (augmentation, batch, pomo)max_pomo_reward, _ = aug_reward.max(dim=2)  # get best results from pomo# shape: (augmentation, batch)no_aug_score = -max_pomo_reward[0, :].float().mean()  # negative sign to make positive valuemax_aug_pomo_reward, _ = max_pomo_reward.max(dim=0)  # get best results from augmentation# shape: (batch,)aug_score = -max_aug_pomo_reward.float().mean()  # negative sign to make positive valuereturn no_aug_score.item(), aug_score.item()

附录

代码(全)


import torchimport os
from logging import getLoggerfrom CVRPEnv import CVRPEnv as Env
from CVRPModel import CVRPModel as Modelfrom utils.utils import *class CVRPTester:def __init__(self,env_params,model_params,tester_params):# save argumentsself.env_params = env_paramsself.model_params = model_paramsself.tester_params = tester_params# result folder, loggerself.logger = getLogger(name='trainer')self.result_folder = get_result_folder()# cudaUSE_CUDA = self.tester_params['use_cuda']if USE_CUDA:cuda_device_num = self.tester_params['cuda_device_num']torch.cuda.set_device(cuda_device_num)device = torch.device('cuda', cuda_device_num)torch.set_default_tensor_type('torch.cuda.FloatTensor')else:device = torch.device('cpu')torch.set_default_tensor_type('torch.FloatTensor')self.device = device# ENV and MODELself.env = Env(**self.env_params)self.model = Model(**self.model_params)# Restoremodel_load = tester_params['model_load']checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)checkpoint = torch.load(checkpoint_fullname, map_location=device)self.model.load_state_dict(checkpoint['model_state_dict'])# utilityself.time_estimator = TimeEstimator()def run(self):self.time_estimator.reset()score_AM = AverageMeter()aug_score_AM = AverageMeter()if self.tester_params['test_data_load']['enable']:self.env.use_saved_problems(self.tester_params['test_data_load']['filename'], self.device)test_num_episode = self.tester_params['test_episodes']episode = 0while episode < test_num_episode:remaining = test_num_episode - episodebatch_size = min(self.tester_params['test_batch_size'], remaining)score, aug_score = self._test_one_batch(batch_size)score_AM.update(score, batch_size)aug_score_AM.update(aug_score, batch_size)episode += batch_size############################# Logs############################elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(episode, test_num_episode)self.logger.info("episode {:3d}/{:3d}, Elapsed[{}], Remain[{}], score:{:.3f}, aug_score:{:.3f}".format(episode, test_num_episode, elapsed_time_str, remain_time_str, score, aug_score))all_done = (episode == test_num_episode)if all_done:self.logger.info(" *** Test Done *** ")self.logger.info(" NO-AUG SCORE: {:.4f} ".format(score_AM.avg))self.logger.info(" AUGMENTATION SCORE: {:.4f} ".format(aug_score_AM.avg))def _test_one_batch(self, batch_size):# Augmentation###############################################if self.tester_params['augmentation_enable']:aug_factor = self.tester_params['aug_factor']else:aug_factor = 1# Ready###############################################self.model.eval()with torch.no_grad():self.env.load_problems(batch_size, aug_factor)reset_state, _, _ = self.env.reset()self.model.pre_forward(reset_state)# POMO Rollout###############################################state, reward, done = self.env.pre_step()while not done:selected, _ = self.model(state)# shape: (batch, pomo)state, reward, done = self.env.step(selected)# Return###############################################aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size)# shape: (augmentation, batch, pomo)max_pomo_reward, _ = aug_reward.max(dim=2)  # get best results from pomo# shape: (augmentation, batch)no_aug_score = -max_pomo_reward[0, :].float().mean()  # negative sign to make positive valuemax_aug_pomo_reward, _ = max_pomo_reward.max(dim=0)  # get best results from augmentation# shape: (batch,)aug_score = -max_aug_pomo_reward.float().mean()  # negative sign to make positive valuereturn no_aug_score.item(), aug_score.item()
http://www.dtcms.com/wzjs/840190.html

相关文章:

  • 建设银行企业网站进不去php手机wap网站源码
  • 郑州汉狮做网站网络公司中国公共招聘网
  • 动态速写网站北京建设网坡屋顶改造工程指标
  • cms网站开发个人网站不能备案
  • 彩票站自己做网站昆明网站服务器
  • 高明网站建设产品设计留学哪个国家好
  • 广州市手机网站建设平台网站建设购买模板
  • 千岛湖建设集团有限公司网站网站开发过程的需求分析
  • 自己做app的网站安卓优化大师手机版下载
  • 做租房网站可信网站认证费用
  • 儿童网站模板网站托管服务商查询
  • wpf做网站如何做个小程序自己卖货
  • wordpress sns上海营销seo
  • 西安做网站公司哪家好怎么在凡科上做网站
  • 最好企业网站千岛湖建设集团办公网站
  • 商业设计网站有哪些如何做线下推广
  • 婚宴网站源码龙岩正规招聘网
  • 北京监理建设协会网站网站设计的就业和发展前景
  • 有没有专门做ppt的网站湖南省水运建设投资集团网站
  • 广东网站设计费用一个网站建设的组成
  • 用网站做邮箱wordpress页面编辑乱码
  • 昭通网站建设网站规划管理系统
  • 能打开各种网站的浏览器下载合集如何把网站放到空间别人可以访问
  • 网站建设大概价格国外的响应式网站模板
  • 遵义做网站的网络公司wordpress下载的插件怎么用
  • 游戏开发与网站开发哪个难娄底网站建设企业
  • 闸北微信网站建设网站设计模版免费下载
  • 河南网站建设多少钱网站建设公司挣钱吗
  • 公司网站开发实施方案外贸网站改版公司哪家好
  • 织梦手机wap网站标签调用垂直网站