实现效果:
输入:帮我把厕所的灯打开一下吧模型输出: tensor([2]) tensor([[0, 8, 8, 8, 2, 6, 8, 1, 0, 4, 8, 8, 8, 4]])
意图理解: ['turn_on_light']
实体识别: ['O' 'O' 'O' 'B-LOC' 'I-LOC' 'O' 'B-DEV' 'B-ACT' 'I-ACT' 'O' 'O' 'O']
命令: {'ACT': '打开', 'LOC': '厕所', 'DEV': '灯', 'VAL': ''}
完整代码:
"""
@Title: 意图理解 & 实体识别
@Time: 2025/9/8
@Author: Michael Jie
"""import math
import pickle
import timeimport torch
from sklearn.metrics import f1_score
from sklearn.preprocessing import LabelEncoder
from torch import nn, Tensor
from torch.optim import Adam
from torch.utils.data import Dataset
from tqdm import tqdmfrom ztools import TokenizerBert, MjTrain, MjUtil# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 分词器
tokenizer = TokenizerBert()
# 标签编码器
intent_encoder = LabelEncoder()
entity_encoder = LabelEncoder()# 嵌入
class Embedding(nn.Module):def __init__(self,d_model: int = 512,p: float = 0.1) -> None:"""可学习的词向量和位置编码Args:d_model: 特征维度,默认为 512p: 丢弃率,默认为 0.1"""super(Embedding, self).__init__()self.d_model = d_model# 词嵌入,分词器 BertTokenizer,填充索引 0self.lut = nn.Embedding(21128, d_model, 0)# 位置嵌入,最大序列长度 512self.pe = nn.Embedding(512, d_model)# 丢弃层self.dropout = nn.Dropout(p)def forward(self, x: Tensor) -> Tensor:# 根据序列长度生成位置索引,[0, 1, 2, ..., S - 1]pos = torch.arange(x.size(1)).unsqueeze(0) # (1, S)x = math.sqrt(self.d_model) * (self.lut(x) + self.pe(pos))return self.dropout(x)# 意图、实体识别模型
class IntentAndNER(nn.Module):def __init__(self,num_intent_classes: int,num_entity_classes: int,d_model: int = 512,p: float = 0.1,num_heads: int = 8,num_layers: int = 4) -> None:"""基于 BERTchanges in shape:(batch_size, seq_len) -(embedding)>(batch_size, seq_len, d_model) -(encoder)>(batch_size, seq_len, d_model) -(linear)>(batch_size, num_intent_classes) &(batch_size, seq_len, num_entity_classes)Args:num_intent_classes: 意图识别标签数num_entity_classes: 实体识别标签数d_model: 特征维度,默认为 512p: 丢弃率,默认为 0.1num_heads: 头数,默认为 8num_layers: 层数,默认为 4"""super(IntentAndNER, self).__init__()# 嵌入self.embedding = Embedding(d_model, p)# 编码器layer = nn.TransformerEncoderLayer(d_model, num_heads, activation="gelu",dropout=p, batch_first=True, norm_first=True)self.encoder = nn.TransformerEncoder(layer, num_layers)# 输出层self.intent_linear = nn.Linear(d_model, num_intent_classes)self.entity_linear = nn.Linear(d_model, num_entity_classes)# 损失函数self.loss_fun = nn.CrossEntropyLoss()def forward(self,x: Tensor,padding_mask: Tensor = None,tgt_intent: Tensor = None,tgt_entity: Tensor = None) -> tuple[Tensor, ...]:# (N, S) -> (N, S, E)x = self.embedding(x)y = self.encoder(x, src_key_padding_mask=padding_mask)# 截取相应的 token 当作分类结果y_intent = self.intent_linear(y[:, 0, :]) # (N, E_intent)y_entity = self.entity_linear(y) # (N, S, E_entity)# 计算损失和 f1 指标if tgt_intent is not None and tgt_entity is not None:loss_intent = self.loss_fun(y_intent, tgt_intent)loss_entity = self.loss_fun(y_entity.view(-1, y_entity.size(-1)), tgt_entity.view(-1))loss = loss_intent + loss_entity# 分别计算意图和实体的 f1 指标ids_intent = torch.argmax(y_intent, dim=-1) # (N,)f1_intent = f1_score(tgt_intent.tolist(), ids_intent.tolist(),average="weighted", zero_division=0.0)# 创建掩码,排除 -100 标签mask = tgt_entity != -100ids_entity = torch.argmax(y_entity, dim=-1) # (N, S)f1_entity = f1_score(tgt_entity[mask].tolist(), ids_entity[mask].tolist(),average="weighted", zero_division=0.0)return y_intent, y_entity, loss, f1_intent, f1_entityreturn y_intent, y_entitydef predict(self,x: Tensor,padding_mask: Tensor = None) -> tuple[Tensor, ...]:with torch.no_grad(): # 测试模式y_intent, y_entity = self.forward(x, padding_mask)y_intent = torch.argmax(y_intent, dim=-1) # (N,)y_entity = torch.argmax(y_entity, dim=-1) # (N, S)return y_intent, y_entity# 数据集
class TextDataset(Dataset):def __init__(self, lst: list[dict]) -> None:self.lst = lst # 数据def __getitem__(self, ids):# 直接返回对应文本数据,在 collate_fn 函数中批量处理obj = self.lst[ids]return obj["text"], obj["intent"], obj["entity"]def __len__(self):return len(self.lst)# 对齐实体标签
def align_entity(labels: list[int],tokens: list[int]) -> list[int]:"""[0, 1, 2, 5, 0, 3, 4] &[101, 2828, 778, 2428, 6444, 1168, 8145, 110, 102, 0] ->[-100, 0, 1, 2, 5, 0, 3, 4, -100, -100]Args:labels: 原始标签序列tokens: 输入 token 序列Returns:new_labels: 对齐填充后的标签序列"""new_labels = []i = 0for token in tokens:# [CLS], [SEP], [PAD]if token in [101, 102, 0]:new_labels.append(-100)else:new_labels.append(labels[i])i += 1return new_labels# collate_fn
def collate_fn(batch):text, intent, entity = zip(*batch)# 输入obj = tokenizer.tokenizer(text, padding=True, return_tensors="pt")text2ids = obj["input_ids"].to(device) # (N, T)mask = obj["attention_mask"].float().to(device) # (N, T)# 意图标签intent2ids = intent_encoder.transform(intent) # (N,)intent2ids = torch.tensor(intent2ids, dtype=torch.long, device=device)# 实体标签entity2ids = [] # (N, T)for i, item in enumerate(entity):temp = align_entity(entity_encoder.transform(item), text2ids[i].tolist())entity2ids.append(temp)entity2ids = torch.tensor(entity2ids, dtype=torch.long, device=device)return text2ids, mask, intent2ids, entity2ids# 训练
def train():epoch = 30 # 最大训练轮次lengths = 0.7, 0.3 # 训练集、验证集比例batch_size = 5, 10 # 批次大小d_model = 768 # 特征维度p = 0.2 # 丢弃率num_heads = 12 # 头数num_layers = 12 # 层数lr = 1e-5 # 学习率print("加载数据集...")# 数据texts = MjUtil.load_json("intent_classification/data/datasets/text.json")dataset = TextDataset(texts)loader_train, loader_val = MjTrain.split_dataset(dataset, lengths, batch_size, collate_fn)# 意图标签intent_labels = ["unknown", "turn_on_light", "turn_off_light", "adjust_brightness", ]intent_encoder.fit(intent_labels)with open("intent_classification/data/results/temp/intent_labels.pkl", "wb") as f:pickle.dump(intent_encoder, f)# 实体标签entity_labels = ["O", "B-ACT", "I-ACT", "B-LOC", "I-LOC", "B-DEV", "I-DEV", "B-VAL", "I-VAL", ]entity_encoder.fit(entity_labels)with open("intent_classification/data/results/temp/entity_labels.pkl", "wb") as f:pickle.dump(entity_encoder, f)print("创建模型...")model = IntentAndNER(len(intent_labels), len(entity_labels), d_model, p, num_heads, num_layers)model.to(device)optimizer = Adam(model.parameters(), lr) # 优化器print("开始训练...")for i in range(epoch):time.sleep(1)# 批次训练loss_train = 0 # 损失f1_train_i, f1_train_e = 0, 0 # f1 指标model.train() # 训练模式for data in tqdm(loader_train):_, _, loss, f1_i, f1_e = model(*data)loss_train += loss.item()f1_train_i += f1_if1_train_e += f1_e# 反向传播loss.backward()optimizer.step()optimizer.zero_grad()loss_train /= len(loader_train)f1_train_i /= len(loader_train)f1_train_e /= len(loader_train)print("训练损失:{:.4f},f1指标:{:.4f},{:.4f}".format(loss_train, f1_train_i, f1_train_e))time.sleep(2)# 批次验证loss_val = 0 # 损失f1_val_i, f1_val_e = 0, 0 # f1 指标model.eval() # 测试模式for data in tqdm(loader_val):_, _, loss, f1_i, f1_e = model(*data)loss_val += loss.item()f1_val_i += f1_if1_val_e += f1_eloss_val /= len(loader_val)f1_val_i /= len(loader_val)f1_val_e /= len(loader_val)print("验证损失:{:.4f},f1指标:{:.4f},{:.4f}".format(loss_val, f1_val_i, f1_val_e))time.sleep(2)# 保存损失和 f1 指标with open("intent_classification/data/results/temp/loss.csv", "a+", encoding="utf-8") as f:f.write("{:.4f}".format(loss_train) + "," + "{:.4f}".format(f1_train_i) + "," +"{:.4f}".format(f1_train_e) + "," + "{:.4f}".format(loss_val) + "," +"{:.4f}".format(f1_val_i) + "," + "{:.4f}".format(f1_val_e) + "\n")print("已保存第{}轮的损失和 f1 指标...\n".format(i + 1))# 保存模型if (i + 1) % 10 == 0:torch.save(model, "intent_classification/data/results/temp/model_" + str(i + 1) + ".pth")print("已保存第{}轮的模型...".format(i + 1))# 测试
def test():# 加载模型model_path = "intent_classification/data/results/final/20250909/model.pth"model = torch.load(model_path)model.to(device)model.eval() # 测试模式# 加载输入text = "帮我把厕所的灯打开一下吧"obj = tokenizer.tokenizer(text, return_tensors="pt")inputs = obj["input_ids"].to(device)# 模型预测intent, entity = model.predict(inputs)print("模型输出:", intent, entity)# 解码with open("intent_classification/data/results/final/20250909/intent_labels.pkl", "rb") as f:encoder = pickle.load(f)out = encoder.inverse_transform(intent.tolist())print("意图理解:", out)with open("intent_classification/data/results/final/20250909/entity_labels.pkl", "rb") as f:encoder = pickle.load(f)out = encoder.inverse_transform(entity.view(-1).tolist()[1:-1])print("实体识别:", out)print("命令:", entity2command(text, out))# 命令
def entity2command(text: str,entity: list[str]) -> dict:command = {"ACT": "", "LOC": "", "DEV": "", "VAL": "", }k, v = "", ""for t, e in zip(list(text), entity):if e.startswith("B-"):# 保存上一个实体command[k] = v# 开始新的实体k, v = e.replace("B-", ""), telif e.startswith("I-"):v += telse:# 保存当前实体command[k] = v# 保存最后一个实体,并删除初始的无效键command[k] = vcommand.pop("", None)return commandif __name__ == '__main__':# train()test()