diff --git a/FED.py b/FED.py
index b1b0d13..a727ca0 100644
--- a/FED.py
+++ b/FED.py
@@ -13,10 +13,11 @@ from model.mobilenetv3 import MobileNetV3
 
 # 配置参数
 NUM_CLIENTS = 2
-NUM_ROUNDS = 3
-CLIENT_EPOCHS = 5
+NUM_ROUNDS = 10
+CLIENT_EPOCHS = 2
 BATCH_SIZE = 32
 TEMP = 2.0  # 蒸馏温度
+CLASS_NUM = [3, 3, 3]
 
 # 设备配置
 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -32,9 +33,9 @@ def prepare_data():
     ])
     
     # Load datasets
-    dataset_A = ImageFolder(root='./dataset_A/train', transform=transform)
-    dataset_B = ImageFolder(root='./dataset_B/train', transform=transform)
-    dataset_C = ImageFolder(root='./dataset_C/train', transform=transform)
+    dataset_A = ImageFolder(root='G:/testdata/JY_A/train', transform=transform)
+    dataset_B = ImageFolder(root='G:/testdata/ZY_A/train', transform=transform)
+    dataset_C = ImageFolder(root='G:/testdata/ZY_B/train', transform=transform)
     
     # Assign datasets to clients
     client_datasets = [dataset_B, dataset_C]
@@ -105,13 +106,6 @@ def client_train(client_model, server_model, dataset):
             })
             progress_bar.update(1)
                         
-            # 每10个batch打印详细信息
-            if (batch_idx + 1) % 10 == 0:
-                progress_bar.write(f"\nEpoch {epoch+1} | Batch {batch_idx+1}")
-                progress_bar.write(f"Task Loss: {loss_task:.4f}")
-                progress_bar.write(f"Distill Loss: {loss_distill:.4f}")
-                progress_bar.write(f"Total Loss: {total_loss:.4f}")
-                progress_bar.write(f"Batch Accuracy: {100*correct/total:.2f}%\n")
         # 每个epoch结束打印汇总信息
         avg_loss = epoch_loss / len(loader)
         avg_task = task_loss / len(loader)
@@ -135,6 +129,37 @@ def aggregate(client_params):
         global_params[key] = torch.stack([param[key].float() for param in client_params]).mean(dim=0)
     return global_params
 
+def server_aggregate(server_model, client_models, public_loader):
+    server_model.train()
+    optimizer = torch.optim.Adam(server_model.parameters(), lr=0.001)
+    
+    for data, _ in public_loader:
+        data = data.to(device)
+        
+        # 获取客户端模型特征
+        client_features = []
+        with torch.no_grad():
+            for model in client_models:
+                features = model.extract_features(data)  # 需要实现特征提取方法
+                client_features.append(features)
+        
+        # 计算特征蒸馏目标
+        target_features = torch.stack(client_features).mean(dim=0)
+        
+        # 服务器前向
+        server_features = server_model.extract_features(data)
+        
+        # 特征对齐损失
+        loss = F.mse_loss(server_features, target_features)
+        
+        # 反向传播
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+
+        # 更新统计信息
+        total_loss += loss.item()
+
 # 服务器知识更新
 def server_update(server_model, client_models, public_loader):
     server_model.train()
@@ -191,15 +216,19 @@ def test_model(model, test_loader):
 
 # 主训练流程
 def main():
+    transform = transforms.Compose([
+        transforms.Resize((224, 224)),
+        transforms.ToTensor()
+    ])
     # Initialize models
-    global_server_model = repvit_m1_1(num_classes=10).to(device)
-    client_models = [MobileNetV3(n_class=10).to(device) for _ in range(NUM_CLIENTS)]
+    global_server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
+    client_models = [MobileNetV3(n_class=CLASS_NUM[i+1]).to(device) for i in range(NUM_CLIENTS)]
     
     # Prepare data
     client_datasets, public_loader = prepare_data()
     
     # Test dataset (using dataset A's test set for simplicity)
-    test_dataset = ImageFolder(root='./dataset_A/test', transform=transform)
+    test_dataset = ImageFolder(root='G:/testdata/JY_A/test', transform=transform)
     test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
     
     round_progress = tqdm(total=NUM_ROUNDS, desc="Federated Rounds", unit="round")
@@ -245,20 +274,22 @@ def main():
     
     # Save trained models
     torch.save(global_server_model.state_dict(), "server_model.pth")
-    torch.save(client_models[0].state_dict(), "client_model.pth")
+    for i in range(NUM_CLIENTS):
+        torch.save(client_models[i].state_dict(), "client"+str(i)+"_model.pth")
     print("Models saved successfully.")
     
     # Test server model
-    server_model = repvit_m1_1(num_classes=10).to(device)
-    server_model.load_state_dict(torch.load("server_model.pth"))
+    server_model = repvit_m1_1(num_classes=CLASS_NUM[0]).to(device)
+    server_model.load_state_dict(torch.load("server_model.pth",weights_only=True))
     server_acc = test_model(server_model, test_loader)
     print(f"Server Model Test Accuracy: {server_acc:.2f}%")
     
     # Test client model
-    client_model = MobileNetV3(n_class=10).to(device)
-    client_model.load_state_dict(torch.load("client_model.pth"))
-    client_acc = test_model(client_model, test_loader)
-    print(f"Client Model Test Accuracy: {client_acc:.2f}%")
+    for i in range(NUM_CLIENTS):
+        client_model = MobileNetV3(n_class=CLASS_NUM[i+1]).to(device)
+        client_model.load_state_dict(torch.load("client"+str(i)+"_model.pth",weights_only=True))
+        client_acc = test_model(client_model, test_loader)
+        print(f"Client->{i} Model Test Accuracy: {client_acc:.2f}%")
 
 if __name__ == "__main__":
     main()
\ No newline at end of file
diff --git a/model/mobilenetv3.py b/model/mobilenetv3.py
index 4692cf9..2de909a 100644
--- a/model/mobilenetv3.py
+++ b/model/mobilenetv3.py
@@ -200,6 +200,11 @@ class MobileNetV3(nn.Module):
 
         self._initialize_weights()
 
+
+    def extract_features(self, x):
+        x = self.features(x)
+        return x
+
     def forward(self, x):
         x = self.features(x)
         x = x.mean(3).mean(2)
diff --git a/model/repvit.py b/model/repvit.py
index 78197f8..27b07b9 100644
--- a/model/repvit.py
+++ b/model/repvit.py
@@ -236,6 +236,10 @@ class RepViT(nn.Module):
         self.features = nn.ModuleList(layers)
         self.classifier = Classfier(output_channel, num_classes, distillation)
         
+    def extract_features(self, x):
+        for f in self.features:
+            x = f(x)
+        return x
     def forward(self, x):
         # x = self.features(x)
         for f in self.features: