将.pt文件执行图像比对
目录
1. 加载模型
2. 图像预处理
3. 提取图像特征
4. 计算相似度
调用API或封装函数即可实现端到端比对
使用.pt
文件进行图像比对通常涉及以下步骤:
1. 加载模型
python
import torch# 假设模型是PyTorch保存的权重文件
model = YourModelClass() # 需与保存时的模型结构一致
model.load_state_dict(torch.load('model.pt'))
model.eval() # 切换到推理模式
2. 图像预处理
使用torchvision.transforms
标准化输入:
python
from torchvision import transformspreprocess = transforms.Compose([transforms.Resize((224, 224)), # 根据模型要求调整尺寸transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet标准化
])
3. 提取图像特征
将图像输入模型,获取特征向量:
python
from PIL import Imagedef get_features(image_path):img = Image.open(image_path).convert('RGB')img_tensor = preprocess(img).unsqueeze(0) # 添加batch维度with torch.no_grad():features = model(img_tensor) # 假设模型输出特征向量return features.squeeze() # 去除batch维度
4. 计算相似度
使用余弦相似度或欧氏距离:
python
import torch.nn.functional as F# 假设features1和features2是两张图的特征向量
cos_sim = F.cosine_similarity(features1, features2, dim=0) # 值越接近1越相似
euclidean_dist = torch.norm(features1 - features2, p=2) # 值越小越相似
调用API或封装函数即可实现端到端比对
# 端到端图像比对函数封装
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Imageclass ImageComparator:def __init__(self, model_path, device='cpu'):self.device = torch.device(device)self.model = torch.load(model_path, map_location=device)self.model.eval()self.preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])def extract_features(self, img_path):img = Image.open(img_path).convert('RGB')return self.model(self.preprocess(img).unsqueeze(0).to(self.device))[0]def compare(self, img1_path, img2_path):with torch.no_grad():f1, f2 = self.extract_features(img1_path), self.extract_features(img2_path)return F.cosine_similarity(f1, f2, dim=0).item()# 使用示例
comparator = ImageComparator('model.pt', device='cuda' if torch.cuda.is_available() else 'cpu')
similarity = comparator.compare('img1.jpg', 'img2.jpg')
print(f"Similarity Score: {similarity:.4f}")