深入浅出横向联邦学习、纵向联邦学习、联邦迁移学习
深入浅出解析横向联邦学习(Horizontal Federated Learning)、纵向联邦学习(Vertical Federated Learning)和联邦迁移学习(Federated Transfer Learning)
有多个机构(比如几家不同的银行,或者几家医院)都拥有一些数据,他们希望联合起来训练一个更强大的机器学习模型,但出于隐私保护或法规要求,他们不能直接把数据共享出来。联邦学习就是为了解决这个问题而提出的。它的核心思想是“数据不动模型动”,即数据保留在本地,参与方共同训练模型。
在这个大框架下,根据数据分布的不同,主要可以分为横向联邦学习和纵向联邦学习。而联邦迁移学习则是在此基础上,解决数据特征或样本不足的问题。
横向联邦学习 (Horizontal Federated Learning, HFL)
- 核心特点:特征相似,样本不同。
- 打个比方: 想象一下,有几所不同地区的小学,他们都想联合训练一个更好的学生成绩预测模型。这些学校记录的学生信息(特征)是相似的,比如都有学生的年龄、性别、各科成绩、出勤率等。但是,每个学校的学生群体(样本)是完全不同的。
- 数据分布: 参与方的数据库结构(特征)相同,但数据记录(样本)不同。就像将数据表格水平切分,每个参与方持有一部分行。
- 如何工作:
- 初始化: 一个中心服务器(协调者)初始化一个全局模型,并将模型分发给各个参与方。
- 本地训练: 每个参与方利用本地数据训练这个模型,得到各自的模型更新(比如梯度或模型参数)。
- 模型聚合: 参与方将加密后的模型更新发送给中心服务器。服务器对收集到的模型更新进行聚合(例如,通过加权平均,如经典的FedAvg算法),形成一个新的、更优的全局模型。
- 模型更新: 服务器将聚合后的新模型再分发给各个参与方。
- 迭代: 重复步骤2-4,直到模型收敛或达到预设的训练轮次。
- 典型应用场景:
- 智能手机输入法预测: 不同用户手机上的输入习惯数据特征相似(都是文字序列),但具体输入的内容(样本)不同。Google的Gboard输入法就是横向联邦学习的典型应用。
- 多家分行的零售业务预测: 不同分行的客户特征相似,但客户群体不同。
- 优势: 算法设计相对灵活,可扩展性强。
- 挑战: 需要处理数据异构性(不同参与方数据分布可能不完全一致)、设备异构性(不同设备计算能力和网络状况不同)以及通信效率问题。
纵向联邦学习 (Vertical Federated Learning, VFL)
- 核心特点:样本重叠,特征不同。
- 打个比方: 想象一下,在同一个城市,有一家银行和一家电商平台,它们都想针对同一批用户(样本重叠)联合训练一个更精准的信用风险评估模型。银行拥有用户的收支行为、信用评级等特征,而电商平台拥有这些用户的购买历史、浏览行为等特征。它们的用户群体有交集,但掌握的数据维度(特征)不同。
- 数据分布: 参与方的数据库记录(样本)有交集,但数据结构(特征)不同。就像将数据表格垂直切分,每个参与方持有一部分列。
- 如何工作:
- 数据对齐: 首先,在加密状态下找出参与方之间共有的用户样本。这通常通过隐私集合求交(Private Set Intersection, PSI)等技术实现。
- 协同训练:
- 假设有两个参与方A和B,以及一个可选的协调者C(在两方场景下,协调者可能不是必需的)。
- 在训练过程中,模型被拆分到不同的参与方。每一方基于自己拥有的特征计算中间结果(比如梯度的一部分)。
- 这些中间结果在加密状态下进行交换和聚合,以计算总的梯度和损失,从而更新模型参数。
- 这个过程需要复杂的加密技术(如同态加密、安全多方计算)来保护各方数据的隐私,确保任何一方都无法获取对方的原始特征数据,也无法推断出对方的中间计算结果。
- 模型共享: 最终,各方都能获得一个更强大的联合模型的部分,或者是一个完整的模型(取决于具体实现)。
- 典型应用场景:
- 跨机构金融风控: 银行和保险公司针对共同客户进行风险评估。
- 智慧医疗: 不同医院拥有同一病人的不同类型的医疗数据(如一家有影像数据,另一家有基因数据)。
- 精准营销: 零售商和广告平台针对共同用户进行用户画像和广告推荐。
- 优势: 能够利用不同来源的互补特征,构建更全面的模型。
- 挑战: 系统复杂度较高,需要解决数据对齐、加密计算、多方协调等问题。对参与方的数量和网络通信要求也较高。
联邦迁移学习 (Federated Transfer Learning, FTL)
- 核心特点:特征和样本都可能只有少量重叠,甚至没有重叠,但需要利用已有知识。
- 打个比方: 想象一下,一家位于A国的银行已经基于其本地数据训练了一个还不错的信用评分模型。现在,一家位于B国的新银行,其客户数据(样本)和A国银行完全不同,甚至记录的客户信息(特征)也不完全一样(比如由于法规不同,可收集的特征有差异)。B国银行希望利用A国银行已有的模型知识,在保护数据隐私的前提下,快速训练一个适用于B国本地情况的信用评分模型。
- 数据分布: 参与方之间的数据特征和样本ID重叠都很少,或者一个参与方有丰富的标签数据,而另一个参与方数据虽多但标签不足。
- 如何工作:
- FTL 将联邦学习和迁移学习结合起来。迁移学习的核心思想是将一个领域(源领域)学习到的知识应用到另一个相关但不同的领域(目标领域)。
- 在FTL中,通常会利用一个在源数据上预训练好的模型(或其一部分知识),在联邦学习的框架下,帮助目标领域的参与方训练模型,即使目标领域数据量较小或标签稀疏。
- 例如,可以将源领域模型的参数作为目标领域模型的初始化,或者在联邦学习过程中,共享和迁移一部分能够泛化的特征表示。
- 为了保护隐私,FTL 同样需要加密技术(如同态加密)来保护模型参数或中间结果的交换。
- 典型应用场景:
- 解决冷启动问题: 新业务或新地区缺乏足够数据时,可以借助其他相关业务或地区的模型知识。
- 利用无标签数据: 当一方有大量有标签数据,另一方只有无标签数据时,可以通过FTL进行联合建模。
- 跨领域知识迁移: 例如,将在图像识别领域学到的知识迁移到医疗影像分析。
- 与HFL/VFL的关系:
- FTL 可以看作是HFL或VFL在特定场景下的扩展。当HFL或VFL的参与方面临数据不足、特征不完全匹配或需要利用外部知识时,就可以引入迁移学习的机制,演变成FTL。
- 例如,在纵向联邦学习的场景中,如果一方的特征非常稀疏,可以借助另一方更丰富的特征信息进行知识迁移。
- 优势: 能够克服数据或标签不足的限制,提升模型在小样本或新领域的学习效果。
- 挑战: 需要找到合适的迁移策略,确保迁移的知识是有效的,并避免负迁移(即损害模型性能)。隐私保护机制的设计也更为复杂。
总结与比较:
特性 | 横向联邦学习 (HFL) | 纵向联邦学习 (VFL) | 联邦迁移学习 (FTL) |
---|---|---|---|
数据划分 | 特征相同,样本不同 (按行划分) | 样本相同,特征不同 (按列划分) | 特征和样本可能都只有少量或无重叠 |
核心思想 | 聚合不同样本上的模型更新 | 聚合不同特征下的模型信息,利用样本交集 | 在联邦框架下进行知识迁移,解决数据或标签不足问题 |
隐私技术 | 加密模型更新 (如梯度) | 隐私集合求交,加密中间计算结果 (如同态加密, MPC) | 加密技术 (如同态加密等),迁移知识表示 |
协调者 | 通常需要中心服务器进行模型聚合 | 可能需要协调者,两方场景下可去中心化 | 取决于具体的HFL或VFL基础架构,可能需要协调者 |
主要挑战 | 数据异构性,通信开销,设备异构性 | 数据对齐,加密计算复杂度,多方协调 | 找到有效的迁移策略,避免负迁移,复杂的隐私保护机制 |
适用场景 | 用户群体不同但业务特征相似的场景 (如不同银行的同类业务) | 用户群体有交集但各方掌握用户不同维度特征的场景 (如银行与电商合作) | 数据稀疏、标签不足、需要跨领域知识共享的场景 |
- 横向联邦学习: 大家做的事情一样(特征相似),但服务的人不一样(样本不同)。
- 纵向联邦学习: 大家服务的人有交集(样本重叠),但每个人做的事情不一样(特征不同)。
- 联邦迁移学习: 我这里数据不够或者不完全对口,能不能借鉴一下别人(或别的场景)已经学到的经验,同时大家的数据都不泄露。
好的,下面我将为横向联邦学习(HFL)、纵向联邦学习(VFL)和联邦迁移学习(FTL)分别提供 PyTorch 风格的伪代码案例。这些案例旨在帮助初学者理解核心思想,会简化一些复杂的加密和通信细节。
代码案例
- 伪代码: 这不是可以直接运行的完整代码,而是为了阐释核心逻辑。旨在帮助初学者理解核心思想,简化了一些复杂的加密和通信细节。实际应用要复杂得多,尤其是在安全和效率方面。
- 简化处理: 实际的联邦学习系统会涉及更复杂的通信协议、加密算法(如安全多方计算MPC、同态加密HE)、梯度压缩、参与方管理等。这里我们主要关注数据和模型的交互流程。
- PyTorch 风格: 代码会采用 PyTorch 的常用模式,如
torch.nn.Module
,torch.optim
等。
1. 横向联邦学习 (Horizontal Federated Learning - HFL) 伪代码
场景: 多个客户端(如手机、医院)拥有结构相同但样本不同的数据,共同训练一个模型。
import torch
import torch.nn as nn
import torch.optim as optim
from typing import List, Tuple# --- 定义全局模型 (所有客户端和服务器使用相同的模型结构) ---
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 1) # 假设输入特征维度为10,输出为1def forward(self, x):return self.fc(x)# --- 模拟客户端 ---
class HFLClient:def __init__(self, client_id: int, local_data: List[Tuple[torch.Tensor, torch.Tensor]], learning_rate: float):self.client_id = client_idself.local_data = local_data # (特征, 标签) 列表self.model = SimpleModel() # 每个客户端拥有一个本地模型副本self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate)self.criterion = nn.MSELoss() # 假设是回归任务def set_global_model_weights(self, global_weights):"""从服务器同步全局模型权重"""self.model.load_state_dict(global_weights)def local_train(self, epochs: int):"""在本地数据上训练模型"""self.model.train()for epoch in range(epochs):for features, labels in self.local_data:self.optimizer.zero_grad()outputs = self.model(features)loss = self.criterion(outputs, labels)loss.backward()self.optimizer.step()print(f"客户端 {self.client_id} 本地训练完成.")return self.model.state_dict() # 返回训练后的本地模型权重# --- 模拟服务器 ---
class HFLServer:def __init__(self):self.global_model = SimpleModel()def aggregate_models(self, client_model_weights: List[dict], client_data_sizes: List[int]) -> dict:"""聚合来自客户端的模型权重 (例如,使用联邦平均 FedAvg)client_model_weights: 列表,每个元素是一个客户端的模型 state_dictclient_data_sizes: 列表,每个元素是对应客户端的数据量大小,用于加权平均"""total_data_size = sum(client_data_sizes)aggregated_weights = self.global_model.state_dict() # 初始化为当前全局模型# 清零聚合权重for key in aggregated_weights.keys():aggregated_weights[key] = torch.zeros_like(aggregated_weights[key])# 加权平均for i, weights in enumerate(client_model_weights):weight_factor = client_data_sizes[i] / total_data_sizefor key in weights.keys():aggregated_weights[key] += weights[key] * weight_factorself.global_model.load_state_dict(aggregated_weights)print("服务器:模型聚合完成。")return aggregated_weightsdef get_global_model_weights(self) -> dict:return self.global_model.state_dict()# --- HFL 伪代码执行流程 ---
if __name__ == "__main__":# 0. 初始化NUM_CLIENTS = 3LOCAL_EPOCHS = 5NUM_ROUNDS = 10 # 联邦学习的轮次LEARNING_RATE = 0.01# 1. 模拟数据和客户端# 假设每个客户端有100个样本,每个样本10个特征clients_data = [[(torch.randn(1, 10), torch.randn(1, 1)) for _ in range(100)] for _ in range(NUM_CLIENTS)]client_data_sizes = [len(data) for data in clients_data]clients = [HFLClient(client_id=i, local_data=clients_data[i], learning_rate=LEARNING_RATE) for i in range(NUM_CLIENTS)]server = HFLServer()# 2. 联邦学习迭代for round_num in range(NUM_ROUNDS):print(f"\n--- 联邦学习轮次 {round_num + 1}/{NUM_ROUNDS} ---")current_global_weights = server.get_global_model_weights()local_model_weights_list = []# 2.1. 分发模型并进行本地训练for client in clients:client.set_global_model_weights(current_global_weights) # 同步全局模型local_weights = client.local_train(LOCAL_EPOCHS)local_model_weights_list.append(local_weights)# 2.2. 聚合模型new_global_weights = server.aggregate_models(local_model_weights_list, client_data_sizes)# new_global_weights 会在下一轮开始时分发print("\n--- 横向联邦学习完成 ---")final_model = server.global_model# 可以在这里评估 final_model 的性能
2. 纵向联邦学习 (Vertical Federated Learning - VFL) 伪代码
场景: 两个或多个参与方拥有相同样本ID的不同特征。他们需要协同训练一个模型。
简化假设:
- 我们假设只有两个参与方 A 和 B。
- 模型结构被逻辑上划分为两部分,一部分处理A的特征,另一部分处理B的特征,然后结果结合起来预测。
- 加密和安全计算被高度简化:实际VFL中,梯度和中间结果的交换需要同态加密或安全多方计算等技术保护。这里我们仅展示概念。
- 通常会有一个协调者(或其中一方扮演协调者角色)来同步和聚合加密的梯度/损失。
import torch
import torch.nn as nn
import torch.optim as optim
from typing import Dict, Tuple# --- 假设的整体模型结构 (逻辑上) ---
# Input_A (来自参与方A的特征) -> Model_Part_A -> Intermediate_A ---
# | --- (结合) ---> Prediction_Layer -> Output
# Input_B (来自参与方B的特征) -> Model_Part_B -> Intermediate_B ---# --- 参与方 A 的模型部分 ---
class ModelPartA(nn.Module):def __init__(self, input_dim_a: int, output_dim_a: int):super(ModelPartA, self).__init__()self.fc_a = nn.Linear(input_dim_a, output_dim_a)# self.output_dim_a = output_dim_a # 中间表示的维度def forward(self, x_a: torch.Tensor) -> torch.Tensor:return torch.relu(self.fc_a(x_a)) # 假设的中间表示# --- 参与方 B 的模型部分 ---
class ModelPartB(nn.Module):def __init__(self, input_dim_b: int, output_dim_b: int):super(ModelPartB, self).__init__()self.fc_b = nn.Linear(input_dim_b, output_dim_b)# self.output_dim_b = output_dim_bdef forward(self, x_b: torch.Tensor) -> torch.Tensor:return torch.relu(self.fc_b(x_b))# --- 顶层模型 (通常由协调者或其中一方持有,用于结合和预测) ---
class TopModel(nn.Module):def __init__(self, intermediate_dim_a: int, intermediate_dim_b: int, output_dim: int):super(TopModel, self).__init__()self.fc_top = nn.Linear(intermediate_dim_a + intermediate_dim_b, output_dim)def forward(self, intermediate_a: torch.Tensor, intermediate_b: torch.Tensor) -> torch.Tensor:combined_intermediate = torch.cat((intermediate_a, intermediate_b), dim=1)return self.fc_top(combined_intermediate)# --- 模拟参与方 A ---
class VFLPartyA:def __init__(self, data_a: Dict[str, torch.Tensor], input_dim_a: int, output_dim_a: int, learning_rate: float):self.data_a = data_a # 字典,key为样本ID,value为A的特征self.model_part_a = ModelPartA(input_dim_a, output_dim_a)self.optimizer_a = optim.SGD(self.model_part_a.parameters(), lr=learning_rate)# 在实际VFL中,A通常不会直接看到标签和计算完整的损失def forward_a(self, sample_ids: List[str]) -> Dict[str, torch.Tensor]:"""计算A部分的中间输出"""self.model_part_a.train() # 或者 eval() 取决于阶段intermediate_outputs_a = {}for sample_id in sample_ids:features_a = self.data_a[sample_id]intermediate_outputs_a[sample_id] = self.model_part_a(features_a)return intermediate_outputs_a # {sample_id: intermediate_tensor_a}def backward_a(self, gradients_on_intermediate_a: Dict[str, torch.Tensor], intermediate_outputs_a_for_grad: Dict[str, torch.Tensor]):"""根据从协调者(或B)处获得的关于A中间输出的梯度,来更新A的模型部分。gradients_on_intermediate_a: {sample_id: grad_tensor}intermediate_outputs_a_for_grad: {sample_id: intermediate_tensor_a} - 这些是前向传播时产生的输出,需要它们来反向传播"""self.optimizer_a.zero_grad()total_loss_surrogate = torch.tensor(0.0, requires_grad=True) # 代理损失for sample_id in gradients_on_intermediate_a.keys():# 实际中,这里需要复杂的加密和安全计算# 简化:直接使用梯度和中间输出进行反向传播intermediate_a = intermediate_outputs_a_for_grad[sample_id]grad = gradients_on_intermediate_a[sample_id]# .backward() 需要一个标量,或者对每个输出元素提供梯度# 这里我们假设 grad 是对应 intermediate_a 的梯度intermediate_a.backward(gradient=grad)self.optimizer_a.step()print("参与方 A:模型部分已更新。")# --- 模拟参与方 B (通常持有标签,并与协调者一起计算损失和梯度) ---
class VFLPartyB_and_Coordinator: # 简化,将B和协调者功能合并def __init__(self, data_b: Dict[str, torch.Tensor], labels: Dict[str, torch.Tensor],input_dim_b: int, output_dim_b: int,intermediate_dim_a: int, top_model_output_dim: int, learning_rate: float):self.data_b = data_b # key为样本ID,value为B的特征self.labels = labels # key为样本ID,value为标签self.model_part_b = ModelPartB(input_dim_b, output_dim_b)self.top_model = TopModel(intermediate_dim_a, output_dim_b, top_model_output_dim)self.optimizer_b = optim.SGD(self.model_part_b.parameters(), lr=learning_rate)self.optimizer_top = optim.SGD(self.top_model.parameters(), lr=learning_rate)self.criterion = nn.MSELoss() # 假设回归任务def forward_b_and_top(self, intermediate_outputs_a: Dict[str, torch.Tensor], sample_ids: List[str]) \-> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:"""计算B部分的中间输出,并结合A的中间输出通过顶层模型得到预测"""self.model_part_b.train()self.top_model.train()predictions = {}intermediate_outputs_b = {}final_outputs_for_loss = {}for sample_id in sample_ids:features_b = self.data_b[sample_id]inter_a = intermediate_outputs_a[sample_id]inter_b = self.model_part_b(features_b)intermediate_outputs_b[sample_id] = inter_b# detach inter_a because Party B should not compute gradients for Party A's model directly# Gradients for inter_a will be computed and sent back securely.final_pred = self.top_model(inter_a.detach().requires_grad_(), inter_b)predictions[sample_id] = final_predfinal_outputs_for_loss[sample_id] = final_predreturn predictions, intermediate_outputs_b, final_outputs_for_lossdef compute_loss_and_gradients(self, predictions: Dict[str, torch.Tensor], sample_ids: List[str]) \-> Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:"""计算损失,并为B和顶层模型计算梯度,同时计算传递给A的(加密)梯度"""self.optimizer_b.zero_grad()self.optimizer_top.zero_grad()total_loss = torch.tensor(0.0)# 存储传递给A的梯度 (对A中间输出的梯度)gradients_for_a_intermediate = {}# 存储传递给B的梯度 (对B中间输出的梯度)gradients_for_b_intermediate = {} # 实际上会直接用于B的优化器batch_loss = torch.tensor(0.0, requires_grad=True)intermediate_a_list_for_grad = []intermediate_b_list_for_grad = [] # for Party B's model_part_b update# 在VFL中,通常是逐样本或小批量处理,并安全地聚合梯度# 这里简化为在一个循环中累积损失processed_predictions = []processed_labels = []for sample_id in sample_ids:pred = predictions[sample_id] # This tensor was (inter_a.detach().requires_grad_(), inter_b) -> top_modellabel = self.labels[sample_id]processed_predictions.append(pred)processed_labels.append(label)# Stack for batch loss computationif not processed_predictions:return torch.tensor(0.0), {}, {}batch_predictions_tensor = torch.cat(processed_predictions)batch_labels_tensor = torch.cat(processed_labels)loss = self.criterion(batch_predictions_tensor, batch_labels_tensor)loss.backward() # 这会计算顶层模型参数的梯度,以及 inter_a.grad 和 inter_b.grad# 提取梯度# 注意:实际中这里的梯度交换需要加密!# .grad 属性是在执行 backward() 后填充的# 需要追溯到创建这些tensor的地方idx = 0for sample_id in sample_ids:# This part is tricky in pseudocode as requires_grad_() was on a detached tensor.# For simplicity, assume we can get grads w.r.t. inputs of top_model# A more accurate way would be to re-run parts of the forward pass or use hooks.# Conceptual: Get gradient w.r.t. inter_a that went into top_model# This is a simplification. Real VFL uses secure protocols.# We'll simulate by getting the .grad of the detached inter_a if PyTorch allows,# or more practically, by re-evaluating a small graph or using hooks.# For pseudocode, let's assume Party B can securely compute and send this.# Suppose top_model.fc_top.weight has shape [output_dim, intermediate_dim_a + intermediate_dim_b]# Grad w.r.t. inter_a would be related to loss's grad w.r.t. top_model's output,# then backpropagated through fc_top.# Simplification: We assume these gradients are somehow securely computed and made available.# Let's imagine `predictions[sample_id].grad_fn` allows access to input grads.# Or, more realistically, Party B calculates dL/d(inter_a) and dL/d(inter_b)# and sends dL/d(inter_a) (encrypted) to Party A.# Conceptual placeholder for secure gradient computation for A:# grad_for_a = compute_encrypted_gradient_for_A(loss_details, inter_a_used_in_top_model)# For pseudocode, let's assume we can derive it:# This is highly abstract:grad_for_a_placeholder = torch.randn_like(intermediate_outputs_a[sample_id]) # Placeholder!gradients_for_a_intermediate[sample_id] = grad_for_a_placeholderidx +=1self.optimizer_b.step()self.optimizer_top.step()print(f"参与方 B/协调者:损失计算完毕,模型B和Top已更新。梯度已准备好发往A。损失: {loss.item()}")return loss, gradients_for_a_intermediate# --- VFL 伪代码执行流程 ---
if __name__ == "__main__":# 0. 定义维度INPUT_DIM_A = 5INPUT_DIM_B = 7INTERMEDIATE_DIM_A = 10INTERMEDIATE_DIM_B = 12TOP_MODEL_OUTPUT_DIM = 1LEARNING_RATE = 0.01NUM_SAMPLES = 50NUM_ROUNDS = 10# 1. 模拟数据 (样本ID对齐)sample_ids = [f"sample_{i}" for i in range(NUM_SAMPLES)]data_a = {sid: torch.randn(1, INPUT_DIM_A) for sid in sample_ids}data_b = {sid: torch.randn(1, INPUT_DIM_B) for sid in sample_ids}labels = {sid: torch.randn(1, TOP_MODEL_OUTPUT_DIM) for sid in sample_ids}# 2. 初始化参与方party_a = VFLPartyA(data_a, INPUT_DIM_A, INTERMEDIATE_DIM_A, LEARNING_RATE)party_b_and_coordinator = VFLPartyB_and_Coordinator(data_b, labels, INPUT_DIM_B, INTERMEDIATE_DIM_B,INTERMEDIATE_DIM_A, TOP_MODEL_OUTPUT_DIM, LEARNING_RATE)# 3. VFL 迭代训练for round_num in range(NUM_ROUNDS):print(f"\n--- 纵向联邦学习轮次 {round_num + 1}/{NUM_ROUNDS} ---")current_sample_ids_batch = sample_ids # 在实际中可能是小批量# 3.1 参与方A进行前向传播intermediate_a_outputs = party_a.forward_a(current_sample_ids_batch)# 实际中: intermediate_a_outputs 会被加密发送给 B/协调者# 3.2 参与方B/协调者进行前向传播,计算损失和梯度# Party B receives (encrypted) intermediate_a_outputspredictions, intermediate_b_outputs, final_outputs_for_loss_dict = \party_b_and_coordinator.forward_b_and_top(intermediate_a_outputs, current_sample_ids_batch)# Party B/Coordinator computes loss and gradients for its parts and for A's intermediate output# The gradients for A's intermediate output (gradients_for_a) would be encrypted.loss, gradients_for_a = party_b_and_coordinator.compute_loss_and_gradients(final_outputs_for_loss_dict, current_sample_ids_batch)# 3.3 参与方A进行反向传播# Party A receives (encrypted) gradients_for_a and its own intermediate_a_outputs used in that forward passparty_a.backward_a(gradients_for_a, intermediate_a_outputs)# Note: intermediate_a_outputs are needed again for the backward pass in PyTorch# if they were not retained with requires_grad=True during Party A's forward pass.# For simplicity, we pass them again.print("\n--- 纵向联邦学习完成 ---")# 最终模型由 party_a.model_part_a, party_b_and_coordinator.model_part_b,# 和 party_b_and_coordinator.top_model 共同组成。
VFL 伪代码的关键点说明:
- 数据对齐:
sample_ids
的使用强调了纵向联邦中样本是对齐的。 - 模型拆分:
ModelPartA
,ModelPartB
,TopModel
体现了模型被逻辑拆分到不同参与方。 - 中间结果交换:
intermediate_a_outputs
从A传递到B(协调者)。 - 梯度交换:
gradients_for_a
从B(协调者)传递回A。 - 隐私保护(高度简化): 伪代码中没有实现加密,但实际操作中,所有交换的中间结果和梯度都必须加密。
gradients_for_a_placeholder
明确指出了这是一个需要安全计算的占位符。.detach().requires_grad_()
是一个尝试在概念上分离计算图但仍允许后续计算梯度的技巧,但实际的VFL梯度传递更为复杂。
3. 联邦迁移学习 (Federated Transfer Learning - FTL) 伪代码
场景: 存在一个预训练好的模型(源领域知识)。多个客户端(目标领域)拥有少量或特征不完全匹配的数据,他们希望利用预训练模型,在联邦学习的框架下进行微调。
简化假设:
- 我们采用类似横向联邦学习的架构。
- 源模型和目标模型的结构相似,或者目标模型使用了源模型的部分层。
- 迁移方式:微调预训练模型。
import torch
import torch.nn as nn
import torch.optim as optim
from typing import List, Tuple# --- 定义基础模型结构 (源模型和目标模型可以共享此结构或部分结构) ---
class BaseModel(nn.Module):def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):super(BaseModel, self).__init__()self.feature_extractor = nn.Sequential(nn.Linear(input_dim, hidden_dim),nn.ReLU(),# nn.Linear(hidden_dim, hidden_dim), # 更多层# nn.ReLU())self.classifier = nn.Linear(hidden_dim, output_dim)def forward(self, x):features = self.feature_extractor(x)output = self.classifier(features)return output# --- 模拟FTL客户端 ---
class FTLClient:def __init__(self, client_id: int, local_data: List[Tuple[torch.Tensor, torch.Tensor]],base_model_weights: dict, learning_rate: float,input_dim: int, hidden_dim: int, output_dim: int):self.client_id = client_idself.local_data = local_data # (特征, 标签)self.model = BaseModel(input_dim, hidden_dim, output_dim)# 1. 加载预训练权重 (迁移学习的关键步骤)self.model.load_state_dict(base_model_weights, strict=False) # strict=False 允许部分加载print(f"客户端 {self.client_id}: 已加载预训练模型权重。")# 2. (可选) 冻结部分层 - 例如,只微调分类器# for param in self.model.feature_extractor.parameters():# param.requires_grad = False# print(f"客户端 {self.client_id}: 特征提取层已冻结。")self.optimizer = optim.SGD(filter(lambda p: p.requires_grad, self.model.parameters()), lr=learning_rate)self.criterion = nn.CrossEntropyLoss() # 假设是分类任务def set_global_model_weights(self, global_weights: dict):"""从服务器同步全局模型权重 (在联邦微调过程中)"""self.model.load_state_dict(global_weights) # 严格加载,因为结构应该匹配def local_finetune(self, epochs: int) -> dict:"""在本地数据上微调模型"""self.model.train()for epoch in range(epochs):for features, labels in self.local_data:self.optimizer.zero_grad()outputs = self.model(features)loss = self.criterion(outputs, labels.long().squeeze()) # CrossEntropyLoss期望long类型的标签loss.backward()self.optimizer.step()print(f"客户端 {self.client_id} 本地微调完成.")return self.model.state_dict() # 返回微调后的本地模型权重# --- 模拟FTL服务器 (与HFL服务器类似,但初始模型是预训练的) ---
class FTLServer:def __init__(self, pretrained_global_model_weights: dict):# 服务器持有的全局模型,初始状态为预训练模型self.global_model = BaseModel(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM) # 确保维度一致self.global_model.load_state_dict(pretrained_global_model_weights)print("服务器:已初始化全局模型为预训练模型。")def aggregate_models(self, client_model_weights: List[dict], client_data_sizes: List[int]) -> dict:"""聚合来自客户端的模型权重 (FedAvg)"""total_data_size = sum(client_data_sizes)aggregated_weights = self.global_model.state_dict()for key in aggregated_weights.keys():# 只聚合那些参与训练的参数 (例如,如果部分层被冻结,它们不应改变)# 但通常在联邦平均中,我们还是聚合所有参数,本地冻结是客户端策略aggregated_weights[key] = torch.zeros_like(aggregated_weights[key])for i, weights in enumerate(client_model_weights):weight_factor = client_data_sizes[i] / total_data_sizefor key in weights.keys():if key in aggregated_weights: # 确保key存在aggregated_weights[key] += weights[key] * weight_factorself.global_model.load_state_dict(aggregated_weights)print("服务器:模型聚合完成。")return aggregated_weightsdef get_global_model_weights(self) -> dict:return self.global_model.state_dict()# --- FTL 伪代码执行流程 ---
if __name__ == "__main__":# 0. 定义参数和模型维度INPUT_DIM = 20HIDDEN_DIM = 50OUTPUT_DIM = 5 # 假设目标任务有5个类别NUM_CLIENTS = 2LOCAL_EPOCHS_FTL = 3NUM_ROUNDS_FTL = 5LEARNING_RATE_FTL = 0.001# 1. 模拟一个预训练好的源模型 (通常在大的通用数据集上训练得到)# 假设这是我们从别处加载的预训练模型权重print("正在加载/模拟预训练模型...")source_model_for_pretraining = BaseModel(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM) # 结构可能与目标任务输出维度不同,或特征维度不同# 实际中,这里的权重是训练好的# 为了伪代码,我们只取其初始权重作为“预训练”# 或者,如果特征维度/输出维度不同,需要更复杂的迁移策略(如只迁移特征提取器)# 为简单起见,假设源和目标任务的BaseModel结构兼容进行权重加载pretrained_weights = source_model_for_pretraining.state_dict()print("预训练模型权重已准备好。")# 2. 模拟目标领域的客户端数据 (可能数据量较小)# 假设每个客户端有20个样本,每个样本特征维度为INPUT_DIMclients_target_data = [[(torch.randn(1, INPUT_DIM), torch.randint(0, OUTPUT_DIM, (1,))) for _ in range(20)]for _ in range(NUM_CLIENTS)]client_target_data_sizes = [len(data) for data in clients_target_data]# 3. 初始化服务器和客户端ftl_server = FTLServer(pretrained_global_model_weights=pretrained_weights)ftl_clients = [FTLClient(client_id=i,local_data=clients_target_data[i],base_model_weights=ftl_server.get_global_model_weights(), # 初始时,客户端也从服务器获取预训练模型learning_rate=LEARNING_RATE_FTL,input_dim=INPUT_DIM, hidden_dim=HIDDEN_DIM, output_dim=OUTPUT_DIM) for i in range(NUM_CLIENTS)]# 4. 联邦迁移学习迭代 (微调过程)for round_num in range(NUM_ROUNDS_FTL):print(f"\n--- 联邦迁移学习轮次 {round_num + 1}/{NUM_ROUNDS_FTL} ---")current_global_weights = ftl_server.get_global_model_weights()local_model_weights_list = []# 4.1 分发全局模型并进行本地微调for client in ftl_clients:client.set_global_model_weights(current_global_weights) # 同步聚合后的模型local_weights = client.local_finetune(LOCAL_EPOCHS_FTL)local_model_weights_list.append(local_weights)# 4.2 聚合模型new_global_weights = ftl_server.aggregate_models(local_model_weights_list, client_target_data_sizes)print("\n--- 联邦迁移学习完成 ---")final_ftl_model = ftl_server.global_model# 可以在这里评估 final_ftl_model 在目标任务上的性能
FTL 伪代码的关键点说明:
- 预训练模型:
pretrained_weights
代表了从源领域获取的知识。 - 权重加载: 客户端在初始化时加载预训练权重 (
self.model.load_state_dict(base_model_weights, strict=False)
)。strict=False
允许在源模型和目标模型结构不完全一致时加载匹配的部分(例如,只加载特征提取器)。 - 本地微调:
local_finetune
函数执行在本地小数据集上的训练。 - 选择性冻结: 注释中提到了可以冻结预训练模型的部分层,只微调顶层,这是迁移学习中常用的技巧。
- 联邦平均: 服务器端的聚合过程与HFL类似,但聚合的是微调后的模型。