当前位置: 首页 > news >正文

CLIP模型使用方法

本文主要记录CLIP模型的原理,安装,基本使用。

1.CLIP模型原理

CLIP模型是文本图像匹配模型。使用4亿个图文对进行对比训练。

首先,每个batch的图文对分别进行图像编码和文本编码。

图像编码使用ViT,输出[batch_size,picture_feature]

文本编码使用Transformer,输出[batch_size,text_feature]

对比训练:在batch中计算所有图像与文本的余弦相似度,每个图文对的相似度最高,处在对角线上,其余相似度低。这样就可以进行训练得到CLIP模型。

2.CLIP安装

参考:安装文本-图像对比学习模型CLIP的方法

3.基本使用

3.1 单张图片分类

输入一张图片,设置几个类别,输出softmax概率分布。

def classify_image_with_clip():
    '''
    使用CLIP模型推理图像的文本描述
    可以调整的部分:
    1.选择文本描述,这里为"a diagram", "a dog", "a cat"
    2.加载CLIP模型,这里为ViT-B/32
    3.选择输入图像,这里为CLIP.png
    '''
    device = "cuda" if torch.cuda.is_available() else "cpu" #选择GPU或CPU
    model, preprocess = clip.load("ViT-B/32", device=device) #加载视觉模型ViT-B/32

    image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device) #对输入图像进行预处理,增加一个batch维度并转移到GPU或CPU
    text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device) #选择文本描述,并转移到GPU或CPU

    with torch.no_grad(): #关闭梯度计算
        logits_per_image, logits_per_text = model(image, text)  #计算余弦相似度
        probs = logits_per_image.softmax(dim=-1).cpu().numpy() #softmax归一化后取出概率值

    print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]] #打印概率值

3.2使用CLIP推理CIFAR100图像的类别

选择CIFAR100的任一张图,输出类别。

def classify_image_with_clip_cifar100():
    '''
    使用CLIP模型推理CIFAR100图像的文本描述
    可以调整的部分:
    1.选择某一张CIFAR100图像,这里为第3637张图像
    2.加载CLIP模型,这里为ViT-B/32
    3.扩展CIFAR100的文本描述,这里为"a photo of a {class}"
    '''
    device = "cuda" if torch.cuda.is_available() else "cpu" #选择GPU或CPU
    model, preprocess = clip.load('ViT-B/32', device) #加载视觉模型ViT-B/32

    cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False) #下载CIFAR100数据集,该数据集包含100个类别的图像

    image, class_id = cifar100[3637] #选择第3637张图像及其类别
    image_input = preprocess(image).unsqueeze(0).to(device) #对输入图像进行预处理,增加一个batch维度并转移到GPU或CPU
    text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device) #扩展cifar100的文本描述

    with torch.no_grad(): #关闭梯度计算
        image_features = model.encode_image(image_input) #计算图像特征
        text_features = model.encode_text(text_inputs) #计算文本特征

    image_features /= image_features.norm(dim=-1, keepdim=True) #归一化图像特征
    text_features /= text_features.norm(dim=-1, keepdim=True) #归一化文本特征
    similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) #计算余弦相似度
    values, indices = similarity[0].topk(5) #取出前5个最相似的文本描述

    #打印预测结果
    print("\nTop predictions:\n")
    for value, index in zip(values, indices):
        print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")

3.3使用CLIP训练Logistic回归模型,分类CIFAR100

使用CLIP训练CIFAR100分类器

def train_classifier_with_clip_cifar100():
    '''
    使用CLIP模型训练CIFAR100图像分类器
    可以调整的部分:
    1.选择训练集和测试集的比例,这里为8:2
    2.选择逻辑回归分类器,这里为LogisticRegression
    3.设置分类器超参数,这里为C=0.316, max_iter=1000
    4.加载模型,这里为ViT-B/32
    '''

    device = "cuda" if torch.cuda.is_available() else "cpu" #选择GPU或CPU
    model, preprocess = clip.load('ViT-B/32', device) #加载视觉模型ViT-B/32

    root = os.path.expanduser("~/.cache") #下载CIFAR100数据集的路径
    train = CIFAR100(root, download=True, train=True, transform=preprocess) #下载训练集
    test = CIFAR100(root, download=True, train=False, transform=preprocess) #下载测试集

    def get_features(dataset):
        '''
        使用batch计算数据集的特征
        例:dataset有10000张图像
        1.划分为100个batch
        2.每个batch计算100张图像的特征
        3.将本batch的特征和标签添加合并到一个数组
        4.将100个batch的特征合并后返回
        '''
        all_features = [] #特征列表
        all_labels = [] #标签列表

        with torch.no_grad(): #关闭梯度计算
            for images, labels in tqdm(DataLoader(dataset, batch_size=100)): #加载数据集,batchsize设置为100,并显示进度条
                features = model.encode_image(images.to(device)) #计算图像特征

                all_features.append(features) #将特征添加到列表
                all_labels.append(labels) #将标签添加到列表

        return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy() #将特征和标签合并成一个数组并转移到CPU


    train_features, train_labels = get_features(train) #计算训练集的特征
    test_features, test_labels = get_features(test) #计算测试集的特征

    classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1) #设置逻辑回归分类器

    # 使用训练集训练分类器
    # 输入train_features:[num_samples, feature_dim], train_labels:[num_samples]
    # y = sigmoid(train_features * x + b) [num_samples, num_classes]
    # 使用y和train_labels计算损失函数
    # 优化器更新参数
    classifier.fit(train_features, train_labels) #使用训练集训练分类器

    predictions = classifier.predict(test_features) #使用测试集预测标签
    accuracy = np.mean((test_labels == predictions).astype(float)) * 100. #计算准确率
    print(f"Accuracy = {accuracy:.3f}") #打印准确率

相关文章:

  • MapReduce技术概述**
  • java的jiraapi设置超时时间
  • Excel·VBA江西省预算一体化工资表一键处理
  • 【LeetCode 热题 100】11. 盛最多水的容器 | python 【中等】
  • Unity Shader编程】之基础纹理
  • Mac 上如何安装Mysql? 如何配置 Mysql?以及如何开启并使用MySQL
  • MySQL:CRUD(增删查改)
  • 鸿蒙Next-应用检测、安装以及企业内部商店的实现
  • 日语学习-日语知识点小记-构建基础-JLPT-N4N5阶段(19):(1)家族の呼び方(よびかた):家族人物(2)助词练习
  • FreeRTOS任务状态查询
  • 数据库原理5
  • JavaScript系列06-深入理解 JavaScript 事件系统:从原生事件到 React 合成事件
  • Greenplum6.19集群搭建
  • 【云岚到家】-实战问题(上)
  • ubuntu 20.04 C++ 源码编译 cuda版本 opencv4.5.0
  • STM32驱动OLED屏幕全解析:从原理到温度显示实战(上) | 零基础入门STM32第五十三步
  • 人工智能与深度学习的应用案例解析及代码实现
  • Ae 效果详解:VR 色差
  • 计算机视觉cv2入门之图像的读取,显示,与保存
  • 【江协科技STM32】ADC数模转换器-学习笔记
  • 网站开发框架拓扑/百度统计登录
  • 贪便宜网站/网络广告策划书模板范文
  • 做网站python和php哪个好学/如何设计一个网站页面
  • 宝鸡手机版网站建设/企业网站开发制作
  • 怎么做视频网站首页/资源搜索引擎
  • 公司网站策划方案/南宁seo