当前位置: 首页 > news >正文

使用自定义LLM和Embedding模型部署Vanna:基于RAG的Text-to-SQL生成

使用自定义LLM和Embedding模型部署Vanna:基于RAG的Text-to-SQL生成

说明:

  • 首次发表日期:2024-07-12
  • Vanna Github地址: https://github.com/vanna-ai/vanna
  • Vanna官方文档: https://vanna.ai/

部署Vanna时我们可以选择使用什么大模型和向量数据库,比如OPEN AI和ChromaDB等这些官方支持的。

但是存在一个问题,为了保证数据不存在泄露风险,部署自己的大模型服务比较安全。

Vanna官方文档中说明可以使用自定义大模型的,不过没有给出具体的例子,本文提供一个例子以供参考。

继承VannaBase,并调用自己的大模型实现接口

一般我们的大模型服务,不过是第三方的还是自己部署的,大多都有提供和OPEN AI兼容的接口;所以,我们只需要复制一下Vanna提供的OpenAI_Chat类,进行少量修改,使其可以调用自定义模型即可,代码如下:

class OpenAICompatibleLLM(VannaBase):def __init__(self, client=None, config=None):VannaBase.__init__(self, config=config)# default parameters - can be overrided using configself.temperature = 0.5self.max_tokens = 500if "temperature" in config:self.temperature = config["temperature"]if "max_tokens" in config:self.max_tokens = config["max_tokens"]if "api_type" in config:raise Exception("Passing api_type is now deprecated. Please pass an OpenAI client instead.")if "api_version" in config:raise Exception("Passing api_version is now deprecated. Please pass an OpenAI client instead.")if client is not None:self.client = clientreturnif "api_base" not in config:raise Exception("Please passing api_base")if "api_key" not in config:raise Exception("Please passing api_key")self.client = OpenAI(api_key=config["api_key"], base_url=config["api_base"])def system_message(self, message: str) -> any:return {"role": "system", "content": message}def user_message(self, message: str) -> any:return {"role": "user", "content": message}def assistant_message(self, message: str) -> any:return {"role": "assistant", "content": message}def submit_prompt(self, prompt, **kwargs) -> str:if prompt is None:raise Exception("Prompt is None")if len(prompt) == 0:raise Exception("Prompt is empty")num_tokens = 0for message in prompt:num_tokens += len(message["content"]) / 4if kwargs.get("model", None) is not None:model = kwargs.get("model", None)print(f"Using model {model} for {num_tokens} tokens (approx)")response = self.client.chat.completions.create(model=model,messages=prompt,max_tokens=self.max_tokens,stop=None,temperature=self.temperature,)elif kwargs.get("engine", None) is not None:engine = kwargs.get("engine", None)print(f"Using model {engine} for {num_tokens} tokens (approx)")response = self.client.chat.completions.create(engine=engine,messages=prompt,max_tokens=self.max_tokens,stop=None,temperature=self.temperature,)elif self.config is not None and "engine" in self.config:print(f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)")response = self.client.chat.completions.create(engine=self.config["engine"],messages=prompt,max_tokens=self.max_tokens,stop=None,temperature=self.temperature,)elif self.config is not None and "model" in self.config:print(f"Using model {self.config['model']} for {num_tokens} tokens (approx)")response = self.client.chat.completions.create(model=self.config["model"],messages=prompt,max_tokens=self.max_tokens,stop=None,temperature=self.temperature,)else:if num_tokens > 3500:model = "kimi"else:model = "doubao"print(f"Using model {model} for {num_tokens} tokens (approx)")response = self.client.chat.completions.create(model=model,messages=prompt,max_tokens=self.max_tokens,stop=None,temperature=self.temperature,)for choice in response.choices:if "text" in choice:return choice.textreturn response.choices[0].message.content

继承Qdrant_VectorStore类并使用自己的Embedding服务

class CustomQdrant_VectorStore(Qdrant_VectorStore):def __init__(self,config={}):self.embedding_model_name = config.get("embedding_model_name", "beg-m3")self.embedding_api_base = config.get("embedding_api_base", "https://xxxxxxxxxxx.com")self.embedding_api_key = config.get("embedding_api_key", "sk-xxxxxxxxxxxxxxx")super().__init__(config)def generate_embedding(self, data: str, **kwargs) -> List[float]:def _get_error_string(response: requests.Response) -> str:try:if response.content:return response.json()["detail"]except Exception:passtry:response.raise_for_status()except requests.HTTPError as e:return str(e)return "Unknown error"request_body = {"model": self.embedding_model_name,"input": data,}request_body.update(kwargs)response = requests.post(url=f"{self.embedding_api_base}/v1/embeddings",json=request_body,headers={"Authorization": f"Bearer {self.embedding_api_key}"},)if response.status_code != 200:raise RuntimeError(f"Failed to create the embeddings, detail: {_get_error_string(response)}")result = response.json()embeddings = [d["embedding"] for d in result["data"]]return embeddings[0]

启动服务

  1. 定义一个CustomVanna类,继承CustomQdrant_VectorStoreOpenAICompatibleLLM
  2. 构建一个CustomVanna,在其中指定自己的大模型服务和Embedding服务的参数
  3. 链接数据库,比如mysql
  4. 启动服务
class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM):def __init__(self, llm_config=None, vector_store_config=None):CustomQdrant_VectorStore.__init__(self, config=vector_store_config)OpenAICompatibleLLM.__init__(self, config=llm_config)vn = CustomVanna(vector_store_config={"client": QdrantClient(host="xxxxx", port=6333)},llm_config={"api_key": "sk-xxxxxxxxxxxx","api_base": "https://xxxxxxxxxxxxxxxxxx/v1","model": "xxxxxxx",},
)vn.connect_to_mysql(host='xxxxx', dbname='xxx', user='xxx', password='xxx', port=3306)from vanna.flask import VannaFlaskApp
app = VannaFlaskApp(vn)
app.run()

文章转载自:

http://NNzqyism.tphrx.cn
http://KWLBlfr1.tphrx.cn
http://AryPGSyt.tphrx.cn
http://JvR5RxAr.tphrx.cn
http://cLCP5hj2.tphrx.cn
http://VC459RkM.tphrx.cn
http://9t8h8lom.tphrx.cn
http://04Q3t9TK.tphrx.cn
http://oNqvSkZu.tphrx.cn
http://4eRF540A.tphrx.cn
http://gCLwmb5Q.tphrx.cn
http://nK3SJxZt.tphrx.cn
http://c9C0kZud.tphrx.cn
http://gHf25gkt.tphrx.cn
http://z59PrZPN.tphrx.cn
http://5p8fvyn6.tphrx.cn
http://XXjeZf8H.tphrx.cn
http://HaAZcIbA.tphrx.cn
http://Dlu0byja.tphrx.cn
http://xlAjy7pP.tphrx.cn
http://F1cKsqKT.tphrx.cn
http://13kjFnxF.tphrx.cn
http://52YeqAOG.tphrx.cn
http://4CevAVob.tphrx.cn
http://xnplkJd1.tphrx.cn
http://s77isAEk.tphrx.cn
http://AEblXKWu.tphrx.cn
http://fUexUI4w.tphrx.cn
http://I9XzGPfi.tphrx.cn
http://Z4xcghz2.tphrx.cn
http://www.dtcms.com/a/380285.html

相关文章:

  • DataCollatorForCompletionOnlyLM解析(93)
  • 淘宝RecGPT:通过LLM增强推荐
  • Vue3 中使用 DOMPurify 对渲染动态 HTML 进行安全净化处理
  • 比较 iPhone:全面比较 iPhone 17 系列
  • 【Doris】集群介绍
  • 从“能写”到“能干活”:大模型工具调用(Function-Calling)的工程化落地指南
  • golang程序内存泄漏分析方法论
  • Go 语言 MQTT 消息队列学习指导文档
  • 基于数据挖掘技术构建电信5G客户预测模型的研究与应用
  • 【AI】pickle模块常见用途
  • 智慧园区,智启未来 —— 重塑高效、绿色、安全的产业新生态
  • MySQL 8新特性
  • 腾讯开源Youtu-GraphRAG
  • QT M/V架构开发实战:QStringListModel介绍
  • 【数据结构】Java集合框架:List与ArrayList
  • 开发避坑指南(48):Java Stream 判断List元素的属性是否包含指定的值
  • postgresql 数据库备份、重新构建容器
  • 大数据电商流量分析项目实战:Spark SQL 基础(四)
  • vmware ubuntu18设置共享文件夹的几个重要点
  • 每日一题(5)
  • Lumerical licence center 无法连接的问题
  • Java网络编程(2):(socket API编程:UDP协议的 socket API -- 回显程序)
  • Java 类加载机制双亲委派与自定义类加载器
  • OpenLayers数据源集成 -- 章节九:必应地图集成详解
  • 前端调试工具有哪些?常用前端调试工具推荐、前端调试工具对比与最佳实践
  • 【C++练习】16.C++将一个十进制转换为二进制
  • 公司本地服务器上搭建部署的办公系统web项目网站,怎么让外网访问?有无公网IP下的2种通用方法教程
  • 【C++】string类 模拟实现
  • 【系列文章】Linux中的并发与竞争[02]-原子操作
  • 微信小程序 -开发邮箱注册验证功能