当前位置: 首页 > news >正文

6.2 实现文档加载和切分和简易向量数据库的功能

文章目录

  • 背景
  • 资料的加载和切分
    • 代码测试
    • 结果
  • 向量数据库
    • 定义基类
    • 准备数据
    • 编写本地向量数据库
      • 代码
      • 测试存储是否成功
      • 结果
      • 测试能否根据query查找相似度高的资料
      • 结果
  • 完整代码

背景

之前我们已经完成了Embedding模块的工作,也就是说我们现在已经可以做到将资料转化为embedding,从而可以与用户输入的提问进行相似度计算了,但这仅仅只是一小步的工作,目前我们还缺少

  1. 资料的加载和切分的工具
  2. 通过向量数据库,将向量后的文档存储下来

资料的加载和切分

特别感谢知乎用户“不要葱姜蒜”的博文,本项目也有参考他的项目的地方

动手实现一个最小RAG——TinyRAG ——不要葱姜蒜

这里我用了他的文档加载和切分工具,不过稍微修改了一下,原代码在获取chunk的时候是会去掉所有空格的,我在这里把空格保留了

代码我就不解释了,并不复杂

#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
@File    :   utils.py
@Time    :   2024/02/11 09:52:26
@Author  :   不要葱姜蒜
@Version :   1.0
@Desc    :   None
'''
import os
from typing import Dict, List, Optional, Tuple, Unionimport PyPDF2
import markdown
import html2text
import json
from tqdm import tqdm
import tiktoken
from bs4 import BeautifulSoup
import re
from TinyRAG import base_tiktoken_dir
from TinyRAG.Embedding.BGE_base_zh import BGEBaseZHos.environ["TIKTOKEN_CACHE_DIR"] = base_tiktoken_dir
enc = tiktoken.get_encoding("cl100k_base") # 用于计算文本长度的编码器class ReadFiles:"""class to read files"""def __init__(self, path: str) -> None:self._path = pathself.file_list = self.get_files()def get_files(self):# args:dir_path,目标文件夹路径file_list = []for filepath, dirnames, filenames in os.walk(self._path):# os.walk 函数将递归遍历指定文件夹for filename in filenames:# 通过后缀名判断文件类型是否满足要求if filename.endswith(".md"):# 如果满足要求,将其绝对路径加入到结果列表file_list.append(os.path.join(filepath, filename))elif filename.endswith(".txt"):file_list.append(os.path.join(filepath, filename))elif filename.endswith(".pdf"):file_list.append(os.path.join(filepath, filename))return file_listdef get_content(self, max_token_len: int = 600, cover_content: int = 150):docs = []# 读取文件内容for file in self.file_list:content = self.read_file_content(file)chunk_content = self.get_chunk(content, max_token_len=max_token_len, cover_content=cover_content)docs.extend(chunk_content)return docs@classmethoddef get_chunk(cls, text: str, max_token_len: int = 600, cover_content: int = 150):chunk_text = []curr_len = 0curr_chunk = ''token_len = max_token_len - cover_contentlines = text.splitlines()  # 假设以换行符分割文本为行for line in lines:# line = line.replace(' ', '') # 注释掉这行 保留空格line_len = len(enc.encode(line))# print(line_len, line)if line_len > max_token_len:# 如果单行长度就超过限制,则将其分割成多个块num_chunks = (line_len + token_len - 1) // token_lenfor i in range(num_chunks):start = i * token_lenend = start + token_len# 避免跨单词分割while not line[start:end].rstrip().isspace():start += 1end += 1if start >= line_len:breakcurr_chunk = curr_chunk[-cover_content:] + line[start:end]chunk_text.append(curr_chunk)# 处理最后一个块start = (num_chunks - 1) * token_lencurr_chunk = curr_chunk[-cover_content:] + line[start:end]chunk_text.append(curr_chunk)if curr_len + line_len <= token_len:curr_chunk += linecurr_chunk += '\n'curr_len += line_lencurr_len += 1else:chunk_text.append(curr_chunk)curr_chunk = curr_chunk[-cover_content:] + linecurr_len = line_len + cover_contentif curr_chunk:chunk_text.append(curr_chunk)return chunk_text@classmethoddef read_file_content(cls, file_path: str):# 根据文件扩展名选择读取方法if file_path.endswith('.pdf'):return cls.read_pdf(file_path)elif file_path.endswith('.md'):return cls.read_markdown(file_path)elif file_path.endswith('.txt'):return cls.read_text(file_path)else:raise ValueError("Unsupported file type")@classmethoddef read_pdf(cls, file_path: str):# 读取PDF文件with open(file_path, 'rb') as file:reader = PyPDF2.PdfReader(file)text = ""for page_num in range(len(reader.pages)):text += reader.pages[page_num].extract_text()return text@classmethoddef read_markdown(cls, file_path: str):# 读取Markdown文件with open(file_path, 'r', encoding='utf-8') as file:md_text = file.read()html_text = markdown.markdown(md_text)# 使用BeautifulSoup从HTML中提取纯文本soup = BeautifulSoup(html_text, 'html.parser')plain_text = soup.get_text()# 使用正则表达式移除网址链接text = re.sub(r'http\S+', '', plain_text)return text@classmethoddef read_text(cls, file_path: str):# 读取文本文件with open(file_path, 'r', encoding='utf-8') as file:return file.read()class Documents:"""获取已分好类的json格式文档"""def __init__(self, path: str = '') -> None:self.path = pathdef get_content(self):with open(self.path, mode='r', encoding='utf-8') as f:content = json.load(f)return content

代码测试

这里我就直接用一段普通的文本来进行测试了,我随便挖了点文字

if __name__ == '__main__':s = """这是一个测试文本,用于检查文本分割功能。这段文本应该被分割成多个块,以便适应指定的最大长度。每个块的长度不应超过600个token,且每个块的前150个字符应被覆盖。这段文本的长度应该足够长,以便测试分割功能的有效性。随机生成一段超过600个字符的长文本以确保分割功能能够正确处理长文本。用户刚刚问我是否会保留空格,我得先回顾一下之前的对话内容,用户之前关注的是不同模型对文本的处理方式以及空格和符号的处理。在这个背景下,我被问到是否保留空格,这表明用户想要了解我自身的文本处理机制,可能对我处理文本的具体细节感兴趣或者想进一步比较不同模型的处理方式。
我清楚地知道,我作为一个人工智能助手,我的文本处理机制是按照既定的规则和模型架构设计的,这些设计就是在服务用户的过程中经过大量数据训练和优化而来的。对于用户的这个问题,我的回答要基于我自身的模型架构和设计目标,同时也要符合实际情况,不能夸大或者虚构我的功能。
在分析了用户的问题后,我决定从我自身的设计和功能出发,直白且准确地回答用户的问题。我思考着,既然我的设计决定了我在处理用户输入时会保留空格,那我就直接告诉用户这个事实,这样可以满足用户对我的文本处理机制的好奇,也能体现出我的回答是基于实际情况的,而不是无端猜测或者随意编造的。
我准备直接回答用户的问题,把我的处理方式说明清楚,这样既能满足用户的好奇心,也能体现出我的诚实和可靠。This is a test text to check the text segmentation function.This text should be split into multiple chunks to fit the specified maximum length.Each chunk should not exceed 600 tokens, and the first 150 characters of each chunk should be covered.This text should be long enough to test the effectiveness of the segmentation function.Randomly generate a long text that exceeds 600 charactersto ensure that the segmentation function can correctly handle long texts.The user just asked me if I would keep spaces, and I need to review the previous conversation content first. The user was previously concerned about how different models handle text, including spaces and punctuation. In this context, I was asked whether I retain spaces, indicating that the user is interested in understanding my own text processing mechanism, possibly to compare it with other models.
I clearly know that as an AI assistant, my text processing mechanism is designed according to established rules and model architecture, which have been trained and optimized through a large amount of data in the process of serving users. Regarding the user's question, my answer should be based on my own model architecture and design goals, while also conforming to the actual situation, without exaggerating or fabricating my capabilities.
I analyzed the user's question and decided to answer it directly and accurately based on my own design and functionality. I thought, since my design determines that I will retain spaces when processing user input, I would just tell the user this fact, which can satisfy the user's curiosity about my text processing mechanism and also reflect that my answer is based on actual conditions, rather than unfounded speculation or random fabrication."""res = ReadFiles.get_chunk(s, max_token_len=600, cover_content=150)embedding_model = BGEBaseZH()embeddings = embedding_model.get_embedding(res)print(embeddings.shape)print("All chunks processed.")

结果

可以看到我们的工具把这段文字切分成了3段
在这里插入图片描述

向量数据库

定义基类

"""
-*- coding: UTF-8 -*-
@Author  :Leezed
@Date    :2025/6/30 23:21 
"""
from typing import Listclass VectorStore:def __init__(self) -> None:"""初始化向量存储类"""passdef persist(self, file_path):# 数据库持久化,本地保存raise NotImplementedError("This method should be overridden by subclasses.")def load_vector(self):# 从本地加载数据库raise NotImplementedError("This method should be overridden by subclasses.")def query(self, query: str,  k: int = 1) -> List[str]:# 根据问题检索相关的文档片段raise NotImplementedError("This method should be overridden by subclasses.")

简单定义一个基类,相当于抽象类的功能,但是不去具体实现,主要目的是指定我们的向量数据最基本的需要实现的功能

  1. persist(self, file_path) 的功能是将资料保存下来,注意,这里我的意思是把资料保存成一种方便数据库读取和修改的形式,避免老是去读取本地的各种格式的文档
  2. def load_vector(self): 的功能是读取之前持久化保存下来的数据
  3. query(self, query: str, k: int = 1) -> List[str]: 这个就更简单了,这个就是实现根据输入的问题,返回与问题最匹配的前k个资料的函数了

准备数据

在这里插入图片描述
这里我图省事就没有自己去找资料了,用了别的博主的,具体的连接在上面发了,当然如果不愿意,去找一份合适的资料也可以自己整理md格式的文件啥的

编写本地向量数据库

这里我叫他本地向量数据库的原因是后续我还想实现一个联网爬虫找资料,然后在实现数据库的功能的,所以为了区分我这里叫他本地向量数据库。

这个数据库的实现思路是很简单的,没有用什么花里胡哨的功能,毕竟只是一个简易的项目,后续还可以再精进

  1. 当用户调用persist函数是,根据用户传进来的文件路径,获取该路径下的所有资料,进行向量化后存储。
  2. 存储方式采用json文件的形式,所以我说他很简陋,这个形式实现的功能有:
    1. 存储所有分块后的资料的原文,向量作为value
    2. 以原文的md5最为key
    3. 形成一个键值对的形式,方便批量取出原来的向量进行相似度计算
    4. 同时也能方便对于相同的资料,避免重复的存进json
  3. 当用户调用query是计算相似度

代码

import osfrom TinyRAG.VectorBase.VectorBase import VectorStore
from TinyRAG import base_local_vector_base_dir, base_data_dir
from TinyRAG.utils import ReadFiles
import json
from TinyRAG.Embedding.BaseEmbedding import BaseEmbedding
from TinyRAG.Embedding.BGE_base_zh import BGEBaseZH
import hashlib
import numpy as np
from typing import Listclass LocalVectorBase(VectorStore):def __init__(self, embedding_model: BaseEmbedding, path=base_local_vector_base_dir, json_file="vector.json"):"""初始化本地向量存储类:param path: 存储路径"""super().__init__()self.path = path# 检查当前目录是否存在if not self.path.endswith('/'):self.path += '/'if not os.path.exists(self.path):os.makedirs(self.path)print(f"Directory {self.path} created.")# 判断目录下是否存在向量文件self.vector_file = os.path.join(self.path, json_file)if not os.path.exists(self.vector_file):with open(self.vector_file, 'w', encoding='utf-8') as f:json.dump({}, f, ensure_ascii=False, indent=4)print(f"Vector file {self.vector_file} created.")# 加载向量文件self.vectors = self.load_vector()# 嵌入模型self.embedding_model = embedding_modeldef persist(self, file_path):"""数据库持久化,本地保存"""# 这里实现将文件转成向量存储到本地files = ReadFiles(file_path)docs = files.get_content()vectors = {}embeddings = self.embedding_model.get_embedding(docs)for i, doc in enumerate(docs):vector = {"text": doc,"embedding": embeddings[i].tolist()  # 将numpy数组转换为列表以便于JSON序列化}# 根据doc的内容一个md5哈希值作为唯一标识符doc_hash = hashlib.md5(doc.encode('utf-8')).hexdigest()vectors[doc_hash] = vector# 检查目前的向量文件中是否已经存在相同的向量existing_vectors = self.load_vector()for key, value in vectors.items():if key not in existing_vectors:existing_vectors[key] = value# 保存向量到文件with open(self.vector_file, 'w', encoding='utf-8') as f:json.dump(existing_vectors, f, ensure_ascii=False, indent=4)self.vectors = existing_vectorsdef load_vector(self):"""从本地加载数据库"""with open(self.vector_file, 'r', encoding='utf-8') as f:vectors = json.load(f)return vectorsdef query(self, query: str, k: int = 1) -> List[str]:"""根据问题检索相关的文档片段:param query: 查询字符串:param EmbeddingModel: 嵌入模型:param k: 返回的结果数量:return: 文档片段列表"""query_embedding = self.embedding_model.get_embedding(query)# 取出self.vectors中的所有向量all_embeddings = np.array([vector['embedding'] for vector in self.vectors.values()])# 计算查询向量与所有向量的相似度similarities = np.dot(all_embeddings, query_embedding)# 获取相似度最高的k个索引top_k_indices = np.argsort(similarities)[-k:][::-1]# 返回相似度最高的k个文档片段results = []docs = list(self.vectors.values())for index in top_k_indices:results.append(docs[index]['text'])return results

测试存储是否成功

if __name__ == '__main__':embedding_model = BGEBaseZH()local_vector_base = LocalVectorBase(embedding_model)local_vector_base.persist(base_data_dir)

结果

在这里插入图片描述

测试能否根据query查找相似度高的资料

if __name__ == '__main__':embedding_model = BGEBaseZH()local_vector_base = LocalVectorBase(embedding_model)# local_vector_base.persist(base_data_dir)query  = "请你讲讲git push的用法"res = local_vector_base.query(query,k = 3)for i, r in enumerate(res):print(f"Result {i + 1}: {r}")# print(res)

结果

能很成功的输出相关的资料
在这里插入图片描述

完整代码

完整的代码可以去我的github拿

https://github.com/Leezed525/pytorch_toy

http://www.dtcms.com/a/264279.html

相关文章:

  • 【在 FastAdmin 中取消特定字段的搜索功能】
  • Conda 虚拟环境克隆与 PyCharm 配置教程
  • 高阶数据结构------并查集
  • uniapp+vue3 中使用echart 以及echart文件过大需要分包的记录
  • 吸烟行为检测数据集介绍-2,108张图片 公共场所禁烟监控 健康行为研究
  • SpringCloud系列(45)--SpringCloud Bus简介
  • UE5 - 制作《塞尔达传说》中林克的技能 - 18 - 磁力抓取器
  • 强化学习【chapter0】-学习路线图
  • Java Selenium反爬虫技术方案
  • 07 Springboot+netty+mqtt服务端实现【重构】
  • 数据结构之带头双向循环链表
  • 苍穹外卖系列问题之Day11_05营业额统计代码开发2 StringUtils.join(dateList,“,“)报错
  • Cross-modal Information Flow in Multimodal Large Language Models
  • 【1.6 漫画数据库设计实战 - 从零开始设计高性能数据库】
  • 2025年主流大厂Java后端面试题主题深度解析
  • 推客系统小程序终极指南:从0到1构建自动裂变增长引擎,实现业绩10倍增长!
  • 快速手搓一个MCP服务指南(九): FastMCP 服务器组合技术:构建模块化AI应用的终极方案
  • 【大模型学习 | BLIP2原理】
  • 「Java流程控制」for循环结构
  • langchain从入门到精通(三十二)——RAG优化策略(八)自查询检索器实现动态数据过滤
  • 腾讯 iOA 零信任产品:安全远程访问的革新者
  • Redis-渐进式遍历
  • Java后端调用外部接口标准流程详解
  • python+uniapp基于微信小程序的PS社区系统
  • 使用D435i运行ORB-SLAM3时,纯视觉模式与视觉-惯性模式的位姿矩阵定义问题探讨
  • 基于SpringBoot + HTML 的网上书店系统
  • 转录组分析流程(六):列线图
  • Kafka 生产者和消费者高级用法
  • c++学习(八、函数指针和线程)
  • EasyExcel实现Excel复杂格式导出:合并单元格与样式设置实战