当前位置: 首页 > news >正文

详细解读视频生成模型Wan2.1代码

Diffusion models代码解读:入门与实战

前言:这篇博客的初稿写于8个月前Wan2.1刚刚开源之际,如今Wan2.1已经成为了做视频生成方向最最常用的基础模型,据不完全统计,半年间基于Wan2.1发表的顶会顶刊就超过了100篇。这篇博客详细解读一下这篇现代视频生成模型的原理和代码。

目录

VAE设计

wan2.2 和 wan2.1 vae 的设计对比:

代码解读

核心DiT结构


VAE设计

时至今日,在很多开源的VAE的对比中,wan2.1 和 wan2.2 的重建指标依旧是非常能打。

wan2.2 和 wan2.1 vae 的设计对比:

  1. 压缩比大幅提升

    • Wan2.1:8×8×2(128 倍)

    • Wan2.2:16×16×4(1024 倍)
      这意味着同样分辨率下,Wan2.2 的 latent 只占 Wan2.1 的 1/8,显存占用降低约 64%。

  2. 重建质量不降反升

    • Wan2.2 通过“非对称编解码 + 残差采样”结构,在更高压缩率下 PSNR 仍略优于 Wan2.1。

    • 官方测试 720P 视频 PSNR 达 32.5 dB,比 Wan2.1 的 30.1 dB 高出 2 dB 以上。

  3. 通道维度扩展

    • Wan2.1:latent 通道数 16

    • Wan2.2:latent 通道数 48
      更多通道补偿了高压缩带来的信息损失,细节保留更好。

  4. 速度/显存收益

    • 在 4090 上,Wan2.2-TI2V-5B 凭借新 VAE 可把 5 s 720P 视频生成时间从 Wan2.1-14B 的数分钟级缩短到约 155 s(多卡)或 534 s(单卡),且仅需 24 GB 显存即可跑满。

代码解读

wan2.1 的vae核心是 因果 + 流式”VAE。


  1. 初始化:所有超参一次性写死

dim=128            # 基础通道数  
z_dim=4            # 最终潜码通道数(报告里 16,这里 4 是“单组件”,后面会 *2)  
dim_mult=[1,2,4,4] # 4 级下采样,通道变化 128→256→512→512  
temperal_downsample=[True,True,False]  
# 对应 4 级里哪几级要做时间抽帧:第 2、3 级做 2×,第 4 级不做  

Encoder3dDecoder3d 就是报告里说的“3D 因果残差网络”,内部已经按
“空间 2D 因果 + 时间 1D 因果”拆好 kernel,保证不偷看未来。


  1. 前向:经典 VAE 三段式

x_recon, mu, log_var = model(x)
  • encode 拿 μ 和 σ

  • reparameterize 做采样

  • decode 把潜码还原成视频
    唯一特殊的是:encode/decode 内部都按“块”跑,整段视频不会一次进显存。


  1. encode:把视频切成 “1+4×n” 的因果块

t = x.shape[2]  
iter_ = 1 + (t-1)//4        # 先送 1 帧,再每 4 帧一块  
for i …  if i==0:  out = encoder(x[:,:,:1,…])   # 第 0 帧单独过  else:  out_= encoder(x[:,:,1+4*(i-1):1+4*i,…])  out = cat([out, out_], 2)    # 时间维拼回去  
  • 这样保证 时间因果:后面块永远拿不到前面块的“未来”信息。

  • feat_cache/feat_idx 是给 CausalConv3d 内部用的“隐藏状态”缓存,
    跨块时把上一块最后的 hidden 传下去,等同“RNN 的 h_t”。


  1. 潜码归一化:scale[0]=mean, scale[1]=scale

mu = (mu - scale[0]) * scale[1]

训练时 scale 是 EMA 统计的全局 mean/std;推理时可直接喂 0/1,
或者把数据集统计量传进来做 offline normalization,保证扩散模型输入 N(0,1)。


  1. decode:逐帧滑窗,同样因果

for i in range(iter_):  out_ = decoder(z[:,:,i:i+1,…])   # 每次只送 1 帧潜码  out  = cat([out, out_], 2)

解码器也带 feat_cache,所以:

  • 显存占用只与“窗口长度”有关,与总时长无关 → 无限长生成

  • 输出帧率 = 潜码帧率 × 时间下采样倍数(4×),正好对上报告里的 6 fps 潜码。

下面是vae的原始代码:

class WanVAE_(nn.Module):def __init__(self,dim=128,z_dim=4,dim_mult=[1, 2, 4, 4],num_res_blocks=2,attn_scales=[],temperal_downsample=[True, True, False],dropout=0.0):super().__init__()self.dim = dimself.z_dim = z_dimself.dim_mult = dim_multself.num_res_blocks = num_res_blocksself.attn_scales = attn_scalesself.temperal_downsample = temperal_downsampleself.temperal_upsample = temperal_downsample[::-1]# modulesself.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,attn_scales, self.temperal_downsample, dropout)self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)self.conv2 = CausalConv3d(z_dim, z_dim, 1)self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,attn_scales, self.temperal_upsample, dropout)def forward(self, x):mu, log_var = self.encode(x)z = self.reparameterize(mu, log_var)x_recon = self.decode(z)return x_recon, mu, log_vardef encode(self, x, scale):self.clear_cache()## cachet = x.shape[2]iter_ = 1 + (t - 1) // 4## 对encode输入的x,按时间拆分为1、4、4、4....for i in range(iter_):self._enc_conv_idx = [0]if i == 0:out = self.encoder(x[:, :, :1, :, :],feat_cache=self._enc_feat_map,feat_idx=self._enc_conv_idx)else:out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],feat_cache=self._enc_feat_map,feat_idx=self._enc_conv_idx)out = torch.cat([out, out_], 2)mu, log_var = self.conv1(out).chunk(2, dim=1)if isinstance(scale[0], torch.Tensor):mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)else:mu = (mu - scale[0]) * scale[1]self.clear_cache()return mudef decode(self, z, scale):self.clear_cache()# z: [b,c,t,h,w]if isinstance(scale[0], torch.Tensor):z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)else:z = z / scale[1] + scale[0]iter_ = z.shape[2]x = self.conv2(z)for i in range(iter_):self._conv_idx = [0]if i == 0:out = self.decoder(x[:, :, i:i + 1, :, :],feat_cache=self._feat_map,feat_idx=self._conv_idx)else:out_ = self.decoder(x[:, :, i:i + 1, :, :],feat_cache=self._feat_map,feat_idx=self._conv_idx)out = torch.cat([out, out_], 2)self.clear_cache()return outdef reparameterize(self, mu, log_var):std = torch.exp(0.5 * log_var)eps = torch.randn_like(std)return eps * std + mudef sample(self, imgs, deterministic=False):mu, log_var = self.encode(imgs)if deterministic:return mustd = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))return mu + std * torch.randn_like(std)def clear_cache(self):self._conv_num = count_conv3d(self.decoder)self._conv_idx = [0]self._feat_map = [None] * self._conv_num#cache encodeself._enc_conv_num = count_conv3d(self.encoder)self._enc_conv_idx = [0]self._enc_feat_map = [None] * self._enc_conv_num

作者在外面还包了一层,相当于warp吧。可以看出整个项目的编码风格不一样,明显是不同人写的:

def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):"""Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL."""# paramscfg = dict(dim=96,z_dim=z_dim,dim_mult=[1, 2, 4, 4],num_res_blocks=2,attn_scales=[],temperal_downsample=[False, True, True],dropout=0.0)cfg.update(**kwargs)# init modelwith torch.device('meta'):model = WanVAE_(**cfg)# load checkpointlogging.info(f'loading {pretrained_path}')model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True)return modelclass WanVAE:def __init__(self,z_dim=16,vae_pth='cache/vae_step_411000.pth',dtype=torch.float,device="cuda"):self.dtype = dtypeself.device = devicemean = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921]std = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160]self.mean = torch.tensor(mean, dtype=dtype, device=device)self.std = torch.tensor(std, dtype=dtype, device=device)self.scale = [self.mean, 1.0 / self.std]# init modelself.model = _video_vae(pretrained_path=vae_pth,z_dim=z_dim,).eval().requires_grad_(False).to(device)def encode(self, videos):"""videos: A list of videos each with shape [C, T, H, W]."""with amp.autocast(dtype=self.dtype):return [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)for u in videos]def decode(self, zs):with amp.autocast(dtype=self.dtype):return [self.model.decode(u.unsqueeze(0),self.scale).float().clamp_(-1, 1).squeeze(0)for u in zs]

核心DiT结构

wan2.1 采用了比较经典的transformer结构,而不是diffusion 领域大火的单流-双流结构。

  1. 前向:一条流水线,4 步走完
    Patchify

Python

复制

x = [C, F, H, W]  
x = patch_embedding(u.unsqueeze(0))  # [1,2048,F,H/2,W/2]  
x = flatten(2).transpose(1,2)        # [1,L,2048]  L=F*H/2*W/2

每个视频单独算 L,再 cat 成一个 batch,不足 seq_len 补零。

时间步嵌入

e  = time_embedding(sinusoidal(t))   # [B,2048]  
e0 = time_projection(e)              # [B,6,2048]  拆成 6 份给 Layer-scale + Gate

6 份分别喂给:

  • 每个 WanAttentionBlockself-attn gatecross-attn gate

  • Head输出 scale-shift
    unflatten(1,(6,dim)) 一次搞定。

文本 / 图像条件

  • 文本:T5 向量先 Linear→GELU→Linear 统一到 2048,再 pad 到 512 长。

  • 图像:CLIP 图像 token 过 MLPProj 得到 257×2048,拼在文本前面,形成 image-first 的 cross-attn 序列
    任务类型由 model_type 控制:
    t2v 只走文本;i2v/flf2v/vace 都额外吃 clip_fea

32 层 Transformer + Head
每块 WanAttentionBlock 内部顺序:

  1. Self-Attention(窗口 or 全局)

  2. Cross-Attention(文本/图像)

  3. SwiGLU-FFN
    全部 Pre-RMSNorm + 残差 + Layer-scale(6 份 e0)
    最后 Head线性 + pixel-shuffle,把 2048 映射回 out_dim*patch_size[0]*patch_size[1]*patch_size[2]
    再由 unpatchify 恢复 [C_out, F, H/8, W/8]


class WanModel(ModelMixin, ConfigMixin):r"""Wan diffusion backbone supporting both text-to-video and image-to-video."""ignore_for_config = ['patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size']_no_split_modules = ['WanAttentionBlock']@register_to_configdef __init__(self,model_type='t2v',patch_size=(1, 2, 2),text_len=512,in_dim=16,dim=2048,ffn_dim=8192,freq_dim=256,text_dim=4096,out_dim=16,num_heads=16,num_layers=32,window_size=(-1, -1),qk_norm=True,cross_attn_norm=True,eps=1e-6):r"""Initialize the diffusion model backbone.Args:model_type (`str`, *optional*, defaults to 't2v'):Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video) or 'vace'patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):3D patch dimensions for video embedding (t_patch, h_patch, w_patch)text_len (`int`, *optional*, defaults to 512):Fixed length for text embeddingsin_dim (`int`, *optional*, defaults to 16):Input video channels (C_in)dim (`int`, *optional*, defaults to 2048):Hidden dimension of the transformerffn_dim (`int`, *optional*, defaults to 8192):Intermediate dimension in feed-forward networkfreq_dim (`int`, *optional*, defaults to 256):Dimension for sinusoidal time embeddingstext_dim (`int`, *optional*, defaults to 4096):Input dimension for text embeddingsout_dim (`int`, *optional*, defaults to 16):Output video channels (C_out)num_heads (`int`, *optional*, defaults to 16):Number of attention headsnum_layers (`int`, *optional*, defaults to 32):Number of transformer blockswindow_size (`tuple`, *optional*, defaults to (-1, -1)):Window size for local attention (-1 indicates global attention)qk_norm (`bool`, *optional*, defaults to True):Enable query/key normalizationcross_attn_norm (`bool`, *optional*, defaults to False):Enable cross-attention normalizationeps (`float`, *optional*, defaults to 1e-6):Epsilon value for normalization layers"""super().__init__()assert model_type in ['t2v', 'i2v', 'flf2v', 'vace']self.model_type = model_typeself.patch_size = patch_sizeself.text_len = text_lenself.in_dim = in_dimself.dim = dimself.ffn_dim = ffn_dimself.freq_dim = freq_dimself.text_dim = text_dimself.out_dim = out_dimself.num_heads = num_headsself.num_layers = num_layersself.window_size = window_sizeself.qk_norm = qk_normself.cross_attn_norm = cross_attn_normself.eps = eps# embeddingsself.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),nn.Linear(dim, dim))self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))# blockscross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'self.blocks = nn.ModuleList([WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,window_size, qk_norm, cross_attn_norm, eps)for _ in range(num_layers)])# headself.head = Head(dim, out_dim, patch_size, eps)# buffers (don't use register_buffer otherwise dtype will be changed in to())assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0d = dim // num_headsself.freqs = torch.cat([rope_params(1024, d - 4 * (d // 6)),rope_params(1024, 2 * (d // 6)),rope_params(1024, 2 * (d // 6))],dim=1)if model_type == 'i2v' or model_type == 'flf2v':self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v')# initialize weightsself.init_weights()def forward(self,x,t,context,seq_len,clip_fea=None,y=None,):r"""Forward pass through the diffusion modelArgs:x (List[Tensor]):List of input video tensors, each with shape [C_in, F, H, W]t (Tensor):Diffusion timesteps tensor of shape [B]context (List[Tensor]):List of text embeddings each with shape [L, C]seq_len (`int`):Maximum sequence length for positional encodingclip_fea (Tensor, *optional*):CLIP image features for image-to-video mode or first-last-frame-to-video modey (List[Tensor], *optional*):Conditional video inputs for image-to-video mode, same shape as xReturns:List[Tensor]:List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]"""if self.model_type == 'i2v' or self.model_type == 'flf2v':assert clip_fea is not None and y is not None# paramsdevice = self.patch_embedding.weight.deviceif self.freqs.device != device:self.freqs = self.freqs.to(device)if y is not None:x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]# embeddingsx = [self.patch_embedding(u.unsqueeze(0)) for u in x]grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x])x = [u.flatten(2).transpose(1, 2) for u in x]seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)assert seq_lens.max() <= seq_lenx = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],dim=1) for u in x])# time embeddingswith amp.autocast(dtype=torch.float32):e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t).float())e0 = self.time_projection(e).unflatten(1, (6, self.dim))assert e.dtype == torch.float32 and e0.dtype == torch.float32# contextcontext_lens = Nonecontext = self.text_embedding(torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])for u in context]))if clip_fea is not None:context_clip = self.img_emb(clip_fea)  # bs x 257 (x2) x dimcontext = torch.concat([context_clip, context], dim=1)# argumentskwargs = dict(e=e0,seq_lens=seq_lens,grid_sizes=grid_sizes,freqs=self.freqs,context=context,context_lens=context_lens)for block in self.blocks:x = block(x, **kwargs)# headx = self.head(x, e)# unpatchifyx = self.unpatchify(x, grid_sizes)return [u.float() for u in x]def unpatchify(self, x, grid_sizes):r"""Reconstruct video tensors from patch embeddings.Args:x (List[Tensor]):List of patchified features, each with shape [L, C_out * prod(patch_size)]grid_sizes (Tensor):Original spatial-temporal grid dimensions before patching,shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)Returns:List[Tensor]:Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]"""c = self.out_dimout = []for u, v in zip(x, grid_sizes.tolist()):u = u[:math.prod(v)].view(*v, *self.patch_size, c)u = torch.einsum('fhwpqrc->cfphqwr', u)u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])out.append(u)return outdef init_weights(self):r"""Initialize model parameters using Xavier initialization."""# basic initfor m in self.modules():if isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)if m.bias is not None:nn.init.zeros_(m.bias)# init embeddingsnn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))for m in self.text_embedding.modules():if isinstance(m, nn.Linear):nn.init.normal_(m.weight, std=.02)for m in self.time_embedding.modules():if isinstance(m, nn.Linear):nn.init.normal_(m.weight, std=.02)# init output layernn.init.zeros_(self.head.head.weight)

写不下了,下一篇继续。

http://www.dtcms.com/a/601064.html

相关文章:

  • Cortex-M3-STM32F1 开发:(二十二)HAL 库开发 ➤ STM32 中断逻辑优先级计算
  • THC63LVD1027D一款10位双链路LVDS信号中继器芯片,支持WUXGA分辨率视频数据传输THC63LVD1027支持30位数据通道方案
  • 考研规划手册
  • MongoDB中 client_connection和database和collection之间的关系
  • 建筑网站建设赏析外贸公司用什么建网站
  • [智能体设计模式] 第4章:反思(Reflection)
  • 系统架构设计师与考研408在IT基础设施能力考核上的全面对比研究
  • 饮用水品牌营销型网站手机网站主页
  • 亿网中国网站管理系统绍兴网站网站建设
  • 基于web宿舍管理系统的设计与实现
  • 利用idea创建springboot多模块项目
  • C++仿muduo库高并发服务器项目:Poller模块
  • QT C++ QWebEngine与Web JS之间通信
  • 华为防火墙web配置SSL-在外人员访问内网资源
  • 本地部署事务管理软件 JIRA 并实现外网访问(Windows 版本)
  • 18、Linux常用命令-磁盘分区相关命令
  • nvm与node.js的安装指南
  • python+django/flask+vue的书城图书阅读器系统,亮点含目录章节pycharm
  • 外贸cms什么意思seo海外推广
  • C++网络开发---CURL与CURLcode数据类型
  • 【Python数据分析】数据分析与可视化
  • MyBatis概述
  • Hadoop集群搭建(下):centos 7为例(已将将安装所需压缩包统一放在了/opt/software目录下)
  • 美创网站建设优势开县网站制作
  • 北京市网站建设网站怎么盈利的
  • 2.6、安全大脑:AI驱动的安全编排与自动化响应实战
  • Linux 进程间通信怎么选?——场景化决策指南
  • 折800网站源码石家庄新闻发布会
  • ThreadLocal 中弱引用(WeakReference)设计:为什么要 “故意” 让 Key 被回收?
  • Java大厂面试真题:从Spring Boot到AI微服务的三轮技术拷问