当前位置: 首页 > news >正文

视频分类 pytorchvideo

目录

1. 速度 vs 精度分析

r2plus1d_r50 推理代码:

x3d_xs推理代码:

R(2+1)D

X3D(轻量级,速度快)

I3D(经典 3D CNN)

替换分类层(适配你的任务)


1. 速度 vs 精度分析

模型计算量/速度精度适合你需求的程度
X3D⭐⭐⭐⭐⭐ 最快⭐⭐⭐⭐ 较高🏆 最佳选择
R(2+1)D⭐⭐⭐ 中等⭐⭐⭐⭐⭐ 很高⭐⭐⭐ 不错但稍慢
I3D⭐⭐ 最慢⭐⭐⭐⭐ 较高⭐⭐ 不太适合
pip install pytorchvideo
import torch
from pytorchvideo.models import hubbackbone = getattr(hub, "r2plus1d_r50")(pretrained=False)backbone = torch.hub.load("facebookresearch/pytorchvideo", model="r2plus1d_r50", pretrained=True)

r2plus1d_r50 推理代码:

224*224 分类需要60ms

import timeimport torch
import numpy as np
from PIL import Image
import torchvision.transforms as transformsdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")def load_x3d_xs_model():"""加载 X3D-XS 模型,更适合小分辨率"""# model = torch.hub.load('facebookresearch/pytorchvideo', 'x3d_xs', pretrained=True)# model = torch.hub.load('facebookresearch/pytorchvideo', 'x3d_s', pretrained=True)model = torch.hub.load("facebookresearch/pytorchvideo", model="r2plus1d_r50", pretrained=True)model = model.to(device)model.eval()return modeldef preprocess_for_x3d_xs(video_frames, target_size=182, crop_size=72):"""为 X3D-XS 预处理X3D-XS 设计用于较小分辨率,推荐 160×160"""mean = [0.45, 0.45, 0.45]std = [0.225, 0.225, 0.225]num_frames = 13  # X3D-XS 使用4帧# 帧数处理if len(video_frames) > num_frames:indices = np.linspace(0, len(video_frames) - 1, num_frames, dtype=int)video_frames = [video_frames[i] for i in indices]transform = transforms.Compose([transforms.Resize(target_size), transforms.CenterCrop(crop_size), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])processed_frames = []for frame in video_frames:if isinstance(frame, np.ndarray):frame = Image.fromarray(frame)processed_frame = transform(frame)processed_frames.append(processed_frame)video_tensor = torch.stack(processed_frames).permute(1, 0, 2, 3)return video_tensor.unsqueeze(0)# 使用示例
model = load_x3d_xs_model()for i in range(10):dummy_frames = [np.random.randint(0, 255, (200, 200, 3), dtype=np.uint8) for _ in range(13)]input_tensor = preprocess_for_x3d_xs(dummy_frames)print(input_tensor.shape)input_tensor = input_tensor.to(device)start = time.time()with torch.no_grad():output = model(input_tensor)torch.cuda.synchronize()print(f"输出形状: {output.shape}", time.time() - start)  # 应该能正常工作

x3d_xs推理代码:

import torchimport osfrom torch import nnos.environ["FFCV_DISABLE_IOPATH"] = "1"
import torch.nn.functional as Ffrom pytorchvideo.models.x3d import create_x3d
import torchvision.transforms as T
from torchvision.io import read_video
import numpy as np
from PIL import Image
import cv2
import osclass VideoNormalize(nn.Module):def __init__(self, mean, std):super().__init__()self.register_buffer("mean", torch.tensor(mean).view(-1, 1, 1, 1))  # [C,1,1,1]self.register_buffer("std", torch.tensor(std).view(-1, 1, 1, 1))    # [C,1,1,1]def forward(self, x):# x: [C, T, H, W]return (x - self.mean) / self.stdclass X3DVideoClassifier:def __init__(self, model_type='x3d_xs', num_classes=2, device='auto'):"""初始化X3D视频分类器Args:model_type: 模型类型 ('x3d_xs', 'x3d_s', 'x3d_m')num_classes: 分类数量device: 运行设备 ('auto', 'cuda', 'cpu')"""self.model_type = model_typeself.num_classes = num_classesself.device = deviceif device == 'auto':self.device = 'cuda' if torch.cuda.is_available() else 'cpu'self._load_model()self._setup_transforms()def _load_model(self):"""加载预训练模型"""# model_map = {'x3d_xs': x3d_xs, 'x3d_s': x3d_s, 'x3d_m': x3d_m}# if self.model_type not in model_map:#     raise ValueError(f"不支持的模型类型: {self.model_type}")# 加载预训练模型self.model = torch.hub.load("facebookresearch/pytorchvideo", "x3d_s", pretrained=True)# 替换最后的分类层in_features = self.model.blocks[-1].proj.in_featuresself.model.blocks[-1].proj = torch.nn.Linear(in_features, self.num_classes)self.model.to(self.device)self.model.eval()print(f"已加载 {self.model_type} 模型到 {self.device}")def _setup_transforms(self):"""设置数据预处理流程"""self.transform = T.Compose([T.Lambda(lambda x: x / 255.0),  # 归一化到 [0, 1]T.Lambda(lambda x: x.permute(3, 0, 1, 2)),  # [T, H, W, C] -> [C, T, H, W]T.Resize((72, 72)),  # 调整到72x72VideoNormalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225]), ])def load_video(self, video_path, max_frames=16):"""加载视频文件Args:video_path: 视频文件路径max_frames: 最大帧数"""if not os.path.exists(video_path):raise FileNotFoundError(f"视频文件不存在: {video_path}")# 使用OpenCV读取视频cap = cv2.VideoCapture(video_path)frames = []while len(frames) < max_frames:ret, frame = cap.read()if not ret:break# 转换BGR到RGBframe_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)frames.append(frame_rgb)cap.release()if len(frames) == 0:raise ValueError("无法从视频中读取帧")# 转换为tensor [T, H, W, C]video_tensor = torch.from_numpy(np.array(frames)).float()return video_tensordef preprocess_video(self, video_tensor):"""预处理视频数据"""# 应用变换processed = self.transform(video_tensor)# 添加batch维度 [1, C, T, H, W]processed = processed.unsqueeze(0)return processed.to(self.device)def predict(self, video_path, class_names=None):"""对视频进行分类预测Args:video_path: 视频文件路径class_names: 类别名称列表"""# 加载视频print(f"正在加载视频: {video_path}")video_tensor = self.load_video(video_path)print(f"视频帧数: {video_tensor.shape[0]}")# 预处理input_tensor = self.preprocess_video(video_tensor)print(f"输入张量形状: {input_tensor.shape}")# 推理with torch.no_grad():outputs = self.model(input_tensor)probabilities = F.softmax(outputs, dim=1)confidence, predicted_idx = torch.max(probabilities, 1)# 处理结果confidence = confidence.item()predicted_idx = predicted_idx.item()if class_names and len(class_names) > predicted_idx:predicted_class = class_names[predicted_idx]else:predicted_class = f"Class {predicted_idx}"return {'predicted_class': predicted_class, 'confidence': confidence, 'class_index': predicted_idx, 'probabilities': probabilities.cpu().numpy()[0]}def predict_from_tensor(self, video_tensor, class_names=None):"""直接从张量进行预测"""input_tensor = self.preprocess_video(video_tensor)with torch.no_grad():outputs = self.model(input_tensor)probabilities = F.softmax(outputs, dim=1)confidence, predicted_idx = torch.max(probabilities, 1)confidence = confidence.item()predicted_idx = predicted_idx.item()if class_names and len(class_names) > predicted_idx:predicted_class = class_names[predicted_idx]else:predicted_class = f"Class {predicted_idx}"return {'predicted_class': predicted_class, 'confidence': confidence, 'class_index': predicted_idx, 'probabilities': probabilities.cpu().numpy()[0]}# 使用示例
def main():# 初始化分类器classifier = X3DVideoClassifier(model_type='x3d_s',  # 使用超小版本,速度最快num_classes=2,  # 2分类任务device='auto'  # 自动选择设备)# 类别名称(根据你的任务修改)class_names = ["类别A", "类别B"]# 示例1: 从文件预测video_path = r"C:\Users\Administrator\Videos\xiaoxia.mp4"try:result = classifier.predict(video_path, class_names)print("\n" + "=" * 50)print("视频分类结果:")print(f"预测类别: {result['predicted_class']}")print(f"置信度: {result['confidence']:.4f}")print(f"类别索引: {result['class_index']}")print("各类别概率:")for i, prob in enumerate(result['probabilities']):class_name = class_names[i] if i < len(class_names) else f"Class {i}"print(f"  {class_name}: {prob:.4f}")print("=" * 50)except Exception as e:print(f"错误: {e}")print("使用随机张量进行演示...")# 示例2: 使用随机张量演示random_video = torch.randn(16, 72, 72, 3)  # [T, H, W, C]result = classifier.predict_from_tensor(random_video, class_names)print("\n随机张量演示结果:")print(f"预测类别: {result['predicted_class']}")print(f"置信度: {result['confidence']:.4f}")if __name__ == "__main__":main()

R(2+1)D

R(2+1)D 将 3D 卷积分解为空间 2D 卷积和时间 1D 卷积,在性能和效率上取得了很好的平衡。

import torch
from pytorchvideo.models import resnet# R(2+1)D-18, 预训练在 Kinetics-400
model = resnet.create_r2plus1d(input_channel=3,         # RGBmodel_depth=18,          # ResNet18 backbonemodel_num_class=400,     # Kinetics-400 分类数pretrained=True
)

X3D(轻量级,速度快)


from pytorchvideo.models import x3d# X3D-Medium (还有 XS, S, L 版本)
model = x3d.create_x3d(input_channel=3,model_num_class=400,   # Kinetics-400model_depth=50,pretrained=True,model_variant="M"      # XS / S / M / L
)

I3D(经典 3D CNN)

from pytorchvideo.models import i3dmodel = i3d.create_kinetics_resnet50(pretrained=True,model_num_class=400   # Kinetics-400
)

替换分类层(适配你的任务)

假设你的视频只有 num_classes=5

 
num_classes = 5# 替换掉最后的分类层 (fc 或 proj depending on model)
if hasattr(model, "blocks"):  # R(2+1)D / I3D 用这种方式model.blocks[-1].proj = torch.nn.Linear(model.blocks[-1].proj.in_features, num_classes)
else:# X3Dmodel.head.proj = torch.nn.Linear(model.head.proj.in_features, num_classes)


文章转载自:

http://EoLiZraT.xrLwr.cn
http://kg58rbDY.xrLwr.cn
http://lcfAShqu.xrLwr.cn
http://efPjJu1n.xrLwr.cn
http://EMhuYvrD.xrLwr.cn
http://uAR7MGj1.xrLwr.cn
http://s0JMTYht.xrLwr.cn
http://yhxuX207.xrLwr.cn
http://L0ZsqwwV.xrLwr.cn
http://wppWWO1z.xrLwr.cn
http://ECxyTXDB.xrLwr.cn
http://cStphvU6.xrLwr.cn
http://Ii6UwYhZ.xrLwr.cn
http://lfL7SsDS.xrLwr.cn
http://U1rK3PRD.xrLwr.cn
http://dh4s6JIa.xrLwr.cn
http://08kFT6QV.xrLwr.cn
http://11PHnEqK.xrLwr.cn
http://kADMz4oh.xrLwr.cn
http://fUDsimZv.xrLwr.cn
http://oPYoG5c7.xrLwr.cn
http://gTtSSvY7.xrLwr.cn
http://znKrvGOk.xrLwr.cn
http://Lxg8Qbj2.xrLwr.cn
http://biQ6sgxZ.xrLwr.cn
http://GfOtFRAX.xrLwr.cn
http://q5ZNDPrp.xrLwr.cn
http://PyrfwqeJ.xrLwr.cn
http://jifUH7M1.xrLwr.cn
http://oeFMsEE9.xrLwr.cn
http://www.dtcms.com/a/384718.html

相关文章:

  • RabbitMQ 基础概念与原理
  • 专题:2025中国消费市场趋势与数字化转型研究报告|附360+份报告PDF、数据仪表盘汇总下载
  • 预制菜行业新风向:企业运营与商家协同发展的实践启示
  • 晶台光耦 KL6N137 :以精密光电技术驱动智能开关性能提升
  • 贪心算法应用:最短作业优先(SJF)调度问题详解
  • javaee初阶 文件IO
  • 如何调整滚珠丝杆的反向间隙?
  • Python项目中的包添加后为什么要进行可编辑安装?
  • daily notes[45]
  • 基于51单片机的蓝牙体温计app设计
  • Git版本控制完全指南
  • 【CSS】一个自适应大小的父元素,如何让子元素的宽高比一直是2:1
  • 前端通过地址生成自定义二维码实战(带源码)
  • Android Doze低电耗休眠模式 与 WorkManager
  • 用 Go 重写 adbkit:原理、架构与实现实践
  • 通过Magisk service.d 脚本实现手机开机自动开启无线 ADB
  • NineData社区版 V4.5.0 正式发布!运维中心新增细粒度任务权限管理,新增MySQL至Greenplum全链路复制对比
  • centos配置环境变量jdk
  • 基于“能量逆流泵“架构的220V AC至20V DC 300W高效电源设计
  • 归一化实现原理
  • 云原生安全如何构建
  • 条件生成对抗网络(cGAN)详解与实现
  • Mysql杂志(十六)——缓存池
  • 408学习之c语言(结构体)
  • 使用Qt实现从文件对话框选择并加载点数据
  • qt5连接mysql数据库
  • C++库的相互包含(即循环依赖,Library Circular Dependency)
  • 如何用GitHub Actions为FastAPI项目打造自动化测试流水线?
  • LVS与Keepalived详解(二)LVS负载均衡实现实操
  • 闪电科创-无人机轨迹预测SCI/EI会议辅导