Mixture of Experts (MoE)
Master sparse model architectures with expert routing for efficient large-scale neural networks
What are Mixture of Experts?
Mixture of Experts (MoE) is a sparse neural network architecture where only a subset of model parameters (experts) are activated for each input. This enables scaling to massive model sizes while keeping computational cost constant, as routing mechanisms decide which experts should process each token or input.
Sparse Activation
Only 10-20% of parameters active per forward pass
Scalable Training
Scale to trillions of parameters with constant compute
Expert Specialization
Different experts learn specialized sub-tasks
MoE Efficiency Calculator
Analyze computational efficiency and resource usage of Mixture of Experts models
Efficiency Metrics
Expert Routing Strategies
Top-K Routing
Route tokens to the top-K highest scoring experts
Low
Poor
Good
Production MoE Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import numpy as np
from typing import Tuple, List, Optional, Dict
import math
from dataclasses import dataclass
@dataclass
class MoEConfig:
"""Configuration for Mixture of Experts"""
num_experts: int = 8
top_k: int = 2
capacity_factor: float = 1.25
expert_dropout: float = 0.0
gate_noise: float = 0.1
load_balancing_loss_weight: float = 0.01
router_z_loss_weight: float = 0.001
class Expert(nn.Module):
"""Individual expert network"""
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: float = 0.1):
super().__init__()
self.w1 = nn.Linear(input_dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, output_dim, bias=False)
self.w3 = nn.Linear(input_dim, hidden_dim, bias=False) # GLU gate
self.dropout = nn.Dropout(dropout)
# Expert-specific initialization
self._init_weights()
def _init_weights(self):
"""Initialize expert weights with different variance for specialization"""
for module in [self.w1, self.w2, self.w3]:
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""GLU-based expert forward pass"""
gate = torch.sigmoid(self.w3(x))
hidden = F.silu(self.w1(x)) * gate
hidden = self.dropout(hidden)
return self.w2(hidden)
class TopKRouter(nn.Module):
"""Top-K routing with load balancing"""
def __init__(self, input_dim: int, num_experts: int, top_k: int,
gate_noise: float = 0.1, capacity_factor: float = 1.25):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.gate_noise = gate_noise
self.capacity_factor = capacity_factor
# Gating network
self.gate = nn.Linear(input_dim, num_experts, bias=False)
self.gate_noise_scale = gate_noise
# For load balancing
self.register_buffer('expert_counts', torch.zeros(num_experts))
self.register_buffer('total_tokens', torch.tensor(0))
def _add_noise(self, logits: torch.Tensor) -> torch.Tensor:
"""Add noise during training for exploration"""
if self.training and self.gate_noise_scale > 0:
noise = torch.randn_like(logits) * self.gate_noise_scale
return logits + noise
return logits
def _compute_load_balancing_loss(self, gates: torch.Tensor,
expert_indices: torch.Tensor) -> torch.Tensor:
"""Compute load balancing loss to encourage even expert usage"""
# Count tokens per expert
expert_counts = torch.bincount(expert_indices.flatten(), minlength=self.num_experts)
# Compute mean gate values per expert
gates_mean = torch.mean(gates, dim=0)
# Load balancing loss (encourage uniform distribution)
load_loss = torch.sum(gates_mean * expert_counts.float()) * self.num_experts
load_loss = load_loss / (torch.sum(expert_counts) + 1e-6)
return load_loss
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Route tokens to experts
Args:
x: Input tensor [batch_size, seq_len, hidden_dim]
Returns:
expert_weights: Gating weights [batch_size, seq_len, top_k]
expert_indices: Selected expert indices [batch_size, seq_len, top_k]
load_balancing_loss: Loss to encourage uniform expert usage
router_z_loss: Loss to prevent router collapse
"""
batch_size, seq_len, hidden_dim = x.shape
# Reshape for routing
x_flat = x.view(-1, hidden_dim) # [batch_size * seq_len, hidden_dim]
# Compute gate logits
gate_logits = self.gate(x_flat) # [batch_size * seq_len, num_experts]
gate_logits = self._add_noise(gate_logits)
# Apply softmax to get probabilities
gate_probs = F.softmax(gate_logits, dim=-1)
# Select top-k experts
top_k_probs, top_k_indices = torch.topk(gate_probs, self.top_k, dim=-1)
# Normalize top-k probabilities
top_k_probs = top_k_probs / (torch.sum(top_k_probs, dim=-1, keepdim=True) + 1e-6)
# Compute losses
load_balancing_loss = self._compute_load_balancing_loss(gate_probs, top_k_indices)
# Router z-loss (prevent router collapse)
router_z_loss = torch.logsumexp(gate_logits, dim=-1).mean()
# Reshape back to original dimensions
expert_weights = top_k_probs.view(batch_size, seq_len, self.top_k)
expert_indices = top_k_indices.view(batch_size, seq_len, self.top_k)
return expert_weights, expert_indices, load_balancing_loss, router_z_loss
class SwitchRouter(nn.Module):
"""Switch Transformer routing with capacity constraints"""
def __init__(self, input_dim: int, num_experts: int, capacity_factor: float = 1.25):
super().__init__()
self.num_experts = num_experts
self.capacity_factor = capacity_factor
self.gate = nn.Linear(input_dim, num_experts, bias=False)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Switch routing with capacity constraints
Args:
x: Input tensor [batch_size, seq_len, hidden_dim]
Returns:
dispatch_tensor: Binary tensor indicating token-expert assignment
combine_tensor: Tensor for combining expert outputs
load_balancing_loss: Loss for load balancing
"""
batch_size, seq_len, hidden_dim = x.shape
num_tokens = batch_size * seq_len
# Compute capacity per expert
capacity = int(self.capacity_factor * num_tokens / self.num_experts)
# Flatten input
x_flat = x.view(-1, hidden_dim)
# Compute gate probabilities
gate_logits = self.gate(x_flat)
gate_probs = F.softmax(gate_logits, dim=-1)
# Select top expert for each token
expert_probs, expert_indices = torch.max(gate_probs, dim=-1)
# Create dispatch tensor with capacity constraints
dispatch_tensor = torch.zeros(num_tokens, self.num_experts, device=x.device)
combine_tensor = torch.zeros(num_tokens, self.num_experts, device=x.device)
# Track capacity usage per expert
expert_counts = torch.zeros(self.num_experts, device=x.device)
for token_idx in range(num_tokens):
expert_idx = expert_indices[token_idx].item()
if expert_counts[expert_idx] < capacity:
dispatch_tensor[token_idx, expert_idx] = 1.0
combine_tensor[token_idx, expert_idx] = expert_probs[token_idx]
expert_counts[expert_idx] += 1
# Compute load balancing loss
gate_mean = torch.mean(gate_probs, dim=0)
expert_fraction = expert_counts / num_tokens
load_balancing_loss = self.num_experts * torch.sum(gate_mean * expert_fraction)
return dispatch_tensor, combine_tensor, load_balancing_loss
class MixtureOfExperts(nn.Module):
"""Complete Mixture of Experts layer"""
def __init__(self, config: MoEConfig, input_dim: int, expert_hidden_dim: int):
super().__init__()
self.config = config
self.input_dim = input_dim
self.expert_hidden_dim = expert_hidden_dim
# Create experts
self.experts = nn.ModuleList([
Expert(input_dim, expert_hidden_dim, input_dim, config.expert_dropout)
for _ in range(config.num_experts)
])
# Create router
if config.top_k > 1:
self.router = TopKRouter(
input_dim, config.num_experts, config.top_k,
config.gate_noise, config.capacity_factor
)
self.routing_strategy = 'top_k'
else:
self.router = SwitchRouter(
input_dim, config.num_experts, config.capacity_factor
)
self.routing_strategy = 'switch'
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Forward pass through MoE layer
Args:
x: Input tensor [batch_size, seq_len, hidden_dim]
Returns:
output: Mixed expert outputs [batch_size, seq_len, hidden_dim]
aux_losses: Dictionary of auxiliary losses for training
"""
batch_size, seq_len, hidden_dim = x.shape
original_shape = x.shape
x_flat = x.view(-1, hidden_dim)
if self.routing_strategy == 'top_k':
return self._forward_top_k(x, x_flat, original_shape)
else:
return self._forward_switch(x, x_flat, original_shape)
def _forward_top_k(self, x: torch.Tensor, x_flat: torch.Tensor,
original_shape: Tuple[int, ...]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Top-K routing forward pass"""
expert_weights, expert_indices, load_loss, router_z_loss = self.router(x)
# Flatten for processing
expert_weights_flat = expert_weights.view(-1, self.config.top_k)
expert_indices_flat = expert_indices.view(-1, self.config.top_k)
# Process tokens through selected experts
output = torch.zeros_like(x_flat)
for k in range(self.config.top_k):
# Get tokens and their assigned experts for this k
token_expert_indices = expert_indices_flat[:, k]
token_weights = expert_weights_flat[:, k]
# Process each expert
for expert_idx in range(self.config.num_experts):
# Find tokens assigned to this expert
expert_mask = (token_expert_indices == expert_idx)
if not expert_mask.any():
continue
# Extract tokens for this expert
expert_tokens = x_flat[expert_mask]
# Process through expert
expert_output = self.experts[expert_idx](expert_tokens)
# Weighted contribution back to output
weights = token_weights[expert_mask].unsqueeze(-1)
output[expert_mask] += weights * expert_output
# Reshape back to original dimensions
output = output.view(original_shape)
# Auxiliary losses
aux_losses = {
'load_balancing_loss': load_loss * self.config.load_balancing_loss_weight,
'router_z_loss': router_z_loss * self.config.router_z_loss_weight
}
return output, aux_losses
def _forward_switch(self, x: torch.Tensor, x_flat: torch.Tensor,
original_shape: Tuple[int, ...]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Switch routing forward pass"""
dispatch_tensor, combine_tensor, load_loss = self.router(x)
# Dispatch tokens to experts
expert_inputs = torch.einsum('te,th->eh', dispatch_tensor, x_flat)
# Process through experts
expert_outputs = []
for expert_idx, expert in enumerate(self.experts):
expert_input = expert_inputs[expert_idx]
if expert_input.numel() > 0:
expert_output = expert(expert_input.unsqueeze(0)).squeeze(0)
expert_outputs.append(expert_output)
else:
expert_outputs.append(torch.zeros_like(expert_inputs[expert_idx]))
expert_outputs = torch.stack(expert_outputs) # [num_experts, hidden_dim]
# Combine expert outputs
output = torch.einsum('te,eh->th', combine_tensor, expert_outputs)
output = output.view(original_shape)
aux_losses = {
'load_balancing_loss': load_loss * self.config.load_balancing_loss_weight
}
return output, aux_losses
class MoETransformerLayer(nn.Module):
"""Transformer layer with MoE feedforward"""
def __init__(self, hidden_dim: int, num_heads: int, moe_config: MoEConfig,
expert_hidden_dim: int = None):
super().__init__()
if expert_hidden_dim is None:
expert_hidden_dim = hidden_dim * 4
# Self-attention
self.self_attn = nn.MultiheadAttention(
hidden_dim, num_heads, batch_first=True
)
# Layer norms
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
# MoE feedforward
self.moe = MixtureOfExperts(moe_config, hidden_dim, expert_hidden_dim)
# Dropout
self.dropout = nn.Dropout(0.1)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Forward pass through MoE transformer layer
Args:
x: Input tensor [batch_size, seq_len, hidden_dim]
mask: Optional attention mask
Returns:
output: Layer output [batch_size, seq_len, hidden_dim]
aux_losses: MoE auxiliary losses
"""
# Self-attention with residual connection
attn_out, _ = self.self_attn(x, x, x, attn_mask=mask)
x = self.norm1(x + self.dropout(attn_out))
# MoE feedforward with residual connection
moe_out, aux_losses = self.moe(x)
x = self.norm2(x + self.dropout(moe_out))
return x, aux_losses
# Usage example
def create_moe_model():
"""Example of creating a MoE transformer model"""
# MoE configuration
moe_config = MoEConfig(
num_experts=32,
top_k=2,
capacity_factor=1.25,
gate_noise=0.1,
load_balancing_loss_weight=0.01
)
# Model dimensions
hidden_dim = 768
num_heads = 12
num_layers = 12
vocab_size = 50000
# Create model layers
embedding = nn.Embedding(vocab_size, hidden_dim)
layers = nn.ModuleList([
MoETransformerLayer(hidden_dim, num_heads, moe_config)
for _ in range(num_layers)
])
output_projection = nn.Linear(hidden_dim, vocab_size)
print(f"Created MoE model with {moe_config.num_experts} experts per layer")
print(f"Total experts: {num_layers * moe_config.num_experts}")
print(f"Active experts per token: {moe_config.top_k}")
return {
'embedding': embedding,
'layers': layers,
'output_projection': output_projection,
'config': moe_config
}
def moe_training_step(model_components, input_ids, target_ids, optimizer):
"""Example training step with MoE losses"""
# Forward pass
x = model_components['embedding'](input_ids)
total_aux_loss = 0
for layer in model_components['layers']:
x, aux_losses = layer(x)
total_aux_loss += sum(aux_losses.values())
logits = model_components['output_projection'](x)
# Main loss (e.g., cross-entropy)
main_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_ids.view(-1))
# Total loss with auxiliary losses
total_loss = main_loss + total_aux_loss
# Backward pass
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
return {
'main_loss': main_loss.item(),
'aux_loss': total_aux_loss.item(),
'total_loss': total_loss.item()
}
if __name__ == "__main__":
# Create and demonstrate MoE model
model_components = create_moe_model()
# Example forward pass
batch_size, seq_len = 4, 256
input_ids = torch.randint(0, 50000, (batch_size, seq_len))
print("\nRunning example forward pass...")
x = model_components['embedding'](input_ids)
print(f"Input shape: {x.shape}")
# Process through first MoE layer
layer = model_components['layers'][0]
output, aux_losses = layer(x)
print(f"Output shape: {output.shape}")
print(f"Auxiliary losses: {aux_losses}")
print("\nMoE model demonstration completed!")Scaling Strategies Comparison
Dense Transformer
Traditional transformer with all parameters active
Low
Poor
Switch Transformer Architecture
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Dict, Optional
class SwitchTransformer(nn.Module):
"""
Google's Switch Transformer implementation
Paper: "Switch Transformer: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity"
"""
def __init__(self, vocab_size: int, hidden_dim: int, num_layers: int,
num_heads: int, num_experts: int, capacity_factor: float = 1.25,
expert_capacity: Optional[int] = None):
super().__init__()
self.hidden_dim = hidden_dim
self.num_experts = num_experts
self.capacity_factor = capacity_factor
self.expert_capacity = expert_capacity
# Token and position embeddings
self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
self.position_embedding = nn.Embedding(2048, hidden_dim) # Max sequence length
# Transformer layers with Switch routing
self.layers = nn.ModuleList([
SwitchTransformerLayer(
hidden_dim=hidden_dim,
num_heads=num_heads,
num_experts=num_experts,
capacity_factor=capacity_factor,
expert_capacity=expert_capacity
) for _ in range(num_layers)
])
# Output layer
self.layer_norm = nn.LayerNorm(hidden_dim)
self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)
# Tie input and output embeddings (common in large LMs)
self.lm_head.weight = self.token_embedding.weight
def forward(self, input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Forward pass through Switch Transformer
Args:
input_ids: Token IDs [batch_size, seq_len]
attention_mask: Attention mask [batch_size, seq_len]
Returns:
logits: Output logits [batch_size, seq_len, vocab_size]
aux_losses: Auxiliary losses for routing
"""
batch_size, seq_len = input_ids.shape
# Embeddings
token_embeds = self.token_embedding(input_ids)
position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
position_embeds = self.position_embedding(position_ids)
hidden_states = token_embeds + position_embeds
# Accumulate auxiliary losses
total_aux_loss = 0.0
router_z_losses = []
load_balancing_losses = []
# Process through Switch layers
for layer in self.layers:
hidden_states, layer_aux_losses = layer(hidden_states, attention_mask)
if 'load_balancing_loss' in layer_aux_losses:
load_balancing_losses.append(layer_aux_losses['load_balancing_loss'])
if 'router_z_loss' in layer_aux_losses:
router_z_losses.append(layer_aux_losses['router_z_loss'])
# Final layer norm
hidden_states = self.layer_norm(hidden_states)
# Language modeling head
logits = self.lm_head(hidden_states)
# Aggregate auxiliary losses
aux_losses = {}
if load_balancing_losses:
aux_losses['load_balancing_loss'] = torch.stack(load_balancing_losses).mean()
if router_z_losses:
aux_losses['router_z_loss'] = torch.stack(router_z_losses).mean()
return logits, aux_losses
def get_num_params(self) -> Dict[str, int]:
"""Calculate parameter counts"""
total_params = sum(p.numel() for p in self.parameters())
# Calculate expert parameters
expert_params = 0
for layer in self.layers:
if hasattr(layer, 'switch_ffn'):
expert_params += sum(p.numel() for p in layer.switch_ffn.experts.parameters())
shared_params = total_params - expert_params
# Calculate active parameters (assuming top-1 routing)
active_expert_params = expert_params // self.num_experts
active_params = shared_params + active_expert_params
return {
'total_params': total_params,
'shared_params': shared_params,
'expert_params': expert_params,
'active_params': active_params,
'sparsity_ratio': 1 - (active_params / total_params)
}
class SwitchTransformerLayer(nn.Module):
"""Individual Switch Transformer layer"""
def __init__(self, hidden_dim: int, num_heads: int, num_experts: int,
capacity_factor: float = 1.25, expert_capacity: Optional[int] = None):
super().__init__()
# Self-attention
self.self_attn = nn.MultiheadAttention(
hidden_dim, num_heads, dropout=0.1, batch_first=True
)
self.attn_layer_norm = nn.LayerNorm(hidden_dim)
# Switch Feed-Forward Network
self.switch_ffn = SwitchFeedForward(
hidden_dim=hidden_dim,
num_experts=num_experts,
capacity_factor=capacity_factor,
expert_capacity=expert_capacity
)
self.ffn_layer_norm = nn.LayerNorm(hidden_dim)
self.dropout = nn.Dropout(0.1)
def forward(self, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Layer forward pass with Switch FFN"""
# Self-attention with residual connection
attn_output, _ = self.self_attn(
hidden_states, hidden_states, hidden_states,
attn_mask=attention_mask
)
hidden_states = self.attn_layer_norm(hidden_states + self.dropout(attn_output))
# Switch FFN with residual connection
ffn_output, aux_losses = self.switch_ffn(hidden_states)
hidden_states = self.ffn_layer_norm(hidden_states + self.dropout(ffn_output))
return hidden_states, aux_losses
class SwitchFeedForward(nn.Module):
"""Switch Feed-Forward with expert routing"""
def __init__(self, hidden_dim: int, num_experts: int,
capacity_factor: float = 1.25, expert_capacity: Optional[int] = None,
expert_hidden_dim: Optional[int] = None):
super().__init__()
if expert_hidden_dim is None:
expert_hidden_dim = hidden_dim * 4
self.hidden_dim = hidden_dim
self.num_experts = num_experts
self.capacity_factor = capacity_factor
self.expert_capacity = expert_capacity
# Router (gating network)
self.router = nn.Linear(hidden_dim, num_experts, bias=False)
# Expert networks
self.experts = nn.ModuleList([
SwitchExpert(hidden_dim, expert_hidden_dim)
for _ in range(num_experts)
])
# For load balancing
self.register_buffer('expert_counts', torch.zeros(num_experts))
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Switch routing forward pass
Args:
hidden_states: Input [batch_size, seq_len, hidden_dim]
Returns:
output: Routed expert outputs [batch_size, seq_len, hidden_dim]
aux_losses: Routing auxiliary losses
"""
batch_size, seq_len, hidden_dim = hidden_states.shape
# Flatten for routing
hidden_flat = hidden_states.view(-1, hidden_dim) # [num_tokens, hidden_dim]
num_tokens = hidden_flat.shape[0]
# Calculate expert capacity
if self.expert_capacity is not None:
capacity = self.expert_capacity
else:
capacity = int(self.capacity_factor * num_tokens / self.num_experts)
# Router forward pass
router_logits = self.router(hidden_flat) # [num_tokens, num_experts]
router_probs = F.softmax(router_logits, dim=-1)
# Select top expert for each token
expert_gate, expert_index = torch.max(router_probs, dim=-1)
# Create routing tensors
expert_mask = F.one_hot(expert_index, num_classes=self.num_experts) # [num_tokens, num_experts]
# Apply capacity constraints
position_in_expert = torch.cumsum(expert_mask, dim=0) * expert_mask - expert_mask
capacity_mask = position_in_expert < capacity
expert_mask = expert_mask * capacity_mask.float()
# Combine weights for output combination
combine_weights = expert_gate.unsqueeze(-1) * expert_mask # [num_tokens, num_experts]
# Route tokens to experts and process
output = torch.zeros_like(hidden_flat)
for expert_idx in range(self.num_experts):
# Get tokens assigned to this expert
expert_tokens_mask = expert_mask[:, expert_idx].bool()
if expert_tokens_mask.sum() == 0:
continue
# Extract tokens for this expert
expert_input = hidden_flat[expert_tokens_mask] # [num_expert_tokens, hidden_dim]
# Process through expert
expert_output = self.experts[expert_idx](expert_input) # [num_expert_tokens, hidden_dim]
# Route back to output positions
output[expert_tokens_mask] = expert_output
# Apply combine weights
output = output * expert_gate.unsqueeze(-1)
# Reshape back to original dimensions
output = output.view(batch_size, seq_len, hidden_dim)
# Compute auxiliary losses
aux_losses = self._compute_auxiliary_losses(router_probs, expert_mask, router_logits)
return output, aux_losses
def _compute_auxiliary_losses(self, router_probs: torch.Tensor,
expert_mask: torch.Tensor,
router_logits: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Compute routing auxiliary losses"""
# Load balancing loss
# Encourage equal expert utilization
gates_mean = torch.mean(router_probs, dim=0) # Average gate probability per expert
expert_counts = torch.mean(expert_mask.float(), dim=0) # Fraction of tokens per expert
load_balancing_loss = self.num_experts * torch.sum(gates_mean * expert_counts)
# Router z-loss (prevent logits from growing too large)
router_z_loss = torch.mean(torch.logsumexp(router_logits, dim=-1) ** 2)
return {
'load_balancing_loss': load_balancing_loss,
'router_z_loss': router_z_loss * 1e-3 # Small weight for z-loss
}
class SwitchExpert(nn.Module):
"""Individual expert in Switch Transformer"""
def __init__(self, hidden_dim: int, expert_hidden_dim: int):
super().__init__()
# GLU-style expert with gating
self.w1 = nn.Linear(hidden_dim, expert_hidden_dim, bias=False) # Gate
self.w2 = nn.Linear(expert_hidden_dim, hidden_dim, bias=False) # Down projection
self.w3 = nn.Linear(hidden_dim, expert_hidden_dim, bias=False) # Up projection
self._init_weights()
def _init_weights(self):
"""Initialize expert weights"""
# Use different initialization for expert specialization
for linear in [self.w1, self.w2, self.w3]:
nn.init.normal_(linear.weight, mean=0.0, std=0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Expert forward pass with GLU activation"""
gate = F.silu(self.w1(x)) # SiLU activation for gate
up = self.w3(x) # Linear projection
return self.w2(gate * up) # GLU combination
# Usage and training example
def train_switch_transformer():
"""Example training setup for Switch Transformer"""
# Model configuration
vocab_size = 32000
hidden_dim = 1024
num_layers = 24
num_heads = 16
num_experts = 128
capacity_factor = 1.25
# Create model
model = SwitchTransformer(
vocab_size=vocab_size,
hidden_dim=hidden_dim,
num_layers=num_layers,
num_heads=num_heads,
num_experts=num_experts,
capacity_factor=capacity_factor
)
# Print model statistics
param_stats = model.get_num_params()
print("Switch Transformer Model Statistics:")
for key, value in param_stats.items():
if 'params' in key:
print(f"{key}: {value:,} ({value/1e9:.2f}B)")
else:
print(f"{key}: {value:.3f}")
# Training setup
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.1)
# Example training step
batch_size, seq_len = 8, 512
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
target_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
# Forward pass
logits, aux_losses = model(input_ids)
# Compute losses
lm_loss = F.cross_entropy(logits.view(-1, vocab_size), target_ids.view(-1))
# Add auxiliary losses with weights
total_loss = lm_loss
if 'load_balancing_loss' in aux_losses:
total_loss += 0.01 * aux_losses['load_balancing_loss'] # Load balancing weight
if 'router_z_loss' in aux_losses:
total_loss += aux_losses['router_z_loss'] # Z-loss already weighted
# Backward pass
optimizer.zero_grad()
total_loss.backward()
# Gradient clipping (important for large models)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
print(f"\nTraining step completed:")
print(f"Language modeling loss: {lm_loss.item():.4f}")
if aux_losses:
for loss_name, loss_value in aux_losses.items():
print(f"{loss_name}: {loss_value.item():.6f}")
print(f"Total loss: {total_loss.item():.4f}")
return model, param_stats
if __name__ == "__main__":
model, stats = train_switch_transformer()Real-World MoE Applications
Google - Switch Transformer
Google's Switch Transformer scales to 1.6 trillion parameters while using the same computational cost as a 175B dense model, achieving 4x speedup in pre-training.
- • 1.6T parameters with 2048 experts
- • 4x faster pre-training than dense models
- • Top-1 routing with capacity constraints
- • Load balancing and z-loss regularization
Meta - Expert Choice
Meta's Expert Choice routing improves upon Switch Transformer by letting experts choose which tokens to process, achieving better load balancing and performance.
- • Expert-chooses-token paradigm
- • Improved load balancing
- • Better scaling to more experts
- • Reduced communication overhead
DeepMind - GLaM
DeepMind's GLaM (Generalist Language Model) uses MoE to achieve GPT-3 level performance with 3x less energy consumption during training.
- • 1.2T parameters, 8B activated per token
- • 3x more energy efficient than GPT-3
- • Top-2 expert routing
- • 64 experts per MoE layer
Microsoft - Z-Code++
Microsoft's Z-Code++ applies MoE to code generation tasks, with experts specializing in different programming languages and paradigms.
- • Language-specific expert specialization
- • 15B parameters with 64 experts
- • Code completion and generation
- • Multi-language programming support
MoE Best Practices
✅ Do
Use load balancing losses to prevent expert collapse
Implement capacity constraints to manage memory
Monitor expert utilization and routing patterns
Use gradient clipping for training stability
Scale experts gradually during training
Optimize data parallel and model parallel strategies
Use router z-loss to prevent router collapse
❌ Don't
Ignore expert load imbalance during training
Use too many experts without proper capacity planning
Forget to account for communication overhead
Skip auxiliary loss terms in total loss calculation
Use MoE in every layer without analysis
Ignore expert specialization patterns during inference
Underestimate memory requirements for expert storage