BERT 模型准备与转换详细操作流程
在尝试复现极客专栏《PyTorch 深度学习实战|24 | 文本分类:如何使用BERT构建文本分类模型?》时候,构建模型这一步骤专栏老师一笔带过,对于新手有些不友好,经过一阵摸索,终于调通了,现在总结一下整体流程。
1. 获取必要脚本文件
首先,我们需要从 Transformers 的 GitHub 仓库中找到相关文件:
# 克隆 Transformers 仓库
git clone https://github.com/huggingface/transformers.git
cd transformers
在仓库中,我们需要找到以下关键文件:
src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py
(用于 TF1.x 模型)src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py
(用于 TF2.x 模型)src/transformers/models/bert/modeling_bert.py
(BERT 的 PyTorch 实现)
2. 下载预训练模型
接下来,我们需要下载 Google 提供的预训练 BERT 模型。根据你的需求,我们选择"BERT-Base, Multilingual Cased"版本,它支持104种语言。
访问 Google 的 BERT GitHub 页面:https://github.com/google-research/bert
在该页面中找到"BERT-Base, Multilingual Cased"的下载链接,或直接使用以下命令下载:
mkdir bert-base-multilingual-cased
cd bert-base-multilingual-cased# 下载模型文件
wget https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip
unzip multi_cased_L-12_H-768_A-12.zip
解压后,你会得到以下文件:
bert_model.ckpt.data-00000-of-00001
bert_model.ckpt.index
bert_model.ckpt.meta
bert_config.json
vocab.txt
3. 模型转换
现在,我们使用之前找到的转换脚本将 TensorFlow 模型转换为 PyTorch 格式:
# 回到 transformers 目录
cd ../transformers# 执行转换脚本(针对 TF2.x 模型)
python src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py \--tf_checkpoint_path ../bert-base-multilingual-cased/bert_model.ckpt \--bert_config_file ../bert-base-multilingual-cased/bert_config.json \--pytorch_dump_path ../bert-base-multilingual-cased/pytorch_model.bin
如果你下载的是 TF1.x 模型,则使用:
python src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py \--tf_checkpoint_path ../bert-base-multilingual-cased/multi_cased_L-12_H-768_A-12/bert_model.ckpt \--bert_config_file ../bert-base-multilingual-cased/multi_cased_L-12_H-768_A-12/bert_config.json \--pytorch_dump_path ../bert-base-multilingual-cased/multi_cased_L-12_H-768_A-12/pytorch_model.bin
注意,此处需要安装tensorflow。
4. 准备完整的 PyTorch 模型目录
转换完成后,我们需要确保模型目录包含所有必要文件:
cd ../bert-base-multilingual-cased# 复制 bert_config.json 为 config.json(Transformers 库需要)
cp bert_config.json config.json
现在,你的模型目录应该包含以下三个关键文件:
config.json
:模型配置文件,包含了所有用于训练的参数设置pytorch_model.bin
:转换后的 PyTorch 模型权重文件vocab.txt
:词表文件,用于识别模型支持的各种语言的字符
5. 验证模型转换成功
为了验证模型转换是否成功,我们可以编写一个简单的脚本来加载模型并进行测试:
from transformers import BertTokenizer, BertModel# 加载模型和分词器
model_path = "path/to/bert-base-multilingual-cased"
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BertModel.from_pretrained(model_path)# 测试多语言能力
texts = ["Hello, how are you?", # 英语"你好,最近怎么样?", # 中文"Hola, ¿cómo estás?" # 西班牙语
]for text in texts:inputs = tokenizer(text, return_tensors="pt")outputs = model(**inputs)print(f"Text: {text}")print(f"Shape of last hidden states: {outputs.last_hidden_state.shape}")print("---")
6. 使用模型进行下游任务
现在你可以使用这个转换好的模型进行各种下游任务,如文本分类、命名实体识别等:
from transformers import BertTokenizer, BertForSequenceClassification
import torch# 加载模型和分词器
model_path = "path/to/bert-base-multilingual-cased"
tokenizer = BertTokenizer.from_pretrained(model_path)# 初始化分类模型(假设有2个类别)
model = BertForSequenceClassification.from_pretrained(model_path, num_labels=2)# 准备输入
text = "这是一个测试文本"
inputs = tokenizer(text, return_tensors="pt")# 前向传播
outputs = model(**inputs)
logits = outputs.logits# 获取预测结果
predicted_class = torch.argmax(logits, dim=1).item()
print(f"预测类别: {predicted_class}")
注意事项
-
模型文件大小:BERT-Base 模型文件通常较大(约400MB+),请确保有足够的磁盘空间和内存。
-
路径问题:在执行转换脚本时,确保正确指定了所有文件的路径。
-
命名约定:Transformers 库期望配置文件名为
config.json
,而不是bert_config.json
,所以需要进行复制或重命名。 -
TensorFlow 版本:根据你下载的模型版本(TF1.x 或 TF2.x),选择正确的转换脚本。
-
checkpoint 文件:转换脚本中的
--tf_checkpoint_path
参数应该指向不带后缀的 checkpoint 文件名(如bert_model.ckpt
),而不是具体的.index
或.data
文件。
通过以上步骤,你就可以成功地将 Google 预训练的 BERT 模型转换为 PyTorch 格式,并在你的项目中使用它了。这个多语言版本的 BERT 模型支持 104 种语言,非常适合多语言自然语言处理任务。