当前位置: 首页 > news >正文

使用数据库sqlite 筛选人脸信息

# 主要筛选人脸信息(比如:0 这个人的文件夹里有很多张属于0的人脸照片,但是同时又参杂一些非常模糊或者其他人的照片,那么可以通过这个方法把参杂的模糊的和其他人的人脸排序到最后,那样清理的时候就不需要到处找那些不合格的照片)

import os
import shutil

import numpy as np
from sklearn.metrics.pairwise import euclidean_distances
from PIL import Image
import torch
import torchvision.transforms as transforms
from facenet_pytorch import InceptionResnetV1
import sqlite3
import threading

# 1. 加载预训练的人脸特征提取模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = InceptionResnetV1(pretrained='vggface2').eval().to(device)

# 2. 图像预处理
transform = transforms.Compose([
    transforms.Resize((160, 160)),  # FaceNet 输入尺寸
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


# 3. 提取单张图像的特征向量
def extract_feature(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        feature = model(image).cpu().numpy().flatten()
    return feature


# 4. 创建 SQLite 数据库
def create_database(db_path):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute('''
        CREATE TABLE IF NOT EXISTS features (
            person_id TEXT,
            image_path TEXT,
            feature_vector BLOB,
            PRIMARY KEY (person_id, image_path)
        )
    ''')
    conn.commit()
    conn.close()


# 5. 将特征向量保存到数据库
def save_feature_to_db(db_path, person_id, image_path, feature):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    # 检查是否有相同的person_id 和 image_path 存在 (目的是为例防止程序中断 后 又重新运行 数据插入冲突导致报错)
    cursor.execute("""
       SELECT COUNT(*) FROM features
       WHERE person_id = ? AND image_path = ? """, (person_id, image_path))
    count = cursor.fetchone()[0]

    # 如果不存在
    if count == 0:
        feature_blob = feature.tobytes()  # 将特征向量转换为二进制格式
        cursor.execute('''
            INSERT INTO features (person_id, image_path, feature_vector)
            VALUES (?, ?, ?)
        ''', (person_id, image_path, feature_blob))
        conn.commit()
        conn.close()
    else:
        print(f"Feature for {person_id} - {image_path} already exists,  skipping")


# 6. 处理每个文件夹,提取特征并保存到数据库
def process_folder(db_path, folder_path, person_id):
    for image_name in os.listdir(folder_path):
        image_path = os.path.join(folder_path, image_name)
        # 避免处理非图片文件
        if image_path.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
            # 防止因图片损坏导致提取特侦失败致使程序中断
            try:
                feature = extract_feature(image_path)
                save_feature_to_db(db_path, person_id, image_path, feature)
            except Exception as e:
                print(e)


# 7. 从数据库中获取某个人的平均特征向量
def get_avg_feature(db_path, person_id):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute('''
        SELECT feature_vector FROM features WHERE person_id = ?
    ''', (person_id,))
    rows = cursor.fetchall()
    conn.close()

    # 将所有特征向量转换为 numpy 数组
    features = [np.frombuffer(row[0], dtype=np.float32) for row in rows]
    avg_feature = np.mean(features, axis=0)
    return avg_feature


# 8. 根据欧氏距离排序并重命名图像
def sort_and_rename_images(db_path, out_path, person_id):
    avg_feature = get_avg_feature(db_path, person_id)

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute('''
        SELECT image_path, feature_vector FROM features WHERE person_id = ?
    ''', (person_id,))
    rows = cursor.fetchall()
    conn.close()

    # 计算欧氏距离并排序
    distances = []
    for row in rows:
        image_path, feature_blob = row
        feature = np.frombuffer(feature_blob, dtype=np.float32)
        distance = euclidean_distances([feature], [avg_feature])[0][0]
        distances.append((image_path, distance))

    # 按距离排序
    distances.sort(key=lambda x: x[1])

    # 重命名文件
    for idx, (image_path, _) in enumerate(distances):
        new_name = f"{idx:04d}.jpg"  # 按距离排序后的新文件名
        # new_path = os.path.join(folder_path, new_name)
        new_path = rf'{out_path}/{person_id}/{new_name}'
        # 如果目标文件夹不存在,则创建
        os.makedirs(os.path.dirname(new_path), exist_ok=True)
        shutil.copy(image_path, new_path)

        # os.rename(image_path, new_path)


# 9. 主函数
def main():
    # 数据库路径
    db_path = r'D:\FS_project2\Feature_extraction\sql_database\features.db2'
    create_database(db_path)

    # 基础路径
    base_path = r'D:\FS_project2\Feature_extraction\peopel_crop'
    out_path = r'D:\FS_project2\Feature_extraction\out'

    # 第一步:提取特征并保存到数据库
    for folder in os.listdir(base_path):
        folder_path = os.path.join(base_path, folder)
        if os.path.isdir(folder_path):
            process_folder(db_path, folder_path, folder)
            print(f"Processed folder: {folder}")

    # 第二步:排序并重命名图像
    for folder in os.listdir(base_path):
        folder_path = os.path.join(base_path, folder)
        if os.path.isdir(folder_path):
            sort_and_rename_images(db_path, out_path, folder)
            print(f"Sorted and renamed folder: {folder}")


if __name__ == "__main__":
    main()

相关文章:

  • Tomcat添加到Windows系统服务中,服务名称带空格
  • FreeRTOS低功耗总结
  • 【进阶】JVM篇
  • Kernel之Tcpdump和Netfilter
  • CVE-2022-41352 漏洞分析与利用
  • 基于SpringBoot的在线交通服务管理系统
  • 持有无人机飞手执照,会组装调试维修入伍参军技术详解
  • 104、二叉树的最大深度
  • 同步buck型降压DCDC芯片外围电路详解
  • 一款利器提升 StarRocks 表结构设计效率
  • 图片旋转方向分类:从零开始构建深度学习模型
  • 10、《Thymeleaf模板引擎:动态页面开发全攻略》
  • 如何有效防止TikTok多店铺入驻时IP关联问题?
  • [鸿蒙笔记-基础篇_自定义构建函数及自定义公共样式]
  • 网络安全技术复习总结
  • 【Python深入浅出㊷】探索Python3中scikit-learn的无限可能
  • QtWebEngine::initialize()
  • MySQL查看存储过程和存储函数
  • 2025 AutoCable 中国汽车线束线缆及连接技术创新峰会启动报名!
  • vscode本地和远程对应分支没有同步提交数量
  • 武汉旅游体育集团有限公司原党委书记、董事长董志向被查
  • 警方通报男子地铁上拍视频致乘客恐慌受伤:列车运行一度延误,已行拘
  • 长江画派创始人之一、美术家鲁慕迅逝世,享年98岁
  • 讲座预告|全球贸易不确定情况下企业创新生态构建
  • 市自规局公告收回新校区建设用地,宿迁学院:需变更建设主体
  • 胳膊一抬就疼,炒菜都成问题?警惕这种“炎症”找上门