更新联邦学习

This commit is contained in:
yoiannis 2025-03-11 23:12:34 +08:00
parent f43e21c09d
commit 617230e296

91
FED.py
View File

@ -12,7 +12,7 @@ from model.repvit import repvit_m1_1
from model.mobilenetv3 import MobileNetV3
# 配置参数
NUM_CLIENTS = 4
NUM_CLIENTS = 2
NUM_ROUNDS = 3
CLIENT_EPOCHS = 5
BATCH_SIZE = 32
@ -22,25 +22,27 @@ TEMP = 2.0 # 蒸馏温度
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 数据准备
def prepare_data(num_clients):
import os
from torchvision.datasets import ImageFolder
def prepare_data():
transform = transforms.Compose([
transforms.Resize((224, 224)), # 将图像调整为 224x224
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor()
])
train_set = datasets.MNIST("./data", train=True, download=True, transform=transform)
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# 非IID数据划分每个客户端2个类别
client_data = {i: [] for i in range(num_clients)}
labels = train_set.targets.numpy()
for label in range(10):
label_idx = np.where(labels == label)[0]
np.random.shuffle(label_idx)
split = np.array_split(label_idx, num_clients//2)
for i, idx in enumerate(split):
client_data[i*2 + label%2].extend(idx)
# Load datasets
dataset_A = ImageFolder(root='./dataset_A/train', transform=transform)
dataset_B = ImageFolder(root='./dataset_B/train', transform=transform)
dataset_C = ImageFolder(root='./dataset_C/train', transform=transform)
return [Subset(train_set, ids) for ids in client_data.values()]
# Assign datasets to clients
client_datasets = [dataset_B, dataset_C]
# Server dataset (A) for public updates
public_loader = DataLoader(dataset_A, batch_size=BATCH_SIZE, shuffle=True)
return client_datasets, public_loader
# 客户端训练函数
def client_train(client_model, server_model, dataset):
@ -189,63 +191,47 @@ def test_model(model, test_loader):
# 主训练流程
def main():
# 初始化模型
# Initialize models
global_server_model = repvit_m1_1(num_classes=10).to(device)
client_models = [MobileNetV3(n_class=10).to(device) for _ in range(NUM_CLIENTS)]
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
# 准备数据
client_datasets = prepare_data(NUM_CLIENTS)
public_loader = DataLoader(
datasets.MNIST("./data", train=False, download=True,
transform= transforms.Compose([
transforms.Resize((224, 224)), # 将图像调整为 224x224
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor() # 将图像转换为张量
])),
batch_size=100, shuffle=True)
# Prepare data
client_datasets, public_loader = prepare_data()
test_dataset = datasets.MNIST(
"./data",
train=False,
transform= transforms.Compose([
transforms.Resize((224, 224)), # 将图像调整为 224x224
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor() # 将图像转换为张量
])
)
# Test dataset (using dataset A's test set for simplicity)
test_dataset = ImageFolder(root='./dataset_A/test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
for round in range(NUM_ROUNDS):
print(f"\n{'#'*50}")
print(f"Federated Round {round+1}/{NUM_ROUNDS}")
print(f"{'#'*50}")
# 客户端选择
# Client selection (only 2 clients)
selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False)
print(f"Selected Clients: {selected_clients}")
# 客户端本地训练
# Client local training
client_params = []
for cid in selected_clients:
print(f"\nTraining Client {cid}")
local_model = copy.deepcopy(client_models[cid])
local_model.load_state_dict(client_models[cid].state_dict())
updated_params = client_train(local_model, global_server_model, client_datasets[cid])
client_params.append(updated_params)
# 模型聚合
# Model aggregation
global_client_params = aggregate(client_params)
for model in client_models:
model.load_state_dict(global_client_params)
# 服务器知识更新
# Server knowledge update
print("\nServer Updating...")
server_update(global_server_model, client_models, public_loader)
# 测试模型性能
# Test model performance
server_acc = test_model(global_server_model, test_loader)
client_acc = test_model(client_models[0], test_loader)
print(f"\nRound {round+1} Performance:")
@ -253,25 +239,22 @@ def main():
print(f"Client Model Accuracy: {client_acc:.2f}%")
round_progress.update(1)
print(f"Round {round+1} completed")
print("Training completed!")
# 保存训练好的模型
# Save trained models
torch.save(global_server_model.state_dict(), "server_model.pth")
torch.save(client_models[0].state_dict(), "client_model.pth")
print("Models saved successfully.")
# 创建测试数据加载器
# 测试服务器模型
# Test server model
server_model = repvit_m1_1(num_classes=10).to(device)
server_model.load_state_dict(torch.load("server_model.pth"))
server_acc = test_model(server_model, test_loader)
print(f"Server Model Test Accuracy: {server_acc:.2f}%")
# 测试客户端模型
# Test client model
client_model = MobileNetV3(n_class=10).to(device)
client_model.load_state_dict(torch.load("client_model.pth"))
client_acc = test_model(client_model, test_loader)