Vision Prompt Tune(视觉提示微调)
方法记录
一、网络结构
可训练的参数
prompt module
PromptedVisionTransformer((transformer): PromptedTransformer((embeddings): Embeddings((patch_embeddings): Conv2...affine=True))(prompt_dropout): Dropout(p=0.0, inplace=False)(prompt_proj): Identity())(head): Identity()
)
head
MLP((projection): Sequential()(last_layer): Linear(in_features=768, out_features=200, bias=True)
)
总的结构
ViT((enc): PromptedVisionTransformer((transformer): PromptedTransformer((embeddings): Embeddings((patc...)(head): MLP((projection): Sequential()(last_layer): Linear(in_features=768, out_features=200, bias=True))
)
这里 继承的父类
class PromptedTransformer(Transformer):
而父类里面包含了骨干网络的设置
class Transformer(nn.Module):def __init__(self, config, img_size, vis):super(Transformer, self).__init__()self.embeddings = Embeddings(config, img_size=img_size)self.encoder = Encoder(config, vis)
Encoder中的ViT的Encoder Layer层的定义
class Block(nn.Module):def __init__(self, config, vis):super(Block, self).__init__()self.hidden_size = config.hidden_sizeself.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)self.ffn = Mlp(config)self.attn = Attention(config, vis)def forward(self, x):h = xx = self.attention_norm(x)x, weights = self.attn(x)x = x + hh = xx = self.ffn_norm(x)x = self.ffn(x)x = x + hreturn x, weights
其中的embeddings就是 ViT的 patch projection layer ,而encoder 就是 12个 Encoder Layer层
二、数据集配置
1、train.json 文件的转换
只使用了CUB_200_2011数据集,需要进行转换为json文件。
转换脚本如下:
import json
import osdef process_files(label_file, image_file, class_file, output_train, output_val):# 读取文档1:训练/测试划分 (0表示测试集,1表示训练集)with open(label_file, 'r') as f:labels = [int(line.split()[1]) for line in f.readlines()]# 读取文档2:图片路径with open(image_file, 'r') as f:image_paths = [line.split()[1] for line in f.readlines()]# 读取文档3:类别标签with open(class_file, 'r') as f:classes = [int(line.split()[1]) for line in f.readlines()]# 确保三个文件的长度一致assert len(labels) == len(image_paths) == len(classes), "三个文件的条目数不一致"# 创建训练集和测试集的字典train_data = {}val_data = {}for i in range(len(labels)):img_path = image_paths[i]class_id = classes[i]if labels[i] == 1: # 训练集train_data[img_path] = class_idelse: # 测试集val_data[img_path] = class_id# 写入JSON文件with open(output_train, 'w') as f:json.dump(train_data, f, indent=4)with open(output_val, 'w') as f:json.dump(val_data, f, indent=4)print(f"训练集已保存到 {output_train},包含 {len(train_data)} 个样本")print(f"验证集已保存到 {output_val},包含 {len(val_data)} 个样本")# 文件路径
label_file = r'E:\ZJX\Datasets\CUB_200_2011\CUB_200_2011\train_test_split.txt' # 替换为实际文件路径
image_file = r'E:\ZJX\Datasets\CUB_200_2011\CUB_200_2011\images.txt' # 替换为实际文件路径
class_file = r'E:\ZJX\Datasets\CUB_200_2011\CUB_200_2011\image_class_labels.txt' # 替换为实际文件路径
output_train = r'E:\ZJX\Datasets\CUB_200_2011\CUB_200_2011\train.json'
output_val = r'E:\ZJX\Datasets\CUB_200_2011\CUB_200_2011\val.json'# 处理文件
process_files(label_file, image_file, class_file, output_train, output_val)
2、输入
inputs --- (图片)
targets --- (类别标签)
三、训练流程
1、prompt的处理
embedding_output = self.incorporate_prompt(x)
def incorporate_prompt(self, x):# combine prompt embeddings with image-patch embeddingsB = x.shape[0]# after CLS token, all before image patchesx = self.embeddings(x) # (batch_size, 1 + n_patches, hidden_dim) (8,197,768)x = torch.cat((x[:, :1, :],self.prompt_dropout(self.prompt_proj(self.prompt_embeddings).expand(B, -1, -1)),x[:, 1:, :]), dim=1) # (8,202,768)# (batch_size, cls_token + n_prompt + n_patches, hidden_dim)return x
其中的类别token定义在 Embedding内部
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
整合提示,首先对输入 进行patch Embedding,然后在拼接 [cls, prompt, x_token] 。
这里的 prompt 是长度为 5 的token,并且是可训练的参数。
self.prompt_embeddings = nn.Parameter(torch.zeros(1, num_tokens, prompt_dim)) # (1,5,768) 初始化为参数
2、拼接后送入ViT的连续Encoder层进行注意力运算。
def forward(self, hidden_states):attn_weights = []for layer_block in self.layer:hidden_states, weights = layer_block(hidden_states)if self.vis:attn_weights.append(weights)encoded = self.encoder_norm(hidden_states)return encoded, attn_weights
得到
x = x[:, 0]logits = self.head(x) # 注意,这里的head是 Identity()
x = self.enc(x) # batch_size x self.feat_dim (8,768)
3、取 cls token温度送入head网络,预测类别token
此时的输出
x = self.head(x) # (8,200)
4、与标签计算损失
loss = self.cls_criterion(outputs, targets, self.cls_weights)
其中的loss
loss = F.cross_entropy(logits, targets, weight, reduction="none")
四、deep的变体
主要区别
prompt的设置不同
deep的初始化
if self.prompt_config.DEEP: # noqa Falsetotal_d_layer = config.transformer["num_layers"]-1self.deep_prompt_embeddings = nn.Parameter(torch.zeros(total_d_layer, num_tokens, prompt_dim))# xavier_uniform initializationnn.init.uniform_(self.deep_prompt_embeddings.data, -val, val)
其中
即既包括了shallow的还包括了其他11层的初始化
def forward_deep_prompt(self, embedding_output):attn_weights = []hidden_states = Noneweights = NoneB = embedding_output.shape[0]num_layers = self.vit_config.transformer["num_layers"] # 12for i in range(num_layers):if i == 0:hidden_states, weights = self.encoder.layer[i](embedding_output)else:if i <= self.deep_prompt_embeddings.shape[0]:deep_prompt_emb = self.prompt_dropout(self.prompt_proj(self.deep_prompt_embeddings[i-1]).expand(B, -1, -1))hidden_states = torch.cat((hidden_states[:, :1, :],deep_prompt_emb,hidden_states[:, (1+self.num_tokens):, :]), dim=1)hidden_states, weights = self.encoder.layer[i](hidden_states)if self.encoder.vis:attn_weights.append(weights)encoded = self.encoder.encoder_norm(hidden_states)return encoded, attn_weights
即第一层和shallow 一样的处理,不过之后的每层都会替换对应层设置的prompt 提示。