当前位置: 首页 > news >正文

遥感机器学习入门实战教程 | Sklearn 案例②:PCA + k-NN 分类与评估

延续第①篇:在“仅用训练像素拟合”的前提下,完成降维→分类→评估→整图预测的最小可用工作流,并将所有结果自动保存到时间戳文件夹。

🎯 本篇目标

  • 无数据泄露前提下:仅用训练像素 fit StandardScalerPCA
  • 使用 k-NN 完成分类,输出 OA / AA / Kappaprecision / recall / F1
  • 同时绘制并保存:计数版归一化版混淆矩阵、PCA 累计解释方差曲线
  • 可选:整图像素级预测并保存可视化图与 .npz 结果包
  • 统一一个带时间戳的结果文件夹,便于归档复现实验

📂 数据说明

  • KSC.mat:高光谱数据立方体,形状 (H, W, B)
  • KSC_gt.mat:标签图,形状 (H, W),0 表示背景/无标签,1…C 为类别编号
  • 将二者放入同一目录,例如:
your_path/├─ KSC.mat└─ KSC_gt.mat

只需在脚本里将 DATA_DIR = r"your_path" 改为你的真实路径,其余保持默认即可一键运行。

🧩 方法要点回顾(与第①篇保持一致)

  1. 分层划分:仅在有标签像素上进行训练/测试划分(stratify=labels)。
  2. 无泄露拟合scaler.fitpca.fit 只在训练像素上完成;随后对整图仅做 transform
  3. 统一变换:整图 (H*W, B) 扁平化 → StandardScaler.transformPCA.transform → 还原为 (H, W, PCA_DIM)
  4. 评估:汇报 OA / AA / Kappa 与详细分类报告,同时保存原始与归一化混淆矩阵。
  5. 整图预测(可选):输出 1…C 的类别图,未标注区域可置 0 便于对照。

⚙️ 环境与依赖

  • Python 3.8+
  • numpyscipyscikit-learnmatplotlib
  • 终端安装示例:pip install numpy scipy scikit-learn matplotlib

💻 一键可跑完整脚本(自动保存所有结果)

复制粘贴即可运行。仅需修改第一段参数里的 DATA_DIR

# -*- coding: utf-8 -*-
"""
Sklearn案例②-优化版:无泄露PCA + k-NN分类(结果直存)
数据:KSC / KSC_gt
1) 统一时间戳结果文件夹,所有图表/指标直接保存
2) 可视化优化:高DPI、紧凑布局、清晰坐标与标注
3) 混淆矩阵双版本:计数 & 归一化(百分比)
4) 指标保存:txt + csv;混淆矩阵保存:csv
5) 整图预测:可选,仅保存图片;同时保存预测数组(npz)
"""import os
import time
import json
import numpy as np
import scipy.io as sio
import matplotlib
import matplotlib.pyplot as pltfrom sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import (confusion_matrix, classification_report,accuracy_score, cohen_kappa_score)# ===== 可视化中文支持 =====
matplotlib.rcParams['font.family'] = 'SimHei'
matplotlib.rcParams['axes.unicode_minus'] = False# ===== 参数区(按需修改)=====
DATA_DIR = r"your_path"      # ←←← 修改为你的数据路径
PCA_DIM = 30                  # PCA主成分数
TRAIN_RATIO = 0.3             # 训练占比(仅在有标签像素上分层抽样)
SEED = 42                     # 随机数种子
K = 5                         # k-NN 的 k
DO_FULLMAP = True             # 是否进行整图预测与保存
SAVE_ROOT = os.path.join(DATA_DIR, "knn_results")# ===== 工具:统一保存路径 & 画图助手 =====
def make_save_dir(root):t = time.strftime("%Y%m%d_%H%M%S")d = os.path.join(root, f"run_{t}")os.makedirs(d, exist_ok=True)return ddef save_fig(path, dpi=220):plt.tight_layout()plt.savefig(path, dpi=dpi, bbox_inches="tight")plt.close()def save_text(path, text):with open(path, "w", encoding="utf-8") as f:f.write(text)# ===== 1. 加载数据 =====
X = sio.loadmat(os.path.join(DATA_DIR, "KSC.mat"))["KSC"].astype(np.float32)
Y = sio.loadmat(os.path.join(DATA_DIR, "KSC_gt.mat"))["KSC_gt"].astype(int)
h, w, b = X.shape# 仅取非0标签像素的坐标与标签(转为0-based)
coords = np.argwhere(Y != 0)                    # (N,2)
labels = Y[coords[:, 0], coords[:, 1]] - 1      # (N,)
num_classes = labels.max() + 1# ===== 结果目录 =====
SAVE_DIR = make_save_dir(SAVE_ROOT)
print(f"[INFO] 结果将保存至:{SAVE_DIR}")# ===== 2. 训练/测试分层划分(仅在有标签像素上)=====
train_ids, test_ids = train_test_split(np.arange(len(coords)),train_size=TRAIN_RATIO,stratify=labels,random_state=SEED
)# ===== 3. 无泄露:仅用训练像素拟合 Scaler 与 PCA =====
train_pixels_raw = X[coords[train_ids, 0], coords[train_ids, 1]]   # (N_train, b)
scaler = StandardScaler().fit(train_pixels_raw)
pca = PCA(n_components=PCA_DIM, random_state=SEED).fit(scaler.transform(train_pixels_raw))# 整幅图统一变换(仅transform)
X_flat = X.reshape(-1, b)
X_flat_std = scaler.transform(X_flat)
X_flat_pca = pca.transform(X_flat_std)
X_pca = X_flat_pca.reshape(h, w, PCA_DIM)# 训练/测试样本(从降维后的整图中抽取)
X_train = X_pca[coords[train_ids, 0], coords[train_ids, 1]]
y_train = labels[train_ids]
X_test  = X_pca[coords[test_ids, 0],  coords[test_ids, 1]]
y_test  = labels[test_ids]# ===== 4. k-NN 训练与预测 =====
clf = KNeighborsClassifier(n_neighbors=K, weights='distance', n_jobs=-1)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)# ===== 5. 指标评估并保存 =====
oa = accuracy_score(y_test, y_pred)                 # Overall Accuracy
kappa = cohen_kappa_score(y_test, y_pred)
cm = confusion_matrix(y_test, y_pred, labels=np.arange(num_classes)).astype(np.int64)# AA(平均每类召回率,等价于召回的宏平均)
per_class_recall = np.diag(cm) / np.maximum(cm.sum(axis=1), 1)
aa = float(np.nanmean(per_class_recall))# 保存分类报告与指标(txt)
cls_report = classification_report(y_test, y_pred, digits=4)
metrics_txt = (f"影像尺寸: {h}x{w}x{b}, 有标签像素: {len(labels)}, 类别数: {num_classes}\n"f"PCA_DIM: {PCA_DIM}, 训练占比: {TRAIN_RATIO}, 种子: {SEED}, k: {K}\n\n"f"OA: {oa*100:.2f}%\nAA: {aa*100:.2f}%\nKappa: {kappa:.6f}\n\n"f"分类报告(precision/recall/F1):\n{cls_report}\n"
)
save_text(os.path.join(SAVE_DIR, "metrics_and_report.txt"), metrics_txt)# 另存为csv/json
np.savetxt(os.path.join(SAVE_DIR, "confusion_matrix_counts.csv"), cm, fmt="%d", delimiter=",")
norm_cm = cm / np.maximum(cm.sum(axis=1, keepdims=True), 1)
np.savetxt(os.path.join(SAVE_DIR, "confusion_matrix_normalized.csv"), norm_cm, fmt="%.6f", delimiter=",")with open(os.path.join(SAVE_DIR, "metrics.json"), "w", encoding="utf-8") as f:json.dump({"OA": oa, "AA": aa, "Kappa": kappa}, f, ensure_ascii=False, indent=2)# ===== 6. 可视化:混淆矩阵(计数版)=====
plt.figure(figsize=(7, 6))
plt.imshow(cm, interpolation='nearest')
plt.title("混淆矩阵(计数)")
plt.xlabel("预测类别")
plt.ylabel("真实类别")
cbar = plt.colorbar(fraction=0.046, pad=0.04)
cbar.ax.set_ylabel("像素数", rotation=90)ticks = np.arange(num_classes)
plt.xticks(ticks, ticks)
plt.yticks(ticks, ticks)thresh = cm.max() / 2.0 if cm.max() > 0 else 1
for i in range(num_classes):for j in range(num_classes):val = cm[i, j]plt.text(j, i, str(val),ha="center", va="center",color="white" if val > thresh else "black", fontsize=9)
save_fig(os.path.join(SAVE_DIR, "confusion_matrix_counts.png"))# ===== 7. 可视化:混淆矩阵(归一化百分比)=====
plt.figure(figsize=(7, 6))
plt.imshow(norm_cm, interpolation='nearest', vmin=0, vmax=1)
plt.title("混淆矩阵(按真实类归一化)")
plt.xlabel("预测类别")
plt.ylabel("真实类别")
cbar = plt.colorbar(fraction=0.046, pad=0.04)
cbar.ax.set_ylabel("比例", rotation=90)plt.xticks(ticks, ticks)
plt.yticks(ticks, ticks)for i in range(num_classes):for j in range(num_classes):val = norm_cm[i, j] * 100.0plt.text(j, i, f"{val:.1f}%",ha="center", va="center",color="white" if val > 50 else "black", fontsize=9)
save_fig(os.path.join(SAVE_DIR, "confusion_matrix_normalized.png"))# ===== 8. 可视化:PCA累计解释方差 =====
cum_var = np.cumsum(pca.explained_variance_ratio_)
plt.figure(figsize=(7.2, 4.5))
plt.plot(np.arange(1, len(cum_var) + 1), cum_var, marker='o', linewidth=1.5)
plt.axhline(0.95, linestyle='--', linewidth=1, label="95% 阈值")
plt.axvline(PCA_DIM, linestyle='--', linewidth=1, label=f"n={PCA_DIM}")
plt.xlabel("主成分数")
plt.ylabel("累计解释方差比")
plt.title("PCA累计解释方差曲线")
plt.grid(alpha=0.3)
plt.legend(frameon=False)
save_fig(os.path.join(SAVE_DIR, "pca_cumvar.png"))# ===== 9. 可选:整图预测(仅保存)=====
if DO_FULLMAP:print("[INFO] 正在进行整图像素级预测(可能较耗时)...")pred_flat = clf.predict(X_flat_pca)            # 对 (h*w, PCA_DIM) 预测pred_map_1based = pred_flat.reshape(h, w) + 1  # 1..C# 仅显示有标签区域:未标注位置置0,更利于对照pred_vis = pred_map_1based.copy()pred_vis[Y == 0] = 0# 构建调色板(0为黑色背景,1..C使用tab20循环)from matplotlib.colors import ListedColormapbase_cmap = plt.get_cmap('tab20')colors = [(0, 0, 0, 1)] + [base_cmap(i % 20) for i in range(num_classes)]cmap = ListedColormap(colors)plt.figure(figsize=(8.2, 6.2))im = plt.imshow(pred_vis, cmap=cmap, vmin=0, vmax=num_classes)plt.title("k-NN 预测整图(0=背景,1..C=类别,仅渲染有标注区域)")plt.axis('off')cbar = plt.colorbar(im, fraction=0.046, pad=0.04,ticks=np.arange(0, num_classes + 1, max(1, num_classes // 10)))cbar.ax.set_ylabel("类别ID", rotation=90)save_fig(os.path.join(SAVE_DIR, "full_prediction_masked.png"))# 保存原始预测数组,便于后续分析/制图(含两版)np.savez_compressed(os.path.join(SAVE_DIR, "predictions.npz"),pred_map_1based=pred_map_1based.astype(np.int16),pred_vis_masked=pred_vis.astype(np.int16))# ===== 10. 记录运行配置 =====
run_cfg = {"DATA_DIR": DATA_DIR,"PCA_DIM": PCA_DIM,"TRAIN_RATIO": TRAIN_RATIO,"SEED": SEED,"K": K,"DO_FULLMAP": DO_FULLMAP,"H": int(h), "W": int(w), "BANDS": int(b),"CLASSES": int(num_classes),
}
with open(os.path.join(SAVE_DIR, "run_config.json"), "w", encoding="utf-8") as f:json.dump(run_cfg, f, ensure_ascii=False, indent=2)print(f"[DONE] 所有结果已保存:{SAVE_DIR}")

🔍 关键步骤逐段解读

① 训练/测试分层划分(仅对有标签像素)

coords = np.argwhere(Y != 0)
labels = Y[coords[:, 0], coords[:, 1]] - 1
train_ids, test_ids = train_test_split(np.arange(len(coords)),train_size=TRAIN_RATIO,stratify=labels,random_state=SEED
)
  • coords 是所有有标签像素的坐标集合;
  • labels 转为 0-based,有利于 sklearn;
  • stratify=labels 保障训练/测试的类别比例一致。

② 仅用训练像素 fit 标准化与 PCA

train_pixels_raw = X[coords[train_ids, 0], coords[train_ids, 1]]
scaler = StandardScaler().fit(train_pixels_raw)
pca = PCA(n_components=PCA_DIM, random_state=SEED).fit(scaler.transform(train_pixels_raw)
)
  • 核心原则fit 严格只用训练像素,避免测试信息泄露;
  • 顺序:先标准化再 PCA。

③ 整图统一变换(只 transform

X_flat = X.reshape(-1, b)
X_flat_std = scaler.transform(X_flat)
X_flat_pca = pca.transform(X_flat_std)
X_pca = X_flat_pca.reshape(h, w, PCA_DIM)
  • 将整图展平后做相同参数的标准化与 PCA;
  • 变换完成后,再还原形状以便抽取像素或做整图预测。

④ 训练 k-NN 并评估

clf = KNeighborsClassifier(n_neighbors=K, weights='distance', n_jobs=-1)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
  • weights='distance' 常能缓解近邻投票的硬边界;
  • 指标含 OA / AA / Kappa 与完整分类报告。

⑤ 可视化与结果保存

  • 混淆矩阵:计数与按真实类归一化两个版本;
  • PCA 累计解释方差:辅助选择 PCA_DIM
  • 所有图表与表格自动保存,不阻塞运行。

⑥ 整图预测(可选)

pred_flat = clf.predict(X_flat_pca)
pred_map_1based = pred_flat.reshape(h, w) + 1
pred_vis = pred_map_1based.copy()
pred_vis[Y == 0] = 0
  • 输出整图预测,同时将未标注像素置 0 便于与 KSC_gt 对照;
  • 保存 .npz,方便后续在 GIS / 制图环境复用。

📦 运行后你将得到

位于 your_path/knn_results/run_YYYYmmdd_HHMMSS/ 的结果集,包括:

  • 图片:confusion_matrix_counts.pngconfusion_matrix_normalized.pngpca_cumvar.pngfull_prediction_masked.png
  • 表格/文本:metrics_and_report.txtconfusion_matrix_counts.csvconfusion_matrix_normalized.csvmetrics.json
  • 数据:predictions.npz(整图预测数组)
  • 配置:run_config.json(记录全部关键参数)
    在这里插入图片描述

结果展示

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

🔗 下一篇预告

第③篇将基于第②篇的特征与划分设置,引入SVM / 随机森林等经典模型,进行对比实验与调参,形成稳定的基线组合

http://www.dtcms.com/a/337045.html

相关文章:

  • AWS Neptune:图数据库的强大潜力
  • 【LLM1】大型语言模型的基本生成机制
  • 将 iPhone 连接到 Windows 11 的完整指南
  • Chromium base 库中的 Observer 模式实现:ObserverList 与 ObserverListThreadSafe 深度解析
  • AI 在金融领域的落地案例
  • 强化学习-CH2 状态价值和贝尔曼等式
  • 算法详细讲解:数据结构 - 单链表与双链表
  • Nacos-6--Naco的QUIC协议实现高可用的工作原理
  • cesium中实时获取鼠标精确坐标和高度
  • IB数学课程知识点有哪些?IB数学课程辅导机构怎么选?
  • GitLab 安全漏洞 CVE-2025-7739 解决方案
  • GitLab 安全漏洞 CVE-2025-6186 解决方案
  • AI全链路赋能:smardaten2.0实现软件开发全流程智能化突破
  • Leetcode 3651. Minimum Cost Path with Teleportations
  • 嵌入式 C++ 语言编程规范文档个人学习版(参考《Google C++ 编码规范中文版》)
  • USB基础 -- 字符串描述符 (String Descriptor) 系统整理文档
  • 2025年8月更新!Windows 7 旗舰版 (32位+64位 轻度优化+离线驱动)
  • hla mHAg
  • cortex-m中断技巧
  • 数组学习2
  • 十年回望:Vue 与 React 的设计哲学、演进轨迹与生态博弈
  • idea部署到docker
  • 静配中心配药智能化:基于高并发架构的Go语言实现
  • MySQL 函数大赏:聚合、日期、字符串等函数剖析
  • Ps切片后无法导出原因(存储为web所用格式)为灰色,及解决文案
  • Day119 持续集成docker+jenkins
  • Dockerfile优化指南:利用多阶段构建将Docker镜像体积减小90%
  • 【音频信号发生器】基本应用
  • LAMP 架构部署:Linux+Apache+MariaDB+PHP
  • C# 使用注册表开机自启