利用多线程加速ESMC-6B模型API调用以及403Forbidden问题的解决
前言
只对之前这篇文章进行了补充
403 Forbidden问题的解决
这几天用了一下ESMC-6B的API,发现被403 forbidden了
排查问题查来查去,发现需要翻墙才可以访问(怎么又被针对了)
于是就需要在服务器上面接入VPN,想了想太麻烦,就直接使用ssh转发
首先在本地机上运行以下命令,建立ssh反向连接
ssh -R 127.0.0.1:7890:127.0.0.1:7890 user@remote_server.com
然后在服务器的screen终端上运行以下代码,把该终端的网络代理全部使用7890端口发送
export https_proxy=http://127.0.0.1:7890 http_proxy=http://127.0.0.1:7890 all_proxy=socks5://127.0.0.1:7890
这样就可以把服务器上某个终端的所有网络请求,利用ssh反向连接转发到本地机使用的clash端口,然后就可以实现在服务器上面连接本地的VPN。
多线程优化API调用
评价是把计网的知识忘完了
其实优化方式还有利用多线程加速(我真呆,之前没想到这个),官方在2025年1月份的更新中加入了这个教程。
这个的效果好像还可以,下午跑了1小时40分钟,跑了6600+组,平均下来大概1秒1组
关于token审核,我大概等了一个多月才审核通过。。。实在不行的同学就再等等吧。。。
另外注意长度限制,超过2048长度的蛋白质无法进行embedding接口的调用。
代码
使用的是esm 3.1.3库
from esm.sdk import client
from getpass import getpass
from concurrent.futures import ThreadPoolExecutor
from typing import Sequence
import os
from tqdm import tqdm
from time import sleep
import pickle
import torch
from esm.sdk.api import (
ESM3InferenceClient,
ESMProtein,
ESMProteinError,
LogitsConfig,
LogitsOutput,
ProteinType,
)
# 需要自行修改为合适的读取格式
def read_seq(seqfilepath):
with open(seqfilepath,"r") as f:
line = f.readline()
seq = f.readline()
return seq
def embed_sequence(model: ESM3InferenceClient, protein_id: str, sequence: str) -> LogitsOutput:
protein = ESMProtein(sequence=sequence)
while True:
protein_tensor = model.encode(protein)
if isinstance(protein_tensor,ESMProteinError):
print(protein_tensor)
sleep(1)
continue
break
while True:
logits_output = model.logits(protein_tensor, LogitsConfig(sequence=True, return_embeddings=True))
if isinstance(logits_output,ESMProteinError):
print(logits_output)
sleep(1)
continue
break
return protein_id, logits_output.embeddings.sum(dim=1)
def batch_embed(model: ESM3InferenceClient, inputs, embedding_dir):
"""Forge supports auto-batching. So batch_embed() is as simple as running a collection
of embed calls in parallel using asyncio.
"""
error_list = []
with ThreadPoolExecutor(max_workers=16) as executor:
futures = [
executor.submit(embed_sequence, model, protein_id, inputs[protein_id]) for protein_id in inputs.keys()
]
all = len(futures)
for i,future in enumerate(futures):
try:
protein_id, emb = future.result()
with open(os.path.join(embedding_dir,protein_id,"origin_seq_emb_6b.pkl"),"wb") as f:
pickle.dump(emb,f)
print(i,"/",all," Success ",protein_id)
except Exception as e:
print(i,"/",all,f" Error: {e}")
error_list.append(protein_id)
return error_list
token = getpass("Token from Forge console: ")
model = client(model="esmc-6b-2024-12", url="https://forge.evolutionaryscale.ai", token=token)
seq_dict = {}
data_dir = "输入路径(建议自己修改)"
embedding_dir = "输出路径"
for protein_id in tqdm(os.listdir(data_dir)):
seq_path = os.path.join(data_dir,protein_id,"seq.fasta")
seq = read_seq(seq_path)
if len(seq) > 2048:
continue
if os.path.exists(os.path.join(embedding_dir,protein_id,"origin_seq_emb_6b.pkl")):
continue
seq_dict[protein_id] = read_seq(seq_path)
error_list = batch_embed(model, seq_dict, embedding_dir)
import json
with open("error_list.json","w") as f:
json.dump(error_list,f,indent=4)