当前位置: 首页 > news >正文

Vision Prompt Tune(视觉提示微调)

方法记录

一、网络结构

可训练的参数

prompt module 

PromptedVisionTransformer((transformer): PromptedTransformer((embeddings): Embeddings((patch_embeddings): Conv2...affine=True))(prompt_dropout): Dropout(p=0.0, inplace=False)(prompt_proj): Identity())(head): Identity()
)

head 

MLP((projection): Sequential()(last_layer): Linear(in_features=768, out_features=200, bias=True)
)

总的结构 

ViT((enc): PromptedVisionTransformer((transformer): PromptedTransformer((embeddings): Embeddings((patc...)(head): MLP((projection): Sequential()(last_layer): Linear(in_features=768, out_features=200, bias=True))
)

 这里 继承的父类

class PromptedTransformer(Transformer):

而父类里面包含了骨干网络的设置

class Transformer(nn.Module):def __init__(self, config, img_size, vis):super(Transformer, self).__init__()self.embeddings = Embeddings(config, img_size=img_size)self.encoder = Encoder(config, vis)

Encoder中的ViT的Encoder Layer层的定义

class Block(nn.Module):def __init__(self, config, vis):super(Block, self).__init__()self.hidden_size = config.hidden_sizeself.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)self.ffn = Mlp(config)self.attn = Attention(config, vis)def forward(self, x):h = xx = self.attention_norm(x)x, weights = self.attn(x)x = x + hh = xx = self.ffn_norm(x)x = self.ffn(x)x = x + hreturn x, weights

其中的embeddings就是 ViT的 patch projection layer ,而encoder 就是 12个 Encoder Layer层

二、数据集配置

1、train.json 文件的转换

只使用了CUB_200_2011数据集,需要进行转换为json文件。

转换脚本如下:

import json
import osdef process_files(label_file, image_file, class_file, output_train, output_val):# 读取文档1:训练/测试划分 (0表示测试集,1表示训练集)with open(label_file, 'r') as f:labels = [int(line.split()[1]) for line in f.readlines()]# 读取文档2:图片路径with open(image_file, 'r') as f:image_paths = [line.split()[1] for line in f.readlines()]# 读取文档3:类别标签with open(class_file, 'r') as f:classes = [int(line.split()[1]) for line in f.readlines()]# 确保三个文件的长度一致assert len(labels) == len(image_paths) == len(classes), "三个文件的条目数不一致"# 创建训练集和测试集的字典train_data = {}val_data = {}for i in range(len(labels)):img_path = image_paths[i]class_id = classes[i]if labels[i] == 1:  # 训练集train_data[img_path] = class_idelse:  # 测试集val_data[img_path] = class_id# 写入JSON文件with open(output_train, 'w') as f:json.dump(train_data, f, indent=4)with open(output_val, 'w') as f:json.dump(val_data, f, indent=4)print(f"训练集已保存到 {output_train},包含 {len(train_data)} 个样本")print(f"验证集已保存到 {output_val},包含 {len(val_data)} 个样本")# 文件路径
label_file = r'E:\ZJX\Datasets\CUB_200_2011\CUB_200_2011\train_test_split.txt'   # 替换为实际文件路径
image_file = r'E:\ZJX\Datasets\CUB_200_2011\CUB_200_2011\images.txt'   # 替换为实际文件路径
class_file = r'E:\ZJX\Datasets\CUB_200_2011\CUB_200_2011\image_class_labels.txt' # 替换为实际文件路径
output_train = r'E:\ZJX\Datasets\CUB_200_2011\CUB_200_2011\train.json'
output_val = r'E:\ZJX\Datasets\CUB_200_2011\CUB_200_2011\val.json'# 处理文件
process_files(label_file, image_file, class_file, output_train, output_val)

2、输入

inputs --- (图片)

 targets --- (类别标签)

三、训练流程 

1、prompt的处理

embedding_output = self.incorporate_prompt(x)
    def incorporate_prompt(self, x):# combine prompt embeddings with image-patch embeddingsB = x.shape[0]# after CLS token, all before image patchesx = self.embeddings(x)  # (batch_size, 1 + n_patches, hidden_dim)  (8,197,768)x = torch.cat((x[:, :1, :],self.prompt_dropout(self.prompt_proj(self.prompt_embeddings).expand(B, -1, -1)),x[:, 1:, :]), dim=1)  # (8,202,768)# (batch_size, cls_token + n_prompt + n_patches, hidden_dim)return x

其中的类别token定义在 Embedding内部

self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))

整合提示,首先对输入 进行patch Embedding,然后在拼接 [cls, prompt, x_token] 。

这里的 prompt 是长度为 5 的token,并且是可训练的参数。

self.prompt_embeddings = nn.Parameter(torch.zeros(1, num_tokens, prompt_dim))  # (1,5,768)  初始化为参数

2、拼接后送入ViT的连续Encoder层进行注意力运算。

 def forward(self, hidden_states):attn_weights = []for layer_block in self.layer:hidden_states, weights = layer_block(hidden_states)if self.vis:attn_weights.append(weights)encoded = self.encoder_norm(hidden_states)return encoded, attn_weights

得到

 

x = x[:, 0]logits = self.head(x)  # 注意,这里的head是 Identity()
x = self.enc(x)  # batch_size x self.feat_dim  (8,768)

3、取 cls token温度送入head网络,预测类别token

此时的输出

x = self.head(x)  # (8,200)

4、与标签计算损失

loss = self.cls_criterion(outputs, targets, self.cls_weights)

其中的loss

loss = F.cross_entropy(logits, targets, weight, reduction="none")

四、deep的变体

主要区别

prompt的设置不同

deep的初始化

            if self.prompt_config.DEEP:  # noqa Falsetotal_d_layer = config.transformer["num_layers"]-1self.deep_prompt_embeddings = nn.Parameter(torch.zeros(total_d_layer, num_tokens, prompt_dim))# xavier_uniform initializationnn.init.uniform_(self.deep_prompt_embeddings.data, -val, val)

其中

即既包括了shallow的还包括了其他11层的初始化 

    def forward_deep_prompt(self, embedding_output):attn_weights = []hidden_states = Noneweights = NoneB = embedding_output.shape[0]num_layers = self.vit_config.transformer["num_layers"]  # 12for i in range(num_layers):if i == 0:hidden_states, weights = self.encoder.layer[i](embedding_output)else:if i <= self.deep_prompt_embeddings.shape[0]:deep_prompt_emb = self.prompt_dropout(self.prompt_proj(self.deep_prompt_embeddings[i-1]).expand(B, -1, -1))hidden_states = torch.cat((hidden_states[:, :1, :],deep_prompt_emb,hidden_states[:, (1+self.num_tokens):, :]), dim=1)hidden_states, weights = self.encoder.layer[i](hidden_states)if self.encoder.vis:attn_weights.append(weights)encoded = self.encoder.encoder_norm(hidden_states)return encoded, attn_weights

即第一层和shallow 一样的处理,不过之后的每层都会替换对应层设置的prompt 提示。

相关文章:

  • 在ARM 架构的 Mac 上 更新Navicat到17后连接Oracle时报错:未加载 Oracle 库。
  • Windows 系统中修改文件默认打开方式
  • Java多线程实现之线程池详解
  • 机器人教学和实践的可编程智能仿生机器人平台——智能六足机器人
  • 临时抱佛脚v2
  • Vue Electron 使用来给若依系统打包成exe程序,出现登录成功但是不跳转页面(已解决)
  • OpenSSL 无法验证 DevSidecar 的自签名证书
  • 目标检测yolo算法
  • Windows 上安装 devsidecar 后,使用 WSL ubuntu ssl 报错
  • 机器视觉开发-边缘提取
  • Java-43 深入浅出 Nginx - 基本配置方式 nginx.conf Events块 HTTP块 反向代理 负载均衡
  • 永磁同步电机无速度算法--基于稳态卡尔曼滤波器SSEKF的滑模观测器
  • 实战使用docker compose 搭建 Redis 主从复制集群
  • 【docker】docker registry搭建私有镜像仓库
  • Linux 杀进程指令详解:`kill -9 PID` 和 `kill -15 PID` 有什么区别?
  • 云计算迁移策略:分步框架与优势
  • 开源生态新势能: 驱动国产 DevSecOps 与 AI 工程新进展
  • Vim鼠标右键复制问题解决方法
  • 自定义鼠标效果 - 浏览器扩展使用教程
  • (新手友好)MySQL学习笔记(8):存储过程,自定义函数,游标
  • 线上推广工作是做什么的/seo站长工具平台
  • wordpress cgi漏洞/百度seo还有前景吗
  • 河南移动商城网站建设/自动连点器
  • 网站模糊背景/新闻发布稿
  • h5响应式集团网站推荐/百度官方电话人工服务电话
  • 网站建设软件下载/seo推广思路