当前位置: 首页 > wzjs >正文

武汉快递最新消息谷歌seo博客

武汉快递最新消息,谷歌seo博客,电脑版浏览器在线使用,手机创建微信公众号完整代码在文末,可以一键运行。 1. 核心原理 Codebook是一种离散表征学习方法,其核心思想是将连续特征空间映射到离散的码本空间。我们的实现方案包含三个关键组件: 1.1 ViT编码器 class ViTEncoder(nn.Module):def __init__(self, codebo…

完整代码在文末,可以一键运行。

在这里插入图片描述

1. 核心原理

Codebook是一种离散表征学习方法,其核心思想是将连续特征空间映射到离散的码本空间。我们的实现方案包含三个关键组件:

1.1 ViT编码器

class ViTEncoder(nn.Module):def __init__(self, codebook_dim=512):super().__init__()self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")self.proj = nn.Linear(768, codebook_dim)def forward(self, x):outputs = self.vit(x).last_hidden_statepatch_embeddings = outputs[:, 1:, :]  # 移除CLS tokenreturn self.proj(patch_embeddings)
  • 使用预训练的ViT-Base模型提取图像特征
  • 移除CLS token,保留196个图像块特征
  • 线性投影调整特征维度适配Codebook

1.2 Codebook量化层

class Codebook(nn.Module):def __init__(self, num_embeddings=1024, embedding_dim=512):super().__init__()self.codebook = nn.Embedding(num_embeddings, embedding_dim)def quantize(self, z):# 计算L2距离distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)# 最近邻查找indices = torch.argmin(distances, dim=1)return indices, self.codebook(indices)
  • 使用可学习的Embedding层存储离散码本
  • 通过L2距离计算实现最近邻查找
  • 支持EMA更新(代码中已注释部分)

1.3 ViT解码器

class ViTDecoder(nn.Module):def __init__(self):self.head = nn.Sequential(nn.ConvTranspose2d(768, 384, 4, 2, 1),nn.ReLU(),... # 更多上采样层nn.Conv2d(48, 3, 1))
  • 使用转置卷积逐步上采样
  • 最终输出224x224分辨率图像
  • 与编码器形成对称结构

2. 训练策略

2.1 多目标损失函数

total_loss = mse_loss + 0.1*percep_loss + codebook_loss + commitment_loss
  • MSE Loss: 像素级重建误差
  • Perceptual Loss: VGG16特征匹配
  • Codebook Loss: 码本向量优化
  • Commitment Loss: 编码器输出稳定性

2.2 优化技巧

opt = torch.optim.Adam([{'params': encoder.parameters()},{'params': decoder.parameters()},{'params': codebook.parameters(), 'lr': 1e-4}
], lr=3e-4)
  • 分层学习率设置
  • EMA指数平滑更新
  • 混合精度训练支持
  • 动态学习率调整

3. 完整训练流程

3.1 数据准备

transform_train = transforms.Compose([transforms.Resize(224),transforms.RandomCrop(224, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(...)
])
  • CIFAR-10数据集
  • 随机裁剪+翻转增强
  • Batch Size=4适配显存

3.2 训练监控

# TensorBoard记录
writer.add_scalar('Loss/total', total_loss.item(), global_step)
writer.add_image('Reconstruction', grid, global_step)# 控制台日志
print(f"[Epoch {epoch+1:03d}] Loss: {total_loss.item():.4f}")

完整代码

from transformers import ViTModel, ViTConfig
import torch.nn as nn
import torch
import time
from tqdm import tqdm
class ViTEncoder(nn.Module):def __init__(self, codebook_dim=512):super().__init__()# 加载预训练ViT-Base模型self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")# 调整输出维度匹配Codebookself.proj = nn.Linear(768, codebook_dim)  # 网页2/6中的线性嵌入策略def forward(self, x):outputs = self.vit(x).last_hidden_state  # [batch, num_patches+1, 768]patch_embeddings = outputs[:, 1:, :]     # 移除CLS tokenreturn self.proj(patch_embeddings)       # [batch, 196, 512]class Codebook(nn.Module):def __init__(self, num_embeddings=16384, embedding_dim=512):super().__init__()self.codebook = nn.Embedding(num_embeddings, embedding_dim)nn.init.normal_(self.codebook.weight)  # 网页1的EMA更新可在此扩展def quantize(self, z):"""量化输入特征向量参数:z: 输入特征 [batch, num_patches, embedding_dim]返回:indices: 最近邻码本索引 [batch, num_patches]quantized: 量化后的特征 [batch, num_patches, embedding_dim]"""# 重塑输入为二维矩阵 [batch*num_patches, embedding_dim]batch, num_patches, dim = z.shapez_flat = z.reshape(-1, dim)  # [batch*num_patches, dim]# 计算L2距离 ||z - e||^2 = ||z||^2 - 2<z,e> + ||e||^2z_norm = torch.sum(z_flat ** 2, dim=1, keepdim=True)  # [batch*num_patches, 1]e_norm = torch.sum(self.codebook.weight ** 2, dim=1)  # [num_embeddings]dot_product = torch.matmul(z_flat, self.codebook.weight.t())  # [batch*num_patches, num_embeddings]distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)# 找到最近邻indices = torch.argmin(distances, dim=1)  # [batch*num_patches]indices = indices.reshape(batch, num_patches)  # 恢复原始形状quantized = self.codebook(indices)  # [batch, num_patches, dim]return indices, quantized
class ViTDecoder(nn.Module):def __init__(self, in_dim=512):super().__init__()# 反向映射ViT的patch嵌入self.proj = nn.Linear(in_dim, 768)config = ViTConfig()config.is_decoder = True  # 网页7中的解码器模式self.transformer = ViTModel(config).encoder  self.head = nn.Sequential(# 14x14 -> 28x28nn.ConvTranspose2d(768, 384, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 28x28 -> 56x56nn.ConvTranspose2d(384, 192, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 56x56 -> 112x112 nn.ConvTranspose2d(192, 96, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 112x112 -> 224x224nn.ConvTranspose2d(96, 48, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 最终调整到3通道nn.Conv2d(48, 3, kernel_size=1))def forward(self, x):x = self.proj(x)  # [batch, 196, 768]x = self.transformer(x).last_hidden_statex = x.permute(0, 2, 1).view(-1, 768, 14, 14)  # 恢复空间布局return self.head(x)  # 输出[1, 3, 224, 224]
# encoder = ViTEncoder()
# codebooker = Codebook()
# decoder = ViTDecoder()# data = torch.randn(1, 3, 224, 224)
# output = encoder(data)
# print(output.shape)
# indices, quantized = codebooker.quantize(output)
# print(indices.shape, quantized.shape)
# reconstructed = decoder(quantized)
# print(reconstructed.shape)from torchvision import transforms
import torchvision
import torch.nn.functional as F
# 数据增强和预处理
transform_train = transforms.Compose([transforms.Resize(224),  # 调整图像尺寸适配模型transforms.RandomCrop(224, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])transform_test = transforms.Compose([transforms.Resize(224),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# trainloader = torch.DataLoader(trainset, batch_size=64, shuffle=True)
# 加载CIFAR-10数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)batch_size = 4  # 增大batch size加速训练
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
from torchvision.models import vgg16# 初始化TensorBoard
writer = SummaryWriter('runs/codebook_experiment')# 改进的Codebook类(增加EMA更新)
class Codebook(nn.Module):def __init__(self, num_embeddings=1024, embedding_dim=512, commitment_cost=0.25, decay=0.99):super().__init__()self.codebook = nn.Embedding(num_embeddings, embedding_dim)nn.init.normal_(self.codebook.weight)self.commitment_cost = commitment_costself.decay = decayself.register_buffer('ema_cluster_size', torch.zeros(num_embeddings))self.ema_w = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))nn.init.normal_(self.ema_w)def quantize(self, z):# 重塑输入为二维矩阵 [batch*num_patches, embedding_dim]batch, num_patches, dim = z.shapez_flat = z.reshape(-1, dim)  # [batch*num_patches, dim]# 计算L2距离 ||z - e||^2 = ||z||^2 - 2<z,e> + ||e||^2z_norm = torch.sum(z_flat ** 2, dim=1, keepdim=True)  # [batch*num_patches, 1]e_norm = torch.sum(self.codebook.weight ** 2, dim=1)  # [num_embeddings]dot_product = torch.matmul(z_flat, self.codebook.weight.t())  # [batch*num_patches, num_embeddings]distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)# 找到最近邻indices = torch.argmin(distances, dim=1)  # [batch*num_patches]indices = indices.reshape(batch, num_patches)  # 恢复原始形状quantized = self.codebook(indices)  # [batch, num_patches, dim]# 新增EMA更新# if self.training:#     with torch.no_grad():#         encodings = F.one_hot(indices, self.codebook.num_embeddings).float()#         self.ema_cluster_size = self.decay * self.ema_cluster_size + (1 - self.decay) * torch.sum(encodings, 0)#         n = torch.sum(self.ema_cluster_size)#         self.ema_cluster_size = ((self.ema_cluster_size + 1e-5) / (n + self.codebook.num_embeddings * 1e-5) * n)#         dw = torch.matmul(encodings.t(), z_flat)#         self.ema_w = nn.Parameter(self.ema_w * self.decay + (1 - self.decay) * dw)#         self.codebook.weight.data = self.ema_w / self.ema_cluster_size.unsqueeze(1)return indices, quantized
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 初始化组件
encoder = ViTEncoder().to(device)
codebook = Codebook(commitment_cost=0.25, decay=0.95).to(device)
decoder = ViTDecoder().to(device)
vgg = vgg16(pretrained=True).features[:16].eval().to(device)  # 用于感知损失# 优化器分开设置
opt = torch.optim.Adam([{'params': encoder.parameters()},{'params': decoder.parameters()},{'params': codebook.parameters(), 'lr': 1e-4}  # 更小的学习率
], lr=3e-4)# 训练循环
for epoch in range(100):avg_loss = 0start_time = time.time()  # 记录epoch开始时间for batch_idx, (images, _) in enumerate(tqdm(trainloader, desc=f"Epoch {epoch+1}", ncols=80)):images = images.to(device)# 前向传播z = encoder(images)indices, quantized = codebook.quantize(z)recon = decoder(quantized)# 多目标损失计算mse_loss = F.mse_loss(recon, images)# 感知损失(VGG特征匹配)with torch.no_grad():real_features = vgg(images)recon_features = vgg(recon)percep_loss = F.mse_loss(recon_features, real_features)# Codebook相关损失commitment_loss = codebook.commitment_cost * F.mse_loss(z.detach(), quantized)codebook_loss = F.mse_loss(z, quantized.detach())# 总损失total_loss = mse_loss + 0.1*percep_loss + codebook_loss + commitment_loss# 反向传播opt.zero_grad()total_loss.backward()opt.step()# 记录数据avg_loss += total_loss.item()if batch_idx % 50 == 0:# 记录TensorBoard数据writer.add_scalar('Loss/total', total_loss.item(), epoch*len(trainloader)+batch_idx)writer.add_scalars('Loss/components', {'mse': mse_loss.item(),'perceptual': percep_loss.item(),'codebook': codebook_loss.item(),'commitment': commitment_loss.item()}, epoch*len(trainloader)+batch_idx)# 保存重建样本comparison = torch.cat([images[:4], recon[:4]])grid = vutils.make_grid(comparison.cpu(), nrow=4, normalize=True)writer.add_image('Reconstruction', grid, epoch*len(trainloader)+batch_idx)# 打印epoch统计信息avg_loss /= len(trainloader)print(f"Epoch {epoch+1}: Avg Loss {avg_loss:.4f}")# 保存模型检查点if (epoch+1) % 10 == 0:torch.save({'encoder': encoder.state_dict(),'codebook': codebook.state_dict(),'decoder': decoder.state_dict(),'opt': opt.state_dict()}, f'checkpoint_epoch{epoch+1}.pth')writer.close()

通过本实践,我们实现了从特征提取到离散表征学习的完整流程。Codebook技术可广泛应用于图像压缩、生成模型等领域,期待读者在此基础上探索更多可能性。

http://www.dtcms.com/wzjs/299602.html

相关文章:

  • 手机在线做ppt的网站ciliba磁力猫
  • 销售行业怎样做网站营销策划书格式及范文
  • 房屋装修设计网站成品网站建站空间
  • 像网站的ppt怎么做的活动推广方案策划
  • 男人女人做邪恶的事网站最好的搜索引擎
  • 烟台高端网站建设公司怎么做竞价托管
  • 如何免费注册网站网络广告的形式有哪些?
  • 建手机网站怎么收费苏州seo报价
  • 宁波江北建设局官方网站百度账号登录
  • 中国网站建设公司2345网址导航 中国最
  • wordpress管理网站网页在线代理翻墙
  • 腾讯微校怎么做微网站龙华网站建设
  • 网站建设开发教程合肥关键词排名提升
  • 广州网站建设 企业seo搜索引擎入门教程
  • 网站动图banner怎么做中国十大软件外包公司
  • 太原网站建设外包须知传媒企业推广软文
  • 做网站美工未来规划游戏广告投放平台
  • 桂林做手机网站设计重庆网站seo搜索引擎优化
  • wordpress3.9下载开鲁网站seo转接
  • 将自己做的网站入到阿里云域名上徐州自动seo
  • 武汉悠牛网网站建设加强服务保障 满足群众急需需求
  • 专业手机网站制作哪家好广州网站seo公司
  • 营销类网站百度在线翻译
  • dedecms学校网站模板免费下载橘子seo
  • 外贸网站推广实操手册网络销售话术900句
  • 石景山网站建设好的公司浏览器2345网址导航下载安装
  • 网站做快照seo怎么学在哪里学
  • 网页设计与网站建设作业怎么做sem营销
  • 深圳住建局竣工备案查询官网北京网站优化方案
  • 推介网app百度seo排名点击器app