当前位置: 首页 > news >正文

联邦学习研读笔记

联邦学习的基本原理

1.1 分布式学习和集中式学习的对比

1)分布式学习:

定义:分布式学习是一种通过将数据和计算任务分散在多个参与方之间进行模型训练的方法。每个参与方都拥有本地数据,并在本地进行模型训练,然后将更新的模型参数进行聚合。

数据分布:每个参与方拥有自己的本地数据集,这些数据可以是私有的,不需要共享给其他参与方。

模型训练:每个参与方在本地使用自己的数据进行模型训练。模型训练可以采用各种机器学习算法和优化方法。

模型聚合:参与方将本地训练得到的模型参数传输给中央服务器或协调者,然后在中央服务器上进行模型聚合,通常采用加权平均或其他聚合算法来获得全局模型。

隐私保护:由于数据不需要共享,参与方可以在本地保持数据隐私,并采取差分隐私技术等方法来保护隐私。

2)集中式学习:

定义:集中式学习是一种通过集中收集和存储所有数据,并在中央服务器上进行模型训练的方法。数据集中存储在一个位置,并且所有模型的训练都在该位置进行。

数据集中:所有参与方的数据被收集到一个中央服务器或数据中心,并且数据需要共享给其他参与方或中央服务器。

模型训练:在中央服务器上进行模型训练,使用整个数据集进行模型参数的更新和优化。

隐私保护:由于数据需要集中存储和共享,可能存在隐私泄露的风险。必须采取适当的隐私保护措施,如数据加密、安全传输和访问控制等。

分布式学习和集中式学习在数据分布、模型训练和隐私保护方面存在差异。分布式学习允许参与方在本地训练模型,保护数据隐私,并通过模型聚合获得全局模型。而集中式学习需要集中存储和处理数据,并需要采取额外的隐私保护措施来保护数据隐私。选择使用哪种学习范式取决于具体应用场景和数据隐私要求。

1.2 隐私保护和数据安全性

1)隐私保护:

隐私保护是指在联邦学习过程中,确保参与方的个人数据不被泄露或暴露给其他参与方或第三方

隐私保护的目标是保护数据所有者的隐私权利,同时使其能够参与联邦学习任务,共同构建模型,而无需共享原始数据。
在联邦学习中,常用的隐私保护技术包括差分隐私和加密方法。差分隐私通过向模型训练过程中添加噪声来保护数据隐私,而加密方法使用密码学技术对数据进行加密,使得只有具有相应解密密钥的参与方能够解密数据。

2)数据安全性:

数据安全性是指在联邦学习中,确保数据在传输和存储过程中的安全性,以防止未经授权的访问、篡改或数据泄露。

数据安全性的保护涉及到传输通道的安全加密,防止中间人攻击和数据被窃取。同时,在数据存储过程中,采取适当的安全措施,如数据加密、访问控制和备份策略,以确保数据的完整性和可用性。


在联邦学习中,参与方之间的通信需要采用安全的通信协议和加密算法,确保数据在传输过程中不被窃取或篡改。此外,参与方需要采取相应的安全措施来保护其本地存储的数据,以防止未经授权的访问。

 隐私保护和数据安全性在联邦学习中起着关键的作用,它们确保了参与方的隐私权益和数据的安全性,为参与方提供了信心和动力来积极参与联邦学习任务,并促进了数据共享与模型构建的可持续发展。

1.3 本地更新和全局模型聚合

1)本地更新(Local Update):

在联邦学习中,参与方(例如设备、终端用户)在本地进行模型训练,使用其本地数据集进行迭代优化。每个参与方在本地独立地执行多个训练轮次,以提升其本地模型的性能。这意味着每个参与方可以根据自身的数据特点和隐私要求,采用特定的优化算法、超参数设置和训练策略来更新其本地模型。本地更新阶段通常使用的是传统的机器学习或深度学习算法,类似于集中式学习中的模型训练过程。

2)全局模型聚合(Global Model Aggregation):

在本地更新完成后,参与方需要将其本地模型的更新结果汇总到一个全局模型中。全局模型是联邦学习的核心组件,它代表了整个联邦学习系统的共享知识。全局模型的聚合通常通过一种协调机制来实现,例如联邦平均算法(Federated Averaging)。在全局模型聚合阶段,各参与方将本地模型的参数或梯度发送给中央服务器或协调方,并根据一定的规则进行模型参数的融合。常见的融合方法包括简单的平均、加权平均等。全局模型聚合的目标是将各个参与方的更新结果整合起来,形成一个更加全局性和综合性的模型,以达到更好的性能和泛化能力。

1.4 参与方角色和协作流程

中央服务器(Central Server):中央服务器负责协调和管理整个联邦学习过程。它通常负责模型的初始化、更新和聚合,并提供全局模型的指导方针。

参与方(Participating Parties):参与方是数据拥有者,如个人用户、组织或设备。每个参与方在本地维护着自己的数据集,并负责在本地进行模型训练和更新。

初始化阶段:在联邦学习开始之前,中央服务器初始化全局模型,并将其分发给参与方。

本地训练阶段:参与方在本地使用自己的数据集进行模型训练。他们可以选择不同的机器学习算法和优化策略,以适应自己的数据特点。

模型更新阶段:参与方将本地训练得到的模型参数发送给中央服务器,以便进行模型的更新。这可以通过梯度传输或模型参数的加密方式实现。

模型聚合:服务端对局部上传模型进行聚合。

2 基于Google Cloab的联邦学习案例实现

2.1 数据集介绍:

数据集包含20个受试者进行各种活动的记录。每个受试者使用2个加速度计和1个陀螺仪进行记录。该数据集描述了以下6种活动:

步行活动、上楼梯活动、下楼梯活动、坐姿活动、站立活动、躺卧活动。

训练集包含14个受试者(train.csv),其余用于测试(test.csv)。CSV文件的列描述如下:

列1:受试者编号

列2:加速度计1在x轴上的加速度值

列3:加速度计1在y轴上的加速度值

列4:加速度计1在z轴上的加速度值

列5:陀螺仪1在x轴上的角速度值

列6:陀螺仪1在y轴上的角速度值

列7:陀螺仪1在z轴上的角速度值

列8:加速度计2在x轴上的加速度值

列9:加速度计2在y轴上的加速度值

列10:加速度计2在z轴上的加速度值

列11:标签

2.2  综合指标提取:

将多个轴的数据合成为一个综合指标,用于表示物体的整体运动状态或特征。具体来说,对原始数据集中的加速度和陀螺仪数据进行了平方和开根号的操作,以获得合成的结果。

删除指定的列,在处理数据特征阶段,删除包含对于后续分析或模型建立来说不重要或不相关的列。

首先获取数据中所有不重复的主体(subject)标签,并进行存储,然后根据当前主体筛选出对应的数据,存储在临时变量tempdat中。接着从临时数据中删除主体和标签列,将剩余的特征数据转换为NumPy数组,存储在变量xtemp中(只保留了和特征相关的数据)。再将临时数据中的标签列进行独热编码(one-hot encoding将分类变量转换为机器学习算法可以处理的格式,常采用二进制的向量),将结果转换为NumPy数组,存储在变量ytemp中。最后将特征数据和标签数据按照样本一一对应的方式进行组合,函数返回一个字典finaldict,其中包含按主体分类的数据。每个键对应一个主体的名称,值为该主体下的特征数据和标签数据组合(可以通过主体的名称来访问该主体下的特征数据和标签数据)。

Subject唯一值:

2.3  数据集展示:

2.4 数据集划分:

根据"subject"列的特定值将数据集划分为测试集和训练集,其中测试集包含"subject"列值为 20、22、23、27 和 29 的行,而训练集则包含其他"subject"列值的行。

2.5  将训练数据进行批处理

将训练数据按照subject进行分组,其中每个subject对应一组数据。定义了一个名为batch_data的函数,用于将完整的数据集进行批处理。它接受一个完整的数据集,将数据和标签分别提取出来,并使用tf.data.Dataset.from_tensor_slices方法将它们转化为TensorFlow数据集对象。最后,使用dataset.batch(25)将数据集分批,每批包含25个样本。

通过循环遍历traindata中的每个subject和对应的数据,将每组数据使用batch_data函数进行批处理,并将处理后的数据保存在data_batched字典中。字典的键是主题,值是对应主题的批处理后的数据集。

将测试数据集进行预处理(删除对应列,独热编码、批处理),并将其转换为 TensorFlow 的数据集对象。

 定义一个简单的SimpleMLP网络:

定义计算权重缩放因子函数,即计算当前客户端的权重缩放因子,该因子用于在联邦学习中调整模型更新的权重。具体做法是

  1. 1)通过获取所有客户端的键(key)以及一个批次的样本数量;
  2. 2)计算全局训练数据的批次数:通过对所有客户端的数据进行计数,乘以批次大小,得到联邦学习中的全局训练数据的批次数。
  3. 3)计算当前客户端的数据点数量:对当前客户端的数据进行计数,乘以批次大小,得到当前客户端的数据点数量。
  4. 4)返回当前客户端的权重缩放因子:它是当前客户端数据点数量与全局训练数据批次数的比值,用于调整当前客户端的模型更新权重。

在联邦学习中,不同客户端的数据集可能具有不同的大小,因此需要通过权重缩放因子来平衡客户端的贡献,确保模型更新的公平性和一致性。

 处理模型权重的缩放和加权求和:对每个权重进行缩放,并将缩放后的权重存储在weight_final列表中,再对传入的缩放后的权重列表进行逐层的加权求和操作。对于每一层的权重,它将多个副本(来自不同客户端的缩放权重)按元组进行组合,然后使用tf.math.reduce_sum函数对每个元组进行求和操作。最终,函数返回加权平均后的权重列表avg_grad。

联邦学习中,通常对来自不同客户端的模型权重进行缩放和加权求和,以生成全局模型的更新。通过缩放权重和加权求和,可以确保每个客户端的贡献相对平衡,并生成具有整体代表性的全局模型更新。

 测试联邦学习模型的性能,使用训练好的模型对测试数据进行预测,得到预测结果pred,再计算预测结果pred与真实标签Y_test之间的分类交叉熵损失,并将结果赋给loss变量,利用accuracy_score函数计算预测结果pred和真实标签Y_test之间的分类准确率,并将结果赋给acc变量。

 2.6 具体实现流程:

  1. 创建全局模型:使用SimpleMLP(6,6)初始化一个全局模型,输入维度为6,输出维度为6。

  2. 开始全局训练循环:使用一个循环来执行多个通信轮次(communication round),类似于训练的多个epoch。

  3. 定义变量和初始权重:定义comms_round表示通信轮次的总数,subject_names表示客户端名称列表。然后,获取全局模型的初始权重。

  4. 随机化客户端数据:对客户端名称列表进行随机化,以模拟在真实环境中的客户端选择过程。

  5. 遍历每个客户端:对每个客户端进行以下操作:

          1)创建一个本地模型:使用SimpleMLP(6,6)初始化一个本地模型,与全局模型具有相同的架构。

          2)设置本地模型的权重:将本地模型的权重设置为全局模型的权重,以使本地模型与全局模型初始权重相同。

          3)使用客户端数据训练本地模型:使用客户端的数据(data_batched[sub])对本地模型进行训练,训练2个时期(epochs)。

          4)对模型权重进行缩放并添加到列表:通过计算权重缩放因子(scaling_factor),对本地模型的权重进行缩放,并将缩放后的权重添加到列表scaled_local_weight_list中。

          5)清理会话以释放内存:在每次通信轮次之后,使用K.clear_session()清理会话,释放内存资源。

    6.计算平均权重:通过将缩放后的权重进行累加,计算得到所有本地模型的平均权重average_weights。

      7.更新全局模型:将全局模型的权重更新为平均权重。

      8.测试全局模型并打印指标:使用测试数据集(test_batched)对全局模型进行测试,并在每次通信轮次之后打印出准确率和损失等指标。

 通过这个程序,联邦学习通过迭代地从客户端收集本地模型并在全局模型上进行聚合和更新,以实现在分布式环境中进行模型训练和评估的目的。每个客户端使用自己的本地数据训练本地模型,然后将缩放后的模型权重进行聚合以获得全局模型的更新。最终,可以使用全局模型在测试数据集上进行评估和打印指标。

训练过程:

全局模型性能在 70% 左右。

验证过程:

使用验证数据对联邦学习模型进行测试,并通过循环遍历批次数据来调用相应的测试函数:

 结果:

http://www.dtcms.com/a/122985.html

相关文章:

  • printf
  • 【NLP 面经 9、逐层分解Transformer】
  • 第十一章 Python语言-高阶技巧(终章)
  • Dubbo(44)如何排查Dubbo的服务依赖问题?
  • 17. git pull
  • 6、nRF52xx蓝牙学习(nrf_gpiote.c库函数学习)
  • 基于 AI智能体、大模型、RAG、Agent 等技术构建公司内部闭环智能问答系统的详细方案,结合 Spring Boot + Vue 管理系统 的改造思路
  • Http代理服务器选型与搭建
  • Starrocks的Bitmap索引和Bloom filter索引以及全局字典
  • 基于微信小程序的志愿服务系统的设计与实现
  • 数字图像处理作业3
  • fuse-python使用fuse来挂载fs
  • 汽车软件开发常用的建模工具汇总
  • Joomla 常用模块 - 在线用户与Joomla 常用模块 - 自定义HTML模块
  • [leetcode]判断质数
  • 关于C++日志库spdlog
  • JS 函数提升
  • 蓝桥杯十一届C++B组真题题解
  • 革新电销流程,数企云外呼开启便捷 “直通车”
  • 各种场景的ARP攻击描述笔记(超详细)
  • stream流Collectors.toMap(),key值重复问题
  • Bootstrap Table动态修改列标题
  • C++中命名空间namespace|头文件h文件|源文件cpp文件详解
  • pyecharts常用图形
  • Mysql索引(二)
  • 8.第二阶段x64游戏实战-string类
  • UE学习记录part15
  • ffpyplayer+Qt,制作一个视频播放器
  • 玩转Docker | 使用Docker安装FileDrop文件共享工具
  • 如何解【决泛型作为运行时参数】时类型擦除问题