文搜图/图搜图
文搜图/图搜图
- 1.环境
- 2.建集合
- 3.图入集合
- 4.输入向量化
- 5.文搜图
- 6.图搜图
- 7.参考博文
- 8.仓库代码
1.环境
linux安装docker,修改镜像源,安装docker-compose
#1.安装docker
sudo apt update
sudo apt install docker.io
sudo systemctl start docker
sudo docker --version
#2.修改docker镜像源
sudo su
vi /etc/docker/daemon.json
{"registry-mirrors": ["https://rsk59qvc.mirror.aliyuncs.com"]
}
sudo systemctl restart docker重启
docker info查看是否修改成功
#3.安装docker-compose
sudo curl -L "https://github.com/docker/compose/releases/download/1.29.2/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose
sudo chmod +x /usr/local/bin/docker-compose
docker-compose --version
安装启动milvus(2.5.0)容器和可视化attu容器
wget https://github.com/milvus-io/milvus/releases/download/v2.5.0/milvus-standalone-docker-compose.yml -O docker-compose.yml
sudo docker-compose up -d
docker run -d --name attu -p 8000:3000 -e MILVUS_URL=host.docker.internal:19530 zilliz/attu:v2.5
#windows的话按如下安装milvus
管理员身份打开powershell,
Invoke-WebRequest https://raw.githubusercontent.com/milvus-io/milvus/refs/heads/master/scripts/standalone_embed.bat -OutFile standalone.bat
.\standalone.bat start
docker ps -a (查看是否成功)
本地浏览器访问http://localhost:8000/#/connect可视化milvus库
python环境
pip install cn-clip ipython
pip install pymilvus==2.5.0
2.建集合
# 1.创建milvus库对象from pymilvus import MilvusClient, DataType
import torch
import timedef create_schema():schema = milvus_client.create_schema(auto_id=True,enable_dynamic_field=True,description="")schema.add_field(field_name="id", datatype=DataType.INT64, descrition='ids', is_primary=True)schema.add_field(field_name="vectors", datatype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=512)schema.add_field(field_name="filepath", datatype=DataType.VARCHAR, descrition='file path', max_length=200)return schema
def create_collection(collection_name, schema, timeout = 3):# 创建集合try:milvus_client.create_collection(collection_name=collection_name,schema=schema,shards_num=2)print(f"开始创建集合:{collection_name}")except Exception as e:print(f"创建集合的过程中出现了错误: {e}")return False# 检查集合是否创建成功start_time = time.time()while True:if milvus_client.has_collection(collection_name):print(f"集合 {collection_name} 创建成功")return Trueelif time.time() - start_time > timeout:print(f"创建集合 {collection_name} 超时")return Falsetime.sleep(1)
class CollectionDeletionError(Exception):"""删除集合失败"""def check_and_drop_collection(collection_name):if milvus_client.has_collection(collection_name):print(f"集合 {collection_name} 已经存在")try:milvus_client.drop_collection(collection_name)print(f"删除集合:{collection_name}")return Trueexcept Exception as e:print(f"删除集合时出现错误: {e}")return Falsereturn True
collection_name = "w_cc"
uri="http://localhost:19530"
milvus_client = MilvusClient(uri=uri)
# 如果无法删除集合,抛出异常
if not check_and_drop_collection(collection_name):raise CollectionDeletionError('删除集合失败')
else:# 创建集合的模式schema = create_schema()# 创建集合并等待成功create_collection(collection_name, schema)
3.图入集合
# 2.向量化图像与文字,并把图像入库,创建索引,使用倒排索引(IVF_FLAT),检索效率高,准确性也不错。度量方式使用余弦相似度(COSINE)。
import cn_clip.clip as clip # 导入可用模型的函数
from cn_clip.clip import available_models
import torch
from PIL import Image
import os
from glob import glob
from tqdm import tqdm
import time
import cn_clip.clip as clip # 导入可用模型的函数
from cn_clip.clip import available_models
import torch
# 用于图片处理
from PIL import Image
from pymilvus import MilvusClient
# 查看 chinese-clip 中可用模型列表
print("Available models:", available_models())
# 确定使用的设备:如果可用则使用GPU,否则使用CPU
device = "cuda" if torch.cuda.is_available() else "cpu"# 指定模型名称
model_name = "ViT-B-16"# 加载chinese-clip模型和对应的预处理函数
# model: 包含图片编码器(encode_image)和文本编码器(encode_text)
# preprocess: 图片预处理函数(包括归一化、缩放等操作)
# download_root: 设置模型下载后保存的位置
model, preprocess = clip.load_from_name(model_name, device=device, download_root='./chinese_clip_model')# 将模型设置为评估模式,关闭dropout等训练特性
model.eval()
collection_name = "multimodal_chinese_clip"
uri="http://localhost:19530"
milvus_client = MilvusClient(uri=uri)
def encode_image(image_path):# 关闭梯度计算,减少内存消耗,提高计算效率with torch.no_grad():# 打开图片文件# 如果图片不是RGB格式,使用convert转换格式raw_image = Image.open(image_path).convert('RGB')processed_image = preprocess(raw_image).unsqueeze(0).to(device)# 生成图片的向量image_features = model.encode_image(processed_image)# 特征归一化image_features /= image_features.norm(dim=-1, keepdim=True)# 以列表形式返回向量return image_features.squeeze().tolist()
def encode_text(text_list):# 关闭梯度计算,减少内存消耗,提高计算效率with torch.no_grad():# 文本分词和特殊符号处理text_tokens = clip.tokenize(text_list).to(device)# 生成文本的向量text_features = model.encode_text(text_tokens)# 特征归一化text_features /= text_features.norm(dim=-1, keepdim=True)# 以列表形式返回向量return [f.squeeze().tolist() for f in text_features]
def process_images_and_insert(input_dir_path, ext_list, batch_size=100):# 获取所有图片路径(递归图片检索)image_paths = []for ext in ext_list:print(f"正在查找扩展名: {ext}")# 确保路径通配符正确,递归查找pattern = os.path.join(input_dir_path, f"**/*{ext}") #f"**/*{ext}")print(f"搜索模式: {pattern}")image_paths.extend(glob(pattern, recursive=True))total_images = len(image_paths)print(f"总计需要处理 {total_images} 张图片")# 初始化总计时器total_start_time = time.time()# 初始化进度条with tqdm(total=total_images, desc="处理图片并插入数据") as progress_bar:# 分批处理图片for batch_start in range(0, total_images, batch_size):batch_data = []batch_paths = image_paths[batch_start: batch_start + batch_size]batch_start_time = time.time()# 当前批次的向量化处理for image_path in batch_paths:try:image_embedding = encode_image(image_path)batch_data.append({"vectors": image_embedding,"filepath": image_path})except Exception as e:print(f"处理图片 {image_path} 时出错: {str(e)}")continue# 批量插入当前批次到Milvusif batch_data:try:res = milvus_client.insert(collection_name=collection_name,data=batch_data)# 计算批次耗时batch_duration = time.time() - batch_start_time# 更新进度条:每次成功插入的图片数量progress_bar.update(len(batch_data))# 显示批次处理时间progress_bar.set_postfix({"批次耗时": batch_duration,}) except Exception as e:print(f"插入批次 {batch_start} 时失败: {str(e)}") # 计算总耗时total_duration = time.time() - total_start_timeprint(f"\n所有图片处理完成!总耗时: {total_duration}")print(f"平均处理速度: {total_images/total_duration:.1f}张/秒")
input_dir_path = "lhq_1024_jpg_5000"
batch_size = 300
ext_list = ['.JPEG', '.jpg', '.png'] # 确保扩展名大小写问题
process_images_and_insert(input_dir_path, ext_list, batch_size)
def create_index(collection_name):# 准备索引参数index_params = milvus_client.prepare_index_params()index_params.add_index(index_name="IVF_FLAT",# 指定创建索引的字段field_name="vectors",index_type="IVF_FLAT",metric_type="COSINE",params={"nlist":512})# 创建索引milvus_client.create_index(collection_name=collection_name,index_params=index_params)
create_index(collection_name)
# 加载集合
print(f"正在加载集合 {collection_name}")
milvus_client.load_collection(collection_name=collection_name)
print(f"集合 {collection_name} 加载完成")
# 验证加载状态
state = str(milvus_client.get_load_state(collection_name=collection_name)['state'])
if state == 'Loaded':print("集合加载完成")
else:print("集合加载失败")
print(milvus_client.query(collection_name=collection_name,output_fields=["count(*)"]
))
4.输入向量化
from PIL import Image
from pymilvus import MilvusClient
import cn_clip.clip as clip # 导入可用模型的函数
from cn_clip.clip import available_models
import torch
from PIL import Image
import os
from glob import glob
from tqdm import tqdm
import time
import cn_clip.clip as clip # 导入可用模型的函数
from cn_clip.clip import available_models
import torch
# 用于图片处理
from PIL import Image
from pymilvus import MilvusClient
collection_name = "w_cc"
uri="http://localhost:19530"
milvus_client = MilvusClient(uri=uri)
print("Available models:", available_models())
# 确定使用的设备:如果可用则使用GPU,否则使用CPU
device = "cuda" if torch.cuda.is_available() else "cpu"# 指定模型名称
model_name = "ViT-B-16"model, preprocess = clip.load_from_name(model_name, device=device, download_root='./chinese_clip_model')# 将模型设置为评估模式,关闭dropout等训练特性5
model.eval()
def encode_image(image_path):# 关闭梯度计算,减少内存消耗,提高计算效率with torch.no_grad():# 打开图片文件# 如果图片不是RGB格式,使用convert转换格式raw_image = Image.open(image_path).convert('RGB')processed_image = preprocess(raw_image).unsqueeze(0).to(device)# 生成图片的向量image_features = model.encode_image(processed_image)# 特征归一化image_features /= image_features.norm(dim=-1, keepdim=True)# 以列表形式返回向量return image_features.squeeze().tolist()
def encode_text(text_list):# 关闭梯度计算,减少内存消耗,提高计算效率with torch.no_grad():# 文本分词和特殊符号处理text_tokens = clip.tokenize(text_list).to(device)# 生成文本的向量text_features = model.encode_text(text_tokens)# 特征归一化text_features /= text_features.norm(dim=-1, keepdim=True)# 以列表形式返回向量return [f.squeeze().tolist() for f in text_features]def vector_search(vector, field_name, limit, output_fields):# 执行向量图片检索res = milvus_client.search(collection_name=collection_name,data=vector,anns_field=field_name,limit=limit,output_fields=output_fields)return res# from IPython.display import display
# from PIL import Image# # 定义显示图片检索结果的函数
def create_concatenated_image(res, images_per_row=2, images_per_column=2, image_size=(400, 400)):# 设置拼接后的大图尺寸:宽度是每行图片的宽度之和,高度是每列图片的高度之和width = image_size[0] * images_per_rowheight = image_size[1] * images_per_column# 创建一个空白的大画布(RGB模式,白色背景)concatenated_image = Image.new("RGB", (width, height))# 存储所有结果图片的列表result_images = []# 遍历图片检索结果的每个hit对象for result in res: # 通常res是单batch列表for hit in result:# 从hit对象中获取图片文件路径filename = hit["entity"]["filepath"]# 打开图片文件并调整大小为指定尺寸try:img = Image.open(filename)# 保持宽高比的缩略图img = img.resize(image_size)result_images.append(img)except Exception as e:print(f"无法加载图片 {filename}: {e}")continue# 将缩略图拼接到大画布上for idx, img in enumerate(result_images):# 计算当前图片应放置的网格位置x = idx % images_per_rowy = idx // images_per_row# 将图片粘贴到计算好的位置concatenated_image.paste(img, (x * image_size[0], y * image_size[1]))return concatenated_image
5.文搜图
query_text = ["小桥流水人家"]
query_embedding = encode_text (query_text)[0]
field_name = "vectors"
limit = 10
output_fields = ["filepath"]
res = vector_search ([query_embedding], field_name, limit, output_fields)
print(f"查询文本: {query_text}")
print(f"检索结果:")# 使用 create_concatenated_image 函数生成拼接图像
result_image = create_concatenated_image(res, 2, 2, (400, 400))# 保存拼接图像到本地目录
output_path = "./output/concatenated_image.png"
result_image.save(output_path)print(f"拼接图像已保存到: {output_path}")
做成接口
app = FastAPI()
class QueryRequest(BaseModel):query_text: str
@app.post("/text-search-images/")
async def search_images(query_request: QueryRequest):#用户输入query_textquery_text = query_request.query_textquery_embedding = encode_text([query_text])[0] # 获取文本向量field_name = "vectors"limit = 10output_fields = ["filepath"]res = vector_search([query_embedding], field_name, limit, output_fields)# image_paths = [image["filepath"] for image in res]image_paths=[]for images in res:image_paths.extend([image["entity"]["filepath"] for image in images ])# 返回图片路径列表return {"images": image_paths}
post测试
import requests# 测试文本查询相似图片的接口
def test_text_search_images(query_text):url = "http://127.0.0.1:8001/text-search-images/"response = requests.post(url, json={"query_text": query_text})if response.status_code == 200:print("查询相似图片成功:")print(response.json()) # 打印返回的图片路径列表else:print("查询相似图片失败:", response.status_code)if __name__ == "__main__":query_text = "小桥流水人家" # 示例查询文本test_text_search_images(query_text)
效果
6.图搜图
query_image = '目标/上海/屏幕截图.png'
query_embedding = encode_image(query_image)
field_name = "vectors"
limit = 5
output_fields = ["filepath"]
res = vector_search([query_embedding], field_name, limit, output_fields)
image_paths=[]
for images in res:image_paths.extend([image["entity"]["filepath"] for image in images ])
print(image_paths)
print(f"查询图片")
query_image_save_path = './output/query_image.png'
print(f"图片检索结果:")
concatenated_image = create_concatenated_image(res, images_per_row=3, images_per_column=3, image_size=(300, 300))
concatenated_image_save_path = './output/retrieved_images.png'
concatenated_image.save(concatenated_image_save_path)
print(f"检索结果图像已保存到: {concatenated_image_save_path}")
做成接口
class ImageQueryRequest(BaseModel):image_path: str# 定义后端接口:根据上传的图片查询相似的图片
@app.post("/search-similar-images/")
async def search_similar_images(request: ImageQueryRequest):image_path = request.image_path# 获取图片的嵌入向量query_embedding = encode_image(image_path)# 查询相似图片路径field_name = "vectors"limit = 10output_fields = ["filepath"]res = vector_search([query_embedding], field_name, limit, output_fields)image_paths=[]for images in res:image_paths.extend([image["entity"]["filepath"] for image in images ])# 提取文件路径# image_paths = [image["filepath"] for image in res]# 返回匹配的图片路径列表return {"similar_images": image_paths}# 定义后端接口:上传图片并展示
@app.post("/show-image/")
async def show_image(image: UploadFile = File(...)):# 保存上传的图片到临时目录temp_image_path = f"./temp_images/{image.filename}"os.makedirs(os.path.dirname(temp_image_path), exist_ok=True)with open(temp_image_path, "wb") as f:f.write(await image.read())# 使用 PIL 打开并显示图片img = Image.open(temp_image_path)img.show()return {"message": f"图片已显示,路径: {temp_image_path}"}
post测试
import requests
def test_search_similar_images(image_path):url = "http://127.0.0.1:8001/search-similar-images/"response = requests.post(url, json={"image_path": image_path})if response.status_code == 200:print("查询相似图片成功:")print(response.json())else:print("查询相似图片失败:", response.status_code)
def test_show_image(image_path):url = "http://127.0.0.1:8001/show-image/"# 打开图片文件,并发送 POST 请求with open(image_path, "rb") as img_file:files = {"image": img_file}response = requests.post(url, files=files)if response.status_code == 200:print("图片显示成功:")print(response.json())else:print("图片显示失败:", response.status_code)if __name__ == "__main__":image_path = 'query_image.jpg'test_search_similar_images(image_path)test_show_image(image_path)
7.参考博文
[1]https://mp.weixin.qq.com/s/wW_3X7CquqeuEdu4-zn3qg
8.仓库代码
https://github.com/Turing-dz/text_img_search_img