当前位置: 首页 > news >正文

20250412_代码笔记_CVRProblemDef

文章目录

  • 前言
  • 一、get_random_problems 函数分析
  • 二、augment_xy_data_by_8_fold 函数分析
  • 代码


前言

该笔记分析代码的功能是生成随机VRP问题的数据,包含仓库坐标、节点坐标和节点需求。

对该代码进行改进
20250412-代码改进-拟蒙特卡洛


一、get_random_problems 函数分析

depot_xy = torch.rand(size=(batch_size, 1, 2))
  • 生成仓库坐标:
    • 生成形状为(batch_size, 1, 2) 的随机张量,表示每个批次中仓库的二维坐标(范围 [0,1))。
node_xy = torch.rand(size=(batch_size, problem_size, 2))
  • 生成节点坐标:
    • 生成形状为 (batch_size, problem_size, 2) 的随机张量,表示每个批次中所有节点的二维坐标。
if problem_size == 20:
    demand_scaler = 30
elif problem_size == 50:
    demand_scaler = 40
elif problem_size == 100:
    demand_scaler = 50
node_demand = torch.randint(1, 10, size=(batch_size, problem_size)) / demand_scaler
  • 生成节点需求:
    • 根据 problem_size 选择缩放因子 demand_scaler
    • 生成 1~9 的整数需求,并缩放到 [1/50, 9/50] 等区间,确保需求值为浮点数。

二、augment_xy_data_by_8_fold 函数分析

功能:通过8种几何变换对坐标数据进行增强,扩充数据集。

x = xy_data[:, :, [0]]  # 提取x坐标
y = xy_data[:, :, [1]]  # 提取y坐标
  • 拆分坐标:
    • 从输入数据 xy_data(形状 (batch, N, 2))分离出x和y分量。
dat1 = torch.cat((x, y), dim=2)          # 原始坐标
dat2 = torch.cat((1 - x, y), dim=2)      # x轴镜像
dat3 = torch.cat((x, 1 - y), dim=2)      # y轴镜像
dat4 = torch.cat((1 - x, 1 - y), dim=2)  # x+y轴镜像
dat5 = torch.cat((y, x), dim=2)          # 转置坐标
dat6 = torch.cat((1 - y, x), dim=2)      # 转置后x轴镜像
dat7 = torch.cat((y, 1 - x), dim=2)      # 转置后y轴镜像
dat8 = torch.cat((1 - y, 1 - x), dim=2)  # 转置后x+y轴镜像
  • 生成8种变换:
    • 对坐标进行镜像翻转和转置操作,生成8种变体。
aug_xy_data = torch.cat((dat1, dat2, ..., dat8), dim=0)
  • 合并增强数据:
  • 将8种变换后的数据沿批次维度拼接,最终形状为 (8*batch, N, 2)

代码


import torch
import numpy as np


def get_random_problems(batch_size, problem_size):

    depot_xy = torch.rand(size=(batch_size, 1, 2))
    # shape: (batch, 1, 2)

    node_xy = torch.rand(size=(batch_size, problem_size, 2))
    # shape: (batch, problem, 2)

    if problem_size == 20:
        demand_scaler = 30
    elif problem_size == 50:
        demand_scaler = 40
    elif problem_size == 100:
        demand_scaler = 50
    else:
        raise NotImplementedError

    node_demand = torch.randint(1, 10, size=(batch_size, problem_size)) / float(demand_scaler)
    # shape: (batch, problem)

    return depot_xy, node_xy, node_demand


def augment_xy_data_by_8_fold(xy_data):
    # xy_data.shape: (batch, N, 2)

    x = xy_data[:, :, [0]]
    y = xy_data[:, :, [1]]
    # x,y shape: (batch, N, 1)

    dat1 = torch.cat((x, y), dim=2)
    dat2 = torch.cat((1 - x, y), dim=2)
    dat3 = torch.cat((x, 1 - y), dim=2)
    dat4 = torch.cat((1 - x, 1 - y), dim=2)
    dat5 = torch.cat((y, x), dim=2)
    dat6 = torch.cat((1 - y, x), dim=2)
    dat7 = torch.cat((y, 1 - x), dim=2)
    dat8 = torch.cat((1 - y, 1 - x), dim=2)

    aug_xy_data = torch.cat((dat1, dat2, dat3, dat4, dat5, dat6, dat7, dat8), dim=0)
    # shape: (8*batch, N, 2)

    return aug_xy_data

相关文章:

  • js 颜色转换分析
  • 【Flink运行时架构】核心组件
  • 优化方法介绍(一)
  • PCIe 5.0光学SSD原型问世!
  • 2025-4-11 情绪周期视角复盘(mini)
  • java -jar与java -cp的区别
  • 操作系统 ------ 五种IO模型
  • 前端工程化-包管理NPM-package.json 和 package-lock.json 详解
  • 小甲鱼第004讲:变量和字符串(下)| 课后测试题及答案
  • Git基础知识
  • 蓝桥杯单片机刷题——ADC测量电位器的电压
  • 基于FPGA的六层电梯智能控制系统 矩阵键盘-数码管 上板仿真均验证通过
  • 深入解析Python爬虫技术:从基础到实战的功能工具开发指南
  • python文件打包无法导入ultralytics模块
  • 4月12日随笔
  • 【区块链安全 | 第三十九篇】合约审计之delegatecall(一)
  • 通信中的 “bps“ 含义及详解
  • linux小白对系统环境变量的一些不解和迷惑解析
  • Python(10.2)Python可变与不可变类型内存机制解密:从底层原理到工程实践
  • C 语言 - 右左法则与实践练习题 答案解析
  • dede网站限制IP浏览/济宁百度推广公司
  • 推广自己的店铺推广语/搜索引擎优化的内容包括
  • 网站的域名每年都要续费/产品营销推广策略
  • 淘宝网站网页图片怎么做/网络营销服务的特点
  • 用c语言可以做网站吗/电子商务网络营销
  • 洛阳做网站的公司/做好网络推广