当前位置: 首页 > news >正文

基于Chinese-CLIP与ChromaDB的中文图像检索功能实现

本文按“原理 → 代码 → 讲解”三层展开,读者只需具备 Python 基础即可跟随完成一个可落地的以文搜图应用。

一、整体思路

  1. 把图片和文字都转成固定长度的向量(768 维)。
  2. 把图片向量提前存入向量数据库。
  3. 查询时把文字转成向量,再找出最相似的图片向量。

实现依赖两个核心组件:

  • Chinese-CLIP:中文多模态模型,负责向量化。
  • ChromaDB:轻量级向量数据库,负责存储与检索。

二、准备工作

软件

  • Python ≥ 3.8
  • 显卡可选,如有 NVIDIA GPU 请提前装好 CUDA 驱动。

安装依赖

pip install torch 
pip install transformers  chromadb pillow numpy

数据
在任意位置新建文件夹,例如 D:/photos,把待检索的 .jpg.png 图片全部放进去。

三、分步实现

1. 载入模型

from transformers import ChineseCLIPModel, ChineseCLIPProcessor
import torchdevice = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "OFA-Sys/chinese-clip-vit-large-patch14-336px"
model = ChineseCLIPModel.from_pretrained(model_name).to(device)
processor = ChineseCLIPProcessor.from_pretrained(model_name)

要点

  • processor 负责把图片或文本转成模型所需的张量。
  • 首次运行会自动下载约 1 GB 权重,后续离线可用。

2. 图片预处理

模型输入要求 336×336 像素,并保持原始比例居中裁剪。

from PIL import Imagedef load_image(image_path: str, out_size=(336, 336)) -> Image.Image:target_w, target_h = out_sizewith Image.open(image_path) as img:img = img.convert("RGB")ow, oh = img.sizescale = max(target_w / ow, target_h / oh)new_w, new_h = int(ow * scale + 0.5), int(oh * scale + 0.5)img = img.resize((new_w, new_h), Image.LANCZOS)left = (new_w - target_w) // 2top = (new_h - target_h) // 2img = img.crop((left, top, left + target_w, top + target_h))return img

3. 特征提取

把图片或文本变成 768 维向量,并对向量做 L2 归一化,使后续相似度计算简化为点积。

import numpy as npdef images_to_vectors(images):inputs = processor(images=images, return_tensors="pt").to(device)with torch.no_grad():vec = model.get_image_features(**inputs)vec = vec / vec.norm(p=2, dim=-1, keepdim=True)return vec.cpu().numpy()def texts_to_vectors(texts):inputs = processor(text=texts, padding=True, return_tensors="pt").to(device)with torch.no_grad():vec = model.get_text_features(**inputs)vec = vec / vec.norm(p=2, dim=-1, keepdim=True)return vec.cpu().numpy()

4. 构建向量数据库

ChromaDB 会在本地目录保存数据,支持增量写入。

import chromadb, uuid
from pathlib import PathDATA_DIR = Path("D:/photos")
DB_PATH  = "images.chroma_db"def build_database(data_dir=DATA_DIR):client = chromadb.PersistentClient(DB_PATH)collection = client.get_or_create_collection(name="photos",metadata={"hnsw:space": "cosine"})existing = set(collection.get()["uris"])     # 已入库图片paths = [p for p in data_dir.rglob("*.jpg") if str(p) not in existing]if not paths:print("没有新图片需要入库")returnbatch_size = 32for i in range(0, len(paths), batch_size):batch_paths = paths[i:i+batch_size]images = [load_image(p) for p in batch_paths]vectors = images_to_vectors(images)ids  = [str(uuid.uuid4()) for _ in batch_paths]uris = [str(p) for p in batch_paths]collection.add(embeddings=vectors.tolist(), ids=ids, uris=uris)print(f"已入库 {len(batch_paths)} 张")

运行一次即可:

build_database()

5. 文字查询

def search(text, top_k=5):client = chromadb.PersistentClient(DB_PATH)collection = client.get_collection("photos")vec = texts_to_vectors([text])[0]hits = collection.query(query_embeddings=[vec.tolist()],n_results=top_k,include=["uris", "distances"])return list(zip(hits["uris"][0], hits["distances"][0]))

返回示例

results = search("海棠", top_k=5)
for path, dist in results:print(f"{dist:.3f}  {path}")

6. 结果可视化(可选)

import matplotlib.pyplot as pltdef show_results(results, cols=5):n = len(results)rows = (n + cols - 1) // colsfig, axes = plt.subplots(rows, cols, figsize=(cols*2.5, rows*2.5))axes = axes.flatten() if n > 1 else [axes]for ax, (path, dist) in zip(axes, results):img = Image.open(path)ax.imshow(img)ax.set_title(f"{dist:.2f}")ax.axis("off")plt.tight_layout(); plt.show()

调用:

results = search("海棠", top_k=10)
show_results(results)

四、常见问题

  • CPU 运行慢?单张图片约 200 ms;GPU 可降至 50 ms。
  • 内存不足?把 batch_size 降到 8 或更小。
  • 想支持更多格式?把 rglob("*.jpg") 改成 rglob("*") 并自行过滤扩展名。

五、下一步可扩展

  1. 混合查询:同时输入文字 + 参考图片,把两个向量平均后再搜索。
  2. 过滤条件:在 collection.add 时附加元数据(时间、标签),查询时加 where 条件。
  3. 分布式部署:把 ChromaDB 换成 Milvus / Weaviate,即可横向扩展。

至此,你已拥有一个完整、可维护、易扩展的中文图文检索应用。

http://www.dtcms.com/a/278816.html

相关文章:

  • 人工智能如何重构能源系统以应对气候变化?
  • 动态规划题解——单词拆分【LeetCode】
  • openEuler系统PCIE降速方法简介
  • 【2025/07/14】GitHub 今日热门项目
  • Self - RAG工作步骤
  • 【HTML】五子棋(精美版)
  • 【Java EE】多线程-初阶 认识线程(Thread)
  • 【C语言进阶】指针面试题详解(2)
  • 面试 | JS 面试题 整理(更ing)2/34
  • Android 16系统源码_窗口动画(二)窗口显示动画源码调用流程
  • 护照阅读器:国外证件识别的 OCR “解码师”
  • Python 中调用阿里云 OCR(Optical Character Recognition,光学字符识别)服务
  • STM32介绍和GPIO
  • stm32-Modbus主机移植程序理解以及实战
  • argus/nvarguscamerasrc 远程显示报错
  • 项目一第一天
  • 纯数学专业VS应用数学专业:这两个哪个就业面更广?
  • C++后端面试八股文
  • Linux 基础命令详解:从入门到实践(1)
  • JAVA 并发 ThreadLocal
  • RestAssured(Java)使用详解
  • 19.数据增强技术
  • 管程! 解决互斥,同步问题的现代化手段(操作系统os)
  • Java行为型模式---模板方法模式
  • Imx6ull用网线与电脑连接
  • SpringBoot JAR 反编译替换文件
  • 【嵌入式汇编基础】-操作系统基础(三)
  • 【每日刷题】移动零
  • LabVIEW-Origin 船模数据处理系统
  • 【爬虫】Python实现爬取京东商品信息(超详细)