基于Chinese-CLIP与ChromaDB的中文图像检索功能实现
本文按“原理 → 代码 → 讲解”三层展开,读者只需具备 Python 基础即可跟随完成一个可落地的以文搜图应用。
一、整体思路
- 把图片和文字都转成固定长度的向量(768 维)。
- 把图片向量提前存入向量数据库。
- 查询时把文字转成向量,再找出最相似的图片向量。
实现依赖两个核心组件:
- Chinese-CLIP:中文多模态模型,负责向量化。
- ChromaDB:轻量级向量数据库,负责存储与检索。
二、准备工作
软件
- Python ≥ 3.8
- 显卡可选,如有 NVIDIA GPU 请提前装好 CUDA 驱动。
安装依赖
pip install torch
pip install transformers chromadb pillow numpy
数据
在任意位置新建文件夹,例如 D:/photos
,把待检索的 .jpg
或 .png
图片全部放进去。
三、分步实现
1. 载入模型
from transformers import ChineseCLIPModel, ChineseCLIPProcessor
import torchdevice = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "OFA-Sys/chinese-clip-vit-large-patch14-336px"
model = ChineseCLIPModel.from_pretrained(model_name).to(device)
processor = ChineseCLIPProcessor.from_pretrained(model_name)
要点
processor
负责把图片或文本转成模型所需的张量。- 首次运行会自动下载约 1 GB 权重,后续离线可用。
2. 图片预处理
模型输入要求 336×336 像素,并保持原始比例居中裁剪。
from PIL import Imagedef load_image(image_path: str, out_size=(336, 336)) -> Image.Image:target_w, target_h = out_sizewith Image.open(image_path) as img:img = img.convert("RGB")ow, oh = img.sizescale = max(target_w / ow, target_h / oh)new_w, new_h = int(ow * scale + 0.5), int(oh * scale + 0.5)img = img.resize((new_w, new_h), Image.LANCZOS)left = (new_w - target_w) // 2top = (new_h - target_h) // 2img = img.crop((left, top, left + target_w, top + target_h))return img
3. 特征提取
把图片或文本变成 768 维向量,并对向量做 L2 归一化,使后续相似度计算简化为点积。
import numpy as npdef images_to_vectors(images):inputs = processor(images=images, return_tensors="pt").to(device)with torch.no_grad():vec = model.get_image_features(**inputs)vec = vec / vec.norm(p=2, dim=-1, keepdim=True)return vec.cpu().numpy()def texts_to_vectors(texts):inputs = processor(text=texts, padding=True, return_tensors="pt").to(device)with torch.no_grad():vec = model.get_text_features(**inputs)vec = vec / vec.norm(p=2, dim=-1, keepdim=True)return vec.cpu().numpy()
4. 构建向量数据库
ChromaDB 会在本地目录保存数据,支持增量写入。
import chromadb, uuid
from pathlib import PathDATA_DIR = Path("D:/photos")
DB_PATH = "images.chroma_db"def build_database(data_dir=DATA_DIR):client = chromadb.PersistentClient(DB_PATH)collection = client.get_or_create_collection(name="photos",metadata={"hnsw:space": "cosine"})existing = set(collection.get()["uris"]) # 已入库图片paths = [p for p in data_dir.rglob("*.jpg") if str(p) not in existing]if not paths:print("没有新图片需要入库")returnbatch_size = 32for i in range(0, len(paths), batch_size):batch_paths = paths[i:i+batch_size]images = [load_image(p) for p in batch_paths]vectors = images_to_vectors(images)ids = [str(uuid.uuid4()) for _ in batch_paths]uris = [str(p) for p in batch_paths]collection.add(embeddings=vectors.tolist(), ids=ids, uris=uris)print(f"已入库 {len(batch_paths)} 张")
运行一次即可:
build_database()
5. 文字查询
def search(text, top_k=5):client = chromadb.PersistentClient(DB_PATH)collection = client.get_collection("photos")vec = texts_to_vectors([text])[0]hits = collection.query(query_embeddings=[vec.tolist()],n_results=top_k,include=["uris", "distances"])return list(zip(hits["uris"][0], hits["distances"][0]))
返回示例
results = search("海棠", top_k=5)
for path, dist in results:print(f"{dist:.3f} {path}")
6. 结果可视化(可选)
import matplotlib.pyplot as pltdef show_results(results, cols=5):n = len(results)rows = (n + cols - 1) // colsfig, axes = plt.subplots(rows, cols, figsize=(cols*2.5, rows*2.5))axes = axes.flatten() if n > 1 else [axes]for ax, (path, dist) in zip(axes, results):img = Image.open(path)ax.imshow(img)ax.set_title(f"{dist:.2f}")ax.axis("off")plt.tight_layout(); plt.show()
调用:
results = search("海棠", top_k=10)
show_results(results)
四、常见问题
- CPU 运行慢?单张图片约 200 ms;GPU 可降至 50 ms。
- 内存不足?把
batch_size
降到 8 或更小。 - 想支持更多格式?把
rglob("*.jpg")
改成rglob("*")
并自行过滤扩展名。
五、下一步可扩展
- 混合查询:同时输入文字 + 参考图片,把两个向量平均后再搜索。
- 过滤条件:在
collection.add
时附加元数据(时间、标签),查询时加 where 条件。 - 分布式部署:把 ChromaDB 换成 Milvus / Weaviate,即可横向扩展。
至此,你已拥有一个完整、可维护、易扩展的中文图文检索应用。