当前位置: 首页 > news >正文

Wan2.1 图生视频 多卡推理批量生成视频

Wan2.1 图生视频 多卡推理批量生成视频

flyfish

视频生成的实践效果展示

Phantom 视频生成的实践
Phantom 视频生成的流程
Phantom 视频生成的命令

Wan2.1 图生视频 支持批量生成
Wan2.1 文生视频 支持批量生成、参数化配置和多语言提示词管理
Wan2.1 加速推理方法
Wan2.1 通过首尾帧生成视频

AnyText2 在图片里玩文字而且还是所想即所得
Python 实现从 MP4 视频文件中平均提取指定数量的帧

config.json

{"task": "i2v-14B","size": "832*480","frame_num": null,"ckpt_dir": "/media/models/Wan-AI/Wan2___1-I2V-14B-480P/","offload_model": null,"ulysses_size": 2,"ring_size": 1,"t5_fsdp": false,"t5_cpu": true,"dit_fsdp": true,"save_file": null,"prompt": null,"use_prompt_extend": false,"prompt_extend_method": "local_qwen","prompt_extend_model": null,"prompt_extend_target_lang": "zh","base_seed": -1,"image": null,"first_frame": null,"last_frame": null,"sample_solver": "unipc","sample_steps": null,"sample_shift": null,"sample_guide_scale": 5.0
}

prompt.json

[{"prompt": "Dragon Playing with Pearl: A warrior wields a red-tasseled spear, summoning seven dragon-like phantom spear tips amid swirling ink shadows that twist air into a shredding vortex; visuals include ink-black shadows, molten fire-red tassel, and a violent air vortex. ","image_paths": ["images/1.png"]},{"prompt": "Slicing the Sky, Chopping the Moon: The warrior leaps, slashing the spear diagonally like lightning to create a glowing vacuum rift with azure electricity, then traces a lunar arc that solidifies space to trap enemies; visuals feature a billowing black cape, crackling rift, and frozen lunar arc. ","image_paths": ["images/1.png"]}
]

流程

WanI2VApp.run()
├─ 主应用启动
├─ 加载配置/验证参数
│  ├─ 设置 frame_num=81 等默认值
│  └─ 校验任务和分辨率合法性
├─ 初始化分布式环境
│  ├─ 多GPU时启动进程组
│  ├─ 同步随机种子
│  └─ 验证分布式参数
├─ 模型单例加载(核心优化点)
│  ├─ 创建 WanI2V 模型实例
│  ├─ 加载 checkpoint 到 GPU
│  └─ 日志:"Creating WanI2V pipeline (first time)."
└─ 批量处理图片循环(N张图片)├─ 读取 prompt 和 image_paths├─ 对每张图片:│  ├─ 打开图片并转换格式│  ├─ 提示词扩展处理│  │  ├─ 调用 DashScope/Qwen 扩展器│  │  ├─ 分布式环境广播扩展结果│  │  └─ 失败时回退到原始提示词│  ├─ 复用模型推理│  │  ├─ 调用 model.generate() 方法│  │  ├─ 传入分辨率、帧数等参数│  │  └─ 日志:"Generating video with existing model."│  └─ 保存视频│     ├─ 生成默认文件名(含时间戳和提示词)│     └─ 调用 cache_video 保存为 MP4└─ 模型资源清理(主进程执行)├─ 删除模型实例(del self.model)├─ 清理 GPU 缓存(torch.cuda.empty_cache())└─ 日志:"Model resources cleaned up."

模型加载的时序图

┌──────────────────────────────────────────────────────────┐
│                      WanI2VApp.run()                      │
│  ┌─────────────────┐  ┌─────────────────┐  ┌────────────┐ │
│  │ 加载配置/验证参数 │  │ 初始化分布式环境 │  │ 加载模型   │ │
│  └─────────────────┘  └─────────────────┘  └──────┬─────┘ │
│                                                     │     │
│  ┌───────────────────────────────────────────────┐  │     │
│  │  遍历 prompt.json 中的每个 prompt 和 image   │  │     │
│  ├───────────────────────────────────────────────┤  │     │
│  │  ┌────────────┐  ┌────────────┐  ┌──────────┐  │     │
│  │  │ 处理提示词 │  │ 推理生成视频 │  │ 保存视频 │  │     │
│  │  └────────────┘  └────────────┘  └──────────┘  │     │
│  └───────────────────────────────────────────────┘  │     │
│                                                     │     │
│  ┌──────────────────────────┐                       │     │
│  │ 清理模型资源(仅主进程) │                       │     │
│  └──────────────────────────┘                       │     │
└──────────────────────────────────────────────────────────┘↑                  ↑                  ↑│                  │                  │模型首次加载           复用模型推理         释放模型资源

代码

import argparse
from datetime import datetime
import logging
import os
import sys
import warnings
import jsonwarnings.filterwarnings('ignore')import torch, random
import torch.distributed as dist
from PIL import Imageimport wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video, cache_image, str2boolclass ArgsValidator:@staticmethoddef validate(args):# Basic checkassert args.ckpt_dir is not None, "Please specify the checkpoint directory."assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"# The default sampling steps are 40 for image-to-video tasks.if args.sample_steps is None:args.sample_steps = 40if args.sample_shift is None:args.sample_shift = 3.0 if args.size in ["832*480", "480*832"] else 5.0# The default number of frames are 81.if args.frame_num is None:args.frame_num = 81args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize)# Size checkassert args.size in SUPPORTED_SIZES[args.task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"return argsclass ConfigLoader:@staticmethoddef load_config():# 从配置文件读取参数with open('config.json', 'r') as f:config = json.load(f)# 创建一个命名空间来存储参数class ArgsNamespace:def __init__(self, **kwargs):self.__dict__.update(kwargs)args = ArgsNamespace(**config)return argsclass LoggerInitializer:@staticmethoddef initialize(rank):# loggingif rank == 0:# set formatlogging.basicConfig(level=logging.INFO,format="[%(asctime)s] %(levelname)s: %(message)s",handlers=[logging.StreamHandler(stream=sys.stdout)])else:logging.basicConfig(level=logging.ERROR)class DistributedEnv:def __init__(self, args):self.args = argsself.rank = int(os.getenv("RANK", 0))self.world_size = int(os.getenv("WORLD_SIZE", 1))self.local_rank = int(os.getenv("LOCAL_RANK", 0))self.device = self.local_rankdef initialize(self):if self.args.offload_model is None:self.args.offload_model = False if self.world_size > 1 else Truelogging.info(f"offload_model is not specified, set to {self.args.offload_model}.")if self.world_size > 1:torch.cuda.set_device(self.local_rank)dist.init_process_group(backend="nccl",init_method="env://",rank=self.rank,world_size=self.world_size)else:assert not (self.args.t5_fsdp or self.args.dit_fsdp), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."assert not (self.args.ulysses_size > 1 or self.args.ring_size > 1), f"context parallel are not supported in non-distributed environments."if self.args.ulysses_size > 1 or self.args.ring_size > 1:assert self.args.ulysses_size * self.args.ring_size == self.world_size, f"The number of ulysses_size and ring_size should be equal to the world size."from xfuser.core.distributed import (initialize_model_parallel,init_distributed_environment)init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())initialize_model_parallel(sequence_parallel_degree=dist.get_world_size(),ring_degree=self.args.ring_size,ulysses_degree=self.args.ulysses_size,)if dist.is_initialized():base_seed = [self.args.base_seed] if self.rank == 0 else [None]dist.broadcast_object_list(base_seed, src=0)self.args.base_seed = base_seed[0]return self.args, self.rank, self.deviceclass PromptProcessor:def __init__(self, args, rank, device):self.args = argsself.rank = rankself.device = devicedef process(self, img):if not self.args.use_prompt_extend:return self.args.promptlogging.info("Extending prompt ...")if self.rank == 0:if self.args.prompt_extend_method == "dashscope":prompt_expander = DashScopePromptExpander(model_name=self.args.prompt_extend_model, is_vl=True)elif self.args.prompt_extend_method == "local_qwen":prompt_expander = QwenPromptExpander(model_name=self.args.prompt_extend_model,is_vl=True,device=self.rank)else:raise NotImplementedError(f"Unsupport prompt_extend_method: {self.args.prompt_extend_method}")prompt_output = prompt_expander(self.args.prompt,tar_lang=self.args.prompt_extend_target_lang,image=img,seed=self.args.base_seed)if prompt_output.status == False:logging.info(f"Extending prompt failed: {prompt_output.message}")logging.info("Falling back to original prompt.")input_prompt = self.args.promptelse:input_prompt = prompt_output.promptinput_prompt = [input_prompt]else:input_prompt = [None]if dist.is_initialized():dist.broadcast_object_list(input_prompt, src=0)self.args.prompt = input_prompt[0]logging.info(f"Extended prompt: {self.args.prompt}")return self.args.promptclass VideoGenerator:_instance = None  # 单例实例@classmethoddef get_instance(cls, args, rank, device):# 如果实例不存在,创建新实例if cls._instance is None:cls._instance = cls(args, rank, device)return cls._instancedef __init__(self, args, rank, device):# 初始化只执行一次self.args = argsself.rank = rankself.device = deviceself.cfg = WAN_CONFIGS[args.task]# 加载模型logging.info("Creating WanI2V pipeline (first time).")self.model = wan.WanI2V(config=self.cfg,checkpoint_dir=self.args.ckpt_dir,device_id=self.device,rank=self.rank,t5_fsdp=self.args.t5_fsdp,dit_fsdp=self.args.dit_fsdp,use_usp=(self.args.ulysses_size > 1 or self.args.ring_size > 1),t5_cpu=self.args.t5_cpu,)def generate(self, prompt, img):# 复用已加载的模型进行推理logging.info("Generating video with existing model.")video = self.model.generate(prompt,img,max_area=MAX_AREA_CONFIGS[self.args.size],frame_num=self.args.frame_num,shift=self.args.sample_shift,sample_solver=self.args.sample_solver,sampling_steps=self.args.sample_steps,guide_scale=self.args.sample_guide_scale,seed=self.args.base_seed,offload_model=self.args.offload_model)return video@classmethoddef cleanup(cls):# 清理模型资源(如在应用结束时调用)if cls._instance and hasattr(cls._instance, 'model'):del cls._instance.modeltorch.cuda.empty_cache()logging.info("Model resources cleaned up.")cls._instance = Noneclass VideoSaver:def __init__(self, args, rank):self.args = argsself.rank = rankself.cfg = WAN_CONFIGS[args.task]def save(self, video):if self.rank != 0:returnif self.args.save_file is None:formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")formatted_prompt = self.args.prompt.replace(" ", "_").replace("/","_")[:50]suffix = '.mp4'self.args.save_file = f"{self.args.task}_{self.args.size.replace('*','x') if sys.platform=='win32' else self.args.size}_{self.args.ulysses_size}_{self.args.ring_size}_{formatted_prompt}_{formatted_time}" + suffixlogging.info(f"Saving generated video to {self.args.save_file}")cache_video(tensor=video[None],save_file=self.args.save_file,fps=self.cfg.sample_fps,nrow=1,normalize=True,value_range=(-1, 1))class WanI2VApp:def __init__(self):self.args = Noneself.rank = 0self.device = 0def run(self):# 加载配置config_loader = ConfigLoader()self.args = config_loader.load_config()# 验证参数validator = ArgsValidator()self.args = validator.validate(self.args)# 初始化日志LoggerInitializer.initialize(self.rank)# 初始化分布式环境dist_env = DistributedEnv(self.args)self.args, self.rank, self.device = dist_env.initialize()logging.info(f"Generation job args: {self.args}")logging.info(f"Generation model config: {WAN_CONFIGS[self.args.task]}")# 获取单例模型生成器(只加载一次模型)generator = VideoGenerator.get_instance(self.args, self.rank, self.device)# 从prompt.json文件读取prompt和image_pathswith open('prompt.json', 'r') as f:prompt_list = json.load(f)for prompt_info in prompt_list:self.args.prompt = prompt_info["prompt"]image_paths = prompt_info["image_paths"]for image_path in image_paths:logging.info(f"Input prompt: {self.args.prompt}")logging.info(f"Input image: {image_path}")img = Image.open(image_path).convert("RGB")# 处理promptprompt_processor = PromptProcessor(self.args, self.rank, self.device)prompt = prompt_processor.process(img)# 复用已加载的模型生成视频video = generator.generate(prompt, img)# 保存视频saver = VideoSaver(self.args, self.rank)saver.save(video)# 清理模型资源(可选,在所有推理完成后调用)if self.rank == 0:VideoGenerator.cleanup()logging.info("Finished.")if __name__ == "__main__":app = WanI2VApp()app.run()

执行流程

1. 主应用初始化与配置加载

WanI2VApp 启动:创建主应用实例并调用 run() 方法。
ConfigLoader 加载配置:从 config.json 读取参数(如 taskckpt_dir 等)。
ArgsValidator 验证参数:设置默认值(如 frame_num=81)并校验合法性。

2. 环境与资源初始化

LoggerInitializer 初始化日志:主进程(rank=0)输出INFO,其他进程输出ERROR。
DistributedEnv 初始化分布式环境

  • 多GPU时启动进程组(dist.init_process_group)。
  • 同步随机种子(base_seed)确保结果可复现。
3. 模型单例加载(核心优化点)

VideoGenerator.get_instance() 调用

  • 首次调用时,创建单例实例并加载模型(wan.WanI2V)。
  • 日志提示:Creating WanI2V pipeline (first time).
  • 模型加载完成后,实例保存在 VideoGenerator._instance 中。
4. 批量处理提示词与图片

读取 prompt.json:遍历所有 promptimage_paths
PromptProcessor 扩展提示词

  • 对每张图片,使用 DashScopeQwen 扩展提示词。
  • 扩展失败时回退到原始提示词。
    VideoGenerator.generate() 推理
  • 复用已加载的模型实例(self.model)。
  • 日志提示:Generating video with existing model.
  • 每次推理仅执行计算,不重复加载模型。
5. 结果保存与资源清理

VideoSaver 保存视频:主进程将结果保存为MP4文件。
VideoGenerator.cleanup() 释放资源
- 应用结束时删除模型实例(del self.model)。
- 调用 torch.cuda.empty_cache() 清理GPU缓存。
- 日志提示:Model resources cleaned up.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.dtcms.com/a/213612.html

相关文章:

  • 视频问答功能播放器(视频问答)视频弹题功能实例
  • ffmpeg转换竖屏(画面是横屏旋转90度的竖屏文件格式)视频到横屏
  • 网易互娱游戏研发实习一面
  • 在 ElementUI 中实现 Table 单元格合并
  • 萤石云实际视频实时接入(生产环境)
  • Node.js全局对象详解:console、process与核心功能
  • [ARM][架构] 01.ARMv7 特权等级与核心寄存器
  • 代码随想录算法训练营第60期第四十八天打卡
  • 开源 FcDesigner 表单设计器组件事件详解
  • 算法打卡第七天
  • 【ARTS】【LeetCode-59】螺旋矩阵
  • Debian系统安装Python详细教程及常见问题解答
  • Leetcode 3563. Lexicographically Smallest String After Adjacent Removals
  • Steam发布游戏过程的若干问题
  • 【计算机网络】IP 协议深度解析:从基础到实战
  • 晚期NSCLC临床试验终点与分析策略
  • 重学计算机网络之命令整理
  • 【Bug】--node命令加载失败
  • 重磅升级!Docusign IAM 2025 V1 版本上线,重塑智能协议新体验
  • 计算机网络学习(八)——MAC
  • 云服务器Ubuntu系统安装Docker教程和失败原因
  • 《三维点如何映射到图像像素?——相机投影模型详解》
  • 游戏引擎学习第310天:利用网格划分完成排序加速优化
  • 算力服务器的应用场景都有哪些
  • 猿大师办公助手网页编辑Office/wps支持服务器文件多线程下载吗?
  • uboot常用命令之eMMC/SD卡命令
  • vector的实现
  • CollUtil详解
  • 游戏引擎学习第311天:支持手动排序
  • 终端没有5G图标-不支持特定NSA频段组合