当前位置: 首页 > news >正文

利用“Flower”实现联邦机器学习的实战指南

一个很尴尬的现状就是我们用于训练 AI 模型的数据快要用完了。所以我们在大量的使用合成数据!

据估计,目前公开可用的高质量训练标记大约有 40 万亿到 90 万亿个,其中流行的 FineWeb 数据集包含 15 万亿个标记,仅限于英语。

作为参考,最近发布的 Llama 4 在文本、图像和视频数据集上进行了预训练,使用的标记数量超过 30 万亿个,是 Llama 3 的两倍多。

这让我们意识到,我们距离训练数据达到极限可能只有几年的时间了。

但那真的是极限吗?私人数据集呢?

这些数据集的规模可能是公开数据集的 10 到 20 倍(甚至更多),所有存储的消息中大约有 650 万亿个标记,电子邮件中大约有 1200 万亿个标记。

令人惊讶的是,许多公司收集的大量数据从未被分析过,因此被称为暗数据(Dark data)。

再想想政府机构、医院、律师事务所、金融机构、用户设备等存储的数据。

我同意这些数据是敏感的,而且有严格的数据保护法规来规范其处理方式。

其中大部分数据可能确实不适合用于训练机器学习模型,但肯定有一部分数据可以为人类和组织带来巨大价值。

如果有一种方法可以在不共享数据本身的情况下,使用多个组织的敏感合规数据来训练机器学习模型,那该多好啊!

这就是联邦机器学习(Federated Machine Learning)的用武之地!

接下来,我们将深入探讨联邦机器学习是什么,它是如何工作的,然后编写一个联邦学习流程,使用多个医疗机构的数据安全地训练一个可以检测眼部疾病的机器学习模型。

让我们开始吧!

但是,联邦机器学习到底是什么?

为了理解联邦机器学习是什么,我们先来看看传统的机器学习模型训练方法。

举个例子,我们想训练一个可以检测 CT 扫描图像中癌症的机器学习模型。

第一步是收集来自不同地理位置的多家医院的正常和癌症患者的 CT 扫描图像。

None

选择多样化数据源的原因是:

  • 增加样本量;
  • 减少由于不同因素(包括人口统计、专家和机构因素)导致的偏差。

这使得我们的模型即使对于训练数据集中未充分代表的群体也能具有泛化能力。

一旦这些数据被收集到一个中央的强大服务器上,我们就可以使用这些数据来训练模型并对其进行评估。

None

你能发现这种方法有哪些问题,使得执行起来几乎不可能吗?

首先,敏感的医疗数据受到法律(如GDPRHIPAA)的严格监管,这使得将这些数据传输到中央服务器变得非常困难。

其次,中央服务器必须有足够的计算和存储资源来处理这些数据和训练,这使得这种方法非常昂贵。

如果我们反过来,不是把数据移动到训练中,而是把训练移动到数据那里呢?

这就是联邦机器学习做的事情。

联邦机器学习是一种机器学习技术,多个组织可以在去中心化的方式下协作训练机器学习模型,而无需共享他们的数据集。

以下是使用这种方法的步骤:

  1. 在中央服务器上初始化一个基础/全局模型。

None

  1. 将该模型的参数发送到参与组织的服务器(称为客户端/节点),这些服务器包含本地数据。

None

  1. 每个客户端在其本地数据上训练模型一段时间(不是直到模型收敛,而是进行几步/一到几个周期)。

None

  1. 在本地训练完成后,每个客户端将其模型参数或累积的梯度发送回中央服务器。

None

  1. 由于每个客户端的参数因在不同的本地数据集上训练而与其他客户端不同,因此需要通过一个称为聚合的过程将它们结合起来。聚合的结果用于更新基础/全局模型的参数。

可以使用多种技术进行聚合,其中一种流行的方法是联邦平均(Federated Averaging)。

更新后的全局模型参数 = ∑ i = 1 N 客户端  i 的更新参数 × 客户端  i 的数据量 ∑ i = 1 N 客户端  i 的数据量 \text{更新后的全局模型参数} = \frac{\sum_{i=1}^{N} \text{客户端 } i \text{ 的更新参数} \times \text{客户端 } i \text{ 的数据量}}{\sum_{i=1}^{N} \text{客户端 } i \text{ 的数据量}} 更新后的全局模型参数=i=1N客户端 i 的数据量i=1N客户端 i 的更新参数×客户端 i 的数据量

在这种方法中,不同客户端的更新会被平均,并根据每个客户端用于训练的数据点数量进行加权

None

  1. 更新后的基础模型参数被发送回客户端,然后重复上述训练过程,直到获得一个完全训练好的模型。

None

你有没有注意到联邦机器学习带来的优势?

首先,数据保留在其生成的地方,从未被传输到一个中央位置,这使得这种方法是去中心化的。

其次,减少了对单一强大基础设施的需求,因为计算是在所有参与服务器之间共享的。

最后,我们有一种称为差分隐私(Differential Privacy)的技术,可以保护客户端数据的隐私。

这是一种技术,通过它,无法从联邦学习过程中共享的模型更新中识别出关于单个数据点的敏感信息。

为了实现差分隐私,使用了两种过程:

  • 裁剪客户端模型更新,以限制单个数据点的影响;
  • 加噪,即在裁剪后的更新中添加校准后的噪声。

根据这些过程发生的位置,我们有:

  • 中央差分隐私:中央服务器在全局参数上进行加噪,这些全局参数是通过接收客户端的裁剪更新进行聚合的,或者是由中央服务器进行裁剪的。

None

  • 本地差分隐私:每个客户端在将模型更新发送到中央服务器之前,本地应用裁剪和加噪。

None

训练你的第一个联邦学习机器学习模型

现在你已经了解了联邦机器学习的基础知识,是时候动手实践并编写一些代码了。

视网膜疾病影响着全球数亿人,是导致视力丧失和失明的主要原因之一。

**光学相干断层扫描(OCT)**可以为我们提供视网膜及其他眼层的详细横截面图像。

利用这些图像,我们的目标是训练一个机器学习模型,能够区分健康视网膜和受疾病影响的视网膜。

本教程中的所有代码都是使用 PyTorch 框架在 Jupyter 笔记本中编写的。

下载并探索数据集

我们将使用的数据集名为 OCTMNIST。

它是 MedMNIST 数据集的一个子集,包含 109,309 张大小为 28 × 28 像素的灰度、居中裁剪的视网膜 OCT 图像。

OCTMNIST 是一个多分类数据集,包含以下类别/标签:

  1. 脉络膜新生血管(CNV)
  2. 糖尿病性黄斑水肿(DME)
  3. 玻璃膜疣(Drusen)
  4. 正常

我们先在 Jupyter 笔记本中安装 medmnist 包,并获取有关 OCTMNIST 数据集的一些信息。

!uv pip install medmnist
from medmnist import INFO# 获取 OCTMNIST 数据集信息
info = INFO["octmnist"]print("数据集类型: ", info["task"])
print("数据集标签: ", info["label"])
print("图像通道数: ", info["n_channels"])
print("训练样本数量: ", info["n_samples"]["train"])
print("验证样本数量: ", info["n_samples"]["val"])
print("测试样本数量: ", info["n_samples"]["test"])

输出结果如下:

数据集类型:  多分类
数据集标签:  {'0': '脉络膜新生血管','1': '糖尿病性黄斑水肿', '2': '玻璃膜疣', '3': '正常'}
图像通道数:  1
训练样本数量:  97477
验证样本数量:  10832
测试样本数量:  1000

接下来,我们下载这个数据集,对其应用转换,并绘制其中的一些图像。

!uv pip install torch torchvision
import torch# 如果可用,使用 GPU
if torch.backends.mps.is_available():device = torch.device("mps")
elif torch.cuda.is_available():device = torch.device("cuda")
else:device = torch.device("cpu")print(f"使用设备: {device}")
from torchvision import transforms
from medmnist import OCTMNIST# 定义转换
transform = transforms.ToTensor()# 下载数据集,大小为 64 x 64
train_dataset = OCTMNIST(split='train', transform=transform, download=True, size=64)
val_dataset = OCTMNIST(split='val', transform=transform, download=True, size=64)
test_dataset = OCTMNIST(split='test', transform=transform, download=True, size=64)
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt# 定义标签映射
label_map = {0: '脉络膜新生血管',1: '糖尿病性黄斑水肿',2: '玻璃膜疣',3: '正常'
}# 从数据加载器中获取一批数据
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
images, labels = next(iter(train_loader))# 在 3 x 3 网格中绘制
rows, cols = 3, 3
fig, axes = plt.subplots(rows, cols, figsize=(5, 5))for i in range(rows * cols):ax = axes[i // cols, i % cols]ax.imshow(images[i][0], cmap='gray')ax.set_title(label_map[int(labels[i].item())], fontsize=6)ax.axis('off')plt.tight_layout()
plt.show()

None

将数据集拆分为子集

现实世界中的医疗数据通常存在类别不平衡和偏差。

为了模拟这种情况,我们将 OCTMNIST 数据集拆分为三个子集。

可以将这些子集视为属于三家不同医院的数据集,每个子集都排除了一种眼部疾病标签。

from torch.utils.data import Subset# 创建子数据集
def create_sub_datasets(full_dataset):targets = torch.tensor([label.item() for _, label in full_dataset])mask_A = (targets == 0) | (targets == 2) | (targets == 3)mask_B = (targets == 0) | (targets == 1) | (targets == 3)mask_C = (targets == 1) | (targets == 2) | (targets == 3)indices_A = mask_A.nonzero(as_tuple=True)[0]indices_B = mask_B.nonzero(as_tuple=True)[0]indices_C = mask_C.nonzero(as_tuple=True)[0]dataset_A = Subset(train_dataset, indices_A)  # 包含:CNV、DRUSEN、NORMAL(排除 DME)dataset_B = Subset(train_dataset, indices_B)  # 包含:CNV、DME、NORMAL(排除 DRUSEN)dataset_C = Subset(train_dataset, indices_C)  # 包含:DME、DRUSEN、NORMAL(排除 CNV)return [dataset_A, dataset_B, dataset_C]dataset_A, dataset_B, dataset_C = create_sub_datasets(train_dataset)

接下来,我们定义一个 ResNet-18 模型,用于将图像分类到相应的类别中。

from torchvision.models import resnet18, ResNet18_Weights
import torch.nn as nn# ResNet-18
def get_resnet_model(num_classes=4):model = resnet18(weights=ResNet18_Weights.DEFAULT)# 修改第一层卷积层以接受 1 通道输入model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)# 替换最终的全连接层model.fc = nn.Linear(model.fc.in_features, num_classes)return model.to(device)
训练与评估

为了模拟在每个医院使用本地数据进行训练,我们在之前定义的子数据集上分别训练三个 ResNet 模型。

以下是训练和评估的函数。

!uv pip install tqdm
import torch.optim as optim
from tqdm import tqdm# 训练函数
def train_model(model, criterion, optimizer, train_loader, val_loader, epochs=10):for epoch in range(epochs):model.train()running_correct, running_total = 0, 0loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{epochs}]", leave=False)for images, labels in loop:images = images.to(device)labels = labels.squeeze().long().to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()preds = torch.argmax(outputs, dim=1)running_correct += (preds == labels).sum().item()running_total += labels.size(0)loop.set_postfix(loss=loss.item(), acc=running_correct / running_total)train_acc = running_correct / running_totalval_acc = evaluate_model(model, val_loader)print(f"Epoch [{epoch+1}/{epochs}]  Train Acc: {train_acc:.4f}  Val Acc: {val_acc:.4f}")# 评估函数
def evaluate_model(model, test_loader):model.eval()correct, total = 0, 0with torch.no_grad():for images, labels in test_loader:images = images.to(device)labels = labels.squeeze().to(device)outputs = model(images)preds = torch.argmax(outputs, dim=1)correct += (preds == labels).sum().item()total += labels.size(0)return correct / total
# 在子数据集上训练的函数
def train_on_subset(subset_dataset, val_loader, epochs=10):loader = DataLoader(subset_dataset, batch_size=64, shuffle=True)model = get_resnet_model()criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)train_model(model, criterion, optimizer, loader, val_loader, epochs)return model
# 在子数据集上评估的函数
def evaluate_on_test(model, test_loader):model.eval()all_preds = []all_labels = []with torch.no_grad():for images, labels in test_loader:images = images.to(device)labels = labels.squeeze().to(device)outputs = model(images)preds = torch.argmax(outputs, dim=1)all_preds.extend(preds.cpu().numpy())all_labels.extend(labels.cpu().numpy())acc = sum([p == t for p, t in zip(all_preds, all_labels)]) / len(all_labels)return acc, all_preds, all_labels

是时候训练这些模型了!

# 在子数据集上训练模型
val_loader = DataLoader(val_dataset, batch_size=64)model_A = train_on_subset(dataset_A, val_loader)
model_B = train_on_subset(dataset_B, val_loader)
model_C = train_on_subset(dataset_C, val_loader)

训练过程的输出如下:

Epoch [1/10]  Train Acc: 0.9162  Val Acc: 0.8073
Epoch [2/10]  Train Acc: 0.9454  Val Acc: 0.8477
Epoch [3/10]  Train Acc: 0.9526  Val Acc: 0.8588
Epoch [4/10]  Train Acc: 0.9587  Val Acc: 0.8509
Epoch [5/10]  Train Acc: 0.9597  Val Acc: 0.8574
Epoch [6/10]  Train Acc: 0.9671  Val Acc: 0.8619
Epoch [7/10]  Train Acc: 0.9700  Val Acc: 0.8629
Epoch [8/10]  Train Acc: 0.9747  Val Acc: 0.8623
Epoch [9/10]  Train Acc: 0.9774  Val Acc: 0.8541
Epoch [10/10]  Train Acc: 0.9787  Val Acc: 0.8647
Epoch [1/10]  Train Acc: 0.9466  Val Acc: 0.8498
Epoch [2/10]  Train Acc: 0.9725  Val Acc: 0.8988
Epoch [3/10]  Train Acc: 0.9780  Val Acc: 0.8967
Epoch [4/10]  Train Acc: 0.9816  Val Acc: 0.9027
Epoch [5/10]  Train Acc: 0.9841  Val Acc: 0.9031
Epoch [6/10]  Train Acc: 0.9854  Val Acc: 0.8917
Epoch [7/10]  Train Acc: 0.9881  Val Acc: 0.9060
Epoch [8/10]  Train Acc: 0.9899  Val Acc: 0.9060
Epoch [9/10]  Train Acc: 0.9911  Val Acc: 0.9053
Epoch [10/10]  Train Acc: 0.9930  Val Acc: 0.9005
Epoch [1/10]  Train Acc: 0.9001  Val Acc: 0.6071
Epoch [2/10]  Train Acc: 0.9429  Val Acc: 0.6188
Epoch [3/10]  Train Acc: 0.9531  Val Acc: 0.6117
Epoch [4/10]  Train Acc: 0.9509  Val Acc: 0.6280
Epoch [5/10]  Train Acc: 0.9610  Val Acc: 0.6289
Epoch [6/10]  Train Acc: 0.9649  Val Acc: 0.6283
Epoch [7/10]  Train Acc: 0.9675  Val Acc: 0.6265
Epoch [8/10]  Train Acc: 0.9699  Val Acc: 0.6321
Epoch [9/10]  Train Acc: 0.9750  Val Acc: 0.6330
Epoch [10/10]  Train Acc: 0.9768  Val Acc: 0.6363

然后,我们在完整的测试集上测试这些模型的性能(这将是我们在实际应用中运行模型的情况)。

# 在完整测试集上评估
test_loader = DataLoader(test_dataset, batch_size=64)acc_A, preds_A, labels_A = evaluate_on_test(model_A, test_loader)
acc_B, preds_B, labels_B = evaluate_on_test(model_B, test_loader)
acc_C, preds_C, labels_C = evaluate_on_test(model_C, test_loader)# 报告准确率
print(f"测试准确率 | 在排除 DME 的数据集上训练的模型: {acc_A:.4f}")
print(f"测试准确率 | 在排除 DRUSEN 的数据集上训练的模型: {acc_B:.4f}")
print(f"测试准确率 | 在排除 CNV 的数据集上训练的模型: {acc_C:.4f}")

输出结果如下:

测试准确率 | 在排除 DME 的数据集上训练的模型: 0.6420
测试准确率 | 在排除 DRUSEN 的数据集上训练的模型: 0.7080
测试准确率 | 在排除 CNV 的数据集上训练的模型: 0.7030

我们可以看到,这些模型在测试数据集中未见过的类别上表现不佳。

当我们绘制混淆矩阵并可视化结果时,这一点更加明显。

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplaydef plot_confusion_matrix(y_true, y_pred, title):cm = confusion_matrix(y_true, y_pred, labels=[0,1,2,3])disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["CNV", "DME", "DRUSEN", "NORMAL"])disp.plot(cmap=plt.cm.Blues)plt.title(title)plt.show()# 绘制混淆矩阵
plot_confusion_matrix(labels_A, preds_A, "混淆矩阵 - 排除 DME 的模型")
plot_confusion_matrix(labels_B, preds_B, "混淆矩阵 - 排除 DRUSEN 的模型")
plot_confusion_matrix(labels_C, preds_C, "混淆矩阵 - 排除 CNV 的模型")

None

None

None

在子数据集上进行联邦学习

现在轮到联邦学习大显身手了。

我们使用 Flower 框架,它允许我们使用任何机器学习框架和任何编程语言进行联邦学习、分析和评估。

我仍然使用 PyTorch 框架进行本教程,以使其对大多数人更易于理解。

我使用的是 MacBook M4 Max 来运行以下代码,但如果你在 Google Colab 上使用 T4 GPU,这将报错。

这是因为 Colab 只将一个 GPU 暴露给主笔记本进程。当 Flower(使用 Ray 作为其默认后端运行模拟)为每个模拟客户端启动额外的 Python 工作进程时,这些工作进程无法访问 GPU,程序就会崩溃。

如果你想在 Google Colab 上运行代码,建议使用 CPU 作为设备。不过,这会使训练变得非常缓慢。

安装 Flower 包
!uv pip install "flwr[simulation]"

回顾一下我们之前学到的内容,在联邦学习过程中,客户端和中央服务器之间会交换模型参数/权重。

当客户端从中央服务器接收到模型参数时,它将使用这些参数/权重更新其本地模型。

训练完成后,它将这些本地模型参数/权重发送回中央服务器。

定义客户端函数

两个函数可以帮助我们执行这些操作:

  • get_weights:此函数用于训练完成后获取客户端模型的更新权重,并将其发送回中央服务器。

它接受一个机器学习模型的引用,迭代其 state_dict 中的项,将每个项转换为 Numpy ndarray,并返回这些 ndarray 的列表。

# 获取客户端模型的更新权重
def get_weights(net):return [val.cpu().numpy() for _, val in net.state_dict().items()]
  • set_weights:此函数用于在训练开始之前,使用从中央服务器收到的新权重更新客户端模型的权重。

它接受一个机器学习模型的引用和一个 ndarray 列表。使用这个列表,它更新模型 state_dict 中的所有项。

from collections import OrderedDict# 更新客户端模型的权重
def set_weights(net, parameters):params_dict = zip(net.state_dict().keys(), parameters)state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})net.load_state_dict(state_dict, strict=True)

接下来,我们定义一个 FlowerClient 类,它将帮助我们在客户端上训练和评估模型。

from flwr.client import NumPyClient
from typing import Dict
from flwr.common import NDArrays, Scalar# Flower 客户端
class FlowerClient(NumPyClient):def __init__(self, net, trainset, valset, testset):self.net = netself.trainset = trainsetself.valset = valsetself.testset = testset# 本地训练def fit(self, parameters, config):set_weights(self.net, parameters)# 数据加载器train_loader = DataLoader(self.trainset, batch_size=64, shuffle=True)val_loader = DataLoader(self.valset, batch_size=64)# 损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(self.net.parameters(), lr=0.001)train_model(self.net,criterion,optimizer,train_loader,val_loader,epochs= 1,)return get_weights(self.net), len(self.trainset), {}# 本地评估def evaluate(self, parameters, config):set_weights(self.net, parameters)loss, acc = evaluate_model(self.net, DataLoader(self.testset, batch_size=64))return loss, len(self.testset), {"accuracy": acc}

client_fn 函数帮助我们根据需要创建此类的实例。

ClientApp 作为客户端逻辑的入口点,当 Flower 客户端从中央服务器接收任务时运行。

from flwr.client import Client, ClientApp
from flwr.common import Contexttrain_sets = [dataset_A, dataset_B, dataset_C]# 创建客户端的函数
def client_fn(context: Context) -> Client:cid = int(context.node_config["partition-id"])trainset = train_sets[cid]return FlowerClient(get_resnet_model(),trainset,val_dataset,test_dataset,).to_client()client = ClientApp(client_fn)

这就是客户端需要的所有内容。

定义服务器函数

接下来,我们定义一个 evaluate 函数,中央服务器在每轮联邦学习之后使用它来评估全局模型。

我们还定义了一个名为 filter_by_classes 的函数,它返回一个测试集的子集,其中只包含指定类别列表中的样本。

这有助于我们在每个客户端可用的类别子集上测试模型。

def filter_by_classes(dataset, class_list):indices = [i for i, (_, label) in enumerate(dataset) if label.item() in class_list]return Subset(dataset, indices)# 包含:CNV、DRUSEN、NORMAL - 排除:DME
testset_no_dme = filter_by_classes(test_dataset, [0, 2, 3])# 包含:CNV、DME、NORMAL - 排除:DRUSEN
testset_no_drusen = filter_by_classes(test_dataset, [0, 1, 3])# 包含:DME、DRUSEN、NORMAL - 排除:CNV
testset_no_cnv = filter_by_classes(test_dataset, [1, 2, 3])
# 评估全局模型
def evaluate(server_round, parameters, config, num_rounds = 20):net = get_resnet_model()set_weights(net, parameters)batch_size = 64acc_tot = evaluate_model(net, DataLoader(test_dataset, batch_size=batch_size))acc_A = evaluate_model(net, DataLoader(testset_no_dme, batch_size=batch_size))acc_B = evaluate_model(net, DataLoader(testset_no_drusen, batch_size=batch_size))acc_C = evaluate_model(net, DataLoader(testset_no_cnv, batch_size=batch_size))print(f"[Round {server_round}] 全局准确率: {acc_tot:.4f}")print(f"[Round {server_round}] (CNV,DRUSEN,NORMAL) 准确率: {acc_A:.4f}")print(f"[Round {server_round}] (CNV,DME,NORMAL)    准确率: {acc_B:.4f}")print(f"[Round {server_round}] (DME,DRUSEN,NORMAL) 准确率: {acc_C:.4f}")# 在最后一轮绘制混淆矩阵if server_round == num_rounds:acc_final, preds_final, labels_final = evaluate_on_test(net, DataLoader(test_dataset, batch_size=64))plot_confusion_matrix(labels_final, preds_final, "最终全局混淆矩阵")

接下来,使用 server_fn 函数,我们设置中央服务器,它使用联邦平均聚合策略。

from flwr.common import ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvgnet = get_resnet_model()
params = ndarrays_to_parameters(get_weights(net))# 设置全局服务器的函数
def server_fn(context: Context, num_rounds = 5):# 联邦平均策略strategy = FedAvg(fraction_fit=1.0,fraction_evaluate=0.0,initial_parameters=params,evaluate_fn=evaluate,)config=ServerConfig(num_rounds)return ServerAppComponents(strategy=strategy,config=config,)server = ServerApp(server_fn=server_fn)

现在我们已经准备好训练我们的机器学习模型了。

训练与评估

为了模拟在三个客户端上的训练,我们使用 run_simulation 函数如下:

from flwr.simulation import run_simulation
from logging import ERROR# 为了保持日志输出简洁
backend_setup = {"init_args": {"logging_level": ERROR, "log_to_driver": False}}# 运行训练模拟
run_simulation(server_app=server,client_app=client,num_supernodes=3,backend_config=backend_setup,
)

以下是经过 20 轮联邦学习后的结果。

INFO : aggregate_fit: received 3 results and 0 failures
[Round 20] 全局准确率: 0.7710
[Round 20] (CNV,DRUSEN,NORMAL) 准确率: 0.7933
[Round 20] (CNV,DME,NORMAL)    准确率: 0.8947
[Round 20] (DME,DRUSEN,NORMAL) 准确率: 0.6960

在这里插入图片描述

模型在每个客户端的标签分布过滤后的测试数据集上表现良好(如三个测试子集的准确率所示)。

最棒的是,尽管每个客户端在其本地数据集中缺少一个疾病标签,全局模型仍然能够很好地识别所有标签(如完整测试集上的全局准确率所示)。

请注意,训练并不完美,还需要进一步优化和调整超参数以获得更好的结果。

数据集本身也存在类别不平衡问题,正常 OCT 图像的样本最多,而玻璃膜疣(Drusen)的样本最少,这可能解释了在最终全局混淆矩阵中对这一标签的误分类。

我们可以通过绘制 OCTMNIST 训练集中的类别分布来观察类别不平衡。

# 检查类别不平衡import matplotlib.pyplot as plt
from collections import Counter# 统计类别出现次数
labels = [label.item() for _, label in train_dataset]
class_counts = Counter(labels)
print(class_counts)# 准备绘图数据
class_names = ['CNV', 'DME', 'DRUSEN', 'NORMAL']
counts = [class_counts[i] for i in range(4)]# 绘图
plt.figure(figsize=(8, 5))
plt.bar(class_names, counts)
plt.title("OCTMNIST 训练集中的类别分布")
plt.xlabel("类别")
plt.ylabel("样本数量")
plt.grid(axis='y', linestyle='--', alpha=0.6)
plt.tight_layout()
plt.show()

None

阅读参考

  • Flower 框架文档
  • DeepLearning.ai 上的“联邦学习入门”课程
  • 具有差分隐私的 Gboard 语言模型的联邦学习

相关文章:

  • vector的大小
  • redis数据结构-05 (LPUSH、RPUSH、LPOP、RPOP)
  • 【今日三题】素数回文(模拟) / 活动安排(区间贪心) / 合唱团(动态规划)
  • 特励达力科LeCroy推出Xena Freya Z800 800GE高性能的800G以太网测试平台
  • 【英语笔记(一)】概述词类的作用与语义:名词、代词、数词、代词、动词.....,副词、不定式、分词、形容词等语义在句子中的作用;讲解表语、定语等
  • Linux网络基础 -- 局域网,广域网,网络协议,网络传输的基本流程,端口号,网络字节序
  • python打卡day22@浙大疏锦行
  • Java SE(11)——内部类
  • 无锁秒杀系统设计:基于Java的高效实现
  • VMware安装CentOS Stream10
  • Three.js + React 实战系列 - 联系方式提交表单区域 Contact 组件✨(表单绑定 + 表单验证)
  • Yocto 项目中的 glibc 编译失败全解析:原因、原理与修复策略
  • 深入剖析 MyBatis 位运算查询:从原理到最佳实践
  • RabbitMQ的工作队列模式和路由模式有什么区别?
  • BGP联盟
  • 无侵入式弹窗体验_探索 Chrome 的 Close Watcher API
  • 什么是中央税
  • 基于Flask、Bootstrap及深度学习的水库智能监测分析平台
  • c++ 如何写类(不带指针版)
  • 24、TypeScript:预言家之书——React 19 类型系统
  • 韩国总统选战打响:7人角逐李在明领跑,执政党临阵换将陷入分裂
  • 湛江霞山通报渔船火灾:起火船舶共8艘,无人员伤亡或被困
  • 重庆荣昌区委区政府再设“答谢宴”,邀请800余名志愿者机关食堂用餐
  • 上海“电子支付费率成本为0”背后:金融服务不仅“快”和“省”,更有“稳”和“准”
  • 近4小时会谈、3项联合声明、20多份双边合作文本,中俄元首今年首次面对面会晤成果颇丰
  • 上海启动万兆光网试点建设,助力“模速空间”跑出发展加速度