关于ankh库加载本地模型的改进用于解决服务器无法连接外网的问题
目录
一、问题的出现
二、问题的解决
1、修改ankh_transformers.py
2、修改第一个__init__.py
3、修改第二个__init__.py
三、结语
一、问题的出现
这段是我描述出现问题的由来,不想看这部分废话的可以直接翻到下面的解决步骤。
因为打算部署这篇论文中的三个模块:ChemBERTa、ESM2、Ankh,所以遇到一些问题,其中尤其是Ankh这个板块比较麻烦。


之所以部署这些板块是因为博士师兄认为这三个板块的通用性很高,以后可能很有机会再次用到,我觉得师兄说的很对。其次就是前两个板块的加载模型虽然也是加载huggingface网站上的模型参数,但是用的是Transformers这个包来实现,这个包有加载本地模型的方法,而ankh已经是用pip安装的包了,加载模型的方法已经内置了,只有加载在线模型的方法(虽然这些方法只需要能够联通huggingface的网站并完整下载模型参数文件一次就可以了,但是实验室服务器为了安全是不连接外网的,所以无法做到),所以我想用PC下载并上传到服务器本地的方法就存在困难,需要解决。
# 旧版联网版
# Load the ChemBERTa model
model_name = f'DeepChem/{model_descriptor}'
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='./huggingface')
model = AutoModel.from_pretrained(model_name, cache_dir='./huggingface')
model.to(device).eval()# 新版本地版
# Load the ChemBERTa model
if model_descriptor == "ChemBERTa-77M-MLM":local_model_path = "/root/GEMS/huggingface/ChemBERTa-77M-MLM"
elif model_descriptor == "ChemBERTa-10M-MLM":local_model_path = "/root/GEMS/huggingface/ChemBERTa-10M-MLM"
else:local_model_path = "/root/GEMS/huggingface/ChemBERTa-77M-MLM"
tokenizer = AutoTokenizer.from_pretrained(local_model_path)
model = AutoModel.from_pretrained(local_model_path)
model.to(device).eval()# 旧版联网版
# Load ESM model and tokenizer from HuggingFace
if checkpoint=='t6': model_name = "facebook/esm2_t6_8M_UR50D"model_descriptor = 'esm2_t6_8M_UR50D'embedding_size = 320
if checkpoint=='t12':model_name = "facebook/esm2_t12_35M_UR50D"model_descriptor = 'esm2_t12_35M_UR50D'embedding_size = 480
if checkpoint=='t30':model_name = "facebook/esm2_t30_150M_UR50D"model_descriptor = 'esm2_t30_150M_UR50D'embedding_size = 640
if checkpoint=='t33':model_name = "facebook/esm2_t33_650M_UR50D"model_descriptor = '__t33_650M_UR50D'embedding_size = 1280# 新版本地版
# Load ESM model and tokenizer from HuggingFace
if checkpoint=='t6': model_name = "/root/GEMS/huggingface/esm2_t6_8M_UR50D"model_descriptor = 'esm2_t6_8M_UR50D'embedding_size = 320
if checkpoint=='t12':model_name = "/root/GEMS/huggingface/esm2_t12_35M_UR50D"model_descriptor = 'esm2_t12_35M_UR50D'embedding_size = 480
if checkpoint=='t30':model_name = "/root/GEMS/huggingface/esm2_t30_150M_UR50D"model_descriptor = 'esm2_t30_150M_UR50D'embedding_size = 640
if checkpoint=='t33':model_name = "/root/GEMS/huggingface/esm2_t33_650M_UR50D"model_descriptor = 'esm2_t33_650M_UR50D'embedding_size = 1280
可以看一下github上给的使用介绍:

下面是运行了ankh.load_base_model()这个命令的结果:
root@mytry:~/GEMS# python ankh_features.py --data_dir example_dataset --ankh_base True
True
cuda:0
'(MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /ElnaggarLab/ankh-base/resolve/main/tokenizer_config.json (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f38cb5ce750>: Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: 4b41b4a5-5783-4d3d-ad31-b80fe817461d)')' thrown while requesting HEAD https://huggingface.co/ElnaggarLab/ankh-base/resolve/main/tokenizer_config.json
Retrying in 1s [Retry 1/5].
'(MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /ElnaggarLab/ankh-base/resolve/main/tokenizer_config.json (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f38cb5dd590>: Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: 7ae9f02a-1d0b-4f5a-abea-fd51055aa003)')' thrown while requesting HEAD https://huggingface.co/ElnaggarLab/ankh-base/resolve/main/tokenizer_config.json
Retrying in 2s [Retry 2/5].
'(MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /ElnaggarLab/ankh-base/resolve/main/tokenizer_config.json (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f38cb5de8d0>: Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: 243d445c-904e-47f2-b220-fa82afe63e0c)')' thrown while requesting HEAD https://huggingface.co/ElnaggarLab/ankh-base/resolve/main/tokenizer_config.json
Retrying in 4s [Retry 3/5].
'(MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /ElnaggarLab/ankh-base/resolve/main/tokenizer_config.json (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f38cb5dfc10>: Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: f00207b6-eda0-492f-a1f3-f8194ce9c027)')' thrown while requesting HEAD https://huggingface.co/ElnaggarLab/ankh-base/resolve/main/tokenizer_config.json
Retrying in 8s [Retry 4/5].
'(MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /ElnaggarLab/ankh-base/resolve/main/tokenizer_config.json (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f38cb5f4dd0>: Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: 47112e6f-fc9c-45e9-a43e-8056b9969e58)')' thrown while requesting HEAD https://huggingface.co/ElnaggarLab/ankh-base/resolve/main/tokenizer_config.json
Retrying in 8s [Retry 5/5].
'(MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /ElnaggarLab/ankh-base/resolve/main/tokenizer_config.json (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f38cb5f5dd0>: Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: 69ec0e26-4857-44ce-af45-35262dcf635d)')' thrown while requesting HEAD https://huggingface.co/ElnaggarLab/ankh-base/resolve/main/tokenizer_config.json
'(MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /ElnaggarLab/ankh-base/resolve/main/config.json (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f38cb5dd250>: Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: d2d72a9a-1979-4413-bab3-d2b635c325c6)')' thrown while requesting HEAD https://huggingface.co/ElnaggarLab/ankh-base/resolve/main/config.json
Retrying in 1s [Retry 1/5].
'(MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /ElnaggarLab/ankh-base/resolve/main/config.json (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f38cb5ce610>: Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: 8e38cfd9-34ef-4b0e-b534-43a8f0e0426a)')' thrown while requesting HEAD https://huggingface.co/ElnaggarLab/ankh-base/resolve/main/config.json
Retrying in 2s [Retry 2/5].
'(MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /ElnaggarLab/ankh-base/resolve/main/config.json (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f38cb5f4c50>: Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: 57151e14-2231-453b-b93b-698ac21152b0)')' thrown while requesting HEAD https://huggingface.co/ElnaggarLab/ankh-base/resolve/main/config.json
Retrying in 4s [Retry 3/5].
'(MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /ElnaggarLab/ankh-base/resolve/main/config.json (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f38cb5f6a50>: Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: 3632e73d-a11a-423c-92e9-7aa3169db682)')' thrown while requesting HEAD https://huggingface.co/ElnaggarLab/ankh-base/resolve/main/config.json
Retrying in 8s [Retry 4/5].
'(MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /ElnaggarLab/ankh-base/resolve/main/config.json (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f38cb5f7ad0>: Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: 37872765-2668-4b7b-8e38-51ea9f6ffa33)')' thrown while requesting HEAD https://huggingface.co/ElnaggarLab/ankh-base/resolve/main/config.json
Retrying in 8s [Retry 5/5].
'(MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /ElnaggarLab/ankh-base/resolve/main/config.json (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f38cb604b10>: Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: 22f070e3-3f9b-42c2-8f88-18e495cbeba7)')' thrown while requesting HEAD https://huggingface.co/ElnaggarLab/ankh-base/resolve/main/config.json
Traceback (most recent call last):File "/opt/conda/lib/python3.11/site-packages/urllib3/connection.py", line 198, in _new_connsock = connection.create_connection(^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/opt/conda/lib/python3.11/site-packages/urllib3/util/connection.py", line 85, in create_connectionraise errFile "/opt/conda/lib/python3.11/site-packages/urllib3/util/connection.py", line 73, in create_connectionsock.connect(sa)
OSError: [Errno 101] Network is unreachableThe above exception was the direct cause of the following exception:Traceback (most recent call last):File "/opt/conda/lib/python3.11/site-packages/urllib3/connectionpool.py", line 787, in urlopenresponse = self._make_request(^^^^^^^^^^^^^^^^^^^File "/opt/conda/lib/python3.11/site-packages/urllib3/connectionpool.py", line 488, in _make_requestraise new_eFile "/opt/conda/lib/python3.11/site-packages/urllib3/connectionpool.py", line 464, in _make_requestself._validate_conn(conn)File "/opt/conda/lib/python3.11/site-packages/urllib3/connectionpool.py", line 1093, in _validate_connconn.connect()File "/opt/conda/lib/python3.11/site-packages/urllib3/connection.py", line 753, in connectself.sock = sock = self._new_conn()^^^^^^^^^^^^^^^^File "/opt/conda/lib/python3.11/site-packages/urllib3/connection.py", line 213, in _new_connraise NewConnectionError(
urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPSConnection object at 0x7f38cb604b10>: Failed to establish a new connection: [Errno 101] Network is unreachableThe above exception was the direct cause of the following exception:Traceback (most recent call last):File "/opt/conda/lib/python3.11/site-packages/requests/adapters.py", line 644, in sendresp = conn.urlopen(^^^^^^^^^^^^^File "/opt/conda/lib/python3.11/site-packages/urllib3/connectionpool.py", line 841, in urlopenretries = retries.increment(^^^^^^^^^^^^^^^^^^File "/opt/conda/lib/python3.11/site-packages/urllib3/util/retry.py", line 519, in incrementraise MaxRetryError(_pool, url, reason) from reason # type: ignore[arg-type]^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /ElnaggarLab/ankh-base/resolve/main/config.json (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f38cb604b10>: Failed to establish a new connection: [Errno 101] Network is unreachable'))During handling of the above exception, another exception occurred:Traceback (most recent call last):File "/opt/conda/lib/python3.11/site-packages/huggingface_hub/file_download.py", line 1543, in _get_metadata_or_catch_errormetadata = get_hf_file_metadata(^^^^^^^^^^^^^^^^^^^^^File "/opt/conda/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fnreturn fn(*args, **kwargs)^^^^^^^^^^^^^^^^^^^File "/opt/conda/lib/python3.11/site-packages/huggingface_hub/file_download.py", line 1460, in get_hf_file_metadatar = _request_wrapper(^^^^^^^^^^^^^^^^^File "/opt/conda/lib/python3.11/site-packages/huggingface_hub/file_download.py", line 283, in _request_wrapperresponse = _request_wrapper(^^^^^^^^^^^^^^^^^File "/opt/conda/lib/python3.11/site-packages/huggingface_hub/file_download.py", line 306, in _request_wrapperresponse = http_backoff(method=method, url=url, **params)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/opt/conda/lib/python3.11/site-packages/huggingface_hub/utils/_http.py", line 325, in http_backoffraise errFile "/opt/conda/lib/python3.11/site-packages/huggingface_hub/utils/_http.py", line 306, in http_backoffresponse = session.request(method=method, url=url, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/opt/conda/lib/python3.11/site-packages/requests/sessions.py", line 589, in requestresp = self.send(prep, **send_kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/opt/conda/lib/python3.11/site-packages/requests/sessions.py", line 703, in sendr = adapter.send(request, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/opt/conda/lib/python3.11/site-packages/huggingface_hub/utils/_http.py", line 95, in sendreturn super().send(request, *args, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/opt/conda/lib/python3.11/site-packages/requests/adapters.py", line 677, in sendraise ConnectionError(e, request=request)
requests.exceptions.ConnectionError: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /ElnaggarLab/ankh-base/resolve/main/config.json (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f38cb604b10>: Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: 22f070e3-3f9b-42c2-8f88-18e495cbeba7)')The above exception was the direct cause of the following exception:Traceback (most recent call last):File "/opt/conda/lib/python3.11/site-packages/transformers/utils/hub.py", line 479, in cached_fileshf_hub_download(File "/opt/conda/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fnreturn fn(*args, **kwargs)^^^^^^^^^^^^^^^^^^^File "/opt/conda/lib/python3.11/site-packages/huggingface_hub/file_download.py", line 1007, in hf_hub_downloadreturn _hf_hub_download_to_cache_dir(^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/opt/conda/lib/python3.11/site-packages/huggingface_hub/file_download.py", line 1114, in _hf_hub_download_to_cache_dir_raise_on_head_call_error(head_call_error, force_download, local_files_only)File "/opt/conda/lib/python3.11/site-packages/huggingface_hub/file_download.py", line 1658, in _raise_on_head_call_errorraise LocalEntryNotFoundError(
huggingface_hub.errors.LocalEntryNotFoundError: An error happened while trying to locate the file on the Hub and we cannot find the requested files in the local cache. Please check your connection and try again or make sure your Internet connection is on.The above exception was the direct cause of the following exception:Traceback (most recent call last):File "/root/GEMS/ankh_features.py", line 55, in <module>model, tokenizer = ankh.load_base_model()^^^^^^^^^^^^^^^^^^^^^^File "/opt/conda/lib/python3.11/site-packages/ankh/models/ankh_transformers.py", line 104, in load_base_modeltokenizer = AutoTokenizer.from_pretrained(AvailableModels.ANKH_BASE.value)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/opt/conda/lib/python3.11/site-packages/transformers/models/auto/tokenization_auto.py", line 1093, in from_pretrainedconfig = AutoConfig.from_pretrained(^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/opt/conda/lib/python3.11/site-packages/transformers/models/auto/configuration_auto.py", line 1332, in from_pretrainedconfig_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/opt/conda/lib/python3.11/site-packages/transformers/configuration_utils.py", line 662, in get_config_dictconfig_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/opt/conda/lib/python3.11/site-packages/transformers/configuration_utils.py", line 721, in _get_config_dictresolved_config_file = cached_file(^^^^^^^^^^^^File "/opt/conda/lib/python3.11/site-packages/transformers/utils/hub.py", line 322, in cached_filefile = cached_files(path_or_repo_id=path_or_repo_id, filenames=[filename], **kwargs)^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^File "/opt/conda/lib/python3.11/site-packages/transformers/utils/hub.py", line 553, in cached_filesraise OSError(
OSError: We couldn't connect to 'https://huggingface.co' to load the files, and couldn't find them in the cached files.
Check your internet connection or see how to run the library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'.
明显的联网问题。
ankh的官网链接是这个:agemagician/Ankh: Ankh: Optimized Protein Language Model
https://github.com/agemagician/Ankh可以去看一下。
二、问题的解决
ankh加载模型的核心其实也是Transformers的加载,所以我们只需要改进下面几个文件即可:
1、修改ankh_transformers.py
这个文件在github上看是位于src/ankh文件夹下的models文件夹下的,在服务器上可以通过下面这个命令去找到对应的路径:
root@mytry:~/GEMS# python
Python 3.11.11 | packaged by conda-forge | (main, Dec 5 2024, 14:17:24) [GCC 13.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import ankh
>>> print(ankh.__file__)
/opt/conda/lib/python3.11/site-packages/ankh/__init__.py
接下来是对该文件进行修改,请在代码文件的最后加上这两个部分:
# 这部分是注释,可以不加,需要说明的是,local_path是我的默认值,你可以改成你对应的文件存放路径,
# 这样你可以直接使用 model, tokenizer = ankh.load_local_base_model() 这个代码,
# 不过你不管的话,只需要在调用的时候指定路径即可,
# 比如 model, tokenizer = ankh.load_local_base_model("/root/GEMS/huggingface/ankh-base"),
# 下面的large也是同理
def load_local_base_model(local_path: str = "/root/GEMS/huggingface",generation: bool = False,output_attentions: bool = False,model_format: str = 'pt',
) -> Tuple[Union[T5EncoderModel, T5ForConditionalGeneration], AutoTokenizer]:"""Loads the base model and its tokenizer from local pathArgs:local_path (str, optional): Local path to the model files. Defaults to "/root/GEMS/huggingface".generation (bool, optional): Whether to return`T5ForConditionalGeneration` will bereturned otherwise `T5EncoderModel` willbe returned. Defaults to False.output_attentions (bool, optional): Whether to return the attention ornot. Defaults to False.model_format (str, optional): The model format, currently supports'tf' and 'pt'. Defaults to 'pt'.Returns:Tuple[Union[T5EncoderModel, T5ForConditionalGeneration],AutoTokenizer]: Returns T5 Model and its tokenizer."""# Load tokenizer from local pathtokenizer = AutoTokenizer.from_pretrained(local_path)# Load model from local path using the same logic as online loadingmodel = get_specified_model(path=local_path,generation=generation,output_attentions=output_attentions,model_format=model_format,)return model, tokenizerdef load_local_large_model(local_path: str = "/root/GEMS/huggingface",generation: bool = False,output_attentions: bool = False,model_format: str = 'pt',
) -> Tuple[Union[T5EncoderModel, T5ForConditionalGeneration], AutoTokenizer]:"""Loads the large model and its tokenizer from local pathArgs:local_path (str, optional): Local path to the model files. Defaults to "/root/GEMS/huggingface".generation (bool, optional): Whether to return`T5ForConditionalGeneration` will bereturned otherwise `T5EncoderModel` willbe returned. Defaults to False.output_attentions (bool, optional): Whether to return the attention ornot. Defaults to False.model_format (str, optional): The model format, currently supports'tf' and 'pt'. Defaults to 'pt'.Returns:Tuple[Union[T5EncoderModel, T5ForConditionalGeneration],AutoTokenizer]: Returns T5 Model and its tokenizer."""# Load tokenizer from local pathtokenizer = AutoTokenizer.from_pretrained(local_path)# Load model from local path using the same logic as online loadingmodel = get_specified_model(path=local_path,generation=generation,output_attentions=output_attentions,model_format=model_format,)return model, tokenizer
可以展示一下在huggingface网页的样式和下载后的文件夹情况,便于你的理解:


xftp左边是我的本地下载,右边是我的服务器文件夹,也是复制过去了;所以为了解决网络问题,只需要把
model, tokenizer = ankh.load_base_model()
这个命令改成:
model, tokenizer = ankh.load_local_base_model("/root/GEMS/huggingface/ankh-base")
2、修改第一个__init__.py
只是修改了ankh_transformers.py还不够,需要把同models文件夹下面的__init__.py文件进行修改:
# 旧代码
from ankh.models.convbert_binary_classification import (ConvBertForBinaryClassification,
)
from ankh.models.convbert_multiclass_classification import (ConvBertForMultiClassClassification,
)
from ankh.models.convbert_multilabel_classification import (ConvBertForMultiLabelClassification,
)
from ankh.models.convbert_regression import ConvBertForRegressionfrom .ankh_transformers import (get_available_models,load_base_model,load_large_model,load_model,
)# 新代码from ankh.models.convbert_binary_classification import (ConvBertForBinaryClassification,
)
from ankh.models.convbert_multiclass_classification import (ConvBertForMultiClassClassification,
)
from ankh.models.convbert_multilabel_classification import (ConvBertForMultiLabelClassification,
)
from ankh.models.convbert_regression import ConvBertForRegressionfrom .ankh_transformers import (get_available_models,load_base_model,load_large_model,load_model,load_local_base_model,load_local_large_model,
)# 主要就是加了两个方法的名称
3、修改第二个__init__.py
和第一个__init__.py文件相比,这个是更外一层的__init__.py文件,相当于下面这个关系
root@mytry:~/GEMS# cd /opt/conda/lib/python3.11/site-packages/ankh
root@mytry:/opt/conda/lib/python3.11/site-packages/ankh# ls
__init__.py __pycache__ extract.py models utils.py
root@mytry:/opt/conda/lib/python3.11/site-packages/ankh# cd models/
root@mytry:/opt/conda/lib/python3.11/site-packages/ankh/models# ls
__init__.py ankh_transformers.py convbert_multiclass_classification.py convbert_regression.py
__pycache__ convbert_binary_classification.py convbert_multilabel_classification.py layers.py#相对的文件夹结构如下:
|- __init__.py
|- __pycache__
|- extract.py
|- utils.py
|- models|- __init__.py|- ankh_transformers.py|- convbert_multiclass_classification.py|- convbert_regression.py|- __pycache__|- convbert_binary_classification.py|- convbert_multilabel_classification.py|- layers.py
修改这个__init__.py文件如下:
# 旧代码from .models import get_available_models, load_base_model, load_large_model, load_modelfrom .utils import FastaDataset, CSVDatasetfrom .models import (ConvBertForBinaryClassification,
)
from .models import (ConvBertForMultiClassClassification,
)
from .models import ConvBertForRegression
from .models import ConvBertForMultiLabelClassification
from typing import Unionavailable_tasks = {"binary": ConvBertForBinaryClassification,"regression": ConvBertForRegression,"multiclass": ConvBertForMultiClassClassification,"multilabel": ConvBertForMultiLabelClassification,
}def get_available_tasks():return list(available_tasks.keys())def load_downstream_model(task,
) -> Union[ConvBertForBinaryClassification,ConvBertForMultiClassClassification,ConvBertForRegression,ConvBertForMultiLabelClassification
]:return available_tasks[task]__version__ = "1.0"# 新代码from .models import get_available_models, load_base_model, load_large_model, load_model, load_local_base_model, load_local_large_modelfrom .utils import FastaDataset, CSVDatasetfrom .models import (ConvBertForBinaryClassification,
)
from .models import (ConvBertForMultiClassClassification,
)
from .models import ConvBertForRegression
from .models import ConvBertForMultiLabelClassification
from typing import Unionavailable_tasks = {"binary": ConvBertForBinaryClassification,"regression": ConvBertForRegression,"multiclass": ConvBertForMultiClassClassification,"multilabel": ConvBertForMultiLabelClassification,
}def get_available_tasks():return list(available_tasks.keys())def load_downstream_model(task,
) -> Union[ConvBertForBinaryClassification,ConvBertForMultiClassClassification,ConvBertForRegression,ConvBertForMultiLabelClassification
]:return available_tasks[task]__version__ = "1.0"
改完这些就可以了,可以看一下我的运行效果:
root@mytry:~/GEMS# python ankh_features.py --data_dir example_dataset --ankh_base True
True
cuda:0
Number of Proteins to be processed: 100
Model Name: ankh_base
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:22<00:00, 4.36it/s]
Time taken for 100 proteins: 22.968058824539185 seconds
三、结语
我其实还没用到更大规模的模型,比如github上说明的ankh3_large和ankh3_xl,这个如果想做,可以当一个本博客的延申,值得尝试。
