当前位置: 首页 > news >正文

PyTorch Image Models (timm) 技术指南

timm

  • PyTorch Image Models (timm) 技术指南
      • 功能概述
    • 一、引言
    • 二、timm 库概述
    • 三、安装 timm 库
    • 四、模型加载与推理示例
      • 4.1 通用推理流程
      • 4.2 具体模型示例
        • 4.2.1 ResNeXt50-32x4d
        • 4.2.2 EfficientNet-V2 Small 模型
        • 4.2.3 DeiT-3 large 模型
        • 4.2.4 RepViT-M2 模型
        • 4.2.5 ResNet-RS-101
        • 4.2.6 Vision Transformer (ViT)
        • 4.2.7 Swin Transformer
        • 4.2.8 Swin Transformer V2
        • 4.2.9 Swin Transformer V2 Cr
        • 4.2.10 Levit
      • 4.3 加载自定义模型
      • 4.4 提取模型的中间特征
      • 4.5 冻结模型的部分层
      • 4.6 创建模型时指定输入图像尺寸
      • 4.7 数据预处理阶段调整图像尺寸
      • 4.8 调整输出分类个数
      • 4.9 综合示例
        • 更多模型说明
    • 五、timm 库近期更新
      • 5.1 2025 年 2 月 21 日更新
      • 5.2 其他更新
    • 六、分布式训练支持
    • 七、学习率调度器
      • 7.1 余弦退火调度器(CosineLRScheduler)
      • 7.2 多步学习率调度器(MultiStepLRScheduler)
    • 八、总结
    • 九、参考资料

PyTorch Image Models (timm) 技术指南

timm(PyTorch Image Models)是一个广泛使用的 PyTorch 库,它集合了大量的图像模型、层、实用工具、优化器、调度器、数据加载器/增强器以及参考训练/验证脚本。以下是对 timm 库的详细介绍,包括功能、模型案例、加载与使用示例以及相关教程的信息。

功能概述

  1. 丰富的图像模型:包含众多预训练的图像分类、目标检测、语义分割等模型,如 ResNet、EfficientNet、ViT 等。
  2. 实用工具:提供了一系列用于模型训练、验证和推理的实用工具,如优化器、调度器、数据加载器和增强器等。
  3. 模型构建与管理:支持轻松构建和管理不同类型的模型,包括模型的初始化、权重加载和保存等。
  4. 分布式训练:支持分布式训练,方便在多个 GPU 或节点上进行高效训练。

一、引言

在深度学习领域,图像分类、目标检测等任务常常需要使用预训练的图像模型。PyTorch Image Models (timm) 是一个功能强大的库,它提供了大量预训练的图像模型,涵盖了各种架构,方便开发者快速搭建和训练自己的模型。本文将详细介绍 timm 库的使用,包括模型加载、推理以及近期更新的模型和功能。

二、timm 库概述

timm 是一个基于 PyTorch 的图像模型库,它收集了众多先进的图像模型,如 ResNet、ViT、Swin Transformer 等,并提供了预训练的权重。通过 timm,开发者可以轻松地加载这些模型,进行图像分类、特征提取等任务。

三、安装 timm 库

在使用 timm 之前,需要先安装该库。可以使用以下命令进行安装:

pip install timm

四、模型加载与推理示例

4.1 通用推理流程

以下是一个通用的使用 timm 加载模型并进行推理的示例代码:

import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练模型
model = timm.create_model('model_name', pretrained=True)
model.eval()# 定义图像预处理转换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"Predicted class index: {predicted_idx.item()}")

在上述代码中,'model_name' 需要替换为具体的模型名称,'path_to_your_image.jpg' 需要替换为实际的图像文件路径。

4.2 具体模型示例

4.2.1 ResNeXt50-32x4d
import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练的 ResNeXt50-32x4d 模型
model = timm.create_model('resnext50_32x4d', pretrained=True)
model.eval()# 定义图像预处理转换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"ResNeXt50-32x4d Predicted class index: {predicted_idx.item()}")
4.2.2 EfficientNet-V2 Small 模型
import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练的 EfficientNet-V2 Small 模型
model = timm.create_model('efficientnetv2_s', pretrained=True)
model.eval()# 定义图像预处理转换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"EfficientNet-V2 Small Predicted class index: {predicted_idx.item()}")
4.2.3 DeiT-3 large 模型
import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练的 DeiT-3 large 模型
model = timm.create_model('deit3_large_patch16_384', pretrained=True)
model.eval()# 定义图像预处理转换,注意输入尺寸为 384x384
transform = transforms.Compose([transforms.Resize(384),transforms.CenterCrop(384),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"DeiT-3 large Predicted class index: {predicted_idx.item()}")
4.2.4 RepViT-M2 模型
import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练的 RepViT-M2 模型
model = timm.create_model('repvit_m2', pretrained=True)
model.eval()# 定义图像预处理转换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"RepViT-M2 Predicted class index: {predicted_idx.item()}")
4.2.5 ResNet-RS-101
import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练的 ResNet-RS-101 模型
model = timm.create_model('resnetrs101', pretrained=True)
model.eval()# 定义图像预处理转换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"ResNet-RS-101 Predicted class index: {predicted_idx.item()}")
4.2.6 Vision Transformer (ViT)
import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练的 ViT-Base/32 模型
model = timm.create_model('vit_base_patch32_224', pretrained=True)
model.eval()# 定义图像预处理转换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"ViT-Base/32 Predicted class index: {predicted_idx.item()}")
4.2.7 Swin Transformer
import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练的 Swin Transformer 模型
model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
model.eval()# 定义图像预处理转换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"Swin Transformer Predicted class index: {predicted_idx.item()}")
4.2.8 Swin Transformer V2
import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练的 Swin Transformer V2 模型
model = timm.create_model('swinv2_base_window12_192_22k', pretrained=True)
model.eval()# 定义图像预处理转换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"Swin Transformer V2 Predicted class index: {predicted_idx.item()}")
4.2.9 Swin Transformer V2 Cr
import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练的 Swin Transformer V2 Cr 模型
model = timm.create_model('swinv2_cr_base_224', pretrained=True)
model.eval()# 定义图像预处理转换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"Swin Transformer V2 Cr Predicted class index: {predicted_idx.item()}")
4.2.10 Levit
import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练的 Levit 模型
model = timm.create_model('levit_256', pretrained=True)
model.eval()# 定义图像预处理转换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"Levit Predicted class index: {predicted_idx.item()}")

4.3 加载自定义模型

如果你需要加载自定义的模型,可以使用 timm.create_model 函数,并指定模型的名称和相关参数:

import timm# 创建自定义的 EfficientNet 模型
model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=10)
print(model)

4.4 提取模型的中间特征

import torch
import timm# 加载预训练的模型
model = timm.create_model('resnet18', pretrained=True, features_only=True)# 生成随机输入
x = torch.randn(1, 3, 224, 224)# 提取中间特征
features = model(x)
for i, feat in enumerate(features):print(f"Feature {i} shape: {feat.shape}")

4.5 冻结模型的部分层

import torch
import timm
from timm.utils.model import freeze# 加载预训练的模型
model = timm.create_model('resnet18', pretrained=True)# 冻结模型的前几层
submodules = [n for n, _ in model.named_children()]
freeze(model, submodules[:submodules.index('layer2') + 1])# 检查冻结情况
print(model.layer2[0].conv1.weight.requires_grad)  # 输出: False
print(model.layer3[0].conv1.weight.requires_grad)  # 输出: True

在使用 timm 库加载预训练模型后,我们经常需要根据具体的任务需求调整模型的参数,例如输入图像尺寸、输出分类个数等。下面将结合提供的代码片段详细介绍如何进行这些参数的调整。

4.6 创建模型时指定输入图像尺寸

部分模型在创建时可以通过 img_size 参数指定输入图像的尺寸。以下是一个使用 SwinTransformer 模型的示例:

import timm
import torch# 加载预训练的 SwinTransformer 模型,并指定输入图像尺寸为 384x384
model = timm.create_model('swin_base_patch4_window7_224', pretrained=True, img_size=(384, 384))
model.eval()# 随机生成一个符合指定尺寸的输入张量进行测试
input_tensor = torch.randn(1, 3, 384, 384)
with torch.no_grad():output = model(input_tensor)
print("Output shape:", output.shape)

在上述代码中,我们通过 img_size=(384, 384) 指定了输入图像的尺寸为 384x384。

4.7 数据预处理阶段调整图像尺寸

除了在创建模型时指定输入图像尺寸,还需要在数据预处理阶段将输入图像调整为指定的尺寸。可以使用 torchvision.transforms 来实现这一点,示例如下:

from torchvision import transforms
from PIL import Image# 定义图像预处理转换,将图像调整为 384x384
transform = transforms.Compose([transforms.Resize((384, 384)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)
print("Output shape:", output.shape)

在这个示例中,我们使用 transforms.Resize((384, 384)) 将输入图像调整为 384x384 的尺寸。

4.8 调整输出分类个数

输出分类个数的调整通常在创建模型时通过 num_classes 参数来实现。以下是一个使用 MetaFormer 模型的示例:

import timm# 加载预训练的 MetaFormer 模型,并指定输出分类个数为 10
model = timm.create_model('metaformer', pretrained=True, num_classes=10)
model.eval()# 随机生成一个输入张量进行测试
input_tensor = torch.randn(1, 3, 224, 224)
with torch.no_grad():output = model(input_tensor)
print("Output shape:", output.shape)

在上述代码中,我们通过 num_classes=10 指定了模型的输出分类个数为 10。

4.9 综合示例

下面是一个综合示例,展示了如何同时调整输入图像尺寸和输出分类个数:

import timm
import torch
from torchvision import transforms
from PIL import Image# 加载预训练的 SwinTransformer 模型,调整输入图像尺寸为 384x384,输出分类个数为 10
model = timm.create_model('swin_base_patch4_window7_224', pretrained=True, img_size=(384, 384), num_classes=10)
model.eval()# 定义图像预处理转换,将图像调整为 384x384
transform = transforms.Compose([transforms.Resize((384, 384)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)
print("Output shape:", output.shape)

在这个综合示例中,我们同时调整了输入图像尺寸和输出分类个数,并进行了图像预处理和推理操作。

通过以上方法,我们可以根据具体的任务需求灵活调整预训练模型的输入图像尺寸和输出分类个数。

更多模型说明

除了上述示例中的模型,timm 库还包含了许多其他的模型,如 Aggregating Nested Transformers、BEiT、Big Transfer ResNetV2 (BiT) 等。你可以在 timm 的官方文档 https://huggingface.co/docs/timm 中找到完整的模型列表。

要使用其他模型,只需将 timm.create_model 函数中的模型名称替换为你想要使用的模型名称即可。例如,要使用 BEiT 模型,可以使用以下代码:

import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练的 BEiT 模型
model = timm.create_model('beit_base_patch16_224', pretrained=True)
model.eval()# 定义图像预处理转换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"BEiT Predicted class index: {predicted_idx.item()}")

五、timm 库近期更新

5.1 2025 年 2 月 21 日更新

  • 新增 SigLIP 2 ViT 图像编码器:可从 https://huggingface.co/collections/timm/siglip-2-67b8e72ba08b09dd97aecaf9 获取。
  • 新增 ‘SO150M2’ ViT 权重:使用 SBB 配方训练,在 ImageNet 上取得了很好的效果。例如,vit_so150m2_patch16_reg1_gap_448.sbb_e200_in12k_ft_in1k 的 top-1 准确率达到 88.1%。
  • 更新 InternViT - 300M ‘2.5’ 权重
  • 发布 1.0.15 版本

5.2 其他更新

在 2025 年 1 月至 2024 年 10 月期间,timm 库还进行了许多其他更新,包括添加新的优化器(如 Kron Optimizer、MARS 优化器等)、支持新的模型(如 convnext_nano、AIM - v2 编码器等)、修复一些 bug 以及改进代码结构等。

六、分布式训练支持

timm 库还提供了分布式训练的支持,相关代码在 timm/utils/distributed.py 中。以下是一些关键函数的介绍:

  • reduce_tensor:用于在分布式环境中对张量进行规约操作。
  • distribute_bn:确保每个节点具有相同的运行时 BN 统计信息。
  • init_distributed_device:初始化分布式训练设备。

以下是一个简单的分布式训练初始化示例:

import torch
from timm.utils.distributed import init_distributed_deviceargs = type('', (), {})()  # 创建一个空的参数对象
device = init_distributed_device(args)
print(f"Device: {device}, World size: {args.world_size}, Rank: {args.rank}")

七、学习率调度器

timm 库提供了多种学习率调度器,可在 timm/scheduler 目录下找到相关代码。以下是一些常见的调度器及其使用示例:

7.1 余弦退火调度器(CosineLRScheduler)

import torch
import timm
from timm.scheduler.scheduler_factory import create_scheduler# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 定义调度器参数
scheduler_args = type('', (), {'sched': 'cosine','epochs': 100,'decay_epochs': 30,'warmup_epochs': 5
})()# 创建调度器
scheduler, num_epochs = create_scheduler(scheduler_args, optimizer)# 训练循环
for epoch in range(num_epochs):# 训练代码...scheduler.step(epoch)

7.2 多步学习率调度器(MultiStepLRScheduler)

import torch
import timm
from timm.scheduler.scheduler_factory import create_scheduler# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 定义调度器参数
scheduler_args = type('', (), {'sched': 'multistep','epochs': 100,'decay_milestones': [30, 60],'decay_rate': 0.1,'warmup_epochs': 5
})()# 创建调度器
scheduler, num_epochs = create_scheduler(scheduler_args, optimizer)# 训练循环
for epoch in range(num_epochs):# 训练代码...scheduler.step(epoch)

八、总结

PyTorch Image Models (timm) 是一个非常实用的图像模型库,它提供了丰富的预训练模型和便捷的使用接口,同时支持分布式训练和多种学习率调度器。通过本文的介绍,你可以快速上手 timm 库,进行图像分类等任务的开发。希望本文对你有所帮助,祝你在深度学习领域取得更好的成果!

九、参考资料

  • timm 官方文档:https://huggingface.co/docs/timm
  • timm 代码库:https://github.com/rwightman/pytorch-image-models

相关文章:

  • SRS流媒体服务器(7)源码分析之拉流篇
  • 进程守护服务优点
  • 《解锁Claude4:开启AI交互新体验》
  • SRS流媒体服务器之RTC播放环境搭建
  • 蓝桥杯单片机答题技巧
  • log日志最佳实践
  • openssl 使用生成key pem
  • C#创建桌面快捷方式:使用 WSH 实现快捷方式生成
  • 机器学习-模型选择与调优
  • Python Day32 学习
  • LeetCode 每日一题 2025/5/19-2025/5/25
  • 每日算法刷题计划Day15 5.25:leetcode不定长滑动窗口求子数组个数越短越合法3道题,用时1h
  • python 实现从座位图中识别不同颜色和数量的座位并以JSON格式输出的功能
  • GO 语言基础3 struct 结构体
  • C++ 定义一个结构体,用class还是struct
  • day 36
  • 自定义 win10 命令
  • 人工智能数学基础实验(四):最大似然估计的-AI 模型训练与参数优化
  • 人工智能数学基础实验(五):牛顿优化法-电动汽车充电站选址优化
  • Pandas数据规整
  • 58同城做网站/网络营销策划书2000字
  • 06年可以做相册视频的网站/淘宝美工培训
  • 网站好坏标准/外贸快车
  • 做平面设计什么素材网站好使/百度一下官方网址
  • 乐清网站建设honmau/培训学校机构有哪些
  • 嘉兴云推广网站/google广告投放技巧