当前位置: 首页 > wzjs >正文

东莞市建设企业网站企业seo学院

东莞市建设企业网站企业,seo学院,网络营销方式内容角度,网站开发层次文章目录 前言一、class CVRPTester:__init__(self,env_params,model_params, tester_params)1.1函数解析1.2函数分析1.2.1加载预训练模型 1.2函数代码 二、class CVRPTester:run(self)函数解析函数代码 三、class CVRPTester:_test_one_batch(self, batch_size)函数解析函数代…

文章目录

  • 前言
  • 一、class CVRPTester:__init__(self,env_params,model_params, tester_params)
    • 1.1函数解析
    • 1.2函数分析
      • 1.2.1加载预训练模型
    • 1.2函数代码
  • 二、class CVRPTester:run(self)
    • 函数解析
    • 函数代码
  • 三、class CVRPTester:_test_one_batch(self, batch_size)
    • 函数解析
    • 函数代码
  • 附录
    • 代码(全)


前言

学习代码CVRPTester.py,对代码的分析如下。

/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPTester.py


一、class CVRPTester:init(self,env_params,model_params, tester_params)

1.1函数解析

执行流程图链接

在这里插入图片描述

1.2函数分析

1.2.1加载预训练模型

代码:

# Restore
model_load = tester_params['model_load']
checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
checkpoint = torch.load(checkpoint_fullname, map_location=device)
self.model.load_state_dict(checkpoint['model_state_dict'])
  • model_load: 这是一个字典,包含了从哪里加载预训练模型的路径信息以及具体的 epoch
model_load = tester_params['model_load']
  • checkpoint_fullname: 使用 Python 的字符串格式化功能,构造预训练模型的文件路径。
    这会生成形如 /path/to/model/checkpoint-8100.pt 的文件路径。即需要输入参数pathepoch
checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
  • 加载模型:
    • torch.load(checkpoint_fullname, map_location=device):从磁盘加载模型检查点(即 .pt 文件),并将其存储在 checkpoint 变量中。map_location=device 确保模型会被加载到正确的设备上(GPU 或 CPU)。
    • self.model.load_state_dict(checkpoint['model_state_dict']):从加载的检查点中提取模型的状态字典,并将其加载到 self.model 中。
checkpoint = torch.load(checkpoint_fullname, map_location=device)
self.model.load_state_dict(checkpoint['model_state_dict'])

示例
假设 tester_params_regret[‘model_load’] 如下所示:

tester_params_regret = {'model_load': {'path': '../../pretrained/vrp100','epoch': 8100,},# 其他参数...
}

然后 checkpoint_fullname 会被构造为/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/pretrained/models/checkpoint-8100.pt,模型会从该路径加载。
在这里插入图片描述


1.2函数代码

    def __init__(self,env_params,model_params,tester_params):# save argumentsself.env_params = env_paramsself.model_params = model_paramsself.tester_params = tester_params# result folder, loggerself.logger = getLogger(name='trainer')self.result_folder = get_result_folder()# cudaUSE_CUDA = self.tester_params['use_cuda']if USE_CUDA:cuda_device_num = self.tester_params['cuda_device_num']torch.cuda.set_device(cuda_device_num)device = torch.device('cuda', cuda_device_num)torch.set_default_tensor_type('torch.cuda.FloatTensor')else:device = torch.device('cpu')torch.set_default_tensor_type('torch.FloatTensor')self.device = device# ENV and MODELself.env = Env(**self.env_params)self.model = Model(**self.model_params)# Restoremodel_load = tester_params['model_load']checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)checkpoint = torch.load(checkpoint_fullname, map_location=device)self.model.load_state_dict(checkpoint['model_state_dict'])# utilityself.time_estimator = TimeEstimator()

二、class CVRPTester:run(self)

函数解析

函数执行流程图链接
在这里插入图片描述

函数代码

    def run(self):self.time_estimator.reset()score_AM = AverageMeter()aug_score_AM = AverageMeter()if self.tester_params['test_data_load']['enable']:self.env.use_saved_problems(self.tester_params['test_data_load']['filename'], self.device)test_num_episode = self.tester_params['test_episodes']episode = 0while episode < test_num_episode:remaining = test_num_episode - episodebatch_size = min(self.tester_params['test_batch_size'], remaining)score, aug_score = self._test_one_batch(batch_size)score_AM.update(score, batch_size)aug_score_AM.update(aug_score, batch_size)episode += batch_size############################# Logs############################elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(episode, test_num_episode)self.logger.info("episode {:3d}/{:3d}, Elapsed[{}], Remain[{}], score:{:.3f}, aug_score:{:.3f}".format(episode, test_num_episode, elapsed_time_str, remain_time_str, score, aug_score))all_done = (episode == test_num_episode)if all_done:self.logger.info(" *** Test Done *** ")self.logger.info(" NO-AUG SCORE: {:.4f} ".format(score_AM.avg))self.logger.info(" AUGMENTATION SCORE: {:.4f} ".format(aug_score_AM.avg))

三、class CVRPTester:_test_one_batch(self, batch_size)

函数解析

执行流程图链接
在这里插入图片描述

函数代码

    def _test_one_batch(self, batch_size):# Augmentation###############################################if self.tester_params['augmentation_enable']:aug_factor = self.tester_params['aug_factor']else:aug_factor = 1# Ready###############################################self.model.eval()with torch.no_grad():self.env.load_problems(batch_size, aug_factor)reset_state, _, _ = self.env.reset()self.model.pre_forward(reset_state)# POMO Rollout###############################################state, reward, done = self.env.pre_step()while not done:selected, _ = self.model(state)# shape: (batch, pomo)state, reward, done = self.env.step(selected)# Return###############################################aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size)# shape: (augmentation, batch, pomo)max_pomo_reward, _ = aug_reward.max(dim=2)  # get best results from pomo# shape: (augmentation, batch)no_aug_score = -max_pomo_reward[0, :].float().mean()  # negative sign to make positive valuemax_aug_pomo_reward, _ = max_pomo_reward.max(dim=0)  # get best results from augmentation# shape: (batch,)aug_score = -max_aug_pomo_reward.float().mean()  # negative sign to make positive valuereturn no_aug_score.item(), aug_score.item()

附录

代码(全)


import torchimport os
from logging import getLoggerfrom CVRPEnv import CVRPEnv as Env
from CVRPModel import CVRPModel as Modelfrom utils.utils import *class CVRPTester:def __init__(self,env_params,model_params,tester_params):# save argumentsself.env_params = env_paramsself.model_params = model_paramsself.tester_params = tester_params# result folder, loggerself.logger = getLogger(name='trainer')self.result_folder = get_result_folder()# cudaUSE_CUDA = self.tester_params['use_cuda']if USE_CUDA:cuda_device_num = self.tester_params['cuda_device_num']torch.cuda.set_device(cuda_device_num)device = torch.device('cuda', cuda_device_num)torch.set_default_tensor_type('torch.cuda.FloatTensor')else:device = torch.device('cpu')torch.set_default_tensor_type('torch.FloatTensor')self.device = device# ENV and MODELself.env = Env(**self.env_params)self.model = Model(**self.model_params)# Restoremodel_load = tester_params['model_load']checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)checkpoint = torch.load(checkpoint_fullname, map_location=device)self.model.load_state_dict(checkpoint['model_state_dict'])# utilityself.time_estimator = TimeEstimator()def run(self):self.time_estimator.reset()score_AM = AverageMeter()aug_score_AM = AverageMeter()if self.tester_params['test_data_load']['enable']:self.env.use_saved_problems(self.tester_params['test_data_load']['filename'], self.device)test_num_episode = self.tester_params['test_episodes']episode = 0while episode < test_num_episode:remaining = test_num_episode - episodebatch_size = min(self.tester_params['test_batch_size'], remaining)score, aug_score = self._test_one_batch(batch_size)score_AM.update(score, batch_size)aug_score_AM.update(aug_score, batch_size)episode += batch_size############################# Logs############################elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(episode, test_num_episode)self.logger.info("episode {:3d}/{:3d}, Elapsed[{}], Remain[{}], score:{:.3f}, aug_score:{:.3f}".format(episode, test_num_episode, elapsed_time_str, remain_time_str, score, aug_score))all_done = (episode == test_num_episode)if all_done:self.logger.info(" *** Test Done *** ")self.logger.info(" NO-AUG SCORE: {:.4f} ".format(score_AM.avg))self.logger.info(" AUGMENTATION SCORE: {:.4f} ".format(aug_score_AM.avg))def _test_one_batch(self, batch_size):# Augmentation###############################################if self.tester_params['augmentation_enable']:aug_factor = self.tester_params['aug_factor']else:aug_factor = 1# Ready###############################################self.model.eval()with torch.no_grad():self.env.load_problems(batch_size, aug_factor)reset_state, _, _ = self.env.reset()self.model.pre_forward(reset_state)# POMO Rollout###############################################state, reward, done = self.env.pre_step()while not done:selected, _ = self.model(state)# shape: (batch, pomo)state, reward, done = self.env.step(selected)# Return###############################################aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size)# shape: (augmentation, batch, pomo)max_pomo_reward, _ = aug_reward.max(dim=2)  # get best results from pomo# shape: (augmentation, batch)no_aug_score = -max_pomo_reward[0, :].float().mean()  # negative sign to make positive valuemax_aug_pomo_reward, _ = max_pomo_reward.max(dim=0)  # get best results from augmentation# shape: (batch,)aug_score = -max_aug_pomo_reward.float().mean()  # negative sign to make positive valuereturn no_aug_score.item(), aug_score.item()
http://www.dtcms.com/wzjs/519536.html

相关文章:

  • 哪家公司做网站最好适合40岁女人的培训班
  • 网站建设 苏州免费seo教程资源
  • 有没有专门做蛋糕的网站代写文章接单平台
  • 做长直播的房地产网站广州营销课程培训班
  • 网站开发做网站谷歌google官方下载
  • 连云港做网站制作搜索引擎营销总结
  • 怎样才能建网站广州新闻24小时爆料热线
  • 优秀的网站建设策划书百度助手官网
  • wordpress 密码破击重庆seo管理平台
  • 三维家装设计软件网站优化推广外包
  • 在百度上做网站推广效果怎么样360推广联盟
  • 内蒙古网络自学网站建设今日新闻联播
  • 网站广东省备案系统上海站优云网络科技有限公司
  • 没有rss源的网站如何做rss订阅seo排名优化表格工具
  • 网站怎么添加百度商桥seo推广软件下载
  • 网站设计页面如何做居中今天新闻头条新闻
  • 四川红叶建设有限公司网站成都优化网站哪家公司好
  • 个人微信号做网站行吗百度视频广告怎么投放
  • web前端开发的软件关键词seo排名怎么样
  • 二手车为什么做网站陕西优化疫情防控措施
  • 高唐网站制作网络营销推广的方法
  • 北京运营推广网站建设百度收录官网
  • 网站跟app区别网络技术培训
  • wordpress 做分销株洲seo
  • 平面设计做兼职网站seo专业培训需要多久
  • flash网站案例今日头条官网登录入口
  • 个人网站的设计与实现毕业论文参考文献关键字排名优化工具
  • 河南做网站企起企业网站建设步骤
  • 如何建立公司网站多少钱百度com打开
  • 淄博网站建设优化运营怎么给客户推广自己的产品