Torch Geometric环境下无线通信网络拓扑推理节点数据缺失实验
节点数据缺失样本生成:
gcn_dataset_incomplete.py
#作者:zhouzhichao
#创建时间:2025/5/30
#内容:生成残缺数据集用于实验import h5py
import numpy as np
import torch
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.utils import negative_samplingbase_dir = "D:\\无线通信网络认知\\论文1\\experiment\\直推式拓扑推理实验\\拓扑生成\\200样本\\"N = 30
grapg_size = N
train_n = 31
M = 3000class graph_data(InMemoryDataset):def __init__(self, root, signals=None, tp_list = None, transform=None, pre_transform=None):self.signals = signalsself.tp_list = tp_listsuper().__init__(root, transform, pre_transform)# self.data, self.slices = torch.load(self.processed_paths[0])self.data = torch.load(self.processed_paths[0])# 返回process方法所需的保存文件名。你之后保存的数据集名字和列表里的一致@propertydef processed_file_names(self):return ['gcn_data.pt']# 生成数据集所用的方法def process(self):signals = self.signalstp_list =self.tp_list# tp = Tp[:,:,k]X = torch.tensor(signals, dtype=torch.float)# 所有的边Edge_index = torch.tensor(tp_list, dtype=torch.long)# 所有的边1标签edge_label = np.ones((tp_list.shape[1]))# edge_label = np.zeros((tp_list.shape[1]))Edge_label = torch.tensor(edge_label, dtype=torch.float)neg_edge_index = negative_sampling(edge_index=Edge_index, num_nodes=grapg_size,num_neg_samples=Edge_index.shape[1], method='sparse')# 拼接正负样本索引Edge_label_index = Edge_indexperm = torch.randperm(Edge_index.size(1))Edge_index = Edge_index[:, perm]Edge_index = Edge_index[:, :train_n]Edge_label_index = torch.cat([Edge_label_index, neg_edge_index],dim=-1,)# 拼接正负样本Edge_label = torch.cat([Edge_label,Edge_label.new_zeros(neg_edge_index.size(1))], dim=0)data = Data(x=X, edge_index=Edge_index, edge_label_index=Edge_label_index, edge_label=Edge_label)torch.save(data, self.processed_paths[0])# data_list.append(data)# data_, slices = self.collate(data_list) # 将不同大小的图数据对齐,填充# torch.save((data_, slices), self.processed_paths[0])for snr in [40]:print("snr: ", snr)mat_file = h5py.File(base_dir + str(N) + '_nodes_dataset_snr-' + str(snr) + '_M_' + str(M) + '.mat', 'r')Signals = mat_file["Signals"][()]Tp = mat_file["Tp"][()]Tp_list = mat_file["Tp_list"][()]n = 200for i in range(n):for erase in [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]:L = int(Signals.shape[0]*erase)signals = Signals[:,:,i]for j in range(N):start_idx = np.random.randint(0, Signals.shape[0] - L-5)signals[start_idx:start_idx+L,j]=0tp_list = np.array(mat_file[Tp_list[0, i]])root = "gcn_data-"+str(i)+"_erase-"+str(erase)+"_N_"+str(N)+"_snr_"+str(snr)+"_train_n_"+str(train_n)+"_M_"+str(M)graph_data(root, signals = signals, tp_list = tp_list)print("")print("...图数据生成完成...")
关键部分(在一张图的每个节点的随机位置把L长度的数据清空):
L = int(Signals.shape[0]*erase)
signals = Signals[:,:,i]
for j in range(N):start_idx = np.random.randint(0, Signals.shape[0] - L-5)signals[start_idx:start_idx+L,j]=0tp_list = np.array(mat_file[Tp_list[0, i]])
root = "gcn_data-"+str(i)+"_erase-"+str(erase)+"_N_"+str(N)+"_snr_"+str(snr)+"_train_n_"+str(train_n)+"_M_"+str(M)
graph_data(root, signals = signals, tp_list = tp_list)
生成结果:
对比实验:
gcn_erase_test.py
#作者:zhouzhichao
#创建时间:25年5月30日
#内容:进行残缺数据实验import sys
import torch
import random
import numpy as np
import pandas as pd
from torch_geometric.nn import GCNConv
from sklearn.metrics import roc_auc_score
sys.path.append('D:\无线通信网络认知\论文1\大修意见\Reviewer1-1 阈值相似性图对比实验')
from gcn_dataset import graph_data
print(torch.__version__)
print(torch.cuda.is_available())
from sklearn.metrics import roc_auc_score, precision_score, recall_score, accuracy_scoremode = "gcn"class Net(torch.nn.Module):def __init__(self):super().__init__()self.conv1 = GCNConv(Input_L, 1000)self.conv2 = GCNConv(1000, 20)# self.conv3 = GCNConv(2000, 256)# self.conv4 = GCNConv(256, 128)def encode(self, x, edge_index):x1 = self.conv1(x, edge_index)x1_1 = x1.relu()x2 = self.conv2(x1_1, edge_index)x2_2 = x2.relu()return x2_2def decode(self, z, edge_label_index):distance_squared = torch.sum((z[edge_label_index[0]] - z[edge_label_index[1]]) ** 2, dim=-1)return distance_squareddef decode_all(self, z):prob_adj = z @ z.t() # 得到所有边概率矩阵return (prob_adj > 0).nonzero(as_tuple=False).t() # 返回概率大于0的边,以edge_index的形式@torch.no_grad()def test(self, gcn_data):model.eval()z = model.encode(gcn_data.x, gcn_data.edge_index)out = model.decode(z, gcn_data.edge_label_index).view(-1)out = 1 - outout_np = out.cpu().numpy()labels_np = gcn_data.edge_label.cpu().numpy()roc_auc = roc_auc_score(labels_np, out_np)pred_labels = (out_np > -0.5).astype(int)accuracy = accuracy_score(labels_np, pred_labels)# 计算精度 (Precision)precision = precision_score(labels_np, pred_labels, zero_division=1)# 计算召回率 (Recall)recall = recall_score(labels_np, pred_labels, zero_division=1)return roc_auc, accuracy, precision, recallN = 30
train_n = 31
M = 3000snr = 40# M = 10000# print("train_n: ", train_n)
# gcn_data = graph_data("gcn_data")
# print("M: ", M)
print("snr: ", snr)
for erase in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:accuracy_list = []precision_list = []recall_list = []for i in range(200):root = "gcn_data-" + str(i) + "_erase-" + str(erase) + "_N_" + str(N) + "_snr_" + str(snr) + "_train_n_" + str(train_n) + "_M_" + str(M)gcn_data = graph_data(root)Input_L = gcn_data.x.shape[1]model = Net()# model = Net().to(device)optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)criterion = torch.nn.BCEWithLogitsLoss()model.train()def train():optimizer.zero_grad()z = model.encode(gcn_data.x, gcn_data.edge_index)out = model.decode(z, gcn_data.edge_label_index).view(-1)out = 1 - outloss = criterion(out, gcn_data.edge_label)loss.backward()optimizer.step()return lossmin_loss = 99999count = 0#早停for epoch in range(100000):loss = train()if loss<min_loss:min_loss = losscount = 0print("erase: ", str(erase), " i: ", str(i), " epoch: ", epoch, " loss: ",round(loss.item(), 4), " min_loss: ", round(min_loss.item(), 4))count = count + 1if count>100:break# print("erase: ",str(erase)," i: ",str(i)," epoch: ",epoch," loss: ",round(loss.item(),2), " min_loss: ",round(min_loss.item(),2))roc_auc, accuracy, precision, recall = model.test(gcn_data)accuracy_list.append(accuracy)precision_list.append(precision)recall_list.append(recall)data = {'accuracy_list': accuracy_list,'precision_list': precision_list,'recall_list': recall_list}# 创建一个 DataFramedf = pd.DataFrame(data)## # 保存到 Excel 文件file_path = 'D:\无线通信网络认知\论文1\大修意见\Reviewer1-2 缺失数据实验\\erase-' + str(erase) +'.xlsx'df.to_excel(file_path, index=False)