python第52天打卡
对抗生成网络
知识点回顾:
对抗生成网络的思想:关注损失从何而来
生成器、判别器
nn.sequential容器:适合于按顺序运算的情况,简化前向传播写法
leakyReLU介绍:避免relu的神经元失活现象
ps;如果你学有余力,对于gan的损失函数的理解,建议去找找视频看看,如果只是用,没必要学
作业:对于心脏病数据集,对于病人这个不平衡的样本用GAN来学习并生成病人样本,观察不用GAN和用GAN的F1分数差异。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score, classification_report, confusion_matrix, roc_curve, auc
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings("ignore")# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False # 设置随机种子确保可复现性
torch.manual_seed(42)
np.random.seed(42)# 检测计算设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"运行设备: {device}")# =============================
# 数据加载与预处理模块
# =============================
def prepare_cardiac_data():"""准备心脏病数据集并进行预处理"""try:# 读取数据文件cardiac_df = pd.read_csv(r'D:\桌面\研究项目\打卡文件\python60-days-challenge-master\heart.csv')print(f"成功加载数据集,维度: {cardiac_df.shape}")# 显示数据摘要print("\n数据集摘要:")cardiac_df.info()# 目标变量分布可视化plt.figure(figsize=(6, 4))sns.countplot(x='target', data=cardiac_df)plt.title('心脏病诊断分布')plt.xlabel('诊断结果')plt.ylabel('样本数量')plt.xticks([0, 1], ['阴性', '阳性'])plt.show()# 重命名特征列cardiac_df.columns = ['age', 'gender', 'chest_pain', 'resting_bp', 'cholesterol', 'fasting_bs', 'resting_ecg', 'max_hr', 'exercise_angina', 'st_depression', 'st_slope', 'major_vessels', 'thalassemia', 'target']# 处理分类特征categorical_cols = ['gender', 'chest_pain', 'fasting_bs', 'resting_ecg', 'exercise_angina', 'st_slope', 'major_vessels', 'thalassemia']for col in categorical_cols:cardiac_df[col] = cardiac_df[col].astype('category')# 执行独热编码cardiac_df = pd.get_dummies(cardiac_df, drop_first=True)# 分离特征和标签features = cardiac_df.drop(columns='target').valueslabels = cardiac_df['target'].values# 分析类别分布class_dist = pd.Series(labels).value_counts()print(f"\n类别分布: \n{class_dist}")print(f"类别比例: {class_dist[0]/class_dist[1]:.2f}:1")return features, labelsexcept FileNotFoundError:print("错误:未找到数据文件,请检查路径")return None, Noneexcept Exception as e:print(f"数据处理异常: {str(e)}")return None, None# 执行数据准备
features, labels = prepare_cardiac_data()
if features is None or labels is None:exit()# 特征归一化处理
normalizer = MinMaxScaler(feature_range=(-1, 1))
normalized_features = normalizer.fit_transform(features)# 数据集划分
X_train, X_test, y_train, y_test = train_test_split(normalized_features, labels, test_size=0.2, random_state=42
)# 提取阳性样本
positive_samples = X_train[y_train == 1]
positive_labels = y_train[y_train == 1]print(f"训练集阳性样本数: {len(positive_samples)}")# =============================
# 条件生成对抗网络定义
# =============================
class CGANGenerator(nn.Module):def __init__(self, noise_dim, feature_dim, label_dim):super(CGANGenerator, self).__init__()self.net = nn.Sequential(nn.Linear(noise_dim + label_dim, 64),nn.LeakyReLU(0.2),nn.BatchNorm1d(64),nn.Linear(64, 128),nn.LeakyReLU(0.2),nn.BatchNorm1d(128),nn.Linear(128, feature_dim),nn.Tanh())def forward(self, noise, condition):combined_input = torch.cat([noise, condition], dim=1)return self.net(combined_input)class CGANDiscriminator(nn.Module):def __init__(self, feature_dim, label_dim):super(CGANDiscriminator, self).__init__()self.net = nn.Sequential(nn.Linear(feature_dim + label_dim, 128),nn.LeakyReLU(0.2),nn.Dropout(0.3),nn.Linear(128, 64),nn.LeakyReLU(0.2),nn.Dropout(0.3),nn.Linear(64, 1),nn.Sigmoid())def forward(self, inputs, condition):combined_input = torch.cat([inputs, condition], dim=1)return self.net(combined_input)# =============================
# CGAN模型训练
# =============================
# 网络参数配置
NOISE_DIM = 10
FEATURE_DIM = X_train.shape[1]
CONDITION_DIM = 1
TRAIN_EPOCHS = 1000
BATCH_SIZE = 32
LEARNING_RATE = 0.0002
BETA1 = 0.5# 准备数据加载器
positive_dataset = TensorDataset(torch.FloatTensor(positive_samples), torch.FloatTensor(positive_labels).view(-1, 1)
positive_loader = DataLoader(positive_dataset, batch_size=BATCH_SIZE, shuffle=True)# 初始化网络实例
generator_net = CGANGenerator(NOISE_DIM, FEATURE_DIM, CONDITION_DIM).to(device)
discriminator_net = CGANDiscriminator(FEATURE_DIM, CONDITION_DIM).to(device)# 配置优化目标和优化器
loss_func = nn.BCELoss()
gen_optim = optim.Adam(generator_net.parameters(), lr=LEARNING_RATE, betas=(BETA1, 0.999))
dis_optim = optim.Adam(discriminator_net.parameters(), lr=LEARNING_RATE, betas=(BETA1, 0.999))# 训练过程记录
gen_losses, dis_losses = [], []for epoch in range(TRAIN_EPOCHS):epoch_gen_loss, epoch_dis_loss = 0, 0batch_count = 0for batch_data, batch_labels in positive_loader:real_data = batch_data.to(device)real_labels = batch_labels.to(device)current_batch_size = real_data.size(0)batch_count += 1# 准备标签数据valid_labels = torch.ones(current_batch_size, 1).to(device)fake_labels = torch.zeros(current_batch_size, 1).to(device)# ---------------------# 训练判别器# ---------------------dis_optim.zero_grad()# 真实样本判别real_pred = discriminator_net(real_data, real_labels)dis_real_loss = loss_func(real_pred, valid_labels)# 生成合成样本noise_input = torch.randn(current_batch_size, NOISE_DIM).to(device)synthetic_data = generator_net(noise_input, real_labels)# 合成样本判别fake_pred = discriminator_net(synthetic_data.detach(), real_labels)dis_fake_loss = loss_func(fake_pred, fake_labels)# 计算判别器总损失dis_total_loss = (dis_real_loss + dis_fake_loss) / 2dis_total_loss.backward()dis_optim.step()epoch_dis_loss += dis_total_loss.item()# ---------------------# 训练生成器# ---------------------gen_optim.zero_grad()# 评估生成样本validity = discriminator_net(synthetic_data, real_labels)gen_loss = loss_func(validity, valid_labels)gen_loss.backward()gen_optim.step()epoch_gen_loss += gen_loss.item()# 计算平均损失avg_gen_loss = epoch_gen_loss / batch_countavg_dis_loss = epoch_dis_loss / batch_countgen_losses.append(avg_gen_loss)dis_losses.append(avg_dis_loss)# 定期输出训练进度if (epoch + 1) % 100 == 0:print(f"迭代周期 [{epoch+1}/{TRAIN_EPOCHS}], 判别器损失: {avg_dis_loss:.4f}, 生成器损失: {avg_gen_loss:.4f}")print("CGAN训练完成!")# 可视化训练过程
plt.figure(figsize=(10, 5))
plt.plot(gen_losses, label='生成器损失')
plt.plot(dis_losses, label='判别器损失')
plt.title('网络训练损失变化')
plt.xlabel('训练周期')
plt.ylabel('损失值')
plt.legend()
plt.show()# =============================
# 生成合成数据
# =============================
generator_net.eval()# 生成与阳性样本等量的合成数据
synth_count = len(positive_samples)
noise = torch.randn(synth_count, NOISE_DIM).to(device)
synth_conditions = torch.ones(synth_count, 1).to(device)with torch.no_grad():generated_features = generator_net(noise, synth_conditions).cpu().numpy()# 数据逆归一化
generated_features = normalizer.inverse_transform(generated_features)
generated_labels = np.ones(synth_count)# 原始数据逆归一化
original_train = normalizer.inverse_transform(X_train)# =============================
# 数据分布可视化
# =============================
plt.figure(figsize=(14, 10))
selected_features = ['年龄', '静息血压', '胆固醇', '最大心率']
feature_indices = [0, 3, 4, 7]for i, idx in enumerate(feature_indices):plt.subplot(2, 2, i+1)sns.kdeplot(original_train[y_train == 1, idx], label='原始样本', color='blue')sns.kdeplot(generated_features[:, idx], label='合成样本', color='orange')plt.title(f'{selected_features[i]}分布对比')plt.xlabel('特征值')plt.ylabel('概率密度')plt.legend()plt.tight_layout()
plt.show()# =============================
# 模型性能评估
# =============================
# 基准模型训练
base_model = RandomForestClassifier(random_state=42)
base_model.fit(X_train, y_train)
base_pred = base_model.predict(X_test)
base_probs = base_model.predict_proba(X_test)[:, 1]# 增强模型训练
X_augmented = np.vstack([X_train, normalizer.transform(generated_features)])
y_augmented = np.hstack([y_train, generated_labels])
augmented_model = RandomForestClassifier(random_state=42)
augmented_model.fit(X_augmented, y_augmented)
aug_pred = augmented_model.predict(X_test)
aug_probs = augmented_model.predict_proba(X_test)[:, 1]# 性能比较
base_f1 = f1_score(y_test, base_pred)
aug_f1 = f1_score(y_test, aug_pred)print("\n模型性能对比:")
print(f"基准模型 F1 分数: {base_f1:.4f}")
print(f"增强模型 F1 分数: {aug_f1:.4f}")print("\n基准模型分类报告:")
print(classification_report(y_test, base_pred))print("\n增强模型分类报告:")
print(classification_report(y_test, aug_pred))# 计算评估指标
base_cm = confusion_matrix(y_test, base_pred)
aug_cm = confusion_matrix(y_test, aug_pred)base_fpr, base_tpr, _ = roc_curve(y_test, base_probs)
aug_fpr, aug_tpr, _ = roc_curve(y_test, aug_probs)
base_auc = auc(base_fpr, base_tpr)
aug_auc = auc(aug_fpr, aug_tpr)# =============================
# 结果可视化
# =============================
# 混淆矩阵对比
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
sns.heatmap(base_cm, annot=True, fmt='d', cmap='Blues')
plt.title('基准模型混淆矩阵')
plt.xlabel('预测结果')
plt.ylabel('实际结果')plt.subplot(1, 2, 2)
sns.heatmap(aug_cm, annot=True, fmt='d', cmap='Greens')
plt.title('增强模型混淆矩阵')
plt.xlabel('预测结果')
plt.ylabel('实际结果')
plt.tight_layout()
plt.show()# ROC曲线对比
plt.figure(figsize=(8, 6))
plt.plot(base_fpr, base_tpr, 'b-', lw=2, label=f'基准模型 (AUC={base_auc:.2f})')
plt.plot(aug_fpr, aug_tpr, 'g-', lw=2, label=f'增强模型 (AUC={aug_auc:.2f})')
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('假阳性率')
plt.ylabel('真阳性率')
plt.title('ROC曲线对比')
plt.legend(loc="lower right")
plt.show()# F1分数对比
plt.figure(figsize=(8, 5))
plt.bar(['基准模型', '增强模型'], [base_f1, aug_f1], color=['blue', 'green'])
plt.ylim(0, 1)
plt.title('模型性能对比')
plt.ylabel('F1分数')
for i, score in enumerate([base_f1, aug_f1]):plt.text(i, score + 0.02, f'{score:.4f}', ha='center')
plt.show()
@浙大疏锦行