详细解读视频生成模型Wan2.1代码
Diffusion models代码解读:入门与实战
前言:这篇博客的初稿写于8个月前Wan2.1刚刚开源之际,如今Wan2.1已经成为了做视频生成方向最最常用的基础模型,据不完全统计,半年间基于Wan2.1发表的顶会顶刊就超过了100篇。这篇博客详细解读一下这篇现代视频生成模型的原理和代码。
目录
VAE设计
wan2.2 和 wan2.1 vae 的设计对比:
代码解读
核心DiT结构

VAE设计
时至今日,在很多开源的VAE的对比中,wan2.1 和 wan2.2 的重建指标依旧是非常能打。
wan2.2 和 wan2.1 vae 的设计对比:
-
压缩比大幅提升
-
Wan2.1:8×8×2(128 倍)
-
Wan2.2:16×16×4(1024 倍)
这意味着同样分辨率下,Wan2.2 的 latent 只占 Wan2.1 的 1/8,显存占用降低约 64%。
-
-
重建质量不降反升
-
Wan2.2 通过“非对称编解码 + 残差采样”结构,在更高压缩率下 PSNR 仍略优于 Wan2.1。
-
官方测试 720P 视频 PSNR 达 32.5 dB,比 Wan2.1 的 30.1 dB 高出 2 dB 以上。
-
-
通道维度扩展
-
Wan2.1:latent 通道数 16
-
Wan2.2:latent 通道数 48
更多通道补偿了高压缩带来的信息损失,细节保留更好。
-
-
速度/显存收益
-
在 4090 上,Wan2.2-TI2V-5B 凭借新 VAE 可把 5 s 720P 视频生成时间从 Wan2.1-14B 的数分钟级缩短到约 155 s(多卡)或 534 s(单卡),且仅需 24 GB 显存即可跑满。
-
代码解读
wan2.1 的vae核心是 因果 + 流式”VAE。
-
初始化:所有超参一次性写死
dim=128 # 基础通道数
z_dim=4 # 最终潜码通道数(报告里 16,这里 4 是“单组件”,后面会 *2)
dim_mult=[1,2,4,4] # 4 级下采样,通道变化 128→256→512→512
temperal_downsample=[True,True,False]
# 对应 4 级里哪几级要做时间抽帧:第 2、3 级做 2×,第 4 级不做
Encoder3d 和 Decoder3d 就是报告里说的“3D 因果残差网络”,内部已经按
“空间 2D 因果 + 时间 1D 因果”拆好 kernel,保证不偷看未来。
-
前向:经典 VAE 三段式
x_recon, mu, log_var = model(x)
-
encode拿 μ 和 σ -
reparameterize做采样 -
decode把潜码还原成视频
唯一特殊的是:encode/decode 内部都按“块”跑,整段视频不会一次进显存。
-
encode:把视频切成 “1+4×n” 的因果块
t = x.shape[2]
iter_ = 1 + (t-1)//4 # 先送 1 帧,再每 4 帧一块
for i … if i==0: out = encoder(x[:,:,:1,…]) # 第 0 帧单独过 else: out_= encoder(x[:,:,1+4*(i-1):1+4*i,…]) out = cat([out, out_], 2) # 时间维拼回去
-
这样保证 时间因果:后面块永远拿不到前面块的“未来”信息。
-
feat_cache/feat_idx是给CausalConv3d内部用的“隐藏状态”缓存,
跨块时把上一块最后的 hidden 传下去,等同“RNN 的 h_t”。
-
潜码归一化:scale[0]=mean, scale[1]=scale
mu = (mu - scale[0]) * scale[1]
训练时 scale 是 EMA 统计的全局 mean/std;推理时可直接喂 0/1,
或者把数据集统计量传进来做 offline normalization,保证扩散模型输入 N(0,1)。
-
decode:逐帧滑窗,同样因果
for i in range(iter_): out_ = decoder(z[:,:,i:i+1,…]) # 每次只送 1 帧潜码 out = cat([out, out_], 2)
解码器也带 feat_cache,所以:
-
显存占用只与“窗口长度”有关,与总时长无关 → 无限长生成。
-
输出帧率 = 潜码帧率 × 时间下采样倍数(4×),正好对上报告里的 6 fps 潜码。
下面是vae的原始代码:
class WanVAE_(nn.Module):def __init__(self,dim=128,z_dim=4,dim_mult=[1, 2, 4, 4],num_res_blocks=2,attn_scales=[],temperal_downsample=[True, True, False],dropout=0.0):super().__init__()self.dim = dimself.z_dim = z_dimself.dim_mult = dim_multself.num_res_blocks = num_res_blocksself.attn_scales = attn_scalesself.temperal_downsample = temperal_downsampleself.temperal_upsample = temperal_downsample[::-1]# modulesself.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,attn_scales, self.temperal_downsample, dropout)self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)self.conv2 = CausalConv3d(z_dim, z_dim, 1)self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,attn_scales, self.temperal_upsample, dropout)def forward(self, x):mu, log_var = self.encode(x)z = self.reparameterize(mu, log_var)x_recon = self.decode(z)return x_recon, mu, log_vardef encode(self, x, scale):self.clear_cache()## cachet = x.shape[2]iter_ = 1 + (t - 1) // 4## 对encode输入的x,按时间拆分为1、4、4、4....for i in range(iter_):self._enc_conv_idx = [0]if i == 0:out = self.encoder(x[:, :, :1, :, :],feat_cache=self._enc_feat_map,feat_idx=self._enc_conv_idx)else:out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],feat_cache=self._enc_feat_map,feat_idx=self._enc_conv_idx)out = torch.cat([out, out_], 2)mu, log_var = self.conv1(out).chunk(2, dim=1)if isinstance(scale[0], torch.Tensor):mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)else:mu = (mu - scale[0]) * scale[1]self.clear_cache()return mudef decode(self, z, scale):self.clear_cache()# z: [b,c,t,h,w]if isinstance(scale[0], torch.Tensor):z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)else:z = z / scale[1] + scale[0]iter_ = z.shape[2]x = self.conv2(z)for i in range(iter_):self._conv_idx = [0]if i == 0:out = self.decoder(x[:, :, i:i + 1, :, :],feat_cache=self._feat_map,feat_idx=self._conv_idx)else:out_ = self.decoder(x[:, :, i:i + 1, :, :],feat_cache=self._feat_map,feat_idx=self._conv_idx)out = torch.cat([out, out_], 2)self.clear_cache()return outdef reparameterize(self, mu, log_var):std = torch.exp(0.5 * log_var)eps = torch.randn_like(std)return eps * std + mudef sample(self, imgs, deterministic=False):mu, log_var = self.encode(imgs)if deterministic:return mustd = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))return mu + std * torch.randn_like(std)def clear_cache(self):self._conv_num = count_conv3d(self.decoder)self._conv_idx = [0]self._feat_map = [None] * self._conv_num#cache encodeself._enc_conv_num = count_conv3d(self.encoder)self._enc_conv_idx = [0]self._enc_feat_map = [None] * self._enc_conv_num
作者在外面还包了一层,相当于warp吧。可以看出整个项目的编码风格不一样,明显是不同人写的:
def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):"""Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL."""# paramscfg = dict(dim=96,z_dim=z_dim,dim_mult=[1, 2, 4, 4],num_res_blocks=2,attn_scales=[],temperal_downsample=[False, True, True],dropout=0.0)cfg.update(**kwargs)# init modelwith torch.device('meta'):model = WanVAE_(**cfg)# load checkpointlogging.info(f'loading {pretrained_path}')model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True)return modelclass WanVAE:def __init__(self,z_dim=16,vae_pth='cache/vae_step_411000.pth',dtype=torch.float,device="cuda"):self.dtype = dtypeself.device = devicemean = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921]std = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160]self.mean = torch.tensor(mean, dtype=dtype, device=device)self.std = torch.tensor(std, dtype=dtype, device=device)self.scale = [self.mean, 1.0 / self.std]# init modelself.model = _video_vae(pretrained_path=vae_pth,z_dim=z_dim,).eval().requires_grad_(False).to(device)def encode(self, videos):"""videos: A list of videos each with shape [C, T, H, W]."""with amp.autocast(dtype=self.dtype):return [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)for u in videos]def decode(self, zs):with amp.autocast(dtype=self.dtype):return [self.model.decode(u.unsqueeze(0),self.scale).float().clamp_(-1, 1).squeeze(0)for u in zs]
核心DiT结构
wan2.1 采用了比较经典的transformer结构,而不是diffusion 领域大火的单流-双流结构。
-
前向:一条流水线,4 步走完
① Patchify
Python
复制
x = [C, F, H, W]
x = patch_embedding(u.unsqueeze(0)) # [1,2048,F,H/2,W/2]
x = flatten(2).transpose(1,2) # [1,L,2048] L=F*H/2*W/2
每个视频单独算 L,再 cat 成一个 batch,不足 seq_len 补零。
② 时间步嵌入
e = time_embedding(sinusoidal(t)) # [B,2048]
e0 = time_projection(e) # [B,6,2048] 拆成 6 份给 Layer-scale + Gate
6 份分别喂给:
-
每个
WanAttentionBlock的 self-attn gate 与 cross-attn gate -
Head的 输出 scale-shift
用unflatten(1,(6,dim))一次搞定。
③ 文本 / 图像条件
-
文本:T5 向量先
Linear→GELU→Linear统一到 2048,再 pad 到 512 长。 -
图像:CLIP 图像 token 过
MLPProj得到 257×2048,拼在文本前面,形成 image-first 的 cross-attn 序列。
任务类型由model_type控制:
t2v只走文本;i2v/flf2v/vace都额外吃clip_fea。
④ 32 层 Transformer + Head
每块 WanAttentionBlock 内部顺序:
-
Self-Attention(窗口 or 全局)
-
Cross-Attention(文本/图像)
-
SwiGLU-FFN
全部 Pre-RMSNorm + 残差 + Layer-scale(6 份 e0)。
最后Head做 线性 + pixel-shuffle,把 2048 映射回out_dim*patch_size[0]*patch_size[1]*patch_size[2],
再由unpatchify恢复[C_out, F, H/8, W/8]。
class WanModel(ModelMixin, ConfigMixin):r"""Wan diffusion backbone supporting both text-to-video and image-to-video."""ignore_for_config = ['patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size']_no_split_modules = ['WanAttentionBlock']@register_to_configdef __init__(self,model_type='t2v',patch_size=(1, 2, 2),text_len=512,in_dim=16,dim=2048,ffn_dim=8192,freq_dim=256,text_dim=4096,out_dim=16,num_heads=16,num_layers=32,window_size=(-1, -1),qk_norm=True,cross_attn_norm=True,eps=1e-6):r"""Initialize the diffusion model backbone.Args:model_type (`str`, *optional*, defaults to 't2v'):Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video) or 'vace'patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):3D patch dimensions for video embedding (t_patch, h_patch, w_patch)text_len (`int`, *optional*, defaults to 512):Fixed length for text embeddingsin_dim (`int`, *optional*, defaults to 16):Input video channels (C_in)dim (`int`, *optional*, defaults to 2048):Hidden dimension of the transformerffn_dim (`int`, *optional*, defaults to 8192):Intermediate dimension in feed-forward networkfreq_dim (`int`, *optional*, defaults to 256):Dimension for sinusoidal time embeddingstext_dim (`int`, *optional*, defaults to 4096):Input dimension for text embeddingsout_dim (`int`, *optional*, defaults to 16):Output video channels (C_out)num_heads (`int`, *optional*, defaults to 16):Number of attention headsnum_layers (`int`, *optional*, defaults to 32):Number of transformer blockswindow_size (`tuple`, *optional*, defaults to (-1, -1)):Window size for local attention (-1 indicates global attention)qk_norm (`bool`, *optional*, defaults to True):Enable query/key normalizationcross_attn_norm (`bool`, *optional*, defaults to False):Enable cross-attention normalizationeps (`float`, *optional*, defaults to 1e-6):Epsilon value for normalization layers"""super().__init__()assert model_type in ['t2v', 'i2v', 'flf2v', 'vace']self.model_type = model_typeself.patch_size = patch_sizeself.text_len = text_lenself.in_dim = in_dimself.dim = dimself.ffn_dim = ffn_dimself.freq_dim = freq_dimself.text_dim = text_dimself.out_dim = out_dimself.num_heads = num_headsself.num_layers = num_layersself.window_size = window_sizeself.qk_norm = qk_normself.cross_attn_norm = cross_attn_normself.eps = eps# embeddingsself.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),nn.Linear(dim, dim))self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))# blockscross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'self.blocks = nn.ModuleList([WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,window_size, qk_norm, cross_attn_norm, eps)for _ in range(num_layers)])# headself.head = Head(dim, out_dim, patch_size, eps)# buffers (don't use register_buffer otherwise dtype will be changed in to())assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0d = dim // num_headsself.freqs = torch.cat([rope_params(1024, d - 4 * (d // 6)),rope_params(1024, 2 * (d // 6)),rope_params(1024, 2 * (d // 6))],dim=1)if model_type == 'i2v' or model_type == 'flf2v':self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v')# initialize weightsself.init_weights()def forward(self,x,t,context,seq_len,clip_fea=None,y=None,):r"""Forward pass through the diffusion modelArgs:x (List[Tensor]):List of input video tensors, each with shape [C_in, F, H, W]t (Tensor):Diffusion timesteps tensor of shape [B]context (List[Tensor]):List of text embeddings each with shape [L, C]seq_len (`int`):Maximum sequence length for positional encodingclip_fea (Tensor, *optional*):CLIP image features for image-to-video mode or first-last-frame-to-video modey (List[Tensor], *optional*):Conditional video inputs for image-to-video mode, same shape as xReturns:List[Tensor]:List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]"""if self.model_type == 'i2v' or self.model_type == 'flf2v':assert clip_fea is not None and y is not None# paramsdevice = self.patch_embedding.weight.deviceif self.freqs.device != device:self.freqs = self.freqs.to(device)if y is not None:x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]# embeddingsx = [self.patch_embedding(u.unsqueeze(0)) for u in x]grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x])x = [u.flatten(2).transpose(1, 2) for u in x]seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)assert seq_lens.max() <= seq_lenx = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],dim=1) for u in x])# time embeddingswith amp.autocast(dtype=torch.float32):e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t).float())e0 = self.time_projection(e).unflatten(1, (6, self.dim))assert e.dtype == torch.float32 and e0.dtype == torch.float32# contextcontext_lens = Nonecontext = self.text_embedding(torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])for u in context]))if clip_fea is not None:context_clip = self.img_emb(clip_fea) # bs x 257 (x2) x dimcontext = torch.concat([context_clip, context], dim=1)# argumentskwargs = dict(e=e0,seq_lens=seq_lens,grid_sizes=grid_sizes,freqs=self.freqs,context=context,context_lens=context_lens)for block in self.blocks:x = block(x, **kwargs)# headx = self.head(x, e)# unpatchifyx = self.unpatchify(x, grid_sizes)return [u.float() for u in x]def unpatchify(self, x, grid_sizes):r"""Reconstruct video tensors from patch embeddings.Args:x (List[Tensor]):List of patchified features, each with shape [L, C_out * prod(patch_size)]grid_sizes (Tensor):Original spatial-temporal grid dimensions before patching,shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)Returns:List[Tensor]:Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]"""c = self.out_dimout = []for u, v in zip(x, grid_sizes.tolist()):u = u[:math.prod(v)].view(*v, *self.patch_size, c)u = torch.einsum('fhwpqrc->cfphqwr', u)u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])out.append(u)return outdef init_weights(self):r"""Initialize model parameters using Xavier initialization."""# basic initfor m in self.modules():if isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)if m.bias is not None:nn.init.zeros_(m.bias)# init embeddingsnn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))for m in self.text_embedding.modules():if isinstance(m, nn.Linear):nn.init.normal_(m.weight, std=.02)for m in self.time_embedding.modules():if isinstance(m, nn.Linear):nn.init.normal_(m.weight, std=.02)# init output layernn.init.zeros_(self.head.head.weight)
写不下了,下一篇继续。
