当前位置: 首页 > news >正文

联邦学习论文分享:Federated Learning with GAN-based Data Synthesis for Non-IID Clients

摘要

  • 问题背景

    • 联邦学习允许多个客户端在不共享原始数据的情况下协同训练模型,但当各客户端的数据分布不一致(非IID)时,模型性能会受到影响。

  • 提出的方案

    • 作者提出了一种新框架,叫 Synthetic Data Aided Federated Learning(SDA-FL),通过共享合成数据来缓解非IID问题。

    • 每个客户端预训练一个本地的生成对抗网络(GAN),生成差分隐私的合成数据

    • 这些合成数据上传到参数服务器(PS),由服务器构建一个全局共享的合成数据集

  • 核心技术点

    • 服务器采用迭代伪标签机制(iterative pseudo labeling)为合成数据生成可靠的标签。

    • 客户端在训练时结合本地真实数据和带伪标签的合成数据,使得不同客户端的数据分布更接近,从而提升本地模型一致性,改善全局模型聚合效果。

  • 实验结果

    • 在多个基准数据集上,SDA-FL在监督和半监督学习场景下均显著优于现有基线方法。

引言

1. 背景与问题

  • 联邦学习(FL)允许客户端在不共享本地数据的情况下协同训练全局模型(FedAvg 是典型算法)。

  • 当客户端数据是 IID 时效果良好,但在非IID(数据分布偏斜)时:

    • 不同客户端学习到的模型差异大(本地模型不一致)

    • 全局模型聚合效果下降,性能显著降低。

2. 现有方法及局限

  • 基于模型正则化的方法

    • 通过全局模型或其他客户端的本地模型信息来约束本地模型

    • 限制:无法在极端非IID情况下取得显著提升。

  • 数据增强方法

    • 生成合成样本(例如混合真实样本)

    • 限制:缺乏隐私保护,存在数据泄露风险。

3. 作者提出的 SDA-FL 框架

  • 核心思想:通过共享差分隐私的合成数据解决非IID问题,而不是直接共享真实数据。

  • 流程概述

    1. 每个客户端预训练本地的 差分隐私 GAN 生成合成数据。

    2. 将合成数据上传到参数服务器(PS),构建全局合成数据集。

    3. 迭代伪标签机制

      • PS 根据收到的本地模型更新合成数据的伪标签

      • 随着训练轮次增加,伪标签的可信度提高

      • 有助于提升本地更新和全局聚合效果

  • 优点

    • 避免直接共享真实数据,保护隐私

    • 兼容多种现有联邦学习方法

    • 可用于监督学习和半监督学习场景

    • 不需要真实数据标签也能工作

4. 实验验证

  • 实验会验证:

    • SDA-FL 的性能提升

    • 隐私预算的影响

    • 各关键步骤(如伪标签机制)的有效性

相关工作

1. Non-IID Challenges in Federated Learning

  • 问题:非IID数据分布是联邦学习的核心障碍,它会导致本地模型差异过大,从而使聚合模型性能下降。

  • 已有解决思路

    1. 修改本地目标函数:利用全局模型或其他客户端的模型信息来缓解客户端漂移(client drift),但在严重非IID场景下效果有限。

    2. 调整模型结构:允许不同客户端使用不同的本地模型结构,以适应各自的数据分布。

    3. 优化参数服务器操作:例如改进模型聚合方式、客户端选择策略、客户端聚类或分类器校准等。

2. Data Augmentation and Privacy Preserving

  • 数据共享思路:近年来,一些研究尝试通过数据增强或数据共享来缓解非IID问题。

  • 代表性方法

    1. Mixup 技术:通过混合客户端的真实样本来生成新的全局数据集,帮助缓解非IID问题。但频繁数据交换可能带来隐私泄露风险。

    2. GAN 增强:在服务器上利用客户端上传的少量样本训练生成器,再下发给客户端使用。但上传真实样本违背隐私保护原则。

    3. FedDPGAN:让所有客户端在联邦学习框架下协作训练一个全局生成模型,以补充本地数据。但问题是:

      • GAN 训练需要频繁交换生成模型,通信开销很大

      • 同时存在被对抗攻击利用的风险

前提知识

1. Federated Learning (FL) 基础

  • 介绍了经典的 FedAvg 算法流程:

    1. 每轮训练,选取一部分客户端下载全局模型参数 wt​。

    2. 客户端用本地数据 Dk=(Xk,Yk) 通过 SGD 更新本地模型。

    3. 将更新后的本地模型上传到服务器。

    4. 服务器对模型进行加权聚合,得到新的全局模型。

  • 问题:在非IID场景下,本地模型差异过大,聚合模型性能显著下降。

  • 解决思路:用生成对抗网络(GAN)生成高质量的合成数据,在客户端和服务器之间共享,用于缓解数据分布不一致问题。

2. 差分隐私生成对抗网络(DP-GAN)

  • 为了稳定训练,采用 WGAN-GP(Wasserstein GAN with Gradient Penalty)

  • 在训练中引入 差分隐私 (DP),通过在判别器梯度更新时加入高斯噪声来保护隐私。

  • 公式解释了差分隐私的定义 (ϵ,δ),以及噪声方差与采样概率、批次数之间的关系。

  • 关键点:判别器加噪声 → 生成器通过反向传播间接接收噪声 → 保证生成器输出的合成数据也满足 DP。

3. 伪标签机制(Pseudo Labeling)

  • 由于半监督场景下标签稀缺,无法直接训练有条件的生成模型(如 ACGAN)。

  • 作者采用 伪标签策略

    • 给合成样本预测类别,如果某一类别的预测概率 最大且超过阈值 τ,则将其作为伪标签。

    • 否则不赋标签。

  • 这样能保证只有高置信度的合成数据才被赋予伪标签,用于更新模型。

  • 在整个联邦学习过程中,伪标签会随着本地模型的不断改进而迭代更新,逐步提升标签质量。

核心算法

1. 全局合成数据集构建 (Global Synthetic Dataset Construction)

  • 每个客户端先用本地数据预训练一个 GAN,生成差分隐私的合成数据。

  • 客户端上传合成数据到服务器(PS),由 PS 构建一个全局共享的合成数据集。

  • 伪标签生成

    • 与已有方法不同,SDA-FL 不是用其他客户端的模型,而是用本地模型给对应的合成数据打伪标签(因为它们训练于同一数据分布,匹配性更高)。

    • PS 在每一轮 FL 收到客户端模型后,更新全局合成数据的伪标签。随着本地模型变强,伪标签质量逐步提升。

2. 基于合成数据的模型训练 (Synthetic Data Aided Model Training)

  • 客户端训练

    • 客户端使用真实数据和带高置信度伪标签的合成数据进行联合训练。

    • 采用 Mixup 技术(将真实数据和合成数据线性插值),生成更均匀的数据分布。

    • 定义两类损失函数:

      1. Mixup 损失 ℓ1​:结合真实样本和合成样本。

      2. 真实数据损失 ℓ2​:避免在早期训练阶段伪标签不准带来的不稳定性。

    • 最终更新规则:

  • 服务器训练

    • 传统 FL 的 PS 仅做模型聚合,没有数据。

    • SDA-FL 的 PS 保存全局合成数据集,用它直接更新全局模型,进一步提升效果。

3. 模型训练与合成数据更新的交互 (Interplay between Model Training and Dataset Updating)

  • 每一轮中:

    • 合成数据帮助客户端训练出更好的本地模型;

    • 改进后的本地模型提升伪标签质量,从而增强全局合成数据集;

    • PS 基于更可靠的合成数据提升全局模型;

    • 下一轮客户端再受益于更强的全局模型。

  • 这种 循环增强机制 是 SDA-FL 成功的关键。

4. SDA-FL 与传统 FL 的对比 (SDA-FL vs. Traditional FL)

  • 传统 FL

    • 客户端仅依赖本地数据训练,容易因非IID导致模型不一致,全局性能下降。

    • PS 只做简单参数聚合。

  • SDA-FL

    • 客户端利用合成数据增强本地数据分布,缓解非IID问题。

    • PS 不仅聚合模型,还利用合成数据直接训练全局模型。

  • 总结:通过共享合成数据集 + 迭代伪标签机制,SDA-FL 有效克服了数据异质性,提升了全局模型性能。

实验

设置

1. 数据集 (Datasets)

  • 使用 四个常见基准数据集:MNIST、FashionMNIST、CIFAR-10、SVHN。

  • 采用 非IID划分方式

    • 每个客户端只分到少数类别的数据。

    • 数据先划分为单类别子集,再随机分配给客户端。

  • 超参数设置

    • 伪标签阈值 τ = 0.95,γ = 10.0,λ₂ = 1.0。

    • 客户端数 = 10(每轮全选)。

    • 每个客户端上传 4000 条合成数据给服务器。

  • GAN预训练

    • CIFAR-10:36,000次迭代

    • 其他数据集:18,000次迭代

  • 训练设置

    • 总轮次 200

    • PS 更新:CIFAR-10 10次迭代 / 其他数据集 50次迭代

    • 客户端本地更新:监督学习 E=90 / 半监督学习 E=40

    • 优化器:SGD,学习率 0.03

    • 批量大小:监督学习 B=64 / 半监督学习 B=80(含16有标签样本+64无标签样本)

2. 现实数据验证 (COVID-19 dataset)

  • 为了检验 SDA-FL 的实际应用价值,还在 真实 COVID-19 医学影像数据集 上测试。

  • 特点:Pneumonia 样本稀缺 → 只设置 6 个客户端,每个客户端分到 2 类数据。

  • GAN 本地训练:4,500 次迭代。

  • 本地模型更新:30次迭代;全局模型更新:10次迭代。

3. 对比方法 (Baselines)

  • 监督学习对比

    • FedAvg、FedProx、SCAFFOLD、Naivemix、FedMix

    • 在 COVID-19 数据集上额外比较 FedDPGAN(面向医疗的全局GAN方法)。

  • 半监督学习对比

    • SemiFL、Local Fixmatch、Local Mixup。

  • 为公平比较,调优了 FedProx 的正则化参数 μ 和 FedMix 的混合比例 λ。

4. 模型结构 (Models)

  • 简单 CNN:用于 MNIST 和 FashionMNIST。

  • ResNet18:用于 CIFAR-10、SVHN、COVID-19。

  • 生成器与判别器

    • 生成器:4层反卷积。

    • 判别器:4层卷积 + 1层全连接。

结果

1. 监督式联邦学习的表现 (Federated Supervised Learning)

  • 在不同类别数量的客户端实验中,SDA-FL 显著优于基线方法

  • 在 CIFAR-10 上,SDA-FL 比 Naivemix 和 FedMix 至少高 5%(客户端有 3 个类别时)。

  • 在 COVID-19 数据集上:

    • SDA-FL 比 FedDPGAN 高 1.68%。

    • 甚至比 IID 情况下的 FedAvg 还高 1.14%,说明 SDA-FL 对医学场景尤其有优势。

2. 半监督式联邦学习的表现 (Federated Semi-Supervised Learning)

  • 收敛更快、精度更高,比其他算法更稳定。

  • 在 FashionMNIST 上,比 Semi-FL 高将近 10%。

  • 在 CIFAR-10 上,基线方法无法收敛到可用模型(准确率 < 40%),而 SDA-FL 能收敛并显著提升准确率。

  • 原因:伪标签机制能为合成数据和无标签数据提供高质量标签。

3. 隐私预算与性能的权衡 (Privacy vs. Performance)

  • 用差分隐私控制参数 ε 时:

    • 更严格的隐私预算(ε=5)导致 FID 分数上升(图像质量下降)。

    • 进而引起 MNIST 降低约 0.61%、FashionMNIST 降低约 2.59% 的准确率,CIFAR-10 也有类似趋势。

  • 但即便在隐私约束下,SDA-FL 依然保持了领先性能。

4. 服务器更新与伪标签更新的有效性

  • 服务器更新 (PS 用合成数据训练全局模型)

    • 如果不在服务器更新 → FashionMNIST 上准确率下降近 3%。

    • 如果只用合成数据更新(无限次) → 准确率只有 66%,说明合成数据要与真实数据平衡使用。

  • 伪标签更新

    • 每轮更新伪标签能逐步提升模型精度(随着伪标签置信度增加)。

    • 额外通信开销极小,因为只传输标签而非样本。

5. 与 AC-WGAN-GP 的对比

  • 尝试在 GAN 训练时直接加上标签(AC-WGAN-GP),与 WGAN-GP、SDA-FL 对比:

    • WGAN-GP 的合成数据质量更高(FID 更低)。

    • 虽然 AC-WGAN-GP 可以生成带标签的样本,但 SDA-FL 的 伪标签机制效果更好,最终分类准确率高于 AC-WGAN-GP。

    • 说明 “高质量数据 + 伪标签机制” > “低质量数据 + 自带标签”

http://www.dtcms.com/a/395272.html

相关文章:

  • 绕过百度网盘无限制下载
  • 【自记】PyCharm 更换阿里云国内源教程
  • 【Axure原型分享】区间K线图
  • javascript之Es6八股文
  • npm和pnpm命令大全
  • kali下安装beef-xss报错-启动失败-简单详细
  • 政策法规下的LLM安全:合规之路
  • 《第21课——C typedef:从Java的“实名制”到C的“马甲生成器”——类型伪装术与代码整容的艺术》
  • 【每天一个知识点】什么是知识库?
  • 豆包·Seedream 4.0深度测评:4K多模态时代的图像创作革命(图文增强版)
  • [新启航]发动机喷管推进剂输送孔光学 3D 轮廓测量 - 激光频率梳 3D 轮廓技术
  • 深入理解 TCP 协议:三次握手与四次挥手的底层原理
  • PyTorch 神经网络工具箱
  • 机器学习-多因子线性回归
  • 国产化Excel开发组件Spire.XLS教程:Python 写入 Excel 文件,数据写入自动化实用指南
  • 08 - spring security基于jdbc的账号密码
  • 解决SSL证书导致源站IP被泄露的问题
  • Worst Western Hotel: 1靶场渗透
  • 电子电气架构 --- 软件开发与产品系统集成流程(上)
  • 运维安全08,日志检测和 tcpdump (抓包) 的介绍以及使用
  • DSC 归档配置相关
  • 彭博社-BloombergGPT金融大模型
  • GPT5 Codex简单快速上手
  • Linux配置白名单限制访问_ipset+iptables
  • 多元化通证经济模型:DAO的神经和血液
  • 高系分十六:web应用
  • 【LeetCode热题100(27/100)】合并两个有序链表
  • 嵌入式(SOC+FreeRTOS)汽车仪表盘接口参数安全:规范遵循与防护实践
  • Maven 完整教程
  • 数据驱动下的用户画像系统:从0到1的技术实战与避坑指南