transformer 输入三视图线段输出长宽高 笔记
写个代码,用transformer架构,框架pytorch,输入正交三视图的线段坐标3*4*2*2+对应视图编码,输出立方体的长宽高 以及基于投影矩阵的立方体数据(坐标最小值(0,0))生成器,
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np# 基于投影矩阵的立方体数据生成器
class CubeDataGenerator:def __init__(self, min_size=1.0, max_size=10.0, batch_size=32):self.min_size = min_sizeself.max_size = max_sizeself.batch_size = batch_sizedef generate(self):# 随机生成立方体尺寸lengths = torch.rand(self.batch_size, 3) * (self.max_size - self.min_size) + self.min_size# 根据立方体尺寸生成三个视图的线段坐标 3视图 * 4线段 * 2点 * 2坐标# 这里简化处理,每个视图有4条线段line_segments = torch.zeros(self.batch_size, 3, 4, 2, 2)# 正视图线段坐标for b in range(self.batch_size):l, w, h = lengths[b]# 正视图(从前向后看)line_segments[b, 0, 0] = torch.tensor([[0.0, 0.0], [l, 0.0]]) # 底边line_segments[b, 0, 1] = torch.tensor([[l, 0.0], [l, h]]) # 右边line_segments[b, 0, 2] = torch.tensor([[l, h], [0.0, h]]) # 顶边line_segments[b, 0, 3] = torch.tensor([[0.0, h], [0.0, 0.0]]) # 左边# 侧视图(从左向右看)line_segments[b, 1, 0] = torch.tensor([[0.0, 0.0], [w, 0.0]]) # 底边line_segments[b, 1, 1] = torch.tensor([[w, 0.0], [w, h]]) # 右边line_segments[b, 1, 2] = torch.tensor([[w, h], [0.0, h]]) # 顶边line_segments[b, 1, 3] = torch.tensor([[0.0, h], [0.0, 0.0]]) # 左边# 顶视图(从上向下看)line_segments[b, 2, 0] = torch.tensor([[0.0, 0.0], [l, 0.0]]) # 底边line_segments[b, 2, 1] = torch.tensor([[l, 0.0], [l, w]]) # 右边line_segments[b, 2, 2] = torch.tensor([[l, w], [0.0, w]]) # 顶边line_segments[b, 2, 3] = torch.tensor([[0.0, w], [0.0, 0.0]]) # 左边# 生成视图编码,3个视图对应3个不同的编码view_codes = torch.eye(3).unsqueeze(0).repeat(self.batch_size, 1, 1) # 形状为 batch_size * 3视图 * 3return line_segments, lengths, view_codes# Transformer 架构模型
class TransformerCubePredictor(nn.Module):def __init__(self, d_model=64, num_heads=8, num_layers=2, input_dim=4, output_dim=3):super(TransformerCubePredictor, self).__init__()self.d_model = d_model# 线段坐标嵌入self.line_segment_embed = nn.Linear(input_dim*2*2, d_model) # 输入是线段坐标,每个线段有4个点坐标# 视图编码嵌入self.view_embed = nn.Linear(3, d_model)# Transformer 编码器encoder_layer = nn.TransformerEncoderLayer(d_model, num_heads)self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)# 前馈网络self.fc = nn.Sequential(nn.Linear(d_model, d_model*2),nn.ReLU(),nn.Linear(d_model*2, d_model),nn.ReLU(),nn.Linear(d_model, output_dim))def forward(self, line_segments, view_codes):# 线段坐标形状: batch_size * 3视图 * 4线段 * 2点 * 2坐标batch_size, num_views, num_lines, _, _ = line_segments.shape# 展平线段坐标line_segments_flattened = line_segments.view(batch_size, num_views, num_lines, -1) # batch_size * 3视图 * 4线段 * 8# 嵌入线段坐标和视图编码line_embeddings = self.line_segment_embed(line_segments_flattened) # batch_size * 3视图 * 4线段 * d_modelview_embeddings = self.view_embed(view_codes.unsqueeze(2).repeat(1, 1, 4, 1)) # 3视图编码扩展到每个线段# 合并嵌入embeddings = line_embeddings + view_embeddings# 调整形状以适应Transformerembeddings = embeddings.permute(1, 0, 2, 3).reshape(num_views*4, batch_size, -1) # seq_length * batch_size * d_model# Transformer 编码encoded = self.transformer_encoder(embeddings)# 池化得到固定尺寸的表示pooled = torch.mean(encoded, dim=0) # batch_size * d_model# 预测立方体尺寸predicted_sizes = self.fc(pooled)return predicted_sizes# 测试代码
if __name__ == "__main__":# 创建数据生成器generator = CubeDataGenerator(min_size=1.0, max_size=10.0, batch_size=2)# 生成数据line_segments, true_sizes, view_codes = generator.generate()# 创建模型model = TransformerCubePredictor(input_dim=2)# 前向传播predicted_sizes = model(line_segments, view_codes)# 打印结果print("True Sizes:\n", true_sizes)print("Predicted Sizes:\n", predicted_sizes)
# 训练脚本
def train():# 数据生成器generator = CubeDataGenerator(min_size=1.0, max_size=10.0, batch_size=32)# 创建模型model = TransformerCubePredictor(input_dim=2)# 定义损失函数和优化器criterion = nn.MSELoss() # 均方误差损失optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练参数num_epochs = 100# 训练循环for epoch in range(num_epochs):# 生成训练数据line_segments, true_sizes, view_codes = generator.generate()# 前向传播model.train() # 设置模型为训练模式optimizer.zero_grad() # 清空梯度predicted_sizes = model(line_segments, view_codes)# 计算损失loss = criterion(predicted_sizes, true_sizes)# 反向传播和优化loss.backward()optimizer.step()# 打印训练信息if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')print("Training completed.")# 执行训练
if __name__ == "__main__":train()