更新联邦学习
This commit is contained in:
parent
f43e21c09d
commit
617230e296
91
FED.py
91
FED.py
@ -12,7 +12,7 @@ from model.repvit import repvit_m1_1
|
||||
from model.mobilenetv3 import MobileNetV3
|
||||
|
||||
# 配置参数
|
||||
NUM_CLIENTS = 4
|
||||
NUM_CLIENTS = 2
|
||||
NUM_ROUNDS = 3
|
||||
CLIENT_EPOCHS = 5
|
||||
BATCH_SIZE = 32
|
||||
@ -22,25 +22,27 @@ TEMP = 2.0 # 蒸馏温度
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# 数据准备
|
||||
def prepare_data(num_clients):
|
||||
import os
|
||||
from torchvision.datasets import ImageFolder
|
||||
|
||||
def prepare_data():
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)), # 将图像调整为 224x224
|
||||
transforms.Grayscale(num_output_channels=3),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
train_set = datasets.MNIST("./data", train=True, download=True, transform=transform)
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
# 非IID数据划分(每个客户端2个类别)
|
||||
client_data = {i: [] for i in range(num_clients)}
|
||||
labels = train_set.targets.numpy()
|
||||
for label in range(10):
|
||||
label_idx = np.where(labels == label)[0]
|
||||
np.random.shuffle(label_idx)
|
||||
split = np.array_split(label_idx, num_clients//2)
|
||||
for i, idx in enumerate(split):
|
||||
client_data[i*2 + label%2].extend(idx)
|
||||
# Load datasets
|
||||
dataset_A = ImageFolder(root='./dataset_A/train', transform=transform)
|
||||
dataset_B = ImageFolder(root='./dataset_B/train', transform=transform)
|
||||
dataset_C = ImageFolder(root='./dataset_C/train', transform=transform)
|
||||
|
||||
return [Subset(train_set, ids) for ids in client_data.values()]
|
||||
# Assign datasets to clients
|
||||
client_datasets = [dataset_B, dataset_C]
|
||||
|
||||
# Server dataset (A) for public updates
|
||||
public_loader = DataLoader(dataset_A, batch_size=BATCH_SIZE, shuffle=True)
|
||||
|
||||
return client_datasets, public_loader
|
||||
|
||||
# 客户端训练函数
|
||||
def client_train(client_model, server_model, dataset):
|
||||
@ -189,63 +191,47 @@ def test_model(model, test_loader):
|
||||
|
||||
# 主训练流程
|
||||
def main():
|
||||
# 初始化模型
|
||||
# Initialize models
|
||||
global_server_model = repvit_m1_1(num_classes=10).to(device)
|
||||
client_models = [MobileNetV3(n_class=10).to(device) for _ in range(NUM_CLIENTS)]
|
||||
|
||||
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
|
||||
|
||||
# 准备数据
|
||||
client_datasets = prepare_data(NUM_CLIENTS)
|
||||
public_loader = DataLoader(
|
||||
datasets.MNIST("./data", train=False, download=True,
|
||||
transform= transforms.Compose([
|
||||
transforms.Resize((224, 224)), # 将图像调整为 224x224
|
||||
transforms.Grayscale(num_output_channels=3),
|
||||
transforms.ToTensor() # 将图像转换为张量
|
||||
])),
|
||||
batch_size=100, shuffle=True)
|
||||
# Prepare data
|
||||
client_datasets, public_loader = prepare_data()
|
||||
|
||||
test_dataset = datasets.MNIST(
|
||||
"./data",
|
||||
train=False,
|
||||
transform= transforms.Compose([
|
||||
transforms.Resize((224, 224)), # 将图像调整为 224x224
|
||||
transforms.Grayscale(num_output_channels=3),
|
||||
transforms.ToTensor() # 将图像转换为张量
|
||||
])
|
||||
)
|
||||
# Test dataset (using dataset A's test set for simplicity)
|
||||
test_dataset = ImageFolder(root='./dataset_A/test', transform=transform)
|
||||
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
|
||||
|
||||
round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
|
||||
|
||||
for round in range(NUM_ROUNDS):
|
||||
print(f"\n{'#'*50}")
|
||||
print(f"Federated Round {round+1}/{NUM_ROUNDS}")
|
||||
print(f"{'#'*50}")
|
||||
|
||||
# 客户端选择
|
||||
# Client selection (only 2 clients)
|
||||
selected_clients = np.random.choice(NUM_CLIENTS, 2, replace=False)
|
||||
print(f"Selected Clients: {selected_clients}")
|
||||
|
||||
# 客户端本地训练
|
||||
# Client local training
|
||||
client_params = []
|
||||
for cid in selected_clients:
|
||||
print(f"\nTraining Client {cid}")
|
||||
local_model = copy.deepcopy(client_models[cid])
|
||||
local_model.load_state_dict(client_models[cid].state_dict())
|
||||
|
||||
updated_params = client_train(local_model, global_server_model, client_datasets[cid])
|
||||
client_params.append(updated_params)
|
||||
|
||||
# 模型聚合
|
||||
# Model aggregation
|
||||
global_client_params = aggregate(client_params)
|
||||
for model in client_models:
|
||||
model.load_state_dict(global_client_params)
|
||||
|
||||
# 服务器知识更新
|
||||
# Server knowledge update
|
||||
print("\nServer Updating...")
|
||||
server_update(global_server_model, client_models, public_loader)
|
||||
|
||||
# 测试模型性能
|
||||
# Test model performance
|
||||
server_acc = test_model(global_server_model, test_loader)
|
||||
client_acc = test_model(client_models[0], test_loader)
|
||||
print(f"\nRound {round+1} Performance:")
|
||||
@ -253,25 +239,22 @@ def main():
|
||||
print(f"Client Model Accuracy: {client_acc:.2f}%")
|
||||
|
||||
round_progress.update(1)
|
||||
|
||||
print(f"Round {round+1} completed")
|
||||
|
||||
print("Training completed!")
|
||||
|
||||
# 保存训练好的模型
|
||||
|
||||
# Save trained models
|
||||
torch.save(global_server_model.state_dict(), "server_model.pth")
|
||||
torch.save(client_models[0].state_dict(), "client_model.pth")
|
||||
print("Models saved successfully.")
|
||||
|
||||
# 创建测试数据加载器
|
||||
|
||||
# 测试服务器模型
|
||||
|
||||
# Test server model
|
||||
server_model = repvit_m1_1(num_classes=10).to(device)
|
||||
server_model.load_state_dict(torch.load("server_model.pth"))
|
||||
server_acc = test_model(server_model, test_loader)
|
||||
print(f"Server Model Test Accuracy: {server_acc:.2f}%")
|
||||
|
||||
# 测试客户端模型
|
||||
|
||||
# Test client model
|
||||
client_model = MobileNetV3(n_class=10).to(device)
|
||||
client_model.load_state_dict(torch.load("client_model.pth"))
|
||||
client_acc = test_model(client_model, test_loader)
|
||||
|
Loading…
Reference in New Issue
Block a user