当前位置: 首页 > news >正文

BERT 模型准备与转换详细操作流程

在尝试复现极客专栏《PyTorch 深度学习实战|24 | 文本分类:如何使用BERT构建文本分类模型?》时候,构建模型这一步骤专栏老师一笔带过,对于新手有些不友好,经过一阵摸索,终于调通了,现在总结一下整体流程。

1. 获取必要脚本文件

首先,我们需要从 Transformers 的 GitHub 仓库中找到相关文件:

# 克隆 Transformers 仓库
git clone https://github.com/huggingface/transformers.git
cd transformers

在仓库中,我们需要找到以下关键文件:

  • src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py(用于 TF1.x 模型)
  • src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py(用于 TF2.x 模型)
  • src/transformers/models/bert/modeling_bert.py(BERT 的 PyTorch 实现)

2. 下载预训练模型

接下来,我们需要下载 Google 提供的预训练 BERT 模型。根据你的需求,我们选择"BERT-Base, Multilingual Cased"版本,它支持104种语言。

访问 Google 的 BERT GitHub 页面:https://github.com/google-research/bert

在该页面中找到"BERT-Base, Multilingual Cased"的下载链接,或直接使用以下命令下载:

mkdir bert-base-multilingual-cased
cd bert-base-multilingual-cased# 下载模型文件
wget https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip
unzip multi_cased_L-12_H-768_A-12.zip

解压后,你会得到以下文件:

  • bert_model.ckpt.data-00000-of-00001
  • bert_model.ckpt.index
  • bert_model.ckpt.meta
  • bert_config.json
  • vocab.txt

3. 模型转换

现在,我们使用之前找到的转换脚本将 TensorFlow 模型转换为 PyTorch 格式:

# 回到 transformers 目录
cd ../transformers# 执行转换脚本(针对 TF2.x 模型)
python src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py \--tf_checkpoint_path ../bert-base-multilingual-cased/bert_model.ckpt \--bert_config_file ../bert-base-multilingual-cased/bert_config.json \--pytorch_dump_path ../bert-base-multilingual-cased/pytorch_model.bin

如果你下载的是 TF1.x 模型,则使用:

python src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py \--tf_checkpoint_path ../bert-base-multilingual-cased/multi_cased_L-12_H-768_A-12/bert_model.ckpt \--bert_config_file ../bert-base-multilingual-cased/multi_cased_L-12_H-768_A-12/bert_config.json \--pytorch_dump_path ../bert-base-multilingual-cased/multi_cased_L-12_H-768_A-12/pytorch_model.bin

注意,此处需要安装tensorflow。

4. 准备完整的 PyTorch 模型目录

转换完成后,我们需要确保模型目录包含所有必要文件:

cd ../bert-base-multilingual-cased# 复制 bert_config.json 为 config.json(Transformers 库需要)
cp bert_config.json config.json

现在,你的模型目录应该包含以下三个关键文件:

  1. config.json:模型配置文件,包含了所有用于训练的参数设置
  2. pytorch_model.bin:转换后的 PyTorch 模型权重文件
  3. vocab.txt:词表文件,用于识别模型支持的各种语言的字符

5. 验证模型转换成功

为了验证模型转换是否成功,我们可以编写一个简单的脚本来加载模型并进行测试:

from transformers import BertTokenizer, BertModel# 加载模型和分词器
model_path = "path/to/bert-base-multilingual-cased"
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BertModel.from_pretrained(model_path)# 测试多语言能力
texts = ["Hello, how are you?",  # 英语"你好,最近怎么样?",    # 中文"Hola, ¿cómo estás?"   # 西班牙语
]for text in texts:inputs = tokenizer(text, return_tensors="pt")outputs = model(**inputs)print(f"Text: {text}")print(f"Shape of last hidden states: {outputs.last_hidden_state.shape}")print("---")

6. 使用模型进行下游任务

现在你可以使用这个转换好的模型进行各种下游任务,如文本分类、命名实体识别等:

from transformers import BertTokenizer, BertForSequenceClassification
import torch# 加载模型和分词器
model_path = "path/to/bert-base-multilingual-cased"
tokenizer = BertTokenizer.from_pretrained(model_path)# 初始化分类模型(假设有2个类别)
model = BertForSequenceClassification.from_pretrained(model_path, num_labels=2)# 准备输入
text = "这是一个测试文本"
inputs = tokenizer(text, return_tensors="pt")# 前向传播
outputs = model(**inputs)
logits = outputs.logits# 获取预测结果
predicted_class = torch.argmax(logits, dim=1).item()
print(f"预测类别: {predicted_class}")

注意事项

  1. 模型文件大小:BERT-Base 模型文件通常较大(约400MB+),请确保有足够的磁盘空间和内存。

  2. 路径问题:在执行转换脚本时,确保正确指定了所有文件的路径。

  3. 命名约定:Transformers 库期望配置文件名为 config.json,而不是 bert_config.json,所以需要进行复制或重命名。

  4. TensorFlow 版本:根据你下载的模型版本(TF1.x 或 TF2.x),选择正确的转换脚本。

  5. checkpoint 文件:转换脚本中的 --tf_checkpoint_path 参数应该指向不带后缀的 checkpoint 文件名(如 bert_model.ckpt),而不是具体的 .index.data 文件。

通过以上步骤,你就可以成功地将 Google 预训练的 BERT 模型转换为 PyTorch 格式,并在你的项目中使用它了。这个多语言版本的 BERT 模型支持 104 种语言,非常适合多语言自然语言处理任务。

相关文章:

  • 科学计算库 Numpy
  • 软件工程核心知识全景图:从需求到部署的系统化构建指南
  • 【AI智能体】Spring AI MCP 服务常用开发模式实战详解
  • 命令行中SSH本地端口转发和反向远程端口转发
  • 计算机网络课程设计--基于TCP协议的文件传输系统
  • linux VFS简介
  • 笔式胰岛素简单拆解
  • SAP金属行业解决方案:无锡哲讯科技助力企业数字化转型与高效运营
  • P99延迟:系统性能优化的关键指标
  • 408考研逐题详解:2010年第3题——后序线索二叉树
  • Docker容器自动更新利器:Watchtower
  • 自动化测试01
  • 如何用AI开发完整的小程序<9>—UI自适应与游戏页优化
  • oracle rac - starwind san 磁盘共享篇
  • SpringBoot+Vue服装商城系统 附带详细运行指导视频
  • 设计模式精讲 Day 10:外观模式(Facade Pattern)
  • 华为云Flexus+DeepSeek征文|Dify-LLM平台部署教程与Flexus X实例优势解析
  • CTE vs 子查询:深入拆解PostgreSQL复杂SQL的隐藏性能差异
  • JavaScript 的 “==” 存在的坑
  • 大零售生态下开源链动2+1模式、AI智能名片与S2B2C商城小程序的协同创新研究
  • 广州做网站 timhi/seo是什么意思中文翻译
  • 设计网站推荐 猪/各大搜索引擎收录入口
  • 营销方案网站/sem是什么的缩写
  • 最优网站/深圳疫情防控最新消息
  • 做任务用手机号登录网站/品牌推广
  • 个人备案的域名拿来做别的网站/优化法治化营商环境