当前位置: 首页 > news >正文

【NLP 38、实践 ⑩ NER 命名实体识别任务 Bert 实现】

去做具体的事,然后稳稳托举自己

                                        —— 25.3.16

一、配置文件 config.py

1.模型与数据路径

model_path:模型训练完成后保存的位置。例如:保存最终的模型权重文件。

schema_path:数据结构定义文件,通常用于描述数据的格式(如字段名、标签类型)。
在NER任务中,可能定义实体类别(如 {"PERSON": "人名", "ORG": "组织"})。

train_data_path:训练数据集路径,通常为标注好的文本文件(如 train.txt 或 JSON 格式)。

valid_data_path: 验证数据集路径,用于模型训练时的性能评估和超参数调优。

vocab_path:​字符词汇表文件,记录模型中使用的字符集(如中文字符、字母、数字等)。


2.模型架构

max_length:输入文本的最大序列长度。超过此长度的文本会被截断或填充(如用 [PAD])。

hidden_size:模型隐藏层神经元的数量,影响模型容量和计算复杂度。

num_layers:模型的堆叠层数(如LSTM、Transformer的编码器/解码器层数)。

class_num:任务类别总数。例如:NER任务中可能有9种实体类型。

vocab_size:词表大小


3.训练配置

epoch:训练轮数。每轮遍历整个训练数据集一次。

batch_size:每次梯度更新所使用的样本数量。较小的批次可能更适合内存受限的环境。

optimizer:优化器类型,用于调整模型参数。Adam是常用优化器,结合动量梯度下降。

learning_rate:学习率,控制参数更新的步长。值过小可能导致训练缓慢,过大易过拟合。

use_crf:是否启用条件随机场(CRF)​层。在序列标注任务(如NER)中,CRF可捕捉标签间的依赖关系,提升准确性。


4.预训练模型

bert_path:预训练BERT模型的路径。BERT是一种强大的预训练语言模型,此处可能用于微调或特征提取。

# -*- coding: utf-8 -*-

"""
配置参数信息
"""

Config = {
    "model_path": "model_output",
    "schema_path": "ner_data/schema.json",
    "train_data_path": "ner_data/train",
    "valid_data_path": "ner_data/test",
    "vocab_path":"chars.txt",
    "max_length": 100,
    "hidden_size": 256,
    "num_layers": 2,
    "epoch": 20,
    "batch_size": 16,
    "optimizer": "adam",
    "learning_rate": 1e-3,
    "use_crf": False,
    "class_num": 9,
    "bert_path": r"F:\人工智能NLP/NLP资料\week6 语言模型/bert-base-chinese",
    "vocab_size": 20000
}


二、数据加载 loader.py

1.初始化数据加载类

data_path:数据文件存储路径

config:包含训练 / 数据配置的字典

self.config:保存包含训练 / 数据配置的字典

self.path:保存数据文件存储路径

self.tokenizer:将文本数据转换为深度学习模型(如 BERT)可处理的输入格式的核心工具

self.sentences:初始化句子列表

self.schema:加载实体标签与索引的映射关系表

self.load:调用 load() 方法从 data_path 加载原始数据,进行分词、编码、填充/截断等预处理。

    def __init__(self, data_path, config):
        self.config = config
        self.path = data_path
        self.tokenizer = load_vocab(config["bert_path"])
        self.sentences = []
        self.schema = self.load_schema(config["schema_path"])
        self.load()

2.加载数据并预处理

① 初始化数据容器 ——>

② 文件读取与分段处理 ——>

③ 逐段解析字符与标签 ——>

④ 句子编码与填充 ——>

⑤ 数据封装与返回 

self.data:列表,存储预处理后的数据样本,每个样本由输入张量和标签张量组成

sentenece:保存原始文本句子的拼接结果,便于后续可视化或调试。

open():打开文件并返回文件对象,支持读/写/追加等模式。

参数名类型说明
file字符串文件路径(绝对/相对路径)
mode字符串打开模式(如 r-只读、w-写入、a-追加)
encoding字符串文件编码(如 utf-8,文本模式需指定)
errors字符串编码错误处理方式(如 ignorereplace

文件对象.read():读取文件内容,返回字符串或字节流

参数名类型说明
size整数可选,指定读取的字节数(默认读取全部内容)

split():按分隔符分割字符串,返回子字符串列表

参数名类型说明
delimiter字符串分隔符(默认空格)
maxsplit整数可选,最大分割次数(默认-1表示全部)

strip():去除字符串首尾指定字符(默认空白字符)

参数名类型说明
chars字符串可选,指定需去除的字符集合

join():用分隔符连接可迭代对象的元素,返回新字符串

参数名类型说明
iterable可迭代对象需连接的元素集合(如列表、元组)
sep字符串分隔符(默认空字符串)

列表.append():在列表末尾添加元素

参数名类型说明
obj任意类型要添加的元素
    def load(self):
        self.data = []
        with open(self.path, encoding="utf8") as f:
            segments = f.read().split("\n\n")
            for segment in segments:
                sentenece = []
                labels = [8]  # cls_token
                for line in segment.split("\n"):
                    if line.strip() == "":
                        continue
                    char, label = line.split()
                    sentenece.append(char)
                    labels.append(self.schema[label])
                sentence = "".join(sentenece)
                self.sentences.append(sentence)
                input_ids = self.encode_sentence(sentenece)
                labels = self.padding(labels, -1)
                # print(self.decode(sentence, labels))
                # input()
                self.data.append([torch.LongTensor(input_ids), torch.LongTensor(labels)])
        return

3.加载字 / 词表

BertTokenizer.from_pretrained():是 Hugging Face Transformers 库中用于加载预训练 BERT 分词器的核心方法。

def load_vocab(vocab_path):
    return BertTokenizer.from_pretrained(vocab_path)

4.加载映射关系表 

        加载位于指定路径的 JSON 格式的模式文件,并将其内容解析为 Python 对象以便在数据生成过程中使用。

open():打开文件并返回文件对象,用于读写文件内容。

参数名类型默认值说明
file_namestr文件路径(需包含扩展名)
modestr'r'文件打开模式:
'r': 只读
'w': 只写(覆盖原文件)
'a': 追加写入
'b': 二进制模式
'x': 创建新文件(若存在则报错)
bufferingintNone缓冲区大小(仅二进制模式有效)
encodingstrNone文件编码(仅文本模式有效,如 'utf-8'
newlinestr'\n'行结束符(仅文本模式有效)
closefdboolTrue是否在文件关闭时自动关闭文件描述符
dir_fdint-1文件描述符(高级用法,通常忽略)
flagsint0Linux 系统下的额外标志位
modestr(重复参数,实际使用中只需指定 mode

json.load():从已打开的 JSON 文件对象中加载数据,并将其转换为 Python 对象(如字典、列表)。

参数名类型默认值说明
fpio.TextIO已打开的文件对象(需处于读取模式)
indentint/strNone缩进空格数(美化输出,如 4 或 " "
sort_keysboolFalse是否对 JSON 键进行排序
load_hookcallableNone自定义对象加载回调函数
object_hookcallableNone自定义对象解析回调函数
    def load_schema(self, path):
        with open(path, encoding="utf8") as f:
            return json.load(f)

5.封装数据

DataLoader():PyTorch 模型训练的标配工具,通过合理的参数配置(如 batch_sizenum_workersshuffle),可以显著提升数据加载效率,尤其适用于大规模数据集和复杂预处理任务。其与 Dataset 类的配合使用,是构建高效训练管道的核心。

参数名类型默认值说明
datasetDatasetNone必须参数,自定义数据集对象(需继承 torch.utils.data.Dataset)。
batch_sizeint1每个批次的样本数量。
shuffleboolFalse是否在每个 epoch 开始时打乱数据顺序(训练时推荐设为 True)。
num_workersint0使用多线程加载数据的工人数量(需大于 0 时生效)。
pin_memoryboolFalse是否将数据存储在 pinned memory 中(加速 GPU 数据传输)。
drop_lastboolFalse如果数据集长度无法被 batch_size 整除,是否丢弃最后一个不完整的批次。
persistent_workersboolFalse是否保持工作线程在 epoch 之间持续运行(减少多线程初始化开销)。
worker_init_fncallableNone自定义工作线程初始化函数。
# 用torch自带的DataLoader类封装数据
def load_data(data_path, config, shuffle=True):
    dg = DataGenerator(data_path, config)
    dl = DataLoader(dg, batch_size=config["batch_size"], shuffle=shuffle)
    return dl

6.对于输入文本做截断 / 填充

    #补齐或截断输入的序列,使其可以在一个batch内运算
    def padding(self, input_id, pad_token=0):
        input_id = input_id[:self.config["max_length"]]
        input_id += [pad_token] * (self.config["max_length"] - len(input_id))
        return input_id

7.类内魔术方法

__len__():用于定义对象的“长度”,通过内置函数 len() 调用时返回该值。它通常用于容器类(如列表、字典、自定义数据结构),表示容器中元素的个数

__getitem__():允许对象通过索引或键值访问元素,支持 obj[index] 或 obj[key] 语法。它使对象表现得像序列(如列表)或映射(如字典)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

8.对于输入的文本编码

self.tokenizer:将文本数据转换为深度学习模型(如 BERT)可处理的输入格式的核心工具

self.tokenizer.encode():Hugging Face Transformers 库中 BertTokenizer 的核心方法,用于将原始文本转换为模型可处理的输入形式。

参数名类型默认值说明
textstr 或 List[str]必填输入文本(单句或句子对)。
text_pairstrNone第二段文本(用于句子对任务,如问答),与 text 拼接后生成 [CLS] text [SEP] text_pair [SEP]
add_special_tokensboolTrue是否添加 [CLS] 和 [SEP] 标记。关闭后仅返回原始分词索引
max_lengthint512最大序列长度。超长文本会被截断,不足则填充
paddingstr 或 boolFalse填充策略:True/'longest'(按批次最长填充)、'max_length'(按 max_length 填充)
truncationstr 或 boolFalse截断策略:True(按 max_length 截断)、'only_first'(仅截断第一句)
return_tensorsstrNone

返回张量类型:

'pt'(PyTorch)、'tf'(TensorFlow)、'np'(NumPy)

return_attention_maskboolTrue是否生成 attention_mask,标识有效内容(1)与填充部分(0)
    def encode_sentence(self, text, padding=True):
        return self.tokenizer.encode(text,
                                     padding="max_length",
                                     max_length=self.config["max_length"],
                                     truncation=True)

9.对于编码后的输入文本作解码

(04+): 匹配以 0(B-LOCATION)开头,后接多个 4(I-LOCATION)的连续标签

(15+)(26+)(37+)分别对应 ORGANIZATION(B=1, I=5)、PERSON(B=2, I=6)、TIME(B=3, I=7)的标签模式。

sentence:输入的原句(添加 $ 后的版本),用于根据标签索引提取实体文本。

lables:模型输出的标签序列,转换为字符串后通过正则匹配定位实体位置。

results:存储提取的实体,键为实体类型(如 "LOCATION"),值为该类型实体的文本列表。

location:正则匹配结果,通过 span() 获取实体在 sentence 中的起止位置,用于提取具体文本片段。

join():将可迭代对象(列表、元组等)中的元素按指定分隔符连接成一个字符串。调用该方法的字符串作为分隔符。

参数名类型默认值说明
iterable可迭代对象必填需连接的元素集合,所有元素必须是字符串类型。若为空,返回空字符串。

str():将其他数据类型(整数、浮点数、布尔值等)转换为字符串类型。支持格式化输出和复杂对象的字符串表示。

参数名类型默认值说明
object任意类型必填需转换的对象,如整数、列表、字典等。
encoding字符串可选编码格式(仅对字节类型有效),如 utf-8
errors字符串可选编码错误处理策略,如 ignorereplace

defaultdict():创建字典的子类,为不存在的键自动生成默认值。需指定 default_factory(如 listint)定义默认值类型。

参数名类型默认值说明
default_factory可调用对象或无参数函数None用于生成默认值的函数。若未指定,访问不存在的键会抛出 KeyError
**kwargs关键字参数可选其他初始化字典的键值对,如 name="Alice"

re.finditer():在字符串中全局搜索正则表达式匹配项,返回一个迭代器,每个元素为 Match 对象

参数名类型说明
patternstr 或正则表达式对象要匹配的正则表达式模式
stringstr要搜索的字符串
flagsint (可选)正则匹配标志(如 re.IGNORECASE

.span():返回正则匹配的起始和结束索引(左闭右开区间)

列表.append():向列表末尾添加单个元素,直接修改原列表

参数名类型说明
element任意要添加的元素
    def decode(self, sentence, labels):
        sentence = "$" + sentence
        labels = "".join([str(x) for x in labels[:len(sentence) + 2]])
        results = defaultdict(list)
        for location in re.finditer("(04+)", labels):
            s, e = location.span()
            print("location", s, e)
            results["LOCATION"].append(sentence[s:e])
        for location in re.finditer("(15+)", labels):
            s, e = location.span()
            print("org", s, e)
            results["ORGANIZATION"].append(sentence[s:e])
        for location in re.finditer("(26+)", labels):
            s, e = location.span()
            print("per", s, e)
            results["PERSON"].append(sentence[s:e])
        for location in re.finditer("(37+)", labels):
            s, e = location.span()
            print("time", s, e)
            results["TIME"].append(sentence[s:e])
        return results

完整代码

# -*- coding: utf-8 -*-

import json
import re
import os
import torch
import random
import jieba
import numpy as np
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
from transformers import BertTokenizer

"""
数据加载
"""


class DataGenerator:
    def __init__(self, data_path, config):
        self.config = config
        self.path = data_path
        self.tokenizer = load_vocab(config["bert_path"])
        self.sentences = []
        self.schema = self.load_schema(config["schema_path"])
        self.load()

    def load(self):
        self.data = []
        with open(self.path, encoding="utf8") as f:
            segments = f.read().split("\n\n")
            for segment in segments:
                sentenece = []
                labels = [8]  # cls_token
                for line in segment.split("\n"):
                    if line.strip() == "":
                        continue
                    char, label = line.split()
                    sentenece.append(char)
                    labels.append(self.schema[label])
                sentence = "".join(sentenece)
                self.sentences.append(sentence)
                input_ids = self.encode_sentence(sentenece)
                labels = self.padding(labels, -1)
                # print(self.decode(sentence, labels))
                # input()
                self.data.append([torch.LongTensor(input_ids), torch.LongTensor(labels)])
        return

    def encode_sentence(self, text, padding=True):
        return self.tokenizer.encode(text,
                                     padding="max_length",
                                     max_length=self.config["max_length"],
                                     truncation=True)

    def decode(self, sentence, labels):
        sentence = "$" + sentence
        labels = "".join([str(x) for x in labels[:len(sentence) + 2]])
        results = defaultdict(list)
        for location in re.finditer("(04+)", labels):
            s, e = location.span()
            print("location", s, e)
            results["LOCATION"].append(sentence[s:e])
        for location in re.finditer("(15+)", labels):
            s, e = location.span()
            print("org", s, e)
            results["ORGANIZATION"].append(sentence[s:e])
        for location in re.finditer("(26+)", labels):
            s, e = location.span()
            print("per", s, e)
            results["PERSON"].append(sentence[s:e])
        for location in re.finditer("(37+)", labels):
            s, e = location.span()
            print("time", s, e)
            results["TIME"].append(sentence[s:e])
        return results

    # 补齐或截断输入的序列,使其可以在一个batch内运算
    def padding(self, input_id, pad_token=0):
        input_id = input_id[:self.config["max_length"]]
        input_id += [pad_token] * (self.config["max_length"] - len(input_id))
        return input_id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

    def load_schema(self, path):
        with open(path, encoding="utf8") as f:
            return json.load(f)


def load_vocab(vocab_path):
    return BertTokenizer.from_pretrained(vocab_path)


# 用torch自带的DataLoader类封装数据
def load_data(data_path, config, shuffle=True):
    dg = DataGenerator(data_path, config)
    dl = DataLoader(dg, batch_size=config["batch_size"], shuffle=shuffle)
    return dl


if __name__ == "__main__":
    from config import Config

    dg = DataGenerator("ner_data/train", Config)
    dl = DataLoader(dg, batch_size=32)
    for x, y in dl:
        print(x.shape, y.shape)
        print(x[1], y[1])
        input()

三、模型建立 model.py

1.模型初始化

hidden_size:定义LSTM隐藏层的维度(即每个时间步输出的特征数量

vocab_size:词表大小,即嵌入层(Embedding)可处理的词汇总数

max_length:输入序列的最大长度,用于数据预处理(如截断或填充)

class_num:分类任务的类别数量,决定线性层(nn.Linear)的输出维度

num_layers:堆叠的LSTM层数,用于增加模型复杂度

BertModel.from_pretrained():加载预训练的 BERT 模型,支持从本地或 Hugging Face 模型库加载

参数名类型默认值说明
pretrained_model_name字符串预训练模型名称或路径(如 bert-base-chinese
config字典/类默认配置自定义模型配置,覆盖默认参数(如隐藏层维度、注意力头数)
cache_dir字符串None模型缓存目录
output_hidden_states布尔值False是否返回所有隐藏层输出(用于特征提取)

nn.Linear():实现全连接层的线性变换(y = xW^T + b

参数名类型默认值说明
in_features整数输入特征维度(如词向量维度 hidden_size
out_features整数输出特征维度(如分类类别数 class_num
bias布尔值True是否启用偏置项

CRF():条件随机场层,用于序列标注任务中约束标签转移逻辑。

参数名类型默认值说明
num_tags整数标签类别数(如 class_num
batch_first布尔值False输入张量是否为 (batch_size, seq_len) 格式

torch.nn.CrossEntropyLoss():计算交叉熵损失,常用于分类任务

参数名类型默认值说明
ignore_index整数-1忽略指定索引的标签(如填充符 -1
reduction字符串mean损失聚合方式(可选 nonesummean
    def __init__(self, config):
        super(TorchModel, self).__init__()
        hidden_size = config["hidden_size"]
        vocab_size = config["vocab_size"] + 1
        max_length = config["max_length"]
        class_num = config["class_num"]
        num_layers = config["num_layers"]
        # self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
        self.layer = BertModel.from_pretrained(config["bert_path"], hidden_size=hidden_size, num_layers=num_layers)
        # self.layer = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True, num_layers=num_layers)
        self.classify = nn.Linear(hidden_size * 2, class_num)
        self.crf_layer = CRF(class_num, batch_first=True)
        self.use_crf = config["use_crf"]
        self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1)  #loss采用交叉熵损失

2.前向计算 

计算流程

输入 x → 嵌入层 → LSTM层 → 分类层 → 分支判断:
          │
          ├── 有 target → CRF? → 是:计算 CRF 损失
          │                 │
          │                 └→ 否:计算交叉熵损失
          │
          └── 无 target → CRF? → 是:解码最优标签序列
                            │
                            └→ 否:返回预测 logits

gt():张量的逐元素比较函数,返回布尔型张量,标记输入张量中大于指定值的元素位置。常用于生成掩码(如忽略填充符)

参数名类型默认值说明
otherTensor/标量比较的阈值或张量。若为标量,则张量中每个元素与该值比较;若为张量,需与输入张量形状相同。
outTensorNone可选输出张量,用于存储结果。

shape():返回张量的维度信息,描述各轴的大小。

view():调整张量的形状,支持自动推断维度(通过-1占位符)。常用于数据展平或维度转换。

参数名类型默认值说明
*shape可变参数目标形状的维度序列,如view(2, 3)view(-1, 28)-1表示自动计算。
    #当输入真实标签,返回loss值;无真实标签,返回预测值
    def forward(self, x, target=None):
        x = self.embedding(x)  #input shape:(batch_size, sen_len)
        x, _ = self.layer(x)      #input shape:(batch_size, sen_len, input_dim)
        predict = self.classify(x) #ouput:(batch_size, sen_len, num_tags) -> (batch_size * sen_len, num_tags)

        if target is not None:
            if self.use_crf:
                mask = target.gt(-1)
                # loss 是 crf 的相反数,即 - crf(predict, target, mask)
                return - self.crf_layer(predict, target, mask, reduction="mean")
            else:
                #(number, class_num), (number)
                return self.loss(predict.view(-1, predict.shape[-1]), target.view(-1))
        else:
            if self.use_crf:
                return self.crf_layer.decode(predict)
            else:
                return predict

3.选择优化器 

Adam():自适应矩估计优化器(Adaptive Moment Estimation),结合动量和 RMSProp 的优点。

参数名类型默认值说明
lrfloat1e-3学习率。
betastuple(0.9, 0.999)动量系数(β₁, β₂)。
epsfloat1e-8防止除零误差。
weight_decayfloat0权重衰减率。
amsgradboolFalse是否启用 AMSGrad 优化。
foreachboolFalse是否为每个参数单独计算梯度。

SGD():随机梯度下降优化器(Stochastic Gradient Descent)

参数名类型默认值说明
lrfloat1e-3学习率。
momentumfloat0动量系数(如 momentum=0.9)。
weight_decayfloat0权重衰减率。
dampeningfloat0动力衰减系数(用于 SGD with Momentum)。
nesterovboolFalse是否启用 Nesterov 动量。
foreachboolFalse是否为每个参数单独计算梯度。

parameters():返回模型所有可训练参数的迭代器,常用于参数初始化或梯度清零。

参数名类型默认值说明
filtercallableNone过滤条件函数(如 lambda p: p.requires_grad)。默认返回所有参数。
def choose_optimizer(config, model):
    optimizer = config["optimizer"]
    learning_rate = config["learning_rate"]
    if optimizer == "adam":
        return Adam(model.parameters(), lr=learning_rate)
    elif optimizer == "sgd":
        return SGD(model.parameters(), lr=learning_rate)

4.模型建立

# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
from torch.optim import Adam, SGD
from torchcrf import CRF
import torch
from transformers import BertModel

"""
建立网络模型结构
"""

class TorchModel(nn.Module):
    def __init__(self, config):
        super(TorchModel, self).__init__()
        hidden_size = config["hidden_size"]
        vocab_size = config["vocab_size"] + 1
        max_length = config["max_length"]
        class_num = config["class_num"]
        num_layers = config["num_layers"]
        self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
        # self.layer = BertModel.from_pretrained(config["bert_path"], hidden_size=hidden_size, num_layers=num_layers)
        self.layer = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True, num_layers=num_layers)
        self.classify = nn.Linear(hidden_size * 2, class_num)
        self.crf_layer = CRF(class_num, batch_first=True)
        self.use_crf = config["use_crf"]
        self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1)  #loss采用交叉熵损失

    #当输入真实标签,返回loss值;无真实标签,返回预测值
    def forward(self, x, target=None):
        x = self.embedding(x)  #input shape:(batch_size, sen_len)
        x, _ = self.layer(x)      #input shape:(batch_size, sen_len, input_dim)
        predict = self.classify(x) #ouput:(batch_size, sen_len, num_tags) -> (batch_size * sen_len, num_tags)

        if target is not None:
            if self.use_crf:
                mask = target.gt(-1)
                # loss 是 crf 的相反数,即 - crf(predict, target, mask)
                return - self.crf_layer(predict, target, mask, reduction="mean")
            else:
                #(number, class_num), (number)
                return self.loss(predict.view(-1, predict.shape[-1]), target.view(-1))
        else:
            if self.use_crf:
                return self.crf_layer.decode(predict)
            else:
                return predict


def choose_optimizer(config, model):
    optimizer = config["optimizer"]
    learning_rate = config["learning_rate"]
    if optimizer == "adam":
        return Adam(model.parameters(), lr=learning_rate)
    elif optimizer == "sgd":
        return SGD(model.parameters(), lr=learning_rate)


if __name__ == "__main__":
    from config import Config
    model = TorchModel(Config)

四、模型效果测试 evaluate.py

1.模型流程

Ⅰ、​数据准备与初始化

加载验证集:从指定路径加载预处理后的验证数据,保持数据顺序以避免随机性干扰评估结果​

初始化统计字典:为每个实体类别(如LOCATION、PERSON等)创建计数器,记录“正确识别数”“样本实体数”等统计指标


Ⅱ、​模型推理与预测

切换评估模式:调用model.eval()关闭Dropout等训练层,确保推理稳定性​

批次处理

        ​数据迁移至GPU:若CUDA可用,将输入ID和标签移至GPU加速计算​

        无梯度预测:在torch.no_grad()上下文中执行模型推理,减少内存占用

        ​输出处理:若未使用CRF层,通过torch.argmax直接获取预测标签序列;若使用CRF,需解码最优路径


​Ⅲ、实体解码与对齐

标签序列转换:将数值标签拼接为字符串(如[0,4,4]"044"),并截取与句子长度对齐

正则匹配实体

        规则定义:通过正则表达式匹配标签模式(如04+表示LOCATION实体),提取连续B-I标签对应的文本片段

        ​索引对齐:根据匹配的起止位置从原始句子中截取实体(例如"04+"匹配到索引3-5,则提取句子[3:5]


Ⅳ、​统计与评估指标计算

对比真实与预测实体:遍历每个句子的实体列表,统计以下指标:​

        正确识别数:预测实体存在于真实列表中的数量。

        样本实体数:真实实体总数。

        ​识别出实体数:预测实体总数​

计算指标

        ​精确率(Precision)​:正确识别数 / 识别出实体数。

        召回率(Recall)​:正确识别数 / 样本实体数。

        F1值:精确率与召回率的调和平均​

输出结果:按实体类别输出指标,并计算宏平均(Macro-F1)和微平均(Micro-F1)


Ⅴ、关键设计细节

标签编码规则:采用BIO格式(如B-LOCATION=0,I-LOCATION=4),确保实体连续性​

异常处理:添加1e-5平滑项避免除零错误,增强数值稳定性​

性能优化:禁用梯度计算、GPU加速、批次处理提升效率


2.初始化

Ⅰ、加载配置文件、模型及日志模块 ——>

Ⅱ、读取验证集数据(固定顺序,避免随机性干扰评估)——>

Ⅲ、初始化统计字典 stats_dict,按实体类别记录正确识别数、样本实体数等

config:存储运行时配置,例如数据路径、超参数(如批次大小 batch_size)、是否使用CRF层等。通过 config["valid_data_path"] 动态获取验证集路径。

model:待评估的模型实例,用于调用预测方法(如 model(input_id)),需提前完成训练和加载。

logger:记录运行日志,例如输出评估指标(准确率、F1值)到文件或控制台,便于调试和监控。

valid_data:验证数据集,用于模型训练时的性能评估和超参数调优。

load_data():数据加载类中,用torch自带的DataLoader类封装数据的函数

    def __init__(self, config, model, logger):
        self.config = config
        self.model = model
        self.logger = logger
        self.valid_data = load_data(config["valid_data_path"], config, shuffle=False)

 3.统计模型效果

Ⅰ、解码实体 ——> Ⅱ、对比结果

len():返回对象的元素数量(字符串、列表、元组、字典等)

参数名类型说明
object任意可迭代对象如字符串、列表、字典等

torch.argmax():返回张量中最大值所在的索引

参数名类型说明
inputTensor输入张量
dimint沿指定维度查找最大值
keepdimbool是否保持输出维度一致

cpu():将张量从GPU移动到CPU内存

zip():将多个可迭代对象打包成元组列表

参数名类型说明
iterables多个可迭代对象如列表、元组、字符串

.detach():从计算图中分离张量,阻止梯度传播

.tolist():将张量或数组转换为Python列表

    def write_stats(self, labels, pred_results, sentences):
        assert len(labels) == len(pred_results) == len(sentences)
        if not self.config["use_crf"]:
            pred_results = torch.argmax(pred_results, dim=-1)
        for true_label, pred_label, sentence in zip(labels, pred_results, sentences):
            if not self.config["use_crf"]:
                pred_label = pred_label.cpu().detach().tolist()
            true_label = true_label.cpu().detach().tolist()
            true_entities = self.decode(sentence, true_label)
            pred_entities = self.decode(sentence, pred_label)

            # 正确率 = 识别出的正确实体数 / 识别出的实体数
            # 召回率 = 识别出的正确实体数 / 样本的实体数
            for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:
                self.stats_dict[key]["正确识别"] += len([ent for ent in pred_entities[key] if ent in true_entities[key]])
                self.stats_dict[key]["样本实体数"] += len(true_entities[key])
                self.stats_dict[key]["识别出实体数"] += len(pred_entities[key])
        return

4.可视化统计模型效果

精确率 (Precision):正确预测实体数 / 总预测实体数

召回率 (Recall):正确预测实体数 / 总真实实体数​

F1值:精确率与召回率的调和平均

F1:F1分数:准确率与召回率的调和平均数,综合衡量模型的精确性与覆盖能力。

F1_scores:存储四个实体类别的 F1 分数,用于计算宏观平均。

precision:准确率:模型预测为某类实体的结果中,正确的比例。反映模型预测的精确度。

recall:召回率:真实存在的某类实体中,被模型正确识别的比例。反映模型对实体的覆盖能力。

key:当前处理的实体类别(如 "PERSON""LOCATION")。

correct_pred:总正确识别数:所有类别中被正确识别的实体总数。

total_pred:总识别实体数:模型预测出的所有实体数量(含错误识别)。

true_enti:总样本实体数:验证数据中真实存在的所有实体数量。

micro_precision:微观准确率:全局视角下的准确率,所有实体类别的正确识别数与总识别数的比例。

micro_recall:微观召回率:全局视角下的召回率,所有实体类别的正确识别数与总样本实体数的比例。

micro_f1:微观F1分数:微观准确率与微观召回率的调和平均数。

列表.append():在列表末尾添加元素

参数名类型说明
element任意要添加的元素

logger.info():记录日志信息(需配置日志模块)

参数名类型说明
formatstr格式化字符串
*args可变参数格式化参数

sum():计算可迭代对象的元素总和

参数名类型说明
iterable可迭代对象如列表、元组
start数值(可选)初始累加值

列表推导式:通过简洁语法生成新列表,语法:[表达式 for item in iterable if 条件]

    def show_stats(self):
        F1_scores = []
        for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:
            # 正确率 = 识别出的正确实体数 / 识别出的实体数
            # 召回率 = 识别出的正确实体数 / 样本的实体数
            precision = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["识别出实体数"])
            recall = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["样本实体数"])
            F1 = (2 * precision * recall) / (precision + recall + 1e-5)
            F1_scores.append(F1)
            self.logger.info("%s类实体,准确率:%f, 召回率: %f, F1: %f" % (key, precision, recall, F1))
        self.logger.info("Macro-F1: %f" % np.mean(F1_scores))
        correct_pred = sum([self.stats_dict[key]["正确识别"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
        total_pred = sum([self.stats_dict[key]["识别出实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
        true_enti = sum([self.stats_dict[key]["样本实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
        micro_precision = correct_pred / (total_pred + 1e-5)
        micro_recall = correct_pred / (true_enti + 1e-5)
        micro_f1 = (2 * micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-5)
        self.logger.info("Micro-F1 %f" % micro_f1)
        self.logger.info("--------------------")
        return

5.评估模型效果

模型切换为评估模式:关闭Dropout等训练层​

批次处理数据

         提取原始句子 sentences

   将数据迁移至GPU(若可用)

         预测时禁用梯度计算(torch.no_grad())优化内存

统计结果:调用 write_stats 对比预测与真实标签

epoch:当前训练轮次,用于日志。

logger:记录日志的工具。

stats_dict:统计字典,记录各实体类别的指标。

valid_data:验证数据集,通常由 load_data 加载(如 config["valid_data_path"] 指定路径)

index: 循环中的批次索引

batch_data: 循环中的数据。

sentences:当前批次的原始句子

pred_results:模型预测结果

write_stats():写入统计信息

show_stats():显示统计结果

logger.info():记录日志信息(需配置日志模块)

参数名类型说明
formatstr格式化字符串
*args可变参数格式化参数

defaultdict():创建带有默认值工厂的字典

参数名类型说明
default_factory可调用对象如int、list、自定义函数

model.eval():将模型设置为评估模式(关闭Dropout等训练层)

enumerate():返回索引和元素组成的枚举对象

参数名类型说明
iterable可迭代对象如列表、字符串
startint(可选)起始索引,默认为0

torch.cuda.is_available():检查当前环境是否支持CUDA(GPU加速)

cuda():将张量或模型移动到GPU

参数名类型说明
deviceint/str指定GPU设备号,如"cuda:0"

torch.no_grad():禁用梯度计算,节省内存并加速推理

    def eval(self, epoch):
        self.logger.info("开始测试第%d轮模型效果:" % epoch)
        self.stats_dict = {"LOCATION": defaultdict(int),
                           "TIME": defaultdict(int),
                           "PERSON": defaultdict(int),
                           "ORGANIZATION": defaultdict(int)}
        self.model.eval()
        for index, batch_data in enumerate(self.valid_data):
            sentences = self.valid_data.dataset.sentences[index * self.config["batch_size"]: (index+1) * self.config["batch_size"]]
            if torch.cuda.is_available():
                batch_data = [d.cuda() for d in batch_data]
            input_id, labels = batch_data   #输入变化时这里需要修改,比如多输入,多输出的情况
            with torch.no_grad():
                pred_results = self.model(input_id) #不输入labels,使用模型当前参数进行预测
            self.write_stats(labels, pred_results, sentences)
        self.show_stats()
        return

6.解码

标签序列预处理:将数值标签拼接为字符串(如 [0,4,4] → "044"

正则匹配实体

   04+B-LOCATION(0)后接多个I-LOCATION(4)

   15+B-ORGANIZATION(1)后接I-ORGANIZATION(5)

           其他实体类别同理

索引对齐:根据匹配位置截取原始句子中的实体文本

Ⅰ、输入预处理

在原句首添加 $ 符号,通常用于对齐标签与字符位置(例如避免索引越界)

        sentence = "$" + sentence

Ⅱ、标签序列转换

将整数标签序列转换为字符串,并截取长度与 sentence 对齐

str.join():将可迭代对象中的字符串元素按指定分隔符连接成一个新字符串

参数名类型说明
iterable可迭代对象元素必须为字符串类型

str():将对象转换为字符串表示形式,支持自定义类的 __str__ 方法

参数名类型说明
object任意要转换的对象

len():返回对象的长度或元素个数(适用于字符串、列表、字典等)

参数名类型说明
object可迭代对象如字符串、列表等

列表推导式:通过简洁语法生成新列表,支持条件过滤和多层循环

        [expression for item in iterable if condition]

部分类型说明
expression表达式对 item 处理后的结果
item变量迭代变量
iterable可迭代对象如列表、range() 生成的序列
condition条件表达式 (可选)过滤不符合条件的元素
        labels = "".join([str(x) for x in labels[:len(sentence)+1]])

Ⅲ、初始化结果容器

创建默认值为列表的字典,存储四类实体:

        (LOCATION、ORGANIZATION、PERSON、TIME)的识别结果

defaultdict():创建默认值字典,当键不存在时自动生成默认值(基于工厂函数)

参数名类型说明
default_factory可调用对象如 intlist 或自定义函数
        results = defaultdict(list)

Ⅳ、正则表达式匹配

    (04+): 匹配以 0(B-LOCATION)开头,后接多个 4(I-LOCATION)的连续标签

    (15+)(26+)(37+)分别对应 ORGANIZATION(B=1, I=5)、PERSON(B=2, I=6)、TIME(B=3, I=7)的标签模式。

re.finditer():在字符串中全局搜索正则表达式匹配项,返回一个迭代器,每个元素为 Match 对象

参数名类型说明
patternstr 或正则表达式对象要匹配的正则表达式模式
stringstr要搜索的字符串
flagsint (可选)正则匹配标志(如 re.IGNORECASE

.span():返回正则匹配的起始和结束索引(左闭右开区间)

列表.append():向列表末尾添加单个元素,直接修改原列表

参数名类型说明
element任意要添加的元素
        for location in re.finditer("(04+)", labels):
            s, e = location.span()
            results["LOCATION"].append(sentence[s:e])

Ⅴ、完整代码 

    '''
    {
      "B-LOCATION": 0,
      "B-ORGANIZATION": 1,
      "B-PERSON": 2,
      "B-TIME": 3,
      "I-LOCATION": 4,
      "I-ORGANIZATION": 5,
      "I-PERSON": 6,
      "I-TIME": 7,
      "O": 8
    }
    '''
    def decode(self, sentence, labels):
        sentence = "$" + sentence
        labels = "".join([str(x) for x in labels[:len(sentence)+1]])
        results = defaultdict(list)
        for location in re.finditer("(04+)", labels):
            s, e = location.span()
            results["LOCATION"].append(sentence[s:e])
        for location in re.finditer("(15+)", labels):
            s, e = location.span()
            results["ORGANIZATION"].append(sentence[s:e])
        for location in re.finditer("(26+)", labels):
            s, e = location.span()
            results["PERSON"].append(sentence[s:e])
        for location in re.finditer("(37+)", labels):
            s, e = location.span()
            results["TIME"].append(sentence[s:e])
        return results

7.完整代码 

# -*- coding: utf-8 -*-
import torch
import re
import numpy as np
from collections import defaultdict
from loader import load_data

"""
模型效果测试
"""

class Evaluator:
    def __init__(self, config, model, logger):
        self.config = config
        self.model = model
        self.logger = logger
        self.valid_data = load_data(config["valid_data_path"], config, shuffle=False)


    def eval(self, epoch):
        self.logger.info("开始测试第%d轮模型效果:" % epoch)
        self.stats_dict = {"LOCATION": defaultdict(int),
                           "TIME": defaultdict(int),
                           "PERSON": defaultdict(int),
                           "ORGANIZATION": defaultdict(int)}
        self.model.eval()
        for index, batch_data in enumerate(self.valid_data):
            sentences = self.valid_data.dataset.sentences[index * self.config["batch_size"]: (index+1) * self.config["batch_size"]]
            if torch.cuda.is_available():
                batch_data = [d.cuda() for d in batch_data]
            input_id, labels = batch_data   #输入变化时这里需要修改,比如多输入,多输出的情况
            with torch.no_grad():
                pred_results = self.model(input_id) #不输入labels,使用模型当前参数进行预测
            self.write_stats(labels, pred_results, sentences)
        self.show_stats()
        return

    def write_stats(self, labels, pred_results, sentences):
        assert len(labels) == len(pred_results) == len(sentences)
        if not self.config["use_crf"]:
            pred_results = torch.argmax(pred_results, dim=-1)
        for true_label, pred_label, sentence in zip(labels, pred_results, sentences):
            if not self.config["use_crf"]:
                pred_label = pred_label.cpu().detach().tolist()
            true_label = true_label.cpu().detach().tolist()
            true_entities = self.decode(sentence, true_label)
            pred_entities = self.decode(sentence, pred_label)

            # 正确率 = 识别出的正确实体数 / 识别出的实体数
            # 召回率 = 识别出的正确实体数 / 样本的实体数
            for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:
                self.stats_dict[key]["正确识别"] += len([ent for ent in pred_entities[key] if ent in true_entities[key]])
                self.stats_dict[key]["样本实体数"] += len(true_entities[key])
                self.stats_dict[key]["识别出实体数"] += len(pred_entities[key])
        return

    def show_stats(self):
        F1_scores = []
        for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:
            # 正确率 = 识别出的正确实体数 / 识别出的实体数
            # 召回率 = 识别出的正确实体数 / 样本的实体数
            precision = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["识别出实体数"])
            recall = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["样本实体数"])
            F1 = (2 * precision * recall) / (precision + recall + 1e-5)
            F1_scores.append(F1)
            self.logger.info("%s类实体,准确率:%f, 召回率: %f, F1: %f" % (key, precision, recall, F1))
        self.logger.info("Macro-F1: %f" % np.mean(F1_scores))
        correct_pred = sum([self.stats_dict[key]["正确识别"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
        total_pred = sum([self.stats_dict[key]["识别出实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
        true_enti = sum([self.stats_dict[key]["样本实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
        micro_precision = correct_pred / (total_pred + 1e-5)
        micro_recall = correct_pred / (true_enti + 1e-5)
        micro_f1 = (2 * micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-5)
        self.logger.info("Micro-F1 %f" % micro_f1)
        self.logger.info("--------------------")
        return

    '''
    {
      "B-LOCATION": 0,
      "B-ORGANIZATION": 1,
      "B-PERSON": 2,
      "B-TIME": 3,
      "I-LOCATION": 4,
      "I-ORGANIZATION": 5,
      "I-PERSON": 6,
      "I-TIME": 7,
      "O": 8
    }
    '''
    def decode(self, sentence, labels):
        sentence = "$" + sentence
        labels = "".join([str(x) for x in labels[:len(sentence)+1]])
        results = defaultdict(list)
        for location in re.finditer("(04+)", labels):
            s, e = location.span()
            results["LOCATION"].append(sentence[s:e])
        for location in re.finditer("(15+)", labels):
            s, e = location.span()
            results["ORGANIZATION"].append(sentence[s:e])
        for location in re.finditer("(26+)", labels):
            s, e = location.span()
            results["PERSON"].append(sentence[s:e])
        for location in re.finditer("(37+)", labels):
            s, e = location.span()
            results["TIME"].append(sentence[s:e])
        return results


五、主函数文件 main.py

① 环境初始化与配置加载 ——>

② 数据加载与预处理 ——>

③ 模型初始化与硬件适配 ——>

④ 优化器与评估器初始化 ——>

⑤ 训练循环与参数更新 ——>

⑥ 模型评估与权重保存

1.导入文件

# -*- coding: utf-8 -*-

import torch
import os
import random
import numpy as np
import logging
from config import Config
from model import TorchModel, choose_optimizer
from evaluate import Evaluator
from loader import load_data

2.日志配置

logging.basicConfig():配置日志系统的基础参数(一次性设置,应在首次日志调用前调用)

参数名类型是否必需默认值说明
filename字符串None日志输出文件名(若指定,日志写入文件而非控制台)
filemode字符串'a'文件打开模式(如'w'覆盖,'a'追加)
format字符串基础格式日志格式模板(如'%(asctime)s - %(levelname)s - %(message)s'
datefmt字符串时间格式(如'%Y-%m-%d %H:%M:%S'
level整数WARNING日志级别(如logging.INFOlogging.DEBUG
stream对象None指定日志输出流(如sys.stderr,与filename互斥)

logging.getLogger():获取或创建指定名称的日志记录器(Logger)。若nameNone,返回根日志记录器

参数名类型是否必需默认值说明
name字符串None日志记录器名称(分层结构,如'module.sub'
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

3.主函数 main  

Ⅰ、创建模型保存目录

os.path.isdir():检查指定路径是否为目录(文件夹)

参数名类型是否必需默认值说明
path字符串要检查的路径(绝对或相对)

os.mkdir():创建单个目录(若父目录不存在会抛出异常)

参数名类型是否必需默认值说明
path字符串要创建的目录路径
mode整数0o777目录权限(八进制格式,某些系统可能忽略此参数)
    #创建保存模型的目录
    if not os.path.isdir(config["model_path"]):
        os.mkdir(config["model_path"])

Ⅱ、加载训练数据

    #加载训练数据
    train_data = load_data(config["train_data_path"], config)

Ⅲ、加载模型

    #加载模型
    model = TorchModel(config)

Ⅳ、检查GPU并迁移模型

torch.cuda.is_available():检查系统是否满足 CUDA 环境要求

logger.info():记录日志信息,输出训练过程中的关键状态

参数类型必须说明示例
msgstr日志消息(支持格式化字符串)logger.info("Epoch: %d", epoch)
*argsAny格式化参数(用于%占位符)

cuda():将张量或模型移动到GPU显存,加速计算

参数类型必须说明示例
deviceint/str指定GPU设备(如0"cuda:0"tensor.cuda(device=0)
non_blockingbool是否异步传输数据(默认False)tensor.cuda(non_blocking=True)
    # 标识是否使用gpu
    cuda_flag = torch.cuda.is_available()
    if cuda_flag:
        logger.info("gpu可以使用,迁移模型至gpu")
        model = model.cuda()

Ⅴ、加载优化器

    #加载优化器
    optimizer = choose_optimizer(config, model)

Ⅵ、加载评估器

    #加载效果测试类
    evaluator = Evaluator(config, model, logger)

Ⅶ、模型训练 ⭐

① Epoch循环控制

range():Python 内置函数,用于生成一个不可变的整数序列,​核心功能是为循环控制提供高效的数值迭代支持

参数名类型默认值说明
start整数0序列起始值(包含)。若省略,则默认从 0 开始。例如 range(3) 等价于 range(0,3)
stop整数必填序列结束值(不包含)。例如 range(2, 5) 生成 2,3,4
step整数1步长(正/负):
- ​正步长需满足 start < stop,否则无输出(如 range(5, 2) 无效)。
- ​负步长需满足 start > stop,例如 range(5, 0, -1) 生成 5,4,3,2,1
​**不能为 0**​(否则触发 ValueError
for epoch in range(config["epoch"]):
    epoch += 1
② 模型设置训练模式 

train_loss:计算当前批次的损失值,通常结合损失函数(如交叉熵、均方误差)使用

model.train():设置模型为训练模式,启用Dropout、BatchNorm等层的训练行为

参数类型默认值说明示例
modeboolTrue是否启用训练模式(True)或评估模式(False)model.train(True)

logger.info():记录日志信息,输出训练过程中的关键状态

参数类型必须说明示例
msgstr日志消息(支持格式化字符串)logger.info("Epoch: %d", epoch)
*argsAny格式化参数(用于%占位符)
        model.train()
        logger.info("epoch %d begin" % epoch)
        train_loss = []

③ Batch数据遍历

enumerate():遍历可迭代对象时返回索引和元素,支持自定义起始索引

参数类型必须说明示例
iterableIterable可迭代对象(如列表、生成器)enumerate(["a", "b"])
startint索引起始值(默认0)enumerate(data, start=1)
        for index, batch_data in enumerate(train_data):

④ 梯度清零与设备切换

optimizer.zero_grad():清空模型参数的梯度,防止梯度累积

参数类型必须说明示例
set_to_nonebool是否将梯度置为None(高效但危险)optimizer.zero_grad(True)

cuda():将张量或模型移动到GPU显存,加速计算

参数类型必须说明示例
deviceint/str指定GPU设备(如0"cuda:0"tensor.cuda(device=0)
non_blockingbool是否异步传输数据(默认False)tensor.cuda(non_blocking=True)
            optimizer.zero_grad()
            if cuda_flag:
                batch_data = [d.cuda() for d in batch_data]

⑤ 前向传播与损失计算
            input_id, labels = batch_data   #输入变化时这里需要修改,比如多输入,多输出的情况
            loss = model(input_id, labels)

⑥ 反向传播与参数更新

loss.backward():反向传播计算梯度,基于损失值更新模型参数的.grad属性

参数类型必须说明示例
retain_graphbool是否保留计算图(用于多次反向传播)loss.backward(retain_graph=True)

optimizer.step():根据梯度更新模型参数,执行优化算法(如SGD、Adam)

参数类型必须说明示例
closureCallable重新计算损失的闭包函数(如LBFGS)optimizer.step(closure)
            loss.backward()
            optimizer.step()

⑦ 损失记录与日志输出

列表.append():在列表末尾添加元素,直接修改原列表

参数类型必须说明示例
objectAny要添加到列表末尾的元素train_loss.append(loss.item())

int():将字符串或浮点数转换为整数,支持进制转换

参数类型必须说明示例
xstr/float待转换的值(如字符串或浮点数)int("10", base=2)(输出2进制10=2)
baseint进制(默认10)

len():返回对象(如列表、字符串)的长度或元素个数

参数类型必须说明示例
objSequence/Collection可计算长度的对象(如列表、字符串)len([1, 2, 3])(返回3)

logger.info():记录日志信息,输出训练过程中的关键状态

参数类型必须说明示例
msgstr日志消息(支持格式化字符串)logger.info("Epoch: %d", epoch)
*argsAny格式化参数(用于%占位符)
            train_loss.append(loss.item())
            if index % int(len(train_data) / 2) == 0:
                logger.info("batch loss %f" % loss)

⑧ Epoch评估与日志

logger.info():记录日志信息,输出训练过程中的关键状态

参数类型必须说明示例
msgstr日志消息(支持格式化字符串)logger.info("Epoch: %d", epoch)
*argsAny格式化参数(用于%占位符)
            train_loss.append(loss.item())
            if index % int(len(train_data) / 2) == 0:
                logger.info("batch loss %f" % loss)

⑨ 完整训练代码
    #训练
    for epoch in range(config["epoch"]):
        epoch += 1
        model.train()
        logger.info("epoch %d begin" % epoch)
        train_loss = []
        for index, batch_data in enumerate(train_data):
            optimizer.zero_grad()
            if cuda_flag:
                batch_data = [d.cuda() for d in batch_data]
            input_id, labels = batch_data   #输入变化时这里需要修改,比如多输入,多输出的情况
            loss = model(input_id, labels)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
            if index % int(len(train_data) / 2) == 0:
                logger.info("batch loss %f" % loss)
        logger.info("epoch average loss: %f" % np.mean(train_loss))
        evaluator.eval(epoch)

Ⅷ、模型保存

os.path.join():Python 中用于拼接路径的核心函数,其核心价值在于自动处理不同操作系统的路径分隔符,从而保证代码的跨平台兼容性

参数类型必填说明
path1字符串初始路径组件
*paths可变参数后续路径组件(可传多个)
    model_path = os.path.join(config["model_path"], "epoch_%d.pth" % epoch)
    # torch.save(model.state_dict(), model_path)
    return model, train_data

4.调用模型预测

# -*- coding: utf-8 -*-

import torch
import os
import random
import numpy as np
import logging
from config import Config
from model import TorchModel, choose_optimizer
from evaluate import Evaluator
from loader import load_data

logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

"""
模型训练主程序
"""

def main(config):
    #创建保存模型的目录
    if not os.path.isdir(config["model_path"]):
        os.mkdir(config["model_path"])
    #加载训练数据
    train_data = load_data(config["train_data_path"], config)
    #加载模型
    model = TorchModel(config)
    # 标识是否使用gpu
    cuda_flag = torch.cuda.is_available()
    if cuda_flag:
        logger.info("gpu可以使用,迁移模型至gpu")
        model = model.cuda()
    #加载优化器
    optimizer = choose_optimizer(config, model)
    #加载效果测试类
    evaluator = Evaluator(config, model, logger)
    #训练
    for epoch in range(config["epoch"]):
        epoch += 1
        model.train()
        logger.info("epoch %d begin" % epoch)
        train_loss = []
        for index, batch_data in enumerate(train_data):
            optimizer.zero_grad()
            if cuda_flag:
                batch_data = [d.cuda() for d in batch_data]
            input_id, labels = batch_data   #输入变化时这里需要修改,比如多输入,多输出的情况
            loss = model(input_id, labels)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
            if index % int(len(train_data) / 2) == 0:
                logger.info("batch loss %f" % loss)
        logger.info("epoch average loss: %f" % np.mean(train_loss))
        evaluator.eval(epoch)
    # 保存模型
    model_path = os.path.join(config["model_path"], "epoch_%d.pth" % epoch)
    torch.save(model.state_dict(), model_path)
    return model, train_data

if __name__ == "__main__":
    model, train_data = main(Config)

相关文章:

  • Spring Boot拦截器(Interceptor)与过滤器(Filter)深度解析:区别、实现与实战指南
  • Springboot中的 Mapper 无法找到的 可能原因及解决方案
  • 一个简单的井字棋(Tic-Tac-Toe)游戏的C语言实现
  • 程序化广告行业(20/89):交易模式深度剖析与价值解读
  • 基于51单片机的多功能时钟闹钟proteus仿真
  • 前端内存优化实战指南:从内存泄漏到性能巅峰
  • IMX6ULL_Pro开发板的串口应用程序实例(利用TTY子系统去使用串口)
  • 蓝桥杯[阶段总结] 二分,前缀和
  • C语言动态内存管理(上)
  • Compose 实践与探索十二 —— 附带效应
  • Webpack 基础
  • SLC跨头协作机制
  • 解析 Bootloader:嵌入式系统中不可或缺的启动程序
  • 蓝桥杯备考---- 图的存储与遍历
  • Matlab 基于SVPWM的VF三电平逆变器异步电机速度控制
  • 【Agent】OpenManus-Agent架构详细分析
  • 0-1背包问题 之 分割等和子集以及变形问题
  • 嵌入式SDIO 总线面试题及参考答案
  • 验证与调参——交叉验证/ 网格搜索/贝叶斯优化/随机搜索
  • 第7章 站在对象模型的尖端3: RTTI
  • 梅花奖在上海|朱洁静:穿越了人生暴风雨,舞台是最好良药
  • 构筑高地共伴成长,第六届上海创新创业青年50人论坛在沪举行
  • 中华人民共和国和俄罗斯联邦关于全球战略稳定的联合声明
  • 两部门部署中小学幼儿园教师招聘工作:吸纳更多高校毕业生从教
  • 顾家家居:拟定增募资近20亿元,用于家居产品生产线的改造和扩建等
  • 中美“第二阶段”贸易协定是否会在会谈中提出?商务部回应