timm教程翻译:(一)Overview
https://timm.fast.ai/
1.How to use(如何使用)
1.1 Create a model(创造一个model)
import timm
import torchmodel = timm.create_model('resnet34')
x = torch.randn(1, 3, 224, 224)
model(x).shape
torch.Size([1, 1000])
使用 timm 创建模型就是这么简单。create_model
函数是一个工厂方法,可用于创建 timm 库中超过 300 个模型。
要创建预训练模型,只需传入 pretrained=True
。
pretrained_resnet_34 = timm.create_model('resnet34', pretrained=True)
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth" to /home/tmabraham/.cache/torch/hub/checkpoints/resnet34-43635321.pth
要创建具有自定义类别数的模型,只需传入 num_classes=<number_of_classes>
。
import timm
import torchmodel = timm.create_model('resnet34', num_classes=10)
x = torch.randn(1, 3, 224, 224)
model(x).shape
torch.Size([1, 10])
1.2 List Models with Pretrained Weights(列出所有预训练过的模型)
timm.list_models()
返回 timm 中可用模型的完整列表。要查看预训练模型的完整列表,请在 list_models 中传入 pretrained=True
。
avail_pretrained_models = timm.list_models(pretrained=True)
len(avail_pretrained_models), avail_pretrained_models[:5]
(592,['adv_inception_v3','bat_resnext26ts','beit_base_patch16_224','beit_base_patch16_224_in22k','beit_base_patch16_384'])
There are a total of 271 models with pretrained weights currently available in timm
!
1.3 Search for model architectures by Wildcard
It is also possible to search for model architectures using Wildcard as below:
还可以使用 Wildcard 搜索模型架构,如下所示:
all_densenet_models = timm.list_models('*densenet*')
all_densenet_models
['densenet121','densenet121d','densenet161','densenet169','densenet201','densenet264','densenet264d_iabn','densenetblur121d','tv_densenet121']
1.4 Fine-tune timm model in fastai
The fastai library has support for fine-tuning models from timm:
fasta
i 库支持来自 timm 的微调模型:
from fastai.vision.all import *path = untar_data(URLs.PETS)/'images'
dls = ImageDataLoaders.from_name_func(path, get_image_files(path), valid_pct=0.2,label_func=lambda x: x[0].isupper(), item_tfms=Resize(224))# if a string is passed into the model argument, it will now use timm (if it is installed)
learn = vision_learner(dls, 'vit_tiny_patch16_224', metrics=error_rate)learn.fine_tune(1)
epoch | train_loss | valid_loss | error_rate | time |
---|---|---|---|---|
0 | 0.201583 | 0.024980 | 0.006766 | 00:08 |
epoch | train_loss | valid_loss | error_rate | time |
---|---|---|---|---|
0 | 0.040622 | 0.024036 | 0.005413 | 00:10 |