当前位置: 首页 > news >正文

ATPrompt:基于属性的视觉提示

项目论文:https://arxiv.org/abs/2412.09442

项目代码:GitHub - zhengli97/ATPrompt: [ICCV 2025] Official PyTorch Code for "Advancing Textual Prompt Learning with Anchored Attributes"


一、背景

提示学习:已有的文本模版提示具有两个问题:(1) 传统的固定的文本提示往往不是最优,(2) 针对性设计的文本模板费时费力,且不同数据集之间无法泛化通用。CoOp首先提出了将多个可学习词元(learnable soft token)与类别词元(class token)级联的形式,以此让模型自己学出适合的文本提示。

本文讨论了提示学习现有的缺点:例如CoOp引入软文本标记和硬类别标记相结合作为输入,但是这种形式将软提示限制在一维的、预定义的类别空间内与图像对齐,从而限制了它们在未知类别上的适用性。因此,基于当前文本形式进行训练更有可能过拟合已知类别,降低了它们对未知类别的零样本泛化能力。

于是,ATPrompt提出利用属性作为桥梁来增强图像与未知类别的对齐。为VLM引入基于属性锚定的文本提示方法。通过将多个固定的通用属性标记整合到可学习的软提示中,将软提示的学习空间从原来的一类别层面扩展到多维属性层面。软标记在训练过程中不仅能获得特定于类别的表示,还能获得与属性相关的通用表示。

该方法有两个创新点:(1)属性搜索:引入可微分的属性搜索,旨在从搜索空间V找到具有代表性的属性V。为了使搜索空间连续,将离散的属性选择放宽为对所有可能的候选属性进行softmax加权求和。于是属性搜索变成为候选池学习权重向量α。(2)联合学习属性权重α和软提示标记θ,通过最小化验证损失Lval和最小化训练损失Ltrain来学习,采用交叉优化算法解决这个双层优化问题,其中两个损失函数均使用交叉熵损失函数,搜索之后,选择权重最高的属性组合。

训练过程如下


二、实际代码实战

首先是按照github给出的要求:创建环境并安装依赖项。

创建conda环境:

# Create a conda environment
conda create -y -n atprompt python=3.8# Activate the environment
conda activate atprompt# Install torch (requires version >= 1.8.1) and torchvision
# Please refer to https://pytorch.org/ if you need a different cuda version
pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html

克隆 ATPrompt 代码存储库和安装要求

git clone https://github.com/zhengli97/ATPrompt.gitcd ATPrompt/
# Install requirementspip install -r requirements.txtcd ..

安装 dassl 库:

cd Dassl.pytorch/# Install dependencies
pip install -r requirements.txt# Install this library (no need to re-build if the source code is modified)
python setup.py develop

接下来下载具体的数据集,按照官方给出的数据集格式进行准备:不需要全部下载,我这里只对擦了caltech101和stanfordcas。

$DATA/
|–– imagenet/
|–– caltech-101/
|–– oxford_pets/
|–– stanford_cars/

接下来可以下载预训练权重到本地(可选):

需要在/trainers/coop.py下修改到本地路径。

接下来可以直接进行训练,按照论文的说法,训练主要按照base基类进行训练,然后对new新类进行zero-shot测试和泛化测试。

第一步要修改数据集的路径:scripts/coop/base2new_train.sh的第四行

然后使用ATPrompt进行训练,比如对caltech101进行训练:

# CoOp+ATPrompt, dataset=caltech101
sh scripts/coop/atp_base2new_train.sh caltech101

训练会进行5轮,选取最优种子。

如果想进行对比实验,不使用ATPrompt,可以使用:

# Vanilla CoOp
sh scripts/coop/vanilla_base2new_train.sh imagenet

对于泛化训练过程也是类似:


复现实验结果,使用Caltech101和Stanfordcars数据集,使用和不使用 ATPrompt 的情况下,从base到new的泛化实验。ATPrompt对baseline的Coop方法的性能有所提高。

以上为全部内容!

http://www.dtcms.com/a/560707.html

相关文章:

  • 手机如何制作网站教程网站双线选择
  • upload文件上传漏洞浅析
  • GitHub 热榜项目 - 日榜(2025-11-02)
  • 网站稿件管理发布系统中山网站建设半江红
  • 【Qt开发】布局管理器(二)-> QHBoxLayout水平布局
  • Linux 6.17:最新的驱动程序、快速的网络和可靠的内存
  • 【Ubuntu】虚拟机 Ubuntu 挂载 宿主机 Windows文件夹
  • 将go-tcp项目部署到docker容器下运行
  • 华为OD机试双机位A卷 - 几何平均值最大的子数组 (C++ Python JAVA JS GO)
  • PostgreSQL死锁排查攻略:从日志分析到实时监控
  • 佛山响应式网站开发二级域名免费分发站
  • 【软考架构】案例分析-Web应用设计(应用服务器概念)
  • C++中的过滤器模式:原理、实现与应用
  • Kanass实践指南(4) - 测试团队如何通过kanass管理跟踪用例与缺陷
  • 天河做网站技术松江做网站费用
  • 面试Redis篇—————缓存穿透问题及解决策略
  • 【ComfyUI】通用 文生图转视频
  • 怎样建网站?西湖区住房和城市建设局网站
  • 教做宝宝衣服的网站济南网站优化多少钱
  • 分布式文件存储服务设计与实现优化
  • Qt-Nice-Frameless-Window: 一个跨平台无边框窗口(Frameless Window)解决方案
  • 跨平台游戏引擎 Axmol-2.9.1 发布
  • Redis性能优化避坑指南
  • 【Cache缓存】两路组相连和全相连
  • 青岛门头设计制作长春百度关键词优化
  • 青海网站制作的公司天津市网站建设公司
  • 数据结构04:链表的概念及实现单链表
  • springCloud二-SkyWalking3-性能剖析-⽇志上传-告警管理-接入飞书
  • 【项目基础】vue-class-component、vue-property-decorator、vuex-class、GeoJson
  • JWT 是由哪三个部分组成?如何使用JWT进行身份认证?