当前位置: 首页 > wzjs >正文

谢岗仿做网站深圳网站制作设计

谢岗仿做网站,深圳网站制作设计,北京建站模板制作,国内有哪些比较好的做定制旅游网站完整代码在文末,可以一键运行。 1. 核心原理 Codebook是一种离散表征学习方法,其核心思想是将连续特征空间映射到离散的码本空间。我们的实现方案包含三个关键组件: 1.1 ViT编码器 class ViTEncoder(nn.Module):def __init__(self, codebo…

完整代码在文末,可以一键运行。

在这里插入图片描述

1. 核心原理

Codebook是一种离散表征学习方法,其核心思想是将连续特征空间映射到离散的码本空间。我们的实现方案包含三个关键组件:

1.1 ViT编码器

class ViTEncoder(nn.Module):def __init__(self, codebook_dim=512):super().__init__()self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")self.proj = nn.Linear(768, codebook_dim)def forward(self, x):outputs = self.vit(x).last_hidden_statepatch_embeddings = outputs[:, 1:, :]  # 移除CLS tokenreturn self.proj(patch_embeddings)
  • 使用预训练的ViT-Base模型提取图像特征
  • 移除CLS token,保留196个图像块特征
  • 线性投影调整特征维度适配Codebook

1.2 Codebook量化层

class Codebook(nn.Module):def __init__(self, num_embeddings=1024, embedding_dim=512):super().__init__()self.codebook = nn.Embedding(num_embeddings, embedding_dim)def quantize(self, z):# 计算L2距离distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)# 最近邻查找indices = torch.argmin(distances, dim=1)return indices, self.codebook(indices)
  • 使用可学习的Embedding层存储离散码本
  • 通过L2距离计算实现最近邻查找
  • 支持EMA更新(代码中已注释部分)

1.3 ViT解码器

class ViTDecoder(nn.Module):def __init__(self):self.head = nn.Sequential(nn.ConvTranspose2d(768, 384, 4, 2, 1),nn.ReLU(),... # 更多上采样层nn.Conv2d(48, 3, 1))
  • 使用转置卷积逐步上采样
  • 最终输出224x224分辨率图像
  • 与编码器形成对称结构

2. 训练策略

2.1 多目标损失函数

total_loss = mse_loss + 0.1*percep_loss + codebook_loss + commitment_loss
  • MSE Loss: 像素级重建误差
  • Perceptual Loss: VGG16特征匹配
  • Codebook Loss: 码本向量优化
  • Commitment Loss: 编码器输出稳定性

2.2 优化技巧

opt = torch.optim.Adam([{'params': encoder.parameters()},{'params': decoder.parameters()},{'params': codebook.parameters(), 'lr': 1e-4}
], lr=3e-4)
  • 分层学习率设置
  • EMA指数平滑更新
  • 混合精度训练支持
  • 动态学习率调整

3. 完整训练流程

3.1 数据准备

transform_train = transforms.Compose([transforms.Resize(224),transforms.RandomCrop(224, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(...)
])
  • CIFAR-10数据集
  • 随机裁剪+翻转增强
  • Batch Size=4适配显存

3.2 训练监控

# TensorBoard记录
writer.add_scalar('Loss/total', total_loss.item(), global_step)
writer.add_image('Reconstruction', grid, global_step)# 控制台日志
print(f"[Epoch {epoch+1:03d}] Loss: {total_loss.item():.4f}")

完整代码

from transformers import ViTModel, ViTConfig
import torch.nn as nn
import torch
import time
from tqdm import tqdm
class ViTEncoder(nn.Module):def __init__(self, codebook_dim=512):super().__init__()# 加载预训练ViT-Base模型self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")# 调整输出维度匹配Codebookself.proj = nn.Linear(768, codebook_dim)  # 网页2/6中的线性嵌入策略def forward(self, x):outputs = self.vit(x).last_hidden_state  # [batch, num_patches+1, 768]patch_embeddings = outputs[:, 1:, :]     # 移除CLS tokenreturn self.proj(patch_embeddings)       # [batch, 196, 512]class Codebook(nn.Module):def __init__(self, num_embeddings=16384, embedding_dim=512):super().__init__()self.codebook = nn.Embedding(num_embeddings, embedding_dim)nn.init.normal_(self.codebook.weight)  # 网页1的EMA更新可在此扩展def quantize(self, z):"""量化输入特征向量参数:z: 输入特征 [batch, num_patches, embedding_dim]返回:indices: 最近邻码本索引 [batch, num_patches]quantized: 量化后的特征 [batch, num_patches, embedding_dim]"""# 重塑输入为二维矩阵 [batch*num_patches, embedding_dim]batch, num_patches, dim = z.shapez_flat = z.reshape(-1, dim)  # [batch*num_patches, dim]# 计算L2距离 ||z - e||^2 = ||z||^2 - 2<z,e> + ||e||^2z_norm = torch.sum(z_flat ** 2, dim=1, keepdim=True)  # [batch*num_patches, 1]e_norm = torch.sum(self.codebook.weight ** 2, dim=1)  # [num_embeddings]dot_product = torch.matmul(z_flat, self.codebook.weight.t())  # [batch*num_patches, num_embeddings]distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)# 找到最近邻indices = torch.argmin(distances, dim=1)  # [batch*num_patches]indices = indices.reshape(batch, num_patches)  # 恢复原始形状quantized = self.codebook(indices)  # [batch, num_patches, dim]return indices, quantized
class ViTDecoder(nn.Module):def __init__(self, in_dim=512):super().__init__()# 反向映射ViT的patch嵌入self.proj = nn.Linear(in_dim, 768)config = ViTConfig()config.is_decoder = True  # 网页7中的解码器模式self.transformer = ViTModel(config).encoder  self.head = nn.Sequential(# 14x14 -> 28x28nn.ConvTranspose2d(768, 384, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 28x28 -> 56x56nn.ConvTranspose2d(384, 192, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 56x56 -> 112x112 nn.ConvTranspose2d(192, 96, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 112x112 -> 224x224nn.ConvTranspose2d(96, 48, kernel_size=4, stride=2, padding=1),nn.ReLU(),# 最终调整到3通道nn.Conv2d(48, 3, kernel_size=1))def forward(self, x):x = self.proj(x)  # [batch, 196, 768]x = self.transformer(x).last_hidden_statex = x.permute(0, 2, 1).view(-1, 768, 14, 14)  # 恢复空间布局return self.head(x)  # 输出[1, 3, 224, 224]
# encoder = ViTEncoder()
# codebooker = Codebook()
# decoder = ViTDecoder()# data = torch.randn(1, 3, 224, 224)
# output = encoder(data)
# print(output.shape)
# indices, quantized = codebooker.quantize(output)
# print(indices.shape, quantized.shape)
# reconstructed = decoder(quantized)
# print(reconstructed.shape)from torchvision import transforms
import torchvision
import torch.nn.functional as F
# 数据增强和预处理
transform_train = transforms.Compose([transforms.Resize(224),  # 调整图像尺寸适配模型transforms.RandomCrop(224, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])transform_test = transforms.Compose([transforms.Resize(224),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# trainloader = torch.DataLoader(trainset, batch_size=64, shuffle=True)
# 加载CIFAR-10数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)batch_size = 4  # 增大batch size加速训练
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
from torchvision.models import vgg16# 初始化TensorBoard
writer = SummaryWriter('runs/codebook_experiment')# 改进的Codebook类(增加EMA更新)
class Codebook(nn.Module):def __init__(self, num_embeddings=1024, embedding_dim=512, commitment_cost=0.25, decay=0.99):super().__init__()self.codebook = nn.Embedding(num_embeddings, embedding_dim)nn.init.normal_(self.codebook.weight)self.commitment_cost = commitment_costself.decay = decayself.register_buffer('ema_cluster_size', torch.zeros(num_embeddings))self.ema_w = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))nn.init.normal_(self.ema_w)def quantize(self, z):# 重塑输入为二维矩阵 [batch*num_patches, embedding_dim]batch, num_patches, dim = z.shapez_flat = z.reshape(-1, dim)  # [batch*num_patches, dim]# 计算L2距离 ||z - e||^2 = ||z||^2 - 2<z,e> + ||e||^2z_norm = torch.sum(z_flat ** 2, dim=1, keepdim=True)  # [batch*num_patches, 1]e_norm = torch.sum(self.codebook.weight ** 2, dim=1)  # [num_embeddings]dot_product = torch.matmul(z_flat, self.codebook.weight.t())  # [batch*num_patches, num_embeddings]distances = z_norm - 2 * dot_product + e_norm.unsqueeze(0)# 找到最近邻indices = torch.argmin(distances, dim=1)  # [batch*num_patches]indices = indices.reshape(batch, num_patches)  # 恢复原始形状quantized = self.codebook(indices)  # [batch, num_patches, dim]# 新增EMA更新# if self.training:#     with torch.no_grad():#         encodings = F.one_hot(indices, self.codebook.num_embeddings).float()#         self.ema_cluster_size = self.decay * self.ema_cluster_size + (1 - self.decay) * torch.sum(encodings, 0)#         n = torch.sum(self.ema_cluster_size)#         self.ema_cluster_size = ((self.ema_cluster_size + 1e-5) / (n + self.codebook.num_embeddings * 1e-5) * n)#         dw = torch.matmul(encodings.t(), z_flat)#         self.ema_w = nn.Parameter(self.ema_w * self.decay + (1 - self.decay) * dw)#         self.codebook.weight.data = self.ema_w / self.ema_cluster_size.unsqueeze(1)return indices, quantized
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 初始化组件
encoder = ViTEncoder().to(device)
codebook = Codebook(commitment_cost=0.25, decay=0.95).to(device)
decoder = ViTDecoder().to(device)
vgg = vgg16(pretrained=True).features[:16].eval().to(device)  # 用于感知损失# 优化器分开设置
opt = torch.optim.Adam([{'params': encoder.parameters()},{'params': decoder.parameters()},{'params': codebook.parameters(), 'lr': 1e-4}  # 更小的学习率
], lr=3e-4)# 训练循环
for epoch in range(100):avg_loss = 0start_time = time.time()  # 记录epoch开始时间for batch_idx, (images, _) in enumerate(tqdm(trainloader, desc=f"Epoch {epoch+1}", ncols=80)):images = images.to(device)# 前向传播z = encoder(images)indices, quantized = codebook.quantize(z)recon = decoder(quantized)# 多目标损失计算mse_loss = F.mse_loss(recon, images)# 感知损失(VGG特征匹配)with torch.no_grad():real_features = vgg(images)recon_features = vgg(recon)percep_loss = F.mse_loss(recon_features, real_features)# Codebook相关损失commitment_loss = codebook.commitment_cost * F.mse_loss(z.detach(), quantized)codebook_loss = F.mse_loss(z, quantized.detach())# 总损失total_loss = mse_loss + 0.1*percep_loss + codebook_loss + commitment_loss# 反向传播opt.zero_grad()total_loss.backward()opt.step()# 记录数据avg_loss += total_loss.item()if batch_idx % 50 == 0:# 记录TensorBoard数据writer.add_scalar('Loss/total', total_loss.item(), epoch*len(trainloader)+batch_idx)writer.add_scalars('Loss/components', {'mse': mse_loss.item(),'perceptual': percep_loss.item(),'codebook': codebook_loss.item(),'commitment': commitment_loss.item()}, epoch*len(trainloader)+batch_idx)# 保存重建样本comparison = torch.cat([images[:4], recon[:4]])grid = vutils.make_grid(comparison.cpu(), nrow=4, normalize=True)writer.add_image('Reconstruction', grid, epoch*len(trainloader)+batch_idx)# 打印epoch统计信息avg_loss /= len(trainloader)print(f"Epoch {epoch+1}: Avg Loss {avg_loss:.4f}")# 保存模型检查点if (epoch+1) % 10 == 0:torch.save({'encoder': encoder.state_dict(),'codebook': codebook.state_dict(),'decoder': decoder.state_dict(),'opt': opt.state_dict()}, f'checkpoint_epoch{epoch+1}.pth')writer.close()

通过本实践,我们实现了从特征提取到离散表征学习的完整流程。Codebook技术可广泛应用于图像压缩、生成模型等领域,期待读者在此基础上探索更多可能性。

http://www.dtcms.com/wzjs/463072.html

相关文章:

  • 企业网站管理系统用哪个好品牌推广方案策划书
  • 深圳产品型网站建设16种营销模型
  • 网站界面设计的优点站长之家关键词查询
  • 关键词seo如何优化网站排名优化查询
  • 搭建手机网站创建网站的基本步骤
  • 网站建设和设计网站死链检测工具
  • iis网站配置教程自媒体平台
  • 淄博网站建设服务抖音推广运营
  • 如何申请个人网站国际新闻最新消息今天 新闻
  • 成都网站logo设计seo确定关键词
  • 佛山建站模板2021年经典营销案例
  • 兼职做一篇微信的网站手机自动排名次的软件
  • 企业网站模板网 凡建站免费建设个人网站
  • 一个网站可以做多少地区词网络营销就是
  • 德州网站seo北京快速优化排名
  • 网站建设标准流程及外包注意事项上海网站seo
  • pc网站自动跳转wap百度搜索推广多少钱
  • 免费打广告网站付费推广
  • 现在哪个公司的网络比较好优化大师的优化项目有哪7个
  • 做网站需要有服务器百度号码认证平台官网
  • 主流网站开发采用网站收录平台
  • 营销类网站百度投放广告联系谁
  • 滁州市大滁城建设网站怎么找网站
  • 网站建设及推广套餐网络优化
  • 做三级分销商城网站设计seo优化推广技巧
  • 大型网站的建设包括那些内容培训学校招生营销方案
  • 厦门网站关键词优化深圳龙岗区疫情最新消息
  • html论坛网站模板下载关键词排名查询网站
  • 云南网招聘网站优化人员通常会将目标关键词放在网站首页中的
  • 网和网站的区别我要下载百度