Skip to main contentSkip to user menuSkip to navigation

Mixture of Experts (MoE)

Master sparse model architectures with expert routing for efficient large-scale neural networks

60 min readAdvanced
Not Started
Loading...

What are Mixture of Experts?

Mixture of Experts (MoE) is a sparse neural network architecture where only a subset of model parameters (experts) are activated for each input. This enables scaling to massive model sizes while keeping computational cost constant, as routing mechanisms decide which experts should process each token or input.

Sparse Activation

Only 10-20% of parameters active per forward pass

Scalable Training

Scale to trillions of parameters with constant compute

Expert Specialization

Different experts learn specialized sub-tasks

MoE Efficiency Calculator

Analyze computational efficiency and resource usage of Mixture of Experts models

Efficiency Metrics

Active Parameters:22B
Memory Usage:2867200 GB
Total FLOPs:2867 TFLOPs
Expert Utilization:13%
Sparsity Ratio:88%
Routing Overhead:8388.608 GFLOPs

Expert Routing Strategies

Top-K Routing

Route tokens to the top-K highest scoring experts

Complexity

Low

Load Balance

Poor

Performance

Good

Production MoE Implementation

Complete Mixture of Experts Framework
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import numpy as np
from typing import Tuple, List, Optional, Dict
import math
from dataclasses import dataclass

@dataclass
class MoEConfig:
    """Configuration for Mixture of Experts"""
    num_experts: int = 8
    top_k: int = 2
    capacity_factor: float = 1.25
    expert_dropout: float = 0.0
    gate_noise: float = 0.1
    load_balancing_loss_weight: float = 0.01
    router_z_loss_weight: float = 0.001
    
class Expert(nn.Module):
    """Individual expert network"""
    
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: float = 0.1):
        super().__init__()
        self.w1 = nn.Linear(input_dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, output_dim, bias=False)
        self.w3 = nn.Linear(input_dim, hidden_dim, bias=False)  # GLU gate
        self.dropout = nn.Dropout(dropout)
        
        # Expert-specific initialization
        self._init_weights()
    
    def _init_weights(self):
        """Initialize expert weights with different variance for specialization"""
        for module in [self.w1, self.w2, self.w3]:
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """GLU-based expert forward pass"""
        gate = torch.sigmoid(self.w3(x))
        hidden = F.silu(self.w1(x)) * gate
        hidden = self.dropout(hidden)
        return self.w2(hidden)

class TopKRouter(nn.Module):
    """Top-K routing with load balancing"""
    
    def __init__(self, input_dim: int, num_experts: int, top_k: int, 
                 gate_noise: float = 0.1, capacity_factor: float = 1.25):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.gate_noise = gate_noise
        self.capacity_factor = capacity_factor
        
        # Gating network
        self.gate = nn.Linear(input_dim, num_experts, bias=False)
        self.gate_noise_scale = gate_noise
        
        # For load balancing
        self.register_buffer('expert_counts', torch.zeros(num_experts))
        self.register_buffer('total_tokens', torch.tensor(0))
        
    def _add_noise(self, logits: torch.Tensor) -> torch.Tensor:
        """Add noise during training for exploration"""
        if self.training and self.gate_noise_scale > 0:
            noise = torch.randn_like(logits) * self.gate_noise_scale
            return logits + noise
        return logits
    
    def _compute_load_balancing_loss(self, gates: torch.Tensor, 
                                   expert_indices: torch.Tensor) -> torch.Tensor:
        """Compute load balancing loss to encourage even expert usage"""
        # Count tokens per expert
        expert_counts = torch.bincount(expert_indices.flatten(), minlength=self.num_experts)
        
        # Compute mean gate values per expert
        gates_mean = torch.mean(gates, dim=0)
        
        # Load balancing loss (encourage uniform distribution)
        load_loss = torch.sum(gates_mean * expert_counts.float()) * self.num_experts
        load_loss = load_loss / (torch.sum(expert_counts) + 1e-6)
        
        return load_loss
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Route tokens to experts
        
        Args:
            x: Input tensor [batch_size, seq_len, hidden_dim]
            
        Returns:
            expert_weights: Gating weights [batch_size, seq_len, top_k]
            expert_indices: Selected expert indices [batch_size, seq_len, top_k]
            load_balancing_loss: Loss to encourage uniform expert usage
            router_z_loss: Loss to prevent router collapse
        """
        batch_size, seq_len, hidden_dim = x.shape
        
        # Reshape for routing
        x_flat = x.view(-1, hidden_dim)  # [batch_size * seq_len, hidden_dim]
        
        # Compute gate logits
        gate_logits = self.gate(x_flat)  # [batch_size * seq_len, num_experts]
        gate_logits = self._add_noise(gate_logits)
        
        # Apply softmax to get probabilities
        gate_probs = F.softmax(gate_logits, dim=-1)
        
        # Select top-k experts
        top_k_probs, top_k_indices = torch.topk(gate_probs, self.top_k, dim=-1)
        
        # Normalize top-k probabilities
        top_k_probs = top_k_probs / (torch.sum(top_k_probs, dim=-1, keepdim=True) + 1e-6)
        
        # Compute losses
        load_balancing_loss = self._compute_load_balancing_loss(gate_probs, top_k_indices)
        
        # Router z-loss (prevent router collapse)
        router_z_loss = torch.logsumexp(gate_logits, dim=-1).mean()
        
        # Reshape back to original dimensions
        expert_weights = top_k_probs.view(batch_size, seq_len, self.top_k)
        expert_indices = top_k_indices.view(batch_size, seq_len, self.top_k)
        
        return expert_weights, expert_indices, load_balancing_loss, router_z_loss

class SwitchRouter(nn.Module):
    """Switch Transformer routing with capacity constraints"""
    
    def __init__(self, input_dim: int, num_experts: int, capacity_factor: float = 1.25):
        super().__init__()
        self.num_experts = num_experts
        self.capacity_factor = capacity_factor
        
        self.gate = nn.Linear(input_dim, num_experts, bias=False)
        
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Switch routing with capacity constraints
        
        Args:
            x: Input tensor [batch_size, seq_len, hidden_dim]
            
        Returns:
            dispatch_tensor: Binary tensor indicating token-expert assignment
            combine_tensor: Tensor for combining expert outputs  
            load_balancing_loss: Loss for load balancing
        """
        batch_size, seq_len, hidden_dim = x.shape
        num_tokens = batch_size * seq_len
        
        # Compute capacity per expert
        capacity = int(self.capacity_factor * num_tokens / self.num_experts)
        
        # Flatten input
        x_flat = x.view(-1, hidden_dim)
        
        # Compute gate probabilities
        gate_logits = self.gate(x_flat)
        gate_probs = F.softmax(gate_logits, dim=-1)
        
        # Select top expert for each token
        expert_probs, expert_indices = torch.max(gate_probs, dim=-1)
        
        # Create dispatch tensor with capacity constraints
        dispatch_tensor = torch.zeros(num_tokens, self.num_experts, device=x.device)
        combine_tensor = torch.zeros(num_tokens, self.num_experts, device=x.device)
        
        # Track capacity usage per expert
        expert_counts = torch.zeros(self.num_experts, device=x.device)
        
        for token_idx in range(num_tokens):
            expert_idx = expert_indices[token_idx].item()
            if expert_counts[expert_idx] < capacity:
                dispatch_tensor[token_idx, expert_idx] = 1.0
                combine_tensor[token_idx, expert_idx] = expert_probs[token_idx]
                expert_counts[expert_idx] += 1
        
        # Compute load balancing loss
        gate_mean = torch.mean(gate_probs, dim=0)
        expert_fraction = expert_counts / num_tokens
        load_balancing_loss = self.num_experts * torch.sum(gate_mean * expert_fraction)
        
        return dispatch_tensor, combine_tensor, load_balancing_loss

class MixtureOfExperts(nn.Module):
    """Complete Mixture of Experts layer"""
    
    def __init__(self, config: MoEConfig, input_dim: int, expert_hidden_dim: int):
        super().__init__()
        self.config = config
        self.input_dim = input_dim
        self.expert_hidden_dim = expert_hidden_dim
        
        # Create experts
        self.experts = nn.ModuleList([
            Expert(input_dim, expert_hidden_dim, input_dim, config.expert_dropout)
            for _ in range(config.num_experts)
        ])
        
        # Create router
        if config.top_k > 1:
            self.router = TopKRouter(
                input_dim, config.num_experts, config.top_k,
                config.gate_noise, config.capacity_factor
            )
            self.routing_strategy = 'top_k'
        else:
            self.router = SwitchRouter(
                input_dim, config.num_experts, config.capacity_factor
            )
            self.routing_strategy = 'switch'
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Forward pass through MoE layer
        
        Args:
            x: Input tensor [batch_size, seq_len, hidden_dim]
            
        Returns:
            output: Mixed expert outputs [batch_size, seq_len, hidden_dim]
            aux_losses: Dictionary of auxiliary losses for training
        """
        batch_size, seq_len, hidden_dim = x.shape
        original_shape = x.shape
        x_flat = x.view(-1, hidden_dim)
        
        if self.routing_strategy == 'top_k':
            return self._forward_top_k(x, x_flat, original_shape)
        else:
            return self._forward_switch(x, x_flat, original_shape)
    
    def _forward_top_k(self, x: torch.Tensor, x_flat: torch.Tensor, 
                      original_shape: Tuple[int, ...]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """Top-K routing forward pass"""
        expert_weights, expert_indices, load_loss, router_z_loss = self.router(x)
        
        # Flatten for processing
        expert_weights_flat = expert_weights.view(-1, self.config.top_k)
        expert_indices_flat = expert_indices.view(-1, self.config.top_k)
        
        # Process tokens through selected experts
        output = torch.zeros_like(x_flat)
        
        for k in range(self.config.top_k):
            # Get tokens and their assigned experts for this k
            token_expert_indices = expert_indices_flat[:, k]
            token_weights = expert_weights_flat[:, k]
            
            # Process each expert
            for expert_idx in range(self.config.num_experts):
                # Find tokens assigned to this expert
                expert_mask = (token_expert_indices == expert_idx)
                if not expert_mask.any():
                    continue
                
                # Extract tokens for this expert
                expert_tokens = x_flat[expert_mask]
                
                # Process through expert
                expert_output = self.experts[expert_idx](expert_tokens)
                
                # Weighted contribution back to output
                weights = token_weights[expert_mask].unsqueeze(-1)
                output[expert_mask] += weights * expert_output
        
        # Reshape back to original dimensions
        output = output.view(original_shape)
        
        # Auxiliary losses
        aux_losses = {
            'load_balancing_loss': load_loss * self.config.load_balancing_loss_weight,
            'router_z_loss': router_z_loss * self.config.router_z_loss_weight
        }
        
        return output, aux_losses
    
    def _forward_switch(self, x: torch.Tensor, x_flat: torch.Tensor,
                       original_shape: Tuple[int, ...]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """Switch routing forward pass"""
        dispatch_tensor, combine_tensor, load_loss = self.router(x)
        
        # Dispatch tokens to experts
        expert_inputs = torch.einsum('te,th->eh', dispatch_tensor, x_flat)
        
        # Process through experts
        expert_outputs = []
        for expert_idx, expert in enumerate(self.experts):
            expert_input = expert_inputs[expert_idx]
            if expert_input.numel() > 0:
                expert_output = expert(expert_input.unsqueeze(0)).squeeze(0)
                expert_outputs.append(expert_output)
            else:
                expert_outputs.append(torch.zeros_like(expert_inputs[expert_idx]))
        
        expert_outputs = torch.stack(expert_outputs)  # [num_experts, hidden_dim]
        
        # Combine expert outputs
        output = torch.einsum('te,eh->th', combine_tensor, expert_outputs)
        output = output.view(original_shape)
        
        aux_losses = {
            'load_balancing_loss': load_loss * self.config.load_balancing_loss_weight
        }
        
        return output, aux_losses

class MoETransformerLayer(nn.Module):
    """Transformer layer with MoE feedforward"""
    
    def __init__(self, hidden_dim: int, num_heads: int, moe_config: MoEConfig,
                 expert_hidden_dim: int = None):
        super().__init__()
        
        if expert_hidden_dim is None:
            expert_hidden_dim = hidden_dim * 4
        
        # Self-attention
        self.self_attn = nn.MultiheadAttention(
            hidden_dim, num_heads, batch_first=True
        )
        
        # Layer norms
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        
        # MoE feedforward
        self.moe = MixtureOfExperts(moe_config, hidden_dim, expert_hidden_dim)
        
        # Dropout
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Forward pass through MoE transformer layer
        
        Args:
            x: Input tensor [batch_size, seq_len, hidden_dim]
            mask: Optional attention mask
            
        Returns:
            output: Layer output [batch_size, seq_len, hidden_dim]  
            aux_losses: MoE auxiliary losses
        """
        # Self-attention with residual connection
        attn_out, _ = self.self_attn(x, x, x, attn_mask=mask)
        x = self.norm1(x + self.dropout(attn_out))
        
        # MoE feedforward with residual connection
        moe_out, aux_losses = self.moe(x)
        x = self.norm2(x + self.dropout(moe_out))
        
        return x, aux_losses

# Usage example
def create_moe_model():
    """Example of creating a MoE transformer model"""
    
    # MoE configuration
    moe_config = MoEConfig(
        num_experts=32,
        top_k=2,
        capacity_factor=1.25,
        gate_noise=0.1,
        load_balancing_loss_weight=0.01
    )
    
    # Model dimensions
    hidden_dim = 768
    num_heads = 12
    num_layers = 12
    vocab_size = 50000
    
    # Create model layers
    embedding = nn.Embedding(vocab_size, hidden_dim)
    layers = nn.ModuleList([
        MoETransformerLayer(hidden_dim, num_heads, moe_config)
        for _ in range(num_layers)
    ])
    output_projection = nn.Linear(hidden_dim, vocab_size)
    
    print(f"Created MoE model with {moe_config.num_experts} experts per layer")
    print(f"Total experts: {num_layers * moe_config.num_experts}")
    print(f"Active experts per token: {moe_config.top_k}")
    
    return {
        'embedding': embedding,
        'layers': layers, 
        'output_projection': output_projection,
        'config': moe_config
    }

def moe_training_step(model_components, input_ids, target_ids, optimizer):
    """Example training step with MoE losses"""
    
    # Forward pass
    x = model_components['embedding'](input_ids)
    
    total_aux_loss = 0
    for layer in model_components['layers']:
        x, aux_losses = layer(x)
        total_aux_loss += sum(aux_losses.values())
    
    logits = model_components['output_projection'](x)
    
    # Main loss (e.g., cross-entropy)
    main_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_ids.view(-1))
    
    # Total loss with auxiliary losses
    total_loss = main_loss + total_aux_loss
    
    # Backward pass
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    return {
        'main_loss': main_loss.item(),
        'aux_loss': total_aux_loss.item(), 
        'total_loss': total_loss.item()
    }

if __name__ == "__main__":
    # Create and demonstrate MoE model
    model_components = create_moe_model()
    
    # Example forward pass
    batch_size, seq_len = 4, 256
    input_ids = torch.randint(0, 50000, (batch_size, seq_len))
    
    print("\nRunning example forward pass...")
    
    x = model_components['embedding'](input_ids)
    print(f"Input shape: {x.shape}")
    
    # Process through first MoE layer
    layer = model_components['layers'][0]
    output, aux_losses = layer(x)
    
    print(f"Output shape: {output.shape}")
    print(f"Auxiliary losses: {aux_losses}")
    
    print("\nMoE model demonstration completed!")

Scaling Strategies Comparison

Dense Transformer

Traditional transformer with all parameters active

Computational Efficiency

Low

Scalability

Poor

Switch Transformer Architecture

Google's Switch Transformer Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Dict, Optional

class SwitchTransformer(nn.Module):
    """
    Google's Switch Transformer implementation
    Paper: "Switch Transformer: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity"
    """
    
    def __init__(self, vocab_size: int, hidden_dim: int, num_layers: int, 
                 num_heads: int, num_experts: int, capacity_factor: float = 1.25,
                 expert_capacity: Optional[int] = None):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.num_experts = num_experts
        self.capacity_factor = capacity_factor
        self.expert_capacity = expert_capacity
        
        # Token and position embeddings
        self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
        self.position_embedding = nn.Embedding(2048, hidden_dim)  # Max sequence length
        
        # Transformer layers with Switch routing
        self.layers = nn.ModuleList([
            SwitchTransformerLayer(
                hidden_dim=hidden_dim,
                num_heads=num_heads,
                num_experts=num_experts,
                capacity_factor=capacity_factor,
                expert_capacity=expert_capacity
            ) for _ in range(num_layers)
        ])
        
        # Output layer
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)
        
        # Tie input and output embeddings (common in large LMs)
        self.lm_head.weight = self.token_embedding.weight
        
    def forward(self, input_ids: torch.Tensor, 
                attention_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Forward pass through Switch Transformer
        
        Args:
            input_ids: Token IDs [batch_size, seq_len]
            attention_mask: Attention mask [batch_size, seq_len]
            
        Returns:
            logits: Output logits [batch_size, seq_len, vocab_size]
            aux_losses: Auxiliary losses for routing
        """
        batch_size, seq_len = input_ids.shape
        
        # Embeddings
        token_embeds = self.token_embedding(input_ids)
        position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        position_embeds = self.position_embedding(position_ids)
        
        hidden_states = token_embeds + position_embeds
        
        # Accumulate auxiliary losses
        total_aux_loss = 0.0
        router_z_losses = []
        load_balancing_losses = []
        
        # Process through Switch layers
        for layer in self.layers:
            hidden_states, layer_aux_losses = layer(hidden_states, attention_mask)
            
            if 'load_balancing_loss' in layer_aux_losses:
                load_balancing_losses.append(layer_aux_losses['load_balancing_loss'])
            if 'router_z_loss' in layer_aux_losses:
                router_z_losses.append(layer_aux_losses['router_z_loss'])
        
        # Final layer norm
        hidden_states = self.layer_norm(hidden_states)
        
        # Language modeling head
        logits = self.lm_head(hidden_states)
        
        # Aggregate auxiliary losses
        aux_losses = {}
        if load_balancing_losses:
            aux_losses['load_balancing_loss'] = torch.stack(load_balancing_losses).mean()
        if router_z_losses:
            aux_losses['router_z_loss'] = torch.stack(router_z_losses).mean()
        
        return logits, aux_losses
    
    def get_num_params(self) -> Dict[str, int]:
        """Calculate parameter counts"""
        total_params = sum(p.numel() for p in self.parameters())
        
        # Calculate expert parameters
        expert_params = 0
        for layer in self.layers:
            if hasattr(layer, 'switch_ffn'):
                expert_params += sum(p.numel() for p in layer.switch_ffn.experts.parameters())
        
        shared_params = total_params - expert_params
        
        # Calculate active parameters (assuming top-1 routing)
        active_expert_params = expert_params // self.num_experts
        active_params = shared_params + active_expert_params
        
        return {
            'total_params': total_params,
            'shared_params': shared_params,
            'expert_params': expert_params,
            'active_params': active_params,
            'sparsity_ratio': 1 - (active_params / total_params)
        }

class SwitchTransformerLayer(nn.Module):
    """Individual Switch Transformer layer"""
    
    def __init__(self, hidden_dim: int, num_heads: int, num_experts: int,
                 capacity_factor: float = 1.25, expert_capacity: Optional[int] = None):
        super().__init__()
        
        # Self-attention
        self.self_attn = nn.MultiheadAttention(
            hidden_dim, num_heads, dropout=0.1, batch_first=True
        )
        self.attn_layer_norm = nn.LayerNorm(hidden_dim)
        
        # Switch Feed-Forward Network
        self.switch_ffn = SwitchFeedForward(
            hidden_dim=hidden_dim,
            num_experts=num_experts,
            capacity_factor=capacity_factor,
            expert_capacity=expert_capacity
        )
        self.ffn_layer_norm = nn.LayerNorm(hidden_dim)
        
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, hidden_states: torch.Tensor,
                attention_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """Layer forward pass with Switch FFN"""
        
        # Self-attention with residual connection
        attn_output, _ = self.self_attn(
            hidden_states, hidden_states, hidden_states, 
            attn_mask=attention_mask
        )
        hidden_states = self.attn_layer_norm(hidden_states + self.dropout(attn_output))
        
        # Switch FFN with residual connection
        ffn_output, aux_losses = self.switch_ffn(hidden_states)
        hidden_states = self.ffn_layer_norm(hidden_states + self.dropout(ffn_output))
        
        return hidden_states, aux_losses

class SwitchFeedForward(nn.Module):
    """Switch Feed-Forward with expert routing"""
    
    def __init__(self, hidden_dim: int, num_experts: int, 
                 capacity_factor: float = 1.25, expert_capacity: Optional[int] = None,
                 expert_hidden_dim: Optional[int] = None):
        super().__init__()
        
        if expert_hidden_dim is None:
            expert_hidden_dim = hidden_dim * 4
            
        self.hidden_dim = hidden_dim
        self.num_experts = num_experts
        self.capacity_factor = capacity_factor
        self.expert_capacity = expert_capacity
        
        # Router (gating network)
        self.router = nn.Linear(hidden_dim, num_experts, bias=False)
        
        # Expert networks
        self.experts = nn.ModuleList([
            SwitchExpert(hidden_dim, expert_hidden_dim) 
            for _ in range(num_experts)
        ])
        
        # For load balancing
        self.register_buffer('expert_counts', torch.zeros(num_experts))
        
    def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Switch routing forward pass
        
        Args:
            hidden_states: Input [batch_size, seq_len, hidden_dim]
            
        Returns:
            output: Routed expert outputs [batch_size, seq_len, hidden_dim]
            aux_losses: Routing auxiliary losses
        """
        batch_size, seq_len, hidden_dim = hidden_states.shape
        
        # Flatten for routing
        hidden_flat = hidden_states.view(-1, hidden_dim)  # [num_tokens, hidden_dim]
        num_tokens = hidden_flat.shape[0]
        
        # Calculate expert capacity
        if self.expert_capacity is not None:
            capacity = self.expert_capacity
        else:
            capacity = int(self.capacity_factor * num_tokens / self.num_experts)
        
        # Router forward pass
        router_logits = self.router(hidden_flat)  # [num_tokens, num_experts]
        router_probs = F.softmax(router_logits, dim=-1)
        
        # Select top expert for each token
        expert_gate, expert_index = torch.max(router_probs, dim=-1)
        
        # Create routing tensors
        expert_mask = F.one_hot(expert_index, num_classes=self.num_experts)  # [num_tokens, num_experts]
        
        # Apply capacity constraints
        position_in_expert = torch.cumsum(expert_mask, dim=0) * expert_mask - expert_mask
        capacity_mask = position_in_expert < capacity
        expert_mask = expert_mask * capacity_mask.float()
        
        # Combine weights for output combination  
        combine_weights = expert_gate.unsqueeze(-1) * expert_mask  # [num_tokens, num_experts]
        
        # Route tokens to experts and process
        output = torch.zeros_like(hidden_flat)
        
        for expert_idx in range(self.num_experts):
            # Get tokens assigned to this expert
            expert_tokens_mask = expert_mask[:, expert_idx].bool()
            if expert_tokens_mask.sum() == 0:
                continue
                
            # Extract tokens for this expert
            expert_input = hidden_flat[expert_tokens_mask]  # [num_expert_tokens, hidden_dim]
            
            # Process through expert
            expert_output = self.experts[expert_idx](expert_input)  # [num_expert_tokens, hidden_dim]
            
            # Route back to output positions
            output[expert_tokens_mask] = expert_output
        
        # Apply combine weights
        output = output * expert_gate.unsqueeze(-1)
        
        # Reshape back to original dimensions
        output = output.view(batch_size, seq_len, hidden_dim)
        
        # Compute auxiliary losses
        aux_losses = self._compute_auxiliary_losses(router_probs, expert_mask, router_logits)
        
        return output, aux_losses
    
    def _compute_auxiliary_losses(self, router_probs: torch.Tensor, 
                                expert_mask: torch.Tensor, 
                                router_logits: torch.Tensor) -> Dict[str, torch.Tensor]:
        """Compute routing auxiliary losses"""
        
        # Load balancing loss
        # Encourage equal expert utilization
        gates_mean = torch.mean(router_probs, dim=0)  # Average gate probability per expert
        expert_counts = torch.mean(expert_mask.float(), dim=0)  # Fraction of tokens per expert
        
        load_balancing_loss = self.num_experts * torch.sum(gates_mean * expert_counts)
        
        # Router z-loss (prevent logits from growing too large)
        router_z_loss = torch.mean(torch.logsumexp(router_logits, dim=-1) ** 2)
        
        return {
            'load_balancing_loss': load_balancing_loss,
            'router_z_loss': router_z_loss * 1e-3  # Small weight for z-loss
        }

class SwitchExpert(nn.Module):
    """Individual expert in Switch Transformer"""
    
    def __init__(self, hidden_dim: int, expert_hidden_dim: int):
        super().__init__()
        
        # GLU-style expert with gating
        self.w1 = nn.Linear(hidden_dim, expert_hidden_dim, bias=False)  # Gate
        self.w2 = nn.Linear(expert_hidden_dim, hidden_dim, bias=False)  # Down projection  
        self.w3 = nn.Linear(hidden_dim, expert_hidden_dim, bias=False)  # Up projection
        
        self._init_weights()
        
    def _init_weights(self):
        """Initialize expert weights"""
        # Use different initialization for expert specialization
        for linear in [self.w1, self.w2, self.w3]:
            nn.init.normal_(linear.weight, mean=0.0, std=0.02)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Expert forward pass with GLU activation"""
        gate = F.silu(self.w1(x))  # SiLU activation for gate
        up = self.w3(x)           # Linear projection
        return self.w2(gate * up)  # GLU combination

# Usage and training example
def train_switch_transformer():
    """Example training setup for Switch Transformer"""
    
    # Model configuration
    vocab_size = 32000
    hidden_dim = 1024
    num_layers = 24
    num_heads = 16
    num_experts = 128
    capacity_factor = 1.25
    
    # Create model
    model = SwitchTransformer(
        vocab_size=vocab_size,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        num_heads=num_heads,
        num_experts=num_experts,
        capacity_factor=capacity_factor
    )
    
    # Print model statistics
    param_stats = model.get_num_params()
    print("Switch Transformer Model Statistics:")
    for key, value in param_stats.items():
        if 'params' in key:
            print(f"{key}: {value:,} ({value/1e9:.2f}B)")
        else:
            print(f"{key}: {value:.3f}")
    
    # Training setup
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.1)
    
    # Example training step
    batch_size, seq_len = 8, 512
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
    target_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
    
    # Forward pass
    logits, aux_losses = model(input_ids)
    
    # Compute losses
    lm_loss = F.cross_entropy(logits.view(-1, vocab_size), target_ids.view(-1))
    
    # Add auxiliary losses with weights
    total_loss = lm_loss
    if 'load_balancing_loss' in aux_losses:
        total_loss += 0.01 * aux_losses['load_balancing_loss']  # Load balancing weight
    if 'router_z_loss' in aux_losses:
        total_loss += aux_losses['router_z_loss']  # Z-loss already weighted
    
    # Backward pass
    optimizer.zero_grad()
    total_loss.backward()
    
    # Gradient clipping (important for large models)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    optimizer.step()
    
    print(f"\nTraining step completed:")
    print(f"Language modeling loss: {lm_loss.item():.4f}")
    if aux_losses:
        for loss_name, loss_value in aux_losses.items():
            print(f"{loss_name}: {loss_value.item():.6f}")
    print(f"Total loss: {total_loss.item():.4f}")
    
    return model, param_stats

if __name__ == "__main__":
    model, stats = train_switch_transformer()

Real-World MoE Applications

G

Google - Switch Transformer

Google's Switch Transformer scales to 1.6 trillion parameters while using the same computational cost as a 175B dense model, achieving 4x speedup in pre-training.

  • • 1.6T parameters with 2048 experts
  • • 4x faster pre-training than dense models
  • • Top-1 routing with capacity constraints
  • • Load balancing and z-loss regularization
M

Meta - Expert Choice

Meta's Expert Choice routing improves upon Switch Transformer by letting experts choose which tokens to process, achieving better load balancing and performance.

  • • Expert-chooses-token paradigm
  • • Improved load balancing
  • • Better scaling to more experts
  • • Reduced communication overhead
D

DeepMind - GLaM

DeepMind's GLaM (Generalist Language Model) uses MoE to achieve GPT-3 level performance with 3x less energy consumption during training.

  • • 1.2T parameters, 8B activated per token
  • • 3x more energy efficient than GPT-3
  • • Top-2 expert routing
  • • 64 experts per MoE layer
MS

Microsoft - Z-Code++

Microsoft's Z-Code++ applies MoE to code generation tasks, with experts specializing in different programming languages and paradigms.

  • • Language-specific expert specialization
  • • 15B parameters with 64 experts
  • • Code completion and generation
  • • Multi-language programming support

MoE Best Practices

✅ Do

Use load balancing losses to prevent expert collapse

Implement capacity constraints to manage memory

Monitor expert utilization and routing patterns

Use gradient clipping for training stability

Scale experts gradually during training

Optimize data parallel and model parallel strategies

Use router z-loss to prevent router collapse

❌ Don't

Ignore expert load imbalance during training

Use too many experts without proper capacity planning

Forget to account for communication overhead

Skip auxiliary loss terms in total loss calculation

Use MoE in every layer without analysis

Ignore expert specialization patterns during inference

Underestimate memory requirements for expert storage

No quiz questions available
Quiz ID "mixture-of-experts" not found