超越CNN:GCN如何重塑图像处理
目录
写在前面
一、GCN处理图像的优势
二、构建数据集
三、定义模型
四、训练代码
五、推理代码
六、总结
写在前面
GCN 用于图像处理时,并没有 CNN 中 “固定形状、滑动遍历” 的卷积核,但存在承担 “特征变换” 功能的权重矩阵,其作用与 CNN 卷积核的 “参数化特征提取” 本质相通,只是适配图结构的操作形式不同。
下面我将用GCN完成一个简单的图像分类任务,这项任务的核心是数据处理——构建图数据。
要了解GCN的基础知识,戳这里:一图看懂图卷积网络GCN
要了解GCN处理图像的计算过程,戳这里:图卷积网络GCN:图像理解的新视角
一、GCN处理图像的优势
那GCN比CNN有哪些优势呢?举个例子,假设你要识别猫的图片。
CNN 会扫描整张图 → 识别猫 → 识别椅子 → 组合信息,它看到的是一个规则的像素网格。优点是对规则、结构一致的图像(例如方形照片)非常高效。缺点是它只能处理“规则网格”,也就是固定排列的像素点。
GCN 把图像看作一个图(Graph),节点(Node)可以是每个像素、每个超像素(superpixel)或者图像里的某些“关键区域”。边(Edge)代表这些节点之间的关系,比如:像素之间颜色相似度、空间距离、是否属于同一物体等。然后,GCN 在图上传播信息——每个节点都会根据它的邻居更新自己的特征。
GCN 把“猫的身体各部分”“椅子的腿”“背景”等区域当作节点;根据它们之间的关系(比如“猫在椅子上”)建立边;让这些节点相互传递信息;最终得出一个更“结构化”的理解。GCN 知道:“猫”和“椅子”不是孤立存在的,它们之间有关系。这类“结构关系”是 CNN 不容易直接捕捉的。
二、SLIC 算法
下面步骤会用到SLIC 算法,这里简单介绍一下。
SLIC(Simple Linear Iterative Clustering) 是一种常用的 超像素分割算法。
它的作用是把一张图像切成一堆颜色相近、空间相邻的小块区域(称为“超像素”)。这些超像素比像素更“聪明”——每个块大致代表图像中的一个局部区域(比如一块天空、一片草地、一只眼睛)。它能让后续算法(如 GCN、目标检测、分割)更高效地处理图像结构。
SLIC 的核心思想很简单:在颜色空间和空间位置上,把相似的像素聚成一类。它本质上是 在五维空间中做 K-means 聚类。
这 5 个维度是:
-
三个颜色维度(通常是 Lab 空间中的 L, a, b);
-
两个位置维度(像素的 x, y 坐标)。
所以每个像素都可以表示为一个五维向量:(L,a,b,x,y)
计算步骤:
1.初始化聚类中心
-
把图像均匀分成若干个格子;
-
在每个格子的中心挑一个像素作为初始聚类中心。
2.定义距离度量(颜色+空间)
对每个像素,计算它到聚类中心的“距离”:
其中:
-
:颜色差(Lab空间)
-
:空间距离
-
S:超像素的期望大小
-
m:平衡系数(控制颜色 vs 空间的重要性)
当 m小时,更注重颜色一致;当 m大时,更注重空间连续。
3.分配像素
每个像素根据距离 D 选择最近的中心归类。
4.更新聚类中心
对每个聚类,重新计算平均的颜色和坐标,然后更新中心点。
5.迭代
重复分配和更新,直到聚类中心稳定。
6.后处理(可选)
去掉孤立的小块,保证区域连通。
直观理解(打个比方)
想象你在画布上撒满了彩色小珠子:SLIC 就像在画布上放很多小“吸铁石”,每个吸铁石会吸引周围 颜色相近 的小珠子;经过几轮吸附和调整,画布就被自然地分成了一些 颜色块 —— 这就是超像素。
四、构建数据集
这里是任务的核心——生成“超像素”或者“关键区域”。我们可以使用目标检测或者分割模型先识别出猫和椅子等物体作为“超像素”,然后再建立联系;这里我们简化问题,使用基于颜色的SLIC 算法来生成“超像素”。
通常,一个图像对应一个 Data
对象(来自 torch_geometric.data.Data
)。我们可以把多个图像封装进 Dataset
。构建数据集代码:
from torch_geometric.data import Dataset, DataLoader
from skimage.segmentation import slic
from skimage.color import rgb2lab
import numpy as np
from PIL import Image
import os
import torchclass CatDogGraphDataset(Dataset):def __init__(self, root_dir):self.root_dir = root_dirself.samples = []for label, cls in enumerate(['cat', 'dog']):folder = os.path.join(root_dir, cls)for img_name in os.listdir(folder):if img_name.endswith('.jpg'):self.samples.append((os.path.join(folder, img_name), label))def __len__(self):return len(self.samples)def __getitem__(self, idx):img_path, label = self.samples[idx]img = Image.open(img_path).convert("RGB").resize((64, 64))img_np = np.array(img)# 生成超像素segments = slic(img_np, n_segments=75, compactness=10)num_nodes = segments.max() + 1# 节点特征lab_img = rgb2lab(img_np)node_features = []for i in range(num_nodes):mask = segments == imean_color = lab_img[mask].mean(axis=0)node_features.append(mean_color)x = torch.tensor(node_features, dtype=torch.float)# 构造边(相邻的超像素)edges = set()for i in range(63):for j in range(63):if segments[i, j] != segments[i, j + 1]:edges.add((segments[i, j], segments[i, j + 1]))if segments[i, j] != segments[i + 1, j]:edges.add((segments[i, j], segments[i + 1, j]))if len(edges) == 0:edges.add((0, 0)) # 避免空图edge_index = torch.tensor(list(zip(*edges)), dtype=torch.long)return Data(x=x, edge_index=edge_index, y=torch.tensor([label], dtype=torch.long))s定义模型
模型很简单,由两层GCN组成,每层都之后是ReLU操作,然后经过global_mean_pool,最后送入全连接输出结果:
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_poolclass SimpleGCN(torch.nn.Module):def __init__(self, input_dim, hidden_dim, output_dim):super().__init__()self.conv1 = GCNConv(input_dim, hidden_dim)self.conv2 = GCNConv(hidden_dim, hidden_dim)self.fc = torch.nn.Linear(hidden_dim, output_dim)def forward(self, x, edge_index, batch):x = F.relu(self.conv1(x, edge_index))x = F.relu(self.conv2(x, edge_index))x = global_mean_pool(x, batch) # 聚合所有节点特征 → 图级别表示x = self.fc(x)return x
五、训练代码
使用交叉熵损失,Adam优化器:
from torch_geometric.loader import DataLoader
import torch.optim as optim# 加载数据
train_dataset = CatDogGraphDataset('data/cats_and_dogs/train')
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)# 初始化模型
model = SimpleGCN(input_dim=3, hidden_dim=32, output_dim=2)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()# 训练循环
for epoch in range(10):model.train()total_loss = 0for batch in train_loader:optimizer.zero_grad()out = model(batch.x, batch.edge_index, batch.batch)loss = criterion(out, batch.y)loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")
六、推理代码
推理时,流程一样,只是不计算梯度。
test_dataset = CatDogGraphDataset('data/cats_and_dogs/test')
test_loader = DataLoader(test_dataset, batch_size=1)model.eval()
correct = 0
total = 0with torch.no_grad():for batch in test_loader:out = model(batch.x, batch.edge_index, batch.batch)pred = out.argmax(dim=1)correct += int((pred == batch.y).sum())total += batch.y.size(0)print(f"Test Accuracy: {correct / total:.2%}")
七、总结
一句话总结:
CNN:在像素网格上卷积提取局部特征;
GCN:在区域关系图上传递并融合结构信息。
对于“猫坐在椅子上”这种有结构关系的图像,GCN 能更好地理解语义。
GCN用于图像处理就介绍到这里。
关注不迷路(*^▽^*),暴富入口==》 https://bbs.csdn.net/topics/619691583