Model Compression Techniques
Master advanced model compression: pruning, quantization, distillation for efficient deployment
45 min read•Advanced
Not Started
Loading...
🔬 Model Compression Calculator
175B parameters
90% sparse
Compression Results
Technique:Neural Network Pruning
Compressed Size:17.5B params
Compression Ratio:10x smaller
Quality Retention:70%
Speedup Factor:10x faster
Memory Reduction:90% (70GB)
Hardware Support:Limited (specialized)
Trade-off Analysis: Remove redundant connections and neurons based on importance
🔧 Compression Techniques Overview
Neural Network Pruning
Remove redundant connections and neurons based on importance
Typical Compression:10x smaller
Quality Retention:70%
Speedup:10.0x
Hardware Support:Limited (specialized)
Model Quantization
Reduce precision of weights and activations
Typical Compression:4x smaller
Quality Retention:95%
Speedup:4.0x
Hardware Support:Excellent (GPU/CPU)
Knowledge Distillation
Train smaller student model to mimic larger teacher
Typical Compression:10x smaller
Quality Retention:80%
Speedup:10.0x
Hardware Support:Universal
Low-Rank Factorization
Decompose weight matrices into smaller factors
Typical Compression:3x smaller
Quality Retention:92%
Speedup:2.5x
Hardware Support:Good (optimized BLAS)
💻 Implementation Examples
1. Neural Network Pruning with PyTorch
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
class StructuredPruner:
def __init__(self, model, sparsity=0.9):
self.model = model
self.sparsity = sparsity
def magnitude_based_pruning(self):
"""Remove connections with smallest weights"""
for name, module in self.model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
prune.l1_unstructured(module, name='weight', amount=self.sparsity)
def structured_pruning(self):
"""Remove entire neurons/filters"""
for name, module in self.model.named_modules():
if isinstance(module, nn.Linear):
# Remove neurons with lowest L2 norm
prune.ln_structured(
module, name='weight', amount=0.5, n=2, dim=0
)
elif isinstance(module, nn.Conv2d):
# Remove filters with lowest L2 norm
prune.ln_structured(
module, name='weight', amount=0.3, n=2, dim=0
)
def gradual_pruning(self, initial_sparsity=0.1, final_sparsity=0.9, steps=10):
"""Gradually increase sparsity during training"""
sparsity_schedule = torch.linspace(
initial_sparsity, final_sparsity, steps
)
for step, sparsity in enumerate(sparsity_schedule):
self.apply_global_pruning(sparsity)
# Train for several epochs
self.train_epoch()
def apply_global_pruning(self, sparsity):
"""Global magnitude-based pruning across all layers"""
parameters_to_prune = []
for name, module in self.model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
parameters_to_prune.append((module, 'weight'))
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=sparsity,
)
# Usage example
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
pruner = StructuredPruner(model, sparsity=0.9)
# Apply different pruning strategies
pruner.magnitude_based_pruning()
print(f"Model sparsity: {prune.global_sparsity(model):.2%}")
# Remove pruning masks to get actual compressed model
for module in model.modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
prune.remove(module, 'weight')2. Advanced Quantization Techniques
import torch
import torch.quantization as quant
from transformers import AutoModel, AutoTokenizer
class AdvancedQuantizer:
def __init__(self, model, calibration_data):
self.model = model
self.calibration_data = calibration_data
def post_training_quantization(self):
"""PTQ: Quantize after training without retraining"""
# Prepare model for quantization
self.model.eval()
self.model.qconfig = quant.get_default_qconfig('fbgemm')
quant.prepare(self.model, inplace=True)
# Calibrate with representative data
with torch.no_grad():
for batch in self.calibration_data:
self.model(batch)
# Convert to quantized model
quantized_model = quant.convert(self.model)
return quantized_model
def quantization_aware_training(self):
"""QAT: Fine-tune model during quantization"""
self.model.train()
self.model.qconfig = quant.get_default_qat_qconfig('fbgemm')
# Prepare for QAT
quant.prepare_qat(self.model, inplace=True)
# Fine-tune for several epochs with quantization simulation
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.0001)
for epoch in range(3): # Few epochs of fine-tuning
for batch in self.calibration_data:
optimizer.zero_grad()
output = self.model(batch)
loss = self.compute_loss(output, batch)
loss.backward()
optimizer.step()
# Convert to quantized inference model
self.model.eval()
quantized_model = quant.convert(self.model)
return quantized_model
# Advanced LLM Quantization with GPTQ/AWQ
class LLMQuantizer:
def __init__(self, model_name):
self.model_name = model_name
def gptq_quantization(self, bits=4):
"""GPTQ: Optimal quantization for generative models"""
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
# Configure quantization
quantize_config = BaseQuantizeConfig(
bits=bits,
group_size=128,
desc_act=False,
)
# Load and quantize model
model = AutoGPTQForCausalLM.from_pretrained(
self.model_name,
quantize_config=quantize_config
)
# Quantize with calibration dataset
model.quantize(calibration_dataset)
return model
def awq_quantization(self):
"""AWQ: Activation-aware weight quantization"""
from awq import AutoAWQForCausalLM
model = AutoAWQForCausalLM.from_pretrained(self.model_name)
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
# Quantize with activation-aware method
model.quantize(tokenizer, quant_config={
"zero_point": True,
"q_group_size": 128,
"w_bit": 4,
"version": "GEMM"
})
return model
# Usage
quantizer = LLMQuantizer('meta-llama/Llama-2-7b-hf')
gptq_model = quantizer.gptq_quantization(bits=4)
print(f"Model size reduced by ~{32/4}x with 4-bit GPTQ")3. Knowledge Distillation Framework
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig
class KnowledgeDistiller:
def __init__(self, teacher_model, student_config, alpha=0.7, temperature=4):
self.teacher = teacher_model
self.teacher.eval()
# Create smaller student model
self.student = self.create_student_model(student_config)
self.alpha = alpha # Weight for distillation loss
self.temperature = temperature # Softmax temperature
def create_student_model(self, config):
"""Create smaller student model with same architecture"""
student_config = AutoConfig.from_pretrained(config['base_model'])
# Reduce model size
student_config.hidden_size = config.get('hidden_size', 384)
student_config.num_attention_heads = config.get('num_heads', 6)
student_config.num_hidden_layers = config.get('num_layers', 6)
student_config.intermediate_size = config.get('intermediate_size', 1536)
return AutoModel.from_config(student_config)
def distillation_loss(self, student_logits, teacher_logits, labels):
"""Compute combined distillation and task loss"""
# Soft targets from teacher
teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
student_log_probs = F.log_softmax(
student_logits / self.temperature, dim=-1
)
# KL divergence loss
distill_loss = F.kl_div(
student_log_probs,
teacher_probs,
reduction='batchmean'
) * (self.temperature ** 2)
# Task-specific loss
task_loss = F.cross_entropy(student_logits, labels)
# Combined loss
total_loss = (
self.alpha * distill_loss +
(1 - self.alpha) * task_loss
)
return total_loss, distill_loss, task_loss
def feature_distillation(self, student_features, teacher_features):
"""Match intermediate layer representations"""
feature_losses = []
for s_feat, t_feat in zip(student_features, teacher_features):
# Project student features to teacher dimension if needed
if s_feat.size(-1) != t_feat.size(-1):
projection = nn.Linear(s_feat.size(-1), t_feat.size(-1))
s_feat = projection(s_feat)
# MSE loss between features
feat_loss = F.mse_loss(s_feat, t_feat.detach())
feature_losses.append(feat_loss)
return sum(feature_losses) / len(feature_losses)
def train_student(self, train_loader, num_epochs=5):
"""Train student model with knowledge distillation"""
optimizer = torch.optim.AdamW(self.student.parameters(), lr=5e-4)
for epoch in range(num_epochs):
total_loss = 0
for batch in train_loader:
inputs, labels = batch
# Teacher predictions (no gradients)
with torch.no_grad():
teacher_outputs = self.teacher(inputs)
teacher_logits = teacher_outputs.logits
# Student predictions
student_outputs = self.student(inputs)
student_logits = student_outputs.logits
# Compute distillation loss
loss, distill_loss, task_loss = self.distillation_loss(
student_logits, teacher_logits, labels
)
# Optionally add feature distillation
if hasattr(student_outputs, 'hidden_states'):
feat_loss = self.feature_distillation(
student_outputs.hidden_states,
teacher_outputs.hidden_states
)
loss += 0.1 * feat_loss
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}: Loss = {total_loss/len(train_loader):.4f}")
return self.student
# Progressive Knowledge Distillation
class ProgressiveDistiller:
def __init__(self, teacher_model, compression_stages):
self.teacher = teacher_model
self.stages = compression_stages
def multi_stage_distillation(self, train_loader):
"""Gradually compress model through multiple stages"""
current_teacher = self.teacher
for stage, config in enumerate(self.stages):
print(f"Stage {stage + 1}: Creating student with {config}")
distiller = KnowledgeDistiller(current_teacher, config)
student = distiller.train_student(train_loader)
# Use current student as teacher for next stage
current_teacher = student
return current_teacher
# Usage Example
stages = [
{'hidden_size': 512, 'num_layers': 8, 'num_heads': 8}, # Stage 1
{'hidden_size': 384, 'num_layers': 6, 'num_heads': 6}, # Stage 2
{'hidden_size': 256, 'num_layers': 4, 'num_heads': 4}, # Stage 3
]
progressive_distiller = ProgressiveDistiller(teacher_model, stages)
final_student = progressive_distiller.multi_stage_distillation(train_loader)🏭 Production Compression Examples
M
MobileBERT
Google's compressed BERT
- Technique:Knowledge Distillation
- Size Reduction:4x smaller (25M params)
- Speed:5.5x faster inference
- Quality:99.2% of BERT performance
D
DistilBERT
Hugging Face distilled model
- Technique:Triple Loss Distillation
- Size Reduction:2x smaller (66M params)
- Speed:1.6x faster
- Quality:97% of BERT performance
L
Llama 2 4-bit
GPTQ/AWQ quantized
- Technique:GPTQ/AWQ 4-bit
- Size Reduction:8x smaller memory
- Speed:2-3x faster inference
- Quality:98.5% performance retained
E
EfficientNet
Neural architecture search
- Technique:Architecture + Pruning
- Size Reduction:10x smaller than ResNet
- Speed:6.1x faster
- Quality:Better accuracy
✅ Compression Best Practices
✅ Do's
- Profile before compressing - Identify actual bottlenecks and most important layers
- Use calibration data - Representative samples for quantization calibration
- Combine techniques carefully - Pruning + quantization can be complementary
- Validate on diverse datasets - Ensure compression doesn't hurt edge cases
- Consider hardware constraints - Match compression to target deployment platform
❌ Don'ts
- Don't compress blindly - Understand model structure and importance of different layers
- Avoid extreme compression ratios - >95% sparsity often hurts quality significantly
- Don't ignore inference framework - Some formats aren't supported on all hardware
- Avoid one-size-fits-all - Different layers may need different compression strategies
- Don't skip quality validation - Always benchmark compressed model performance
No quiz questions available
Quiz ID "model-compression-techniques" not found