Skip to main contentSkip to user menuSkip to navigation

Model Serving Patterns

Master production model serving: FastAPI deployment, GPU management, batch optimization, and distributed serving patterns

75 min readAdvanced
Not Started
Loading...

Model Serving Architecture Patterns

Production model serving requires careful selection of serving patterns based on latency requirements, throughput needs, and operational complexity. Each pattern offers different trade-offs between performance, scalability, and development complexity.

Key Considerations

  • Latency Requirements: Real-time vs batch processing
  • Throughput Needs: Requests per second targets
  • Resource Utilization: CPU, GPU, memory efficiency
  • Scalability: Auto-scaling and load handling
  • Operational Complexity: Deployment and maintenance

Performance Factors

  • Model Size: Parameter count and memory requirements
  • Batch Processing: Dynamic batching capabilities
  • GPU Management: Memory pooling and scheduling
  • Caching: Model weights and inference results
  • Optimization: Quantization, pruning, compilation

Implementation Deep Dive

FastAPI Model Serving

FastAPI provides a lightweight, high-performance framework for serving ML models with automatic API documentation, async support, and excellent Python ecosystem integration.

Advantages

  • • Quick development and deployment
  • • Automatic API documentation
  • • Excellent Python ecosystem support
  • • Built-in async/await support
  • • Type hints and validation

Limitations

  • • Manual batching implementation
  • • Limited built-in optimizations
  • • Requires custom GPU management
  • • Basic monitoring capabilities
  • • Scaling complexity
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import torch
import asyncio
import time
from typing import List, Optional, Dict, Any
import logging
from contextlib import asynccontextmanager

# Request/Response models
class InferenceRequest(BaseModel):
    text: str = Field(..., description="Input text for inference")
    max_length: Optional[int] = Field(100, description="Maximum output length")
    temperature: Optional[float] = Field(0.7, description="Sampling temperature")
    top_p: Optional[float] = Field(0.9, description="Top-p sampling")
    
class InferenceResponse(BaseModel):
    generated_text: str
    inference_time: float
    model_version: str
    request_id: str

class BatchInferenceRequest(BaseModel):
    requests: List[InferenceRequest]
    batch_id: Optional[str] = None

class HealthResponse(BaseModel):
    status: str
    model_loaded: bool
    gpu_available: bool
    memory_usage: Dict[str, float]

# Global model instance
model_manager = None

@asynccontextmanager
async def lifespan(app: FastAPI):
    # Startup
    global model_manager
    model_manager = ModelManager()
    await model_manager.load_model()
    yield
    # Shutdown
    await model_manager.cleanup()

# Initialize FastAPI app
app = FastAPI(
    title="Production ML Model API",
    description="High-performance model serving with FastAPI",
    version="1.0.0",
    lifespan=lifespan
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class ModelManager:
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.device = None
        self.batch_processor = None
        self.model_version = "1.0.0"
        
    async def load_model(self):
        """Load model with optimizations"""
        try:
            # Check GPU availability
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            
            # Load model and tokenizer
            from transformers import AutoModelForCausalLM, AutoTokenizer
            
            model_name = "gpt2"  # Replace with your model
            
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                device_map="auto" if torch.cuda.is_available() else None
            )
            
            # Optimize model
            if torch.cuda.is_available():
                self.model = torch.compile(self.model)
            
            self.model.eval()
            
            # Initialize batch processor
            self.batch_processor = BatchProcessor(self.model, self.tokenizer, self.device)
            
            logging.info(f"Model loaded successfully on {self.device}")
            
        except Exception as e:
            logging.error(f"Failed to load model: {e}")
            raise
    
    async def generate_text(self, 
                          text: str, 
                          max_length: int = 100,
                          temperature: float = 0.7,
                          top_p: float = 0.9) -> Dict[str, Any]:
        """Generate text with the model"""
        
        start_time = time.time()
        
        try:
            # Tokenize input
            inputs = self.tokenizer.encode(text, return_tensors="pt").to(self.device)
            
            # Generate
            with torch.inference_mode():
                outputs = self.model.generate(
                    inputs,
                    max_new_tokens=max_length,
                    temperature=temperature,
                    top_p=top_p,
                    do_sample=True,
                    pad_token_id=self.tokenizer.eos_token_id
                )
            
            # Decode output
            generated_text = self.tokenizer.decode(
                outputs[0][inputs.shape[1]:], 
                skip_special_tokens=True
            )
            
            inference_time = time.time() - start_time
            
            return {
                "generated_text": generated_text,
                "inference_time": inference_time,
                "model_version": self.model_version
            }
            
        except Exception as e:
            logging.error(f"Inference failed: {e}")
            raise HTTPException(status_code=500, detail=str(e))
    
    async def batch_generate(self, requests: List[InferenceRequest]) -> List[Dict[str, Any]]:
        """Process batch of requests efficiently"""
        
        return await self.batch_processor.process_batch(requests)
    
    def get_health_status(self) -> HealthResponse:
        """Get system health status"""
        
        gpu_available = torch.cuda.is_available()
        model_loaded = self.model is not None
        
        memory_usage = {}
        if gpu_available:
            memory_usage = {
                "gpu_allocated_gb": torch.cuda.memory_allocated() / 1e9,
                "gpu_reserved_gb": torch.cuda.memory_reserved() / 1e9,
            }
        
        status = "healthy" if (model_loaded and (gpu_available or not torch.cuda.is_available())) else "unhealthy"
        
        return HealthResponse(
            status=status,
            model_loaded=model_loaded,
            gpu_available=gpu_available,
            memory_usage=memory_usage
        )
    
    async def cleanup(self):
        """Clean up resources"""
        if self.model:
            del self.model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

class BatchProcessor:
    """Efficient batch processing for multiple requests"""
    
    def __init__(self, model, tokenizer, device):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.max_batch_size = 8  # Adjust based on GPU memory
        
    async def process_batch(self, requests: List[InferenceRequest]) -> List[Dict[str, Any]]:
        """Process batch of requests"""
        
        batch_size = min(len(requests), self.max_batch_size)
        results = []
        
        # Process in chunks
        for i in range(0, len(requests), batch_size):
            batch = requests[i:i + batch_size]
            batch_results = await self._process_batch_chunk(batch)
            results.extend(batch_results)
        
        return results
    
    async def _process_batch_chunk(self, requests: List[InferenceRequest]) -> List[Dict[str, Any]]:
        """Process a single batch chunk"""
        
        start_time = time.time()
        
        try:
            # Prepare batch inputs
            texts = [req.text for req in requests]
            max_lengths = [req.max_length for req in requests]
            
            # Tokenize batch
            batch_inputs = self.tokenizer(
                texts,
                padding=True,
                truncation=True,
                return_tensors="pt",
                max_length=512
            ).to(self.device)
            
            # Generate for batch
            with torch.inference_mode():
                batch_outputs = self.model.generate(
                    **batch_inputs,
                    max_new_tokens=max(max_lengths),
                    temperature=requests[0].temperature,  # Use first request's params
                    top_p=requests[0].top_p,
                    do_sample=True,
                    pad_token_id=self.tokenizer.eos_token_id
                )
            
            # Process outputs
            results = []
            for i, (request, output) in enumerate(zip(requests, batch_outputs)):
                # Extract new tokens
                input_length = batch_inputs.input_ids[i].shape[0]
                generated_tokens = output[input_length:]
                
                generated_text = self.tokenizer.decode(
                    generated_tokens, 
                    skip_special_tokens=True
                )
                
                results.append({
                    "generated_text": generated_text,
                    "inference_time": (time.time() - start_time) / len(requests),
                    "model_version": "1.0.0"
                })
            
            return results
            
        except Exception as e:
            logging.error(f"Batch processing failed: {e}")
            # Return error for all requests in batch
            return [{"error": str(e)} for _ in requests]

# API Endpoints
@app.post("/v1/generate", response_model=InferenceResponse)
async def generate_text(request: InferenceRequest, background_tasks: BackgroundTasks):
    """Generate text from input"""
    
    request_id = f"req_{int(time.time() * 1000)}"
    
    try:
        result = await model_manager.generate_text(
            text=request.text,
            max_length=request.max_length,
            temperature=request.temperature,
            top_p=request.top_p
        )
        
        # Log request in background
        background_tasks.add_task(log_request, request_id, request.text, result)
        
        return InferenceResponse(
            generated_text=result["generated_text"],
            inference_time=result["inference_time"],
            model_version=result["model_version"],
            request_id=request_id
        )
        
    except Exception as e:
        logging.error(f"Generation failed: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/v1/batch-generate")
async def batch_generate_text(request: BatchInferenceRequest):
    """Generate text for batch of requests"""
    
    try:
        results = await model_manager.batch_generate(request.requests)
        
        return {
            "results": results,
            "batch_id": request.batch_id,
            "processed_count": len(results)
        }
        
    except Exception as e:
        logging.error(f"Batch generation failed: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health", response_model=HealthResponse)
async def health_check():
    """Health check endpoint"""
    return model_manager.get_health_status()

@app.get("/metrics")
async def get_metrics():
    """Prometheus metrics endpoint"""
    # Return metrics in Prometheus format
    return {"message": "Implement Prometheus metrics here"}

def log_request(request_id: str, input_text: str, result: Dict):
    """Log request for monitoring"""
    logging.info(f"Request {request_id}: {len(input_text)} chars -> {result['inference_time']:.3f}s")

# Run with: uvicorn app:app --host 0.0.0.0 --port 8000 --workers 1

Production Optimization Strategies

Batching Strategies

  • Dynamic Batching: Variable batch sizes
  • Continuous Batching: Streaming requests
  • Priority Batching: Request prioritization
  • Adaptive Batching: Load-based adjustment

GPU Optimization

  • Memory Pooling: Efficient GPU memory use
  • Model Sharding: Large model distribution
  • Pipeline Parallelism: Sequential processing
  • Mixed Precision: FP16/INT8 optimization

System Architecture

  • Load Balancing: Multi-instance serving
  • Auto Scaling: Dynamic resource allocation
  • Health Monitoring: System observability
  • Circuit Breakers: Fault tolerance

Monitoring & Observability

Key Metrics

  • Latency Metrics:

    P50, P95, P99 response times

  • Throughput Metrics:

    Requests per second, batch efficiency

  • Resource Metrics:

    GPU utilization, memory usage

  • Quality Metrics:

    Model accuracy, prediction confidence

Observability Stack

Monitoring Integration
from prometheus_client import Counter, Histogram, Gauge, generate_latest
import time
import logging
from typing import Optional

# Define metrics
REQUEST_COUNT = Counter(
    'model_serving_requests_total',
    'Total number of requests',
    ['model_name', 'status', 'endpoint']
)

REQUEST_DURATION = Histogram(
    'model_serving_request_duration_seconds',
    'Request duration in seconds',
    ['model_name', 'endpoint'],
    buckets=[0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0]
)

GPU_UTILIZATION = Gauge(
    'model_serving_gpu_utilization_percent',
    'GPU utilization percentage',
    ['gpu_id']
)

ACTIVE_REQUESTS = Gauge(
    'model_serving_active_requests',
    'Number of active requests',
    ['model_name']
)

class ServingMetrics:
    def __init__(self, model_name: str):
        self.model_name = model_name
        
    def record_request(self, 
                      endpoint: str,
                      duration: float,
                      status: str = 'success'):
        """Record request metrics"""
        REQUEST_COUNT.labels(
            model_name=self.model_name,
            status=status,
            endpoint=endpoint
        ).inc()
        
        REQUEST_DURATION.labels(
            model_name=self.model_name,
            endpoint=endpoint
        ).observe(duration)
    
    def track_active_requests(self, delta: int):
        """Track active request count"""
        ACTIVE_REQUESTS.labels(
            model_name=self.model_name
        ).inc(delta)
    
    def record_gpu_utilization(self, gpu_id: int, utilization: float):
        """Record GPU utilization"""
        GPU_UTILIZATION.labels(gpu_id=str(gpu_id)).set(utilization)

class RequestTracker:
    """Context manager for automatic request tracking"""
    
    def __init__(self, metrics: ServingMetrics, endpoint: str):
        self.metrics = metrics
        self.endpoint = endpoint
        self.start_time = None
        
    def __enter__(self):
        self.start_time = time.time()
        self.metrics.track_active_requests(1)
        return self
        
    def __exit__(self, exc_type, exc_val, exc_tb):
        duration = time.time() - self.start_time
        status = 'error' if exc_type else 'success'
        
        self.metrics.record_request(
            endpoint=self.endpoint,
            duration=duration,
            status=status
        )
        
        self.metrics.track_active_requests(-1)
        
        if exc_type:
            logging.error(f"Request failed: {exc_val}")

# Usage in serving endpoint
async def inference_endpoint(request_data):
    metrics = ServingMetrics("my-model")
    
    with RequestTracker(metrics, "inference") as tracker:
        # Your inference logic here
        result = await model.predict(request_data)
        return result
No quiz questions available
Quiz ID "model-serving-patterns" not found