Skip to main contentSkip to user menuSkip to navigation

Model Compression Techniques

Master advanced model compression: pruning, quantization, distillation for efficient deployment

45 min readAdvanced
Not Started
Loading...

🔬 Model Compression Calculator

175B parameters
90% sparse

Compression Results

Technique:Neural Network Pruning
Compressed Size:17.5B params
Compression Ratio:10x smaller
Quality Retention:70%
Speedup Factor:10x faster
Memory Reduction:90% (70GB)
Hardware Support:Limited (specialized)

Trade-off Analysis: Remove redundant connections and neurons based on importance

🔧 Compression Techniques Overview

Neural Network Pruning

Remove redundant connections and neurons based on importance

Typical Compression:10x smaller
Quality Retention:70%
Speedup:10.0x
Hardware Support:Limited (specialized)

Model Quantization

Reduce precision of weights and activations

Typical Compression:4x smaller
Quality Retention:95%
Speedup:4.0x
Hardware Support:Excellent (GPU/CPU)

Knowledge Distillation

Train smaller student model to mimic larger teacher

Typical Compression:10x smaller
Quality Retention:80%
Speedup:10.0x
Hardware Support:Universal

Low-Rank Factorization

Decompose weight matrices into smaller factors

Typical Compression:3x smaller
Quality Retention:92%
Speedup:2.5x
Hardware Support:Good (optimized BLAS)

💻 Implementation Examples

1. Neural Network Pruning with PyTorch

import torch
import torch.nn as nn
import torch.nn.utils.prune as prune

class StructuredPruner:
    def __init__(self, model, sparsity=0.9):
        self.model = model
        self.sparsity = sparsity
        
    def magnitude_based_pruning(self):
        """Remove connections with smallest weights"""
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                prune.l1_unstructured(module, name='weight', amount=self.sparsity)
                
    def structured_pruning(self):
        """Remove entire neurons/filters"""
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Linear):
                # Remove neurons with lowest L2 norm
                prune.ln_structured(
                    module, name='weight', amount=0.5, n=2, dim=0
                )
            elif isinstance(module, nn.Conv2d):
                # Remove filters with lowest L2 norm
                prune.ln_structured(
                    module, name='weight', amount=0.3, n=2, dim=0
                )
    
    def gradual_pruning(self, initial_sparsity=0.1, final_sparsity=0.9, steps=10):
        """Gradually increase sparsity during training"""
        sparsity_schedule = torch.linspace(
            initial_sparsity, final_sparsity, steps
        )
        
        for step, sparsity in enumerate(sparsity_schedule):
            self.apply_global_pruning(sparsity)
            # Train for several epochs
            self.train_epoch()
            
    def apply_global_pruning(self, sparsity):
        """Global magnitude-based pruning across all layers"""
        parameters_to_prune = []
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                parameters_to_prune.append((module, 'weight'))
        
        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=sparsity,
        )

# Usage example
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
pruner = StructuredPruner(model, sparsity=0.9)

# Apply different pruning strategies
pruner.magnitude_based_pruning()
print(f"Model sparsity: {prune.global_sparsity(model):.2%}")

# Remove pruning masks to get actual compressed model
for module in model.modules():
    if isinstance(module, (nn.Linear, nn.Conv2d)):
        prune.remove(module, 'weight')

2. Advanced Quantization Techniques

import torch
import torch.quantization as quant
from transformers import AutoModel, AutoTokenizer

class AdvancedQuantizer:
    def __init__(self, model, calibration_data):
        self.model = model
        self.calibration_data = calibration_data
        
    def post_training_quantization(self):
        """PTQ: Quantize after training without retraining"""
        # Prepare model for quantization
        self.model.eval()
        self.model.qconfig = quant.get_default_qconfig('fbgemm')
        quant.prepare(self.model, inplace=True)
        
        # Calibrate with representative data
        with torch.no_grad():
            for batch in self.calibration_data:
                self.model(batch)
        
        # Convert to quantized model
        quantized_model = quant.convert(self.model)
        return quantized_model
    
    def quantization_aware_training(self):
        """QAT: Fine-tune model during quantization"""
        self.model.train()
        self.model.qconfig = quant.get_default_qat_qconfig('fbgemm')
        
        # Prepare for QAT
        quant.prepare_qat(self.model, inplace=True)
        
        # Fine-tune for several epochs with quantization simulation
        optimizer = torch.optim.SGD(self.model.parameters(), lr=0.0001)
        
        for epoch in range(3):  # Few epochs of fine-tuning
            for batch in self.calibration_data:
                optimizer.zero_grad()
                output = self.model(batch)
                loss = self.compute_loss(output, batch)
                loss.backward()
                optimizer.step()
        
        # Convert to quantized inference model
        self.model.eval()
        quantized_model = quant.convert(self.model)
        return quantized_model

# Advanced LLM Quantization with GPTQ/AWQ
class LLMQuantizer:
    def __init__(self, model_name):
        self.model_name = model_name
        
    def gptq_quantization(self, bits=4):
        """GPTQ: Optimal quantization for generative models"""
        from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
        
        # Configure quantization
        quantize_config = BaseQuantizeConfig(
            bits=bits,
            group_size=128,
            desc_act=False,
        )
        
        # Load and quantize model
        model = AutoGPTQForCausalLM.from_pretrained(
            self.model_name,
            quantize_config=quantize_config
        )
        
        # Quantize with calibration dataset
        model.quantize(calibration_dataset)
        
        return model
    
    def awq_quantization(self):
        """AWQ: Activation-aware weight quantization"""
        from awq import AutoAWQForCausalLM
        
        model = AutoAWQForCausalLM.from_pretrained(self.model_name)
        tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        
        # Quantize with activation-aware method
        model.quantize(tokenizer, quant_config={
            "zero_point": True, 
            "q_group_size": 128,
            "w_bit": 4,
            "version": "GEMM"
        })
        
        return model

# Usage
quantizer = LLMQuantizer('meta-llama/Llama-2-7b-hf')
gptq_model = quantizer.gptq_quantization(bits=4)
print(f"Model size reduced by ~{32/4}x with 4-bit GPTQ")

3. Knowledge Distillation Framework

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig

class KnowledgeDistiller:
    def __init__(self, teacher_model, student_config, alpha=0.7, temperature=4):
        self.teacher = teacher_model
        self.teacher.eval()
        
        # Create smaller student model
        self.student = self.create_student_model(student_config)
        
        self.alpha = alpha  # Weight for distillation loss
        self.temperature = temperature  # Softmax temperature
        
    def create_student_model(self, config):
        """Create smaller student model with same architecture"""
        student_config = AutoConfig.from_pretrained(config['base_model'])
        
        # Reduce model size
        student_config.hidden_size = config.get('hidden_size', 384)
        student_config.num_attention_heads = config.get('num_heads', 6)
        student_config.num_hidden_layers = config.get('num_layers', 6)
        student_config.intermediate_size = config.get('intermediate_size', 1536)
        
        return AutoModel.from_config(student_config)
    
    def distillation_loss(self, student_logits, teacher_logits, labels):
        """Compute combined distillation and task loss"""
        # Soft targets from teacher
        teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
        student_log_probs = F.log_softmax(
            student_logits / self.temperature, dim=-1
        )
        
        # KL divergence loss
        distill_loss = F.kl_div(
            student_log_probs, 
            teacher_probs, 
            reduction='batchmean'
        ) * (self.temperature ** 2)
        
        # Task-specific loss
        task_loss = F.cross_entropy(student_logits, labels)
        
        # Combined loss
        total_loss = (
            self.alpha * distill_loss + 
            (1 - self.alpha) * task_loss
        )
        
        return total_loss, distill_loss, task_loss
    
    def feature_distillation(self, student_features, teacher_features):
        """Match intermediate layer representations"""
        feature_losses = []
        
        for s_feat, t_feat in zip(student_features, teacher_features):
            # Project student features to teacher dimension if needed
            if s_feat.size(-1) != t_feat.size(-1):
                projection = nn.Linear(s_feat.size(-1), t_feat.size(-1))
                s_feat = projection(s_feat)
            
            # MSE loss between features
            feat_loss = F.mse_loss(s_feat, t_feat.detach())
            feature_losses.append(feat_loss)
        
        return sum(feature_losses) / len(feature_losses)
    
    def train_student(self, train_loader, num_epochs=5):
        """Train student model with knowledge distillation"""
        optimizer = torch.optim.AdamW(self.student.parameters(), lr=5e-4)
        
        for epoch in range(num_epochs):
            total_loss = 0
            
            for batch in train_loader:
                inputs, labels = batch
                
                # Teacher predictions (no gradients)
                with torch.no_grad():
                    teacher_outputs = self.teacher(inputs)
                    teacher_logits = teacher_outputs.logits
                
                # Student predictions
                student_outputs = self.student(inputs)
                student_logits = student_outputs.logits
                
                # Compute distillation loss
                loss, distill_loss, task_loss = self.distillation_loss(
                    student_logits, teacher_logits, labels
                )
                
                # Optionally add feature distillation
                if hasattr(student_outputs, 'hidden_states'):
                    feat_loss = self.feature_distillation(
                        student_outputs.hidden_states,
                        teacher_outputs.hidden_states
                    )
                    loss += 0.1 * feat_loss
                
                # Backpropagation
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
            
            print(f"Epoch {epoch+1}: Loss = {total_loss/len(train_loader):.4f}")
        
        return self.student

# Progressive Knowledge Distillation
class ProgressiveDistiller:
    def __init__(self, teacher_model, compression_stages):
        self.teacher = teacher_model
        self.stages = compression_stages
        
    def multi_stage_distillation(self, train_loader):
        """Gradually compress model through multiple stages"""
        current_teacher = self.teacher
        
        for stage, config in enumerate(self.stages):
            print(f"Stage {stage + 1}: Creating student with {config}")
            
            distiller = KnowledgeDistiller(current_teacher, config)
            student = distiller.train_student(train_loader)
            
            # Use current student as teacher for next stage
            current_teacher = student
            
        return current_teacher

# Usage Example
stages = [
    {'hidden_size': 512, 'num_layers': 8, 'num_heads': 8},  # Stage 1
    {'hidden_size': 384, 'num_layers': 6, 'num_heads': 6},  # Stage 2
    {'hidden_size': 256, 'num_layers': 4, 'num_heads': 4},  # Stage 3
]

progressive_distiller = ProgressiveDistiller(teacher_model, stages)
final_student = progressive_distiller.multi_stage_distillation(train_loader)

🏭 Production Compression Examples

M

MobileBERT

Google's compressed BERT

  • Technique:Knowledge Distillation
  • Size Reduction:4x smaller (25M params)
  • Speed:5.5x faster inference
  • Quality:99.2% of BERT performance
D

DistilBERT

Hugging Face distilled model

  • Technique:Triple Loss Distillation
  • Size Reduction:2x smaller (66M params)
  • Speed:1.6x faster
  • Quality:97% of BERT performance
L

Llama 2 4-bit

GPTQ/AWQ quantized

  • Technique:GPTQ/AWQ 4-bit
  • Size Reduction:8x smaller memory
  • Speed:2-3x faster inference
  • Quality:98.5% performance retained
E

EfficientNet

Neural architecture search

  • Technique:Architecture + Pruning
  • Size Reduction:10x smaller than ResNet
  • Speed:6.1x faster
  • Quality:Better accuracy

✅ Compression Best Practices

✅ Do's

  • Profile before compressing - Identify actual bottlenecks and most important layers
  • Use calibration data - Representative samples for quantization calibration
  • Combine techniques carefully - Pruning + quantization can be complementary
  • Validate on diverse datasets - Ensure compression doesn't hurt edge cases
  • Consider hardware constraints - Match compression to target deployment platform

❌ Don'ts

  • Don't compress blindly - Understand model structure and importance of different layers
  • Avoid extreme compression ratios - >95% sparsity often hurts quality significantly
  • Don't ignore inference framework - Some formats aren't supported on all hardware
  • Avoid one-size-fits-all - Different layers may need different compression strategies
  • Don't skip quality validation - Always benchmark compressed model performance
No quiz questions available
Quiz ID "model-compression-techniques" not found