open webui源码分析10-四个特征之搜索
open webui支持四个可插拔的特征,可用于加强大模型的能力,四个特征分别是记忆(memory)、搜索(web search)、文生图(image generation)和代码解析(code interpreter),今天先从搜索开始分析。
一、启用搜索
open webui缺省启动时是不支持联网搜索的,需要通过设置相关环境变量后才能使用。具体环境变量包括:
ENABLE_WEB_SEARCH:true为启用。缺省为禁用。
WEB_SEARCH_ENGINE:使用的搜索引擎名,比如google_pse,tavily等
{引擎名大写}_API_KEY:访问引擎时的API_KEY,比如TAVILY_API_KEY
这些环境变量可以在制作镜像时配置,也可以在运行容器时指定。
比如:我们想使用tavily搜索引擎,则在启动open webui时,通过--env传入,具体如下:
--env ENABLE_WET_SEARCH=true --env WEB_SEARCH_ENGINE=tavily --env TAVILY_API_KEY={你申请的KEY}
启用搜索后,如下对话窗口可以看到联网搜索按钮,用户可以自由选择是否启用:
二、源码分析
1)请求数据
启用联网搜索有,发起会话请求时数据与未启用时的区别是features中的web_search被设置为true。
2)源码分析
相关源码入口仍然在process_chat_payload函数中,具体如下:
async def process_chat_payload(request, form_data, user, metadata, model):
……
features = form_data.pop("features", None)
if features:……
'''
**********************看这里*********************************。
检查前端数据中的features['web_search'],如果为true,则调用
chat_web_search_handler方法
'''
if "web_search" in features and features["web_search"]:
form_data = await chat_web_search_handler(
request, form_data, extra_params, user
)……
……
下面分析chat_web_search_handler方法源码:
本方法流程如下:
1)通知前端要生成输入生成搜索问题
2)调用大模型根据原始输入生成搜索问题
3)调用process_web_search,并把返回的数据作为文件追加到表单的files中,作为上下文
4)把搜索结果的简略信息推送到前端
async def chat_web_search_handler(
request: Request, form_data: dict, extra_params: dict, user
):
event_emitter = extra_params["__event_emitter__"]
await event_emitter(推送发起搜索通知给前端
{
"type": "status",
"data": {
"action": "web_search",
"description": "Generating search query",
"done": False,
},
}
)messages = form_data["messages"]
user_message = get_last_user_message(messages)#用户当前输入queries = []
try:
res = await generate_queries(#调用大模型生成搜索问题
request,
{
"model": form_data["model"],
"messages": messages,
"prompt": user_message,
"type": "web_search",
},
user,
)#response是大模型针对用户问题分析后得到的搜索问题列表
response = res["choices"][0]["message"]["content"]
try:
'''
具体的查询问题被{}包围,比如:
\n \"queries\": [\"-current korean president\"]\n}
'''
bracket_start = response.find("{")
bracket_end = response.rfind("}") + 1if bracket_start == -1 or bracket_end == -1:
raise Exception("No JSON object found in the response")response = response[bracket_start:bracket_end]
queries = json.loads(response)
queries = queries.get("queries", [])
except Exception as e:
queries = [response]except Exception as e:
log.exception(e)
queries = [user_message]#以下是两个 防错处理。
#如果大模型返回的问题为"",则把用户原始问题作为查询问题
if len(queries) == 1 and queries[0].strip() == "":
queries = [user_message]
if len(queries) == 0:#如果大模型未返回需要查询的问题,推送状态通知到前端
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "No search query generated",
"done": True,
},
}
)
return form_dataawait event_emitter(#推送开始搜索通知到前端
{
"type": "status",
"data": {
"action": "web_search",
"description": "Searching the web",
"done": False,
},
}
)try:
'''
进行网络搜索,返回数据包括status, collection_names, filenames, loaded_count
'''
results = await process_web_search(
request,
SearchForm(queries=queries),
user=user,
)if results:
files = form_data.get("files", [])if results.get("collection_names"):#看这里
for col_idx, collection_name in enumerate(
results.get("collection_names")
):
files.append(#把加工后的搜索内容作为一个文件追加到files列表中
{
"collection_name": collection_name,
"name": ", ".join(queries),
"type": "web_search",
"urls": results["filenames"],
"queries": queries,
}
)
elif results.get("docs"):
# Invoked when bypass embedding and retrieval is set to True
docs = results["docs"]
files.append(
{
"docs": docs,
"name": ", ".join(queries),
"type": "web_search",
"urls": results["filenames"],
"queries": queries,
}
)form_data["files"] = files#更新表单中的files
await event_emitter(#把搜索结果简略信息推送到前端
{
"type": "status",
"data": {
"action": "web_search",
"description": "Searched {{count}} sites",
"urls": results["filenames"],
"done": True,
},
}
)
else:#如果没有搜索到任何结果,则推送通知到前端
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "No search results found",
"done": True,
"error": True,
},
}
)except Exception as e:
log.exception(e)
await event_emitter(
{
"type": "status",
"data": {
"action": "web_search",
"description": "An error occurred while searching the web",
"queries": queries,
"done": True,
"error": True,
},
}
)return form_data
搜索执行逻辑在process_web_search需要重点分析,具体代码如下:
本方法流程如下:
1)针对每个搜索问题创建一个异步任务search_web,并分别启动
2)收集多个异步任务的搜索结果
3)收集所有搜索结果的网页地址,并使用WebBaseLoader加载所有网页
4)把WebBaseLoader加载的数据插入到向量库中
@router.post("/process/web/search")
async def process_web_search(
request: Request, form_data: SearchForm, user=Depends(get_verified_user)
):urls = []
try:
logging.info(
f"trying to web search with {request.app.state.config.WEB_SEARCH_ENGINE, form_data.queries}"
)search_tasks = [#针对每个搜索问题,创建一个异步搜索任务
run_in_threadpool(
search_web, #调用对应搜索引擎
request,
request.app.state.config.WEB_SEARCH_ENGINE,
query,
)
for query in form_data.queries
]'''
每个异步任务搜索一个问题,返回结果为SearchResult(link, title, content)组成的数组
汇总后的返回结果为二维数组[[SearchResult1,SearchResult2, SearchResult3],…… ]
'''
search_results = await asyncio.gather(*search_tasks)#启动所有的异步搜索任务
for result in search_results:#遍历外层数组
if result:
for item in result: #遍历内层数组
if item and item.link:
urls.append(item.link) #把每个搜索结果的地址增加到urls中#这行代码有些费解,urls已经是数组了,为什么还要转成dict?
urls = list(dict.fromkeys(urls))
log.debug(f"urls: {urls}")except Exception as e:
log.exception(e)raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
)try:
if request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER:#不须考虑
search_results = [
item for result in search_results for item in result if result
]docs = [
Document(
page_content=result.snippet,
metadata={
"source": result.link,
"title": result.title,
"snippet": result.snippet,
"link": result.link,
},
)
for result in search_results
if hasattr(result, "snippet")
]
else: #缺省用WebBaseLoader加载网页,也可以在启动容器时配置
loader = get_web_loader(
urls,
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
requests_per_second=request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
trust_env=request.app.state.config.WEB_SEARCH_TRUST_ENV,
)
docs = await loader.aload() #docs为Document对象数组urls = [#用加载网页的source组成urls列表
doc.metadata.get("source") for doc in docs if doc.metadata.get("source")
]if request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
#该分支不须关注
return {
"status": True,
"collection_name": None,
"filenames": urls,
"docs": [
{
"content": doc.page_content,
"metadata": doc.metadata,
}
for doc in docs
],
"loaded_count": len(docs),
}
else:#关键代码在这里
'''计算所有根据搜索问题的sha256摘要,用前面的63字节拼接到web-search-之后作
为集合名
'''
collection_name = (
f"web-search-{calculate_sha256_string('-'.join(form_data.queries))}"[
:63
]
)try:
#把所有加载的网页内容增加向量库中前面所生成的集合名指定的集合中
await run_in_threadpool(
save_docs_to_vector_db,
request,
docs,
collection_name,
overwrite=True,
user=user,
)
except Exception as e:
log.debug(f"error saving docs: {e}")return {#返回数据
"status": True,
"collection_names": [collection_name],
"filenames": urls,
"loaded_count": len(docs),
}
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
现在再看一下search_web的代码。该方法根据系统在启动时设置的搜索引擎,调用对应的引擎的API进行搜索,并返回搜索结果,以使用tavily为例。
代码很简明,不做分析。
def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
……
elif engine == "tavily":
if request.app.state.config.TAVILY_API_KEY:
return search_tavily(
request.app.state.config.TAVILY_API_KEY,#环境变量传入
query,
request.app.state.config.WEB_SEARCH_RESULT_COUNT, #缺省为3
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,#缺省为空
)
else:
raise Exception("No TAVILY_API_KEY found in environment variables")……
以上代码中 search_tavily是调用搜索引擎的核心代码,需要分析一下:
本方法主要逻辑如下:
1)组织HTTP请求,包括在头部设置API_KEY和内容类型,在HTTP BODY中设置搜索问题及返回结果数量
2)发送HTTP请求到tavily的API服务
3)解析结果并做白名单检查后返回搜索结果
def search_tavily(
api_key: str,
query: str, #查询语句,也就是前面大模型根据用户问题生成的搜索问题
count: int,#返回结果数
filter_list: Optional[list[str]] = None,#过滤器,实际就是白名单,用来限制搜索结果的范围
# **kwargs,
) -> list[SearchResult]:
url = "https://api.tavily.com/search"#tavily API地址'''
以下代码组织搜索请求,包括在HTTP header设置Content_Type和在Authorization设置
API_KEY,在HTTP Body中设置查询内容和返回结果数量
'''
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
data = {"query": query, "max_results": count}'''
调用tavily API。返回结果形式如下:
{
"query":"搜索问题",
"follow_up_questions":null,
"answer":null,
"images":[],"
"results":[
{
"url":"搜索结果网页地址",
"title":"网页标题",
"score":"得分",
"published_date":"网页发布日期",
"content":"网页内容",
"raw_content":"原始内容"
},
……
]
}
'''
response = requests.post(url, headers=headers, json=data)
response.raise_for_status()json_response = response.json()
results = json_response.get("results", [])#提取应答json中的 results
if filter_list:
results = get_filtered_results(results, filter_list)#过滤掉不在白名单中的网页return [#返回SearchResult列表
SearchResult(
link=result["url"],
title=result.get("title", ""),
snippet=result.get("content"),
)
for result in results
]
联网搜索相关代码分析到此结束。