Skip to main contentSkip to user menuSkip to navigation

Knowledge Distillation & Model Compression

Master teacher-student training, pruning, quantization, and deployment optimization techniques

50 min readAdvanced
Not Started
Loading...

What is Knowledge Distillation?

Knowledge Distillation is a model compression technique where a smaller "student" model learns to mimic the behavior of a larger "teacher" model, achieving similar performance with significantly fewer parameters.

Model Compression
90%+ size reduction possible
Performance Retention
95%+ accuracy preservation
Deployment Ready
Mobile & edge optimization

🧮 Knowledge Distillation Calculator

Calculate compression ratios, performance retention, and optimization metrics for knowledge distillation.

Distillation Analysis

Compression Ratio:4.00%
Teacher Accuracy:92.0%
Student Accuracy:89.1%
Accuracy Retention:96.8%
Memory Reduction:96.0%
Speed Improvement:5.0x
Cost Savings:80.0%

Knowledge Distillation Methods

Response Distillation

  • • Learn from teacher's output probabilities
  • • Soft targets with temperature scaling
  • • Most common and straightforward approach
  • • Good for classification tasks
  • • Minimal architectural constraints

Feature Distillation

  • • Match intermediate layer representations
  • • Transfer rich feature knowledge
  • • Better for complex tasks
  • • Requires architectural alignment
  • • Higher computational overhead

Attention Distillation

  • • Transfer attention patterns
  • • Preserve important relationships
  • • Effective for transformer models
  • • Better interpretability
  • • Specialized for attention mechanisms

Progressive Distillation

  • • Gradual knowledge transfer
  • • Multiple intermediate models
  • • Best compression ratios
  • • Higher training complexity
  • • Superior final performance

Knowledge Distillation Implementation

Basic Distillation Framework

Knowledge Distillation Core
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple, Optional
import numpy as np

class DistillationLoss(nn.Module):
    """Knowledge distillation loss combining soft and hard targets"""
    
    def __init__(self, 
                 temperature: float = 4.0,
                 alpha: float = 0.7,
                 reduction: str = 'mean'):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.reduction = reduction
        self.kl_div = nn.KLDivLoss(reduction=reduction)
        self.ce_loss = nn.CrossEntropyLoss(reduction=reduction)
    
    def forward(self, 
                student_logits: torch.Tensor,
                teacher_logits: torch.Tensor, 
                target_labels: torch.Tensor) -> Dict[str, torch.Tensor]:
        
        # Soft targets from teacher (knowledge distillation loss)
        student_soft = F.log_softmax(student_logits / self.temperature, dim=1)
        teacher_soft = F.softmax(teacher_logits / self.temperature, dim=1)
        
        distillation_loss = self.kl_div(student_soft, teacher_soft) * (self.temperature ** 2)
        
        # Hard targets (standard cross-entropy loss)  
        student_loss = self.ce_loss(student_logits, target_labels)
        
        # Combined loss
        total_loss = (
            self.alpha * distillation_loss + 
            (1 - self.alpha) * student_loss
        )
        
        return {
            'total_loss': total_loss,
            'distillation_loss': distillation_loss,
            'student_loss': student_loss
        }

class FeatureDistillationLoss(nn.Module):
    """Feature-based knowledge distillation"""
    
    def __init__(self, feature_dim: int, student_dim: int = None):
        super().__init__()
        self.feature_dim = feature_dim
        
        # Projection layer if dimensions don't match
        if student_dim and student_dim != feature_dim:
            self.projection = nn.Linear(student_dim, feature_dim)
        else:
            self.projection = nn.Identity()
        
        self.mse_loss = nn.MSELoss()
    
    def forward(self, 
                student_features: torch.Tensor,
                teacher_features: torch.Tensor) -> torch.Tensor:
        
        # Project student features if needed
        projected_student = self.projection(student_features)
        
        # Normalize features
        student_norm = F.normalize(projected_student, p=2, dim=-1)
        teacher_norm = F.normalize(teacher_features, p=2, dim=-1)
        
        # Feature matching loss
        return self.mse_loss(student_norm, teacher_norm)

class AttentionDistillationLoss(nn.Module):
    """Attention-based knowledge distillation for transformers"""
    
    def __init__(self, num_heads: int):
        super().__init__()
        self.num_heads = num_heads
        self.mse_loss = nn.MSELoss()
    
    def forward(self, 
                student_attentions: torch.Tensor,
                teacher_attentions: torch.Tensor) -> torch.Tensor:
        
        # Average over attention heads
        student_attn_avg = student_attentions.mean(dim=1)  # [batch, seq, seq]
        teacher_attn_avg = teacher_attentions.mean(dim=1)  # [batch, seq, seq]
        
        # Attention pattern matching
        return self.mse_loss(student_attn_avg, teacher_attn_avg)

class KnowledgeDistillationTrainer:
    """Complete knowledge distillation training framework"""
    
    def __init__(self,
                 teacher_model: nn.Module,
                 student_model: nn.Module,
                 distillation_config: Dict):
        
        self.teacher_model = teacher_model
        self.student_model = student_model
        self.config = distillation_config
        
        # Loss functions
        self.response_loss = DistillationLoss(
            temperature=distillation_config.get('temperature', 4.0),
            alpha=distillation_config.get('alpha', 0.7)
        )
        
        self.feature_loss = FeatureDistillationLoss(
            feature_dim=distillation_config.get('feature_dim', 768),
            student_dim=distillation_config.get('student_dim', None)
        )
        
        self.attention_loss = AttentionDistillationLoss(
            num_heads=distillation_config.get('num_heads', 12)
        )
        
        # Freeze teacher model
        self.teacher_model.eval()
        for param in self.teacher_model.parameters():
            param.requires_grad = False
    
    def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
        """Single training step with knowledge distillation"""
        
        inputs = batch['inputs']
        labels = batch['labels']
        
        # Teacher forward pass (no gradients)
        with torch.no_grad():
            teacher_outputs = self.teacher_model(inputs, return_dict=True)
            teacher_logits = teacher_outputs.logits
            teacher_features = teacher_outputs.get('hidden_states', None)
            teacher_attentions = teacher_outputs.get('attentions', None)
        
        # Student forward pass
        student_outputs = self.student_model(inputs, return_dict=True)
        student_logits = student_outputs.logits
        student_features = student_outputs.get('hidden_states', None)
        student_attentions = student_outputs.get('attentions', None)
        
        # Compute losses
        losses = {}
        
        # Response distillation (always computed)
        response_losses = self.response_loss(student_logits, teacher_logits, labels)
        losses.update(response_losses)
        
        total_loss = response_losses['total_loss']
        
        # Feature distillation (if features available)
        if (student_features is not None and teacher_features is not None and 
            self.config.get('use_feature_distillation', False)):
            
            # Use last hidden state
            feature_loss = self.feature_loss(
                student_features[-1], 
                teacher_features[-1]
            )
            losses['feature_loss'] = feature_loss
            total_loss += self.config.get('feature_weight', 0.1) * feature_loss
        
        # Attention distillation (if attentions available)
        if (student_attentions is not None and teacher_attentions is not None and
            self.config.get('use_attention_distillation', False)):
            
            # Use last attention layer
            attention_loss = self.attention_loss(
                student_attentions[-1],
                teacher_attentions[-1]
            )
            losses['attention_loss'] = attention_loss
            total_loss += self.config.get('attention_weight', 0.1) * attention_loss
        
        losses['total_loss'] = total_loss
        
        return {k: v.item() if torch.is_tensor(v) else v for k, v in losses.items()}
    
    def evaluate(self, dataloader) -> Dict[str, float]:
        """Evaluate student model performance"""
        self.student_model.eval()
        
        total_loss = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch in dataloader:
                inputs = batch['inputs']
                labels = batch['labels']
                
                # Student inference
                student_outputs = self.student_model(inputs, return_dict=True)
                student_logits = student_outputs.logits
                
                # Accuracy calculation
                predictions = torch.argmax(student_logits, dim=-1)
                correct += (predictions == labels).sum().item()
                total += labels.size(0)
                
                # Loss calculation
                loss = F.cross_entropy(student_logits, labels)
                total_loss += loss.item()
        
        return {
            'accuracy': correct / total,
            'loss': total_loss / len(dataloader)
        }

# Usage example
if __name__ == "__main__":
    # Configuration
    distillation_config = {
        'temperature': 4.0,
        'alpha': 0.7,
        'use_feature_distillation': True,
        'use_attention_distillation': True,
        'feature_dim': 768,
        'student_dim': 384,
        'num_heads': 12,
        'feature_weight': 0.1,
        'attention_weight': 0.05
    }
    
    # Initialize trainer (models would be loaded separately)
    # trainer = KnowledgeDistillationTrainer(teacher_model, student_model, distillation_config)
    
    print("Knowledge Distillation Framework initialized")
    print(f"Configuration: {distillation_config}")

Advanced Distillation Techniques

Progressive Knowledge Distillation

Progressive Distillation Implementation
import torch
import torch.nn as nn
from typing import List, Dict
import copy

class ProgressiveKnowledgeDistillation:
    """Progressive distillation with multiple intermediate teachers"""
    
    def __init__(self, 
                 teacher_model: nn.Module,
                 target_compression: float = 0.1,
                 num_stages: int = 3):
        
        self.teacher_model = teacher_model
        self.target_compression = target_compression
        self.num_stages = num_stages
        
        # Calculate intermediate model sizes
        self.compression_stages = self._calculate_compression_stages()
        
        # Store intermediate teachers
        self.intermediate_teachers = []
    
    def _calculate_compression_stages(self) -> List[float]:
        """Calculate compression ratio for each stage"""
        stages = []
        current_ratio = 1.0
        
        # Geometric progression of compression ratios
        stage_ratio = self.target_compression ** (1 / self.num_stages)
        
        for i in range(self.num_stages):
            current_ratio *= stage_ratio
            stages.append(current_ratio)
        
        return stages
    
    def create_intermediate_student(self, compression_ratio: float) -> nn.Module:
        """Create intermediate student model with specified compression"""
        # This is simplified - in practice, you'd architect smaller models
        original_params = sum(p.numel() for p in self.teacher_model.parameters())
        target_params = int(original_params * compression_ratio)
        
        # Create smaller model (example with transformer)
        if hasattr(self.teacher_model, 'config'):
            config = copy.deepcopy(self.teacher_model.config)
            
            # Scale down model dimensions
            scale_factor = compression_ratio ** 0.5
            config.hidden_size = max(64, int(config.hidden_size * scale_factor))
            config.intermediate_size = max(256, int(config.intermediate_size * scale_factor))
            config.num_attention_heads = max(1, int(config.num_attention_heads * scale_factor))
            config.num_hidden_layers = max(1, int(config.num_hidden_layers * scale_factor))
            
            # Create new model with scaled config
            student_model = type(self.teacher_model)(config)
            return student_model
        
        # Fallback for custom models
        return self._create_custom_student(compression_ratio)
    
    def _create_custom_student(self, compression_ratio: float) -> nn.Module:
        """Create custom student architecture"""
        class CompressedModel(nn.Module):
            def __init__(self, input_dim: int, output_dim: int, compression: float):
                super().__init__()
                hidden_dim = max(64, int(512 * compression))
                
                self.layers = nn.Sequential(
                    nn.Linear(input_dim, hidden_dim),
                    nn.ReLU(),
                    nn.Dropout(0.1),
                    nn.Linear(hidden_dim, hidden_dim // 2),
                    nn.ReLU(),
                    nn.Dropout(0.1),
                    nn.Linear(hidden_dim // 2, output_dim)
                )
            
            def forward(self, x):
                return self.layers(x)
        
        return CompressedModel(768, 1000, compression_ratio)  # Example dimensions
    
    def train_progressive_stages(self, 
                                train_dataloader,
                                val_dataloader,
                                optimizer_fn,
                                num_epochs_per_stage: int = 10):
        """Train through progressive distillation stages"""
        
        current_teacher = self.teacher_model
        results = []
        
        for stage_idx, compression_ratio in enumerate(self.compression_stages):
            print(f"\n=== Stage {stage_idx + 1}/{self.num_stages} ===")
            print(f"Compression ratio: {compression_ratio:.3f}")
            
            # Create student for this stage
            student_model = self.create_intermediate_student(compression_ratio)
            
            # Setup distillation trainer
            distillation_config = {
                'temperature': 4.0 + stage_idx,  # Increase temperature for harder stages
                'alpha': 0.8 - stage_idx * 0.1,  # Decrease alpha for later stages
                'use_feature_distillation': True,
                'feature_weight': 0.2
            }
            
            trainer = KnowledgeDistillationTrainer(
                current_teacher, 
                student_model, 
                distillation_config
            )
            
            # Train student
            optimizer = optimizer_fn(student_model.parameters())
            stage_results = self._train_stage(
                trainer, 
                train_dataloader, 
                val_dataloader,
                optimizer,
                num_epochs_per_stage
            )
            
            results.append({
                'stage': stage_idx + 1,
                'compression_ratio': compression_ratio,
                'results': stage_results
            })
            
            # Current student becomes next teacher
            current_teacher = student_model
            self.intermediate_teachers.append(current_teacher)
        
        return results
    
    def _train_stage(self, 
                    trainer, 
                    train_dataloader, 
                    val_dataloader,
                    optimizer, 
                    num_epochs: int) -> Dict:
        """Train a single distillation stage"""
        
        best_accuracy = 0
        stage_results = []
        
        for epoch in range(num_epochs):
            # Training
            trainer.student_model.train()
            total_loss = 0
            
            for batch in train_dataloader:
                optimizer.zero_grad()
                
                losses = trainer.train_step(batch)
                loss = losses['total_loss']
                
                loss.backward()
                optimizer.step()
                
                total_loss += loss
            
            # Validation
            val_results = trainer.evaluate(val_dataloader)
            
            stage_results.append({
                'epoch': epoch + 1,
                'train_loss': total_loss / len(train_dataloader),
                'val_accuracy': val_results['accuracy'],
                'val_loss': val_results['loss']
            })
            
            if val_results['accuracy'] > best_accuracy:
                best_accuracy = val_results['accuracy']
            
            print(f"Epoch {epoch + 1}: Val Acc = {val_results['accuracy']:.3f}")
        
        return {
            'best_accuracy': best_accuracy,
            'training_history': stage_results
        }

# Usage example for progressive distillation
def run_progressive_distillation():
    # Initialize progressive distillation
    progressive_distiller = ProgressiveKnowledgeDistillation(
        teacher_model=large_teacher_model,
        target_compression=0.05,  # 5% of original size
        num_stages=4
    )
    
    # Define optimizer factory
    def optimizer_fn(params):
        return torch.optim.AdamW(params, lr=1e-4, weight_decay=0.01)
    
    # Run progressive training
    results = progressive_distiller.train_progressive_stages(
        train_dataloader=train_loader,
        val_dataloader=val_loader, 
        optimizer_fn=optimizer_fn,
        num_epochs_per_stage=15
    )
    
    # Final compressed model
    final_student = progressive_distiller.intermediate_teachers[-1]
    
    return final_student, results

Production Knowledge Distillation Service

Distillation-as-a-Service Platform

production_distillation_service.py
from fastapi import FastAPI, BackgroundTasks, HTTPException
from pydantic import BaseModel
from typing import Dict, List, Optional, Any
import torch
import asyncio
import uuid
import logging
from dataclasses import dataclass
from contextlib import asynccontextmanager
import json
import time

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class DistillationJob:
    job_id: str
    teacher_model_path: str
    target_compression: float
    distillation_method: str
    status: str = "pending"
    progress: float = 0.0
    results: Optional[Dict] = None
    created_at: float = None
    completed_at: Optional[float] = None

class DistillationRequest(BaseModel):
    teacher_model_name: str
    target_compression: float = 0.1
    distillation_method: str = "response_distillation"
    training_config: Dict[str, Any] = {}
    dataset_config: Dict[str, Any] = {}

class DistillationResponse(BaseModel):
    job_id: str
    status: str
    progress: float
    estimated_completion: Optional[str] = None
    results: Optional[Dict] = None

class DistillationService:
    """Production knowledge distillation service"""
    
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.jobs: Dict[str, DistillationJob] = {}
        self.active_jobs = set()
        self.max_concurrent_jobs = 2
        
        # Model registry
        self.available_teachers = {
            "bert-large": {
                "path": "bert-large-uncased",
                "type": "transformer",
                "params": 340_000_000
            },
            "roberta-large": {
                "path": "roberta-large", 
                "type": "transformer",
                "params": 355_000_000
            },
            "gpt2-medium": {
                "path": "gpt2-medium",
                "type": "causal_lm", 
                "params": 345_000_000
            }
        }
    
    async def create_distillation_job(self, request: DistillationRequest) -> str:
        """Create new knowledge distillation job"""
        
        # Validate teacher model
        if request.teacher_model_name not in self.available_teachers:
            raise HTTPException(
                status_code=400,
                detail=f"Teacher model {request.teacher_model_name} not available"
            )
        
        # Generate job ID
        job_id = str(uuid.uuid4())
        
        # Create job
        job = DistillationJob(
            job_id=job_id,
            teacher_model_path=self.available_teachers[request.teacher_model_name]["path"],
            target_compression=request.target_compression,
            distillation_method=request.distillation_method,
            created_at=time.time()
        )
        
        self.jobs[job_id] = job
        
        logger.info(f"Created distillation job {job_id}")
        return job_id
    
    async def start_distillation_job(self, job_id: str, background_tasks: BackgroundTasks):
        """Start distillation job in background"""
        
        if job_id not in self.jobs:
            raise HTTPException(status_code=404, detail="Job not found")
        
        if len(self.active_jobs) >= self.max_concurrent_jobs:
            raise HTTPException(
                status_code=429, 
                detail="Maximum concurrent jobs reached"
            )
        
        self.active_jobs.add(job_id)
        self.jobs[job_id].status = "running"
        
        # Start background distillation
        background_tasks.add_task(self._run_distillation, job_id)
        
        return {"message": "Job started", "job_id": job_id}
    
    async def _run_distillation(self, job_id: str):
        """Run knowledge distillation process"""
        
        job = self.jobs[job_id]
        
        try:
            logger.info(f"Starting distillation for job {job_id}")
            
            # Simulate distillation process
            await self._simulate_distillation_training(job)
            
            # Mark as completed
            job.status = "completed"
            job.progress = 100.0
            job.completed_at = time.time()
            
            logger.info(f"Distillation job {job_id} completed successfully")
            
        except Exception as e:
            logger.error(f"Distillation job {job_id} failed: {e}")
            job.status = "failed"
            job.results = {"error": str(e)}
        
        finally:
            self.active_jobs.discard(job_id)
    
    async def _simulate_distillation_training(self, job: DistillationJob):
        """Simulate the distillation training process"""
        
        total_epochs = 20
        results = {
            "teacher_accuracy": 0.95,
            "initial_student_accuracy": 0.73,
            "final_student_accuracy": 0.91,
            "compression_ratio": job.target_compression,
            "training_history": []
        }
        
        for epoch in range(total_epochs):
            # Simulate training progress
            await asyncio.sleep(1)  # Simulate training time
            
            job.progress = (epoch + 1) / total_epochs * 100
            
            # Simulate improving accuracy
            current_accuracy = (
                results["initial_student_accuracy"] + 
                (results["final_student_accuracy"] - results["initial_student_accuracy"]) * 
                (epoch + 1) / total_epochs
            )
            
            results["training_history"].append({
                "epoch": epoch + 1,
                "student_accuracy": round(current_accuracy, 3),
                "distillation_loss": round(2.5 - epoch * 0.1, 3),
                "total_loss": round(1.8 - epoch * 0.07, 3)
            })
        
        # Calculate final metrics
        original_params = self.available_teachers[
            list(self.available_teachers.keys())[0]
        ]["params"]
        student_params = int(original_params * job.target_compression)
        
        results.update({
            "original_parameters": original_params,
            "compressed_parameters": student_params,
            "parameter_reduction": f"{(1 - job.target_compression) * 100:.1f}%",
            "memory_savings": f"{(1 - job.target_compression) * 100:.1f}%",
            "inference_speedup": f"{1 / job.target_compression:.1f}x",
            "accuracy_retention": f"{(results['final_student_accuracy'] / results['teacher_accuracy']) * 100:.1f}%"
        })
        
        job.results = results
    
    async def get_job_status(self, job_id: str) -> DistillationResponse:
        """Get status of distillation job"""
        
        if job_id not in self.jobs:
            raise HTTPException(status_code=404, detail="Job not found")
        
        job = self.jobs[job_id]
        
        # Estimate completion time
        estimated_completion = None
        if job.status == "running" and job.progress > 0:
            elapsed = time.time() - job.created_at
            estimated_total = elapsed / (job.progress / 100)
            remaining = estimated_total - elapsed
            estimated_completion = f"{remaining / 60:.1f} minutes"
        
        return DistillationResponse(
            job_id=job_id,
            status=job.status,
            progress=job.progress,
            estimated_completion=estimated_completion,
            results=job.results
        )
    
    async def list_jobs(self) -> List[Dict]:
        """List all distillation jobs"""
        return [
            {
                "job_id": job.job_id,
                "status": job.status,
                "progress": job.progress,
                "teacher_model": job.teacher_model_path,
                "compression": job.target_compression,
                "created_at": job.created_at
            }
            for job in self.jobs.values()
        ]
    
    async def get_available_teachers(self) -> Dict:
        """Get list of available teacher models"""
        return self.available_teachers
    
    async def delete_job(self, job_id: str) -> Dict[str, str]:
        """Delete a distillation job"""
        
        if job_id not in self.jobs:
            raise HTTPException(status_code=404, detail="Job not found")
        
        if job_id in self.active_jobs:
            raise HTTPException(
                status_code=400, 
                detail="Cannot delete running job"
            )
        
        del self.jobs[job_id]
        return {"message": "Job deleted", "job_id": job_id}

# Initialize service
distillation_service = DistillationService()

# FastAPI setup
@asynccontextmanager
async def lifespan(app: FastAPI):
    logger.info("Starting Knowledge Distillation Service")
    yield
    logger.info("Shutting down Knowledge Distillation Service")

app = FastAPI(
    title="Knowledge Distillation Service",
    description="Production knowledge distillation and model compression API",
    version="1.0.0",
    lifespan=lifespan
)

@app.post("/distillation/jobs", response_model=Dict[str, str])
async def create_distillation_job(request: DistillationRequest):
    """Create new knowledge distillation job"""
    job_id = await distillation_service.create_distillation_job(request)
    return {"job_id": job_id, "status": "created"}

@app.post("/distillation/jobs/{job_id}/start")
async def start_distillation_job(job_id: str, background_tasks: BackgroundTasks):
    """Start distillation job"""
    return await distillation_service.start_distillation_job(job_id, background_tasks)

@app.get("/distillation/jobs/{job_id}", response_model=DistillationResponse)
async def get_job_status(job_id: str):
    """Get distillation job status"""
    return await distillation_service.get_job_status(job_id)

@app.get("/distillation/jobs")
async def list_jobs():
    """List all distillation jobs"""
    return await distillation_service.list_jobs()

@app.get("/distillation/teachers")
async def get_available_teachers():
    """Get available teacher models"""
    return await distillation_service.get_available_teachers()

@app.delete("/distillation/jobs/{job_id}")
async def delete_job(job_id: str):
    """Delete distillation job"""
    return await distillation_service.delete_job(job_id)

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    return {
        "status": "healthy",
        "active_jobs": len(distillation_service.active_jobs),
        "total_jobs": len(distillation_service.jobs)
    }

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(
        "production_distillation_service:app",
        host="0.0.0.0",
        port=8000,
        workers=1
    )

Real-World Examples

Google DistilBERT

Compressed BERT model with 60% fewer parameters while retaining 95% of BERT's performance.

  • • 6 layers vs 12 in BERT-base
  • • 60% parameter reduction
  • • 97% accuracy retention

Microsoft MiniLM

Deep self-attention distillation achieving 5.3x speedup with minimal accuracy loss.

  • • Self-attention transfer
  • • 5.3x inference speedup
  • • <2% accuracy drop

TinyBERT

Two-stage knowledge distillation achieving 7.5x smaller models with competitive performance.

  • • Two-stage distillation
  • • 7.5x parameter reduction
  • • 9.4x inference speedup

Knowledge Distillation Best Practices

✅ Do's

  • Use temperature scaling (T=3-5) for soft targets
  • Balance soft and hard losses (α=0.7-0.9)
  • Use feature distillation for complex tasks
  • Consider progressive distillation for extreme compression
  • Validate on target deployment hardware
  • Monitor both accuracy and efficiency metrics

❌ Don'ts

  • Don't use very high temperatures (>10)
  • Don't ignore the importance of teacher quality
  • Don't over-compress without validation
  • Don't skip architectural considerations
  • Don't forget to optimize for target constraints
  • Don't assume distillation works for all tasks
No quiz questions available
Quiz ID "knowledge-distillation" not found