当前位置: 首页 > news >正文

深度学习与遥感入门(七)|CNN vs CNN+形态学属性(MP):特征工程到底值不值?

前置内容见链接:https://mp.weixin.qq.com/mp/appmsgalbum?__biz=MzkwMTE0MjI4NQ==&action=getalbum&album_id=3959684907432935425#wechat_redirect

前面的内容我们介绍了很多特征挖掘的内容,那么今天主要通过实验证明特征挖掘到底有没有用。

在遥感图像分类任务中,模型性能的提升往往离不开高质量的特征。当深度学习遇上传统特征工程,会碰撞出怎样的火花?今天我们就通过对比纯CNN模型与融合了形态学属性(Morphological Profiles, MP)的CNN模型,来探讨特征工程在遥感分类中的价值。

一、为什么要做特征挖掘?

遥感图像包含丰富的光谱、空间和纹理信息,但原始数据往往存在维度高、冗余多、噪声干扰等问题。直接将原始数据输入模型,不仅会增加计算负担,还可能因为无效信息过多而影响分类精度。

特征挖掘(特征工程)的核心价值在于:

  • 降维提质:保留关键信息,剔除冗余噪声
  • 显式建模:将领域知识转化为可量化的特征
  • 辅助学习:为模型提供更易学习的表示形式

在遥感领域,形态学属性(MP)就是一种经典的特征工程方法——它通过数学形态学运算捕捉图像中的空间结构信息,非常适合处理遥感图像中常见的地物轮廓、纹理等特征。

二、实验设计:严格对比CNN与CNN+MP

我们以KSC(肯尼迪航天中心)高光谱数据集为研究对象,设计严格的对比实验:

  • 数据集:KSC高光谱图像(含176个波段,614×512像素)及对应地物标签
  • 核心对比
    1. 仅使用PCA降维后的光谱特征(CNN模型)
    2. 融合PCA特征与多尺度形态学属性(CNN+MP模型)
  • 评估指标:总体精度(OA)、平均精度(AA)、Kappa系数、推理速度

三、代码逐段解析

1. 基础工具与参数设置

首先导入必要的库并设置基础工具函数:

import os, time, random
import numpy as np
import scipy.io as sio
import scipy.ndimage as ndi
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (classification_report, confusion_matrix, accuracy_score,cohen_kappa_score)
from sklearn.model_selection import train_test_splitimport matplotlib
if os.name == 'nt':try:matplotlib.use('TkAgg')except Exception:matplotlib.use('Agg')
else:matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
matplotlib.rcParams['font.family'] = 'SimHei'
matplotlib.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 120
sns.set_theme(context="notebook", style="whitegrid", font="SimHei")torch.backends.cudnn.benchmark = True

基础工具函数:

  • set_seeds:固定随机种子,保证实验可复现
  • average_accuracy:计算平均精度(AA)
  • make_disk:生成形态学运算所需的圆形结构元素
def set_seeds(seed=42):random.seed(seed); np.random.seed(seed)torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)def average_accuracy(y_true, y_pred, num_classes):cm = confusion_matrix(y_true, y_pred, labels=list(range(num_classes)))with np.errstate(divide='ignore', invalid='ignore'):per_class = np.diag(cm) / cm.sum(axis=1).clip(min=1)return float(np.nanmean(per_class))def make_disk(radius: int) -> np.ndarray:if radius <= 0: return np.array([[1]], dtype=bool)r = radiusy, x = np.ogrid[-r:r+1, -r:r+1]return (x*x + y*y <= r*r)

2. 形态学属性(MP)特征构建

形态学属性通过多尺度的开运算(Opening)和闭运算(Closing)提取图像的空间结构特征:

  • 开运算(去除小亮结构):先腐蚀后膨胀
  • 闭运算(填充小暗区域):先膨胀后腐蚀
def build_morphological_profiles(X_pca_img: np.ndarray,k_components: int = 8,radii=(1, 3, 5, 7)) -> np.ndarray:"""输入:X_pca_img(H,W,C)。对前k个主成分做多尺度灰度开/闭,构造属性差分通道并拼接。"""H, W, C = X_pca_img.shapek = int(min(k_components, C))feats = []for i in range(k):pc = X_pca_img[..., i]  # 取第i个主成分for r in radii:  # 多尺度处理fp = make_disk(r)  # 生成半径为r的圆形结构元素opened = ndi.grey_opening(pc, footprint=fp)  # 灰度开运算closed = ndi.grey_closing(pc, footprint=fp)  # 灰度闭运算# 存储差分特征:原始-开运算(突出亮结构)、闭运算-原始(突出暗结构)feats.append((pc - opened).astype(np.float32, copy=False))feats.append((closed - pc).astype(np.float32, copy=False))if not feats:  # k=0时直接返回原始PCA特征return X_pca_imgmp_stack = np.stack(feats, axis=-1).astype(np.float32, copy=False)# 拼接原始PCA特征与形态学特征return np.concatenate([X_pca_img, mp_stack], axis=-1).astype(np.float32, copy=False)

3. 数据集与模型定义

(1)高光谱图像补丁数据集
class HSIPatchDataset(Dataset):def __init__(self, patches, labels):# 转换为PyTorch张量,通道维度前置(H,W,C→C,H,W)self.X = torch.tensor(patches, dtype=torch.float32).permute(0, 3, 1, 2)self.y = torch.tensor(labels, dtype=torch.long)def __len__(self): return len(self.y)def __getitem__(self, idx): return self.X[idx], self.y[idx]
(2)简单CNN模型

采用轻量级CNN架构,适合高光谱图像分类任务:

class SimpleCNN(nn.Module):def __init__(self, in_ch, num_classes, width=32):super().__init__()self.net = nn.Sequential(# 第一组卷积:输入通道→32通道nn.Conv2d(in_ch, width, 3, padding=1, bias=False), nn.BatchNorm2d(width), nn.ReLU(inplace=True),# 第二组卷积:32→32通道nn.Conv2d(width, width, 3, padding=1, bias=False), nn.BatchNorm2d(width), nn.ReLU(inplace=True),nn.MaxPool2d(2, ceil_mode=True),  # 池化降维# 第三组卷积:32→64通道nn.Conv2d(width, width*2, 3, padding=1, bias=False), nn.BatchNorm2d(width*2), nn.ReLU(inplace=True),nn.AdaptiveAvgPool2d(1)  # 自适应池化到1×1)self.fc = nn.Linear(width*2, num_classes)  # 分类头def forward(self, x):x = self.net(x).flatten(1)  # 展平特征return self.fc(x)

4. 训练与评估函数

(1)模型评估函数
@torch.no_grad()  # 关闭梯度计算,加速推理
def evaluate(model, loader, device):model.eval()  # 切换到评估模式ys, ps = [], []for xb, yb in loader:xb = xb.to(device)pred = model(xb).argmax(dim=1).cpu().numpy()  # 取预测概率最大的类别ys.extend(yb.numpy()); ps.extend(pred)return accuracy_score(ys, ps), np.array(ys), np.array(ps)
(2)模型训练函数
def train_model(name, model, train_loader, test_loader, device,epochs=20, lr=1e-3, weight_decay=1e-4):model = model.to(device)  # 移动模型到设备# 优化器与学习率调度器opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)sch = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='max', factor=0.5, patience=3)crit = nn.CrossEntropyLoss()  # 交叉熵损失best_acc, best_path = 0.0, f"best_{name}.pth"  # 保存最佳模型for ep in range(1, epochs + 1):model.train(); loss_sum = 0.0  # 切换到训练模式for xb, yb in train_loader:xb, yb = xb.to(device), yb.to(device)opt.zero_grad()  # 清零梯度loss = crit(model(xb), yb)  # 计算损失loss.backward(); opt.step()  # 反向传播与参数更新loss_sum += loss.item() * xb.size(0)  # 累计损失# 评估测试集精度acc, _, _ = evaluate(model, test_loader, device)sch.step(acc)  # 根据精度调整学习率print(f"[{name}] Epoch {ep:02d}/{epochs} | Loss {loss_sum/len(train_loader.dataset):.4f} | Acc {acc:.4f}")# 保存最佳模型if acc > best_acc:best_acc = acctorch.save(model.state_dict(), best_path)# 加载最佳模型权重try:state = torch.load(best_path, map_location=device, weights_only=True)except TypeError:state = torch.load(best_path, map_location=device)model.load_state_dict(state)return model, best_acc

5. 全图预测函数

@torch.inference_mode()
def predict_full_image_by_coords(model, X_img_feat, patch_size, device,batch_size=4096, title="FullPred", show=True):H, W, C = X_img_feat.shapem = patch_size // 2padded = np.pad(X_img_feat, ((m, m), (m, m), (0, 0)), mode='reflect')coords = np.mgrid[0:H, 0:W].reshape(2, -1).Tpred_map = np.zeros((H, W), dtype=np.int32)t0 = time.time()for i in range(0, len(coords), batch_size):batch_coords = coords[i:i + batch_size]patches = np.empty((len(batch_coords), patch_size, patch_size, C), dtype=np.float32)for k, (r, c) in enumerate(batch_coords):patches[k] = padded[r:r + patch_size, c:c + patch_size, :]tensor = torch.from_numpy(patches).permute(0, 3, 1, 2).to(device)preds = model(tensor).argmax(dim=1).cpu().numpy() + 1for (r, c), p in zip(batch_coords, preds):pred_map[r, c] = pelapsed = time.time() - t0try:fig = plt.figure(figsize=(10, 7.5))cmap = matplotlib.colormaps.get_cmap('tab20')vmin, vmax = pred_map.min(), pred_map.max()if vmin == vmax: vmin, vmax = 0, 1im = plt.imshow(pred_map, cmap=cmap, interpolation='nearest', vmin=vmin, vmax=vmax)cbar = plt.colorbar(im, shrink=0.85); cbar.set_label('预测类别', rotation=90)plt.title(title, fontsize=14, weight='bold'); plt.axis('off'); plt.tight_layout()if matplotlib.get_backend().lower() == 'agg':plt.savefig(f"{title}_map.png", bbox_inches='tight')else:plt.show(block=True)plt.close(fig)except Exception as e:print(f"[全图可视化失败] {e}")return pred_map, elapsed

6. 主流程执行

def main():set_seeds(42)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(f"使用设备:{device}")# 配置参数DATA_DIR = r" "  # 改成你的路径X_FILE,  Y_FILE  = "KSC.mat", "KSC_gt.mat"PCA_DIM        = 30PATCH_SIZE     = 5TRAIN_RATIO    = 0.30EPOCHS         = 20BATCH_SIZE     = 64LR             = 1e-3WEIGHT_DECAY   = 1e-4NUM_WORKERS    = 0 if os.name == 'nt' else min(4, os.cpu_count() or 0)PIN_MEMORY     = (device.type == 'cuda')PREDICT_BATCH_SIZE = 4096MP_K           = 8              # 用于构建MP的前K个主成分RADII          = (1, 3, 5, 7)   # MP半径# 数据导入(严格按指定键名加载)print("加载数据(固定键名:KSC / KSC_gt)...")X_mat = sio.loadmat(os.path.join(DATA_DIR, X_FILE))Y_mat = sio.loadmat(os.path.join(DATA_DIR, Y_FILE))if "KSC" not in X_mat:raise KeyError(f"{X_FILE} 中未找到键 'KSC'。请检查文件或键名。")if "KSC_gt" not in Y_mat:raise KeyError(f"{Y_FILE} 中未找到键 'KSC_gt'。请检查文件或键名。")X_img = X_mat["KSC"]Y_img = Y_mat["KSC_gt"]# 形状断言与数据类型转换assert X_img.ndim == 3 and Y_img.ndim == 2, "期望 X:(H,W,B), Y:(H,W)"h, w, bands = X_img.shapeprint(f"数据尺寸: {h}×{w}, 波段: {bands}")X_img = X_img.astype(np.float32, copy=False)# 有标签索引与分层划分labeled_rc = np.array([(i, j) for i in range(h) for j in range(w) if Y_img[i, j] != 0])labels_all = np.array([Y_img[i, j] - 1 for i, j in labeled_rc], dtype=np.int64)num_classes = int(np.max(labels_all) + 1)print(f"有标签样本: {len(labeled_rc)},类别数: {num_classes}")train_ids, test_ids = train_test_split(np.arange(len(labeled_rc)), test_size=1 - TRAIN_RATIO,stratify=labels_all, random_state=42)# 无泄露 StandardScaler + PCA(仅训练像素拟合)print("拟合 StandardScaler/PCA(仅训练像素)...")train_pixels = np.array([X_img[i, j] for i, j in labeled_rc[train_ids]], dtype=np.float32)scaler = StandardScaler().fit(train_pixels)pca = PCA(n_components=PCA_DIM, random_state=42).fit(scaler.transform(train_pixels))X_pca_img = pca.transform(scaler.transform(X_img.reshape(-1, bands))).astype(np.float32)X_pca_img = X_pca_img.reshape(h, w, PCA_DIM)# 构建形态学属性特征print("构建形态学属性(Morphological Profiles)...")X_feat_img = build_morphological_profiles(X_pca_img, k_components=MP_K, radii=RADII)print(f"PCA 通道: {X_pca_img.shape[-1]} -> 增强后通道: {X_feat_img.shape[-1]}")# Patch 提取函数def extract_patches(X_img_any, sel_ids):m = PATCH_SIZE // 2H, W, C = X_img_any.shapepadded = np.pad(X_img_any, ((m, m), (m, m), (0, 0)), mode='reflect')patches = np.empty((len(sel_ids), PATCH_SIZE, PATCH_SIZE, C), dtype=np.float32)labs = np.empty((len(sel_ids),), dtype=np.int64)for n, k in enumerate(sel_ids):i, j = labeled_rc[k]patches[n] = padded[i:i + PATCH_SIZE, j:j + PATCH_SIZE, :]labs[n] = labels_all[k]return patches, labs# 提取训练与测试补丁Xtr_pca,  ytr = extract_patches(X_pca_img,  train_ids)Xte_pca,  yte = extract_patches(X_pca_img,  test_ids)Xtr_feat, _   = extract_patches(X_feat_img, train_ids)Xte_feat, _   = extract_patches(X_feat_img, test_ids)# 创建数据加载器train_loader_pca = DataLoader(HSIPatchDataset(Xtr_pca,  ytr), batch_size=BATCH_SIZE,shuffle=True,  num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)test_loader_pca  = DataLoader(HSIPatchDataset(Xte_pca,  yte), batch_size=BATCH_SIZE,shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)train_loader_feat = DataLoader(HSIPatchDataset(Xtr_feat, ytr), batch_size=BATCH_SIZE,shuffle=True,  num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)test_loader_feat  = DataLoader(HSIPatchDataset(Xte_feat, yte), batch_size=BATCH_SIZE,shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)# 模型对比实验results = []# 1) CNN(仅 PCA)in_ch_pca = X_pca_img.shape[-1]model_pca = SimpleCNN(in_ch=in_ch_pca, num_classes=num_classes, width=32)print("\n===== 训练:CNN(仅 PCA)=====")model_pca, _ = train_model("CNN_PCA", model_pca, train_loader_pca, test_loader_pca,device, epochs=EPOCHS, lr=LR, weight_decay=WEIGHT_DECAY)acc_pca, y_true_pca, y_pred_pca = evaluate(model_pca, test_loader_pca, device)aa_pca = average_accuracy(y_true_pca, y_pred_pca, num_classes)kappa_pca = cohen_kappa_score(y_true_pca, y_pred_pca)# 单 patch 推理时间测试model_pca.eval()warm = torch.randn(32, in_ch_pca, PATCH_SIZE, PATCH_SIZE).to(device); _ = model_pca(warm)t0 = time.time(); xb = torch.randn(512, in_ch_pca, PATCH_SIZE, PATCH_SIZE).to(device); _ = model_pca(xb)patch_sec_pca = (time.time() - t0) / 512.0# 全图预测print("[CNN(仅 PCA)] 全图预测...")pred_map_pca, full_sec_pca = predict_full_image_by_coords(model_pca, X_pca_img, patch_size=PATCH_SIZE, device=device,batch_size=PREDICT_BATCH_SIZE, title="CNN_PCA_FullPred",show=(matplotlib.get_backend().lower() != 'agg'))results.append({"Model": "CNN(PCA)", "OA": acc_pca, "AA": aa_pca, "Kappa": kappa_pca,"PatchSec": patch_sec_pca, "FullImgSec": full_sec_pca})# 2) CNN+MPin_ch_feat = X_feat_img.shape[-1]model_feat = SimpleCNN(in_ch=in_ch_feat, num_classes=num_classes, width=32)print("\n===== 训练:CNN + 形态学属性(PCA+MP)=====")model_feat, _ = train_model("CNN_PCA_MP", model_feat, train_loader_feat, test_loader_feat,device, epochs=EPOCHS, lr=LR, weight_decay=WEIGHT_DECAY)acc_feat, y_true_feat, y_pred_feat = evaluate(model_feat, test_loader_feat, device)aa_feat = average_accuracy(y_true_feat, y_pred_feat, num_classes)kappa_feat = cohen_kappa_score(y_true_feat, y_pred_feat)# 单 patch 推理时间测试model_feat.eval()warm = torch.randn(32, in_ch_feat, PATCH_SIZE, PATCH_SIZE).to(device); _ = model_feat(warm)t0 = time.time(); xb = torch.randn(512, in_ch_feat, PATCH_SIZE, PATCH_SIZE).to(device); _ = model_feat(xb)patch_sec_feat = (time.time() - t0) / 512.0# 全图预测print("[CNN+MP] 全图预测...")pred_map_feat, full_sec_feat = predict_full_image_by_coords(model_feat, X_feat_img, patch_size=PATCH_SIZE, device=device,batch_size=PREDICT_BATCH_SIZE, title="CNN_PCA_MP_FullPred",show=(matplotlib.get_backend().lower() != 'agg'))results.append({"Model": "CNN(PCA+MP)", "OA": acc_feat, "AA": aa_feat, "Kappa": kappa_feat,"PatchSec": patch_sec_feat, "FullImgSec": full_sec_feat})# 结果展示与可视化df = pd.DataFrame(results).sort_values(by="OA", ascending=False)print("\n=== 对比结果(CNN vs CNN+MP) ===")print(df.to_string(index=False))# 混淆矩阵绘制for (name, y_true, y_pred) in [("CNN(PCA)", y_true_pca, y_pred_pca),("CNN(PCA+MP)", y_true_feat, y_pred_feat)]:try:plt.figure(figsize=(10, 7))class_names = [f"类{i+1}" for i in range(num_classes)]sns.heatmap(confusion_matrix(y_true, y_pred), annot=True, fmt='d', cmap="Blues",xticklabels=class_names, yticklabels=class_names, cbar=False, square=True)plt.xlabel("预测标签"); plt.ylabel("真实标签")plt.title(f"{name} 测试集混淆矩阵", fontsize=14, weight='bold')if matplotlib.get_backend().lower() == 'agg':plt.tight_layout(); plt.savefig(f"{name}_confusion.png", bbox_inches='tight')else:plt.tight_layout(); plt.show(block=True)plt.close()except Exception as e:print(f"[混淆矩阵绘制失败] {name}{e}")# OA/耗时条形图try:fig1 = plt.figure(figsize=(7.5, 4.5))ax1 = sns.barplot(data=df, x="Model", y="OA")ax1.set_ylim(0, 1.0); ax1.set_title("OA 对比(越高越好)", fontsize=13, weight='bold')for p in ax1.patches:ax1.annotate(f"{p.get_height():.3f}", (p.get_x()+p.get_width()/2, p.get_height()),ha='center', va='bottom', fontsize=9)plt.tight_layout()if matplotlib.get_backend().lower() == 'agg':plt.savefig("compare_OA.png", bbox_inches='tight')else:plt.show(block=True)plt.close(fig1)fig2 = plt.figure(figsize=(7.5, 4.5))ax2 = sns.barplot(data=df, x="Model", y="FullImgSec")ax2.set_title("全图预测耗时(秒,越低越好)", fontsize=13, weight='bold')for p in ax2.patches:ax2.annotate(f"{p.get_height():.1f}s", (p.get_x()+p.get_width()/2, p.get_height()),ha='center', va='bottom', fontsize=9)plt.tight_layout()if matplotlib.get_backend().lower() == 'agg':plt.savefig("compare_FullImgSec.png", bbox_inches='tight')else:plt.show(block=True)plt.close(fig2)except Exception as e:print(f"[对比图绘制失败] {e}")print("\n完成。")# 程序入口
if __name__ == "__main__":try:import multiprocessing as mpmp.set_start_method("spawn", force=True)mp.freeze_support()except Exception:passmain()

四、实验结果对比

1. 精度对比

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

从结果可以看出,融合形态学属性(MP)后,模型精度显著提升。

这说明显式的特征工程(如MP)能够为模型提供额外的空间结构信息,尤其对遥感图像中形状、纹理差异明显的地物(如建筑物、植被、水体)分类效果提升显著。

2. 分类图

在这里插入图片描述
在这里插入图片描述

五、结论:特征工程到底值不值?

  1. 在遥感领域,特征工程依然重要
    深度学习模型擅长自动学习特征,但遥感数据的特殊性(高维、强噪声、空间-光谱耦合)使得领域知识驱动的特征工程(如MP)能够有效辅助模型学习,尤其在样本有限时效果更明显。

  2. 精度与效率的平衡艺术
    特征工程会增加计算成本,但对于遥感分类这类精度优先的任务,适当的特征增强带来的精度提升远超过效率损失

  3. 最佳实践
    不要盲目依赖端到端学习,也不要固守传统特征工程。将深度学习的自动特征学习能力与领域知识驱动的特征工程相结合,才能在遥感任务中取得更好的效果。

欢迎大家关注下方我的公众获取更多内容!

http://www.dtcms.com/a/328950.html

相关文章:

  • 一键自动化:Kickstart无人值守安装指南
  • 【unitrix数间混合计算】2.20 比较计算(cmp.rs)
  • Spring Boot (v3.2.12) + application.yml + jasypt 数据源加密连接设置实例
  • 25个自动化办公脚本合集(覆盖人工智能、数据处理、文档管理、图片处理、文件操作等)
  • 【电气】NPN与PNP
  • [C语言]第二章-从Hello World到头文件
  • 四分位数与箱线图
  • Redis持久化机制详解:RDB与AOF的全面对比与实践指南
  • 动静态库
  • FPGA的PS基础1
  • 【FPGA】初始Verilog HDL
  • c++编程题-笔记
  • kali linux 2025.2安装Matlab的详细教程
  • 通过限制网络访问来降低服务器被攻击风险的方法
  • 服务器如何应对SYN Flood攻击?
  • FluxApi - 使用Spring进行调用Flux接口
  • Gradle(三)创建一个 SpringBoot 项目
  • 深度学习(3):全连接神经网络构建
  • mysql的快照读与当前读的区别
  • 11G RAC数据文件创建到本地如何处理
  • 【C语言强化训练16天】--从基础到进阶的蜕变之旅:Day3
  • 《算法导论》第 22 章 - 基本的图算法
  • [AXI5]AXI协议中的Scalar atomic和Vector atomic有什么区别?
  • 【算法】位运算经典例题
  • BM25:概率检索框架下的经典相关性评分算法
  • ADB 无线调试连接(Windows + WSL 环境)
  • 如何在VS里使用MySQL提供的mysql Connector/C++的debug版本
  • C++ 优选算法 力扣 209.长度最小的子数组 滑动窗口 (同向双指针)优化 每日一题 详细题解
  • Java Spring框架最新版本及发展史详解(截至2025年8月)-优雅草卓伊凡
  • graphql接口快速使用postman添加接口以及输入返回参数