当前位置: 首页 > news >正文

LlamaIndex中使用本地LLM和Embedding

LlamaIndex默认会调用OpenAI的text-davinci-002模型对应的API,用于获得大模型输出,这种方式在很多情况下对国内用户不太方便,如果本地有大模型可以部署,可以按照以下方式在LlamaIndex中使用本地的LLM和Embedding(这里LLM使用chatglm2-6b,Embedding使用m3e-base):

import torch
from transformers import AutoModel, AutoTokenizer
from llama_index.llms import HuggingFaceLLM
from llama_index import VectorStoreIndex, ServiceContext
from llama_index import LangchainEmbedding, ServiceContext
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from llama_index import VectorStoreIndex, SimpleDirectoryReader
from llama_index import Prompt, PromptHelper
from llama_index.node_parser import SimpleNodeParser
from llama_index.langchain_helpers.text_splitter import TextSplitter, TokenTextSplitter
from llama_index import set_global_service_context

# 需要使用GPU才能运行
device = 'cuda'

# 自定义输入大模型的prompt
TEMPLATE_STR = """我们在下面提供了上下文信息

{context_str}
根据此信息,请回答问题:{query_str}
"""
QA_TEMPLATE = Prompt(TEMPLATE_STR)

# 加载本地LLM,需提供本地LLM模型文件的路径
llm_tokenizer = AutoTokenizer.from_pretrained('/models/chatglm2-6b/', trust_remote_code=True, device=device)
llm_model = AutoModel.from_pretrained('/models/chatglm2-6b/', trust_remote_code=True, device=device)
chatglm2 = HuggingFaceLLM(model=llm_model, tokenizer=llm_tokenizer)

# 加载本地Embedding,需提供本地Embedding模型文件的路径
embed_tokenizer = AutoTokenizer.from_pretrained('/models/moka-ai/m3e-base/', trust_remote_code=True, device=device)
embed_model = LangchainEmbedding(
                HuggingFaceEmbeddings(model_name='/models/moka-ai/m3e-base/'), 
                tokenizer=embed_tokenizer)

node_parser = SimpleNodeParser(text_splitter=TokenTextSplitter(tokenizer=embed_tokenizer))
prompt_helper = PromptHelper(tokenizer=llm_tokenizer)
service_context = ServiceContext.from_defaults(
                llm=chatglm2, 
                prompt_helper=prompt_helper, 
                embed_model=embed_model, 
                node_parser=node_parser)
set_global_service_context(service_context)

documents = SimpleDirectoryReader('/path/to/your/files').load_data()
index = VectorStoreIndex.from_documents(documents, service_context=service_context)

 # 查询引擎
query_engine = index.as_query_engine(text_qa_template=QA_TEMPLATE)
# 聊天引擎
chat_eigine = index.as_chat_engine()

response = query_engine.query("your question")
print(response)

相关文章:

  • ue5 Arch vis AI traffic system 车辆系统添加不同种类的车
  • FPGA DSP:Vivado 中带有 DDS 的 FIR 滤波器
  • VS code + Cline + 阿里百炼
  • python获取网页内容 靠谱的做法
  • Linux /etc/fstab文件详解:自动挂载配置指南(中英双语)
  • DDD - 实现限界上下文集成的四种方式
  • 数据库之MySQL——事务(一)
  • 如何使用3D高斯分布进行环境建模
  • 07.Docker 数据管理
  • CORS跨域问题常见解决办法
  • 正确清理C盘空间
  • 使用LangChain构建第一个ReAct Agent
  • 开源的 LLM 应用开发平台-Dify 部署和使用
  • Linux 命令 mount 完全指南(中英双语)
  • 力扣-贪心-376 摆动序列
  • 【云服务器】云服务器内存不够用,开启SWAP交换分区
  • 深蓝学院自主泊车第3次作业-IPM
  • 跟着 Lua 5.1 官方参考文档学习 Lua (6)
  • java网络编程
  • 【Leetcode 每日一题】2506. 统计相似字符串对的数目
  • 对话|蓬皮杜策展人布莱昂:抽象风景中的中国审美
  • 湖南张家界警方公告宣布一名外国人居留许可作废
  • 国铁集团:铁路五一假期运输收官,多项运输指标创历史新高
  • 外交部:应美方请求举行贸易代表会谈,中方反对美滥施关税立场没有变化
  • 央行、证监会:科技创新债券含公司债券、企业债券、非金融企业债务融资工具等
  • 韩正出席庆祝中国欧盟建交50周年招待会并致辞