当前位置: 首页 > news >正文

6月14日day53打卡

对抗生成网络

知识点回顾

  1. 对抗生成网络思想关注损失从何而来
  2. 生成判别器
  3. nn.sequential容器适合于按顺序运算情况简化前向传播写法
  4. leakyReLU介绍避免relu神经元失活现象

ps;如果你学有余力对于gan的损失函数的理解,建议去找找视频看看,如果只是用,没必要学

作业:对于心脏病数据集对于病人这个不平衡样本用GAN来学习并生成病人样本观察不用GAN和用GAN的F1分数差异
 

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score, classification_report, confusion_matrix, roc_curve, auc
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings("ignore")# 设置中文字体(解决中文显示问题)
plt.rcParams['font.sans-serif'] = ['SimHei']  # Windows系统常用黑体字体
plt.rcParams['axes.unicode_minus'] = False    # 正常显示负号
data = pd.read_csv('D:\桌面\研究项目\打卡文件\python60-days-challenge-master\heart.csv')    #读取数据# 设置随机种子以确保结果可复现
torch.manual_seed(42)
np.random.seed(42)# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# ---------------------------
# 1. 加载并预处理心脏病数据集
# ---------------------------def load_heart_disease_data():"""加载心脏病数据集并进行预处理"""try:# 读取CSV文件df = pd.read_csv("D:\桌面\研究项目\打卡文件\python60-days-challenge-master\heart.csv")print(f"成功加载heart.csv,数据形状: {df.shape}")# 显示数据基本信息print("\n数据基本信息:")df.info()# 显示数据集行数和列数rows, columns = df.shape# 数据可视化 - 目标变量分布plt.figure(figsize=(6, 4))sns.countplot(x='target', data=df)plt.title('心脏病患者分布')plt.xlabel('是否患病')plt.ylabel('样本数')plt.xticks([0, 1], ['健康', '患病'])plt.show()# 重命名列以便更好理解df.columns = ['age', 'sex', 'chest_pain_type', 'resting_blood_pressure', 'cholesterol', 'fasting_blood_sugar', 'rest_ecg', 'max_heart_rate_achieved', 'exercise_induced_angina', 'st_depression', 'st_slope', 'num_major_vessels', 'thalassemia', 'target']# 转换分类特征categorical_features = ['sex', 'chest_pain_type', 'fasting_blood_sugar', 'rest_ecg', 'exercise_induced_angina', 'st_slope', 'num_major_vessels', 'thalassemia']for feature in categorical_features:df[feature] = df[feature].astype('object')# 独热编码df = pd.get_dummies(df, drop_first=True)# 划分特征和目标变量X = df.drop(columns='target').valuesy = df['target'].values# 打印类别分布class_counts = pd.Series(y).value_counts()print(f"\n类别分布: \n{class_counts}")print(f"不平衡比例: {class_counts[0]/class_counts[1]:.2f}:1")return X, yexcept FileNotFoundError:print("错误:未找到heart.csv文件,请检查文件路径")return None, Noneexcept Exception as e:print(f"数据加载错误: {str(e)}")return None, None# 加载数据
X, y = load_heart_disease_data()
if X is None or y is None:exit()# 数据缩放
scaler = MinMaxScaler(feature_range=(-1, 1))
X_scaled = scaler.fit_transform(X)# 分割数据
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)# 提取少数类样本 (心脏病患者)
X_minority = X_train[y_train == 1]
y_minority = y_train[y_train == 1]print(f"训练集中少数类样本数量: {len(X_minority)}")# ---------------------------
# 2. 定义条件GAN模型
# ---------------------------class ConditionalGenerator(nn.Module):def __init__(self, input_dim, output_dim, label_dim):super(ConditionalGenerator, self).__init__()self.model = nn.Sequential(nn.Linear(input_dim + label_dim, 64),nn.LeakyReLU(0.2),nn.BatchNorm1d(64),nn.Linear(64, 128),nn.LeakyReLU(0.2),nn.BatchNorm1d(128),nn.Linear(128, output_dim),nn.Tanh()  # 输出范围为[-1, 1],与数据缩放范围一致)def forward(self, z, labels):# 合并噪声和标签input_tensor = torch.cat([z, labels], dim=1)return self.model(input_tensor)class ConditionalDiscriminator(nn.Module):def __init__(self, input_dim, label_dim):super(ConditionalDiscriminator, self).__init__()self.model = nn.Sequential(nn.Linear(input_dim + label_dim, 128),nn.LeakyReLU(0.2),nn.Dropout(0.3),nn.Linear(128, 64),nn.LeakyReLU(0.2),nn.Dropout(0.3),nn.Linear(64, 1),nn.Sigmoid()  # 输出概率值)def forward(self, x, labels):# 合并输入特征和标签input_tensor = torch.cat([x, labels], dim=1)return self.model(input_tensor)# ---------------------------
# 3. 训练CGAN模型
# ---------------------------# 模型参数
LATENT_DIM = 10
INPUT_DIM = X_train.shape[1]
LABEL_DIM = 1  # 二分类问题
EPOCHS = 1000
BATCH_SIZE = 32
LR = 0.0002
BETA1 = 0.5# 创建数据加载器
minority_dataset = TensorDataset(torch.FloatTensor(X_minority), torch.FloatTensor(y_minority).view(-1, 1)
)
minority_dataloader = DataLoader(minority_dataset, batch_size=BATCH_SIZE, shuffle=True)# 实例化模型
generator = ConditionalGenerator(LATENT_DIM, INPUT_DIM, LABEL_DIM).to(device)
discriminator = ConditionalDiscriminator(INPUT_DIM, LABEL_DIM).to(device)# 定义损失函数和优化器
criterion = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=LR, betas=(BETA1, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=LR, betas=(BETA1, 0.999))# 训练循环
g_losses, d_losses = [], []for epoch in range(EPOCHS):epoch_g_loss, epoch_d_loss = 0, 0batches_per_epoch = 0for i, (real_data, real_labels) in enumerate(minority_dataloader):real_data = real_data.to(device)real_labels = real_labels.to(device)batch_size = real_data.size(0)batches_per_epoch += 1# 创建真实和虚假标签real_targets = torch.ones(batch_size, 1).to(device)fake_targets = torch.zeros(batch_size, 1).to(device)# ---------------------#  训练判别器# ---------------------d_optimizer.zero_grad()# 用真实数据训练real_validity = discriminator(real_data, real_labels)d_real_loss = criterion(real_validity, real_targets)# 生成假数据z = torch.randn(batch_size, LATENT_DIM).to(device)fake_labels = real_labels  # 生成与真实样本相同类别的数据fake_data = generator(z, fake_labels)# 用假数据训练fake_validity = discriminator(fake_data.detach(), fake_labels)d_fake_loss = criterion(fake_validity, fake_targets)# 总判别器损失d_loss = (d_real_loss + d_fake_loss) / 2d_loss.backward()d_optimizer.step()epoch_d_loss += d_loss.item()# ---------------------#  训练生成器# ---------------------g_optimizer.zero_grad()# 生成假数据fake_data = generator(z, fake_labels)fake_validity = discriminator(fake_data, fake_labels)# 生成器损失g_loss = criterion(fake_validity, real_targets)g_loss.backward()g_optimizer.step()epoch_g_loss += g_loss.item()# 计算平均损失avg_g_loss = epoch_g_loss / batches_per_epochavg_d_loss = epoch_d_loss / batches_per_epochg_losses.append(avg_g_loss)d_losses.append(avg_d_loss)# 每100个epoch打印一次损失if (epoch + 1) % 100 == 0:print(f"Epoch [{epoch+1}/{EPOCHS}], D_loss: {avg_d_loss:.4f}, G_loss: {avg_g_loss:.4f}")print("CGAN训练完成!")# 绘制训练损失曲线
plt.figure(figsize=(10, 5))
plt.plot(g_losses, label='生成器损失')
plt.plot(d_losses, label='判别器损失')
plt.title('训练损失曲线')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend()
plt.show()# ---------------------------
# 4. 生成合成样本
# ---------------------------# 设置为评估模式
generator.eval()# 生成与少数类样本数量相同的合成数据
num_samples_to_generate = len(X_minority)
z = torch.randn(num_samples_to_generate, LATENT_DIM).to(device)
labels = torch.ones(num_samples_to_generate, 1).to(device)  # 标签为1,表示心脏病患者with torch.no_grad():synthetic_data = generator(z, labels).cpu().numpy()# 逆缩放合成数据
synthetic_data = scaler.inverse_transform(synthetic_data)# 为合成数据创建标签
synthetic_labels = np.ones(num_samples_to_generate)# 逆缩放原始训练数据用于可视化
X_train_original = scaler.inverse_transform(X_train)# ---------------------------
# 5. 可视化原始数据和合成数据
# ---------------------------# 可视化特征分布对比
plt.figure(figsize=(14, 10))
feature_names = ['年龄', '血压', '胆固醇', '最大心率']
feature_indices = [0, 3, 4, 7]  # 对应数据集中的特征索引for i, idx in enumerate(feature_indices):plt.subplot(2, 2, i+1)# 绘制原始少数类样本的特征分布sns.kdeplot(X_train_original[y_train == 1, idx], label='原始数据', color='blue')# 绘制合成样本的特征分布sns.kdeplot(synthetic_data[:, idx], label='合成数据', color='orange')plt.title(f'{feature_names[i]}分布对比')plt.xlabel('特征值')plt.ylabel('密度')plt.legend()plt.tight_layout()
plt.show()# ---------------------------
# 6. 比较模型性能
# ---------------------------# 6.1 使用原始数据训练的模型
model_original = RandomForestClassifier(random_state=42)
model_original.fit(X_train, y_train)
y_pred_original = model_original.predict(X_test)
y_pred_prob_original = model_original.predict_proba(X_test)[:, 1]# 6.2 使用增强数据训练的模型
# 将合成数据添加到训练集中
X_train_augmented = np.vstack([X_train, scaler.transform(synthetic_data)])
y_train_augmented = np.hstack([y_train, synthetic_labels])model_augmented = RandomForestClassifier(random_state=42)
model_augmented.fit(X_train_augmented, y_train_augmented)
y_pred_augmented = model_augmented.predict(X_test)
y_pred_prob_augmented = model_augmented.predict_proba(X_test)[:, 1]# 6.3 比较F1分数
f1_original = f1_score(y_test, y_pred_original)
f1_augmented = f1_score(y_test, y_pred_augmented)print("\n模型性能比较:")
print(f"原始数据训练的模型 F1 分数: {f1_original:.4f}")
print(f"增强数据训练的模型 F1 分数: {f1_augmented:.4f}")# 打印详细分类报告
print("\n原始数据训练的模型分类报告:")
print(classification_report(y_test, y_pred_original))print("\n增强数据训练的模型分类报告:")
print(classification_report(y_test, y_pred_augmented))# 计算混淆矩阵
cm_original = confusion_matrix(y_test, y_pred_original)
cm_augmented = confusion_matrix(y_test, y_pred_augmented)# 计算ROC曲线
fpr_original, tpr_original, _ = roc_curve(y_test, y_pred_prob_original)
fpr_augmented, tpr_augmented, _ = roc_curve(y_test, y_pred_prob_augmented)
roc_auc_original = auc(fpr_original, tpr_original)
roc_auc_augmented = auc(fpr_augmented, tpr_augmented)# ---------------------------
# 7. 可视化评估结果
# ---------------------------# 7.1 混淆矩阵对比
plt.figure(figsize=(12, 5))plt.subplot(1, 2, 1)
sns.heatmap(cm_original, annot=True, fmt='d', cmap='Blues')
plt.title('原始数据模型混淆矩阵')
plt.xlabel('预测类别')
plt.ylabel('真实类别')plt.subplot(1, 2, 2)
sns.heatmap(cm_augmented, annot=True, fmt='d', cmap='Greens')
plt.title('增强数据模型混淆矩阵')
plt.xlabel('预测类别')
plt.ylabel('真实类别')plt.tight_layout()
plt.show()# 7.2 ROC曲线对比
plt.figure(figsize=(8, 6))
plt.plot(fpr_original, tpr_original, color='blue', lw=2, label=f'原始数据 (AUC = {roc_auc_original:.2f})')
plt.plot(fpr_augmented, tpr_augmented, color='green', lw=2, label=f'增强数据 (AUC = {roc_auc_augmented:.2f})')
plt.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('假阳性率 (1-特异性)')
plt.ylabel('真阳性率 (敏感性)')
plt.title('ROC曲线比较')
plt.legend(loc="lower right")
plt.show()# 7.3 F1分数对比
plt.figure(figsize=(8, 5))
plt.bar(['原始数据', '增强数据'], [f1_original, f1_augmented], color=['#636EFA', '#EF553B'])
plt.ylim(0, 1)
plt.title('使用GAN增强前后的模型F1分数对比')
plt.ylabel('F1分数')
for i, v in enumerate([f1_original, f1_augmented]):plt.text(i, v + 0.02, f'{v:.4f}', ha='center')
plt.show()    
使用设备: cuda
成功加载heart.csv,数据形状: (303, 14)数据基本信息:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 303 entries, 0 to 302
Data columns (total 14 columns):#   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  0   age       303 non-null    int64  1   sex       303 non-null    int64  2   cp        303 non-null    int64  3   trestbps  303 non-null    int64  4   chol      303 non-null    int64  5   fbs       303 non-null    int64  6   restecg   303 non-null    int64  7   thalach   303 non-null    int64  8   exang     303 non-null    int64  9   oldpeak   303 non-null    float6410  slope     303 non-null    int64  11  ca        303 non-null    int64  12  thal      303 non-null    int64  13  target    303 non-null    int64  
dtypes: float64(1), int64(13)
memory usage: 33.3 KB

类别分布: 
1    165
0    138
Name: count, dtype: int64
不平衡比例: 0.84:1
训练集中少数类样本数量: 133
Epoch [100/1000], D_loss: 0.5933, G_loss: 0.9950
Epoch [200/1000], D_loss: 0.6231, G_loss: 0.9423
Epoch [300/1000], D_loss: 0.5518, G_loss: 1.0263
Epoch [400/1000], D_loss: 0.5745, G_loss: 0.9765
Epoch [500/1000], D_loss: 0.5725, G_loss: 0.8739
Epoch [600/1000], D_loss: 0.6261, G_loss: 0.8975
Epoch [700/1000], D_loss: 0.5993, G_loss: 0.9735
Epoch [800/1000], D_loss: 0.6198, G_loss: 0.9446
Epoch [900/1000], D_loss: 0.5936, G_loss: 0.9563
Epoch [1000/1000], D_loss: 0.6497, G_loss: 0.9020
CGAN训练完成!

 

模型性能比较:
原始数据训练的模型 F1 分数: 0.8333
增强数据训练的模型 F1 分数: 0.8333原始数据训练的模型分类报告:precision    recall  f1-score   support0       0.79      0.90      0.84        291       0.89      0.78      0.83        32accuracy                           0.84        61macro avg       0.84      0.84      0.84        61
weighted avg       0.84      0.84      0.84        61增强数据训练的模型分类报告:precision    recall  f1-score   support0       0.79      0.90      0.84        291       0.89      0.78      0.83        32accuracy                           0.84        61macro avg       0.84      0.84      0.84        61
weighted avg       0.84      0.84      0.84        61

 

实验总结报告

一、数据分布分析

1. 原始类别分布

  • 数据集规模:心脏病数据集共 303 条记录。
  • 类别分布:患病(target=1)样本 165 条,健康(target=0)样本 138 条。
  • 不平衡程度:不平衡比例约为 0.84:1,属于轻度不平衡数据集。
  • 实验动机:因存在轻度不平衡,引入 GAN 进行数据增强,试图补充少数类(健康样本相对患病样本为少数类)数据,优化模型对少数类的识别能力。

2. 合成数据分布

  • 特征分布对比(年龄、血压、胆固醇、最大心率等关键特征):
    • 整体趋势:合成数据与原始数据的核密度曲线形态接近,说明 GAN 生成的合成数据能一定程度模拟原始数据的分布模式,具备补充真实数据的潜力。
    • 现存差异:部分特征(如胆固醇分布)的曲线形态仍有差异,反映生成数据与真实数据存在一定 gap。
  • 优化方向:后续可尝试调整 GAN 网络结构(如增加层数、调整隐层维度)、训练参数(如学习率、迭代次数)以提升生成效果。

二、GAN 训练过程分析

损失曲线特征

  • 前期阶段:判别器损失快速下降,生成器损失快速上升,体现判别器对真假数据的辨别能力快速提升,生成器努力学习生成更“逼真”数据。
  • 后期阶段:二者损失进入相对稳定的波动状态,模型逐渐达到博弈平衡。
  • 现存问题:损失值未趋近于理想低水平(如判别器损失未接近 0.5 附近稳定),可能因训练轮次不足、网络结构未充分拟合数据分布,或数据集特征复杂导致。
  • 优化方向:可考虑延长训练周期、改进网络架构(如引入注意力机制、调整激活函数)。

三、模型评估指标分析

1. 混淆矩阵

  • 分类格局:原始数据与增强数据训练的随机森林模型,混淆矩阵结构几乎一致。
    • 健康样本(target=0):真实为健康的样本中,预测正确 26 例、错误 3 例。
    • 患病样本(target=1):真实为患病的样本中,预测正确 25 例、错误 7 例。
  • 结论:GAN 增强数据未显著改变模型对不同类别样本的分类正误格局。
  • 可能原因:原始数据集不平衡程度较低,少量合成数据补充对模型决策边界影响有限;或生成数据质量不足,未提供足够有价值的新信息。

2. ROC 曲线与 AUC

  • 曲线形态:两条 ROC 曲线形态接近,AUC 均为 0.93。
  • 结论:模型在原始数据和增强数据上对正负样本的区分能力相当,整体预测性能良好,但增强数据未带来 AUC 提升。
  • 可能原因:合成数据质量或数量不足,未有效强化模型区分能力;可尝试生成更多高质量数据或结合其他增强策略(如 SMOTE 结合 GAN)。

3. F1 分数

  • 指标结果:原始数据与增强数据训练模型的 F1 分数均为 0.8333。
  • 结论:GAN 增强后,模型在平衡精确率(precision)和召回率(recall)方面未体现优势;结合分类报告,两类样本的精确率、召回率无明显变化。
  • 可能原因:当前实验设置下,GAN 增强未有效改善模型性能,需从数据生成质量、模型训练策略(如调整分类器参数、尝试其他分类模型)等方面优化。

四、整体结论与优化方向

本次实验中,GAN 虽能一定程度模拟数据分布,但因原始数据集不平衡程度不高、生成数据质量和数量有限等因素,未显著提升分类模型性能。
后续优化方向

  1. 优化 GAN 训练:调整网络结构(如增加层数、引入注意力机制)、延长训练周期、优化超参数(如学习率、批次大小)。
  2. 探索多策略数据增强:结合 SMOTE 等传统方法与 GAN,提升合成数据质量与多样性。
  3. 改进分类模型:尝试其他分类器(如 XGBoost、LightGBM)或调整随机森林参数,挖掘模型潜力。

@浙大疏锦行 

相关文章:

  • openeuler系统每日生成系统运行报告并发送至指定邮箱
  • AVL树的实现
  • 人工智能基础知识笔记九:数据规范化技巧
  • 解决文明6 内存相关内容报错EXCEPTION_ACCESS_VIOLATION
  • 吃透 Golang 基础:方法
  • 鸿蒙NEXT-Data类型数据通过AppStore获取后找原本一样的数据(值一样)但是比较结果却为false
  • 电路问题处理:SGMII链路中的AC耦合电容摆放位置
  • Linux【8】-----Linux系统编程(并发编程原理与应用)
  • 【系统分析师】2009年真题:综合知识-答案及详解
  • JavaEE初阶第一期:计算机是如何 “思考” 的(上)
  • 如何统一从一个共享的配置文件(比如 config.yaml 和 config.py)加载配置信息,避免硬编码。
  • 14:00开始面试,14:06就出来了,问的问题有点变态。。。
  • leetcode0765. 情侣牵手-hard
  • 【多线程初阶】详解线程池(上)
  • 什么是旋转开关?
  • 芯片测试之 trim修调测试详解
  • 掌握C#枚举:从交通灯看懂状态管理
  • 【0.3 漫画数据结构与算法】
  • Sngine 4.0.4海外社交平台PHP源码 – 多语言支持短视频和博客订阅(源码下载)
  • pdb/sdf转pdbqt错误、pdbqt小分子文件对接蛋白质dock分数为0
  • 企业的门户网站一般用什么做/泉州百度搜索推广
  • 支付宝网站设计分析/百度广告联盟官网
  • 网站手机版怎么做/端口扫描站长工具
  • 单页网站上传教程/宣传推广策略
  • 湘潭公司网站建设/淘宝seo培训
  • 安丘网站建设报价/淘宝关键词指数