自然语言处理实战——基于k近邻法的文本分类
目录
一、引言
二、20newsgroups 数据集核心介绍
1. 基本信息
2. 数据结构
三、基础配置与资源准备功能
1. 环境与样式配置
2. 依赖资源自动准备
四、数据处理核心功能(文本分类基础流程)
1. 本地数据加载(load_data)
2. 文本预处理(preprocess_text)
3. 文本向量化(vectorize_text)
五、模型训练与评估核心功能
1. k-NN 模型训练(train_knn)
2. k 值自动调优(optimize_k)
3. 模型评估(evaluate_model)
六、可视化分析功能
可视化保存统一优化点:
七、相似文本检索(retrieve_similar)
八、数据持久化功能(保存与加载)
1. 关键对象保存(save_objects)
2. 保存对象加载(load_saved_objects)
九、主函数(main):全流程串联功能
十、整体功能亮点总结
十一、基于k近邻法的文本分类的Python代码完整实现
十二、程序运行结果展示
十三、实验结果分析
(一)实验数据基础分析
1. 数据规模与分布
2. 文本预处理效果
(二)特征提取效果分析
1. 类别特征区分度
2. 向量维度合理性
(三)参数调优分析
1. 调优结果
2. 调优价值
(四)模型分类性能分析
1. 整体性能指标
2. 各类别详细性能
3. 混淆矩阵分析
(五)相似文本检索功能分析
1. 检索准确性
2. 相似度指标
(六)实验结论与总结
1. 核心结论
2. 潜在优化方向
3. 实验价值
十四、总结
一、引言
本文的实战内容是一个功能完整、易用性强的基于k近邻法的文本分类,该方法将文本转化为向量(如词袋、TF-IDF)后,通过 k 个最相似样本的类别投票,完成新闻分类、垃圾邮件识别等任务,适合小数据集快速落地。本文将详细的介绍该项目的所有功能以及Python代码完整实现。
二、20newsgroups 数据集核心介绍
20newsgroups 是文本分类领域的经典基准数据集,也是本文项目中完成文本分类任务的数据集。数据集下载地址:https://people.csail.mit.edu/jrennie/20Newsgroups/20news-18828.tar.gz。将下载好的数据集解压并保存到指定的目录中,便于后续Python代码能够正常访问。核心信息如下:
1. 基本信息
- 来源:由美国麻省理工学院(MIT)Jason Rennie 于 1999 年整理发布,数据源自 1980 年代末至 1990 年代初的 Usenet 新闻组(类似早期论坛)。
- 核心用途:作为文本分类、主题识别、文本聚类的基准数据集,用于评估模型对 “不同主题文本” 的区分能力。
- 规模:主流版本(20news-18828)含 18828 条文本样本,分为 20 个主题类别。
- 编码格式:默认采用 latin1 编码,兼容英文文本,无中文乱码问题。
2. 数据结构
- 类别划分:20 个类别覆盖 6 大领域,代码中选取 3 个典型类别(rec.sport.hockey 冰球、talk.politics.misc 杂项政治、sci.space 太空)用于演示,具体分类如下:
- 体育:rec.sport.hockey、rec.sport.baseball 等;
- 政治:talk.politics.misc、talk.politics.guns 等;
- 科技 / 太空:sci.space、sci.med、sci.tech 等;
- 计算机 / 硬件:comp.sys.ibm.pc.hardware 等;
- 宗教 / 哲学:alt.atheism、talk.religion.misc 等;
- 娱乐 / 爱好:rec.motorcycles、rec.autos 等。
- 文本特点:每条样本为新闻组用户发帖内容,含主题讨论、观点表达等,长度几十到几百词;原始文本包含标题、正文、脚注、引用,代码中通过过滤仅保留核心正文。
- 训练测试集划分:官方默认 60% 为训练集(~11314 条)、40% 为测试集(~7514 条),按 “时间戳” 拆分(训练集为早期发帖,测试集为后期发帖),模拟真实时序泛化场景,避免数据泄露。
三、基础配置与资源准备功能
1. 环境与样式配置
- 核心功能:统一绘图风格、解决中文乱码、适配负号显示,为后续可视化提供基础环境。
- 关键细节:
- 采用
seaborn-v0_8-whitegrid风格,图表整洁美观; - 配置
SimHei字体支持中文,axes.unicode_minus=False避免负号显示异常; - 全局统一配置,无需在每个可视化函数中重复设置。
- 采用
2. 依赖资源自动准备
- NLTK 资源下载:自动下载
stopwords(英文停用词表)和wordnet(词形还原词典),首次运行自动完成,无需用户手动操作。 - 目录自动创建:
- 核心对象保存目录:
SAVE_DIR = "G:/save_data/20news_save/",用于存储模型、向量器等核心产物; - 可视化保存目录:
VIS_SAVE_DIR = "G:/save_data/20news_visualizations/",专门存储所有图表,与核心对象分离,便于管理; - 两个目录均支持 “不存在则自动创建”,避免保存失败。
- 核心对象保存目录:
四、数据处理核心功能(文本分类基础流程)
1. 本地数据加载(load_data)
- 核心目标:解决在线数据集下载 403 错误,从本地加载 20newsgroups 数据集。
- 具体功能:
- 加载指定 3 个类别(
rec.sport.hockey冰球、talk.politics.misc杂项政治、sci.space太空)的文本数据; - 数据校验:打印加载的总样本数,若样本数为 0 则抛出错误,提示用户检查路径和文件夹结构;
- 自动拆分训练集(60%)和测试集(40%):用
stratify参数保证两类数据的类别分布一致,避免数据偏斜影响模型泛化; - 输出:训练文本(
X_train)、训练标签(y_train)、测试文本(X_test)、测试标签(y_test)、类别名称(categories)。
- 加载指定 3 个类别(
2. 文本预处理(preprocess_text)
- 核心目标:清洗文本噪声,标准化词汇,为向量化和模型训练铺路。
- 多步处理流程(按顺序):
- 过滤特殊字符 / 数字:仅保留英文单词和空格,去除标点、数字等无关噪声;
- 小写化:统一文本大小写(如 “Hockey”→“ hockey ”),避免词汇重复计算;
- 分词:按空格拆分文本为单词列表;
- 去停用词:过滤无意义通用词汇(如 “the”“is”),聚焦核心语义;
- 词形还原:用
WordNetLemmatizer标准化词汇(如 “running”→“run”“cars”→“car”),减少词汇变体; - 空文本处理:若预处理后文本为空,返回空字符串,避免后续向量化报错。
3. 文本向量化(vectorize_text)
- 核心目标:将文本转化为 k-NN 算法可计算的数值向量(k-NN 基于向量距离判断相似度)。
- 关键设计:
- 采用 TF-IDF(词频 - 逆文档频率)方法:通过权重突出类别特征词(如 “hockey” 在冰球类文本中权重更高),比单纯词频更适合文本分类;
- 限制最大特征数(
max_features=5000):减少高维稀疏向量对计算效率的影响,平衡性能与速度; - 数据泄露防护:训练集执行 “拟合 + 转化”(
fit_transform),测试集仅执行 “转化”(transform),复用训练集词汇表; - 输出:训练集向量(
train_vec)、测试集向量(test_vec)、拟合好的 TF-IDF 向量器(tfidf,用于后续检索和新文本向量化)。
五、模型训练与评估核心功能
1. k-NN 模型训练(train_knn)
- 核心目标:构建基于余弦距离的 k-NN 分类模型,“记忆” 训练数据的向量和标签。
- 参数设计:
n_neighbors=k:近邻数量(由后续调优确定);metric='cosine':选择余弦距离,适配高维文本向量(衡量向量方向相似度,抗稀疏性优于欧氏距离);- 模型特性:k-NN 是 “懒惰学习”,无复杂训练过程,仅通过
fit方法存储训练数据,适合小规模文本分类; - 输出:训练好的 k-NN 模型(
knn)。
2. k 值自动调优(optimize_k)
- 核心目标:解决 k 值选择难题(k 过小易过拟合、k 过大易欠拟合),找到最优近邻数量。
- 具体功能:
- 候选 k 值范围:3、5、7、9、11(覆盖常见合理范围);
- 调优方法:网格搜索(
GridSearchCV)+5 折交叉验证,以准确率为评价指标,避免单一数据划分的偶然性; - 可视化 k 值影响:绘制 “k 值 - 交叉验证准确率” 折线图,标注每个 k 值的具体准确率,直观展示最优 k 值;
- 新增保存功能:
- 高清保存折线图(
dpi=300),满足报告 / 论文使用; - 文件名含时间戳(
k_value_impact_20240520_163300.png),避免覆盖; bbox_inches='tight'防止标题 / 标签被截断;
- 高清保存折线图(
- 输出:最优 k 值(
best_k)、最优 k 值对应的交叉验证准确率(best_k_score),并打印保存路径。
3. 模型评估(evaluate_model)
- 核心目标:量化模型分类性能,定位错误模式,直观展示模型效果。
- 具体功能:
- 计算测试集准确率:快速判断模型整体表现;
- 混淆矩阵热力图:直观展示 “真实类别→预测类别” 的错误分布(如哪两类文本易混淆);
- 新增保存功能:高清保存热力图(
dpi=300),文件名含时间戳(confusion_matrix_20240520_163400.png);
- 新增保存功能:高清保存热力图(
- 输出分类报告:包含每个类别的精确率、召回率、F1 值,全面评估模型对不同类别的表现(如是否存在某类别预测精度低的问题);
- 输出:测试集预测标签(
y_pred)、混淆矩阵(cm)、测试集准确率(acc),并打印混淆矩阵保存路径。
六、可视化分析功能
所有可视化图表均支持 “显示 + 高清保存” 双重功能,覆盖数据特征、模型效果、参数影响全维度,保存逻辑统一优化:
| 可视化功能 | 核心作用 | 保存细节亮点 |
|---|---|---|
类别分布可视化(visualize_class_distribution) | 验证训练集 / 测试集类别是否均衡,避免数据偏斜 | 文件名:class_distribution_20240520_163000.png;自动适配中文标签,高清保存 |
文本长度分布对比(visualize_text_length) | 验证预处理的噪声过滤效果(预处理后文本更简洁) | 双子图对比(预处理前 / 后);dpi=300保证清晰度;文件名含时间戳 |
类别特征词云(visualize_category_wordcloud) | 直观展示每个类别的核心特征词(TF-IDF 权重最高) | 替换类别名特殊字符为下划线(如rec.sport.hockey→rec_sport_hockey),避免文件名报错;每个类别单独保存词云图 |
k 值影响可视化(optimize_k内置) | 辅助理解 k 值对模型性能的影响,验证最优 k 值选择 | 标注每个 k 值的准确率;保存后不影响原有可视化效果 |
混淆矩阵热力图(evaluate_model内置) | 定位模型错误模式(如类别混淆情况) | 保留标注数字和中文标签;bbox_inches='tight'防止边界截断 |
相似文本检索相似度图(retrieve_similar内置) | 直观展示检索结果的相关性高低 | 截取查询前 20 个字符,过滤特殊字符(如 “hockey game rules and players”→“hockey game rules”);文件名含查询关键词和时间戳 |
可视化保存统一优化点:
- 高清保存:
dpi=300,满足学术报告、演示等场景需求; - 防覆盖:文件名含精确到秒的时间戳(
%Y%m%d_%H%M%S),彻底避免文件覆盖; - 防报错:自动过滤 / 替换文件名中的特殊字符(如
/、:、.),避免保存失败; - 易查找:保存后打印文件路径,方便用户快速定位图表。
七、相似文本检索(retrieve_similar)
- 核心目标:利用 k-NN 的近邻查找能力,实现 “输入查询文本→返回最相关文本” 的检索功能,超出基础分类任务。
- 具体流程:
- 对查询文本执行与训练数据一致的预处理和向量化(保证格式统一);
- 查找训练集中与查询向量最相似的
top_k个样本,计算相似度(1 - 余弦距离,范围 0-1,值越高越相关); - 可视化相似度分布(条形图,含保存功能);
- 输出检索结果详情:文本内容预览(前 200 字符)、类别、相似度;
- 输入参数:查询文本(
query)、TF-IDF 向量器(tfidf)、k-NN 模型(knn)、训练文本(train_texts)等,支持自定义top_k(默认 3); - 应用场景:文本查重、相关内容推荐、主题溯源等。
八、数据持久化功能(保存与加载)
1. 关键对象保存(save_objects)
- 核心目标:避免重复执行预处理、训练等耗时步骤,支持后续快速复用。
- 保存内容(覆盖全流程核心产物):
- 数据层:预处理后的训练 / 测试文本(
X_train_clean/X_test_clean); - 特征层:TF-IDF 向量器(
tfidf)、训练 / 测试向量(X_train_vec/X_test_vec); - 模型层:最优 k 值(
best_k)、训练好的 k-NN 模型(knn_model); - 结果层:测试集预测标签(
y_pred)、混淆矩阵(cm)、测试准确率(acc); - 辅助层:类别名称(
categories),便于后续加载后识别类别;
- 数据层:预处理后的训练 / 测试文本(
- 保存机制:用
pickle序列化打包为单个.pkl文件,文件名含时间戳,避免覆盖。
2. 保存对象加载(load_saved_objects)
- 核心目标:快速加载之前保存的所有对象,直接用于后续任务,无需重复运行全流程。
- 应用场景:
- 直接用加载的模型预测新文本(无需重新预处理、训练);
- 重新绘制可视化图表(如混淆矩阵、词云),无需重新计算;
- 基于历史结果优化参数(如调整
max_features),对比不同配置的效果;
- 使用方式:输入保存文件路径,返回包含所有对象的字典,键与保存时一致,直接调用即可。
九、主函数(main):全流程串联功能
- 核心目标:实现 “一键运行”,按逻辑顺序串联所有模块,无需用户手动调用单个函数。
- 执行流程:
- 加载本地数据 → 2. 可视化类别分布(自动保存) → 3. 文本预处理 → 4. 可视化文本长度变化(自动保存) → 5. 文本向量化 → 6. 可视化类别词云(自动保存) → 7. k 值调优(自动保存 k 值影响图) → 8. 训练最优 k 值模型 → 9. 模型评估(自动保存混淆矩阵) → 10. 相似文本检索演示(自动保存相似度图) → 11. 保存所有关键对象;
- 用户体验优化:每个步骤均打印进度提示(如 “数据加载完成”“文本向量化完成”),保存后打印文件路径,便于跟踪运行状态。
十、整体功能亮点总结
- 全流程覆盖:从数据加载到结果保存,无需额外补充代码,一站式完成文本分类 + 检索任务;
- 鲁棒性强:包含数据校验、异常处理、路径适配、特殊字符过滤,降低运行报错概率;
- 可视化丰富且实用:6 类图表覆盖 “数据 - 模型 - 结果” 全维度,新增自动保存功能,满足报告 / 演示需求;
- 易用性极高:自动创建目录、时间戳防覆盖、高清保存、中文适配,无需专业知识也能顺畅运行;
- 可复用性好:数据持久化支持快速复用,避免重复计算,节省时间;
- 扩展性强:参数可配置(如
max_features、top_k、候选 k 值),支持根据需求调整优化。
十一、基于k近邻法的文本分类的Python代码完整实现
import re
import numpy as np
import nltk
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import os
import time
from wordcloud import WordCloud
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.model_selection import GridSearchCV
from sklearn.datasets import load_files
from sklearn.model_selection import train_test_split# 配置可视化样式
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False# 下载NLTK资源(首次运行需执行)
nltk.download('stopwords')
nltk.download('wordnet')# 路径配置
SAVE_DIR = "G:/save_data/20news_save/" # 核心对象保存目录
VIS_SAVE_DIR = "G:/save_data/20news_visualizations/" # 可视化图表保存目录
# 创建目录(不存在则自动创建)
for dir_path in [SAVE_DIR, VIS_SAVE_DIR]:if not os.path.exists(dir_path):os.makedirs(dir_path)# 数据加载
def load_data():data_path = r"G:/datasets/20news-18828/"categories = ['rec.sport.hockey', 'talk.politics.misc', 'sci.space']all_data = load_files(container_path=data_path,categories=categories,shuffle=True,random_state=42,encoding='latin1')print(f"加载到的总样本数:{len(all_data.data)}")if len(all_data.data) == 0:raise ValueError("未加载到数据!请检查路径和文件夹结构")X_train, X_test, y_train, y_test = train_test_split(all_data.data,all_data.target,test_size=0.4,random_state=42,stratify=all_data.target)print(f"数据加载完成:训练集{len(X_train)}条,测试集{len(X_test)}条")return X_train, y_train, X_test, y_test, categories# 文本预处理
def preprocess_text(text):text = re.sub(r'[^a-zA-Z\s]', '', text)text = text.lower()words = text.split()stop_words = set(stopwords.words('english'))words = [w for w in words if w not in stop_words]lemmatizer = WordNetLemmatizer()words = [lemmatizer.lemmatize(w) for w in words]return ' '.join(words) if words else ''# 文本向量化
def vectorize_text(train_texts, test_texts, max_features=5000):tfidf = TfidfVectorizer(max_features=max_features)train_vec = tfidf.fit_transform(train_texts)test_vec = tfidf.transform(test_texts)print(f"文本向量化完成:向量维度{train_vec.shape[1]}")return train_vec, test_vec, tfidf# 模型训练
def train_knn(train_vec, train_labels, k=5):knn = KNeighborsClassifier(n_neighbors=k, metric='cosine')knn.fit(train_vec, train_labels)return knn# 模型评估
def evaluate_model(model, test_vec, test_labels, categories):y_pred = model.predict(test_vec)acc = accuracy_score(test_labels, y_pred)print(f"\n模型准确率:{acc:.4f}")# 混淆矩阵热力图cm = confusion_matrix(test_labels, y_pred)plt.figure(figsize=(8, 6))sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',xticklabels=categories, yticklabels=categories)plt.xlabel('预测类别')plt.ylabel('真实类别')plt.title(f'k-NN模型混淆矩阵(准确率:{acc:.4f})')plt.tight_layout()# 保存混淆矩阵timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())save_path = os.path.join(VIS_SAVE_DIR, f"confusion_matrix_{timestamp}.png")plt.savefig(save_path, dpi=300, bbox_inches='tight') # dpi=300保证高清plt.show()print(f"混淆矩阵已保存至:{save_path}")# 分类报告print("\n分类报告:")print(classification_report(test_labels, y_pred, target_names=categories))return y_pred, cm, acc# k值调优
def optimize_k(train_vec, train_labels):param_grid = {'n_neighbors': [3, 5, 7, 9, 11]}grid_search = GridSearchCV(KNeighborsClassifier(metric='cosine'),param_grid,cv=5,scoring='accuracy')grid_search.fit(train_vec, train_labels)# 可视化k值影响k_values = [param['n_neighbors'] for param in grid_search.cv_results_['params']]mean_scores = grid_search.cv_results_['mean_test_score']plt.figure(figsize=(8, 5))plt.plot(k_values, mean_scores, 'o-', color='orange')plt.xlabel('k值(近邻数量)')plt.ylabel('5折交叉验证准确率')plt.title('k值对模型性能的影响')plt.xticks(k_values)for k, score in zip(k_values, mean_scores):plt.text(k, score + 0.005, f'{score:.4f}', ha='center')plt.tight_layout()# 保存k值影响图timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())save_path = os.path.join(VIS_SAVE_DIR, f"k_value_impact_{timestamp}.png")plt.savefig(save_path, dpi=300, bbox_inches='tight')plt.show()print(f"k值影响图已保存至:{save_path}")best_k = grid_search.best_params_['n_neighbors']best_score = grid_search.best_score_print(f"\n最优k值:{best_k},对应交叉验证准确率:{best_score:.4f}")return best_k, best_score# 相似文本检索
def retrieve_similar(query, tfidf, knn, train_texts, train_labels, categories, top_k=3):query_clean = preprocess_text(query)query_vec = tfidf.transform([query_clean])distances, indices = knn.kneighbors(query_vec, n_neighbors=top_k)similarities = 1 - distances[0]# 可视化相似度plt.figure(figsize=(8, 5))bars = plt.bar([f'第{i + 1}篇' for i in range(top_k)], similarities, color='skyblue')plt.ylim(0, 1.0)plt.xlabel('检索结果')plt.ylabel('相似度(越高越相关)')plt.title(f'与查询“{query}”相关的文本相似度')for bar, sim in zip(bars, similarities):height = bar.get_height()plt.text(bar.get_x() + bar.get_width() / 2., height + 0.01,f'{sim:.4f}', ha='center', va='bottom')plt.tight_layout()# 保存相似度条形图timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())# 替换查询中的特殊字符,避免文件名报错safe_query = re.sub(r'[^\w\s]', '', query)[:20] # 截取前20个字符save_path = os.path.join(VIS_SAVE_DIR, f"similarity_search_{safe_query}_{timestamp}.png")plt.savefig(save_path, dpi=300, bbox_inches='tight')plt.show()print(f"相似度检索图已保存至:{save_path}")# 输出结果results = []for i, idx in enumerate(indices[0]):results.append({'文本内容': train_texts[idx][:200] + '...','类别': categories[train_labels[idx]],'相似度': similarities[i]})return results# 可视化函数
# 类别分布(训练集vs测试集)
def visualize_class_distribution(y_train, y_test, categories):plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)sns.countplot(x=y_train)plt.xticks(ticks=range(len(categories)), labels=categories, rotation=45, ha='right')plt.xlabel('类别')plt.ylabel('样本数量')plt.title('训练集类别分布')plt.subplot(1, 2, 2)sns.countplot(x=y_test)plt.xticks(ticks=range(len(categories)), labels=categories, rotation=45, ha='right')plt.xlabel('类别')plt.ylabel('样本数量')plt.title('测试集类别分布')plt.tight_layout()# 保存类别分布图timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())save_path = os.path.join(VIS_SAVE_DIR, f'class_distribution_{timestamp}.png')plt.savefig(save_path, dpi=300, bbox_inches='tight')plt.show()print(f"类别分布图已保存至:{save_path}")# 文本长度分布
def visualize_text_length(X_raw, X_clean, title_prefix="训练集文本长度分布(词数)"):raw_lengths = [len(text.split()) for text in X_raw]clean_lengths = [len(text.split()) for text in X_clean]plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)sns.histplot(raw_lengths, bins=30, color='salmon', kde=True)plt.xlabel('词数')plt.ylabel('样本数量')plt.title(f'{title_prefix}(预处理前)')plt.xlim(0, 500)plt.subplot(1, 2, 2)sns.histplot(clean_lengths, bins=30, color='lightgreen', kde=True)plt.xlabel('词数')plt.ylabel('样本数量')plt.title(f'{title_prefix}(预处理后)')plt.xlim(0, 500)plt.tight_layout()# 保存文本长度分布图timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())save_path = os.path.join(VIS_SAVE_DIR, f'text_length_distribution_{timestamp}.png')plt.savefig(save_path, dpi=300, bbox_inches='tight')plt.show()print(f"文本长度分布图已保存至:{save_path}")# 类别特征词云(TF-IDF权重最高的词)
def visualize_category_wordcloud(train_vec, tfidf, y_train, categories, top_n=50):feature_names = tfidf.get_feature_names_out()for label, category in enumerate(categories):category_vecs = train_vec[y_train == label]avg_weights = np.array(category_vecs.mean(axis=0)).flatten()top_indices = avg_weights.argsort()[-top_n:][::-1]top_words = {feature_names[i]: avg_weights[i] for i in top_indices}plt.figure(figsize=(8, 6))wc = WordCloud(width=800, height=600, background_color='white').generate_from_frequencies(top_words)plt.imshow(wc)plt.axis('off')plt.title(f'类别“{category}”的特征词云(TF-IDF权重)')plt.tight_layout()# 保存词云图timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())safe_category = re.sub(r'[^\w\s]', '_', category) # 替换特殊字符为下划线save_path = os.path.join(VIS_SAVE_DIR, f'wordcloud_{safe_category}_{timestamp}.png')plt.savefig(save_path, dpi=300, bbox_inches='tight')plt.show()print(f"类别[{category}]词云已保存至:{save_path}")# 保存功能函数
def save_objects(X_train_clean, X_test_clean,tfidf, X_train_vec, X_test_vec,best_k, best_k_score,knn_model,y_pred, cm, acc,categories
):save_dict = {"X_train_clean": X_train_clean,"X_test_clean": X_test_clean,"tfidf_vectorizer": tfidf,"X_train_vec": X_train_vec,"X_test_vec": X_test_vec,"best_k": best_k,"best_k_cv_score": best_k_score,"knn_model": knn_model,"y_pred": y_pred,"confusion_matrix": cm,"test_accuracy": acc,"categories": categories}timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())save_path = os.path.join(SAVE_DIR, f"knn_text_classification_{timestamp}.pkl")with open(save_path, 'wb') as f:pickle.dump(save_dict, f)print(f"\n所有对象已保存至:{save_path}")print("保存内容包含:预处理文本、TF-IDF向量器、文本向量、最优k值、k-NN模型、分类结果")# 加载功能函数
def load_saved_objects(save_file_path):if not os.path.exists(save_file_path):raise FileNotFoundError(f"保存文件不存在:{save_file_path}")with open(save_file_path, 'rb') as f:saved_dict = pickle.load(f)print(f"已加载保存文件:{save_file_path}")return saved_dict# 主函数
def main():# 加载数据X_train, y_train, X_test, y_test, categories = load_data()# 可视化类别分布visualize_class_distribution(y_train, y_test, categories)# 文本预处理print("\n开始文本预处理...")X_train_clean = [preprocess_text(text) for text in X_train]X_test_clean = [preprocess_text(text) for text in X_test]# 可视化文本长度变化visualize_text_length(X_train, X_train_clean)# 文本向量化X_train_vec, X_test_vec, tfidf = vectorize_text(X_train_clean, X_test_clean)# 可视化类别词云visualize_category_wordcloud(X_train_vec, tfidf, y_train, categories)# k值调优best_k, best_k_score = optimize_k(X_train_vec, y_train)# 训练最优k值模型print(f"\n用最优k={best_k}训练模型...")knn_model = train_knn(X_train_vec, y_train, k=best_k)# 模型评估y_pred, cm, test_acc = evaluate_model(knn_model, X_test_vec, y_test, categories)# 相似文本检索示例print("\n===== 相似文本检索示例 =====")query = "hockey game rules and players"similar_texts = retrieve_similar(query, tfidf, knn_model, X_train, y_train, categories, top_k=2)for i, doc in enumerate(similar_texts):print(f"\n第{i + 1}篇相关文本(相似度:{doc['相似度']:.4f})")print(f"类别:{doc['类别']}")print(f"内容:{doc['文本内容']}")# 保存所有关键对象save_objects(X_train_clean=X_train_clean, X_test_clean=X_test_clean,tfidf=tfidf, X_train_vec=X_train_vec, X_test_vec=X_test_vec,best_k=best_k, best_k_score=best_k_score,knn_model=knn_model,y_pred=y_pred, cm=cm, acc=test_acc,categories=categories)# 运行主函数
if __name__ == "__main__":main()
十二、程序运行结果展示
加载到的总样本数:2761
数据加载完成:训练集1656条,测试集1105条

类别分布图已保存至:G:/save_data/20news_visualizations/class_distribution_20251108_182854.png
开始文本预处理...

文本长度分布图已保存至:G:/save_data/20news_visualizations/text_length_distribution_20251108_183004.png
文本向量化完成:向量维度5000

类别[rec.sport.hockey]词云已保存至:G:/save_data/20news_visualizations/wordcloud_rec_sport_hockey_20251108_183054.png

类别[talk.politics.misc]词云已保存至:G:/save_data/20news_visualizations/wordcloud_talk_politics_misc_20251108_183143.png

类别[sci.space]词云已保存至:G:/save_data/20news_visualizations/wordcloud_sci_space_20251108_183232.png

k值影响图已保存至:G:/save_data/20news_visualizations/k_value_impact_20251108_183320.png
最优k值:5,对应交叉验证准确率:0.9758
用最优k=5训练模型...
模型准确率:0.9638

混淆矩阵已保存至:G:/save_data/20news_visualizations/confusion_matrix_20251108_183433.png
分类报告:
precision recall f1-score support
rec.sport.hockey 0.96 0.98 0.97 400
talk.politics.misc 0.96 0.98 0.97 395
sci.space 0.97 0.92 0.95 310
accuracy 0.96 1105
macro avg 0.96 0.96 0.96 1105
weighted avg 0.96 0.96 0.96 1105
===== 相似文本检索示例 =====

相似度检索图已保存至:G:/save_data/20news_visualizations/similarity_search_hockey game rules an_20251108_183609.png
第1篇相关文本(相似度:0.2845)
类别:rec.sport.hockey
内容:From: drozinst@db.erau.edu (Drozinski Tim)
Subject: Re: Ulf and all...
joe13+@pitt.edu (Joseph B Stiehm) writes:
>>the way he does, does not belong in the NHL. There have been cheap shot artists
>...
第2篇相关文本(相似度:0.2721)
类别:rec.sport.hockey
内容:From: filinuk@staff.dccs.upenn.edu (Geoff Filinuk)
Subject: Wash/Isl OT game
This comes indirectly from Al Morgani who works in the studio
for ESPN hockey.
The management of ESPN was reluctant to...
所有对象已保存至:G:/save_data/20news_save/knn_text_classification_20251108_183740.pkl
保存内容包含:预处理文本、TF-IDF向量器、文本向量、最优k值、k-NN模型、分类结果
十三、实验结果分析
(一)实验数据基础分析
1. 数据规模与分布
本次实验基于 20newsgroups 数据集的 3 个核心类别(rec.sport.hockey冰球、talk.politics.misc杂项政治、sci.space太空)展开,数据加载与拆分结果如下:
- 总样本量:2761 条,其中训练集 1656 条(占比 60%),测试集 1105 条(占比 40%),符合文本分类任务经典的训练 - 测试集划分比例。
- 类别分布:从类别分布图可见,三个类别在训练集和测试集中的样本数量相对均衡 —— 训练集中冰球类、政治类、太空类样本数量接近,测试集分别为 400 条、395 条、310 条,无严重类别偏斜。这种均衡分布避免了模型对多数类的过度拟合,为公平评估模型泛化能力提供了基础。
2. 文本预处理效果
文本预处理(去特殊字符、小写化、去停用词、词形还原)的效果通过长度分布直方图得到直观验证:
- 预处理前:文本词数分布较分散,部分文本词数接近 500,包含大量冗余信息(如停用词、无意义符号)。
- 预处理后:文本词数分布更集中,整体词数显著减少,冗余噪声被有效过滤,核心语义信息得以保留。这一结果表明预处理流程成功降低了文本向量的维度和稀疏性,为后续向量化和模型计算提升了效率。
(二)特征提取效果分析
本次实验采用 TF-IDF 进行文本向量化,向量维度限制为 5000,特征提取效果通过类别词云和向量维度验证:
1. 类别特征区分度
三个类别的 TF-IDF 特征词云呈现出极强的主题区分性:
rec.sport.hockey(冰球类):核心特征词集中于hockey(冰球)、team(球队)、league(联盟)、NHL(国家冰球联盟)、player(球员)等,完全贴合冰球运动主题。talk.politics.misc(政治类):核心特征词包括government(政府)、president(总统)、law(法律)、people(民众)等,精准反映政治领域讨论焦点。sci.space(太空类):虽文档中词云描述存在部分字符乱码,但核心语义特征仍围绕太空探索相关主题,与其他两类形成明确区分。
2. 向量维度合理性
TF-IDF 向量维度设置为 5000,既保留了高频核心词汇的特征信息,又避免了过高维度导致的 “维度灾难”。实验结果显示,该维度下模型能有效捕捉类别差异,为 k-NN 算法的余弦相似度计算提供了高质量的特征输入。
(三)参数调优分析
k 值是 k-NN 算法的核心参数,本次实验通过网格搜索(候选 k 值:3、5、7、9、11)+5 折交叉验证进行调优:
1. 调优结果
- 最优 k 值:5,对应交叉验证准确率达 0.9758,为所有候选值中的最高值。
- k 值影响规律:从 k 值对模型性能的影响图可见,k=3 时准确率略低(受噪声样本干扰),k=5 时达到峰值,k>5 后准确率呈轻微下降趋势(过度平滑导致类别区分度降低)。这一规律完全符合 k-NN 算法特性:k 值过小易过拟合,k 值过大易欠拟合,验证了调优结果的合理性。
2. 调优价值
通过自动调优避免了人工试错的主观性,确保模型在 “欠拟合 - 过拟合” 之间达到最优平衡,为后续测试集上的优异表现奠定了参数基础。
(四)模型分类性能分析
1. 整体性能指标
- 测试集准确率:0.9638,即模型对 1105 条测试样本的分类正确率超过 96%,整体性能优异,表明 k-NN 算法结合 TF-IDF 特征能有效解决文本分类问题。
- 宏观平均(macro avg):精确率、召回率、F1 值均为 0.96;加权平均(weighted avg):三项指标同样为 0.96,说明模型对三个类别的分类性能均衡,无明显偏向性。
2. 各类别详细性能
从分类报告可见,三个类别的表现均处于高水准:
rec.sport.hockey(冰球类):精确率 0.96,召回率 0.98,F1 值 0.97,是三类中表现最佳的类别。高召回率表明几乎所有真实冰球类文本都被正确识别,这得益于冰球类主题特征鲜明、核心词汇集中度高。talk.politics.misc(政治类):精确率 0.96,召回率 0.98,F1 值 0.97,与冰球类表现持平。政治类文本的核心议题(如政府、法律)区分度强,模型能有效捕捉其语义特征。sci.space(太空类):精确率 0.97,召回率 0.92,F1 值 0.95,是三类中召回率略低的类别。推测原因是太空类测试集样本量相对较少(310 条),且部分文本可能涉及交叉主题,导致少量样本被误分,但 0.95 的 F1 值仍处于优秀水平。
3. 混淆矩阵分析
混淆矩阵热力图显示,类别间误分情况极少:
- 冰球类(真实):仅少量样本误分为政治类或太空类,误分率低于 2%。
- 政治类(真实):误分数量可忽略,核心原因是政治类文本的语义特征与其他两类差异显著。
- 太空类(真实):存在少量误分至政治类的情况,但整体误分率低于 8%,未影响模型整体性能。
混淆矩阵结果表明,k-NN 算法基于余弦相似度的近邻投票机制能有效区分三个类别的文本,验证了特征提取和参数调优的有效性。
(五)相似文本检索功能分析
以查询 “hockey game rules and players”(冰球比赛规则和球员)为例,检索功能表现如下:
1. 检索准确性
- 检索结果:返回的 2 篇最相关文本均属于
rec.sport.hockey(冰球类),类别匹配准确率 100%,无跨类别误检索。 - 内容相关性:第一篇文本讨论 NHL 球员行为规范,第二篇涉及 ESPN 冰球赛事转播,均与查询中的 “冰球比赛”“球员” 核心主题高度相关,表明 TF-IDF 向量能有效捕捉文本语义相似性。
2. 相似度指标
- 两篇文本的相似度分别为 0.2845 和 0.2721,虽绝对值不高,但考虑到文本语义的复杂性和英文文本的词汇多样性,该相似度已能有效区分相关与无关文本。这一结果也符合 k-NN 算法的相似性判断逻辑 —— 聚焦 “相对相似性” 而非 “绝对阈值”。
(六)实验结论与总结
1. 核心结论
- 数据处理有效性:均衡的类别分布、合理的预处理流程和 TF-IDF 特征提取,为模型提供了高质量的输入数据,是实验成功的基础。
- 模型参数最优性:k=5 是本次任务的最优近邻数量,该参数下模型在 “拟合能力 - 泛化能力” 之间达到平衡,交叉验证准确率 0.9758、测试集准确率 0.9638 均证明了参数调优的价值。
- 功能完整性:不仅实现了高准确率的文本分类(三类 F1 值≥0.95),还具备有效的相似文本检索功能,满足 NLP 基础任务的核心需求。
- 持久化价值:实验保存了预处理文本、TF-IDF 向量器、训练好的模型等关键对象,为后续复用(如新增文本分类、检索功能扩展)提供了便利。
2. 潜在优化方向
- 太空类性能提升:可通过增加太空类样本量、优化文本预处理(如保留专业术语)等方式提升其召回率。
- 相似度阈值调整:相似文本检索可引入动态阈值(如基于训练集相似度分布设定),进一步提升检索结果的相关性。
- 向量维度优化:可尝试调整 TF-IDF 的
max_features参数(如 3000、8000),探索更优的特征维度配置。
3. 实验价值
本次实验验证了 k-NN 算法在小规模文本分类任务中的有效性,其 “简单易实现、无需复杂训练” 的特点使其适合快速落地;同时,完整的可视化分析和数据持久化设计,为同类文本分类任务提供了可复用的实验框架和参考标准。
十四、总结
本文介绍了一个基于k近邻算法的文本分类系统,使用20newsgroups数据集中的3个类别(冰球、政治、太空)作为示例。系统包含完整的文本处理流程:数据加载与预处理、TF-IDF向量化、k值调优(k=5最优)、模型训练与评估。实验结果显示模型准确率达96.38%,并通过词云、混淆矩阵等可视化分析验证了分类效果。系统还支持相似文本检索功能,并实现了关键对象持久化保存。该方案具有实现简单、运行高效的特点,适合小规模文本分类任务的快速落地。
