Federated Learning Systems
Master privacy-preserving distributed machine learning across decentralized data sources
What is Federated Learning?
Federated Learning (FL) is a machine learning paradigm that enables multiple parties to collaboratively train a shared model without sharing their raw data. Instead of centralizing data, the model travels to the data, trains locally, and only shares model updates - preserving privacy and reducing communication costs.
Privacy-Preserving
Raw data never leaves client devices or organizations
Communication Efficient
Share model updates instead of large datasets
Personalization
Adapt global model to local data distributions
Federated Learning Simulator
Configure federated learning parameters and see the impact on training dynamics
Simulation Results
Federated Learning Strategies
FedAvg (Federated Averaging)
Weighted average of client model updates
Good
Low
Basic
Production Federated Learning Framework
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
from abc import ABC, abstractmethod
import logging
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
@dataclass
class FLConfig:
"""Federated learning configuration"""
num_clients: int = 100
clients_per_round: int = 10
num_rounds: int = 50
local_epochs: int = 5
learning_rate: float = 0.01
batch_size: int = 32
model_name: str = "federated_model"
aggregation_strategy: str = "fedavg"
min_clients: int = 1
privacy_mechanism: Optional[str] = None
differential_privacy_epsilon: float = 1.0
differential_privacy_delta: float = 1e-5
class Client:
"""Federated learning client"""
def __init__(self, client_id: str, train_dataset: Dataset,
val_dataset: Optional[Dataset] = None, device: str = 'cpu'):
self.client_id = client_id
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.device = device
self.model = None
self.optimizer = None
self.local_epochs = 5
self.batch_size = 32
self.learning_rate = 0.01
# Privacy and security
self.noise_multiplier = 0.0
self.max_grad_norm = 1.0
# Training history
self.training_history = []
def set_model(self, model: nn.Module):
"""Set the global model for local training"""
self.model = model.to(self.device)
self.optimizer = optim.SGD(self.model.parameters(),
lr=self.learning_rate,
momentum=0.9)
def get_data_size(self) -> int:
"""Get size of client's training data"""
return len(self.train_dataset)
def train(self, global_round: int) -> Dict[str, Any]:
"""Perform local training and return model updates"""
if self.model is None:
raise ValueError("Model not set for client")
self.model.train()
dataloader = DataLoader(self.train_dataset,
batch_size=self.batch_size,
shuffle=True)
criterion = nn.CrossEntropyLoss()
epoch_losses = []
for epoch in range(self.local_epochs):
batch_losses = []
for batch_idx, (data, targets) in enumerate(dataloader):
data, targets = data.to(self.device), targets.to(self.device)
self.optimizer.zero_grad()
outputs = self.model(data)
loss = criterion(outputs, targets)
loss.backward()
# Gradient clipping for differential privacy
if self.max_grad_norm > 0:
torch.nn.utils.clip_grad_norm_(self.model.parameters(),
self.max_grad_norm)
# Add differential privacy noise
if self.noise_multiplier > 0:
self._add_dp_noise()
self.optimizer.step()
batch_losses.append(loss.item())
epoch_loss = np.mean(batch_losses)
epoch_losses.append(epoch_loss)
# Evaluate on validation set if available
val_accuracy = self._evaluate() if self.val_dataset else None
# Prepare training results
training_result = {
'client_id': self.client_id,
'global_round': global_round,
'local_epochs': self.local_epochs,
'final_loss': epoch_losses[-1],
'data_size': self.get_data_size(),
'val_accuracy': val_accuracy,
'training_time': time.time()
}
self.training_history.append(training_result)
return training_result
def _add_dp_noise(self):
"""Add differential privacy noise to gradients"""
with torch.no_grad():
for param in self.model.parameters():
if param.grad is not None:
noise = torch.normal(
mean=0.0,
std=self.noise_multiplier * self.max_grad_norm,
size=param.grad.shape,
device=param.grad.device
)
param.grad.add_(noise)
def _evaluate(self) -> float:
"""Evaluate model on validation set"""
if self.val_dataset is None:
return 0.0
self.model.eval()
correct = 0
total = 0
with torch.no_grad():
dataloader = DataLoader(self.val_dataset,
batch_size=self.batch_size,
shuffle=False)
for data, targets in dataloader:
data, targets = data.to(self.device), targets.to(self.device)
outputs = self.model(data)
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += (predicted == targets).sum().item()
return correct / total if total > 0 else 0.0
def get_model_updates(self, initial_model: nn.Module) -> Dict[str, torch.Tensor]:
"""Get model parameter updates (difference from initial model)"""
updates = {}
initial_state = initial_model.state_dict()
current_state = self.model.state_dict()
for key in current_state:
updates[key] = current_state[key] - initial_state[key]
return updates
def set_privacy_params(self, noise_multiplier: float = 0.0,
max_grad_norm: float = 1.0):
"""Configure differential privacy parameters"""
self.noise_multiplier = noise_multiplier
self.max_grad_norm = max_grad_norm
class AggregationStrategy(ABC):
"""Abstract base class for aggregation strategies"""
@abstractmethod
def aggregate(self, client_updates: List[Dict],
client_weights: List[float]) -> Dict[str, torch.Tensor]:
pass
class FedAvgAggregation(AggregationStrategy):
"""Federated Averaging aggregation strategy"""
def aggregate(self, client_updates: List[Dict],
client_weights: List[float]) -> Dict[str, torch.Tensor]:
"""Weighted average of client model updates"""
if not client_updates:
return {}
# Normalize weights
total_weight = sum(client_weights)
weights = [w / total_weight for w in client_weights]
# Initialize aggregated updates
aggregated = {}
for key in client_updates[0]:
aggregated[key] = torch.zeros_like(client_updates[0][key])
# Weighted averaging
for client_update, weight in zip(client_updates, weights):
for key in aggregated:
aggregated[key] += weight * client_update[key]
return aggregated
class FedProxAggregation(AggregationStrategy):
"""FedProx aggregation with proximal term"""
def __init__(self, mu: float = 0.01):
self.mu = mu # Proximal term weight
def aggregate(self, client_updates: List[Dict],
client_weights: List[float]) -> Dict[str, torch.Tensor]:
"""FedProx aggregation with regularization"""
# For simplicity, this implements standard FedAvg
# In practice, FedProx modifies the local training objective
fedavg = FedAvgAggregation()
return fedavg.aggregate(client_updates, client_weights)
class FederatedServer:
"""Centralized federated learning server"""
def __init__(self, model: nn.Module, config: FLConfig, device: str = 'cpu'):
self.model = model.to(device)
self.config = config
self.device = device
# Aggregation strategy
if config.aggregation_strategy == 'fedavg':
self.aggregator = FedAvgAggregation()
elif config.aggregation_strategy == 'fedprox':
self.aggregator = FedProxAggregation()
else:
self.aggregator = FedAvgAggregation()
# Training history
self.training_history = []
self.global_model_history = []
# Clients
self.clients: Dict[str, Client] = {}
def register_client(self, client: Client):
"""Register a client with the server"""
self.clients[client.client_id] = client
client.set_model(self._get_model_copy())
# Set privacy parameters if specified
if self.config.privacy_mechanism == 'differential_privacy':
# Calculate noise multiplier from epsilon and delta
noise_multiplier = self._calculate_noise_multiplier()
client.set_privacy_params(noise_multiplier=noise_multiplier)
def _get_model_copy(self) -> nn.Module:
"""Get a copy of the current global model"""
model_copy = type(self.model)() # Create new instance
model_copy.load_state_dict(self.model.state_dict())
return model_copy
def _calculate_noise_multiplier(self) -> float:
"""Calculate noise multiplier for differential privacy"""
# Simplified calculation - in practice, use privacy accounting tools
epsilon = self.config.differential_privacy_epsilon
delta = self.config.differential_privacy_delta
return np.sqrt(2 * np.log(1.25 / delta)) / epsilon
def select_clients(self, round_num: int) -> List[Client]:
"""Select clients for training round"""
available_clients = list(self.clients.values())
if len(available_clients) < self.config.min_clients:
raise ValueError(f"Not enough clients: {len(available_clients)} < {self.config.min_clients}")
num_selected = min(self.config.clients_per_round, len(available_clients))
# Random selection (could implement more sophisticated strategies)
selected = random.sample(available_clients, num_selected)
return selected
def train_round(self, round_num: int) -> Dict[str, Any]:
"""Execute one round of federated training"""
print(f"\nStarting federated round {round_num + 1}/{self.config.num_rounds}")
# Select clients
selected_clients = self.select_clients(round_num)
print(f"Selected {len(selected_clients)} clients for training")
# Distribute current global model to selected clients
initial_model = self._get_model_copy()
for client in selected_clients:
client.set_model(self._get_model_copy())
client.local_epochs = self.config.local_epochs
client.learning_rate = self.config.learning_rate
client.batch_size = self.config.batch_size
# Parallel local training
client_results = []
client_updates = []
client_weights = []
with ThreadPoolExecutor(max_workers=min(len(selected_clients), 10)) as executor:
# Submit training tasks
future_to_client = {
executor.submit(client.train, round_num): client
for client in selected_clients
}
# Collect results
for future in as_completed(future_to_client):
client = future_to_client[future]
try:
result = future.result()
client_results.append(result)
# Get model updates
updates = client.get_model_updates(initial_model)
client_updates.append(updates)
client_weights.append(client.get_data_size())
except Exception as exc:
print(f'Client {client.client_id} generated an exception: {exc}')
# Aggregate model updates
if client_updates:
aggregated_updates = self.aggregator.aggregate(client_updates, client_weights)
# Apply aggregated updates to global model
global_state = self.model.state_dict()
for key in aggregated_updates:
global_state[key] += aggregated_updates[key]
self.model.load_state_dict(global_state)
# Evaluate global model
global_accuracy = self.evaluate_global_model()
# Record round statistics
round_stats = {
'round': round_num,
'num_clients': len(selected_clients),
'avg_local_loss': np.mean([r['final_loss'] for r in client_results]),
'global_accuracy': global_accuracy,
'total_data_points': sum(client_weights),
'timestamp': time.time()
}
self.training_history.append(round_stats)
print(f"Round {round_num + 1} completed:")
print(f" Average local loss: {round_stats['avg_local_loss']:.4f}")
print(f" Global accuracy: {global_accuracy:.4f}")
return round_stats
def evaluate_global_model(self) -> float:
"""Evaluate global model on server test set"""
# This would typically use a held-out server test set
# For simplicity, we'll return a placeholder
return 0.85 + np.random.normal(0, 0.02) # Simulate improving accuracy
def train(self) -> Dict[str, Any]:
"""Run complete federated learning training"""
print(f"Starting federated learning with {len(self.clients)} clients")
print(f"Configuration: {self.config}")
# Training loop
for round_num in range(self.config.num_rounds):
round_stats = self.train_round(round_num)
# Save model checkpoint periodically
if (round_num + 1) % 10 == 0:
self.save_checkpoint(round_num)
final_stats = {
'total_rounds': self.config.num_rounds,
'final_accuracy': self.training_history[-1]['global_accuracy'],
'training_history': self.training_history
}
return final_stats
def save_checkpoint(self, round_num: int):
"""Save model and training state"""
checkpoint = {
'round': round_num,
'model_state_dict': self.model.state_dict(),
'config': self.config,
'training_history': self.training_history
}
torch.save(checkpoint, f'{self.config.model_name}_round_{round_num}.pt')
# Example neural network for federated learning
class FederatedNet(nn.Module):
"""Simple CNN for federated learning"""
def __init__(self, num_classes: int = 10):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
x = self.conv1(x)
x = torch.relu(x)
x = self.conv2(x)
x = torch.relu(x)
x = torch.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = torch.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
return torch.log_softmax(x, dim=1)
# Usage example
def run_federated_learning():
"""Example of running federated learning"""
# Configuration
config = FLConfig(
num_clients=50,
clients_per_round=10,
num_rounds=25,
local_epochs=3,
learning_rate=0.01,
aggregation_strategy='fedavg',
privacy_mechanism='differential_privacy',
differential_privacy_epsilon=1.0
)
# Create global model
global_model = FederatedNet(num_classes=10)
# Initialize federated server
server = FederatedServer(global_model, config)
# Create and register clients (replace with actual data loading)
for i in range(config.num_clients):
# In practice, each client would have their own dataset
# Here we simulate with random data
dummy_dataset = [(torch.randn(1, 28, 28), torch.randint(0, 10, (1,))[0])
for _ in range(100)]
client = Client(
client_id=f"client_{i}",
train_dataset=dummy_dataset,
device='cpu'
)
server.register_client(client)
# Run federated training
results = server.train()
print("\nFederated Learning Completed!")
print(f"Final global accuracy: {results['final_accuracy']:.4f}")
return results, server
if __name__ == "__main__":
results, server = run_federated_learning()Privacy-Preserving Mechanisms
Basic FL
Standard federated learning without additional privacy
Limited
None
import torch
import numpy as np
from typing import List, Tuple
import math
class DifferentialPrivacyMechanism:
"""Differential privacy for federated learning"""
def __init__(self, epsilon: float = 1.0, delta: float = 1e-5,
sensitivity: float = 1.0, mechanism: str = 'gaussian'):
self.epsilon = epsilon
self.delta = delta
self.sensitivity = sensitivity
self.mechanism = mechanism
# Calculate noise parameters
if mechanism == 'gaussian':
self.sigma = self._calculate_gaussian_noise_scale()
elif mechanism == 'laplace':
self.b = sensitivity / epsilon
def _calculate_gaussian_noise_scale(self) -> float:
"""Calculate noise scale for Gaussian mechanism"""
# Using the analytic Gaussian mechanism
return self.sensitivity * math.sqrt(2 * math.log(1.25 / self.delta)) / self.epsilon
def add_noise_to_gradients(self, gradients: List[torch.Tensor],
clip_norm: float = 1.0) -> List[torch.Tensor]:
"""Add differential privacy noise to gradients"""
noisy_gradients = []
for grad in gradients:
# Clip gradients
clipped_grad = self._clip_gradient(grad, clip_norm)
# Add noise
if self.mechanism == 'gaussian':
noise = torch.normal(
mean=0.0,
std=self.sigma,
size=grad.shape,
device=grad.device
)
elif self.mechanism == 'laplace':
noise = torch.distributions.Laplace(0, self.b).sample(grad.shape).to(grad.device)
else:
raise ValueError(f"Unknown mechanism: {self.mechanism}")
noisy_grad = clipped_grad + noise
noisy_gradients.append(noisy_grad)
return noisy_gradients
def _clip_gradient(self, gradient: torch.Tensor, clip_norm: float) -> torch.Tensor:
"""Clip gradient to bound sensitivity"""
grad_norm = torch.norm(gradient)
if grad_norm > clip_norm:
return gradient * (clip_norm / grad_norm)
return gradient
def compute_privacy_spent(self, steps: int, batch_size: int,
dataset_size: int) -> Tuple[float, float]:
"""Compute total privacy spent using RDP accounting"""
# Simplified privacy accounting - use opacus or similar for production
q = batch_size / dataset_size # Sampling ratio
if self.mechanism == 'gaussian':
# RDP calculation for Gaussian mechanism
alpha = 2.0 # Renyi parameter
rdp = alpha * q**2 * steps / (2 * self.sigma**2)
# Convert RDP to (epsilon, delta)
epsilon_spent = rdp + math.log(1 / self.delta) / (alpha - 1)
return epsilon_spent, self.delta
else:
# Simple composition for Laplace
epsilon_spent = self.epsilon * steps
return epsilon_spent, 0.0
class SecureAggregation:
"""Secure aggregation protocol for federated learning"""
def __init__(self, num_clients: int, threshold: int):
self.num_clients = num_clients
self.threshold = threshold # Minimum clients needed
# In practice, these would be proper cryptographic keys
self.client_keys = {}
self.server_key = None
def setup_keys(self):
"""Setup cryptographic keys for secure aggregation"""
# Simplified key setup - use proper cryptography in production
for i in range(self.num_clients):
self.client_keys[f"client_{i}"] = torch.randint(0, 1000, (1,)).item()
self.server_key = torch.randint(0, 1000, (1,)).item()
def encrypt_model_update(self, client_id: str,
model_update: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Encrypt model update for secure aggregation"""
if client_id not in self.client_keys:
raise ValueError(f"No key found for client {client_id}")
encrypted_update = {}
client_key = self.client_keys[client_id]
for layer_name, weights in model_update.items():
# Simplified encryption - add noise based on key
noise_scale = client_key / 1000.0
noise = torch.normal(0, noise_scale, size=weights.shape)
encrypted_update[layer_name] = weights + noise
return encrypted_update
def aggregate_encrypted_updates(self,
encrypted_updates: List[Dict[str, torch.Tensor]],
client_ids: List[str]) -> Dict[str, torch.Tensor]:
"""Aggregate encrypted model updates"""
if len(encrypted_updates) < self.threshold:
raise ValueError(f"Not enough clients: {len(encrypted_updates)} < {self.threshold}")
# Initialize aggregated result
aggregated = {}
for layer_name in encrypted_updates[0]:
aggregated[layer_name] = torch.zeros_like(encrypted_updates[0][layer_name])
# Sum encrypted updates
for update in encrypted_updates:
for layer_name in aggregated:
aggregated[layer_name] += update[layer_name]
# Average
num_updates = len(encrypted_updates)
for layer_name in aggregated:
aggregated[layer_name] /= num_updates
# Remove encryption noise (simplified)
for client_id in client_ids:
client_key = self.client_keys[client_id]
noise_scale = client_key / 1000.0
for layer_name in aggregated:
# Remove average noise contribution
expected_noise = torch.zeros_like(aggregated[layer_name])
aggregated[layer_name] -= expected_noise / num_updates
return aggregated
# Usage example for privacy-preserving FL
def privacy_preserving_federated_learning():
"""Example with differential privacy and secure aggregation"""
# Setup differential privacy
dp_mechanism = DifferentialPrivacyMechanism(
epsilon=1.0,
delta=1e-5,
sensitivity=1.0
)
# Setup secure aggregation
secure_agg = SecureAggregation(num_clients=10, threshold=6)
secure_agg.setup_keys()
# Simulate model updates from clients
model_updates = []
client_ids = []
for i in range(8): # 8 clients participate
client_id = f"client_{i}"
client_ids.append(client_id)
# Simulate model update
update = {
'layer1.weight': torch.randn(10, 5),
'layer1.bias': torch.randn(10),
'layer2.weight': torch.randn(1, 10),
'layer2.bias': torch.randn(1)
}
# Apply differential privacy
gradients = list(update.values())
noisy_gradients = dp_mechanism.add_noise_to_gradients(gradients)
# Reconstruct update dict
noisy_update = {}
for j, (layer_name, _) in enumerate(update.items()):
noisy_update[layer_name] = noisy_gradients[j]
# Encrypt for secure aggregation
encrypted_update = secure_agg.encrypt_model_update(client_id, noisy_update)
model_updates.append(encrypted_update)
# Aggregate encrypted updates
final_update = secure_agg.aggregate_encrypted_updates(model_updates, client_ids)
# Compute privacy spent
epsilon_spent, delta_spent = dp_mechanism.compute_privacy_spent(
steps=50, batch_size=32, dataset_size=1000
)
print(f"Privacy budget spent: ε = {epsilon_spent:.3f}, δ = {delta_spent:.2e}")
print("Secure aggregation completed successfully!")
return final_update, epsilon_spent
if __name__ == "__main__":
final_update, privacy_cost = privacy_preserving_federated_learning()Real-World Federated Learning Applications
Google - Gboard
Google uses federated learning for improving Gboard's next-word prediction and query suggestions without accessing user typing data, training on millions of mobile devices.
- • 100M+ participating devices
- • Privacy-preserving language modeling
- • Secure aggregation protocol
- • Differential privacy integration
Apple - Health & Siri
Apple employs federated learning for health data analysis and Siri improvements, ensuring sensitive personal data remains on-device while benefiting from collective learning.
- • On-device health pattern recognition
- • Voice command optimization
- • Local differential privacy
- • Cross-device personalization
Healthcare Consortiums
Hospital networks use federated learning for medical imaging, drug discovery, and clinical decision support while maintaining patient privacy and regulatory compliance.
- • Multi-hospital collaboration
- • HIPAA-compliant training
- • Rare disease research
- • Clinical trial optimization
Financial Services
Banks collaborate on fraud detection and risk assessment models using federated learning, sharing insights without exposing sensitive customer transaction data.
- • Cross-bank fraud detection
- • Regulatory compliance (GDPR)
- • Risk model improvements
- • Anti-money laundering (AML)
Federated Learning Best Practices
✅ Do
Design for data heterogeneity and client diversity
Implement robust client selection strategies
Use differential privacy for strong privacy guarantees
Monitor client dropout and handle failures gracefully
Optimize communication efficiency with compression
Validate model quality across client populations
Implement secure aggregation for sensitive domains
❌ Don't
Assume all clients have similar data distributions
Ignore communication costs and bandwidth constraints
Rely on basic FL without additional privacy mechanisms
Use overly complex models that don't converge well
Forget to handle client computational heterogeneity
Skip validation on diverse client populations
Underestimate the impact of client incentives