python打卡day53@浙大疏锦行
知识点回顾:
- 对抗生成网络的思想:关注损失从何而来
- 生成器、判别器
- nn.sequential容器:适合于按顺序运算的情况,简化前向传播写法
- leakyReLU介绍:避免relu的神经元失活现象
ps;如果你学有余力,对于gan的损失函数的理解,建议去找找视频看看,如果只是用,没必要学
作业:对于心脏病数据集,对于病人这个不平衡的样本用GAN来学习并生成病人样本,观察不用GAN和用GAN的F1分数差异。
一、数据预处理(修改 src/data/preprocessing.py )
def split_minority_class(data_df):# 提取少数类(病人样本)minority = data_df[data_df.target == 1]return minority.drop('target', axis=1).values
二、GAN网络定义(新增 src/models/gan.py )
class Generator(nn.Sequential):def __init__(self, input_dim, output_dim):super().__init__(nn.Linear(input_dim, 128),nn.LeakyReLU(0.2),nn.Linear(128, 256),nn.LeakyReLU(0.2),nn.Linear(256, output_dim),nn.Tanh())class Discriminator(nn.Sequential):def __init__(self, input_dim):super().__init__(nn.Linear(input_dim, 256),nn.LeakyReLU(0.2),nn.Linear(256, 128),nn.LeakyReLU(0.2),nn.Linear(128, 1),nn.Sigmoid())
三、训练流程(修改 src/models/train.py )
# GAN训练循环
for epoch in range(epochs):for real_data in minority_loader:# 生成假数据z = torch.randn(batch_size, latent_dim)fake_data = generator(z)# 判别器训练d_loss_real = criterion(discriminator(real_data), real_labels)d_loss_fake = criterion(discriminator(fake_data.detach()), fake_labels)d_loss = (d_loss_real + d_loss_fake) / 2# 生成器训练g_loss = criterion(discriminator(fake_data), real_labels)
四、评估对比(新增 src/visualization/evaluate.py )
def compare_f1(original_f1, gan_f1):plt.figure(figsize=(10,6))plt.bar(['Original', 'GAN Augmented'], [original_f1, gan_f1])plt.title('F1 Score Comparison')plt.savefig('reports/figures/f1_comparison.png')
执行流程
1.安装依赖
pip install imbalanced-learn
2.训练GAN生成样本
3.分别训练基线模型和增强模型
4.生成对比报告