Model Serving Patterns
Master production model serving: FastAPI deployment, GPU management, batch optimization, and distributed serving patterns
75 min read•Advanced
Not Started
Loading...
Model Serving Architecture Patterns
Production model serving requires careful selection of serving patterns based on latency requirements, throughput needs, and operational complexity. Each pattern offers different trade-offs between performance, scalability, and development complexity.
Key Considerations
- • Latency Requirements: Real-time vs batch processing
- • Throughput Needs: Requests per second targets
- • Resource Utilization: CPU, GPU, memory efficiency
- • Scalability: Auto-scaling and load handling
- • Operational Complexity: Deployment and maintenance
Performance Factors
- • Model Size: Parameter count and memory requirements
- • Batch Processing: Dynamic batching capabilities
- • GPU Management: Memory pooling and scheduling
- • Caching: Model weights and inference results
- • Optimization: Quantization, pruning, compilation
Implementation Deep Dive
FastAPI Model Serving
FastAPI provides a lightweight, high-performance framework for serving ML models with automatic API documentation, async support, and excellent Python ecosystem integration.
Advantages
- • Quick development and deployment
- • Automatic API documentation
- • Excellent Python ecosystem support
- • Built-in async/await support
- • Type hints and validation
Limitations
- • Manual batching implementation
- • Limited built-in optimizations
- • Requires custom GPU management
- • Basic monitoring capabilities
- • Scaling complexity
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import torch
import asyncio
import time
from typing import List, Optional, Dict, Any
import logging
from contextlib import asynccontextmanager
# Request/Response models
class InferenceRequest(BaseModel):
text: str = Field(..., description="Input text for inference")
max_length: Optional[int] = Field(100, description="Maximum output length")
temperature: Optional[float] = Field(0.7, description="Sampling temperature")
top_p: Optional[float] = Field(0.9, description="Top-p sampling")
class InferenceResponse(BaseModel):
generated_text: str
inference_time: float
model_version: str
request_id: str
class BatchInferenceRequest(BaseModel):
requests: List[InferenceRequest]
batch_id: Optional[str] = None
class HealthResponse(BaseModel):
status: str
model_loaded: bool
gpu_available: bool
memory_usage: Dict[str, float]
# Global model instance
model_manager = None
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
global model_manager
model_manager = ModelManager()
await model_manager.load_model()
yield
# Shutdown
await model_manager.cleanup()
# Initialize FastAPI app
app = FastAPI(
title="Production ML Model API",
description="High-performance model serving with FastAPI",
version="1.0.0",
lifespan=lifespan
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ModelManager:
def __init__(self):
self.model = None
self.tokenizer = None
self.device = None
self.batch_processor = None
self.model_version = "1.0.0"
async def load_model(self):
"""Load model with optimizations"""
try:
# Check GPU availability
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model and tokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "gpt2" # Replace with your model
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
# Optimize model
if torch.cuda.is_available():
self.model = torch.compile(self.model)
self.model.eval()
# Initialize batch processor
self.batch_processor = BatchProcessor(self.model, self.tokenizer, self.device)
logging.info(f"Model loaded successfully on {self.device}")
except Exception as e:
logging.error(f"Failed to load model: {e}")
raise
async def generate_text(self,
text: str,
max_length: int = 100,
temperature: float = 0.7,
top_p: float = 0.9) -> Dict[str, Any]:
"""Generate text with the model"""
start_time = time.time()
try:
# Tokenize input
inputs = self.tokenizer.encode(text, return_tensors="pt").to(self.device)
# Generate
with torch.inference_mode():
outputs = self.model.generate(
inputs,
max_new_tokens=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode output
generated_text = self.tokenizer.decode(
outputs[0][inputs.shape[1]:],
skip_special_tokens=True
)
inference_time = time.time() - start_time
return {
"generated_text": generated_text,
"inference_time": inference_time,
"model_version": self.model_version
}
except Exception as e:
logging.error(f"Inference failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
async def batch_generate(self, requests: List[InferenceRequest]) -> List[Dict[str, Any]]:
"""Process batch of requests efficiently"""
return await self.batch_processor.process_batch(requests)
def get_health_status(self) -> HealthResponse:
"""Get system health status"""
gpu_available = torch.cuda.is_available()
model_loaded = self.model is not None
memory_usage = {}
if gpu_available:
memory_usage = {
"gpu_allocated_gb": torch.cuda.memory_allocated() / 1e9,
"gpu_reserved_gb": torch.cuda.memory_reserved() / 1e9,
}
status = "healthy" if (model_loaded and (gpu_available or not torch.cuda.is_available())) else "unhealthy"
return HealthResponse(
status=status,
model_loaded=model_loaded,
gpu_available=gpu_available,
memory_usage=memory_usage
)
async def cleanup(self):
"""Clean up resources"""
if self.model:
del self.model
if torch.cuda.is_available():
torch.cuda.empty_cache()
class BatchProcessor:
"""Efficient batch processing for multiple requests"""
def __init__(self, model, tokenizer, device):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.max_batch_size = 8 # Adjust based on GPU memory
async def process_batch(self, requests: List[InferenceRequest]) -> List[Dict[str, Any]]:
"""Process batch of requests"""
batch_size = min(len(requests), self.max_batch_size)
results = []
# Process in chunks
for i in range(0, len(requests), batch_size):
batch = requests[i:i + batch_size]
batch_results = await self._process_batch_chunk(batch)
results.extend(batch_results)
return results
async def _process_batch_chunk(self, requests: List[InferenceRequest]) -> List[Dict[str, Any]]:
"""Process a single batch chunk"""
start_time = time.time()
try:
# Prepare batch inputs
texts = [req.text for req in requests]
max_lengths = [req.max_length for req in requests]
# Tokenize batch
batch_inputs = self.tokenizer(
texts,
padding=True,
truncation=True,
return_tensors="pt",
max_length=512
).to(self.device)
# Generate for batch
with torch.inference_mode():
batch_outputs = self.model.generate(
**batch_inputs,
max_new_tokens=max(max_lengths),
temperature=requests[0].temperature, # Use first request's params
top_p=requests[0].top_p,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
# Process outputs
results = []
for i, (request, output) in enumerate(zip(requests, batch_outputs)):
# Extract new tokens
input_length = batch_inputs.input_ids[i].shape[0]
generated_tokens = output[input_length:]
generated_text = self.tokenizer.decode(
generated_tokens,
skip_special_tokens=True
)
results.append({
"generated_text": generated_text,
"inference_time": (time.time() - start_time) / len(requests),
"model_version": "1.0.0"
})
return results
except Exception as e:
logging.error(f"Batch processing failed: {e}")
# Return error for all requests in batch
return [{"error": str(e)} for _ in requests]
# API Endpoints
@app.post("/v1/generate", response_model=InferenceResponse)
async def generate_text(request: InferenceRequest, background_tasks: BackgroundTasks):
"""Generate text from input"""
request_id = f"req_{int(time.time() * 1000)}"
try:
result = await model_manager.generate_text(
text=request.text,
max_length=request.max_length,
temperature=request.temperature,
top_p=request.top_p
)
# Log request in background
background_tasks.add_task(log_request, request_id, request.text, result)
return InferenceResponse(
generated_text=result["generated_text"],
inference_time=result["inference_time"],
model_version=result["model_version"],
request_id=request_id
)
except Exception as e:
logging.error(f"Generation failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/v1/batch-generate")
async def batch_generate_text(request: BatchInferenceRequest):
"""Generate text for batch of requests"""
try:
results = await model_manager.batch_generate(request.requests)
return {
"results": results,
"batch_id": request.batch_id,
"processed_count": len(results)
}
except Exception as e:
logging.error(f"Batch generation failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint"""
return model_manager.get_health_status()
@app.get("/metrics")
async def get_metrics():
"""Prometheus metrics endpoint"""
# Return metrics in Prometheus format
return {"message": "Implement Prometheus metrics here"}
def log_request(request_id: str, input_text: str, result: Dict):
"""Log request for monitoring"""
logging.info(f"Request {request_id}: {len(input_text)} chars -> {result['inference_time']:.3f}s")
# Run with: uvicorn app:app --host 0.0.0.0 --port 8000 --workers 1Production Optimization Strategies
Batching Strategies
- • Dynamic Batching: Variable batch sizes
- • Continuous Batching: Streaming requests
- • Priority Batching: Request prioritization
- • Adaptive Batching: Load-based adjustment
GPU Optimization
- • Memory Pooling: Efficient GPU memory use
- • Model Sharding: Large model distribution
- • Pipeline Parallelism: Sequential processing
- • Mixed Precision: FP16/INT8 optimization
System Architecture
- • Load Balancing: Multi-instance serving
- • Auto Scaling: Dynamic resource allocation
- • Health Monitoring: System observability
- • Circuit Breakers: Fault tolerance
Monitoring & Observability
Key Metrics
- Latency Metrics:
P50, P95, P99 response times
- Throughput Metrics:
Requests per second, batch efficiency
- Resource Metrics:
GPU utilization, memory usage
- Quality Metrics:
Model accuracy, prediction confidence
Observability Stack
Monitoring Integration
from prometheus_client import Counter, Histogram, Gauge, generate_latest
import time
import logging
from typing import Optional
# Define metrics
REQUEST_COUNT = Counter(
'model_serving_requests_total',
'Total number of requests',
['model_name', 'status', 'endpoint']
)
REQUEST_DURATION = Histogram(
'model_serving_request_duration_seconds',
'Request duration in seconds',
['model_name', 'endpoint'],
buckets=[0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0]
)
GPU_UTILIZATION = Gauge(
'model_serving_gpu_utilization_percent',
'GPU utilization percentage',
['gpu_id']
)
ACTIVE_REQUESTS = Gauge(
'model_serving_active_requests',
'Number of active requests',
['model_name']
)
class ServingMetrics:
def __init__(self, model_name: str):
self.model_name = model_name
def record_request(self,
endpoint: str,
duration: float,
status: str = 'success'):
"""Record request metrics"""
REQUEST_COUNT.labels(
model_name=self.model_name,
status=status,
endpoint=endpoint
).inc()
REQUEST_DURATION.labels(
model_name=self.model_name,
endpoint=endpoint
).observe(duration)
def track_active_requests(self, delta: int):
"""Track active request count"""
ACTIVE_REQUESTS.labels(
model_name=self.model_name
).inc(delta)
def record_gpu_utilization(self, gpu_id: int, utilization: float):
"""Record GPU utilization"""
GPU_UTILIZATION.labels(gpu_id=str(gpu_id)).set(utilization)
class RequestTracker:
"""Context manager for automatic request tracking"""
def __init__(self, metrics: ServingMetrics, endpoint: str):
self.metrics = metrics
self.endpoint = endpoint
self.start_time = None
def __enter__(self):
self.start_time = time.time()
self.metrics.track_active_requests(1)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
duration = time.time() - self.start_time
status = 'error' if exc_type else 'success'
self.metrics.record_request(
endpoint=self.endpoint,
duration=duration,
status=status
)
self.metrics.track_active_requests(-1)
if exc_type:
logging.error(f"Request failed: {exc_val}")
# Usage in serving endpoint
async def inference_endpoint(request_data):
metrics = ServingMetrics("my-model")
with RequestTracker(metrics, "inference") as tracker:
# Your inference logic here
result = await model.predict(request_data)
return resultNo quiz questions available
Quiz ID "model-serving-patterns" not found