FlashAttention & Memory Optimization
Master efficient attention mechanisms and GPU memory optimization for transformers
55 min read•Advanced
Not Started
Loading...
What is FlashAttention?
FlashAttention is a memory-efficient attention algorithm that reduces the quadratic memory complexity of standard attention from O(N²) to O(N), enabling training and inference of much longer sequences without sacrificing accuracy.
Memory Efficient
O(N) vs O(N²) complexity
IO Aware
Optimized for GPU hierarchy
Exact Attention
Same output as standard attention
🧮 Attention Memory Calculator
Compare memory usage and performance between standard and FlashAttention implementations.
Memory Analysis
Current Memory:4.0 MB
Standard Attention:64.0 MB
FlashAttention:4.0 MB
Memory Reduction:93.8%
FLOP Reduction:99.2%
Max Seq Length:8192
FlashAttention Algorithm
❌ Standard Attention Problem
- • Materializes full N×N attention matrix
- • O(N²) memory complexity
- • Memory bottleneck for long sequences
- • Multiple memory reads/writes
- • GPU memory fragmentation
✅ FlashAttention Solution
- • Tile-based computation
- • O(N) memory complexity
- • Online softmax computation
- • Optimized memory access patterns
- • No attention matrix materialization
Core Algorithm
FlashAttention Implementation
import torch
import torch.nn as nn
import math
class FlashAttention(nn.Module):
def __init__(self, dropout=0.0):
super().__init__()
self.dropout = dropout
self.scale = None
def forward(self, q, k, v, block_size=128):
"""
FlashAttention forward pass
Args:
q, k, v: Query, Key, Value tensors [batch, heads, seq_len, head_dim]
block_size: Tile size for memory efficiency
"""
batch_size, num_heads, seq_len, head_dim = q.shape
self.scale = 1.0 / math.sqrt(head_dim)
# Initialize output and statistics
O = torch.zeros_like(q)
l = torch.zeros(batch_size, num_heads, seq_len, 1, device=q.device)
m = torch.full((batch_size, num_heads, seq_len, 1), -float('inf'), device=q.device)
# Number of blocks
num_blocks = math.ceil(seq_len / block_size)
for i in range(num_blocks):
# Query block indices
q_start = i * block_size
q_end = min((i + 1) * block_size, seq_len)
# Extract query block
q_i = q[:, :, q_start:q_end, :] # [B, H, block_size, D]
O_i = O[:, :, q_start:q_end, :]
l_i = l[:, :, q_start:q_end, :]
m_i = m[:, :, q_start:q_end, :]
for j in range(num_blocks):
# Key-Value block indices
kv_start = j * block_size
kv_end = min((j + 1) * block_size, seq_len)
# Extract key-value blocks
k_j = k[:, :, kv_start:kv_end, :] # [B, H, block_size, D]
v_j = v[:, :, kv_start:kv_end, :] # [B, H, block_size, D]
# Compute attention scores for this block
S_ij = torch.matmul(q_i, k_j.transpose(-2, -1)) * self.scale
# Apply causal mask if needed (for autoregressive models)
if q_start >= kv_start: # Only attend to previous positions
causal_mask = torch.triu(
torch.ones(q_i.size(-2), k_j.size(-2)),
diagonal=max(0, kv_start - q_start + 1)
).bool().to(q.device)
S_ij = S_ij.masked_fill(causal_mask, -float('inf'))
# Online softmax computation
m_ij = torch.max(S_ij, dim=-1, keepdim=True)[0]
P_ij = torch.exp(S_ij - m_ij)
l_ij = torch.sum(P_ij, dim=-1, keepdim=True)
# Update running statistics
m_i_new = torch.maximum(m_i, m_ij)
alpha = torch.exp(m_i - m_i_new)
beta = torch.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
O_i_new = (alpha * l_i * O_i + beta * torch.matmul(P_ij, v_j)) / l_i_new
# Update for next iteration
O_i = O_i_new
l_i = l_i_new
m_i = m_i_new
# Store updated output block
O[:, :, q_start:q_end, :] = O_i
l[:, :, q_start:q_end, :] = l_i
m[:, :, q_start:q_end, :] = m_i
return O
# Usage example
flash_attn = FlashAttention()
B, H, N, D = 2, 8, 1024, 64
q = torch.randn(B, H, N, D, requires_grad=True)
k = torch.randn(B, H, N, D, requires_grad=True)
v = torch.randn(B, H, N, D, requires_grad=True)
# Memory-efficient attention computation
output = flash_attn(q, k, v, block_size=128)
print(f"Output shape: {output.shape}")
print(f"Memory efficient: O({N}) vs O({N}²)")Advanced Memory Optimization Techniques
Gradient Checkpointing
Memory-Efficient Training
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
class MemoryEfficientTransformerLayer(nn.Module):
def __init__(self, d_model, nhead, dropout=0.1):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
self.linear1 = nn.Linear(d_model, d_model * 4)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_model * 4, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward_chunk(self, x):
"""Forward pass for a chunk - used with gradient checkpointing"""
# Self-attention block
x2, _ = self.self_attn(x, x, x)
x = x + self.dropout1(x2)
x = self.norm1(x)
# Feed-forward block
x2 = self.linear2(self.dropout(torch.relu(self.linear1(x))))
x = x + self.dropout2(x2)
x = self.norm2(x)
return x
def forward(self, x, use_checkpoint=True):
if use_checkpoint and self.training:
# Use gradient checkpointing to save memory
return checkpoint(self.forward_chunk, x)
else:
return self.forward_chunk(x)
class MemoryEfficientTransformer(nn.Module):
def __init__(self, d_model, nhead, num_layers, max_seq_len, vocab_size):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = nn.Parameter(torch.randn(1, max_seq_len, d_model))
self.layers = nn.ModuleList([
MemoryEfficientTransformerLayer(d_model, nhead)
for _ in range(num_layers)
])
self.output_proj = nn.Linear(d_model, vocab_size)
def forward(self, x, use_checkpoint=True):
seq_len = x.size(1)
# Embedding + positional encoding
x = self.embedding(x) * math.sqrt(self.d_model)
x = x + self.pos_encoding[:, :seq_len, :]
# Apply transformer layers with optional checkpointing
for layer in self.layers:
x = layer(x, use_checkpoint=use_checkpoint)
return self.output_proj(x)
# Memory usage comparison
def compare_memory_usage():
model = MemoryEfficientTransformer(
d_model=768, nhead=12, num_layers=12,
max_seq_len=2048, vocab_size=50000
)
x = torch.randint(0, 50000, (4, 1024)) # Batch of sequences
# Measure memory with checkpointing
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
with torch.cuda.device(0):
output_checkpoint = model(x, use_checkpoint=True)
loss_checkpoint = output_checkpoint.sum()
loss_checkpoint.backward()
memory_checkpoint = torch.cuda.max_memory_allocated() / 1e9
# Measure memory without checkpointing
model.zero_grad()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
with torch.cuda.device(0):
output_normal = model(x, use_checkpoint=False)
loss_normal = output_normal.sum()
loss_normal.backward()
memory_normal = torch.cuda.max_memory_allocated() / 1e9
print(f"Memory with checkpointing: {memory_checkpoint:.2f} GB")
print(f"Memory without checkpointing: {memory_normal:.2f} GB")
print(f"Memory reduction: {((memory_normal - memory_checkpoint) / memory_normal * 100):.1f}%")
# compare_memory_usage()Dynamic Sequence Packing
Efficient Batch Processing
import torch
import torch.nn.functional as F
from typing import List, Tuple
class SequencePacker:
"""Pack variable-length sequences efficiently for memory optimization"""
def __init__(self, max_seq_len: int = 2048):
self.max_seq_len = max_seq_len
def pack_sequences(self, sequences: List[torch.Tensor],
attention_masks: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
"""
Pack multiple sequences into efficient batches
Returns packed sequences, attention masks, and split indices
"""
total_length = sum(seq.size(0) for seq in sequences)
if total_length <= self.max_seq_len:
# Single packed sequence
packed_seq = torch.cat(sequences, dim=0)
packed_mask = torch.cat(attention_masks, dim=0)
split_indices = [seq.size(0) for seq in sequences]
return packed_seq.unsqueeze(0), packed_mask.unsqueeze(0), split_indices
# Multiple packed sequences needed
packed_sequences = []
packed_masks = []
all_split_indices = []
current_batch = []
current_masks = []
current_length = 0
for seq, mask in zip(sequences, attention_masks):
seq_len = seq.size(0)
if current_length + seq_len > self.max_seq_len and current_batch:
# Finalize current batch
batch_seq = torch.cat(current_batch, dim=0)
batch_mask = torch.cat(current_masks, dim=0)
# Pad to max_seq_len
padding_len = self.max_seq_len - batch_seq.size(0)
if padding_len > 0:
batch_seq = F.pad(batch_seq, (0, 0, 0, padding_len))
batch_mask = F.pad(batch_mask, (0, padding_len))
packed_sequences.append(batch_seq)
packed_masks.append(batch_mask)
all_split_indices.append([s.size(0) for s in current_batch])
# Start new batch
current_batch = [seq]
current_masks = [mask]
current_length = seq_len
else:
current_batch.append(seq)
current_masks.append(mask)
current_length += seq_len
# Handle remaining sequences
if current_batch:
batch_seq = torch.cat(current_batch, dim=0)
batch_mask = torch.cat(current_masks, dim=0)
padding_len = self.max_seq_len - batch_seq.size(0)
if padding_len > 0:
batch_seq = F.pad(batch_seq, (0, 0, 0, padding_len))
batch_mask = F.pad(batch_mask, (0, padding_len))
packed_sequences.append(batch_seq)
packed_masks.append(batch_mask)
all_split_indices.append([s.size(0) for s in current_batch])
return (torch.stack(packed_sequences),
torch.stack(packed_masks),
all_split_indices)
def unpack_sequences(self, packed_output: torch.Tensor,
split_indices: List[List[int]]) -> List[torch.Tensor]:
"""Unpack the model outputs back to individual sequences"""
all_sequences = []
for batch_idx, batch_splits in enumerate(split_indices):
batch_output = packed_output[batch_idx]
start_idx = 0
for seq_len in batch_splits:
seq_output = batch_output[start_idx:start_idx + seq_len]
all_sequences.append(seq_output)
start_idx += seq_len
return all_sequences
# Memory-efficient training with sequence packing
class PackedDataLoader:
def __init__(self, sequences, batch_size=8, max_seq_len=2048):
self.sequences = sequences
self.batch_size = batch_size
self.packer = SequencePacker(max_seq_len)
def __iter__(self):
# Sort sequences by length for better packing
sorted_sequences = sorted(self.sequences, key=lambda x: x[0].size(0))
batch_sequences = []
batch_masks = []
for seq, mask in sorted_sequences:
batch_sequences.append(seq)
batch_masks.append(mask)
if len(batch_sequences) >= self.batch_size:
# Pack and yield batch
packed_seq, packed_mask, split_indices = self.packer.pack_sequences(
batch_sequences, batch_masks
)
yield packed_seq, packed_mask, split_indices
batch_sequences = []
batch_masks = []
# Handle remaining sequences
if batch_sequences:
packed_seq, packed_mask, split_indices = self.packer.pack_sequences(
batch_sequences, batch_masks
)
yield packed_seq, packed_mask, split_indices
# Usage example
sequences = [torch.randn(200, 768), torch.randn(150, 768), torch.randn(300, 768)]
masks = [torch.ones(200), torch.ones(150), torch.ones(300)]
packer = SequencePacker(max_seq_len=512)
packed_seq, packed_mask, split_indices = packer.pack_sequences(sequences, masks)
print(f"Original sequences: {[s.shape[0] for s in sequences]}")
print(f"Packed shape: {packed_seq.shape}")
print(f"Memory efficiency: {sum(s.size(0) for s in sequences) / (packed_seq.size(0) * packed_seq.size(1)):.2f}")Production FlashAttention Service
High-Performance Attention Server
production_flash_attention.py
import torch
import torch.nn as nn
from typing import Optional, Tuple, Dict, Any
import asyncio
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
from contextlib import asynccontextmanager
import logging
import time
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class AttentionRequest(BaseModel):
query: list # [batch_size, seq_len, d_model]
key: list # [batch_size, seq_len, d_model]
value: list # [batch_size, seq_len, d_model]
num_heads: int = 8
block_size: int = 128
use_flash: bool = True
causal_mask: bool = False
class AttentionResponse(BaseModel):
output: list
processing_time: float
memory_used: float
sequence_length: int
attention_type: str
class OptimizedFlashAttention(nn.Module):
"""Production-optimized FlashAttention implementation"""
def __init__(self):
super().__init__()
self.attention_cache = {}
self.max_cache_size = 100
def _get_cache_key(self, shape: Tuple[int, ...], dtype: torch.dtype, device: str) -> str:
return f"{shape}_{dtype}_{device}"
def _get_cached_tensors(self, q: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Get cached intermediate tensors to avoid reallocation"""
key = self._get_cache_key(q.shape, q.dtype, str(q.device))
if key not in self.attention_cache:
if len(self.attention_cache) >= self.max_cache_size:
# Remove oldest entry
oldest_key = next(iter(self.attention_cache))
del self.attention_cache[oldest_key]
batch_size, num_heads, seq_len, head_dim = q.shape
self.attention_cache[key] = {
'O': torch.zeros_like(q),
'l': torch.zeros(batch_size, num_heads, seq_len, 1, device=q.device, dtype=q.dtype),
'm': torch.full((batch_size, num_heads, seq_len, 1), -float('inf'),
device=q.device, dtype=q.dtype)
}
return self.attention_cache[key]
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
block_size: int = 128, causal_mask: bool = False) -> torch.Tensor:
"""
Optimized FlashAttention forward pass
"""
batch_size, num_heads, seq_len, head_dim = q.shape
scale = 1.0 / math.sqrt(head_dim)
# Get cached tensors
cached = self._get_cached_tensors(q)
O = cached['O'].zero_()
l = cached['l'].zero_()
m = cached['m'].fill_(-float('inf'))
num_blocks = (seq_len + block_size - 1) // block_size
# Use torch.compile for better performance (PyTorch 2.0+)
@torch.compile
def compute_block_attention(q_block, k_block, v_block, m_prev, l_prev, O_prev):
# Compute attention scores
scores = torch.matmul(q_block, k_block.transpose(-2, -1)) * scale
# Apply causal mask if needed
if causal_mask:
mask = torch.triu(torch.ones_like(scores, dtype=torch.bool), diagonal=1)
scores = scores.masked_fill(mask, -float('inf'))
# Numerically stable softmax
m_curr = torch.max(scores, dim=-1, keepdim=True)[0]
m_new = torch.maximum(m_prev, m_curr)
alpha = torch.exp(m_prev - m_new)
beta = torch.exp(m_curr - m_new)
p_curr = torch.exp(scores - m_new)
l_curr = torch.sum(p_curr, dim=-1, keepdim=True)
l_new = alpha * l_prev + beta * l_curr
O_new = (alpha * l_prev * O_prev + beta * torch.matmul(p_curr, v_block)) / l_new
return O_new, l_new, m_new
# Process blocks
for i in range(num_blocks):
q_start, q_end = i * block_size, min((i + 1) * block_size, seq_len)
for j in range(num_blocks):
if causal_mask and j > i:
continue # Skip future blocks for causal attention
kv_start, kv_end = j * block_size, min((j + 1) * block_size, seq_len)
q_block = q[:, :, q_start:q_end, :]
k_block = k[:, :, kv_start:kv_end, :]
v_block = v[:, :, kv_start:kv_end, :]
O_block, l_block, m_block = compute_block_attention(
q_block, k_block, v_block,
m[:, :, q_start:q_end, :],
l[:, :, q_start:q_end, :],
O[:, :, q_start:q_end, :]
)
O[:, :, q_start:q_end, :] = O_block
l[:, :, q_start:q_end, :] = l_block
m[:, :, q_start:q_end, :] = m_block
return O
class AttentionService:
"""Production attention service with monitoring and optimization"""
def __init__(self):
self.flash_attention = OptimizedFlashAttention()
self.standard_attention = nn.MultiheadAttention(
embed_dim=768, num_heads=8, batch_first=True
)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.flash_attention.to(self.device)
self.standard_attention.to(self.device)
# Performance monitoring
self.request_count = 0
self.total_processing_time = 0.0
self.memory_usage_history = []
async def process_attention(self, request: AttentionRequest) -> AttentionResponse:
"""Process attention request with monitoring"""
start_time = time.time()
try:
# Convert inputs to tensors
q = torch.tensor(request.query, device=self.device, dtype=torch.float16)
k = torch.tensor(request.key, device=self.device, dtype=torch.float16)
v = torch.tensor(request.value, device=self.device, dtype=torch.float16)
batch_size, seq_len, d_model = q.shape
# Reshape for multi-head attention
head_dim = d_model // request.num_heads
q = q.view(batch_size, seq_len, request.num_heads, head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, request.num_heads, head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, request.num_heads, head_dim).transpose(1, 2)
# Memory monitoring
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
# Process attention
with torch.no_grad():
if request.use_flash:
output = self.flash_attention(q, k, v, request.block_size, request.causal_mask)
attention_type = "FlashAttention"
else:
# Standard attention fallback
q_std = q.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
k_std = k.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
v_std = v.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
output, _ = self.standard_attention(q_std, k_std, v_std)
output = output.view(batch_size, seq_len, request.num_heads, head_dim).transpose(1, 2)
attention_type = "StandardAttention"
# Reshape output back
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
# Memory usage calculation
memory_used = 0.0
if torch.cuda.is_available():
memory_used = torch.cuda.max_memory_allocated() / 1e6 # MB
processing_time = time.time() - start_time
# Update monitoring metrics
self.request_count += 1
self.total_processing_time += processing_time
self.memory_usage_history.append(memory_used)
if len(self.memory_usage_history) > 1000:
self.memory_usage_history.pop(0)
logger.info(f"Processed {attention_type} request: {seq_len} tokens, "
f"{processing_time:.3f}s, {memory_used:.1f}MB")
return AttentionResponse(
output=output.cpu().tolist(),
processing_time=processing_time,
memory_used=memory_used,
sequence_length=seq_len,
attention_type=attention_type
)
except Exception as e:
logger.error(f"Error processing attention request: {e}")
raise HTTPException(status_code=500, detail=str(e))
def get_stats(self) -> Dict[str, Any]:
"""Get service performance statistics"""
avg_time = self.total_processing_time / max(1, self.request_count)
avg_memory = sum(self.memory_usage_history) / max(1, len(self.memory_usage_history))
return {
"total_requests": self.request_count,
"average_processing_time": avg_time,
"average_memory_usage": avg_memory,
"device": str(self.device)
}
# Initialize service
attention_service = AttentionService()
# FastAPI setup
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Starting FlashAttention service")
logger.info(f"Using device: {attention_service.device}")
yield
logger.info("Shutting down FlashAttention service")
app = FastAPI(title="FlashAttention Service", lifespan=lifespan)
@app.post("/attention", response_model=AttentionResponse)
async def process_attention_endpoint(request: AttentionRequest):
"""Process attention computation request"""
return await attention_service.process_attention(request)
@app.get("/stats")
async def get_stats():
"""Get service performance statistics"""
return attention_service.get_stats()
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy", "device": str(attention_service.device)}
if __name__ == "__main__":
uvicorn.run(
"production_flash_attention:app",
host="0.0.0.0",
port=8000,
workers=1, # Single worker for GPU usage
log_level="info"
)Real-World Examples
OpenAI GPT-4
Uses FlashAttention-2 for efficient training of 2 trillion parameter models with 32K context windows.
- • 10x memory reduction vs standard attention
- • 32K token context length
- • 50% faster training throughput
Anthropic Claude
Leverages memory-optimized attention for 100K+ token context understanding with constitutional AI training.
- • 100K+ token context windows
- • Memory-efficient constitutional AI
- • Real-time inference optimization
Meta LLaMA-2
Implements FlashAttention for open-source LLM training with optimized memory usage on commodity hardware.
- • 70B parameter model training
- • Commodity GPU compatibility
- • 4x memory efficiency improvement
FlashAttention Best Practices
✅ Do's
- •Use block sizes that are multiples of 64 for GPU efficiency
- •Enable gradient checkpointing for very long sequences
- •Use FP16/BF16 precision for memory savings
- •Implement sequence packing for variable length inputs
- •Monitor GPU memory utilization during training
- •Use async preprocessing for better throughput
❌ Don'ts
- •Don't use very small block sizes (<64) - inefficient
- •Don't mix FlashAttention with standard attention layers
- •Don't ignore memory fragmentation in long sequences
- •Don't use FlashAttention for very short sequences
- •Don't forget to clear CUDA cache between runs
- •Don't skip benchmarking on your specific hardware
No quiz questions available
Quiz ID "flashattention-memory-optimization" not found