Skip to main contentSkip to user menuSkip to navigation

Federated Learning Systems

Master privacy-preserving distributed machine learning across decentralized data sources

55 min readAdvanced
Not Started
Loading...

What is Federated Learning?

Federated Learning (FL) is a machine learning paradigm that enables multiple parties to collaboratively train a shared model without sharing their raw data. Instead of centralizing data, the model travels to the data, trains locally, and only shares model updates - preserving privacy and reducing communication costs.

Privacy-Preserving

Raw data never leaves client devices or organizations

Communication Efficient

Share model updates instead of large datasets

Personalization

Adapt global model to local data distributions

Federated Learning Simulator

Configure federated learning parameters and see the impact on training dynamics

Simulation Results

Total Training Data:0.1M samples
Round Duration:25 min
Total Training Time:20h 50m
Communication Rounds:100
Expected Convergence:85%

Federated Learning Strategies

FedAvg (Federated Averaging)

Weighted average of client model updates

Convergence

Good

Communication Cost

Low

Privacy Level

Basic

Production Federated Learning Framework

Complete Federated Learning System
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
from abc import ABC, abstractmethod
import logging
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
import time

@dataclass
class FLConfig:
    """Federated learning configuration"""
    num_clients: int = 100
    clients_per_round: int = 10
    num_rounds: int = 50
    local_epochs: int = 5
    learning_rate: float = 0.01
    batch_size: int = 32
    model_name: str = "federated_model"
    aggregation_strategy: str = "fedavg"
    min_clients: int = 1
    privacy_mechanism: Optional[str] = None
    differential_privacy_epsilon: float = 1.0
    differential_privacy_delta: float = 1e-5

class Client:
    """Federated learning client"""
    
    def __init__(self, client_id: str, train_dataset: Dataset, 
                 val_dataset: Optional[Dataset] = None, device: str = 'cpu'):
        self.client_id = client_id
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.device = device
        self.model = None
        self.optimizer = None
        self.local_epochs = 5
        self.batch_size = 32
        self.learning_rate = 0.01
        
        # Privacy and security
        self.noise_multiplier = 0.0
        self.max_grad_norm = 1.0
        
        # Training history
        self.training_history = []
        
    def set_model(self, model: nn.Module):
        """Set the global model for local training"""
        self.model = model.to(self.device)
        self.optimizer = optim.SGD(self.model.parameters(), 
                                 lr=self.learning_rate, 
                                 momentum=0.9)
    
    def get_data_size(self) -> int:
        """Get size of client's training data"""
        return len(self.train_dataset)
    
    def train(self, global_round: int) -> Dict[str, Any]:
        """Perform local training and return model updates"""
        if self.model is None:
            raise ValueError("Model not set for client")
        
        self.model.train()
        dataloader = DataLoader(self.train_dataset, 
                              batch_size=self.batch_size, 
                              shuffle=True)
        
        criterion = nn.CrossEntropyLoss()
        epoch_losses = []
        
        for epoch in range(self.local_epochs):
            batch_losses = []
            
            for batch_idx, (data, targets) in enumerate(dataloader):
                data, targets = data.to(self.device), targets.to(self.device)
                
                self.optimizer.zero_grad()
                outputs = self.model(data)
                loss = criterion(outputs, targets)
                loss.backward()
                
                # Gradient clipping for differential privacy
                if self.max_grad_norm > 0:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), 
                                                 self.max_grad_norm)
                
                # Add differential privacy noise
                if self.noise_multiplier > 0:
                    self._add_dp_noise()
                
                self.optimizer.step()
                batch_losses.append(loss.item())
            
            epoch_loss = np.mean(batch_losses)
            epoch_losses.append(epoch_loss)
        
        # Evaluate on validation set if available
        val_accuracy = self._evaluate() if self.val_dataset else None
        
        # Prepare training results
        training_result = {
            'client_id': self.client_id,
            'global_round': global_round,
            'local_epochs': self.local_epochs,
            'final_loss': epoch_losses[-1],
            'data_size': self.get_data_size(),
            'val_accuracy': val_accuracy,
            'training_time': time.time()
        }
        
        self.training_history.append(training_result)
        return training_result
    
    def _add_dp_noise(self):
        """Add differential privacy noise to gradients"""
        with torch.no_grad():
            for param in self.model.parameters():
                if param.grad is not None:
                    noise = torch.normal(
                        mean=0.0,
                        std=self.noise_multiplier * self.max_grad_norm,
                        size=param.grad.shape,
                        device=param.grad.device
                    )
                    param.grad.add_(noise)
    
    def _evaluate(self) -> float:
        """Evaluate model on validation set"""
        if self.val_dataset is None:
            return 0.0
        
        self.model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            dataloader = DataLoader(self.val_dataset, 
                                  batch_size=self.batch_size, 
                                  shuffle=False)
            
            for data, targets in dataloader:
                data, targets = data.to(self.device), targets.to(self.device)
                outputs = self.model(data)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        
        return correct / total if total > 0 else 0.0
    
    def get_model_updates(self, initial_model: nn.Module) -> Dict[str, torch.Tensor]:
        """Get model parameter updates (difference from initial model)"""
        updates = {}
        initial_state = initial_model.state_dict()
        current_state = self.model.state_dict()
        
        for key in current_state:
            updates[key] = current_state[key] - initial_state[key]
        
        return updates
    
    def set_privacy_params(self, noise_multiplier: float = 0.0, 
                          max_grad_norm: float = 1.0):
        """Configure differential privacy parameters"""
        self.noise_multiplier = noise_multiplier
        self.max_grad_norm = max_grad_norm

class AggregationStrategy(ABC):
    """Abstract base class for aggregation strategies"""
    
    @abstractmethod
    def aggregate(self, client_updates: List[Dict], 
                 client_weights: List[float]) -> Dict[str, torch.Tensor]:
        pass

class FedAvgAggregation(AggregationStrategy):
    """Federated Averaging aggregation strategy"""
    
    def aggregate(self, client_updates: List[Dict], 
                 client_weights: List[float]) -> Dict[str, torch.Tensor]:
        """Weighted average of client model updates"""
        if not client_updates:
            return {}
        
        # Normalize weights
        total_weight = sum(client_weights)
        weights = [w / total_weight for w in client_weights]
        
        # Initialize aggregated updates
        aggregated = {}
        for key in client_updates[0]:
            aggregated[key] = torch.zeros_like(client_updates[0][key])
        
        # Weighted averaging
        for client_update, weight in zip(client_updates, weights):
            for key in aggregated:
                aggregated[key] += weight * client_update[key]
        
        return aggregated

class FedProxAggregation(AggregationStrategy):
    """FedProx aggregation with proximal term"""
    
    def __init__(self, mu: float = 0.01):
        self.mu = mu  # Proximal term weight
    
    def aggregate(self, client_updates: List[Dict], 
                 client_weights: List[float]) -> Dict[str, torch.Tensor]:
        """FedProx aggregation with regularization"""
        # For simplicity, this implements standard FedAvg
        # In practice, FedProx modifies the local training objective
        fedavg = FedAvgAggregation()
        return fedavg.aggregate(client_updates, client_weights)

class FederatedServer:
    """Centralized federated learning server"""
    
    def __init__(self, model: nn.Module, config: FLConfig, device: str = 'cpu'):
        self.model = model.to(device)
        self.config = config
        self.device = device
        
        # Aggregation strategy
        if config.aggregation_strategy == 'fedavg':
            self.aggregator = FedAvgAggregation()
        elif config.aggregation_strategy == 'fedprox':
            self.aggregator = FedProxAggregation()
        else:
            self.aggregator = FedAvgAggregation()
        
        # Training history
        self.training_history = []
        self.global_model_history = []
        
        # Clients
        self.clients: Dict[str, Client] = {}
        
    def register_client(self, client: Client):
        """Register a client with the server"""
        self.clients[client.client_id] = client
        client.set_model(self._get_model_copy())
        
        # Set privacy parameters if specified
        if self.config.privacy_mechanism == 'differential_privacy':
            # Calculate noise multiplier from epsilon and delta
            noise_multiplier = self._calculate_noise_multiplier()
            client.set_privacy_params(noise_multiplier=noise_multiplier)
    
    def _get_model_copy(self) -> nn.Module:
        """Get a copy of the current global model"""
        model_copy = type(self.model)()  # Create new instance
        model_copy.load_state_dict(self.model.state_dict())
        return model_copy
    
    def _calculate_noise_multiplier(self) -> float:
        """Calculate noise multiplier for differential privacy"""
        # Simplified calculation - in practice, use privacy accounting tools
        epsilon = self.config.differential_privacy_epsilon
        delta = self.config.differential_privacy_delta
        return np.sqrt(2 * np.log(1.25 / delta)) / epsilon
    
    def select_clients(self, round_num: int) -> List[Client]:
        """Select clients for training round"""
        available_clients = list(self.clients.values())
        
        if len(available_clients) < self.config.min_clients:
            raise ValueError(f"Not enough clients: {len(available_clients)} &lt; {self.config.min_clients}")
        
        num_selected = min(self.config.clients_per_round, len(available_clients))
        
        # Random selection (could implement more sophisticated strategies)
        selected = random.sample(available_clients, num_selected)
        
        return selected
    
    def train_round(self, round_num: int) -> Dict[str, Any]:
        """Execute one round of federated training"""
        print(f"\nStarting federated round {round_num + 1}/{self.config.num_rounds}")
        
        # Select clients
        selected_clients = self.select_clients(round_num)
        print(f"Selected {len(selected_clients)} clients for training")
        
        # Distribute current global model to selected clients
        initial_model = self._get_model_copy()
        for client in selected_clients:
            client.set_model(self._get_model_copy())
            client.local_epochs = self.config.local_epochs
            client.learning_rate = self.config.learning_rate
            client.batch_size = self.config.batch_size
        
        # Parallel local training
        client_results = []
        client_updates = []
        client_weights = []
        
        with ThreadPoolExecutor(max_workers=min(len(selected_clients), 10)) as executor:
            # Submit training tasks
            future_to_client = {
                executor.submit(client.train, round_num): client 
                for client in selected_clients
            }
            
            # Collect results
            for future in as_completed(future_to_client):
                client = future_to_client[future]
                try:
                    result = future.result()
                    client_results.append(result)
                    
                    # Get model updates
                    updates = client.get_model_updates(initial_model)
                    client_updates.append(updates)
                    client_weights.append(client.get_data_size())
                    
                except Exception as exc:
                    print(f'Client {client.client_id} generated an exception: {exc}')
        
        # Aggregate model updates
        if client_updates:
            aggregated_updates = self.aggregator.aggregate(client_updates, client_weights)
            
            # Apply aggregated updates to global model
            global_state = self.model.state_dict()
            for key in aggregated_updates:
                global_state[key] += aggregated_updates[key]
            self.model.load_state_dict(global_state)
        
        # Evaluate global model
        global_accuracy = self.evaluate_global_model()
        
        # Record round statistics
        round_stats = {
            'round': round_num,
            'num_clients': len(selected_clients),
            'avg_local_loss': np.mean([r['final_loss'] for r in client_results]),
            'global_accuracy': global_accuracy,
            'total_data_points': sum(client_weights),
            'timestamp': time.time()
        }
        
        self.training_history.append(round_stats)
        
        print(f"Round {round_num + 1} completed:")
        print(f"  Average local loss: {round_stats['avg_local_loss']:.4f}")
        print(f"  Global accuracy: {global_accuracy:.4f}")
        
        return round_stats
    
    def evaluate_global_model(self) -> float:
        """Evaluate global model on server test set"""
        # This would typically use a held-out server test set
        # For simplicity, we'll return a placeholder
        return 0.85 + np.random.normal(0, 0.02)  # Simulate improving accuracy
    
    def train(self) -> Dict[str, Any]:
        """Run complete federated learning training"""
        print(f"Starting federated learning with {len(self.clients)} clients")
        print(f"Configuration: {self.config}")
        
        # Training loop
        for round_num in range(self.config.num_rounds):
            round_stats = self.train_round(round_num)
            
            # Save model checkpoint periodically
            if (round_num + 1) % 10 == 0:
                self.save_checkpoint(round_num)
        
        final_stats = {
            'total_rounds': self.config.num_rounds,
            'final_accuracy': self.training_history[-1]['global_accuracy'],
            'training_history': self.training_history
        }
        
        return final_stats
    
    def save_checkpoint(self, round_num: int):
        """Save model and training state"""
        checkpoint = {
            'round': round_num,
            'model_state_dict': self.model.state_dict(),
            'config': self.config,
            'training_history': self.training_history
        }
        
        torch.save(checkpoint, f'{self.config.model_name}_round_{round_num}.pt')

# Example neural network for federated learning
class FederatedNet(nn.Module):
    """Simple CNN for federated learning"""
    
    def __init__(self, num_classes: int = 10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, num_classes)
    
    def forward(self, x):
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.conv2(x)
        x = torch.relu(x)
        x = torch.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return torch.log_softmax(x, dim=1)

# Usage example
def run_federated_learning():
    """Example of running federated learning"""
    
    # Configuration
    config = FLConfig(
        num_clients=50,
        clients_per_round=10,
        num_rounds=25,
        local_epochs=3,
        learning_rate=0.01,
        aggregation_strategy='fedavg',
        privacy_mechanism='differential_privacy',
        differential_privacy_epsilon=1.0
    )
    
    # Create global model
    global_model = FederatedNet(num_classes=10)
    
    # Initialize federated server
    server = FederatedServer(global_model, config)
    
    # Create and register clients (replace with actual data loading)
    for i in range(config.num_clients):
        # In practice, each client would have their own dataset
        # Here we simulate with random data
        dummy_dataset = [(torch.randn(1, 28, 28), torch.randint(0, 10, (1,))[0]) 
                        for _ in range(100)]
        
        client = Client(
            client_id=f"client_{i}",
            train_dataset=dummy_dataset,
            device='cpu'
        )
        
        server.register_client(client)
    
    # Run federated training
    results = server.train()
    
    print("\nFederated Learning Completed!")
    print(f"Final global accuracy: {results['final_accuracy']:.4f}")
    
    return results, server

if __name__ == "__main__":
    results, server = run_federated_learning()

Privacy-Preserving Mechanisms

Basic FL

Standard federated learning without additional privacy

Privacy Protection

Limited

Computational Overhead

None

Differential Privacy Implementation
import torch
import numpy as np
from typing import List, Tuple
import math

class DifferentialPrivacyMechanism:
    """Differential privacy for federated learning"""
    
    def __init__(self, epsilon: float = 1.0, delta: float = 1e-5, 
                 sensitivity: float = 1.0, mechanism: str = 'gaussian'):
        self.epsilon = epsilon
        self.delta = delta
        self.sensitivity = sensitivity
        self.mechanism = mechanism
        
        # Calculate noise parameters
        if mechanism == 'gaussian':
            self.sigma = self._calculate_gaussian_noise_scale()
        elif mechanism == 'laplace':
            self.b = sensitivity / epsilon
    
    def _calculate_gaussian_noise_scale(self) -> float:
        """Calculate noise scale for Gaussian mechanism"""
        # Using the analytic Gaussian mechanism
        return self.sensitivity * math.sqrt(2 * math.log(1.25 / self.delta)) / self.epsilon
    
    def add_noise_to_gradients(self, gradients: List[torch.Tensor], 
                             clip_norm: float = 1.0) -> List[torch.Tensor]:
        """Add differential privacy noise to gradients"""
        noisy_gradients = []
        
        for grad in gradients:
            # Clip gradients
            clipped_grad = self._clip_gradient(grad, clip_norm)
            
            # Add noise
            if self.mechanism == 'gaussian':
                noise = torch.normal(
                    mean=0.0,
                    std=self.sigma,
                    size=grad.shape,
                    device=grad.device
                )
            elif self.mechanism == 'laplace':
                noise = torch.distributions.Laplace(0, self.b).sample(grad.shape).to(grad.device)
            else:
                raise ValueError(f"Unknown mechanism: {self.mechanism}")
            
            noisy_grad = clipped_grad + noise
            noisy_gradients.append(noisy_grad)
        
        return noisy_gradients
    
    def _clip_gradient(self, gradient: torch.Tensor, clip_norm: float) -> torch.Tensor:
        """Clip gradient to bound sensitivity"""
        grad_norm = torch.norm(gradient)
        if grad_norm > clip_norm:
            return gradient * (clip_norm / grad_norm)
        return gradient
    
    def compute_privacy_spent(self, steps: int, batch_size: int, 
                            dataset_size: int) -> Tuple[float, float]:
        """Compute total privacy spent using RDP accounting"""
        # Simplified privacy accounting - use opacus or similar for production
        q = batch_size / dataset_size  # Sampling ratio
        
        if self.mechanism == 'gaussian':
            # RDP calculation for Gaussian mechanism
            alpha = 2.0  # Renyi parameter
            rdp = alpha * q**2 * steps / (2 * self.sigma**2)
            
            # Convert RDP to (epsilon, delta)
            epsilon_spent = rdp + math.log(1 / self.delta) / (alpha - 1)
            
            return epsilon_spent, self.delta
        else:
            # Simple composition for Laplace
            epsilon_spent = self.epsilon * steps
            return epsilon_spent, 0.0

class SecureAggregation:
    """Secure aggregation protocol for federated learning"""
    
    def __init__(self, num_clients: int, threshold: int):
        self.num_clients = num_clients
        self.threshold = threshold  # Minimum clients needed
        
        # In practice, these would be proper cryptographic keys
        self.client_keys = {}
        self.server_key = None
    
    def setup_keys(self):
        """Setup cryptographic keys for secure aggregation"""
        # Simplified key setup - use proper cryptography in production
        for i in range(self.num_clients):
            self.client_keys[f"client_{i}"] = torch.randint(0, 1000, (1,)).item()
        
        self.server_key = torch.randint(0, 1000, (1,)).item()
    
    def encrypt_model_update(self, client_id: str, 
                           model_update: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Encrypt model update for secure aggregation"""
        if client_id not in self.client_keys:
            raise ValueError(f"No key found for client {client_id}")
        
        encrypted_update = {}
        client_key = self.client_keys[client_id]
        
        for layer_name, weights in model_update.items():
            # Simplified encryption - add noise based on key
            noise_scale = client_key / 1000.0
            noise = torch.normal(0, noise_scale, size=weights.shape)
            encrypted_update[layer_name] = weights + noise
        
        return encrypted_update
    
    def aggregate_encrypted_updates(self, 
                                  encrypted_updates: List[Dict[str, torch.Tensor]],
                                  client_ids: List[str]) -> Dict[str, torch.Tensor]:
        """Aggregate encrypted model updates"""
        if len(encrypted_updates) < self.threshold:
            raise ValueError(f"Not enough clients: {len(encrypted_updates)} &lt; {self.threshold}")
        
        # Initialize aggregated result
        aggregated = {}
        for layer_name in encrypted_updates[0]:
            aggregated[layer_name] = torch.zeros_like(encrypted_updates[0][layer_name])
        
        # Sum encrypted updates
        for update in encrypted_updates:
            for layer_name in aggregated:
                aggregated[layer_name] += update[layer_name]
        
        # Average
        num_updates = len(encrypted_updates)
        for layer_name in aggregated:
            aggregated[layer_name] /= num_updates
        
        # Remove encryption noise (simplified)
        for client_id in client_ids:
            client_key = self.client_keys[client_id]
            noise_scale = client_key / 1000.0
            
            for layer_name in aggregated:
                # Remove average noise contribution
                expected_noise = torch.zeros_like(aggregated[layer_name])
                aggregated[layer_name] -= expected_noise / num_updates
        
        return aggregated

# Usage example for privacy-preserving FL
def privacy_preserving_federated_learning():
    """Example with differential privacy and secure aggregation"""
    
    # Setup differential privacy
    dp_mechanism = DifferentialPrivacyMechanism(
        epsilon=1.0,
        delta=1e-5,
        sensitivity=1.0
    )
    
    # Setup secure aggregation
    secure_agg = SecureAggregation(num_clients=10, threshold=6)
    secure_agg.setup_keys()
    
    # Simulate model updates from clients
    model_updates = []
    client_ids = []
    
    for i in range(8):  # 8 clients participate
        client_id = f"client_{i}"
        client_ids.append(client_id)
        
        # Simulate model update
        update = {
            'layer1.weight': torch.randn(10, 5),
            'layer1.bias': torch.randn(10),
            'layer2.weight': torch.randn(1, 10),
            'layer2.bias': torch.randn(1)
        }
        
        # Apply differential privacy
        gradients = list(update.values())
        noisy_gradients = dp_mechanism.add_noise_to_gradients(gradients)
        
        # Reconstruct update dict
        noisy_update = {}
        for j, (layer_name, _) in enumerate(update.items()):
            noisy_update[layer_name] = noisy_gradients[j]
        
        # Encrypt for secure aggregation
        encrypted_update = secure_agg.encrypt_model_update(client_id, noisy_update)
        model_updates.append(encrypted_update)
    
    # Aggregate encrypted updates
    final_update = secure_agg.aggregate_encrypted_updates(model_updates, client_ids)
    
    # Compute privacy spent
    epsilon_spent, delta_spent = dp_mechanism.compute_privacy_spent(
        steps=50, batch_size=32, dataset_size=1000
    )
    
    print(f"Privacy budget spent: ε = {epsilon_spent:.3f}, δ = {delta_spent:.2e}")
    print("Secure aggregation completed successfully!")
    
    return final_update, epsilon_spent

if __name__ == "__main__":
    final_update, privacy_cost = privacy_preserving_federated_learning()

Real-World Federated Learning Applications

G

Google - Gboard

Google uses federated learning for improving Gboard's next-word prediction and query suggestions without accessing user typing data, training on millions of mobile devices.

  • • 100M+ participating devices
  • • Privacy-preserving language modeling
  • • Secure aggregation protocol
  • • Differential privacy integration
A

Apple - Health & Siri

Apple employs federated learning for health data analysis and Siri improvements, ensuring sensitive personal data remains on-device while benefiting from collective learning.

  • • On-device health pattern recognition
  • • Voice command optimization
  • • Local differential privacy
  • • Cross-device personalization
H

Healthcare Consortiums

Hospital networks use federated learning for medical imaging, drug discovery, and clinical decision support while maintaining patient privacy and regulatory compliance.

  • • Multi-hospital collaboration
  • • HIPAA-compliant training
  • • Rare disease research
  • • Clinical trial optimization
F

Financial Services

Banks collaborate on fraud detection and risk assessment models using federated learning, sharing insights without exposing sensitive customer transaction data.

  • • Cross-bank fraud detection
  • • Regulatory compliance (GDPR)
  • • Risk model improvements
  • • Anti-money laundering (AML)

Federated Learning Best Practices

✅ Do

Design for data heterogeneity and client diversity

Implement robust client selection strategies

Use differential privacy for strong privacy guarantees

Monitor client dropout and handle failures gracefully

Optimize communication efficiency with compression

Validate model quality across client populations

Implement secure aggregation for sensitive domains

❌ Don't

Assume all clients have similar data distributions

Ignore communication costs and bandwidth constraints

Rely on basic FL without additional privacy mechanisms

Use overly complex models that don't converge well

Forget to handle client computational heterogeneity

Skip validation on diverse client populations

Underestimate the impact of client incentives

No quiz questions available
Quiz ID "federated-learning" not found