当前位置: 首页 > news >正文

Torch Geometric环境下无线通信网络拓扑推理节点数据缺失实验

节点数据缺失样本生成:

gcn_dataset_incomplete.py

#作者:zhouzhichao
#创建时间:2025/5/30
#内容:生成残缺数据集用于实验import h5py
import numpy as np
import torch
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.utils import negative_samplingbase_dir = "D:\\无线通信网络认知\\论文1\\experiment\\直推式拓扑推理实验\\拓扑生成\\200样本\\"N = 30
grapg_size = N
train_n = 31
M = 3000class graph_data(InMemoryDataset):def __init__(self, root, signals=None, tp_list = None, transform=None, pre_transform=None):self.signals = signalsself.tp_list = tp_listsuper().__init__(root, transform, pre_transform)# self.data, self.slices = torch.load(self.processed_paths[0])self.data = torch.load(self.processed_paths[0])# 返回process方法所需的保存文件名。你之后保存的数据集名字和列表里的一致@propertydef processed_file_names(self):return ['gcn_data.pt']# 生成数据集所用的方法def process(self):signals = self.signalstp_list =self.tp_list# tp = Tp[:,:,k]X = torch.tensor(signals, dtype=torch.float)# 所有的边Edge_index = torch.tensor(tp_list, dtype=torch.long)# 所有的边1标签edge_label = np.ones((tp_list.shape[1]))# edge_label = np.zeros((tp_list.shape[1]))Edge_label = torch.tensor(edge_label, dtype=torch.float)neg_edge_index = negative_sampling(edge_index=Edge_index, num_nodes=grapg_size,num_neg_samples=Edge_index.shape[1], method='sparse')# 拼接正负样本索引Edge_label_index = Edge_indexperm = torch.randperm(Edge_index.size(1))Edge_index = Edge_index[:, perm]Edge_index = Edge_index[:, :train_n]Edge_label_index = torch.cat([Edge_label_index, neg_edge_index],dim=-1,)# 拼接正负样本Edge_label = torch.cat([Edge_label,Edge_label.new_zeros(neg_edge_index.size(1))], dim=0)data = Data(x=X, edge_index=Edge_index, edge_label_index=Edge_label_index, edge_label=Edge_label)torch.save(data, self.processed_paths[0])# data_list.append(data)# data_, slices = self.collate(data_list)  # 将不同大小的图数据对齐,填充# torch.save((data_, slices), self.processed_paths[0])for snr in [40]:print("snr: ", snr)mat_file = h5py.File(base_dir + str(N) + '_nodes_dataset_snr-' + str(snr) + '_M_' + str(M) + '.mat', 'r')Signals = mat_file["Signals"][()]Tp = mat_file["Tp"][()]Tp_list = mat_file["Tp_list"][()]n = 200for i in range(n):for erase in [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]:L = int(Signals.shape[0]*erase)signals = Signals[:,:,i]for j in range(N):start_idx = np.random.randint(0, Signals.shape[0] - L-5)signals[start_idx:start_idx+L,j]=0tp_list = np.array(mat_file[Tp_list[0, i]])root = "gcn_data-"+str(i)+"_erase-"+str(erase)+"_N_"+str(N)+"_snr_"+str(snr)+"_train_n_"+str(train_n)+"_M_"+str(M)graph_data(root, signals = signals, tp_list = tp_list)print("")print("...图数据生成完成...")

关键部分(在一张图的每个节点的随机位置把L长度的数据清空):

L = int(Signals.shape[0]*erase)
signals = Signals[:,:,i]
for j in range(N):start_idx = np.random.randint(0, Signals.shape[0] - L-5)signals[start_idx:start_idx+L,j]=0tp_list = np.array(mat_file[Tp_list[0, i]])
root = "gcn_data-"+str(i)+"_erase-"+str(erase)+"_N_"+str(N)+"_snr_"+str(snr)+"_train_n_"+str(train_n)+"_M_"+str(M)
graph_data(root, signals = signals, tp_list = tp_list)

生成结果:

对比实验:

gcn_erase_test.py

#作者:zhouzhichao
#创建时间:25年5月30日
#内容:进行残缺数据实验import sys
import torch
import random
import numpy as np
import pandas as pd
from torch_geometric.nn import GCNConv
from sklearn.metrics import roc_auc_score
sys.path.append('D:\无线通信网络认知\论文1\大修意见\Reviewer1-1 阈值相似性图对比实验')
from gcn_dataset import graph_data
print(torch.__version__)
print(torch.cuda.is_available())
from sklearn.metrics import roc_auc_score, precision_score, recall_score, accuracy_scoremode = "gcn"class Net(torch.nn.Module):def __init__(self):super().__init__()self.conv1 = GCNConv(Input_L, 1000)self.conv2 = GCNConv(1000, 20)# self.conv3 = GCNConv(2000, 256)# self.conv4 = GCNConv(256, 128)def encode(self, x, edge_index):x1 = self.conv1(x, edge_index)x1_1 = x1.relu()x2 = self.conv2(x1_1, edge_index)x2_2 = x2.relu()return x2_2def decode(self, z, edge_label_index):distance_squared = torch.sum((z[edge_label_index[0]] - z[edge_label_index[1]]) ** 2, dim=-1)return distance_squareddef decode_all(self, z):prob_adj = z @ z.t()  # 得到所有边概率矩阵return (prob_adj > 0).nonzero(as_tuple=False).t()  # 返回概率大于0的边,以edge_index的形式@torch.no_grad()def test(self, gcn_data):model.eval()z = model.encode(gcn_data.x, gcn_data.edge_index)out = model.decode(z, gcn_data.edge_label_index).view(-1)out = 1 - outout_np = out.cpu().numpy()labels_np = gcn_data.edge_label.cpu().numpy()roc_auc = roc_auc_score(labels_np, out_np)pred_labels = (out_np > -0.5).astype(int)accuracy = accuracy_score(labels_np, pred_labels)# 计算精度 (Precision)precision = precision_score(labels_np, pred_labels, zero_division=1)# 计算召回率 (Recall)recall = recall_score(labels_np, pred_labels, zero_division=1)return roc_auc, accuracy, precision, recallN = 30
train_n = 31
M = 3000snr = 40# M = 10000# print("train_n: ", train_n)
# gcn_data = graph_data("gcn_data")
# print("M: ", M)
print("snr: ", snr)
for erase in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:accuracy_list = []precision_list = []recall_list = []for i in range(200):root = "gcn_data-" + str(i) + "_erase-" + str(erase) + "_N_" + str(N) + "_snr_" + str(snr) + "_train_n_" + str(train_n) + "_M_" + str(M)gcn_data = graph_data(root)Input_L = gcn_data.x.shape[1]model = Net()# model = Net().to(device)optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)criterion = torch.nn.BCEWithLogitsLoss()model.train()def train():optimizer.zero_grad()z = model.encode(gcn_data.x, gcn_data.edge_index)out = model.decode(z, gcn_data.edge_label_index).view(-1)out = 1 - outloss = criterion(out, gcn_data.edge_label)loss.backward()optimizer.step()return lossmin_loss = 99999count = 0#早停for epoch in range(100000):loss = train()if loss<min_loss:min_loss = losscount = 0print("erase:  ", str(erase), "   i:  ", str(i), "   epoch:  ", epoch, "   loss: ",round(loss.item(), 4), "   min_loss: ", round(min_loss.item(), 4))count = count + 1if count>100:break# print("erase:  ",str(erase),"   i:  ",str(i),"   epoch:  ",epoch,"   loss: ",round(loss.item(),2), "   min_loss: ",round(min_loss.item(),2))roc_auc, accuracy, precision, recall = model.test(gcn_data)accuracy_list.append(accuracy)precision_list.append(precision)recall_list.append(recall)data = {'accuracy_list': accuracy_list,'precision_list': precision_list,'recall_list': recall_list}# 创建一个 DataFramedf = pd.DataFrame(data)## # 保存到 Excel 文件file_path = 'D:\无线通信网络认知\论文1\大修意见\Reviewer1-2 缺失数据实验\\erase-' + str(erase) +'.xlsx'df.to_excel(file_path, index=False)

相关文章:

  • YOLOv8 移动端升级:借助 GhostNetv2 主干网络,实现高效特征提取
  • MySQL中count(1)和count(*)的区别及细节
  • python连接邮箱,下载附件,并且定时更新的方案
  • 【机器学习】支持向量机
  • 【速通RAG实战:进阶】17、AI视频打点全攻略:从技术实现到媒体工作流提效的实战指南
  • AUTOSAR图解==>AUTOSAR_EXP_AIADASAndVMC
  • JWT 原理与设计上的缺陷及利用
  • 设计模式——适配器设计模式(结构型)
  • 数字化转型进阶:精读41页华为数字化转型实践【附全文阅读】
  • leetcode动态规划—买卖股票系列
  • Python----目标检测(《基于区域提议网络的实时目标检测方法》和Faster R-CNN)
  • 每日算法刷题Day19 5.31:leetcode二分答案3道题,用时1h
  • 34.x64汇编写法(一)
  • 端午安康(Python)
  • 现代数据湖架构全景解析:存储、表格式、计算引擎与元数据服务的协同生态
  • 【Web API系列】WebTransportSendStream接口深度解析:构建高性能实时数据传输的基石
  • 开源是什么?我们为什么要开源?
  • 谷歌工作自动化——仙盟大衍灵机——仙盟创梦IDE
  • Java中的引用类型以及区别的特点
  • 第十四章 MQTT订阅
  • 烟台制作网站的公司简介/东莞seo技术
  • 做网站带后台多少钱/品牌运营策划方案
  • 个人无网站怎样做cps广告/站长工具域名
  • 比较好的做外贸网站/北京seo外包平台
  • 苏州做门户网站的公司/南宁网站建设
  • 惠州建设网站/关键词优化排名要多少钱