当前位置: 首页 > news >正文

图机器学习(18)——使用图构建文档主题分类模型

图机器学习(18)——使用图构建文档主题分类模型

    • 0. 前言
    • 1. 构建文档主题分类器
    • 2. 模型训练
    • 3. 模型性能评估

0. 前言

我们已经学习了如何处理非结构化信息,并掌握了如何用图结构表征这类信息。我们从 Reuters-21578 基准数据集出发,通过标准 NLP 工具对文本信息进行标记和结构化处理。随后利用这些高层特征构建了多种网络类型:基于知识的网络、二分网络、节点子集的投影网络,以及反映数据集主题关联的网络。通过局部和全局属性展示了这些量化指标如何表征和描述结构各异的网络类型。在本节中,将介绍利用这些图结构构建机器学习模型,运用无监督技术识别语义社区,将主题/话题相似的文档进行聚类。

1. 构建文档主题分类器

为演示图结构的应用价值,我们将重点利用二分实体-文档图提供的拓扑信息和实体关联来训练多标签分类器,从而实现文档主题预测。具体而言分析浅层机器学习方法,使用从二分网络中提取的嵌入来训练传统分类器,例如随机森林分类器。

(1) 选取前 10 个主题进行建模,这些主题已具备足够的文档数据用于模型训练与评估:

from collections import Counter
topics = Counter([label for document_labels in corpus["label"] for label in document_labels]).most_common(10)

输出结果如下所示,显示了我们将在后续分析中关注的主题名称:

[('earn', 3964), ('acq', 2369), ('money-fx', 717),
('grain', 582), ('crude', 578), ('trade', 485),
('interest', 478), ('ship', 286), ('wheat', 283),
('corn', 237)]

(2) 在训练主题分类器时,我们需要将焦点限制在仅属于这些标签的文档上,获得过滤后的语料库:

topicsList = [topic[0] for topic in topics]
topicsSet = set(topicsList)
dataset = corpus[corpus["label"].apply(lambda x: len(topicsSet.intersection(x))>0)]

完成数据集提取与结构化处理后,即可开始训练主题模型并评估其性能。接下来,我们将利用网络信息实现主题分类任务的浅层学习方法。

(3) 首先,在二分图上应用 Node2Vec 算法计算嵌入向量。经过过滤的文档-文档网络通常具有包含大量孤立节点的外围结构,这些节点无法从拓扑信息中受益;而未过滤的文档-文档网络则存在过多边连接,会导致算法可扩展性问题。因此,使用二分图对于有效利用拓扑信息及实体与文档间的关联至关重要:

from node2vec import Node2Vec
node2vec = Node2Vec(G, dimensions=10)
model = node2vec.fit(window=20)
embeddings = model.wv

其中,嵌入向量的维度以及用于生成随机游走的窗口大小都属于需要通过交叉验证优化的超参数。

(4) 为提高计算效率,可以预先计算向量集并保存到本地,并在优化过程中直接调用存储结果。该方案基于半监督或转导学习场景的假设,即在训练时我们已经掌握整个数据集的连接信息(除标签外)。将嵌入向量存储至文件:

pd.DataFrame(embeddings.vectors,index=embeddings.index2word
).to_pickle(f"graphEmbeddings_{dimension}_{window}.p")

(5) 这些嵌入可以集成到 scikit-learn 转换器中,以便在网格搜索交叉验证过程中使用:

from sklearn.base import BaseEstimatorclass EmbeddingsTransformer(BaseEstimator):def __init__(self, embeddings_file):self.embeddings_file = embeddings_filedef fit(self, *args, **kwargs):self.embeddings = pd.read_pickle(self.embeddings_file)return selfdef transform(self, X):return self.embeddings.loc[X.index]def fit_transform(self, X, y):return self.fit().transform(X)

(6) 构建建模训练管道时,我们先将语料库划分为训练集和测试集:

def train_test_split(corpus):graphIndex = [index for index in corpus.index if index in graphEmbeddings.embeddings.index]train_idx = [idx for idx in graphIndex if "training/" in idx]test_idx = [idx for idx in graphIndex if "test/" in idx]
return corpus.loc[train_idx], corpus.loc[test_idx]
train, test = train_test_split(dataset)

(7) 构建函数提取特征和标签:

def get_labels(corpus, topicsList=topicsList):return corpus["label"].apply(lambda labels: pd.Series({label: 1 for label in labels}).reindex(topicsList).fillna(0))[topicsList]
def get_features(corpus):return corpus["parsed"] #graphEmbeddings.transform(corpus["parsed"])
def get_features_and_labels(corpus):return get_features(corpus), get_labels(corpus)

(8) 实例化建模管道:

from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestClassifier 
from sklearn.multioutput import MultiOutputClassifier
pipeline = Pipeline([("embeddings", graphEmbeddings),("model", model)
]) 

2. 模型训练

定义参数空间以及交叉验证网格搜索配置:

param_grid = {"embeddings__embeddings_file": files,"model__estimator__n_estimators": [50, 100], "model__estimator__max_features": [0.2,0.3, "auto"], #"model__estimator__max_depth": [3, 5]
}
grid_search = GridSearchCV(pipeline, param_grid=param_grid, cv=5, n_jobs=-1, scoring=lambda y_true, y_pred: f1_score(y_true, y_pred,average='weighted'))

最后,使用 sklearn APIfit 方法训练主题模型:

model = grid_search.fit(features, labels)

3. 模型性能评估

确定最佳模型后,我们可以在测试数据集上使用该模型评估其性能。定义辅助函数用于获取一组预测结果:

def get_predictions(model, features):return pd.DataFrame(model.predict(features), columns=topicsList, index=features.index
)
preds = get_predictions(model, get_features(test))
labels = get_labels(test)

查看训练分类器的性能:

from sklearn.metrics import classification_report
print(classification_report(labels, preds))

输出结果如下所示:

              precision    recall  f1-score   support0       0.97      0.94      0.95      10871       0.93      0.74      0.83       7192       0.79      0.45      0.57       1793       0.96      0.64      0.77       1494       0.95      0.59      0.73       1895       0.95      0.45      0.61       1176       0.87      0.41      0.56       1317       0.83      0.21      0.34        898       0.69      0.34      0.45        719       0.61      0.25      0.35        56micro avg       0.94      0.72      0.81      2787macro avg       0.85      0.50      0.62      2787
weighted avg       0.92      0.72      0.79      2787samples avg       0.76      0.75      0.75      2787

可以尝试调整分析流程的类型和超参数,变换不同模型,并在编码嵌入向量时试验不同取值。上述方法属于转导式学习,因为它使用了基于完整数据集训练的嵌入向量。这种情形在半监督任务中十分常见——当标注信息仅存在于少量数据点子集时,我们的任务就是推断所有未知样本的标签。

http://www.dtcms.com/a/293536.html

相关文章:

  • 使用idea 将一个git分支的部分记录合并到git另一个分支
  • 阿里云ODPS十五周年重磅升级发布:为AI而生的数据平台
  • 第七章 Pytorch构建模型详解【构建CIFAR10模型结构】
  • Cmake、VS2019、C++、openGLopenCV环境安装
  • idea部署新项目时,用自定义的maven出现的问题解决
  • charles手机端抓包 ios 安卓通用
  • 【js(5)原型与原型链】
  • 反向传播及优化器
  • 【图像翻转+图像的仿射变换】——图像预处理(OpenCV)
  • 网络--VLAN技术
  • Ruby 命令行选项详解
  • C++ std::list概念与使用案例
  • Web后端实战:登录认证(JWT令牌生成和Filter过滤器Interceptor拦截器)
  • 前端ApplePay支付-H5全流程实战指南
  • 使用Docker搭建SearXNG搜索引擎
  • AI聊天方案:vue+nodeJs+SSE
  • 变频器带动电机:全方位解析参数变化
  • MCP与企业数据集成:ERP、CRM、数据仓库的统一接入
  • 第一层nginx访问url如何透传到第二层nginx
  • OpenLayers 快速入门(九)Extent 介绍
  • Leetcode力扣解题记录--第240题(矩阵搜索)
  • 数据科学与大数据技术和统计学有什么区别?​
  • 关于针对 DT_REG 出现红色波浪线的问题(编译错误/IDE警告),以下是 精准解决方案,保持你的代码功能完全不变:
  • 【Linux-云原生-笔记】Haproxy相关
  • 基于Python(Django)+MongoDB实现的(Web)新闻采集和订阅系统
  • 模拟实现消息队列项目
  • 使用PEghost恢复系统(笔记版)
  • OpenEuler系统架构下编译redis的RPM包
  • [Mediatek] MTK openwrt-21.02 wifi 没启动问题
  • Android Multidex 完全解析:解决64K方法数限制