当前位置: 首页 > news >正文

python transformers库笔记(BertForTokenClassification类)

BertForTokenClassification类

        BertForTokenclassification类是Hugging Face transformers库中专门为基于BERT的序列标注任务(如命名实体识别NER、词性标注POS)设计的模型类。它在BERT的基础上添加了一个线性分类层,用于对每个token进行分类。

1、特点

        任务类型:专为Token-level分类设计,即对输入序列中的每一个token预测一个标签。典型应用有命名实体识别(NER)、词性标注(POS)、语义角色标注(SRL)

2、模型架构

BERT Base Model (bert-base-uncased等)↓
[CLS] Token 1 Token 2 ... Token N [SEP]  (输出隐藏状态)↓
Dropout Layer (可选)↓
Linear Classifier (hidden_size → num_labels)↓
Softmax (输出每个 token 的标签概率)

3、关键组件

        BERT编辑器:提取上下文相关的token表示(支持所有BERT变体)

        分类头:将每个token的隐藏状态映射到标签空间(hidden_size→num_labels)

        CRF层(可选):可通过扩展添加条件随机场层,提升标签间依赖建模(需自定义实现)

4、使用方法

 (1)加载预训练模型

import torch
from transformers import BertForTokenClassification, BertTokenizerFastmodel = BertForTokenClassification.from_pretrained('chinese-bert-wwm',num_labels=10,  # 标签数量id2label={0: 'O', 1: 'B-质量差', 2: 'I-质量差', ......}  # 标签映射
)
tokenizer = BertTokenizerFast.from_pretrained('chinese-bert-wwm')

(2)数据预处理

text = '容易碎裂。质量太差,不值这个价。'
input = tokenizer(text,return_tensor='pt',trucation=True,padding=True,return_offsets_mapping=True
)
# 假设0=O,1=B-质量差,2=I-质量差,3=B-易碎裂,4=I-易碎裂
labels = [3, 4, 4, 4, 4, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0]
inputs["labels"] = torch.tensor([labels])

(3)模型推理

outputs = model(**inputs)
logits = outputs.logits  # 形状:(batch_size, seq_len, num_labels)# 获取预测标签
predictions = torch.argmax(logits, dim=-1)[0].tolist()
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])# 打印结果
for token, pred in zip(tokens, predictions):print(f"{token:15}→{model.config.id2label.get(pred, 'UNK')}")

输出示例

[CLS]          →O
容              →B-易碎裂
易              →I-易碎裂
碎              →I-易碎裂
裂              →I-易碎裂
。              →O
质              →B-质量差
量              →I-质量差
太              →I-质量差
差              →I-质量差
,              →O
不              →O
值              →O
这              →O
个              →O
价              →O
。              →O
[SEP]          →O
http://www.dtcms.com/a/269470.html

相关文章:

  • 【牛客刷题】小红的与运算
  • node.js中yarn、npm、cnpm详解
  • 精益管理与数字化转型的融合:中小制造企业降本增效的双重引擎
  • 算法训练营DAY29 第八章 贪心算法 part02
  • 实战Linux进程状态观察:R、S、D、T、Z状态详解与实验模拟
  • 联通线路物理服务器选择的关键要点
  • No Hack No CTF 2025Web部分个人WP
  • Django双下划线查询
  • 微信小程序控制空调之接收MQTT消息
  • 如何利用AI大模型对已有创意进行评估,打造杀手级的广告创意
  • deepseek实战教程-第九篇开源模型智能体开发框架solon-ai
  • Python爬取知乎评论:多线程与异步爬虫的性能优化
  • React18+TypeScript状态管理最佳实践
  • Jenkins 使用宿主机的Docker
  • 深入解析 structuredClone API:现代JS深拷贝的终极方案
  • Ubuntu 版本号与别名对照表(部分精选)
  • Java使用接口AES进行加密+微信小程序接收解密
  • Linux Ubuntu系统下载
  • Docker企业级应用:从入门到生产环境最佳实践
  • any实现(基于LLVM中libcxx实现分析)
  • 深入理解Java虚拟机(JVM):从内存管理到性能优化
  • 基于Java+Maven+Testng+Selenium+Log4j+Allure+Jenkins搭建一个WebUI自动化框架(1)搭建框架基本雏形
  • C++11标准库算法:深入理解std::find, std::find_if与std::find_if_not
  • iOS Widget 开发-3:Widget 的种类与尺寸(主屏、锁屏、灵动岛)
  • el-button传入icon用法可能会出现的问题
  • Unity开发如何解决iOS闪退问题
  • 数据分析-59-SPC统计过程控制XR图和XS图和IMR图和CPK分析图
  • 手机解压软件 7z:高效便捷的解压缩利器
  • 【机器学习笔记 Ⅲ】5 强化学习
  • C++异步编程入门