PyTorch Image Models (timm) 技术指南
timm
- PyTorch Image Models (timm) 技术指南
- 功能概述
- 一、引言
- 二、timm 库概述
- 三、安装 timm 库
- 四、模型加载与推理示例
- 4.1 通用推理流程
- 4.2 具体模型示例
- 4.2.1 ResNeXt50-32x4d
- 4.2.2 EfficientNet-V2 Small 模型
- 4.2.3 DeiT-3 large 模型
- 4.2.4 RepViT-M2 模型
- 4.2.5 ResNet-RS-101
- 4.2.6 Vision Transformer (ViT)
- 4.2.7 Swin Transformer
- 4.2.8 Swin Transformer V2
- 4.2.9 Swin Transformer V2 Cr
- 4.2.10 Levit
- 4.3 加载自定义模型
- 4.4 提取模型的中间特征
- 4.5 冻结模型的部分层
- 4.6 创建模型时指定输入图像尺寸
- 4.7 数据预处理阶段调整图像尺寸
- 4.8 调整输出分类个数
- 4.9 综合示例
- 更多模型说明
- 五、timm 库近期更新
- 5.1 2025 年 2 月 21 日更新
- 5.2 其他更新
- 六、分布式训练支持
- 七、学习率调度器
- 7.1 余弦退火调度器(CosineLRScheduler)
- 7.2 多步学习率调度器(MultiStepLRScheduler)
- 八、总结
- 九、参考资料
PyTorch Image Models (timm) 技术指南
timm
(PyTorch Image Models)是一个广泛使用的 PyTorch 库,它集合了大量的图像模型、层、实用工具、优化器、调度器、数据加载器/增强器以及参考训练/验证脚本。以下是对 timm
库的详细介绍,包括功能、模型案例、加载与使用示例以及相关教程的信息。
功能概述
- 丰富的图像模型:包含众多预训练的图像分类、目标检测、语义分割等模型,如 ResNet、EfficientNet、ViT 等。
- 实用工具:提供了一系列用于模型训练、验证和推理的实用工具,如优化器、调度器、数据加载器和增强器等。
- 模型构建与管理:支持轻松构建和管理不同类型的模型,包括模型的初始化、权重加载和保存等。
- 分布式训练:支持分布式训练,方便在多个 GPU 或节点上进行高效训练。
一、引言
在深度学习领域,图像分类、目标检测等任务常常需要使用预训练的图像模型。PyTorch Image Models (timm)
是一个功能强大的库,它提供了大量预训练的图像模型,涵盖了各种架构,方便开发者快速搭建和训练自己的模型。本文将详细介绍 timm
库的使用,包括模型加载、推理以及近期更新的模型和功能。
二、timm 库概述
timm
是一个基于 PyTorch 的图像模型库,它收集了众多先进的图像模型,如 ResNet、ViT、Swin Transformer 等,并提供了预训练的权重。通过 timm
,开发者可以轻松地加载这些模型,进行图像分类、特征提取等任务。
三、安装 timm 库
在使用 timm
之前,需要先安装该库。可以使用以下命令进行安装:
pip install timm
四、模型加载与推理示例
4.1 通用推理流程
以下是一个通用的使用 timm
加载模型并进行推理的示例代码:
import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练模型
model = timm.create_model('model_name', pretrained=True)
model.eval()# 定义图像预处理转换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"Predicted class index: {predicted_idx.item()}")
在上述代码中,'model_name'
需要替换为具体的模型名称,'path_to_your_image.jpg'
需要替换为实际的图像文件路径。
4.2 具体模型示例
4.2.1 ResNeXt50-32x4d
import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练的 ResNeXt50-32x4d 模型
model = timm.create_model('resnext50_32x4d', pretrained=True)
model.eval()# 定义图像预处理转换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"ResNeXt50-32x4d Predicted class index: {predicted_idx.item()}")
4.2.2 EfficientNet-V2 Small 模型
import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练的 EfficientNet-V2 Small 模型
model = timm.create_model('efficientnetv2_s', pretrained=True)
model.eval()# 定义图像预处理转换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"EfficientNet-V2 Small Predicted class index: {predicted_idx.item()}")
4.2.3 DeiT-3 large 模型
import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练的 DeiT-3 large 模型
model = timm.create_model('deit3_large_patch16_384', pretrained=True)
model.eval()# 定义图像预处理转换,注意输入尺寸为 384x384
transform = transforms.Compose([transforms.Resize(384),transforms.CenterCrop(384),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"DeiT-3 large Predicted class index: {predicted_idx.item()}")
4.2.4 RepViT-M2 模型
import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练的 RepViT-M2 模型
model = timm.create_model('repvit_m2', pretrained=True)
model.eval()# 定义图像预处理转换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"RepViT-M2 Predicted class index: {predicted_idx.item()}")
4.2.5 ResNet-RS-101
import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练的 ResNet-RS-101 模型
model = timm.create_model('resnetrs101', pretrained=True)
model.eval()# 定义图像预处理转换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"ResNet-RS-101 Predicted class index: {predicted_idx.item()}")
4.2.6 Vision Transformer (ViT)
import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练的 ViT-Base/32 模型
model = timm.create_model('vit_base_patch32_224', pretrained=True)
model.eval()# 定义图像预处理转换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"ViT-Base/32 Predicted class index: {predicted_idx.item()}")
4.2.7 Swin Transformer
import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练的 Swin Transformer 模型
model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
model.eval()# 定义图像预处理转换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"Swin Transformer Predicted class index: {predicted_idx.item()}")
4.2.8 Swin Transformer V2
import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练的 Swin Transformer V2 模型
model = timm.create_model('swinv2_base_window12_192_22k', pretrained=True)
model.eval()# 定义图像预处理转换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"Swin Transformer V2 Predicted class index: {predicted_idx.item()}")
4.2.9 Swin Transformer V2 Cr
import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练的 Swin Transformer V2 Cr 模型
model = timm.create_model('swinv2_cr_base_224', pretrained=True)
model.eval()# 定义图像预处理转换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"Swin Transformer V2 Cr Predicted class index: {predicted_idx.item()}")
4.2.10 Levit
import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练的 Levit 模型
model = timm.create_model('levit_256', pretrained=True)
model.eval()# 定义图像预处理转换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"Levit Predicted class index: {predicted_idx.item()}")
4.3 加载自定义模型
如果你需要加载自定义的模型,可以使用 timm.create_model
函数,并指定模型的名称和相关参数:
import timm# 创建自定义的 EfficientNet 模型
model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=10)
print(model)
4.4 提取模型的中间特征
import torch
import timm# 加载预训练的模型
model = timm.create_model('resnet18', pretrained=True, features_only=True)# 生成随机输入
x = torch.randn(1, 3, 224, 224)# 提取中间特征
features = model(x)
for i, feat in enumerate(features):print(f"Feature {i} shape: {feat.shape}")
4.5 冻结模型的部分层
import torch
import timm
from timm.utils.model import freeze# 加载预训练的模型
model = timm.create_model('resnet18', pretrained=True)# 冻结模型的前几层
submodules = [n for n, _ in model.named_children()]
freeze(model, submodules[:submodules.index('layer2') + 1])# 检查冻结情况
print(model.layer2[0].conv1.weight.requires_grad) # 输出: False
print(model.layer3[0].conv1.weight.requires_grad) # 输出: True
在使用 timm
库加载预训练模型后,我们经常需要根据具体的任务需求调整模型的参数,例如输入图像尺寸、输出分类个数等。下面将结合提供的代码片段详细介绍如何进行这些参数的调整。
4.6 创建模型时指定输入图像尺寸
部分模型在创建时可以通过 img_size
参数指定输入图像的尺寸。以下是一个使用 SwinTransformer
模型的示例:
import timm
import torch# 加载预训练的 SwinTransformer 模型,并指定输入图像尺寸为 384x384
model = timm.create_model('swin_base_patch4_window7_224', pretrained=True, img_size=(384, 384))
model.eval()# 随机生成一个符合指定尺寸的输入张量进行测试
input_tensor = torch.randn(1, 3, 384, 384)
with torch.no_grad():output = model(input_tensor)
print("Output shape:", output.shape)
在上述代码中,我们通过 img_size=(384, 384)
指定了输入图像的尺寸为 384x384。
4.7 数据预处理阶段调整图像尺寸
除了在创建模型时指定输入图像尺寸,还需要在数据预处理阶段将输入图像调整为指定的尺寸。可以使用 torchvision.transforms
来实现这一点,示例如下:
from torchvision import transforms
from PIL import Image# 定义图像预处理转换,将图像调整为 384x384
transform = transforms.Compose([transforms.Resize((384, 384)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)
print("Output shape:", output.shape)
在这个示例中,我们使用 transforms.Resize((384, 384))
将输入图像调整为 384x384 的尺寸。
4.8 调整输出分类个数
输出分类个数的调整通常在创建模型时通过 num_classes
参数来实现。以下是一个使用 MetaFormer
模型的示例:
import timm# 加载预训练的 MetaFormer 模型,并指定输出分类个数为 10
model = timm.create_model('metaformer', pretrained=True, num_classes=10)
model.eval()# 随机生成一个输入张量进行测试
input_tensor = torch.randn(1, 3, 224, 224)
with torch.no_grad():output = model(input_tensor)
print("Output shape:", output.shape)
在上述代码中,我们通过 num_classes=10
指定了模型的输出分类个数为 10。
4.9 综合示例
下面是一个综合示例,展示了如何同时调整输入图像尺寸和输出分类个数:
import timm
import torch
from torchvision import transforms
from PIL import Image# 加载预训练的 SwinTransformer 模型,调整输入图像尺寸为 384x384,输出分类个数为 10
model = timm.create_model('swin_base_patch4_window7_224', pretrained=True, img_size=(384, 384), num_classes=10)
model.eval()# 定义图像预处理转换,将图像调整为 384x384
transform = transforms.Compose([transforms.Resize((384, 384)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)
print("Output shape:", output.shape)
在这个综合示例中,我们同时调整了输入图像尺寸和输出分类个数,并进行了图像预处理和推理操作。
通过以上方法,我们可以根据具体的任务需求灵活调整预训练模型的输入图像尺寸和输出分类个数。
更多模型说明
除了上述示例中的模型,timm
库还包含了许多其他的模型,如 Aggregating Nested Transformers、BEiT、Big Transfer ResNetV2 (BiT) 等。你可以在 timm
的官方文档 https://huggingface.co/docs/timm 中找到完整的模型列表。
要使用其他模型,只需将 timm.create_model
函数中的模型名称替换为你想要使用的模型名称即可。例如,要使用 BEiT 模型,可以使用以下代码:
import torch
import timm
from PIL import Image
from torchvision import transforms# 加载预训练的 BEiT 模型
model = timm.create_model('beit_base_patch16_224', pretrained=True)
model.eval()# 定义图像预处理转换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载图像
image = Image.open('path_to_your_image.jpg')
image = transform(image).unsqueeze(0)# 进行推理
with torch.no_grad():output = model(image)# 获取预测结果
_, predicted_idx = torch.max(output, 1)
print(f"BEiT Predicted class index: {predicted_idx.item()}")
五、timm 库近期更新
5.1 2025 年 2 月 21 日更新
- 新增 SigLIP 2 ViT 图像编码器:可从 https://huggingface.co/collections/timm/siglip-2-67b8e72ba08b09dd97aecaf9 获取。
- 新增 ‘SO150M2’ ViT 权重:使用 SBB 配方训练,在 ImageNet 上取得了很好的效果。例如,
vit_so150m2_patch16_reg1_gap_448.sbb_e200_in12k_ft_in1k
的 top-1 准确率达到 88.1%。 - 更新 InternViT - 300M ‘2.5’ 权重。
- 发布 1.0.15 版本。
5.2 其他更新
在 2025 年 1 月至 2024 年 10 月期间,timm
库还进行了许多其他更新,包括添加新的优化器(如 Kron Optimizer、MARS 优化器等)、支持新的模型(如 convnext_nano
、AIM - v2 编码器等)、修复一些 bug 以及改进代码结构等。
六、分布式训练支持
timm
库还提供了分布式训练的支持,相关代码在 timm/utils/distributed.py
中。以下是一些关键函数的介绍:
reduce_tensor
:用于在分布式环境中对张量进行规约操作。distribute_bn
:确保每个节点具有相同的运行时 BN 统计信息。init_distributed_device
:初始化分布式训练设备。
以下是一个简单的分布式训练初始化示例:
import torch
from timm.utils.distributed import init_distributed_deviceargs = type('', (), {})() # 创建一个空的参数对象
device = init_distributed_device(args)
print(f"Device: {device}, World size: {args.world_size}, Rank: {args.rank}")
七、学习率调度器
timm
库提供了多种学习率调度器,可在 timm/scheduler
目录下找到相关代码。以下是一些常见的调度器及其使用示例:
7.1 余弦退火调度器(CosineLRScheduler)
import torch
import timm
from timm.scheduler.scheduler_factory import create_scheduler# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 定义调度器参数
scheduler_args = type('', (), {'sched': 'cosine','epochs': 100,'decay_epochs': 30,'warmup_epochs': 5
})()# 创建调度器
scheduler, num_epochs = create_scheduler(scheduler_args, optimizer)# 训练循环
for epoch in range(num_epochs):# 训练代码...scheduler.step(epoch)
7.2 多步学习率调度器(MultiStepLRScheduler)
import torch
import timm
from timm.scheduler.scheduler_factory import create_scheduler# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 定义调度器参数
scheduler_args = type('', (), {'sched': 'multistep','epochs': 100,'decay_milestones': [30, 60],'decay_rate': 0.1,'warmup_epochs': 5
})()# 创建调度器
scheduler, num_epochs = create_scheduler(scheduler_args, optimizer)# 训练循环
for epoch in range(num_epochs):# 训练代码...scheduler.step(epoch)
八、总结
PyTorch Image Models (timm)
是一个非常实用的图像模型库,它提供了丰富的预训练模型和便捷的使用接口,同时支持分布式训练和多种学习率调度器。通过本文的介绍,你可以快速上手 timm
库,进行图像分类等任务的开发。希望本文对你有所帮助,祝你在深度学习领域取得更好的成果!
九、参考资料
timm
官方文档:https://huggingface.co/docs/timmtimm
代码库:https://github.com/rwightman/pytorch-image-models