【NLP 38、实践 ⑩ NER 命名实体识别任务 Bert 实现】
去做具体的事,然后稳稳托举自己
—— 25.3.16
一、配置文件 config.py
1.模型与数据路径
model_path:模型训练完成后保存的位置。例如:保存最终的模型权重文件。
schema_path:数据结构定义文件,通常用于描述数据的格式(如字段名、标签类型)。
在NER任务中,可能定义实体类别(如 {"PERSON": "人名", "ORG": "组织"}
)。
train_data_path:训练数据集路径,通常为标注好的文本文件(如 train.txt
或 JSON
格式)。
valid_data_path: 验证数据集路径,用于模型训练时的性能评估和超参数调优。
vocab_path:字符词汇表文件,记录模型中使用的字符集(如中文字符、字母、数字等)。
2.模型架构
max_length:输入文本的最大序列长度。超过此长度的文本会被截断或填充(如用 [PAD]
)。
hidden_size:模型隐藏层神经元的数量,影响模型容量和计算复杂度。
num_layers:模型的堆叠层数(如LSTM、Transformer的编码器/解码器层数)。
class_num:任务类别总数。例如:NER任务中可能有9种实体类型。
vocab_size:词表大小
3.训练配置
epoch:训练轮数。每轮遍历整个训练数据集一次。
batch_size:每次梯度更新所使用的样本数量。较小的批次可能更适合内存受限的环境。
optimizer:优化器类型,用于调整模型参数。Adam是常用优化器,结合动量梯度下降。
learning_rate:学习率,控制参数更新的步长。值过小可能导致训练缓慢,过大易过拟合。
use_crf:是否启用条件随机场(CRF)层。在序列标注任务(如NER)中,CRF可捕捉标签间的依赖关系,提升准确性。
4.预训练模型
bert_path:预训练BERT模型的路径。BERT是一种强大的预训练语言模型,此处可能用于微调或特征提取。
# -*- coding: utf-8 -*-
"""
配置参数信息
"""
Config = {
"model_path": "model_output",
"schema_path": "ner_data/schema.json",
"train_data_path": "ner_data/train",
"valid_data_path": "ner_data/test",
"vocab_path":"chars.txt",
"max_length": 100,
"hidden_size": 256,
"num_layers": 2,
"epoch": 20,
"batch_size": 16,
"optimizer": "adam",
"learning_rate": 1e-3,
"use_crf": False,
"class_num": 9,
"bert_path": r"F:\人工智能NLP/NLP资料\week6 语言模型/bert-base-chinese",
"vocab_size": 20000
}
二、数据加载 loader.py
1.初始化数据加载类
data_path:数据文件存储路径
config:包含训练 / 数据配置的字典
self.config:保存包含训练 / 数据配置的字典
self.path:保存数据文件存储路径
self.tokenizer:将文本数据转换为深度学习模型(如 BERT)可处理的输入格式的核心工具
self.sentences:初始化句子列表
self.schema:加载实体标签与索引的映射关系表
self.load:调用 load()
方法从 data_path
加载原始数据,进行分词、编码、填充/截断等预处理。
def __init__(self, data_path, config):
self.config = config
self.path = data_path
self.tokenizer = load_vocab(config["bert_path"])
self.sentences = []
self.schema = self.load_schema(config["schema_path"])
self.load()
2.加载数据并预处理
① 初始化数据容器 ——>
② 文件读取与分段处理 ——>
③ 逐段解析字符与标签 ——>
④ 句子编码与填充 ——>
⑤ 数据封装与返回
self.data:列表,存储预处理后的数据样本,每个样本由输入张量和标签张量组成
sentenece:保存原始文本句子的拼接结果,便于后续可视化或调试。
open():打开文件并返回文件对象,支持读/写/追加等模式。
参数名 | 类型 | 说明 |
---|---|---|
file | 字符串 | 文件路径(绝对/相对路径) |
mode | 字符串 | 打开模式(如 r -只读、w -写入、a -追加) |
encoding | 字符串 | 文件编码(如 utf-8 ,文本模式需指定) |
errors | 字符串 | 编码错误处理方式(如 ignore 、replace ) |
文件对象.read():读取文件内容,返回字符串或字节流
参数名 | 类型 | 说明 |
---|---|---|
size | 整数 | 可选,指定读取的字节数(默认读取全部内容) |
split():按分隔符分割字符串,返回子字符串列表
参数名 | 类型 | 说明 |
---|---|---|
delimiter | 字符串 | 分隔符(默认空格) |
maxsplit | 整数 | 可选,最大分割次数(默认-1表示全部) |
strip():去除字符串首尾指定字符(默认空白字符)
参数名 | 类型 | 说明 |
---|---|---|
chars | 字符串 | 可选,指定需去除的字符集合 |
join():用分隔符连接可迭代对象的元素,返回新字符串
参数名 | 类型 | 说明 |
---|---|---|
iterable | 可迭代对象 | 需连接的元素集合(如列表、元组) |
sep | 字符串 | 分隔符(默认空字符串) |
列表.append():在列表末尾添加元素
参数名 | 类型 | 说明 |
---|---|---|
obj | 任意类型 | 要添加的元素 |
def load(self):
self.data = []
with open(self.path, encoding="utf8") as f:
segments = f.read().split("\n\n")
for segment in segments:
sentenece = []
labels = [8] # cls_token
for line in segment.split("\n"):
if line.strip() == "":
continue
char, label = line.split()
sentenece.append(char)
labels.append(self.schema[label])
sentence = "".join(sentenece)
self.sentences.append(sentence)
input_ids = self.encode_sentence(sentenece)
labels = self.padding(labels, -1)
# print(self.decode(sentence, labels))
# input()
self.data.append([torch.LongTensor(input_ids), torch.LongTensor(labels)])
return
3.加载字 / 词表
BertTokenizer.from_pretrained():是 Hugging Face Transformers 库中用于加载预训练 BERT 分词器的核心方法。
def load_vocab(vocab_path):
return BertTokenizer.from_pretrained(vocab_path)
4.加载映射关系表
加载位于指定路径的 JSON 格式的模式文件,并将其内容解析为 Python 对象以便在数据生成过程中使用。
open():打开文件并返回文件对象,用于读写文件内容。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
file_name | str | 无 | 文件路径(需包含扩展名) |
mode | str | 'r' | 文件打开模式: - 'r' : 只读- 'w' : 只写(覆盖原文件)- 'a' : 追加写入- 'b' : 二进制模式- 'x' : 创建新文件(若存在则报错) |
buffering | int | None | 缓冲区大小(仅二进制模式有效) |
encoding | str | None | 文件编码(仅文本模式有效,如 'utf-8' ) |
newline | str | '\n' | 行结束符(仅文本模式有效) |
closefd | bool | True | 是否在文件关闭时自动关闭文件描述符 |
dir_fd | int | -1 | 文件描述符(高级用法,通常忽略) |
flags | int | 0 | Linux 系统下的额外标志位 |
mode | str | 无 | (重复参数,实际使用中只需指定 mode ) |
json.load():从已打开的 JSON 文件对象中加载数据,并将其转换为 Python 对象(如字典、列表)。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
fp | io.TextIO | 无 | 已打开的文件对象(需处于读取模式) |
indent | int/str | None | 缩进空格数(美化输出,如 4 或 " " ) |
sort_keys | bool | False | 是否对 JSON 键进行排序 |
load_hook | callable | None | 自定义对象加载回调函数 |
object_hook | callable | None | 自定义对象解析回调函数 |
def load_schema(self, path):
with open(path, encoding="utf8") as f:
return json.load(f)
5.封装数据
DataLoader():PyTorch 模型训练的标配工具,通过合理的参数配置(如 batch_size
、num_workers
、shuffle
),可以显著提升数据加载效率,尤其适用于大规模数据集和复杂预处理任务。其与 Dataset
类的配合使用,是构建高效训练管道的核心。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
dataset | Dataset | None | 必须参数,自定义数据集对象(需继承 torch.utils.data.Dataset )。 |
batch_size | int | 1 | 每个批次的样本数量。 |
shuffle | bool | False | 是否在每个 epoch 开始时打乱数据顺序(训练时推荐设为 True )。 |
num_workers | int | 0 | 使用多线程加载数据的工人数量(需大于 0 时生效)。 |
pin_memory | bool | False | 是否将数据存储在 pinned memory 中(加速 GPU 数据传输)。 |
drop_last | bool | False | 如果数据集长度无法被 batch_size 整除,是否丢弃最后一个不完整的批次。 |
persistent_workers | bool | False | 是否保持工作线程在 epoch 之间持续运行(减少多线程初始化开销)。 |
worker_init_fn | callable | None | 自定义工作线程初始化函数。 |
# 用torch自带的DataLoader类封装数据
def load_data(data_path, config, shuffle=True):
dg = DataGenerator(data_path, config)
dl = DataLoader(dg, batch_size=config["batch_size"], shuffle=shuffle)
return dl
6.对于输入文本做截断 / 填充
#补齐或截断输入的序列,使其可以在一个batch内运算
def padding(self, input_id, pad_token=0):
input_id = input_id[:self.config["max_length"]]
input_id += [pad_token] * (self.config["max_length"] - len(input_id))
return input_id
7.类内魔术方法
__len__():用于定义对象的“长度”,通过内置函数 len()
调用时返回该值。它通常用于容器类(如列表、字典、自定义数据结构),表示容器中元素的个数
__getitem__():允许对象通过索引或键值访问元素,支持 obj[index]
或 obj[key]
语法。它使对象表现得像序列(如列表)或映射(如字典)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
8.对于输入的文本编码
self.tokenizer:将文本数据转换为深度学习模型(如 BERT)可处理的输入格式的核心工具
self.tokenizer.encode():Hugging Face Transformers 库中 BertTokenizer
的核心方法,用于将原始文本转换为模型可处理的输入形式。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
text | str 或 List[str] | 必填 | 输入文本(单句或句子对)。 |
text_pair | str | None | 第二段文本(用于句子对任务,如问答),与 text 拼接后生成 [CLS] text [SEP] text_pair [SEP] |
add_special_tokens | bool | True | 是否添加 [CLS] 和 [SEP] 标记。关闭后仅返回原始分词索引 |
max_length | int | 512 | 最大序列长度。超长文本会被截断,不足则填充 |
padding | str 或 bool | False | 填充策略:True /'longest' (按批次最长填充)、'max_length' (按 max_length 填充) |
truncation | str 或 bool | False | 截断策略:True (按 max_length 截断)、'only_first' (仅截断第一句) |
return_tensors | str | None | 返回张量类型:
|
return_attention_mask | bool | True | 是否生成 attention_mask ,标识有效内容(1)与填充部分(0) |
def encode_sentence(self, text, padding=True):
return self.tokenizer.encode(text,
padding="max_length",
max_length=self.config["max_length"],
truncation=True)
9.对于编码后的输入文本作解码
(04+)
: 匹配以0
(B-LOCATION)开头,后接多个4
(I-LOCATION)的连续标签
(15+)
、(26+)
、(37+)
:分别对应 ORGANIZATION(B=1, I=5)、PERSON(B=2, I=6)、TIME(B=3, I=7)的标签模式。
sentence:输入的原句(添加 $
后的版本),用于根据标签索引提取实体文本。
lables:模型输出的标签序列,转换为字符串后通过正则匹配定位实体位置。
results:存储提取的实体,键为实体类型(如 "LOCATION"
),值为该类型实体的文本列表。
location:正则匹配结果,通过 span()
获取实体在 sentence
中的起止位置,用于提取具体文本片段。
join():将可迭代对象(列表、元组等)中的元素按指定分隔符连接成一个字符串。调用该方法的字符串作为分隔符。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
iterable | 可迭代对象 | 必填 | 需连接的元素集合,所有元素必须是字符串类型。若为空,返回空字符串。 |
str():将其他数据类型(整数、浮点数、布尔值等)转换为字符串类型。支持格式化输出和复杂对象的字符串表示。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
object | 任意类型 | 必填 | 需转换的对象,如整数、列表、字典等。 |
encoding | 字符串 | 可选 | 编码格式(仅对字节类型有效),如 utf-8 。 |
errors | 字符串 | 可选 | 编码错误处理策略,如 ignore 、replace 。 |
defaultdict():创建字典的子类,为不存在的键自动生成默认值。需指定 default_factory
(如 list
、int
)定义默认值类型。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
default_factory | 可调用对象或无参数函数 | None | 用于生成默认值的函数。若未指定,访问不存在的键会抛出 KeyError 。 |
**kwargs | 关键字参数 | 可选 | 其他初始化字典的键值对,如 name="Alice" 。 |
re.finditer():在字符串中全局搜索正则表达式匹配项,返回一个迭代器,每个元素为 Match
对象
参数名 | 类型 | 说明 |
---|---|---|
pattern | str 或正则表达式对象 | 要匹配的正则表达式模式 |
string | str | 要搜索的字符串 |
flags | int (可选) | 正则匹配标志(如 re.IGNORECASE ) |
.span():返回正则匹配的起始和结束索引(左闭右开区间)
列表.append():向列表末尾添加单个元素,直接修改原列表
参数名 | 类型 | 说明 |
---|---|---|
element | 任意 | 要添加的元素 |
def decode(self, sentence, labels):
sentence = "$" + sentence
labels = "".join([str(x) for x in labels[:len(sentence) + 2]])
results = defaultdict(list)
for location in re.finditer("(04+)", labels):
s, e = location.span()
print("location", s, e)
results["LOCATION"].append(sentence[s:e])
for location in re.finditer("(15+)", labels):
s, e = location.span()
print("org", s, e)
results["ORGANIZATION"].append(sentence[s:e])
for location in re.finditer("(26+)", labels):
s, e = location.span()
print("per", s, e)
results["PERSON"].append(sentence[s:e])
for location in re.finditer("(37+)", labels):
s, e = location.span()
print("time", s, e)
results["TIME"].append(sentence[s:e])
return results
完整代码
# -*- coding: utf-8 -*-
import json
import re
import os
import torch
import random
import jieba
import numpy as np
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
from transformers import BertTokenizer
"""
数据加载
"""
class DataGenerator:
def __init__(self, data_path, config):
self.config = config
self.path = data_path
self.tokenizer = load_vocab(config["bert_path"])
self.sentences = []
self.schema = self.load_schema(config["schema_path"])
self.load()
def load(self):
self.data = []
with open(self.path, encoding="utf8") as f:
segments = f.read().split("\n\n")
for segment in segments:
sentenece = []
labels = [8] # cls_token
for line in segment.split("\n"):
if line.strip() == "":
continue
char, label = line.split()
sentenece.append(char)
labels.append(self.schema[label])
sentence = "".join(sentenece)
self.sentences.append(sentence)
input_ids = self.encode_sentence(sentenece)
labels = self.padding(labels, -1)
# print(self.decode(sentence, labels))
# input()
self.data.append([torch.LongTensor(input_ids), torch.LongTensor(labels)])
return
def encode_sentence(self, text, padding=True):
return self.tokenizer.encode(text,
padding="max_length",
max_length=self.config["max_length"],
truncation=True)
def decode(self, sentence, labels):
sentence = "$" + sentence
labels = "".join([str(x) for x in labels[:len(sentence) + 2]])
results = defaultdict(list)
for location in re.finditer("(04+)", labels):
s, e = location.span()
print("location", s, e)
results["LOCATION"].append(sentence[s:e])
for location in re.finditer("(15+)", labels):
s, e = location.span()
print("org", s, e)
results["ORGANIZATION"].append(sentence[s:e])
for location in re.finditer("(26+)", labels):
s, e = location.span()
print("per", s, e)
results["PERSON"].append(sentence[s:e])
for location in re.finditer("(37+)", labels):
s, e = location.span()
print("time", s, e)
results["TIME"].append(sentence[s:e])
return results
# 补齐或截断输入的序列,使其可以在一个batch内运算
def padding(self, input_id, pad_token=0):
input_id = input_id[:self.config["max_length"]]
input_id += [pad_token] * (self.config["max_length"] - len(input_id))
return input_id
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
def load_schema(self, path):
with open(path, encoding="utf8") as f:
return json.load(f)
def load_vocab(vocab_path):
return BertTokenizer.from_pretrained(vocab_path)
# 用torch自带的DataLoader类封装数据
def load_data(data_path, config, shuffle=True):
dg = DataGenerator(data_path, config)
dl = DataLoader(dg, batch_size=config["batch_size"], shuffle=shuffle)
return dl
if __name__ == "__main__":
from config import Config
dg = DataGenerator("ner_data/train", Config)
dl = DataLoader(dg, batch_size=32)
for x, y in dl:
print(x.shape, y.shape)
print(x[1], y[1])
input()
三、模型建立 model.py
1.模型初始化
hidden_size:定义LSTM隐藏层的维度(即每个时间步输出的特征数量
vocab_size:词表大小,即嵌入层(Embedding)可处理的词汇总数
max_length:输入序列的最大长度,用于数据预处理(如截断或填充)
class_num:分类任务的类别数量,决定线性层(nn.Linear
)的输出维度
num_layers:堆叠的LSTM层数,用于增加模型复杂度
BertModel.from_pretrained():加载预训练的 BERT 模型,支持从本地或 Hugging Face 模型库加载
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
pretrained_model_name | 字符串 | 无 | 预训练模型名称或路径(如 bert-base-chinese ) |
config | 字典/类 | 默认配置 | 自定义模型配置,覆盖默认参数(如隐藏层维度、注意力头数) |
cache_dir | 字符串 | None | 模型缓存目录 |
output_hidden_states | 布尔值 | False | 是否返回所有隐藏层输出(用于特征提取) |
nn.Linear():实现全连接层的线性变换(y = xW^T + b
)
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
in_features | 整数 | 无 | 输入特征维度(如词向量维度 hidden_size ) |
out_features | 整数 | 无 | 输出特征维度(如分类类别数 class_num ) |
bias | 布尔值 | True | 是否启用偏置项 |
CRF():条件随机场层,用于序列标注任务中约束标签转移逻辑。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
num_tags | 整数 | 无 | 标签类别数(如 class_num ) |
batch_first | 布尔值 | False | 输入张量是否为 (batch_size, seq_len) 格式 |
torch.nn.CrossEntropyLoss():计算交叉熵损失,常用于分类任务
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
ignore_index | 整数 | -1 | 忽略指定索引的标签(如填充符 -1 ) |
reduction | 字符串 | mean | 损失聚合方式(可选 none 、sum 、mean ) |
def __init__(self, config):
super(TorchModel, self).__init__()
hidden_size = config["hidden_size"]
vocab_size = config["vocab_size"] + 1
max_length = config["max_length"]
class_num = config["class_num"]
num_layers = config["num_layers"]
# self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
self.layer = BertModel.from_pretrained(config["bert_path"], hidden_size=hidden_size, num_layers=num_layers)
# self.layer = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True, num_layers=num_layers)
self.classify = nn.Linear(hidden_size * 2, class_num)
self.crf_layer = CRF(class_num, batch_first=True)
self.use_crf = config["use_crf"]
self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1) #loss采用交叉熵损失
2.前向计算
计算流程
输入 x → 嵌入层 → LSTM层 → 分类层 → 分支判断:
│
├── 有 target → CRF? → 是:计算 CRF 损失
│ │
│ └→ 否:计算交叉熵损失
│
└── 无 target → CRF? → 是:解码最优标签序列
│
└→ 否:返回预测 logits
gt():张量的逐元素比较函数,返回布尔型张量,标记输入张量中大于指定值的元素位置。常用于生成掩码(如忽略填充符)
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
other | Tensor/标量 | 无 | 比较的阈值或张量。若为标量,则张量中每个元素与该值比较;若为张量,需与输入张量形状相同。 |
out | Tensor | None | 可选输出张量,用于存储结果。 |
shape():返回张量的维度信息,描述各轴的大小。
view():调整张量的形状,支持自动推断维度(通过-1
占位符)。常用于数据展平或维度转换。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
*shape | 可变参数 | 无 | 目标形状的维度序列,如view(2, 3) 或view(-1, 28) ,-1 表示自动计算。 |
#当输入真实标签,返回loss值;无真实标签,返回预测值
def forward(self, x, target=None):
x = self.embedding(x) #input shape:(batch_size, sen_len)
x, _ = self.layer(x) #input shape:(batch_size, sen_len, input_dim)
predict = self.classify(x) #ouput:(batch_size, sen_len, num_tags) -> (batch_size * sen_len, num_tags)
if target is not None:
if self.use_crf:
mask = target.gt(-1)
# loss 是 crf 的相反数,即 - crf(predict, target, mask)
return - self.crf_layer(predict, target, mask, reduction="mean")
else:
#(number, class_num), (number)
return self.loss(predict.view(-1, predict.shape[-1]), target.view(-1))
else:
if self.use_crf:
return self.crf_layer.decode(predict)
else:
return predict
3.选择优化器
Adam():自适应矩估计优化器(Adaptive Moment Estimation),结合动量和 RMSProp 的优点。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
lr | float | 1e-3 | 学习率。 |
betas | tuple | (0.9, 0.999) | 动量系数(β₁, β₂)。 |
eps | float | 1e-8 | 防止除零误差。 |
weight_decay | float | 0 | 权重衰减率。 |
amsgrad | bool | False | 是否启用 AMSGrad 优化。 |
foreach | bool | False | 是否为每个参数单独计算梯度。 |
SGD():随机梯度下降优化器(Stochastic Gradient Descent)
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
lr | float | 1e-3 | 学习率。 |
momentum | float | 0 | 动量系数(如 momentum=0.9 )。 |
weight_decay | float | 0 | 权重衰减率。 |
dampening | float | 0 | 动力衰减系数(用于 SGD with Momentum)。 |
nesterov | bool | False | 是否启用 Nesterov 动量。 |
foreach | bool | False | 是否为每个参数单独计算梯度。 |
parameters():返回模型所有可训练参数的迭代器,常用于参数初始化或梯度清零。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
filter | callable | None | 过滤条件函数(如 lambda p: p.requires_grad )。默认返回所有参数。 |
def choose_optimizer(config, model):
optimizer = config["optimizer"]
learning_rate = config["learning_rate"]
if optimizer == "adam":
return Adam(model.parameters(), lr=learning_rate)
elif optimizer == "sgd":
return SGD(model.parameters(), lr=learning_rate)
4.模型建立
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from torch.optim import Adam, SGD
from torchcrf import CRF
import torch
from transformers import BertModel
"""
建立网络模型结构
"""
class TorchModel(nn.Module):
def __init__(self, config):
super(TorchModel, self).__init__()
hidden_size = config["hidden_size"]
vocab_size = config["vocab_size"] + 1
max_length = config["max_length"]
class_num = config["class_num"]
num_layers = config["num_layers"]
self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
# self.layer = BertModel.from_pretrained(config["bert_path"], hidden_size=hidden_size, num_layers=num_layers)
self.layer = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True, num_layers=num_layers)
self.classify = nn.Linear(hidden_size * 2, class_num)
self.crf_layer = CRF(class_num, batch_first=True)
self.use_crf = config["use_crf"]
self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1) #loss采用交叉熵损失
#当输入真实标签,返回loss值;无真实标签,返回预测值
def forward(self, x, target=None):
x = self.embedding(x) #input shape:(batch_size, sen_len)
x, _ = self.layer(x) #input shape:(batch_size, sen_len, input_dim)
predict = self.classify(x) #ouput:(batch_size, sen_len, num_tags) -> (batch_size * sen_len, num_tags)
if target is not None:
if self.use_crf:
mask = target.gt(-1)
# loss 是 crf 的相反数,即 - crf(predict, target, mask)
return - self.crf_layer(predict, target, mask, reduction="mean")
else:
#(number, class_num), (number)
return self.loss(predict.view(-1, predict.shape[-1]), target.view(-1))
else:
if self.use_crf:
return self.crf_layer.decode(predict)
else:
return predict
def choose_optimizer(config, model):
optimizer = config["optimizer"]
learning_rate = config["learning_rate"]
if optimizer == "adam":
return Adam(model.parameters(), lr=learning_rate)
elif optimizer == "sgd":
return SGD(model.parameters(), lr=learning_rate)
if __name__ == "__main__":
from config import Config
model = TorchModel(Config)
四、模型效果测试 evaluate.py
1.模型流程
Ⅰ、数据准备与初始化
加载验证集:从指定路径加载预处理后的验证数据,保持数据顺序以避免随机性干扰评估结果
初始化统计字典:为每个实体类别(如LOCATION、PERSON等)创建计数器,记录“正确识别数”“样本实体数”等统计指标
Ⅱ、模型推理与预测
切换评估模式:调用
model.eval()
关闭Dropout等训练层,确保推理稳定性批次处理:
数据迁移至GPU:若CUDA可用,将输入ID和标签移至GPU加速计算
无梯度预测:在
torch.no_grad()
上下文中执行模型推理,减少内存占用输出处理:若未使用CRF层,通过
torch.argmax
直接获取预测标签序列;若使用CRF,需解码最优路径
Ⅲ、实体解码与对齐
标签序列转换:将数值标签拼接为字符串(如
[0,4,4]
→"044"
),并截取与句子长度对齐正则匹配实体:
规则定义:通过正则表达式匹配标签模式(如
04+
表示LOCATION实体),提取连续B-I标签对应的文本片段索引对齐:根据匹配的起止位置从原始句子中截取实体(例如
"04+"
匹配到索引3-5,则提取句子[3:5]
)
Ⅳ、统计与评估指标计算
对比真实与预测实体:遍历每个句子的实体列表,统计以下指标:
正确识别数:预测实体存在于真实列表中的数量。
样本实体数:真实实体总数。
识别出实体数:预测实体总数
计算指标:
精确率(Precision):正确识别数 / 识别出实体数。
召回率(Recall):正确识别数 / 样本实体数。
F1值:精确率与召回率的调和平均
输出结果:按实体类别输出指标,并计算宏平均(Macro-F1)和微平均(Micro-F1)
Ⅴ、关键设计细节
标签编码规则:采用BIO格式(如B-LOCATION=0,I-LOCATION=4),确保实体连续性
异常处理:添加
1e-5
平滑项避免除零错误,增强数值稳定性性能优化:禁用梯度计算、GPU加速、批次处理提升效率
2.初始化
Ⅰ、加载配置文件、模型及日志模块 ——>
Ⅱ、读取验证集数据(固定顺序,避免随机性干扰评估)——>
Ⅲ、初始化统计字典
stats_dict
,按实体类别记录正确识别数、样本实体数等
config:存储运行时配置,例如数据路径、超参数(如批次大小 batch_size
)、是否使用CRF层等。通过 config["valid_data_path"]
动态获取验证集路径。
model:待评估的模型实例,用于调用预测方法(如 model(input_id)
),需提前完成训练和加载。
logger:记录运行日志,例如输出评估指标(准确率、F1值)到文件或控制台,便于调试和监控。
valid_data:验证数据集,用于模型训练时的性能评估和超参数调优。
load_data():数据加载类中,用torch自带的DataLoader类封装数据的函数
def __init__(self, config, model, logger):
self.config = config
self.model = model
self.logger = logger
self.valid_data = load_data(config["valid_data_path"], config, shuffle=False)
3.统计模型效果
Ⅰ、解码实体 ——> Ⅱ、对比结果
len():返回对象的元素数量(字符串、列表、元组、字典等)
参数名 | 类型 | 说明 |
---|---|---|
object | 任意可迭代对象 | 如字符串、列表、字典等 |
torch.argmax():返回张量中最大值所在的索引
参数名 | 类型 | 说明 |
---|---|---|
input | Tensor | 输入张量 |
dim | int | 沿指定维度查找最大值 |
keepdim | bool | 是否保持输出维度一致 |
cpu():将张量从GPU移动到CPU内存
zip():将多个可迭代对象打包成元组列表
参数名 | 类型 | 说明 |
---|---|---|
iterables | 多个可迭代对象 | 如列表、元组、字符串 |
.detach():从计算图中分离张量,阻止梯度传播
.tolist():将张量或数组转换为Python列表
def write_stats(self, labels, pred_results, sentences):
assert len(labels) == len(pred_results) == len(sentences)
if not self.config["use_crf"]:
pred_results = torch.argmax(pred_results, dim=-1)
for true_label, pred_label, sentence in zip(labels, pred_results, sentences):
if not self.config["use_crf"]:
pred_label = pred_label.cpu().detach().tolist()
true_label = true_label.cpu().detach().tolist()
true_entities = self.decode(sentence, true_label)
pred_entities = self.decode(sentence, pred_label)
# 正确率 = 识别出的正确实体数 / 识别出的实体数
# 召回率 = 识别出的正确实体数 / 样本的实体数
for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:
self.stats_dict[key]["正确识别"] += len([ent for ent in pred_entities[key] if ent in true_entities[key]])
self.stats_dict[key]["样本实体数"] += len(true_entities[key])
self.stats_dict[key]["识别出实体数"] += len(pred_entities[key])
return
4.可视化统计模型效果
精确率 (Precision):正确预测实体数 / 总预测实体数
召回率 (Recall):正确预测实体数 / 总真实实体数
F1值:精确率与召回率的调和平均
F1:F1分数:准确率与召回率的调和平均数,综合衡量模型的精确性与覆盖能力。
F1_scores:存储四个实体类别的 F1 分数,用于计算宏观平均。
precision:准确率:模型预测为某类实体的结果中,正确的比例。反映模型预测的精确度。
recall:召回率:真实存在的某类实体中,被模型正确识别的比例。反映模型对实体的覆盖能力。
key:当前处理的实体类别(如 "PERSON"
、"LOCATION"
)。
correct_pred:总正确识别数:所有类别中被正确识别的实体总数。
total_pred:总识别实体数:模型预测出的所有实体数量(含错误识别)。
true_enti:总样本实体数:验证数据中真实存在的所有实体数量。
micro_precision:微观准确率:全局视角下的准确率,所有实体类别的正确识别数与总识别数的比例。
micro_recall:微观召回率:全局视角下的召回率,所有实体类别的正确识别数与总样本实体数的比例。
micro_f1:微观F1分数:微观准确率与微观召回率的调和平均数。
列表.append():在列表末尾添加元素
参数名 | 类型 | 说明 |
---|---|---|
element | 任意 | 要添加的元素 |
logger.info():记录日志信息(需配置日志模块)
参数名 | 类型 | 说明 |
---|---|---|
format | str | 格式化字符串 |
*args | 可变参数 | 格式化参数 |
sum():计算可迭代对象的元素总和
参数名 | 类型 | 说明 |
---|---|---|
iterable | 可迭代对象 | 如列表、元组 |
start | 数值(可选) | 初始累加值 |
列表推导式:通过简洁语法生成新列表,语法:[表达式 for item in iterable if 条件]
def show_stats(self):
F1_scores = []
for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:
# 正确率 = 识别出的正确实体数 / 识别出的实体数
# 召回率 = 识别出的正确实体数 / 样本的实体数
precision = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["识别出实体数"])
recall = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["样本实体数"])
F1 = (2 * precision * recall) / (precision + recall + 1e-5)
F1_scores.append(F1)
self.logger.info("%s类实体,准确率:%f, 召回率: %f, F1: %f" % (key, precision, recall, F1))
self.logger.info("Macro-F1: %f" % np.mean(F1_scores))
correct_pred = sum([self.stats_dict[key]["正确识别"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
total_pred = sum([self.stats_dict[key]["识别出实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
true_enti = sum([self.stats_dict[key]["样本实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
micro_precision = correct_pred / (total_pred + 1e-5)
micro_recall = correct_pred / (true_enti + 1e-5)
micro_f1 = (2 * micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-5)
self.logger.info("Micro-F1 %f" % micro_f1)
self.logger.info("--------------------")
return
5.评估模型效果
模型切换为评估模式:关闭Dropout等训练层
批次处理数据:
提取原始句子
sentences
将数据迁移至GPU(若可用)
预测时禁用梯度计算(
torch.no_grad()
)优化内存统计结果:调用
write_stats
对比预测与真实标签
epoch:当前训练轮次,用于日志。
logger:记录日志的工具。
stats_dict:统计字典,记录各实体类别的指标。
valid_data:验证数据集,通常由 load_data
加载(如 config["valid_data_path"]
指定路径)
index: 循环中的批次索引
batch_data: 循环中的数据。
sentences:当前批次的原始句子
pred_results:模型预测结果
write_stats():写入统计信息
show_stats():显示统计结果
logger.info():记录日志信息(需配置日志模块)
参数名 | 类型 | 说明 |
---|---|---|
format | str | 格式化字符串 |
*args | 可变参数 | 格式化参数 |
defaultdict():创建带有默认值工厂的字典
参数名 | 类型 | 说明 |
---|---|---|
default_factory | 可调用对象 | 如int、list、自定义函数 |
model.eval():将模型设置为评估模式(关闭Dropout等训练层)
enumerate():返回索引和元素组成的枚举对象
参数名 | 类型 | 说明 |
---|---|---|
iterable | 可迭代对象 | 如列表、字符串 |
start | int(可选) | 起始索引,默认为0 |
torch.cuda.is_available():检查当前环境是否支持CUDA(GPU加速)
cuda():将张量或模型移动到GPU
参数名 | 类型 | 说明 |
---|---|---|
device | int/str | 指定GPU设备号,如"cuda:0" |
torch.no_grad():禁用梯度计算,节省内存并加速推理
def eval(self, epoch):
self.logger.info("开始测试第%d轮模型效果:" % epoch)
self.stats_dict = {"LOCATION": defaultdict(int),
"TIME": defaultdict(int),
"PERSON": defaultdict(int),
"ORGANIZATION": defaultdict(int)}
self.model.eval()
for index, batch_data in enumerate(self.valid_data):
sentences = self.valid_data.dataset.sentences[index * self.config["batch_size"]: (index+1) * self.config["batch_size"]]
if torch.cuda.is_available():
batch_data = [d.cuda() for d in batch_data]
input_id, labels = batch_data #输入变化时这里需要修改,比如多输入,多输出的情况
with torch.no_grad():
pred_results = self.model(input_id) #不输入labels,使用模型当前参数进行预测
self.write_stats(labels, pred_results, sentences)
self.show_stats()
return
6.解码
标签序列预处理:将数值标签拼接为字符串(如
[0,4,4]
→"044"
)正则匹配实体:
04+
:B-LOCATION(0)后接多个I-LOCATION(4)
15+
:B-ORGANIZATION(1)后接I-ORGANIZATION(5)其他实体类别同理
索引对齐:根据匹配位置截取原始句子中的实体文本
Ⅰ、输入预处理
在原句首添加 $
符号,通常用于对齐标签与字符位置(例如避免索引越界)
sentence = "$" + sentence
Ⅱ、标签序列转换
将整数标签序列转换为字符串,并截取长度与 sentence
对齐
str.join():将可迭代对象中的字符串元素按指定分隔符连接成一个新字符串
参数名 | 类型 | 说明 |
---|---|---|
iterable | 可迭代对象 | 元素必须为字符串类型 |
str():将对象转换为字符串表示形式,支持自定义类的 __str__
方法
参数名 | 类型 | 说明 |
---|---|---|
object | 任意 | 要转换的对象 |
len():返回对象的长度或元素个数(适用于字符串、列表、字典等)
参数名 | 类型 | 说明 |
---|---|---|
object | 可迭代对象 | 如字符串、列表等 |
列表推导式:通过简洁语法生成新列表,支持条件过滤和多层循环
[expression for item in iterable if condition]
部分 | 类型 | 说明 |
---|---|---|
expression | 表达式 | 对 item 处理后的结果 |
item | 变量 | 迭代变量 |
iterable | 可迭代对象 | 如列表、range() 生成的序列 |
condition | 条件表达式 (可选) | 过滤不符合条件的元素 |
labels = "".join([str(x) for x in labels[:len(sentence)+1]])
Ⅲ、初始化结果容器
创建默认值为列表的字典,存储四类实体:
(LOCATION、ORGANIZATION、PERSON、TIME)的识别结果
defaultdict():创建默认值字典,当键不存在时自动生成默认值(基于工厂函数)
参数名 | 类型 | 说明 |
---|---|---|
default_factory | 可调用对象 | 如 int 、list 或自定义函数 |
results = defaultdict(list)
Ⅳ、正则表达式匹配
(04+)
: 匹配以0
(B-LOCATION)开头,后接多个4
(I-LOCATION)的连续标签
(15+)
、(26+)
、(37+)
:分别对应 ORGANIZATION(B=1, I=5)、PERSON(B=2, I=6)、TIME(B=3, I=7)的标签模式。
re.finditer():在字符串中全局搜索正则表达式匹配项,返回一个迭代器,每个元素为 Match
对象
参数名 | 类型 | 说明 |
---|---|---|
pattern | str 或正则表达式对象 | 要匹配的正则表达式模式 |
string | str | 要搜索的字符串 |
flags | int (可选) | 正则匹配标志(如 re.IGNORECASE ) |
.span():返回正则匹配的起始和结束索引(左闭右开区间)
列表.append():向列表末尾添加单个元素,直接修改原列表
参数名 | 类型 | 说明 |
---|---|---|
element | 任意 | 要添加的元素 |
for location in re.finditer("(04+)", labels):
s, e = location.span()
results["LOCATION"].append(sentence[s:e])
Ⅴ、完整代码
'''
{
"B-LOCATION": 0,
"B-ORGANIZATION": 1,
"B-PERSON": 2,
"B-TIME": 3,
"I-LOCATION": 4,
"I-ORGANIZATION": 5,
"I-PERSON": 6,
"I-TIME": 7,
"O": 8
}
'''
def decode(self, sentence, labels):
sentence = "$" + sentence
labels = "".join([str(x) for x in labels[:len(sentence)+1]])
results = defaultdict(list)
for location in re.finditer("(04+)", labels):
s, e = location.span()
results["LOCATION"].append(sentence[s:e])
for location in re.finditer("(15+)", labels):
s, e = location.span()
results["ORGANIZATION"].append(sentence[s:e])
for location in re.finditer("(26+)", labels):
s, e = location.span()
results["PERSON"].append(sentence[s:e])
for location in re.finditer("(37+)", labels):
s, e = location.span()
results["TIME"].append(sentence[s:e])
return results
7.完整代码
# -*- coding: utf-8 -*-
import torch
import re
import numpy as np
from collections import defaultdict
from loader import load_data
"""
模型效果测试
"""
class Evaluator:
def __init__(self, config, model, logger):
self.config = config
self.model = model
self.logger = logger
self.valid_data = load_data(config["valid_data_path"], config, shuffle=False)
def eval(self, epoch):
self.logger.info("开始测试第%d轮模型效果:" % epoch)
self.stats_dict = {"LOCATION": defaultdict(int),
"TIME": defaultdict(int),
"PERSON": defaultdict(int),
"ORGANIZATION": defaultdict(int)}
self.model.eval()
for index, batch_data in enumerate(self.valid_data):
sentences = self.valid_data.dataset.sentences[index * self.config["batch_size"]: (index+1) * self.config["batch_size"]]
if torch.cuda.is_available():
batch_data = [d.cuda() for d in batch_data]
input_id, labels = batch_data #输入变化时这里需要修改,比如多输入,多输出的情况
with torch.no_grad():
pred_results = self.model(input_id) #不输入labels,使用模型当前参数进行预测
self.write_stats(labels, pred_results, sentences)
self.show_stats()
return
def write_stats(self, labels, pred_results, sentences):
assert len(labels) == len(pred_results) == len(sentences)
if not self.config["use_crf"]:
pred_results = torch.argmax(pred_results, dim=-1)
for true_label, pred_label, sentence in zip(labels, pred_results, sentences):
if not self.config["use_crf"]:
pred_label = pred_label.cpu().detach().tolist()
true_label = true_label.cpu().detach().tolist()
true_entities = self.decode(sentence, true_label)
pred_entities = self.decode(sentence, pred_label)
# 正确率 = 识别出的正确实体数 / 识别出的实体数
# 召回率 = 识别出的正确实体数 / 样本的实体数
for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:
self.stats_dict[key]["正确识别"] += len([ent for ent in pred_entities[key] if ent in true_entities[key]])
self.stats_dict[key]["样本实体数"] += len(true_entities[key])
self.stats_dict[key]["识别出实体数"] += len(pred_entities[key])
return
def show_stats(self):
F1_scores = []
for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:
# 正确率 = 识别出的正确实体数 / 识别出的实体数
# 召回率 = 识别出的正确实体数 / 样本的实体数
precision = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["识别出实体数"])
recall = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["样本实体数"])
F1 = (2 * precision * recall) / (precision + recall + 1e-5)
F1_scores.append(F1)
self.logger.info("%s类实体,准确率:%f, 召回率: %f, F1: %f" % (key, precision, recall, F1))
self.logger.info("Macro-F1: %f" % np.mean(F1_scores))
correct_pred = sum([self.stats_dict[key]["正确识别"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
total_pred = sum([self.stats_dict[key]["识别出实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
true_enti = sum([self.stats_dict[key]["样本实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
micro_precision = correct_pred / (total_pred + 1e-5)
micro_recall = correct_pred / (true_enti + 1e-5)
micro_f1 = (2 * micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-5)
self.logger.info("Micro-F1 %f" % micro_f1)
self.logger.info("--------------------")
return
'''
{
"B-LOCATION": 0,
"B-ORGANIZATION": 1,
"B-PERSON": 2,
"B-TIME": 3,
"I-LOCATION": 4,
"I-ORGANIZATION": 5,
"I-PERSON": 6,
"I-TIME": 7,
"O": 8
}
'''
def decode(self, sentence, labels):
sentence = "$" + sentence
labels = "".join([str(x) for x in labels[:len(sentence)+1]])
results = defaultdict(list)
for location in re.finditer("(04+)", labels):
s, e = location.span()
results["LOCATION"].append(sentence[s:e])
for location in re.finditer("(15+)", labels):
s, e = location.span()
results["ORGANIZATION"].append(sentence[s:e])
for location in re.finditer("(26+)", labels):
s, e = location.span()
results["PERSON"].append(sentence[s:e])
for location in re.finditer("(37+)", labels):
s, e = location.span()
results["TIME"].append(sentence[s:e])
return results
五、主函数文件 main.py
① 环境初始化与配置加载 ——>
② 数据加载与预处理 ——>
③ 模型初始化与硬件适配 ——>
④ 优化器与评估器初始化 ——>
⑤ 训练循环与参数更新 ——>
⑥ 模型评估与权重保存
1.导入文件
# -*- coding: utf-8 -*-
import torch
import os
import random
import numpy as np
import logging
from config import Config
from model import TorchModel, choose_optimizer
from evaluate import Evaluator
from loader import load_data
2.日志配置
logging.basicConfig():配置日志系统的基础参数(一次性设置,应在首次日志调用前调用)
参数名 | 类型 | 是否必需 | 默认值 | 说明 |
---|---|---|---|---|
filename | 字符串 | 否 | None | 日志输出文件名(若指定,日志写入文件而非控制台) |
filemode | 字符串 | 否 | 'a' | 文件打开模式(如'w' 覆盖,'a' 追加) |
format | 字符串 | 否 | 基础格式 | 日志格式模板(如'%(asctime)s - %(levelname)s - %(message)s' ) |
datefmt | 字符串 | 否 | 无 | 时间格式(如'%Y-%m-%d %H:%M:%S' ) |
level | 整数 | 否 | WARNING | 日志级别(如logging.INFO 、logging.DEBUG ) |
stream | 对象 | 否 | None | 指定日志输出流(如sys.stderr ,与filename 互斥) |
logging.getLogger():获取或创建指定名称的日志记录器(Logger)。若name
为None
,返回根日志记录器
参数名 | 类型 | 是否必需 | 默认值 | 说明 |
---|---|---|---|---|
name | 字符串 | 否 | None | 日志记录器名称(分层结构,如'module.sub' ) |
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
3.主函数 main
Ⅰ、创建模型保存目录
os.path.isdir():检查指定路径是否为目录(文件夹)
参数名 | 类型 | 是否必需 | 默认值 | 说明 |
---|---|---|---|---|
path | 字符串 | 是 | 无 | 要检查的路径(绝对或相对) |
os.mkdir():创建单个目录(若父目录不存在会抛出异常)
参数名 | 类型 | 是否必需 | 默认值 | 说明 |
---|---|---|---|---|
path | 字符串 | 是 | 无 | 要创建的目录路径 |
mode | 整数 | 否 | 0o777 | 目录权限(八进制格式,某些系统可能忽略此参数) |
#创建保存模型的目录
if not os.path.isdir(config["model_path"]):
os.mkdir(config["model_path"])
Ⅱ、加载训练数据
#加载训练数据
train_data = load_data(config["train_data_path"], config)
Ⅲ、加载模型
#加载模型
model = TorchModel(config)
Ⅳ、检查GPU并迁移模型
torch.cuda.is_available():检查系统是否满足 CUDA 环境要求
logger.info():记录日志信息,输出训练过程中的关键状态
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
msg | str | 是 | 日志消息(支持格式化字符串) | logger.info("Epoch: %d", epoch) |
*args | Any | 否 | 格式化参数(用于% 占位符) |
cuda():将张量或模型移动到GPU显存,加速计算
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
device | int/str | 否 | 指定GPU设备(如0 或"cuda:0" ) | tensor.cuda(device=0) |
non_blocking | bool | 否 | 是否异步传输数据(默认False) | tensor.cuda(non_blocking=True) |
# 标识是否使用gpu
cuda_flag = torch.cuda.is_available()
if cuda_flag:
logger.info("gpu可以使用,迁移模型至gpu")
model = model.cuda()
Ⅴ、加载优化器
#加载优化器
optimizer = choose_optimizer(config, model)
Ⅵ、加载评估器
#加载效果测试类
evaluator = Evaluator(config, model, logger)
Ⅶ、模型训练 ⭐
① Epoch循环控制
range():Python 内置函数,用于生成一个不可变的整数序列,核心功能是为循环控制提供高效的数值迭代支持
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
start | 整数 | 0 | 序列起始值(包含)。若省略,则默认从 0 开始。例如 range(3) 等价于 range(0,3) 。 |
stop | 整数 | 必填 | 序列结束值(不包含)。例如 range(2, 5) 生成 2,3,4 |
step | 整数 | 1 | 步长(正/负): - 正步长需满足 start < stop ,否则无输出(如 range(5, 2) 无效)。- 负步长需满足 start > stop ,例如 range(5, 0, -1) 生成 5,4,3,2,1 **不能为 0 **(否则触发 ValueError ) |
for epoch in range(config["epoch"]):
epoch += 1
② 模型设置训练模式
train_loss:计算当前批次的损失值,通常结合损失函数(如交叉熵、均方误差)使用
model.train():设置模型为训练模式,启用Dropout、BatchNorm等层的训练行为
参数 | 类型 | 默认值 | 说明 | 示例 |
---|---|---|---|---|
mode | bool | True | 是否启用训练模式(True)或评估模式(False) | model.train(True) |
logger.info():记录日志信息,输出训练过程中的关键状态
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
msg | str | 是 | 日志消息(支持格式化字符串) | logger.info("Epoch: %d", epoch) |
*args | Any | 否 | 格式化参数(用于% 占位符) |
model.train()
logger.info("epoch %d begin" % epoch)
train_loss = []
③ Batch数据遍历
enumerate():遍历可迭代对象时返回索引和元素,支持自定义起始索引
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
iterable | Iterable | 是 | 可迭代对象(如列表、生成器) | enumerate(["a", "b"]) |
start | int | 否 | 索引起始值(默认0) | enumerate(data, start=1) |
for index, batch_data in enumerate(train_data):
④ 梯度清零与设备切换
optimizer.zero_grad():清空模型参数的梯度,防止梯度累积
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
set_to_none | bool | 否 | 是否将梯度置为None (高效但危险) | optimizer.zero_grad(True) |
cuda():将张量或模型移动到GPU显存,加速计算
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
device | int/str | 否 | 指定GPU设备(如0 或"cuda:0" ) | tensor.cuda(device=0) |
non_blocking | bool | 否 | 是否异步传输数据(默认False) | tensor.cuda(non_blocking=True) |
optimizer.zero_grad()
if cuda_flag:
batch_data = [d.cuda() for d in batch_data]
⑤ 前向传播与损失计算
input_id, labels = batch_data #输入变化时这里需要修改,比如多输入,多输出的情况
loss = model(input_id, labels)
⑥ 反向传播与参数更新
loss.backward():反向传播计算梯度,基于损失值更新模型参数的.grad
属性
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
retain_graph | bool | 否 | 是否保留计算图(用于多次反向传播) | loss.backward(retain_graph=True) |
optimizer.step():根据梯度更新模型参数,执行优化算法(如SGD、Adam)
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
closure | Callable | 否 | 重新计算损失的闭包函数(如LBFGS) | optimizer.step(closure) |
loss.backward()
optimizer.step()
⑦ 损失记录与日志输出
列表.append():在列表末尾添加元素,直接修改原列表
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
object | Any | 是 | 要添加到列表末尾的元素 | train_loss.append(loss.item()) |
int():将字符串或浮点数转换为整数,支持进制转换
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
x | str/float | 是 | 待转换的值(如字符串或浮点数) | int("10", base=2) (输出2进制10=2) |
base | int | 否 | 进制(默认10) |
len():返回对象(如列表、字符串)的长度或元素个数
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
obj | Sequence/Collection | 是 | 可计算长度的对象(如列表、字符串) | len([1, 2, 3]) (返回3) |
logger.info():记录日志信息,输出训练过程中的关键状态
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
msg | str | 是 | 日志消息(支持格式化字符串) | logger.info("Epoch: %d", epoch) |
*args | Any | 否 | 格式化参数(用于% 占位符) |
train_loss.append(loss.item())
if index % int(len(train_data) / 2) == 0:
logger.info("batch loss %f" % loss)
⑧ Epoch评估与日志
logger.info():记录日志信息,输出训练过程中的关键状态
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
msg | str | 是 | 日志消息(支持格式化字符串) | logger.info("Epoch: %d", epoch) |
*args | Any | 否 | 格式化参数(用于% 占位符) |
train_loss.append(loss.item())
if index % int(len(train_data) / 2) == 0:
logger.info("batch loss %f" % loss)
⑨ 完整训练代码
#训练
for epoch in range(config["epoch"]):
epoch += 1
model.train()
logger.info("epoch %d begin" % epoch)
train_loss = []
for index, batch_data in enumerate(train_data):
optimizer.zero_grad()
if cuda_flag:
batch_data = [d.cuda() for d in batch_data]
input_id, labels = batch_data #输入变化时这里需要修改,比如多输入,多输出的情况
loss = model(input_id, labels)
loss.backward()
optimizer.step()
train_loss.append(loss.item())
if index % int(len(train_data) / 2) == 0:
logger.info("batch loss %f" % loss)
logger.info("epoch average loss: %f" % np.mean(train_loss))
evaluator.eval(epoch)
Ⅷ、模型保存
os.path.join():Python 中用于拼接路径的核心函数,其核心价值在于自动处理不同操作系统的路径分隔符,从而保证代码的跨平台兼容性
参数 | 类型 | 必填 | 说明 |
---|---|---|---|
path1 | 字符串 | 是 | 初始路径组件 |
*paths | 可变参数 | 否 | 后续路径组件(可传多个) |
model_path = os.path.join(config["model_path"], "epoch_%d.pth" % epoch)
# torch.save(model.state_dict(), model_path)
return model, train_data
4.调用模型预测
# -*- coding: utf-8 -*-
import torch
import os
import random
import numpy as np
import logging
from config import Config
from model import TorchModel, choose_optimizer
from evaluate import Evaluator
from loader import load_data
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
"""
模型训练主程序
"""
def main(config):
#创建保存模型的目录
if not os.path.isdir(config["model_path"]):
os.mkdir(config["model_path"])
#加载训练数据
train_data = load_data(config["train_data_path"], config)
#加载模型
model = TorchModel(config)
# 标识是否使用gpu
cuda_flag = torch.cuda.is_available()
if cuda_flag:
logger.info("gpu可以使用,迁移模型至gpu")
model = model.cuda()
#加载优化器
optimizer = choose_optimizer(config, model)
#加载效果测试类
evaluator = Evaluator(config, model, logger)
#训练
for epoch in range(config["epoch"]):
epoch += 1
model.train()
logger.info("epoch %d begin" % epoch)
train_loss = []
for index, batch_data in enumerate(train_data):
optimizer.zero_grad()
if cuda_flag:
batch_data = [d.cuda() for d in batch_data]
input_id, labels = batch_data #输入变化时这里需要修改,比如多输入,多输出的情况
loss = model(input_id, labels)
loss.backward()
optimizer.step()
train_loss.append(loss.item())
if index % int(len(train_data) / 2) == 0:
logger.info("batch loss %f" % loss)
logger.info("epoch average loss: %f" % np.mean(train_loss))
evaluator.eval(epoch)
# 保存模型
model_path = os.path.join(config["model_path"], "epoch_%d.pth" % epoch)
torch.save(model.state_dict(), model_path)
return model, train_data
if __name__ == "__main__":
model, train_data = main(Config)