PyTorch 进行人脸训练和 Java 实现人脸识别的解决方案
下面我将为你提供一个结合 PyTorch 进行人脸训练和 Java 实现人脸识别的解决方案。这个方案分为两个主要部分:用 PyTorch 训练人脸特征提取模型,以及用 Java 加载模型并实现人脸识别功能。
第一部分:PyTorch 人脸特征训练
首先我们需要训练一个能够提取人脸特征的模型,这里使用基于 ResNet 的架构并结合 Triplet Loss 进行训练:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import os
import numpy as np
from sklearn.model_selection import train_test_split# 1. 数据集定义
class FaceDataset(Dataset):def __init__(self, image_paths, labels, transform=None):self.image_paths = image_pathsself.labels = labelsself.transform = transformdef __len__(self):return len(self.image_paths)def __getitem__(self, idx):img_path = self.image_paths[idx]image = Image.open(img_path).convert('RGB')label = self.labels[idx]if self.transform:image = self.transform(image)return image, label# 2. 模型定义 - 基于ResNet50的特征提取器
class FaceFeatureExtractor(nn.Module):def __init__(self, embedding_size=128, num_classes=None):super(FaceFeatureExtractor, self).__init__()self.resnet = models.resnet50(pretrained=True)# 替换最后一层,输出特征向量self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embedding_size)self.embedding_size = embedding_size# 如果需要分类头(用于辅助训练)self.classifier = Noneif num_classes is not None:self.classifier = nn.Linear(embedding_size, num_classes)def forward(self, x):features = self.resnet(x)if self.classifier is not None:logits = self.classifier(features)return features, logitsreturn features# 3. Triplet Loss定义
class TripletLoss(nn.Module):def __init__(self, margin=0.5):super(TripletLoss, self).__init__()self.margin = margindef forward(self, anchor, positive, negative):distance_positive = torch.sqrt(torch.sum(torch.pow(anchor - positive, 2), dim=1))distance_negative = torch.sqrt(torch.sum(torch.pow(anchor - negative, 2), dim=1))losses = torch.relu(distance_positive - distance_negative + self.margin)return torch.mean(losses)# 4. 数据准备
def prepare_data(data_dir):image_paths = []labels = []label_map = {}current_label = 0for person_name in os.listdir(data_dir):person_dir = os.path.join(data_dir, person_name)if os.path.isdir(person_dir):label_map[person_name] = current_labelfor img_file in os.listdir(person_dir):if img_file.endswith(('.jpg', '.png', '.jpeg')):image_paths.append(os.path.join(person_dir, img_file))labels.append(current_label)current_label += 1return image_paths, labels, label_map, current_label# 5. 训练函数
def train_model(data_dir, epochs=50, batch_size=32, embedding_size=128):# 数据预处理transform = transforms.Compose([transforms.Resize((150, 150)),transforms.RandomHorizontalFlip(),transforms.RandomRotation(10),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 准备数据image_paths, labels, label_map, num_classes = prepare_data(data_dir)train_paths, val_paths, train_labels, val_labels = train_test_split(image_paths, labels, test_size=0.2, random_state=42)# 创建数据集和数据加载器train_dataset = FaceDataset(train_paths, train_labels, transform)val_dataset = FaceDataset(val_paths, val_labels, transform)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)# 初始化模型、损失函数和优化器device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = FaceFeatureExtractor(embedding_size=embedding_size, num_classes=num_classes).to(device)# 组合损失:Triplet Loss + 交叉熵损失(辅助训练)triplet_criterion = TripletLoss()ce_criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.0001)# 训练循环for epoch in range(epochs):model.train()train_loss = 0.0for images, labels in train_loader:images = images.to(device)labels = labels.to(device)optimizer.zero_grad()# 获取特征和分类结果features, logits = model(images)# 构建三元组(简单实现:同一类为正样本,不同类为负样本)# 实际应用中应使用更智能的三元组挖掘策略anchor_idx = torch.randint(0, len(images), (len(images)//2,))positive_idx = [torch.where(labels == labels[i])[0][torch.randint(0, len(torch.where(labels == labels[i])[0]), (1,))] for i in anchor_idx]negative_idx = [torch.where(labels != labels[i])[0][torch.randint(0, len(torch.where(labels != labels[i])[0]), (1,))] for i in anchor_idx]anchor = features[anchor_idx]positive = features[positive_idx]negative = features[negative_idx]# 计算损失triplet_loss = triplet_criterion(anchor.squeeze(), positive.squeeze(), negative.squeeze())ce_loss = ce_criterion(logits, labels)loss = triplet_loss + 0.5 * ce_loss # 组合损失loss.backward()optimizer.step()train_loss += loss.item() * images.size(0)# 计算平均损失train_loss = train_loss / len(train_loader.dataset)# 验证model.eval()val_loss = 0.0with torch.no_grad():for images, labels in val_loader:images = images.to(device)labels = labels.to(device)features, logits = model(images)ce_loss = ce_criterion(logits, labels)val_loss += ce_loss.item() * images.size(0)val_loss = val_loss / len(val_loader.dataset)print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')# 保存模型和标签映射torch.save(model.state_dict(), 'face_feature_extractor.pth')np.save('label_map.npy', label_map)return model, label_map# 6. 导出为ONNX格式(供Java使用)
def export_to_onnx(model, input_size=(1, 3, 150, 150), output_path='face_model.onnx'):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')dummy_input = torch.randn(*input_size).to(device)# 切换到评估模式model.eval()# 导出模型torch.onnx.export(model,dummy_input,output_path,export_params=True,opset_version=11,do_constant_folding=True,input_names=['input'],output_names=['output'],dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})print(f"模型已导出为ONNX格式: {output_path}")if __name__ == "__main__":# 假设数据集组织结构为:data_dir/人名/图片.jpgdata_dir = "path/to/your/face_dataset"model, label_map = train_model(data_dir, epochs=50)export_to_onnx(model)
第二部分:Java 实现人脸识别
Java 实现人脸识别需要使用 OpenCV 进行人脸检测和预处理,使用 Deeplearning4j 加载 ONNX 模型进行特征提取,然后通过计算特征向量距离实现人脸识别。
import org.opencv.core.*;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import org.opencv.objdetect.CascadeClassifier;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.onnxruntime.OrtEnvironment;
import org.nd4j.onnxruntime.OrtSession;import java.io.*;
import java.nio.FloatBuffer;
import java.util.*;
import java.util.stream.Collectors;public class FaceRecognition {// 加载OpenCV库static {System.loadLibrary(Core.NATIVE_LIBRARY_NAME);}private final CascadeClassifier faceDetector;private final OrtEnvironment env;private final OrtSession session;private final Map<String, float[]> faceFeatures;private final float THRESHOLD = 0.6f; // 相似度阈值,根据实际情况调整public FaceRecognition(String modelPath, String featuresPath) throws Exception {// 初始化人脸检测器faceDetector = new CascadeClassifier("haarcascade_frontalface_default.xml");// 初始化ONNX模型env = OrtEnvironment.getEnvironment();session = env.createSession(modelPath);// 加载已知人脸特征faceFeatures = loadFaceFeatures(featuresPath);}// 从文件加载人脸特征private Map<String, float[]> loadFaceFeatures(String path) throws IOException {Map<String, float[]> features = new HashMap<>();File file = new File(path);if (!file.exists()) {return features;}try (BufferedReader br = new BufferedReader(new FileReader(file))) {String line;while ((line = br.readLine()) != null) {String[] parts = line.split("\t");if (parts.length == 2) {String name = parts[0];float[] feature = Arrays.stream(parts[1].split(",")).mapToFloat(Float::parseFloat).toArray();features.put(name, feature);}}}return features;}// 保存人脸特征到文件public void saveFaceFeatures(String path) throws IOException {try (BufferedWriter bw = new BufferedWriter(new FileWriter(path))) {for (Map.Entry<String, float[]> entry : faceFeatures.entrySet()) {bw.write(entry.getKey() + "\t");bw.write(Arrays.stream(entry.getValue()).mapToObj(String::valueOf).collect(Collectors.joining(",")));bw.newLine();}}}// 检测并提取人脸public Mat detectAndExtractFace(Mat image) {MatOfRect faceDetections = new MatOfRect();faceDetector.detectMultiScale(image, faceDetections);// 如果检测到人脸,返回第一个人脸区域if (faceDetections.empty()) {return null;}Rect faceRect = faceDetections.toArray()[0];Mat face = new Mat(image, faceRect);// 预处理:调整大小、转为RGB、归一化Imgproc.resize(face, face, new Size(150, 150));Imgproc.cvtColor(face, face, Imgproc.COLOR_BGR2RGB);return face;}// 提取人脸特征public float[] extractFeatures(Mat face) throws Exception {if (face == null) {return null;}// 转换为模型输入格式 (1, 3, 150, 150)float[] data = new float[3 * 150 * 150];int idx = 0;for (int c = 0; c < 3; c++) {for (int h = 0; h < 150; h++) {for (int w = 0; w < 150; w++) {double[] pixel = face.get(h, w);// 归一化:ImageNet的均值和标准差data[idx++] = (float)((pixel[c] / 255.0 - 0.485) / 0.229);}}}// 创建输入张量INDArray input = Nd4j.create(data, new int[]{1, 3, 150, 150});// 运行模型OrtSession.Result result = session.run(Collections.singletonMap("input", env.createTensor(input.data().asFloat())));// 提取输出特征float[] features = result.get(0).getValue().asFloat();// 特征归一化INDArray featureArray = Nd4j.create(features);featureArray = Transforms.unitVec(featureArray);return featureArray.data().asFloat();}// 计算特征相似度(余弦相似度)private float calculateSimilarity(float[] feature1, float[] feature2) {float dotProduct = 0.0f;float norm1 = 0.0f;float norm2 = 0.0f;for (int i = 0; i < feature1.length; i++) {dotProduct += feature1[i] * feature2[i];norm1 += feature1[i] * feature1[i];norm2 += feature2[i] * feature2[i];}return dotProduct / (float)(Math.sqrt(norm1) * Math.sqrt(norm2));}// 识别人脸public String recognizeFace(float[] features) {if (features == null || faceFeatures.isEmpty()) {return "Unknown";}String bestMatch = "Unknown";float maxSimilarity = 0.0f;for (Map.Entry<String, float[]> entry : faceFeatures.entrySet()) {float similarity = calculateSimilarity(features, entry.getValue());if (similarity > maxSimilarity && similarity > THRESHOLD) {maxSimilarity = similarity;bestMatch = entry.getKey();}}return bestMatch;}// 注册新人脸public void registerFace(String name, float[] features) {if (name != null && features != null) {faceFeatures.put(name, features);}}// 主方法示例public static void main(String[] args) {try {// 初始化人脸识别器FaceRecognition recognizer = new FaceRecognition("face_model.onnx", "face_features.txt");// 处理测试图片Mat image = Imgcodecs.imread("test_face.jpg");Mat face = recognizer.detectAndExtractFace(image);if (face != null) {float[] features = recognizer.extractFeatures(face);// 识别人脸String result = recognizer.recognizeFace(features);System.out.println("识别结果: " + result);// 如果是未知人脸,可以选择注册if ("Unknown".equals(result)) {// recognizer.registerFace("new_person", features);// recognizer.saveFaceFeatures("face_features.txt");}} else {System.out.println("未检测到人脸");}} catch (Exception e) {e.printStackTrace();}}
}
实现说明
整体流程:
- 先用 PyTorch 训练人脸特征提取模型,将人脸图像转换为固定维度的特征向量
- 将训练好的模型导出为 ONNX 格式,方便跨平台使用
- 在 Java 中使用 OpenCV 检测人脸,使用 Deeplearning4j 加载 ONNX 模型提取特征
- 通过计算特征向量之间的余弦相似度来识别人脸
环境依赖:
- Python: PyTorch, torchvision, Pillow, numpy, scikit-learn
- Java: OpenCV, Deeplearning4j, ND4J, ONNX Runtime
使用方法:
- 准备人脸数据集,按人名分类存放
- 运行 PyTorch 代码训练模型并导出 ONNX 格式
- 在 Java 项目中添加依赖,配置 OpenCV
- 使用 Java 代码进行人脸检测、特征提取和识别
注意事项:
- 人脸数据集质量对识别效果影响很大,建议每个人至少提供 5-10 张不同角度和光照的照片
- 阈值 THRESHOLD 需要根据实际测试结果调整
- 生产环境中需要优化三元组选择策略,提高模型性能
- Java 代码中需要正确配置 OpenCV 库文件路径
这个方案实现了从模型训练到实际应用的完整流程,可以根据具体需求进行优化和扩展。