Skip to main contentSkip to user menuSkip to navigation

FlashAttention & Memory Optimization

Master efficient attention mechanisms and GPU memory optimization for transformers

55 min readAdvanced
Not Started
Loading...

What is FlashAttention?

FlashAttention is a memory-efficient attention algorithm that reduces the quadratic memory complexity of standard attention from O(N²) to O(N), enabling training and inference of much longer sequences without sacrificing accuracy.

Memory Efficient
O(N) vs O(N²) complexity
IO Aware
Optimized for GPU hierarchy
Exact Attention
Same output as standard attention

🧮 Attention Memory Calculator

Compare memory usage and performance between standard and FlashAttention implementations.

Memory Analysis

Current Memory:4.0 MB
Standard Attention:64.0 MB
FlashAttention:4.0 MB
Memory Reduction:93.8%
FLOP Reduction:99.2%
Max Seq Length:8192

FlashAttention Algorithm

❌ Standard Attention Problem

  • • Materializes full N×N attention matrix
  • • O(N²) memory complexity
  • • Memory bottleneck for long sequences
  • • Multiple memory reads/writes
  • • GPU memory fragmentation

✅ FlashAttention Solution

  • • Tile-based computation
  • • O(N) memory complexity
  • • Online softmax computation
  • • Optimized memory access patterns
  • • No attention matrix materialization

Core Algorithm

FlashAttention Implementation
import torch
import torch.nn as nn
import math

class FlashAttention(nn.Module):
    def __init__(self, dropout=0.0):
        super().__init__()
        self.dropout = dropout
        self.scale = None
    
    def forward(self, q, k, v, block_size=128):
        """
        FlashAttention forward pass
        Args:
            q, k, v: Query, Key, Value tensors [batch, heads, seq_len, head_dim]
            block_size: Tile size for memory efficiency
        """
        batch_size, num_heads, seq_len, head_dim = q.shape
        self.scale = 1.0 / math.sqrt(head_dim)
        
        # Initialize output and statistics
        O = torch.zeros_like(q)
        l = torch.zeros(batch_size, num_heads, seq_len, 1, device=q.device)
        m = torch.full((batch_size, num_heads, seq_len, 1), -float('inf'), device=q.device)
        
        # Number of blocks
        num_blocks = math.ceil(seq_len / block_size)
        
        for i in range(num_blocks):
            # Query block indices
            q_start = i * block_size
            q_end = min((i + 1) * block_size, seq_len)
            
            # Extract query block
            q_i = q[:, :, q_start:q_end, :]  # [B, H, block_size, D]
            O_i = O[:, :, q_start:q_end, :]
            l_i = l[:, :, q_start:q_end, :]
            m_i = m[:, :, q_start:q_end, :]
            
            for j in range(num_blocks):
                # Key-Value block indices  
                kv_start = j * block_size
                kv_end = min((j + 1) * block_size, seq_len)
                
                # Extract key-value blocks
                k_j = k[:, :, kv_start:kv_end, :]  # [B, H, block_size, D]
                v_j = v[:, :, kv_start:kv_end, :]  # [B, H, block_size, D]
                
                # Compute attention scores for this block
                S_ij = torch.matmul(q_i, k_j.transpose(-2, -1)) * self.scale
                
                # Apply causal mask if needed (for autoregressive models)
                if q_start >= kv_start:  # Only attend to previous positions
                    causal_mask = torch.triu(
                        torch.ones(q_i.size(-2), k_j.size(-2)), 
                        diagonal=max(0, kv_start - q_start + 1)
                    ).bool().to(q.device)
                    S_ij = S_ij.masked_fill(causal_mask, -float('inf'))
                
                # Online softmax computation
                m_ij = torch.max(S_ij, dim=-1, keepdim=True)[0]
                P_ij = torch.exp(S_ij - m_ij)
                l_ij = torch.sum(P_ij, dim=-1, keepdim=True)
                
                # Update running statistics
                m_i_new = torch.maximum(m_i, m_ij)
                alpha = torch.exp(m_i - m_i_new)
                beta = torch.exp(m_ij - m_i_new)
                
                l_i_new = alpha * l_i + beta * l_ij
                O_i_new = (alpha * l_i * O_i + beta * torch.matmul(P_ij, v_j)) / l_i_new
                
                # Update for next iteration
                O_i = O_i_new
                l_i = l_i_new
                m_i = m_i_new
            
            # Store updated output block
            O[:, :, q_start:q_end, :] = O_i
            l[:, :, q_start:q_end, :] = l_i
            m[:, :, q_start:q_end, :] = m_i
        
        return O

# Usage example
flash_attn = FlashAttention()
B, H, N, D = 2, 8, 1024, 64

q = torch.randn(B, H, N, D, requires_grad=True)
k = torch.randn(B, H, N, D, requires_grad=True) 
v = torch.randn(B, H, N, D, requires_grad=True)

# Memory-efficient attention computation
output = flash_attn(q, k, v, block_size=128)
print(f"Output shape: {output.shape}")
print(f"Memory efficient: O({N}) vs O({N}²)")

Advanced Memory Optimization Techniques

Gradient Checkpointing

Memory-Efficient Training
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

class MemoryEfficientTransformerLayer(nn.Module):
    def __init__(self, d_model, nhead, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.linear1 = nn.Linear(d_model, d_model * 4)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_model * 4, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward_chunk(self, x):
        """Forward pass for a chunk - used with gradient checkpointing"""
        # Self-attention block
        x2, _ = self.self_attn(x, x, x)
        x = x + self.dropout1(x2)
        x = self.norm1(x)
        
        # Feed-forward block  
        x2 = self.linear2(self.dropout(torch.relu(self.linear1(x))))
        x = x + self.dropout2(x2)
        x = self.norm2(x)
        return x
    
    def forward(self, x, use_checkpoint=True):
        if use_checkpoint and self.training:
            # Use gradient checkpointing to save memory
            return checkpoint(self.forward_chunk, x)
        else:
            return self.forward_chunk(x)

class MemoryEfficientTransformer(nn.Module):
    def __init__(self, d_model, nhead, num_layers, max_seq_len, vocab_size):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.Parameter(torch.randn(1, max_seq_len, d_model))
        
        self.layers = nn.ModuleList([
            MemoryEfficientTransformerLayer(d_model, nhead) 
            for _ in range(num_layers)
        ])
        self.output_proj = nn.Linear(d_model, vocab_size)
        
    def forward(self, x, use_checkpoint=True):
        seq_len = x.size(1)
        
        # Embedding + positional encoding
        x = self.embedding(x) * math.sqrt(self.d_model)
        x = x + self.pos_encoding[:, :seq_len, :]
        
        # Apply transformer layers with optional checkpointing
        for layer in self.layers:
            x = layer(x, use_checkpoint=use_checkpoint)
            
        return self.output_proj(x)

# Memory usage comparison
def compare_memory_usage():
    model = MemoryEfficientTransformer(
        d_model=768, nhead=12, num_layers=12,
        max_seq_len=2048, vocab_size=50000
    )
    
    x = torch.randint(0, 50000, (4, 1024))  # Batch of sequences
    
    # Measure memory with checkpointing
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    
    with torch.cuda.device(0):
        output_checkpoint = model(x, use_checkpoint=True)
        loss_checkpoint = output_checkpoint.sum()
        loss_checkpoint.backward()
        memory_checkpoint = torch.cuda.max_memory_allocated() / 1e9
    
    # Measure memory without checkpointing  
    model.zero_grad()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    
    with torch.cuda.device(0):
        output_normal = model(x, use_checkpoint=False)
        loss_normal = output_normal.sum()
        loss_normal.backward()
        memory_normal = torch.cuda.max_memory_allocated() / 1e9
    
    print(f"Memory with checkpointing: {memory_checkpoint:.2f} GB")
    print(f"Memory without checkpointing: {memory_normal:.2f} GB") 
    print(f"Memory reduction: {((memory_normal - memory_checkpoint) / memory_normal * 100):.1f}%")

# compare_memory_usage()

Dynamic Sequence Packing

Efficient Batch Processing
import torch
import torch.nn.functional as F
from typing import List, Tuple

class SequencePacker:
    """Pack variable-length sequences efficiently for memory optimization"""
    
    def __init__(self, max_seq_len: int = 2048):
        self.max_seq_len = max_seq_len
    
    def pack_sequences(self, sequences: List[torch.Tensor], 
                      attention_masks: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
        """
        Pack multiple sequences into efficient batches
        Returns packed sequences, attention masks, and split indices
        """
        total_length = sum(seq.size(0) for seq in sequences)
        
        if total_length <= self.max_seq_len:
            # Single packed sequence
            packed_seq = torch.cat(sequences, dim=0)
            packed_mask = torch.cat(attention_masks, dim=0)
            split_indices = [seq.size(0) for seq in sequences]
            return packed_seq.unsqueeze(0), packed_mask.unsqueeze(0), split_indices
        
        # Multiple packed sequences needed
        packed_sequences = []
        packed_masks = []
        all_split_indices = []
        
        current_batch = []
        current_masks = []
        current_length = 0
        
        for seq, mask in zip(sequences, attention_masks):
            seq_len = seq.size(0)
            
            if current_length + seq_len > self.max_seq_len and current_batch:
                # Finalize current batch
                batch_seq = torch.cat(current_batch, dim=0)
                batch_mask = torch.cat(current_masks, dim=0)
                
                # Pad to max_seq_len
                padding_len = self.max_seq_len - batch_seq.size(0)
                if padding_len > 0:
                    batch_seq = F.pad(batch_seq, (0, 0, 0, padding_len))
                    batch_mask = F.pad(batch_mask, (0, padding_len))
                
                packed_sequences.append(batch_seq)
                packed_masks.append(batch_mask)
                all_split_indices.append([s.size(0) for s in current_batch])
                
                # Start new batch
                current_batch = [seq]
                current_masks = [mask]
                current_length = seq_len
            else:
                current_batch.append(seq)
                current_masks.append(mask)
                current_length += seq_len
        
        # Handle remaining sequences
        if current_batch:
            batch_seq = torch.cat(current_batch, dim=0)
            batch_mask = torch.cat(current_masks, dim=0)
            
            padding_len = self.max_seq_len - batch_seq.size(0)
            if padding_len > 0:
                batch_seq = F.pad(batch_seq, (0, 0, 0, padding_len))
                batch_mask = F.pad(batch_mask, (0, padding_len))
            
            packed_sequences.append(batch_seq)
            packed_masks.append(batch_mask)
            all_split_indices.append([s.size(0) for s in current_batch])
        
        return (torch.stack(packed_sequences), 
                torch.stack(packed_masks), 
                all_split_indices)
    
    def unpack_sequences(self, packed_output: torch.Tensor, 
                        split_indices: List[List[int]]) -> List[torch.Tensor]:
        """Unpack the model outputs back to individual sequences"""
        all_sequences = []
        
        for batch_idx, batch_splits in enumerate(split_indices):
            batch_output = packed_output[batch_idx]
            start_idx = 0
            
            for seq_len in batch_splits:
                seq_output = batch_output[start_idx:start_idx + seq_len]
                all_sequences.append(seq_output)
                start_idx += seq_len
        
        return all_sequences

# Memory-efficient training with sequence packing
class PackedDataLoader:
    def __init__(self, sequences, batch_size=8, max_seq_len=2048):
        self.sequences = sequences
        self.batch_size = batch_size
        self.packer = SequencePacker(max_seq_len)
    
    def __iter__(self):
        # Sort sequences by length for better packing
        sorted_sequences = sorted(self.sequences, key=lambda x: x[0].size(0))
        
        batch_sequences = []
        batch_masks = []
        
        for seq, mask in sorted_sequences:
            batch_sequences.append(seq)
            batch_masks.append(mask)
            
            if len(batch_sequences) >= self.batch_size:
                # Pack and yield batch
                packed_seq, packed_mask, split_indices = self.packer.pack_sequences(
                    batch_sequences, batch_masks
                )
                yield packed_seq, packed_mask, split_indices
                
                batch_sequences = []
                batch_masks = []
        
        # Handle remaining sequences
        if batch_sequences:
            packed_seq, packed_mask, split_indices = self.packer.pack_sequences(
                batch_sequences, batch_masks
            )
            yield packed_seq, packed_mask, split_indices

# Usage example
sequences = [torch.randn(200, 768), torch.randn(150, 768), torch.randn(300, 768)]
masks = [torch.ones(200), torch.ones(150), torch.ones(300)]

packer = SequencePacker(max_seq_len=512)
packed_seq, packed_mask, split_indices = packer.pack_sequences(sequences, masks)

print(f"Original sequences: {[s.shape[0] for s in sequences]}")
print(f"Packed shape: {packed_seq.shape}")
print(f"Memory efficiency: {sum(s.size(0) for s in sequences) / (packed_seq.size(0) * packed_seq.size(1)):.2f}")

Production FlashAttention Service

High-Performance Attention Server

production_flash_attention.py
import torch
import torch.nn as nn
from typing import Optional, Tuple, Dict, Any
import asyncio
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
from contextlib import asynccontextmanager
import logging
import time

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class AttentionRequest(BaseModel):
    query: list  # [batch_size, seq_len, d_model]
    key: list    # [batch_size, seq_len, d_model] 
    value: list  # [batch_size, seq_len, d_model]
    num_heads: int = 8
    block_size: int = 128
    use_flash: bool = True
    causal_mask: bool = False

class AttentionResponse(BaseModel):
    output: list
    processing_time: float
    memory_used: float
    sequence_length: int
    attention_type: str

class OptimizedFlashAttention(nn.Module):
    """Production-optimized FlashAttention implementation"""
    
    def __init__(self):
        super().__init__()
        self.attention_cache = {}
        self.max_cache_size = 100
    
    def _get_cache_key(self, shape: Tuple[int, ...], dtype: torch.dtype, device: str) -> str:
        return f"{shape}_{dtype}_{device}"
    
    def _get_cached_tensors(self, q: torch.Tensor) -> Dict[str, torch.Tensor]:
        """Get cached intermediate tensors to avoid reallocation"""
        key = self._get_cache_key(q.shape, q.dtype, str(q.device))
        
        if key not in self.attention_cache:
            if len(self.attention_cache) >= self.max_cache_size:
                # Remove oldest entry
                oldest_key = next(iter(self.attention_cache))
                del self.attention_cache[oldest_key]
            
            batch_size, num_heads, seq_len, head_dim = q.shape
            self.attention_cache[key] = {
                'O': torch.zeros_like(q),
                'l': torch.zeros(batch_size, num_heads, seq_len, 1, device=q.device, dtype=q.dtype),
                'm': torch.full((batch_size, num_heads, seq_len, 1), -float('inf'), 
                               device=q.device, dtype=q.dtype)
            }
        
        return self.attention_cache[key]
    
    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                block_size: int = 128, causal_mask: bool = False) -> torch.Tensor:
        """
        Optimized FlashAttention forward pass
        """
        batch_size, num_heads, seq_len, head_dim = q.shape
        scale = 1.0 / math.sqrt(head_dim)
        
        # Get cached tensors
        cached = self._get_cached_tensors(q)
        O = cached['O'].zero_()
        l = cached['l'].zero_()
        m = cached['m'].fill_(-float('inf'))
        
        num_blocks = (seq_len + block_size - 1) // block_size
        
        # Use torch.compile for better performance (PyTorch 2.0+)
        @torch.compile
        def compute_block_attention(q_block, k_block, v_block, m_prev, l_prev, O_prev):
            # Compute attention scores
            scores = torch.matmul(q_block, k_block.transpose(-2, -1)) * scale
            
            # Apply causal mask if needed
            if causal_mask:
                mask = torch.triu(torch.ones_like(scores, dtype=torch.bool), diagonal=1)
                scores = scores.masked_fill(mask, -float('inf'))
            
            # Numerically stable softmax
            m_curr = torch.max(scores, dim=-1, keepdim=True)[0]
            m_new = torch.maximum(m_prev, m_curr)
            
            alpha = torch.exp(m_prev - m_new)
            beta = torch.exp(m_curr - m_new)
            
            p_curr = torch.exp(scores - m_new)
            l_curr = torch.sum(p_curr, dim=-1, keepdim=True)
            l_new = alpha * l_prev + beta * l_curr
            
            O_new = (alpha * l_prev * O_prev + beta * torch.matmul(p_curr, v_block)) / l_new
            
            return O_new, l_new, m_new
        
        # Process blocks
        for i in range(num_blocks):
            q_start, q_end = i * block_size, min((i + 1) * block_size, seq_len)
            
            for j in range(num_blocks):
                if causal_mask and j > i:
                    continue  # Skip future blocks for causal attention
                
                kv_start, kv_end = j * block_size, min((j + 1) * block_size, seq_len)
                
                q_block = q[:, :, q_start:q_end, :]
                k_block = k[:, :, kv_start:kv_end, :]
                v_block = v[:, :, kv_start:kv_end, :]
                
                O_block, l_block, m_block = compute_block_attention(
                    q_block, k_block, v_block,
                    m[:, :, q_start:q_end, :],
                    l[:, :, q_start:q_end, :],
                    O[:, :, q_start:q_end, :]
                )
                
                O[:, :, q_start:q_end, :] = O_block
                l[:, :, q_start:q_end, :] = l_block
                m[:, :, q_start:q_end, :] = m_block
        
        return O

class AttentionService:
    """Production attention service with monitoring and optimization"""
    
    def __init__(self):
        self.flash_attention = OptimizedFlashAttention()
        self.standard_attention = nn.MultiheadAttention(
            embed_dim=768, num_heads=8, batch_first=True
        )
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.flash_attention.to(self.device)
        self.standard_attention.to(self.device)
        
        # Performance monitoring
        self.request_count = 0
        self.total_processing_time = 0.0
        self.memory_usage_history = []
    
    async def process_attention(self, request: AttentionRequest) -> AttentionResponse:
        """Process attention request with monitoring"""
        start_time = time.time()
        
        try:
            # Convert inputs to tensors
            q = torch.tensor(request.query, device=self.device, dtype=torch.float16)
            k = torch.tensor(request.key, device=self.device, dtype=torch.float16)  
            v = torch.tensor(request.value, device=self.device, dtype=torch.float16)
            
            batch_size, seq_len, d_model = q.shape
            
            # Reshape for multi-head attention
            head_dim = d_model // request.num_heads
            q = q.view(batch_size, seq_len, request.num_heads, head_dim).transpose(1, 2)
            k = k.view(batch_size, seq_len, request.num_heads, head_dim).transpose(1, 2)
            v = v.view(batch_size, seq_len, request.num_heads, head_dim).transpose(1, 2)
            
            # Memory monitoring
            if torch.cuda.is_available():
                torch.cuda.reset_peak_memory_stats()
            
            # Process attention
            with torch.no_grad():
                if request.use_flash:
                    output = self.flash_attention(q, k, v, request.block_size, request.causal_mask)
                    attention_type = "FlashAttention"
                else:
                    # Standard attention fallback
                    q_std = q.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
                    k_std = k.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
                    v_std = v.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
                    
                    output, _ = self.standard_attention(q_std, k_std, v_std)
                    output = output.view(batch_size, seq_len, request.num_heads, head_dim).transpose(1, 2)
                    attention_type = "StandardAttention"
            
            # Reshape output back
            output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
            
            # Memory usage calculation
            memory_used = 0.0
            if torch.cuda.is_available():
                memory_used = torch.cuda.max_memory_allocated() / 1e6  # MB
            
            processing_time = time.time() - start_time
            
            # Update monitoring metrics
            self.request_count += 1
            self.total_processing_time += processing_time
            self.memory_usage_history.append(memory_used)
            
            if len(self.memory_usage_history) > 1000:
                self.memory_usage_history.pop(0)
            
            logger.info(f"Processed {attention_type} request: {seq_len} tokens, "
                       f"{processing_time:.3f}s, {memory_used:.1f}MB")
            
            return AttentionResponse(
                output=output.cpu().tolist(),
                processing_time=processing_time,
                memory_used=memory_used,
                sequence_length=seq_len,
                attention_type=attention_type
            )
            
        except Exception as e:
            logger.error(f"Error processing attention request: {e}")
            raise HTTPException(status_code=500, detail=str(e))
    
    def get_stats(self) -> Dict[str, Any]:
        """Get service performance statistics"""
        avg_time = self.total_processing_time / max(1, self.request_count)
        avg_memory = sum(self.memory_usage_history) / max(1, len(self.memory_usage_history))
        
        return {
            "total_requests": self.request_count,
            "average_processing_time": avg_time,
            "average_memory_usage": avg_memory,
            "device": str(self.device)
        }

# Initialize service
attention_service = AttentionService()

# FastAPI setup
@asynccontextmanager
async def lifespan(app: FastAPI):
    logger.info("Starting FlashAttention service")
    logger.info(f"Using device: {attention_service.device}")
    yield
    logger.info("Shutting down FlashAttention service")

app = FastAPI(title="FlashAttention Service", lifespan=lifespan)

@app.post("/attention", response_model=AttentionResponse)
async def process_attention_endpoint(request: AttentionRequest):
    """Process attention computation request"""
    return await attention_service.process_attention(request)

@app.get("/stats")
async def get_stats():
    """Get service performance statistics"""
    return attention_service.get_stats()

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    return {"status": "healthy", "device": str(attention_service.device)}

if __name__ == "__main__":
    uvicorn.run(
        "production_flash_attention:app",
        host="0.0.0.0",
        port=8000,
        workers=1,  # Single worker for GPU usage
        log_level="info"
    )

Real-World Examples

OpenAI GPT-4

Uses FlashAttention-2 for efficient training of 2 trillion parameter models with 32K context windows.

  • • 10x memory reduction vs standard attention
  • • 32K token context length
  • • 50% faster training throughput

Anthropic Claude

Leverages memory-optimized attention for 100K+ token context understanding with constitutional AI training.

  • • 100K+ token context windows
  • • Memory-efficient constitutional AI
  • • Real-time inference optimization

Meta LLaMA-2

Implements FlashAttention for open-source LLM training with optimized memory usage on commodity hardware.

  • • 70B parameter model training
  • • Commodity GPU compatibility
  • • 4x memory efficiency improvement

FlashAttention Best Practices

✅ Do's

  • Use block sizes that are multiples of 64 for GPU efficiency
  • Enable gradient checkpointing for very long sequences
  • Use FP16/BF16 precision for memory savings
  • Implement sequence packing for variable length inputs
  • Monitor GPU memory utilization during training
  • Use async preprocessing for better throughput

❌ Don'ts

  • Don't use very small block sizes (<64) - inefficient
  • Don't mix FlashAttention with standard attention layers
  • Don't ignore memory fragmentation in long sequences
  • Don't use FlashAttention for very short sequences
  • Don't forget to clear CUDA cache between runs
  • Don't skip benchmarking on your specific hardware
No quiz questions available
Quiz ID "flashattention-memory-optimization" not found