当前位置: 首页 > news >正文

RAG中使用到的相关函数注释——LangChain核心函数

LangChain核心函数

from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_core.prompts import ChatPromptTemplate

TextLoader

源码位置:langchain-community/libs/community/langchain_community/document_loaders/init.py

  from langchain_community.document_loaders.text import (TextLoader,)
# text.pyimport logging
from pathlib import Path
from typing import Iterator, Optional, Unionfrom langchain_core.documents import Documentfrom langchain_community.document_loaders.base import BaseLoader
from langchain_community.document_loaders.helpers import detect_file_encodingslogger = logging.getLogger(__name__)class TextLoader(BaseLoader):"""Load text file.Args:file_path: Path to the file to load.encoding: File encoding to use. If `None`, the file will be loadedwith the default system encoding.autodetect_encoding: Whether to try to autodetect the file encodingif the specified encoding fails."""def __init__(self,file_path: Union[str, Path],encoding: Optional[str] = None,autodetect_encoding: bool = False,):"""Initialize with file path."""self.file_path = file_pathself.encoding = encodingself.autodetect_encoding = autodetect_encodingdef lazy_load(self) -> Iterator[Document]:"""Load from file path."""text = ""try:with open(self.file_path, encoding=self.encoding) as f:text = f.read()except UnicodeDecodeError as e:if self.autodetect_encoding:detected_encodings = detect_file_encodings(self.file_path)for encoding in detected_encodings:logger.debug(f"Trying encoding: {encoding.encoding}")try:with open(self.file_path, encoding=encoding.encoding) as f:text = f.read()breakexcept UnicodeDecodeError:continueelse:raise RuntimeError(f"Error loading {self.file_path}") from eexcept Exception as e:raise RuntimeError(f"Error loading {self.file_path}") from emetadata = {"source": str(self.file_path)}yield Document(page_content=text, metadata=metadata)
  • TextLoader 继承自 BaseLoader(LangChain 所有文档加载器的抽象基类),必须实现 lazy_load 方法(懒加载文档的核心逻辑)。文档字符串明确了其作用 —— 加载文本文件(.txt 等纯文本格式),并说明三个关键参数的含义。
  • file_path:支持字符串(如 "./data.txt")或 Path 对象(如 Path("./data.txt")),表示要加载的文本文件路径(必填)。
  • encoding:可选参数,指定文件编码(如 "utf-8""gbk")。默认 None,此时使用系统默认编码。
  • autodetect_encoding:布尔值,默认 False。若为 True,当指定的 encoding 解码失败时,会自动尝试检测文件编码并重新加载。
  • lazy_loadBaseLoader 要求实现的抽象方法,作用是懒加载文档(返回 Document 对象的迭代器),避免一次性加载大文件导致内存占用过高。生成 Document 对象:
    • page_content:存储读取到的文本内容(核心数据)。
    • metadata:存储元数据(这里仅包含 source 字段,记录文件路径,方便后续追溯文档来源)。
    • yield 返回迭代器(而非一次性返回列表),实现 “懒加载”(尤其适合大文件,减少内存占用)。
  • detect_file_encodings 工具函数:内部通过分析文件字节流(如 BOM 头、常见字符集特征)推测可能的编码,返回按置信度排序的结果。

RecursiveCharacterTextSplitter

源码位置:langchain/libs/text-splitters/langchain_text_splitters/init.py

from langchain_text_splitters.character import (CharacterTextSplitter,RecursiveCharacterTextSplitter,
)
# character.pyfrom __future__ import annotationsimport re
from typing import Any, Literal, Optional, Unionfrom langchain_text_splitters.base import Language, TextSplitterclass CharacterTextSplitter(TextSplitter):"""Splitting text that looks at characters."""def __init__(self,separator: str = "\n\n",is_separator_regex: bool = False,  # noqa: FBT001,FBT002**kwargs: Any,) -> None:"""Create a new TextSplitter."""super().__init__(**kwargs)self._separator = separatorself._is_separator_regex = is_separator_regexdef split_text(self, text: str) -> list[str]:"""Split into chunks without re-inserting lookaround separators."""# 1. Determine split pattern: raw regex or escaped literalsep_pattern = (self._separator if self._is_separator_regex else re.escape(self._separator))# 2. Initial split (keep separator if requested)splits = _split_text_with_regex(text, sep_pattern, keep_separator=self._keep_separator)# 3. Detect zero-width lookaround so we never re-insert itlookaround_prefixes = ("(?=", "(?<!", "(?<=", "(?!")is_lookaround = self._is_separator_regex and any(self._separator.startswith(p) for p in lookaround_prefixes)# 4. Decide merge separator:#    - if keep_separator or lookaround -> don't re-insert#    - else -> re-insert literal separatormerge_sep = ""if not (self._keep_separator or is_lookaround):merge_sep = self._separator# 5. Merge adjacent splits and returnreturn self._merge_splits(splits, merge_sep)def _split_text_with_regex(text: str, separator: str, *, keep_separator: Union[bool, Literal["start", "end"]]
) -> list[str]:if separator:  # 非空分隔符if keep_separator:  # 保留分隔符# 正则分组(())保留分隔符在分割结果中_splits = re.split(f"({separator})", text)# 根据keep_separator类型,将分隔符附加在片段开头/结尾if keep_separator == "end":splits = [_splits[i] + _splits[i + 1] for i in range(0, len(_splits)-1, 2)]else:  # "start"或Truesplits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]# 处理分割后的奇数长度列表(补全最后一个片段)if len(_splits) % 2 == 0:splits += _splits[-1:]splits = [_splits[0], *splits] if keep_separator == "start" else [*splits, _splits[-1]]else:  # 不保留分隔符,直接分割splits = re.split(separator, text)else:  # 空分隔符(兜底):按单个字符分割splits = list(text)# 过滤空字符串片段return [s for s in splits if s != ""]class RecursiveCharacterTextSplitter(TextSplitter):"""Splitting text by recursively look at characters.Recursively tries to split by different characters to find onethat works."""def __init__(self,separators: Optional[list[str]] = None,keep_separator: Union[bool, Literal["start", "end"]] = True,  # noqa: FBT001,FBT002is_separator_regex: bool = False,  # noqa: FBT001,FBT002**kwargs: Any,) -> None:"""Create a new TextSplitter."""super().__init__(keep_separator=keep_separator, **kwargs)self._separators = separators or ["\n\n", "\n", " ", ""]self._is_separator_regex = is_separator_regexdef _split_text(self, text: str, separators: list[str]) -> list[str]:final_chunks = []  # 存储最终分割后的合规片段# 1. 选择当前层级的分隔符(从优先级高到低尝试)separator = separators[-1]  # 默认兜底分隔符(最后一个,如"")new_separators = []  # 下一层递归要用的分隔符(优先级更低的)for i, _s in enumerate(separators):# 分隔符转义:非正则模式下,自动转义特殊字符(如"."→"\.")_separator = _s if self._is_separator_regex else re.escape(_s)if _s == "":  # 空字符串分隔符(兜底,按单个字符分割)separator = _sbreak# 检查当前分隔符是否能匹配文本(能匹配则用它分割)if re.search(_separator, text):separator = _s  # 确定当前层级的分隔符new_separators = separators[i + 1 :]  # 下一层递归用剩余分隔符(优先级更低)break# 2. 用选定的分隔符分割文本(调用辅助函数_split_text_with_regex)_separator = separator if self._is_separator_regex else re.escape(separator)splits = _split_text_with_regex(text, _separator, keep_separator=self._keep_separator)# 3. 处理分割后的片段:合规则保留,超长则递归细分_good_splits = []  # 临时存储当前层级的合规片段# 合并分隔符:不保留分隔符时,合并片段需补回原分隔符(避免语义断裂)_merge_sep = "" if self._keep_separator else separatorfor s in splits:# 片段长度合规(小于chunk_size):加入临时列表if self._length_function(s) < self._chunk_size:_good_splits.append(s)# 片段超长:先合并临时合规片段,再递归分割超长片段else:if _good_splits:  # 先把临时列表里的合规片段合并后加入最终结果merged_text = self._merge_splits(_good_splits, _merge_sep)final_chunks.extend(merged_text)_good_splits = []  # 清空临时列表# 无更低优先级分隔符:直接保留超长片段(兜底)if not new_separators:final_chunks.append(s)# 有更低优先级分隔符:递归分割超长片段else:recursive_splits = self._split_text(s, new_separators)final_chunks.extend(recursive_splits)# 4. 处理剩余的合规片段if _good_splits:merged_text = self._merge_splits(_good_splits, _merge_sep)final_chunks.extend(merged_text)return final_chunksdef split_text(self, text: str) -> list[str]:"""Split the input text into smaller chunks based on predefined separators.Args:text (str): The input text to be split.Returns:List[str]: A list of text chunks obtained after splitting."""return self._split_text(text, self._separators)@classmethoddef from_language(cls, language: Language, **kwargs: Any) -> RecursiveCharacterTextSplitter:"""Return an instance of this class based on a specific language.This method initializes the text splitter with language-specific separators.Args:language (Language): The language to configure the text splitter for.**kwargs (Any): Additional keyword arguments to customize the splitter.Returns:RecursiveCharacterTextSplitter: An instance of the text splitter configuredfor the specified language."""separators = cls.get_separators_for_language(language)return cls(separators=separators, is_separator_regex=True, **kwargs)@staticmethoddef get_separators_for_language(language: Language) -> list[str]:"""Retrieve a list of separators specific to the given language.Args:language (Language): The language for which to get the separators.Returns:List[str]: A list of separators appropriate for the specified language."""if language in (Language.C, Language.CPP):return [# Split along class definitions"\nclass ",# Split along function definitions"\nvoid ","\nint ","\nfloat ","\ndouble ",# Split along control flow statements"\nif ","\nfor ","\nwhile ","\nswitch ","\ncase ",# Split by the normal type of lines"\n\n","\n"," ","",]if language == Language.GO:return [# Split along function definitions"\nfunc ","\nvar ","\nconst ","\ntype ",# Split along control flow statements"\nif ","\nfor ","\nswitch ","\ncase ",# Split by the normal type of lines"\n\n","\n"," ","",]if language == Language.JAVA:return [# Split along class definitions"\nclass ",# Split along method definitions"\npublic ","\nprotected ","\nprivate ","\nstatic ",# Split along control flow statements"\nif ","\nfor ","\nwhile ","\nswitch ","\ncase ",# Split by the normal type of lines"\n\n","\n"," ","",]if language == Language.KOTLIN:return [# Split along class definitions"\nclass ",# Split along method definitions"\npublic ","\nprotected ","\nprivate ","\ninternal ","\ncompanion ","\nfun ","\nval ","\nvar ",# Split along control flow statements"\nif ","\nfor ","\nwhile ","\nwhen ","\ncase ","\nelse ",# Split by the normal type of lines"\n\n","\n"," ","",]if language == Language.JS:return [# Split along function definitions"\nfunction ","\nconst ","\nlet ","\nvar ","\nclass ",# Split along control flow statements"\nif ","\nfor ","\nwhile ","\nswitch ","\ncase ","\ndefault ",# Split by the normal type of lines"\n\n","\n"," ","",]if language == Language.TS:return ["\nenum ","\ninterface ","\nnamespace ","\ntype ",# Split along class definitions"\nclass ",# Split along function definitions"\nfunction ","\nconst ","\nlet ","\nvar ",# Split along control flow statements"\nif ","\nfor ","\nwhile ","\nswitch ","\ncase ","\ndefault ",# Split by the normal type of lines"\n\n","\n"," ","",]if language == Language.PHP:return [# Split along function definitions"\nfunction ",# Split along class definitions"\nclass ",# Split along control flow statements"\nif ","\nforeach ","\nwhile ","\ndo ","\nswitch ","\ncase ",# Split by the normal type of lines"\n\n","\n"," ","",]if language == Language.PROTO:return [# Split along message definitions"\nmessage ",# Split along service definitions"\nservice ",# Split along enum definitions"\nenum ",# Split along option definitions"\noption ",# Split along import statements"\nimport ",# Split along syntax declarations"\nsyntax ",# Split by the normal type of lines"\n\n","\n"," ","",]if language == Language.PYTHON:return [# First, try to split along class definitions"\nclass ","\ndef ","\n\tdef ",# Now split by the normal type of lines"\n\n","\n"," ","",]if language == Language.RST:return [# Split along section titles"\n=+\n","\n-+\n","\n\\*+\n",# Split along directive markers"\n\n.. *\n\n",# Split by the normal type of lines"\n\n","\n"," ","",]if language == Language.RUBY:return [# Split along method definitions"\ndef ","\nclass ",# Split along control flow statements"\nif ","\nunless ","\nwhile ","\nfor ","\ndo ","\nbegin ","\nrescue ",# Split by the normal type of lines"\n\n","\n"," ","",]if language == Language.ELIXIR:return [# Split along method function and module definition"\ndef ","\ndefp ","\ndefmodule ","\ndefprotocol ","\ndefmacro ","\ndefmacrop ",# Split along control flow statements"\nif ","\nunless ","\nwhile ","\ncase ","\ncond ","\nwith ","\nfor ","\ndo ",# Split by the normal type of lines"\n\n","\n"," ","",]if language == Language.RUST:return [# Split along function definitions"\nfn ","\nconst ","\nlet ",# Split along control flow statements"\nif ","\nwhile ","\nfor ","\nloop ","\nmatch ","\nconst ",# Split by the normal type of lines"\n\n","\n"," ","",]if language == Language.SCALA:return [# Split along class definitions"\nclass ","\nobject ",# Split along method definitions"\ndef ","\nval ","\nvar ",# Split along control flow statements"\nif ","\nfor ","\nwhile ","\nmatch ","\ncase ",# Split by the normal type of lines"\n\n","\n"," ","",]if language == Language.SWIFT:return [# Split along function definitions"\nfunc ",# Split along class definitions"\nclass ","\nstruct ","\nenum ",# Split along control flow statements"\nif ","\nfor ","\nwhile ","\ndo ","\nswitch ","\ncase ",# Split by the normal type of lines"\n\n","\n"," ","",]if language == Language.MARKDOWN:return [# First, try to split along Markdown headings (starting with level 2)"\n#{1,6} ",# Note the alternative syntax for headings (below) is not handled here# Heading level 2# ---------------# End of code block"```\n",# Horizontal lines"\n\\*\\*\\*+\n","\n---+\n","\n___+\n",# Note that this splitter doesn't handle horizontal lines defined# by *three or more* of ***, ---, or ___, but this is not handled"\n\n","\n"," ","",]if language == Language.LATEX:return [# First, try to split along Latex sections"\n\\\\chapter{","\n\\\\section{","\n\\\\subsection{","\n\\\\subsubsection{",# Now split by environments"\n\\\\begin{enumerate}","\n\\\\begin{itemize}","\n\\\\begin{description}","\n\\\\begin{list}","\n\\\\begin{quote}","\n\\\\begin{quotation}","\n\\\\begin{verse}","\n\\\\begin{verbatim}",# Now split by math environments"\n\\\\begin{align}","$$","$",# Now split by the normal type of lines" ","",]if language == Language.HTML:return [# First, try to split along HTML tags"<body","<div","<p","<br","<li","<h1","<h2","<h3","<h4","<h5","<h6","<span","<table","<tr","<td","<th","<ul","<ol","<header","<footer","<nav",# Head"<head","<style","<script","<meta","<title","",]if language == Language.CSHARP:return ["\ninterface ","\nenum ","\nimplements ","\ndelegate ","\nevent ",# Split along class definitions"\nclass ","\nabstract ",# Split along method definitions"\npublic ","\nprotected ","\nprivate ","\nstatic ","\nreturn ",# Split along control flow statements"\nif ","\ncontinue ","\nfor ","\nforeach ","\nwhile ","\nswitch ","\nbreak ","\ncase ","\nelse ",# Split by exceptions"\ntry ","\nthrow ","\nfinally ","\ncatch ",# Split by the normal type of lines"\n\n","\n"," ","",]if language == Language.SOL:return [# Split along compiler information definitions"\npragma ","\nusing ",# Split along contract definitions"\ncontract ","\ninterface ","\nlibrary ",# Split along method definitions"\nconstructor ","\ntype ","\nfunction ","\nevent ","\nmodifier ","\nerror ","\nstruct ","\nenum ",# Split along control flow statements"\nif ","\nfor ","\nwhile ","\ndo while ","\nassembly ",# Split by the normal type of lines"\n\n","\n"," ","",]if language == Language.COBOL:return [# Split along divisions"\nIDENTIFICATION DIVISION.","\nENVIRONMENT DIVISION.","\nDATA DIVISION.","\nPROCEDURE DIVISION.",# Split along sections within DATA DIVISION"\nWORKING-STORAGE SECTION.","\nLINKAGE SECTION.","\nFILE SECTION.",# Split along sections within PROCEDURE DIVISION"\nINPUT-OUTPUT SECTION.",# Split along paragraphs and common statements"\nOPEN ","\nCLOSE ","\nREAD ","\nWRITE ","\nIF ","\nELSE ","\nMOVE ","\nPERFORM ","\nUNTIL ","\nVARYING ","\nACCEPT ","\nDISPLAY ","\nSTOP RUN.",# Split by the normal type of lines"\n"," ","",]if language == Language.LUA:return [# Split along variable and table definitions"\nlocal ",# Split along function definitions"\nfunction ",# Split along control flow statements"\nif ","\nfor ","\nwhile ","\nrepeat ",# Split by the normal type of lines"\n\n","\n"," ","",]if language == Language.HASKELL:return [# Split along function definitions"\nmain :: ","\nmain = ","\nlet ","\nin ","\ndo ","\nwhere ","\n:: ","\n= ",# Split along type declarations"\ndata ","\nnewtype ","\ntype ","\n:: ",# Split along module declarations"\nmodule ",# Split along import statements"\nimport ","\nqualified ","\nimport qualified ",# Split along typeclass declarations"\nclass ","\ninstance ",# Split along case expressions"\ncase ",# Split along guards in function definitions"\n| ",# Split along record field declarations"\ndata ","\n= {","\n, ",# Split by the normal type of lines"\n\n","\n"," ","",]if language == Language.POWERSHELL:return [# Split along function definitions"\nfunction ",# Split along parameter declarations (escape parentheses)"\nparam ",# Split along control flow statements"\nif ","\nforeach ","\nfor ","\nwhile ","\nswitch ",# Split along class definitions (for PowerShell 5.0 and above)"\nclass ",# Split along try-catch-finally blocks"\ntry ","\ncatch ","\nfinally ",# Split by normal lines and empty spaces"\n\n","\n"," ","",]if language == Language.VISUALBASIC6:vis = r"(?:Public|Private|Friend|Global|Static)\s+"return [# Split along definitionsrf"\n(?!End\s){vis}?Sub\s+",rf"\n(?!End\s){vis}?Function\s+",rf"\n(?!End\s){vis}?Property\s+(?:Get|Let|Set)\s+",rf"\n(?!End\s){vis}?Type\s+",rf"\n(?!End\s){vis}?Enum\s+",# Split along control flow statementsr"\n(?!End\s)If\s+",r"\nElseIf\s+",r"\nElse\s+",r"\nSelect\s+Case\s+",r"\nCase\s+",r"\nFor\s+",r"\nDo\s+",r"\nWhile\s+",r"\nWith\s+",# Split by the normal type of linesr"\n\n",r"\n"," ","",]if language in Language._value2member_map_:msg = f"Language {language} is not implemented yet!"raise ValueError(msg)msg = (f"Language {language} is not supported! Please choose from {list(Language)}")raise ValueError(msg)
  • 核心作用是将超长文本(如完整文档、长对话)递归分割成符合 LLM 上下文窗口限制的小片段,同时最大程度保留文本的语义完整性(避免随意切断句子、段落)
    • 将长文本拆分为 指定长度的小片段chunk),适配 LLM 输入限制;
    • 优先按 自然语义分隔符(如段落、句子)分割,减少语义断裂(比简单的 “按固定字符数切割” 更智能);
    • 支持自定义分割规则,兼容不同语言、格式的文本(如中文、英文、代码)。
  • “递归” 是其核心设计:按 分隔符优先级 逐层尝试分割,若某一层分隔后片段仍超长度,则用更低优先级的分隔符继续分割,直到片段长度符合要求。默认分割优先级(从高到低):
    1. 段落分隔"\n\n"(优先按空行分割,保留完整段落);
    2. 换行分隔"\n"(段落切不开时,按换行符分割,保留完整句子);
    3. 空格分隔" "(换行切不开时,按空格分割,保留完整单词 / 词组);
    4. 字符分隔""(最后兜底,按单个字符分割,尽量避免,但确保片段长度合规)。
  • 关键参数解析:
    • separators:定义 “分割优先级”—— 优先用列表前的分隔符分割(如先按空行分段落,分不开再按换行分句子),默认值 ["\n\n", "\n", " ", ""] 适配通用文本,也可自定义(如代码分割用 ["\nclass ", "\ndef ", "\n"])。
    • keep_separator:控制是否保留分隔符在分割后的片段中,支持三种取值:
      • True/"end":分隔符附加在片段末尾(如 “段落 1\n\n”);
      • "start":分隔符附加在片段开头(如 “\n\n 段落 2”);
      • False:不保留分隔符(默认合并时会补回非正则分隔符,避免语义断裂)。
    • is_separator_regex:若为 True,则 separators 中的元素按正则表达式解析(如 Markdown 分隔符 "\n#{1,6} " 匹配标题);若为 False,则按字面量解析(自动转义特殊字符,如 . 不会被当作正则通配符)。
  • 扩展能力是 “按语言自动配置分隔符”,通过 from_language 类方法和 get_separators_for_language 静态方法实现,解决 “通用分隔符不适配特定语言(如代码、Markdown)” 的问题。
    • Python 代码:优先按 "\nclass "(类定义)、"\ndef "(函数定义)分割,避免切断类 / 函数逻辑;
    • Markdown:优先按 "\n#{1,6} "(标题)、"```\n"(代码块结束)分割,保留 Markdown 结构;
    • HTML:优先按 、` 等标签分割,避免切断 HTML 节点;
    • C/C++ 代码:优先按 "\nclass "(类)、"\nvoid "(函数)分割,符合代码逻辑。

HuggingFaceEmbeddings

源码位置:langchain/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py

from __future__ import annotationsfrom typing import Any, Optionalfrom langchain_core.embeddings import Embeddings
from pydantic import BaseModel, ConfigDict, Fieldfrom langchain_huggingface.utils.import_utils import (IMPORT_ERROR,is_ipex_available,is_optimum_intel_available,is_optimum_intel_version,
)DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"_MIN_OPTIMUM_VERSION = "1.22"class HuggingFaceEmbeddings(BaseModel, Embeddings):"""HuggingFace sentence_transformers embedding models.To use, you should have the ``sentence_transformers`` python package installed.Example:.. code-block:: pythonfrom langchain_huggingface import HuggingFaceEmbeddingsmodel_name = "sentence-transformers/all-mpnet-base-v2"model_kwargs = {'device': 'cpu'}encode_kwargs = {'normalize_embeddings': False}hf = HuggingFaceEmbeddings(model_name=model_name,model_kwargs=model_kwargs,encode_kwargs=encode_kwargs)"""model_name: str = Field(default=DEFAULT_MODEL_NAME, alias="model")"""Model name to use."""cache_folder: Optional[str] = None"""Path to store models.Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""model_kwargs: dict[str, Any] = Field(default_factory=dict)"""Keyword arguments to pass to the Sentence Transformer model, such as `device`,`prompts`, `default_prompt_name`, `revision`, `trust_remote_code`, or `token`.See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer"""encode_kwargs: dict[str, Any] = Field(default_factory=dict)"""Keyword arguments to pass when calling the `encode` method for the documents ofthe Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,`precision`, `normalize_embeddings`, and more.See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""query_encode_kwargs: dict[str, Any] = Field(default_factory=dict)"""Keyword arguments to pass when calling the `encode` method for the query ofthe Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`,`precision`, `normalize_embeddings`, and more.See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode"""multi_process: bool = False"""Run encode() on multiple GPUs."""show_progress: bool = False"""Whether to show a progress bar."""def __init__(self, **kwargs: Any):super().__init__(** kwargs)# 1. 检查sentence_transformers库是否安装try:import sentence_transformersexcept ImportError as exc:msg = "Could not import sentence_transformers. Install with `pip install sentence-transformers`."raise ImportError(msg) from exc# 2. 处理Intel IPEX后端(可选,用于CPU优化)if self.model_kwargs.get("backend", "torch") == "ipex":# 检查ipex相关依赖是否安装且版本符合要求if not is_optimum_intel_available() or not is_ipex_available():msg = f"Backend: ipex {IMPORT_ERROR.format('optimum[ipex]')}"raise ImportError(msg)if is_optimum_intel_version("<", _MIN_OPTIMUM_VERSION):msg = f"Backend: ipex requires optimum-intel>={_MIN_OPTIMUM_VERSION}."raise ImportError(msg)# 导入Intel优化的模型类from optimum.intel import IPEXSentenceTransformermodel_cls = IPEXSentenceTransformerelse:# 默认使用sentence_transformers的SentenceTransformer类model_cls = sentence_transformers.SentenceTransformer# 3. 初始化模型实例(_client是实际用于编码的模型对象)self._client = model_cls(self.model_name, cache_folder=self.cache_folder, **self.model_kwargs)model_config = ConfigDict(extra="forbid",protected_namespaces=(),populate_by_name=True,)def _embed(self, texts: list[str], encode_kwargs: dict[str, Any]) -> list[list[float]]:# 1. 文本预处理:替换换行符(避免模型对换行敏感)texts = [x.replace("\n", " ") for x in texts]# 2. 编码逻辑:分多进程/单进程处理if self.multi_process:# 多进程编码(适合多GPU,通过进程池分发任务)pool = self._client.start_multi_process_pool()embeddings = self._client.encode_multi_process(texts, pool)sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)else:# 单进程编码:直接调用模型的encode方法embeddings = self._client.encode(texts,show_progress_bar=self.show_progress,  # 是否显示进度条**encode_kwargs,  # 传入编码参数(如归一化、批量大小))# 3. 结果校验与转换:确保返回列表形式的向量if isinstance(embeddings, list):raise TypeError("Expected embeddings to be a Tensor or numpy array, got list.")return embeddings.tolist()  # 转换为Python列表(Tensor/numpy数组→list[float])def embed_documents(self, texts: list[str]) -> list[list[float]]:"""Compute doc embeddings using a HuggingFace transformer model.Args:texts: The list of texts to embed.Returns:List of embeddings, one for each text."""return self._embed(texts, self.encode_kwargs)def embed_query(self, text: str) -> list[float]:"""Compute query embeddings using a HuggingFace transformer model.Args:text: The text to embed.Returns:Embeddings for the text."""embed_kwargs = (self.query_encode_kwargsif len(self.query_encode_kwargs) > 0else self.encode_kwargs)return self._embed([text], embed_kwargs)[0]

HuggingFaceEmbeddings 是 LangChain 中用于加载 Hugging Face 生态预训练模型 作为嵌入模型的核心类(非函数)。它的核心作用是将文本(句子、段落等)转换为固定维度的密集向量(Embedding),这些向量能够捕捉文本的语义信息,为后续的向量存储(如 InMemoryVectorStore)、相似度检索(如 RAG 中的相关文档匹配)等任务提供基础。

  1. 模型加载:初始化时,通过 model_name 指定 Hugging Face 模型名称(如 BAAI/bge-small-zh-v1.5),内部会调用 sentence-transformers 库加载模型和对应的分词器;
  2. 文本预处理:对输入文本进行分词、添加特殊标记(如 [CLS][SEP])等预处理,转换为模型可接受的输入格式;
  3. 向量生成:将预处理后的文本输入模型,通过模型的输出层(通常是 [CLS] 标记对应的隐藏状态)生成向量,并可根据配置进行归一化等后处理;
  4. 输出向量:返回处理后的向量,供 LangChain 的向量存储或检索逻辑使用。

HuggingFaceEmbeddings 的初始化参数决定了模型加载和向量生成的行为,核心参数如下:

参数名类型说明
model_namestr必需,Hugging Face 模型名称或本地路径(如 BAAI/bge-small-zh-v1.5sentence-transformers/all-MiniLM-L6-v2)。
model_kwargsdict可选,传递给模型加载的参数,常见如 {'device': 'cpu'}(指定运行设备,cpucuda)。
encode_kwargsdict可选,传递给模型编码方法的参数,常见如 {'normalize_embeddings': True}(是否对向量进行 L2 归一化,提升相似度计算精度)。
  • 继承 BaseModel(Pydantic 基类):支持参数验证、类型检查和配置管理(如 model_config)。

  • 继承 Embeddings(LangChain 抽象基类):必须实现 embed_documents(文档嵌入)和 embed_query(查询嵌入)两个核心方法,才能融入 LangChain 的向量处理生态。

  • model_name:默认值为 sentence-transformers/all-mpnet-base-v2(经典英文嵌入模型),支持传入 Hugging Face 模型名或本地路径。

  • model_kwargs:传递给模型初始化的参数(如 device="cuda" 指定 GPU 运行、revision="main" 指定模型版本、trust_remote_code=True 允许加载自定义模型)。

  • encode_kwargsquery_encode_kwargs:分别用于文档和查询的编码参数(如 normalize_embeddings=True 向量归一化、batch_size=32 批量编码大小),支持为文档和查询设置不同的编码策略(如查询用更严格的归一化)。

  • multi_process:是否启用多进程编码(适合多 GPU 场景,加速大规模文档处理)。

  • _embed 是内部通用方法,实现文本预处理→模型编码→结果转换的完整流程,供 embed_documentsembed_query 调用。

  • 中文场景BAAI/bge-small-zh-v1.5(轻量,384 维)、BAAI/bge-large-zh-v1.5(高精度,1024 维);

  • 英文场景sentence-transformers/all-MiniLM-L6-v2(轻量,384 维)、sentence-transformers/all-mpnet-base-v2(高精度,768 维);

  • 多语言场景sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2(支持 50+ 语言,384 维)。

InMemoryVectorStore

LangChain 核心模块中用于在内存中存储向量和文档的轻量级向量存储类。它的核心作用是临时存储文本片段(Document 对象)及其对应的嵌入向量(Embedding),并提供基于向量相似度的检索能力(如余弦相似度匹配),是开发和测试阶段的常用工具。InMemoryVectorStore 的核心功能是:

  1. 内存存储:将文档及其嵌入向量保存在内存中(不写入磁盘),无需依赖外部数据库(如 Chroma、FAISS 等);
  2. 相似度检索:根据用户查询的嵌入向量,计算与存储向量的相似度(默认余弦相似度),返回最相关的文档;
  3. 简易集成:与 LangChain 的嵌入模型(如 HuggingFaceEmbeddings)、文本分割器(如 RecursiveCharacterTextSplitter)无缝配合,快速搭建 RAG 流程。

源码位置:langchain/libs/core/langchain_core/vectorstores/in_memory.py

# in_memory.py
"""In-memory vector store."""from __future__ import annotationsimport json
import uuid
from pathlib import Path
from typing import (TYPE_CHECKING,Any,Callable,Optional,
)from typing_extensions import overridefrom langchain_core._api import deprecated
from langchain_core.documents import Document
from langchain_core.load import dumpd, load
from langchain_core.vectorstores import VectorStore
from langchain_core.vectorstores.utils import _cosine_similarity as cosine_similarity
from langchain_core.vectorstores.utils import maximal_marginal_relevanceif TYPE_CHECKING:from collections.abc import Iterator, Sequencefrom langchain_core.embeddings import Embeddingsfrom langchain_core.indexing import UpsertResponseclass InMemoryVectorStore(VectorStore):"""In-memory vector store implementation.Uses a dictionary, and computes cosine similarity for search using numpy.Setup:Install ``langchain-core``... code-block:: bashpip install -U langchain-coreKey init args — indexing params:embedding_function: EmbeddingsEmbedding function to use.Instantiate:.. code-block:: pythonfrom langchain_core.vectorstores import InMemoryVectorStorefrom langchain_openai import OpenAIEmbeddingsvector_store = InMemoryVectorStore(OpenAIEmbeddings())Add Documents:.. code-block:: pythonfrom langchain_core.documents import Documentdocument_1 = Document(id="1", page_content="foo", metadata={"baz": "bar"})document_2 = Document(id="2", page_content="thud", metadata={"bar": "baz"})document_3 = Document(id="3", page_content="i will be deleted :(")documents = [document_1, document_2, document_3]vector_store.add_documents(documents=documents)Inspect documents:.. code-block:: pythontop_n = 10for index, (id, doc) in enumerate(vector_store.store.items()):if index < top_n:# docs have keys 'id', 'vector', 'text', 'metadata'print(f"{id}: {doc['text']}")else:breakDelete Documents:.. code-block:: pythonvector_store.delete(ids=["3"])Search:.. code-block:: pythonresults = vector_store.similarity_search(query="thud",k=1)for doc in results:print(f"* {doc.page_content} [{doc.metadata}]").. code-block:: none* thud [{'bar': 'baz'}]Search with filter:.. code-block:: pythondef _filter_function(doc: Document) -> bool:return doc.metadata.get("bar") == "baz"results = vector_store.similarity_search(query="thud", k=1, filter=_filter_function)for doc in results:print(f"* {doc.page_content} [{doc.metadata}]").. code-block:: none* thud [{'bar': 'baz'}]Search with score:.. code-block:: pythonresults = vector_store.similarity_search_with_score(query="qux", k=1)for doc, score in results:print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]").. code-block:: none* [SIM=0.832268] foo [{'baz': 'bar'}]Async:.. code-block:: python# add documents# await vector_store.aadd_documents(documents=documents)# delete documents# await vector_store.adelete(ids=["3"])# search# results = vector_store.asimilarity_search(query="thud", k=1)# search with scoreresults = await vector_store.asimilarity_search_with_score(query="qux", k=1)for doc,score in results:print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]").. code-block:: none* [SIM=0.832268] foo [{'baz': 'bar'}]Use as Retriever:.. code-block:: pythonretriever = vector_store.as_retriever(search_type="mmr",search_kwargs={"k": 1, "fetch_k": 2, "lambda_mult": 0.5},)retriever.invoke("thud").. code-block:: none[Document(id='2', metadata={'bar': 'baz'}, page_content='thud')]"""def __init__(self, embedding: Embeddings) -> None:"""Initialize with the given embedding function.Args:embedding: embedding function to use."""# TODO: would be nice to change to# dict[str, Document] at some point (will be a breaking change)# 核心存储结构:字典,key=文档ID,value=包含文档信息的子字典self.store: dict[str, dict[str, Any]] = {}# 嵌入模型:用于将文本转换为向量(如HuggingFaceEmbeddings、OpenAIEmbeddingsself.embedding = embedding@property@overridedef embeddings(self) -> Embeddings:return self.embedding@overridedef delete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:if ids:for _id in ids:self.store.pop(_id, None)  # 从字典中删除ID对应的文档,无ID则忽略@overrideasync def adelete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:self.delete(ids)  # 内存操作无IO开销,异步直接调用同步方法@overridedef add_documents(self, documents: list[Document], ids: Optional[list[str]] = None, **kwargs: Any) -> list[str]:# 1. 提取文档文本,生成嵌入向量texts = [doc.page_content for doc in documents]vectors = self.embedding.embed_documents(texts)  # 同步生成向量# 2. 校验IDs长度(若用户传入IDs,需与文档数量一致)if ids and len(ids) != len(texts):raise ValueError(f"ids长度({len(ids)})与文档数量({len(texts)})不匹配")# 3. 生成文档ID迭代器(优先用用户传入ID→文档自带ID→UUID自动生成)id_iterator: Iterator[Optional[str]] = (iter(ids) if ids else iter(doc.id for doc in documents))# 4. 遍历文档与向量,存入self.storeids_ = []  # 记录最终生成的文档IDfor doc, vector in zip(documents, vectors):doc_id = next(id_iterator)doc_id_ = doc_id or str(uuid.uuid4())  # 无ID则自动生成UUIDids_.append(doc_id_)self.store[doc_id_] = {"id": doc_id_,"vector": vector,"text": doc.page_content,"metadata": doc.metadata,}return ids_  # 返回所有文档的最终ID@overrideasync def aadd_documents(self, documents: list[Document], ids: Optional[list[str]] = None, **kwargs: Any) -> list[str]:"""Add documents to the store."""texts = [doc.page_content for doc in documents]vectors = await self.embedding.aembed_documents(texts)if ids and len(ids) != len(texts):msg = (f"ids must be the same length as texts. "f"Got {len(ids)} ids and {len(texts)} texts.")raise ValueError(msg)id_iterator: Iterator[Optional[str]] = (iter(ids) if ids else iter(doc.id for doc in documents))ids_: list[str] = []for doc, vector in zip(documents, vectors):doc_id = next(id_iterator)doc_id_ = doc_id or str(uuid.uuid4())ids_.append(doc_id_)self.store[doc_id_] = {"id": doc_id_,"vector": vector,"text": doc.page_content,"metadata": doc.metadata,}return ids_@overridedef get_by_ids(self, ids: Sequence[str], /) -> list[Document]:"""Get documents by their ids.Args:ids: The ids of the documents to get.Returns:A list of Document objects."""documents = []for doc_id in ids:# 从self.store中获取文档数据,转换为LangChain的Document对象doc = self.store.get(doc_id)if doc:documents.append(Document(id=doc["id"],page_content=doc["text"],metadata=doc["metadata"],))return documents@deprecated(alternative="VectorStore.add_documents",message=("This was a beta API that was added in 0.2.11. It'll be removed in 0.3.0."),since="0.2.29",removal="1.0",)def upsert(self, items: Sequence[Document], /, **_kwargs: Any) -> UpsertResponse:"""[DEPRECATED] Upsert documents into the store.Args:items: The documents to upsert.Returns:The upsert response."""vectors = self.embedding.embed_documents([item.page_content for item in items])ids = []for item, vector in zip(items, vectors):doc_id = item.id or str(uuid.uuid4())ids.append(doc_id)self.store[doc_id] = {"id": doc_id,"vector": vector,"text": item.page_content,"metadata": item.metadata,}return {"succeeded": ids,"failed": [],}@deprecated(alternative="VectorStore.aadd_documents",message=("This was a beta API that was added in 0.2.11. It'll be removed in 0.3.0."),since="0.2.29",removal="1.0",)async def aupsert(self, items: Sequence[Document], /, **_kwargs: Any) -> UpsertResponse:"""[DEPRECATED] Upsert documents into the store.Args:items: The documents to upsert.Returns:The upsert response."""vectors = await self.embedding.aembed_documents([item.page_content for item in items])ids = []for item, vector in zip(items, vectors):doc_id = item.id or str(uuid.uuid4())ids.append(doc_id)self.store[doc_id] = {"id": doc_id,"vector": vector,"text": item.page_content,"metadata": item.metadata,}return {"succeeded": ids,"failed": [],}@overrideasync def aget_by_ids(self, ids: Sequence[str], /) -> list[Document]:"""Async get documents by their ids.Args:ids: The ids of the documents to get.Returns:A list of Document objects."""return self.get_by_ids(ids)  def _similarity_search_with_score_by_vector(self,embedding: list[float],k: int = 4,filter: Optional[Callable[[Document], bool]] = None,  # noqa: A002) -> list[tuple[Document, float, list[float]]]:# 1. 提取所有文档(固定顺序,确保相似度与文档对应)docs = list(self.store.values())# 2. 应用过滤函数(可选,仅保留满足条件的文档)if filter is not None:docs = [docfor doc in docsif filter(Document(page_content=doc["text"], metadata=doc["metadata"]))]if not docs:return []  # 无符合条件的文档,返回空# 3. 计算查询向量与所有文档向量的余弦相似度# cosine_similarity输入为二维数组,返回形状为[1, 文档数]的相似度矩阵similarity = cosine_similarity([embedding], [doc["vector"] for doc in docs])[0]# 4. 按相似度降序排序,取前k个文档的索引top_k_idx = similarity.argsort()[::-1][:k]# 5. 转换为(Document, 相似度分数, 向量)元组列表return [(Document(id=doc_dict["id"],page_content=doc_dict["text"],metadata=doc_dict["metadata"],),float(similarity[idx].item()),doc_dict["vector"],)for idx in top_k_idx# Assign using walrus operator to avoid multiple lookupsif (doc_dict := docs[idx])]def similarity_search_with_score_by_vector(self,embedding: list[float],k: int = 4,filter: Optional[Callable[[Document], bool]] = None,  # noqa: A002**_kwargs: Any,) -> list[tuple[Document, float]]:"""Search for the most similar documents to the given embedding.Args:embedding: The embedding to search for.k: The number of documents to return.filter: A function to filter the documents.Returns:A list of tuples of Document objects and their similarity scores."""return [(doc, similarity)for doc, similarity, _ in self._similarity_search_with_score_by_vector(embedding=embedding, k=k, filter=filter)]@overridedef similarity_search_with_score(self,query: str,k: int = 4,**kwargs: Any,) -> list[tuple[Document, float]]:embedding = self.embedding.embed_query(query)return self.similarity_search_with_score_by_vector(embedding,k,**kwargs,)@overrideasync def asimilarity_search_with_score(self, query: str, k: int = 4, **kwargs: Any) -> list[tuple[Document, float]]:embedding = await self.embedding.aembed_query(query)return self.similarity_search_with_score_by_vector(embedding,k,**kwargs,)@overridedef similarity_search_by_vector(self,embedding: list[float],k: int = 4,**kwargs: Any,) -> list[Document]:docs_and_scores = self.similarity_search_with_score_by_vector(embedding,k,**kwargs,)return [doc for doc, _ in docs_and_scores]@overrideasync def asimilarity_search_by_vector(self, embedding: list[float], k: int = 4, **kwargs: Any) -> list[Document]:return self.similarity_search_by_vector(embedding, k, **kwargs)@overridedef similarity_search(self, query: str, k: int = 4, **kwargs: Any) -> list[Document]:return [doc for doc, _ in self.similarity_search_with_score(query, k, **kwargs)]@overrideasync def asimilarity_search(self, query: str, k: int = 4, **kwargs: Any) -> list[Document]:return [docfor doc, _ in await self.asimilarity_search_with_score(query, k, **kwargs)]@overridedef max_marginal_relevance_search_by_vector(self,embedding: list[float],k: int = 4,fetch_k: int = 20,lambda_mult: float = 0.5,*,filter: Optional[Callable[[Document], bool]] = None,**kwargs: Any,) -> list[Document]:# 1. 先获取fetch_k个相似文档(比最终返回k多,为多样性筛选留空间)prefetch_hits = self._similarity_search_with_score_by_vector(embedding=embedding,k=fetch_k,filter=filter,)try:import numpy as npexcept ImportError as e:msg = ("numpy must be installed to use max_marginal_relevance_search ""pip install numpy")raise ImportError(msg) from e# 2. 调用MMR算法,计算兼顾相似与多样的文档索引# lambda_mult:0→仅多样性,1→仅相似性,默认0.5平衡两者mmr_chosen_indices = maximal_marginal_relevance(np.array(embedding, dtype=np.float32),[vector for _, _, vector in prefetch_hits],k=k,lambda_mult=lambda_mult,)return [prefetch_hits[idx][0] for idx in mmr_chosen_indices]@overridedef max_marginal_relevance_search(self,query: str,k: int = 4,fetch_k: int = 20,lambda_mult: float = 0.5,**kwargs: Any,) -> list[Document]:embedding_vector = self.embedding.embed_query(query)return self.max_marginal_relevance_search_by_vector(embedding_vector,k,fetch_k,lambda_mult=lambda_mult,**kwargs,)@overrideasync def amax_marginal_relevance_search(self,query: str,k: int = 4,fetch_k: int = 20,lambda_mult: float = 0.5,**kwargs: Any,) -> list[Document]:embedding_vector = await self.embedding.aembed_query(query)return self.max_marginal_relevance_search_by_vector(embedding_vector,k,fetch_k,lambda_mult=lambda_mult,**kwargs,)@classmethod@overridedef from_texts(cls,texts: list[str],embedding: Embeddings,metadatas: Optional[list[dict]] = None,**kwargs: Any,) -> InMemoryVectorStore:# 1. 实例化向量存储store = cls(embedding=embedding,)# 2. 调用add_texts(父类方法,自动将文本+元数据转为Document)store.add_texts(texts=texts, metadatas=metadatas, **kwargs)return store@classmethod@overrideasync def afrom_texts(cls,texts: list[str],embedding: Embeddings,metadatas: Optional[list[dict]] = None,**kwargs: Any,) -> InMemoryVectorStore:store = cls(embedding=embedding,)await store.aadd_texts(texts=texts, metadatas=metadatas, **kwargs)return store@classmethoddef load(cls, path: str, embedding: Embeddings, **kwargs: Any) -> InMemoryVectorStore:"""Load a vector store from a file.Args:path: The path to load the vector store from.embedding: The embedding to use.kwargs: Additional arguments to pass to the constructor.Returns:A VectorStore object."""# 从JSON文件加载数据,实例化向量存储path_: Path = Path(path)with path_.open("r") as f:store = load(json.load(f))vectorstore = cls(embedding=embedding, **kwargs)  # load解析JSON为字典赋值给self.storevectorstore.store = storereturn vectorstoredef dump(self, path: str) -> None:"""Dump the vector store to a file.Args:path: The path to dump the vector store to."""# 将self.store保存为JSON文件path_: Path = Path(path)path_.parent.mkdir(exist_ok=True, parents=True)  # 确保父目录存在with path_.open("w") as f:json.dump(dumpd(self.store), f, indent=2)   # dumpd处理特殊数据类型(如UUID)
  1. 初始化与数据存储:初始化时需传入嵌入模型(用于将文本转换为向量),并可通过 from_textsfrom_documents 方法加载文档。内部会将文档文本通过嵌入模型转换为向量,与文档元数据(如来源、序号)一起存储在内存中的列表或字典结构中。

  2. 相似度检索流程:当调用 similarity_search 等检索方法时:

    • 首先将用户查询文本通过嵌入模型转换为查询向量;
    • 计算查询向量与内存中所有文档向量的相似度(默认余弦相似度);
    • 按相似度排序,返回前 k 个最相关的文档(默认 k=4)。
  3. 核心存储 self.store:类型:dict[str, dict[str, Any]],外层 key 是文档唯一 ID(字符串),内层 value 是包含 4 个字段的子字典:

    • id:文档 ID(与外层 key 一致);
    • vector:文档文本的嵌入向量(list[float]);
    • text:文档的原始文本(page_content);
    • metadata:文档元数据(如来源、作者等,dict 类型)。
  4. 文档操作方法:CRUD 核心逻辑:

    • 文档添加:add_documents(同步)与 aadd_documents(异步)
    • 文档删除:delete(同步)与 adelete(异步)
    • 文档查询:get_by_ids(按 ID 获取)
  5. InMemoryVectorStore 支持两种核心检索方式:基于余弦相似度的精准检索基于 MMR 的相似 + 多样性检索,底层依赖向量相似度计算。

    • 底层相似度计算:_similarity_search_with_score_by_vector:实现 “过滤→相似度计算→排序→结果转换” 的完整流程;

      • 方法名功能描述
        similarity_search_with_score_by_vector输入向量,返回(Document, 相似度分数)列表
        similarity_search_with_score输入查询文本(自动生成向量),返回(Document, 分数)列表
        similarity_search输入查询文本,仅返回匹配的 Document 列表(忽略分数)
        asimilarity_search异步版similarity_search(用异步嵌入生成查询向量)
    • 多样性检索:max_marginal_relevance_search(MMR)检索在 “相似度” 基础上增加 “多样性”,避免返回的文档高度重复,适合需要 “覆盖更多维度信息” 的场景(如长文档问答)。

  6. 数据持久化:InMemoryVectorStore 默认数据仅存于内存,程序退出后丢失,dump/load 提供简单的文件持久化能力。便捷创建:from_texts(从文本列表创建)

ChatPromptTemplate

ChatPromptTemplate 是 LangChain 核心模块中用于构建聊天模型提示词的模板类。它专门针对多轮对话场景设计,能够结构化地组织不同角色(如系统、用户、助手)的消息,确保输入到聊天模型(如 GPT-3.5/4、Claude、Qwen 等)的提示符合模型要求的格式,同时支持动态填充变量,是连接用户输入与大语言模型的核心组件。

  • 核心功能:
    • 角色化消息组织:定义多轮对话中不同角色(系统、用户、助手)的消息内容,明确消息的发送者和上下文关系;
    • 动态变量填充:支持在模板中嵌入变量(如 {context}{question}),并在运行时动态替换为具体值(如检索到的文档、用户的问题);
    • 模型兼容性:生成符合聊天模型输入格式的提示(通常是包含 rolecontent 的消息列表),无需手动处理格式转换。

源码位置:langchain/libs/core/langchain_core/prompts/chat.py

class ChatPromptTemplate(BaseChatPromptTemplate):"""Prompt template for chat models.Use to create flexible templated prompts for chat models.Examples:.. versionchanged:: 0.2.24You can pass any Message-like formats supported by``ChatPromptTemplate.from_messages()`` directly to ``ChatPromptTemplate()``init... code-block:: pythonfrom langchain_core.prompts import ChatPromptTemplatetemplate = ChatPromptTemplate([("system", "You are a helpful AI bot. Your name is {name}."),("human", "Hello, how are you doing?"),("ai", "I'm doing well, thanks!"),("human", "{user_input}"),])prompt_value = template.invoke({"name": "Bob","user_input": "What is your name?"})# Output:# ChatPromptValue(#    messages=[#        SystemMessage(content='You are a helpful AI bot. Your name is Bob.'),#        HumanMessage(content='Hello, how are you doing?'),#        AIMessage(content="I'm doing well, thanks!"),#        HumanMessage(content='What is your name?')#    ]#)Messages Placeholder:.. code-block:: python# In addition to Human/AI/Tool/Function messages,# you can initialize the template with a MessagesPlaceholder# either using the class directly or with the shorthand tuple syntax:template = ChatPromptTemplate([("system", "You are a helpful AI bot."),# Means the template will receive an optional list of messages under# the "conversation" key("placeholder", "{conversation}")# Equivalently:# MessagesPlaceholder(variable_name="conversation", optional=True)])prompt_value = template.invoke({"conversation": [("human", "Hi!"),("ai", "How can I assist you today?"),("human", "Can you make me an ice cream sundae?"),("ai", "No.")]})# Output:# ChatPromptValue(#    messages=[#        SystemMessage(content='You are a helpful AI bot.'),#        HumanMessage(content='Hi!'),#        AIMessage(content='How can I assist you today?'),#        HumanMessage(content='Can you make me an ice cream sundae?'),#        AIMessage(content='No.'),#    ]#)Single-variable template:If your prompt has only a single input variable (i.e., 1 instance of "{variable_nams}"),and you invoke the template with a non-dict object, the prompt template willinject the provided argument into that variable location... code-block:: pythonfrom langchain_core.prompts import ChatPromptTemplatetemplate = ChatPromptTemplate([("system", "You are a helpful AI bot. Your name is Carl."),("human", "{user_input}"),])prompt_value = template.invoke("Hello, there!")# Equivalent to# prompt_value = template.invoke({"user_input": "Hello, there!"})# Output:#  ChatPromptValue(#     messages=[#         SystemMessage(content='You are a helpful AI bot. Your name is Carl.'),#         HumanMessage(content='Hello, there!'),#     ]# )"""  # noqa: E501messages: Annotated[list[MessageLike], SkipValidation()]"""List of messages consisting of either message prompt templates or messages."""validate_template: bool = False"""Whether or not to try validating the template."""def __init__(self,messages: Sequence[MessageLikeRepresentation],*,template_format: PromptTemplateFormat = "f-string",**kwargs: Any,) -> None:"""Create a chat prompt template from a variety of message formats.Args:messages: sequence of message representations.A message can be represented using the following formats:(1) BaseMessagePromptTemplate, (2) BaseMessage, (3) 2-tuple of(message type, template); e.g., ("human", "{user_input}"),(4) 2-tuple of (message class, template), (5) a string which isshorthand for ("human", template); e.g., "{user_input}".template_format: format of the template. Defaults to "f-string".input_variables: A list of the names of the variables whose values arerequired as inputs to the prompt.optional_variables: A list of the names of the variables for placeholderor MessagePlaceholder that are optional.These variables are auto inferred from the prompt and user need notprovide them.partial_variables: A dictionary of the partial variables the prompttemplate carries. Partial variables populate the template so that youdon't need to pass them in every time you call the prompt.validate_template: Whether to validate the template.input_types: A dictionary of the types of the variables the prompt templateexpects. If not provided, all variables are assumed to be strings.Returns:A chat prompt template.Examples:Instantiation from a list of message templates:.. code-block:: pythontemplate = ChatPromptTemplate([("human", "Hello, how are you?"),("ai", "I'm doing well, thanks!"),("human", "That's good to hear."),])Instantiation from mixed message formats:.. code-block:: pythontemplate = ChatPromptTemplate([SystemMessage(content="hello"),("human", "Hello, how are you?"),])"""# 1. 将输入的消息格式统一转换为标准消息模板messages_ = [_convert_to_message_template(message, template_format)for message in messages]# 2. 自动推断输入变量、可选变量、部分变量input_vars: set[str] = set()  # 必须传入的变量(如{question})optional_variables: set[str] = set()  # 可选变量(如MessagesPlaceholder)partial_vars: dict[str, Any] = {}  # 部分预填充变量(如默认空列表)for _message in messages_:# 处理MessagesPlaceholder(动态插入对话历史,可选)if isinstance(_message, MessagesPlaceholder) and _message.optional:partial_vars[_message.variable_name] = []optional_variables.add(_message.variable_name)# 从消息模板中提取输入变量(如HumanMessagePromptTemplate中的{user_input})elif isinstance(_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)):input_vars.update(_message.input_variables)# 3. 组装父类初始化参数(输入变量、可选变量、部分变量)kwargs = {"input_variables": sorted(input_vars),"optional_variables": sorted(optional_variables),"partial_variables": partial_vars,**kwargs,}# 调用父类初始化,保存统一后的消息模板cast("type[ChatPromptTemplate]", super()).__init__(messages=messages_, **kwargs)@classmethoddef get_lc_namespace(cls) -> list[str]:"""Get the namespace of the langchain object."""return ["langchain", "prompts", "chat"]def __add__(self, other: Any) -> ChatPromptTemplate:"""Combine two prompt templates.Args:other: Another prompt template.Returns:Combined prompt template."""# 1. 合并部分变量partials = {**self.partial_variables}# Need to check that other has partial variables since it may not be# a ChatPromptTemplate.if hasattr(other, "partial_variables") and other.partial_variables:partials.update(other.partial_variables)# Allow for easy combining# 2. 支持多种拼接类型if isinstance(other, ChatPromptTemplate):# 拼接两个ChatPromptTemplate的消息列表return ChatPromptTemplate(messages=self.messages + other.messages).partial(**partials)if isinstance(# 拼接单个消息模板/消息other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)):return ChatPromptTemplate(messages=[*self.messages, other]).partial(**partials)if isinstance(other, (list, tuple)):# 拼接消息列表(自动转为ChatPromptTemplate)other_ = ChatPromptTemplate.from_messages(other)return ChatPromptTemplate(messages=self.messages + other_.messages).partial(**partials)if isinstance(other, str):# 拼接字符串(自动转为HumanMessagePromptTemplate)prompt = HumanMessagePromptTemplate.from_template(other)return ChatPromptTemplate(messages=[*self.messages, prompt]).partial(**partials)msg = f"Unsupported operand type for +: {type(other)}"raise NotImplementedError(msg)@model_validator(mode="before")@classmethoddef validate_input_variables(cls, values: dict) -> Any:"""Validate input variables.If input_variables is not set, it will be set to the union ofall input variables in the messages.Args:values: values to validate.Returns:Validated values.Raises:ValueError: If input variables do not match."""# 1. 重新推断输入变量(确保与消息模板一致)messages = values["messages"]input_vars: set = set()optional_variables = set()input_types: dict[str, Any] = values.get("input_types", {})for message in messages:if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)):input_vars.update(message.input_variables)# 处理MessagesPlaceholder可选变量if isinstance(message, MessagesPlaceholder):if "partial_variables" not in values:values["partial_variables"] = {}if (message.optionaland message.variable_name not in values["partial_variables"]):values["partial_variables"][message.variable_name] = []optional_variables.add(message.variable_name)if message.variable_name not in input_types:input_types[message.variable_name] = list[AnyMessage]if "partial_variables" in values:input_vars -= set(values["partial_variables"])if optional_variables:input_vars -= optional_variables# 2. 检查用户传入的input_variables是否与推断结果一致(仅当validate_template=True时)if "input_variables" in values and values.get("validate_template"):if input_vars != set(values["input_variables"]):msg = ("Got mismatched input_variables. "f"Expected: {input_vars}. "f"Got: {values['input_variables']}")raise ValueError(msg)else:values["input_variables"] = sorted(input_vars)if optional_variables:values["optional_variables"] = sorted(optional_variables)# 3. 更新输入变量与可选变量values["input_types"] = input_typesreturn values@classmethoddef from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate:"""Create a chat prompt template from a template string.Creates a chat template consisting of a single message assumed to be fromthe human.Args:template: template string**kwargs: keyword arguments to pass to the constructor.Returns:A new instance of this class."""prompt_template = PromptTemplate.from_template(template, **kwargs)message = HumanMessagePromptTemplate(prompt=prompt_template)return cls.from_messages([message])@classmethod@deprecated("0.0.1", alternative="from_messages", pending=True)def from_role_strings(cls, string_messages: list[tuple[str, str]]) -> ChatPromptTemplate:"""Create a chat prompt template from a list of (role, template) tuples.Args:string_messages: list of (role, template) tuples.Returns:a chat prompt template."""return cls(messages=[ChatMessagePromptTemplate.from_template(template, role=role)for role, template in string_messages])@classmethod@deprecated("0.0.1", alternative="from_messages", pending=True)def from_strings(cls, string_messages: list[tuple[type[BaseMessagePromptTemplate], str]]) -> ChatPromptTemplate:"""Create a chat prompt template from a list of (role class, template) tuples.Args:string_messages: list of (role class, template) tuples.Returns:a chat prompt template."""return cls.from_messages(string_messages)@classmethoddef from_messages(cls,messages: Sequence[MessageLikeRepresentation],template_format: PromptTemplateFormat = "f-string",) -> ChatPromptTemplate:"""Create a chat prompt template from a variety of message formats.Examples:Instantiation from a list of message templates:.. code-block:: pythontemplate = ChatPromptTemplate.from_messages([("human", "Hello, how are you?"),("ai", "I'm doing well, thanks!"),("human", "That's good to hear."),])Instantiation from mixed message formats:.. code-block:: pythontemplate = ChatPromptTemplate.from_messages([SystemMessage(content="hello"),("human", "Hello, how are you?"),])Args:messages: sequence of message representations.A message can be represented using the following formats:(1) BaseMessagePromptTemplate, (2) BaseMessage, (3) 2-tuple of(message type, template); e.g., ("human", "{user_input}"),(4) 2-tuple of (message class, template), (5) a string which isshorthand for ("human", template); e.g., "{user_input}".template_format: format of the template. Defaults to "f-string".Returns:a chat prompt template."""# 从多样化消息格式创建模板,是最推荐的初始化方式return cls(messages, template_format=template_format)def format_messages(self, **kwargs: Any) -> list[BaseMessage]:"""Format the chat template into a list of finalized messages.Args:**kwargs: keyword arguments to use for filling in template variablesin all the template messages in this chat template.Returns:list of formatted messages."""# 1. 合并“部分变量”与“用户传入变量”(部分变量优先预填充)kwargs = self._merge_partial_and_user_variables(**kwargs)result = []# 2. 遍历消息模板,逐个填充变量for message_template in self.messages:if isinstance(message_template, BaseMessage):# 已生成的消息(无变量),直接加入结果result.extend([message_template])elif isinstance(# 消息模板:调用其format_messages填充变量message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate)):message = message_template.format_messages(**kwargs)result.extend(message)else:msg = f"Unexpected input: {message_template}"raise ValueError(msg)  # noqa: TRY004return resultasync def aformat_messages(self, **kwargs: Any) -> list[BaseMessage]:"""Async format the chat template into a list of finalized messages.Args:**kwargs: keyword arguments to use for filling in template variablesin all the template messages in this chat template.Returns:list of formatted messages.Raises:ValueError: If unexpected input."""kwargs = self._merge_partial_and_user_variables(**kwargs)result = []for message_template in self.messages:if isinstance(message_template, BaseMessage):result.extend([message_template])elif isinstance(message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate)):message = await message_template.aformat_messages(**kwargs)result.extend(message)else:msg = f"Unexpected input: {message_template}"raise ValueError(msg)  # noqa:TRY004return resultdef partial(self, **kwargs: Any) -> ChatPromptTemplate:"""Get a new ChatPromptTemplate with some input variables already filled in.Args:**kwargs: keyword arguments to use for filling in template variables. Oughtto be a subset of the input variables.Returns:A new ChatPromptTemplate.Example:.. code-block:: pythonfrom langchain_core.prompts import ChatPromptTemplatetemplate = ChatPromptTemplate.from_messages([("system", "You are an AI assistant named {name}."),("human", "Hi I'm {user}"),("ai", "Hi there, {user}, I'm {name}."),("human", "{input}"),])template2 = template.partial(user="Lucy", name="R2D2")template2.format_messages(input="hello")"""# 1. 复制当前模板的配置prompt_dict = self.__dict__.copy()# 2. 更新输入变量(移除已预填充的变量)prompt_dict["input_variables"] = list(set(self.input_variables).difference(kwargs))# 3. 合并部分变量(新预填充变量覆盖旧的)prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}return type(self)(**prompt_dict)def append(self, message: MessageLikeRepresentation) -> None:"""Append a message to the end of the chat template.Args:message: representation of a message to append.""""""添加单个消息到模板末尾"""self.messages.append(_convert_to_message_template(message))def extend(self, messages: Sequence[MessageLikeRepresentation]) -> None:"""Extend the chat template with a sequence of messages.Args:messages: sequence of message representations to append.""""""扩展多个消息到模板末尾"""self.messages.extend([_convert_to_message_template(message) for message in messages])@overloaddef __getitem__(self, index: int) -> MessageLike: ...@overloaddef __getitem__(self, index: slice) -> ChatPromptTemplate: ...def __getitem__(self, index: Union[int, slice]) -> Union[MessageLike, ChatPromptTemplate]:"""Use to index into the chat template."""if isinstance(index, slice):start, stop, step = index.indices(len(self.messages))messages = self.messages[start:stop:step]return ChatPromptTemplate.from_messages(messages)return self.messages[index]def __len__(self) -> int:"""Get the length of the chat template."""return len(self.messages)@propertydef _prompt_type(self) -> str:"""Name of prompt type. Used for serialization."""return "chat"def save(self, file_path: Union[Path, str]) -> None:"""Save prompt to file.Args:file_path: path to file."""raise NotImplementedError@overridedef pretty_repr(self, html: bool = False) -> str:"""Human-readable representation.Args:html: Whether to format as HTML. Defaults to False.Returns:Human-readable representation."""# TODO: handle partialsreturn "\n\n".join(msg.pretty_repr(html=html) for msg in self.messages)
  • 初始化方法(__init__):消息转换与变量推断。通过 _convert_to_message_template 函数,将用户传入的多样化消息格式(如元组、字符串、BaseMessage)转换为 LangChain 标准的 “消息模板”(如 HumanMessagePromptTemplateSystemMessagePromptTemplate),变量自动推断:无需用户手动指定 input_variables,自动从消息模板中提取所有变量(如 {user_input}、{context}),同时标记可选变量(如 MessagesPlaceholder 的 {conversation}),减少配置错误。示例:
    • 输入 ("human", "{user_input}") → 转换为 HumanMessagePromptTemplate
    • 输入 SystemMessage(content="你是助手") → 直接保留为 BaseMessage
    • 输入 "请回答{question}" → 简写为 HumanMessagePromptTemplate
  • 核心方法:模板创建、变量填充与扩展
    • 模板创建:from_messages(最常用),支持的消息格式:
      • BaseMessagePromptTemplate:标准消息模板(如 HumanMessagePromptTemplate);
      • BaseMessage:已生成的消息(如 SystemMessage(content="你是助手"));
      • 元组 (角色名, 模板):如 ("system", "你是{name}助手")
      • 元组 (消息类, 模板):如 (HumanMessage, "请回答{question}")
      • 字符串:简写为 ("human", 字符串),如 "请回答{question}"
    • 变量填充:format_messages(同步)与 aformat_messages(异步)。负责将变量填充到模板中,生成聊天模型可接受的 BaseMessage 列表;
    • 部分变量预填充:partial 方法。支持预先填充部分变量,后续使用时无需重复传入(适合固定配置,如系统消息中的角色名)。
    • 模板拼接:__add__ 方法,支持通过 + 运算符拼接多个模板或消息,灵活扩展提示结构。
    • 动态修改模板:append 与 extend,支持在模板创建后动态添加 / 扩展消息,适配动态对话场景。
  • 变量验证:validate_input_variables通过模型验证器自动检查输入变量是否匹配,避免用户传错参数。开启 validate_template=True 后,若用户传入的变量与模板不匹配(如少传 {context}),会直接报错,减少调试成本。

参考

  • LangChain
http://www.dtcms.com/a/352573.html

相关文章:

  • tracebox工具使用
  • LKT4202UGM耗材防伪安全芯片,守护您的消费电子产品
  • 从串口到屏幕:如何用C#构建一个军工级数据实时监控
  • JUC之synchronized关键字
  • Dify 从入门到精通(第 57/100 篇):Dify 的知识库扩展(进阶篇)
  • 8.26学习总结
  • 在 C# 中使用 Consul 客户端库实现服务发现
  • 卷积操作现实中的意义
  • 发力低空经济领域,移动云为前沿产业加速崛起注入云端动能
  • 微服务-24.网关登录校验-实现登录校验
  • Linux系统日志分析与存储
  • 机器学习:前篇
  • 从行业智能体到一站式开发平台,移动云推动AI智能体规模化落地
  • 产品经理操作手册(3)——产品需求文档
  • Duplicate Same Files Searcher v10.7.0,秒扫全盘重复档,符号链接一键瘦身
  • 【软件测试面试】全网最全,自动化测试面试题总结大全(付答案)
  • 告别出差!蓝蜂物联网网关让PLC程序远程修改零延迟
  • 二、JVM 入门 —— (四)堆以及 GC
  • 渗透测试术语大全(超详细)
  • C++ STL 顶层设计与安全:迭代器、失效与线程安全
  • 【C++游记】栈vs队列vs优先级队列
  • 算法编程实例-快乐学习
  • 随机森林实战:在鸢尾花数据集上与决策树和逻辑斯蒂回归进行对比
  • AI安全监控与人才需求的时间悖论(对AI安全模型、AI安全人才需求的一些思考)
  • AIDL和HIDL的AudioHal对比
  • Maya绑定基础: FK 和 IK 介绍和使用
  • lottie动画动态更改切图添加事件
  • 五自由度磁悬浮轴承:精准狙击转子质量不平衡引发的同频振动
  • pycharm 远程连接服务器报错
  • NeRAF、ImVid论文解读