Knowledge Distillation & Model Compression
Master teacher-student training, pruning, quantization, and deployment optimization techniques
50 min read•Advanced
Not Started
Loading...
What is Knowledge Distillation?
Knowledge Distillation is a model compression technique where a smaller "student" model learns to mimic the behavior of a larger "teacher" model, achieving similar performance with significantly fewer parameters.
Model Compression
90%+ size reduction possible
Performance Retention
95%+ accuracy preservation
Deployment Ready
Mobile & edge optimization
🧮 Knowledge Distillation Calculator
Calculate compression ratios, performance retention, and optimization metrics for knowledge distillation.
Distillation Analysis
Compression Ratio:4.00%
Teacher Accuracy:92.0%
Student Accuracy:89.1%
Accuracy Retention:96.8%
Memory Reduction:96.0%
Speed Improvement:5.0x
Cost Savings:80.0%
Knowledge Distillation Methods
Response Distillation
- • Learn from teacher's output probabilities
- • Soft targets with temperature scaling
- • Most common and straightforward approach
- • Good for classification tasks
- • Minimal architectural constraints
Feature Distillation
- • Match intermediate layer representations
- • Transfer rich feature knowledge
- • Better for complex tasks
- • Requires architectural alignment
- • Higher computational overhead
Attention Distillation
- • Transfer attention patterns
- • Preserve important relationships
- • Effective for transformer models
- • Better interpretability
- • Specialized for attention mechanisms
Progressive Distillation
- • Gradual knowledge transfer
- • Multiple intermediate models
- • Best compression ratios
- • Higher training complexity
- • Superior final performance
Knowledge Distillation Implementation
Basic Distillation Framework
Knowledge Distillation Core
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple, Optional
import numpy as np
class DistillationLoss(nn.Module):
"""Knowledge distillation loss combining soft and hard targets"""
def __init__(self,
temperature: float = 4.0,
alpha: float = 0.7,
reduction: str = 'mean'):
super().__init__()
self.temperature = temperature
self.alpha = alpha
self.reduction = reduction
self.kl_div = nn.KLDivLoss(reduction=reduction)
self.ce_loss = nn.CrossEntropyLoss(reduction=reduction)
def forward(self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
target_labels: torch.Tensor) -> Dict[str, torch.Tensor]:
# Soft targets from teacher (knowledge distillation loss)
student_soft = F.log_softmax(student_logits / self.temperature, dim=1)
teacher_soft = F.softmax(teacher_logits / self.temperature, dim=1)
distillation_loss = self.kl_div(student_soft, teacher_soft) * (self.temperature ** 2)
# Hard targets (standard cross-entropy loss)
student_loss = self.ce_loss(student_logits, target_labels)
# Combined loss
total_loss = (
self.alpha * distillation_loss +
(1 - self.alpha) * student_loss
)
return {
'total_loss': total_loss,
'distillation_loss': distillation_loss,
'student_loss': student_loss
}
class FeatureDistillationLoss(nn.Module):
"""Feature-based knowledge distillation"""
def __init__(self, feature_dim: int, student_dim: int = None):
super().__init__()
self.feature_dim = feature_dim
# Projection layer if dimensions don't match
if student_dim and student_dim != feature_dim:
self.projection = nn.Linear(student_dim, feature_dim)
else:
self.projection = nn.Identity()
self.mse_loss = nn.MSELoss()
def forward(self,
student_features: torch.Tensor,
teacher_features: torch.Tensor) -> torch.Tensor:
# Project student features if needed
projected_student = self.projection(student_features)
# Normalize features
student_norm = F.normalize(projected_student, p=2, dim=-1)
teacher_norm = F.normalize(teacher_features, p=2, dim=-1)
# Feature matching loss
return self.mse_loss(student_norm, teacher_norm)
class AttentionDistillationLoss(nn.Module):
"""Attention-based knowledge distillation for transformers"""
def __init__(self, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.mse_loss = nn.MSELoss()
def forward(self,
student_attentions: torch.Tensor,
teacher_attentions: torch.Tensor) -> torch.Tensor:
# Average over attention heads
student_attn_avg = student_attentions.mean(dim=1) # [batch, seq, seq]
teacher_attn_avg = teacher_attentions.mean(dim=1) # [batch, seq, seq]
# Attention pattern matching
return self.mse_loss(student_attn_avg, teacher_attn_avg)
class KnowledgeDistillationTrainer:
"""Complete knowledge distillation training framework"""
def __init__(self,
teacher_model: nn.Module,
student_model: nn.Module,
distillation_config: Dict):
self.teacher_model = teacher_model
self.student_model = student_model
self.config = distillation_config
# Loss functions
self.response_loss = DistillationLoss(
temperature=distillation_config.get('temperature', 4.0),
alpha=distillation_config.get('alpha', 0.7)
)
self.feature_loss = FeatureDistillationLoss(
feature_dim=distillation_config.get('feature_dim', 768),
student_dim=distillation_config.get('student_dim', None)
)
self.attention_loss = AttentionDistillationLoss(
num_heads=distillation_config.get('num_heads', 12)
)
# Freeze teacher model
self.teacher_model.eval()
for param in self.teacher_model.parameters():
param.requires_grad = False
def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""Single training step with knowledge distillation"""
inputs = batch['inputs']
labels = batch['labels']
# Teacher forward pass (no gradients)
with torch.no_grad():
teacher_outputs = self.teacher_model(inputs, return_dict=True)
teacher_logits = teacher_outputs.logits
teacher_features = teacher_outputs.get('hidden_states', None)
teacher_attentions = teacher_outputs.get('attentions', None)
# Student forward pass
student_outputs = self.student_model(inputs, return_dict=True)
student_logits = student_outputs.logits
student_features = student_outputs.get('hidden_states', None)
student_attentions = student_outputs.get('attentions', None)
# Compute losses
losses = {}
# Response distillation (always computed)
response_losses = self.response_loss(student_logits, teacher_logits, labels)
losses.update(response_losses)
total_loss = response_losses['total_loss']
# Feature distillation (if features available)
if (student_features is not None and teacher_features is not None and
self.config.get('use_feature_distillation', False)):
# Use last hidden state
feature_loss = self.feature_loss(
student_features[-1],
teacher_features[-1]
)
losses['feature_loss'] = feature_loss
total_loss += self.config.get('feature_weight', 0.1) * feature_loss
# Attention distillation (if attentions available)
if (student_attentions is not None and teacher_attentions is not None and
self.config.get('use_attention_distillation', False)):
# Use last attention layer
attention_loss = self.attention_loss(
student_attentions[-1],
teacher_attentions[-1]
)
losses['attention_loss'] = attention_loss
total_loss += self.config.get('attention_weight', 0.1) * attention_loss
losses['total_loss'] = total_loss
return {k: v.item() if torch.is_tensor(v) else v for k, v in losses.items()}
def evaluate(self, dataloader) -> Dict[str, float]:
"""Evaluate student model performance"""
self.student_model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch in dataloader:
inputs = batch['inputs']
labels = batch['labels']
# Student inference
student_outputs = self.student_model(inputs, return_dict=True)
student_logits = student_outputs.logits
# Accuracy calculation
predictions = torch.argmax(student_logits, dim=-1)
correct += (predictions == labels).sum().item()
total += labels.size(0)
# Loss calculation
loss = F.cross_entropy(student_logits, labels)
total_loss += loss.item()
return {
'accuracy': correct / total,
'loss': total_loss / len(dataloader)
}
# Usage example
if __name__ == "__main__":
# Configuration
distillation_config = {
'temperature': 4.0,
'alpha': 0.7,
'use_feature_distillation': True,
'use_attention_distillation': True,
'feature_dim': 768,
'student_dim': 384,
'num_heads': 12,
'feature_weight': 0.1,
'attention_weight': 0.05
}
# Initialize trainer (models would be loaded separately)
# trainer = KnowledgeDistillationTrainer(teacher_model, student_model, distillation_config)
print("Knowledge Distillation Framework initialized")
print(f"Configuration: {distillation_config}")Advanced Distillation Techniques
Progressive Knowledge Distillation
Progressive Distillation Implementation
import torch
import torch.nn as nn
from typing import List, Dict
import copy
class ProgressiveKnowledgeDistillation:
"""Progressive distillation with multiple intermediate teachers"""
def __init__(self,
teacher_model: nn.Module,
target_compression: float = 0.1,
num_stages: int = 3):
self.teacher_model = teacher_model
self.target_compression = target_compression
self.num_stages = num_stages
# Calculate intermediate model sizes
self.compression_stages = self._calculate_compression_stages()
# Store intermediate teachers
self.intermediate_teachers = []
def _calculate_compression_stages(self) -> List[float]:
"""Calculate compression ratio for each stage"""
stages = []
current_ratio = 1.0
# Geometric progression of compression ratios
stage_ratio = self.target_compression ** (1 / self.num_stages)
for i in range(self.num_stages):
current_ratio *= stage_ratio
stages.append(current_ratio)
return stages
def create_intermediate_student(self, compression_ratio: float) -> nn.Module:
"""Create intermediate student model with specified compression"""
# This is simplified - in practice, you'd architect smaller models
original_params = sum(p.numel() for p in self.teacher_model.parameters())
target_params = int(original_params * compression_ratio)
# Create smaller model (example with transformer)
if hasattr(self.teacher_model, 'config'):
config = copy.deepcopy(self.teacher_model.config)
# Scale down model dimensions
scale_factor = compression_ratio ** 0.5
config.hidden_size = max(64, int(config.hidden_size * scale_factor))
config.intermediate_size = max(256, int(config.intermediate_size * scale_factor))
config.num_attention_heads = max(1, int(config.num_attention_heads * scale_factor))
config.num_hidden_layers = max(1, int(config.num_hidden_layers * scale_factor))
# Create new model with scaled config
student_model = type(self.teacher_model)(config)
return student_model
# Fallback for custom models
return self._create_custom_student(compression_ratio)
def _create_custom_student(self, compression_ratio: float) -> nn.Module:
"""Create custom student architecture"""
class CompressedModel(nn.Module):
def __init__(self, input_dim: int, output_dim: int, compression: float):
super().__init__()
hidden_dim = max(64, int(512 * compression))
self.layers = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim // 2, output_dim)
)
def forward(self, x):
return self.layers(x)
return CompressedModel(768, 1000, compression_ratio) # Example dimensions
def train_progressive_stages(self,
train_dataloader,
val_dataloader,
optimizer_fn,
num_epochs_per_stage: int = 10):
"""Train through progressive distillation stages"""
current_teacher = self.teacher_model
results = []
for stage_idx, compression_ratio in enumerate(self.compression_stages):
print(f"\n=== Stage {stage_idx + 1}/{self.num_stages} ===")
print(f"Compression ratio: {compression_ratio:.3f}")
# Create student for this stage
student_model = self.create_intermediate_student(compression_ratio)
# Setup distillation trainer
distillation_config = {
'temperature': 4.0 + stage_idx, # Increase temperature for harder stages
'alpha': 0.8 - stage_idx * 0.1, # Decrease alpha for later stages
'use_feature_distillation': True,
'feature_weight': 0.2
}
trainer = KnowledgeDistillationTrainer(
current_teacher,
student_model,
distillation_config
)
# Train student
optimizer = optimizer_fn(student_model.parameters())
stage_results = self._train_stage(
trainer,
train_dataloader,
val_dataloader,
optimizer,
num_epochs_per_stage
)
results.append({
'stage': stage_idx + 1,
'compression_ratio': compression_ratio,
'results': stage_results
})
# Current student becomes next teacher
current_teacher = student_model
self.intermediate_teachers.append(current_teacher)
return results
def _train_stage(self,
trainer,
train_dataloader,
val_dataloader,
optimizer,
num_epochs: int) -> Dict:
"""Train a single distillation stage"""
best_accuracy = 0
stage_results = []
for epoch in range(num_epochs):
# Training
trainer.student_model.train()
total_loss = 0
for batch in train_dataloader:
optimizer.zero_grad()
losses = trainer.train_step(batch)
loss = losses['total_loss']
loss.backward()
optimizer.step()
total_loss += loss
# Validation
val_results = trainer.evaluate(val_dataloader)
stage_results.append({
'epoch': epoch + 1,
'train_loss': total_loss / len(train_dataloader),
'val_accuracy': val_results['accuracy'],
'val_loss': val_results['loss']
})
if val_results['accuracy'] > best_accuracy:
best_accuracy = val_results['accuracy']
print(f"Epoch {epoch + 1}: Val Acc = {val_results['accuracy']:.3f}")
return {
'best_accuracy': best_accuracy,
'training_history': stage_results
}
# Usage example for progressive distillation
def run_progressive_distillation():
# Initialize progressive distillation
progressive_distiller = ProgressiveKnowledgeDistillation(
teacher_model=large_teacher_model,
target_compression=0.05, # 5% of original size
num_stages=4
)
# Define optimizer factory
def optimizer_fn(params):
return torch.optim.AdamW(params, lr=1e-4, weight_decay=0.01)
# Run progressive training
results = progressive_distiller.train_progressive_stages(
train_dataloader=train_loader,
val_dataloader=val_loader,
optimizer_fn=optimizer_fn,
num_epochs_per_stage=15
)
# Final compressed model
final_student = progressive_distiller.intermediate_teachers[-1]
return final_student, resultsProduction Knowledge Distillation Service
Distillation-as-a-Service Platform
production_distillation_service.py
from fastapi import FastAPI, BackgroundTasks, HTTPException
from pydantic import BaseModel
from typing import Dict, List, Optional, Any
import torch
import asyncio
import uuid
import logging
from dataclasses import dataclass
from contextlib import asynccontextmanager
import json
import time
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class DistillationJob:
job_id: str
teacher_model_path: str
target_compression: float
distillation_method: str
status: str = "pending"
progress: float = 0.0
results: Optional[Dict] = None
created_at: float = None
completed_at: Optional[float] = None
class DistillationRequest(BaseModel):
teacher_model_name: str
target_compression: float = 0.1
distillation_method: str = "response_distillation"
training_config: Dict[str, Any] = {}
dataset_config: Dict[str, Any] = {}
class DistillationResponse(BaseModel):
job_id: str
status: str
progress: float
estimated_completion: Optional[str] = None
results: Optional[Dict] = None
class DistillationService:
"""Production knowledge distillation service"""
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.jobs: Dict[str, DistillationJob] = {}
self.active_jobs = set()
self.max_concurrent_jobs = 2
# Model registry
self.available_teachers = {
"bert-large": {
"path": "bert-large-uncased",
"type": "transformer",
"params": 340_000_000
},
"roberta-large": {
"path": "roberta-large",
"type": "transformer",
"params": 355_000_000
},
"gpt2-medium": {
"path": "gpt2-medium",
"type": "causal_lm",
"params": 345_000_000
}
}
async def create_distillation_job(self, request: DistillationRequest) -> str:
"""Create new knowledge distillation job"""
# Validate teacher model
if request.teacher_model_name not in self.available_teachers:
raise HTTPException(
status_code=400,
detail=f"Teacher model {request.teacher_model_name} not available"
)
# Generate job ID
job_id = str(uuid.uuid4())
# Create job
job = DistillationJob(
job_id=job_id,
teacher_model_path=self.available_teachers[request.teacher_model_name]["path"],
target_compression=request.target_compression,
distillation_method=request.distillation_method,
created_at=time.time()
)
self.jobs[job_id] = job
logger.info(f"Created distillation job {job_id}")
return job_id
async def start_distillation_job(self, job_id: str, background_tasks: BackgroundTasks):
"""Start distillation job in background"""
if job_id not in self.jobs:
raise HTTPException(status_code=404, detail="Job not found")
if len(self.active_jobs) >= self.max_concurrent_jobs:
raise HTTPException(
status_code=429,
detail="Maximum concurrent jobs reached"
)
self.active_jobs.add(job_id)
self.jobs[job_id].status = "running"
# Start background distillation
background_tasks.add_task(self._run_distillation, job_id)
return {"message": "Job started", "job_id": job_id}
async def _run_distillation(self, job_id: str):
"""Run knowledge distillation process"""
job = self.jobs[job_id]
try:
logger.info(f"Starting distillation for job {job_id}")
# Simulate distillation process
await self._simulate_distillation_training(job)
# Mark as completed
job.status = "completed"
job.progress = 100.0
job.completed_at = time.time()
logger.info(f"Distillation job {job_id} completed successfully")
except Exception as e:
logger.error(f"Distillation job {job_id} failed: {e}")
job.status = "failed"
job.results = {"error": str(e)}
finally:
self.active_jobs.discard(job_id)
async def _simulate_distillation_training(self, job: DistillationJob):
"""Simulate the distillation training process"""
total_epochs = 20
results = {
"teacher_accuracy": 0.95,
"initial_student_accuracy": 0.73,
"final_student_accuracy": 0.91,
"compression_ratio": job.target_compression,
"training_history": []
}
for epoch in range(total_epochs):
# Simulate training progress
await asyncio.sleep(1) # Simulate training time
job.progress = (epoch + 1) / total_epochs * 100
# Simulate improving accuracy
current_accuracy = (
results["initial_student_accuracy"] +
(results["final_student_accuracy"] - results["initial_student_accuracy"]) *
(epoch + 1) / total_epochs
)
results["training_history"].append({
"epoch": epoch + 1,
"student_accuracy": round(current_accuracy, 3),
"distillation_loss": round(2.5 - epoch * 0.1, 3),
"total_loss": round(1.8 - epoch * 0.07, 3)
})
# Calculate final metrics
original_params = self.available_teachers[
list(self.available_teachers.keys())[0]
]["params"]
student_params = int(original_params * job.target_compression)
results.update({
"original_parameters": original_params,
"compressed_parameters": student_params,
"parameter_reduction": f"{(1 - job.target_compression) * 100:.1f}%",
"memory_savings": f"{(1 - job.target_compression) * 100:.1f}%",
"inference_speedup": f"{1 / job.target_compression:.1f}x",
"accuracy_retention": f"{(results['final_student_accuracy'] / results['teacher_accuracy']) * 100:.1f}%"
})
job.results = results
async def get_job_status(self, job_id: str) -> DistillationResponse:
"""Get status of distillation job"""
if job_id not in self.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = self.jobs[job_id]
# Estimate completion time
estimated_completion = None
if job.status == "running" and job.progress > 0:
elapsed = time.time() - job.created_at
estimated_total = elapsed / (job.progress / 100)
remaining = estimated_total - elapsed
estimated_completion = f"{remaining / 60:.1f} minutes"
return DistillationResponse(
job_id=job_id,
status=job.status,
progress=job.progress,
estimated_completion=estimated_completion,
results=job.results
)
async def list_jobs(self) -> List[Dict]:
"""List all distillation jobs"""
return [
{
"job_id": job.job_id,
"status": job.status,
"progress": job.progress,
"teacher_model": job.teacher_model_path,
"compression": job.target_compression,
"created_at": job.created_at
}
for job in self.jobs.values()
]
async def get_available_teachers(self) -> Dict:
"""Get list of available teacher models"""
return self.available_teachers
async def delete_job(self, job_id: str) -> Dict[str, str]:
"""Delete a distillation job"""
if job_id not in self.jobs:
raise HTTPException(status_code=404, detail="Job not found")
if job_id in self.active_jobs:
raise HTTPException(
status_code=400,
detail="Cannot delete running job"
)
del self.jobs[job_id]
return {"message": "Job deleted", "job_id": job_id}
# Initialize service
distillation_service = DistillationService()
# FastAPI setup
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Starting Knowledge Distillation Service")
yield
logger.info("Shutting down Knowledge Distillation Service")
app = FastAPI(
title="Knowledge Distillation Service",
description="Production knowledge distillation and model compression API",
version="1.0.0",
lifespan=lifespan
)
@app.post("/distillation/jobs", response_model=Dict[str, str])
async def create_distillation_job(request: DistillationRequest):
"""Create new knowledge distillation job"""
job_id = await distillation_service.create_distillation_job(request)
return {"job_id": job_id, "status": "created"}
@app.post("/distillation/jobs/{job_id}/start")
async def start_distillation_job(job_id: str, background_tasks: BackgroundTasks):
"""Start distillation job"""
return await distillation_service.start_distillation_job(job_id, background_tasks)
@app.get("/distillation/jobs/{job_id}", response_model=DistillationResponse)
async def get_job_status(job_id: str):
"""Get distillation job status"""
return await distillation_service.get_job_status(job_id)
@app.get("/distillation/jobs")
async def list_jobs():
"""List all distillation jobs"""
return await distillation_service.list_jobs()
@app.get("/distillation/teachers")
async def get_available_teachers():
"""Get available teacher models"""
return await distillation_service.get_available_teachers()
@app.delete("/distillation/jobs/{job_id}")
async def delete_job(job_id: str):
"""Delete distillation job"""
return await distillation_service.delete_job(job_id)
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"active_jobs": len(distillation_service.active_jobs),
"total_jobs": len(distillation_service.jobs)
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"production_distillation_service:app",
host="0.0.0.0",
port=8000,
workers=1
)Real-World Examples
Google DistilBERT
Compressed BERT model with 60% fewer parameters while retaining 95% of BERT's performance.
- • 6 layers vs 12 in BERT-base
- • 60% parameter reduction
- • 97% accuracy retention
Microsoft MiniLM
Deep self-attention distillation achieving 5.3x speedup with minimal accuracy loss.
- • Self-attention transfer
- • 5.3x inference speedup
- • <2% accuracy drop
TinyBERT
Two-stage knowledge distillation achieving 7.5x smaller models with competitive performance.
- • Two-stage distillation
- • 7.5x parameter reduction
- • 9.4x inference speedup
Knowledge Distillation Best Practices
✅ Do's
- •Use temperature scaling (T=3-5) for soft targets
- •Balance soft and hard losses (α=0.7-0.9)
- •Use feature distillation for complex tasks
- •Consider progressive distillation for extreme compression
- •Validate on target deployment hardware
- •Monitor both accuracy and efficiency metrics
❌ Don'ts
- •Don't use very high temperatures (>10)
- •Don't ignore the importance of teacher quality
- •Don't over-compress without validation
- •Don't skip architectural considerations
- •Don't forget to optimize for target constraints
- •Don't assume distillation works for all tasks
No quiz questions available
Quiz ID "knowledge-distillation" not found