Multi-Modal AI Systems
Master vision-language models, cross-modal fusion, and production deployment of multi-modal applications
60 min read•Advanced
Not Started
Loading...
What are Multi-Modal AI Systems?
Multi-modal AI systems process and understand information across multiple modalities (vision, text, audio, etc.) to create more comprehensive and context-aware artificial intelligence applications.
Cross-Modal
Understanding across modalities
Unified Embedding
Shared representation space
Rich Understanding
Context beyond single modality
🧮 Multi-Modal System Calculator
Calculate memory usage, processing complexity, and performance metrics for multi-modal models.
Performance Analysis
Image Memory:18.4 MB
Text Memory:48.0 MB
Total Memory:73.0 MB
Inference Time:7225.4 ms
Throughput:0 samples/sec
Fusion Complexity:1.1x
Vision/Text Ratio:0.38
Core Multi-Modal Architectures
Early Fusion
- • Combine inputs at feature level
- • Shared encoder for both modalities
- • Lower computational cost
- • May lose modality-specific patterns
- • Good for tightly coupled tasks
Late Fusion
- • Process modalities separately
- • Combine at decision level
- • Preserves modality-specific features
- • Higher computational requirements
- • Better for loosely coupled tasks
Cross-Modal Attention
- • Dynamic interaction between modalities
- • Attention-based alignment
- • Context-aware fusion
- • State-of-the-art performance
- • Higher computational complexity
Unified Embeddings
- • Shared representation space
- • Cross-modal retrieval capability
- • Contrastive learning approaches
- • Zero-shot transfer capabilities
- • Requires large-scale pre-training
CLIP: Contrastive Language-Image Pre-training
CLIP Architecture Implementation
CLIP Model Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPModel, CLIPProcessor
import numpy as np
class CustomCLIP(nn.Module):
"""Custom CLIP implementation for vision-language understanding"""
def __init__(self,
vision_model_name: str = "openai/clip-vit-base-patch32",
temperature: float = 0.07):
super().__init__()
# Load pre-trained CLIP
self.clip_model = CLIPModel.from_pretrained(vision_model_name)
self.processor = CLIPProcessor.from_pretrained(vision_model_name)
# Learnable temperature parameter
self.temperature = nn.Parameter(torch.ones([]) * np.log(1/temperature))
# Additional projection layers for fine-tuning
embed_dim = self.clip_model.config.projection_dim
self.vision_projection = nn.Linear(embed_dim, embed_dim)
self.text_projection = nn.Linear(embed_dim, embed_dim)
# Dropout for regularization
self.dropout = nn.Dropout(0.1)
def encode_image(self, images):
"""Encode images to embeddings"""
vision_outputs = self.clip_model.vision_model(pixel_values=images)
image_embeds = self.clip_model.visual_projection(vision_outputs.pooler_output)
# Additional projection
image_embeds = self.vision_projection(self.dropout(image_embeds))
# L2 normalize
image_embeds = F.normalize(image_embeds, p=2, dim=-1)
return image_embeds
def encode_text(self, texts):
"""Encode text to embeddings"""
text_outputs = self.clip_model.text_model(
input_ids=texts['input_ids'],
attention_mask=texts['attention_mask']
)
text_embeds = self.clip_model.text_projection(text_outputs.pooler_output)
# Additional projection
text_embeds = self.text_projection(self.dropout(text_embeds))
# L2 normalize
text_embeds = F.normalize(text_embeds, p=2, dim=-1)
return text_embeds
def forward(self, images, texts):
"""Forward pass with contrastive loss"""
# Get embeddings
image_embeds = self.encode_image(images)
text_embeds = self.encode_text(texts)
# Compute similarity matrix
logits_per_image = torch.matmul(image_embeds, text_embeds.t()) * self.temperature.exp()
logits_per_text = logits_per_image.t()
return logits_per_image, logits_per_text
def compute_contrastive_loss(self, logits_per_image, logits_per_text):
"""Compute symmetric contrastive loss"""
batch_size = logits_per_image.shape[0]
labels = torch.arange(batch_size, device=logits_per_image.device)
# Image-to-text and text-to-image losses
loss_img = F.cross_entropy(logits_per_image, labels)
loss_txt = F.cross_entropy(logits_per_text, labels)
return (loss_img + loss_txt) / 2
class MultiModalSearchEngine:
"""Production-ready multi-modal search engine"""
def __init__(self, model_path: str = None):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if model_path:
self.model = torch.load(model_path, map_location=self.device)
else:
self.model = CustomCLIP().to(self.device)
self.model.eval()
# Image and text embeddings storage
self.image_embeddings = {}
self.text_embeddings = {}
def add_images(self, images, image_ids):
"""Add images to the search index"""
with torch.no_grad():
# Process images
processed = self.model.processor(images=images, return_tensors="pt")
pixel_values = processed['pixel_values'].to(self.device)
# Get embeddings
embeddings = self.model.encode_image(pixel_values)
# Store embeddings
for i, img_id in enumerate(image_ids):
self.image_embeddings[img_id] = embeddings[i].cpu().numpy()
def add_texts(self, texts, text_ids):
"""Add texts to the search index"""
with torch.no_grad():
# Process texts
processed = self.model.processor(text=texts, return_tensors="pt",
padding=True, truncation=True)
# Move to device
for key in processed:
processed[key] = processed[key].to(self.device)
# Get embeddings
embeddings = self.model.encode_text(processed)
# Store embeddings
for i, txt_id in enumerate(text_ids):
self.text_embeddings[txt_id] = embeddings[i].cpu().numpy()
def search_by_text(self, query_text, top_k=5, search_images=True):
"""Search images using text query"""
with torch.no_grad():
# Process query
processed = self.model.processor(text=[query_text], return_tensors="pt",
padding=True, truncation=True)
# Move to device
for key in processed:
processed[key] = processed[key].to(self.device)
# Get query embedding
query_embed = self.model.encode_text(processed)[0].cpu().numpy()
# Search in image embeddings
if search_images and self.image_embeddings:
similarities = []
for img_id, img_embed in self.image_embeddings.items():
sim = np.dot(query_embed, img_embed)
similarities.append((img_id, sim))
# Sort by similarity
similarities.sort(key=lambda x: x[1], reverse=True)
return similarities[:top_k]
else:
# Search in text embeddings
similarities = []
for txt_id, txt_embed in self.text_embeddings.items():
sim = np.dot(query_embed, txt_embed)
similarities.append((txt_id, sim))
similarities.sort(key=lambda x: x[1], reverse=True)
return similarities[:top_k]
def search_by_image(self, query_image, top_k=5, search_texts=True):
"""Search texts using image query"""
with torch.no_grad():
# Process image
processed = self.model.processor(images=[query_image], return_tensors="pt")
pixel_values = processed['pixel_values'].to(self.device)
# Get query embedding
query_embed = self.model.encode_image(pixel_values)[0].cpu().numpy()
# Search in text embeddings
if search_texts and self.text_embeddings:
similarities = []
for txt_id, txt_embed in self.text_embeddings.items():
sim = np.dot(query_embed, txt_embed)
similarities.append((txt_id, sim))
similarities.sort(key=lambda x: x[1], reverse=True)
return similarities[:top_k]
else:
# Search in image embeddings
similarities = []
for img_id, img_embed in self.image_embeddings.items():
sim = np.dot(query_embed, img_embed)
similarities.append((img_id, sim))
similarities.sort(key=lambda x: x[1], reverse=True)
return similarities[:top_k]
# Usage example
if __name__ == "__main__":
# Initialize search engine
search_engine = MultiModalSearchEngine()
# Add sample data
sample_texts = [
"A cat sitting on a windowsill",
"Mountain landscape with snow",
"City skyline at sunset"
]
text_ids = ["text_1", "text_2", "text_3"]
search_engine.add_texts(sample_texts, text_ids)
# Search by text
results = search_engine.search_by_text("feline animal indoors", top_k=3)
print(f"Search results: {results}")Vision-Language Transformers
Cross-Modal Transformer Architecture
Vision-Language Transformer
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, ViTModel
import math
class CrossModalTransformer(nn.Module):
"""Cross-modal transformer for vision-language understanding"""
def __init__(self,
vision_model_name: str = "google/vit-base-patch16-224",
text_model_name: str = "bert-base-uncased",
hidden_dim: int = 768,
num_cross_layers: int = 4,
num_attention_heads: int = 12):
super().__init__()
# Vision and text encoders
self.vision_encoder = ViTModel.from_pretrained(vision_model_name)
self.text_encoder = BertModel.from_pretrained(text_model_name)
# Cross-modal attention layers
self.cross_attention_layers = nn.ModuleList([
CrossModalLayer(hidden_dim, num_attention_heads)
for _ in range(num_cross_layers)
])
# Projection layers
self.vision_proj = nn.Linear(self.vision_encoder.config.hidden_size, hidden_dim)
self.text_proj = nn.Linear(self.text_encoder.config.hidden_size, hidden_dim)
# Task-specific heads
self.classifier = nn.Linear(hidden_dim * 2, 1) # For similarity/matching
self.dropout = nn.Dropout(0.1)
def forward(self,
pixel_values,
input_ids,
attention_mask=None,
return_cross_attention=False):
# Encode vision
vision_outputs = self.vision_encoder(pixel_values=pixel_values)
vision_features = self.vision_proj(vision_outputs.last_hidden_state)
# Encode text
text_outputs = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask
)
text_features = self.text_proj(text_outputs.last_hidden_state)
# Cross-modal interaction
cross_attentions = []
for layer in self.cross_attention_layers:
vision_features, text_features, cross_attn = layer(
vision_features, text_features, attention_mask
)
if return_cross_attention:
cross_attentions.append(cross_attn)
# Global pooling
vision_pooled = vision_features.mean(dim=1) # Average pool
text_pooled = text_features[:, 0, :] # CLS token
# Combine for classification
combined = torch.cat([vision_pooled, text_pooled], dim=-1)
logits = self.classifier(self.dropout(combined))
outputs = {'logits': logits}
if return_cross_attention:
outputs['cross_attentions'] = cross_attentions
return outputs
class CrossModalLayer(nn.Module):
"""Single cross-modal attention layer"""
def __init__(self, hidden_dim: int, num_heads: int):
super().__init__()
# Vision-to-text attention
self.v2t_attention = nn.MultiheadAttention(
embed_dim=hidden_dim,
num_heads=num_heads,
batch_first=True
)
# Text-to-vision attention
self.t2v_attention = nn.MultiheadAttention(
embed_dim=hidden_dim,
num_heads=num_heads,
batch_first=True
)
# Feed-forward networks
self.vision_ffn = FeedForward(hidden_dim)
self.text_ffn = FeedForward(hidden_dim)
# Layer norms
self.vision_norm1 = nn.LayerNorm(hidden_dim)
self.vision_norm2 = nn.LayerNorm(hidden_dim)
self.text_norm1 = nn.LayerNorm(hidden_dim)
self.text_norm2 = nn.LayerNorm(hidden_dim)
def forward(self, vision_features, text_features, text_mask=None):
batch_size, seq_len, hidden_dim = vision_features.shape
# Vision-to-text cross-attention
v2t_out, v2t_weights = self.v2t_attention(
query=text_features,
key=vision_features,
value=vision_features
)
text_features = self.text_norm1(text_features + v2t_out)
# Text-to-vision cross-attention
t2v_out, t2v_weights = self.t2v_attention(
query=vision_features,
key=text_features,
value=text_features,
key_padding_mask=~text_mask if text_mask is not None else None
)
vision_features = self.vision_norm1(vision_features + t2v_out)
# Feed-forward
vision_features = self.vision_norm2(
vision_features + self.vision_ffn(vision_features)
)
text_features = self.text_norm2(
text_features + self.text_ffn(text_features)
)
return vision_features, text_features, (v2t_weights, t2v_weights)
class FeedForward(nn.Module):
"""Feed-forward network with GELU activation"""
def __init__(self, hidden_dim: int, ff_dim: int = None):
super().__init__()
ff_dim = ff_dim or 4 * hidden_dim
self.linear1 = nn.Linear(hidden_dim, ff_dim)
self.linear2 = nn.Linear(ff_dim, hidden_dim)
self.dropout = nn.Dropout(0.1)
self.activation = nn.GELU()
def forward(self, x):
return self.linear2(self.dropout(self.activation(self.linear1(x))))
class MultiModalTrainer:
"""Training utilities for multi-modal models"""
def __init__(self, model, optimizer, device):
self.model = model
self.optimizer = optimizer
self.device = device
self.model.to(device)
def train_step(self, batch):
"""Single training step"""
self.model.train()
self.optimizer.zero_grad()
# Move batch to device
pixel_values = batch['pixel_values'].to(self.device)
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
labels = batch['labels'].to(self.device)
# Forward pass
outputs = self.model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask
)
# Compute loss
loss = F.binary_cross_entropy_with_logits(
outputs['logits'].squeeze(-1),
labels.float()
)
# Backward pass
loss.backward()
self.optimizer.step()
return loss.item()
def evaluate(self, dataloader):
"""Evaluate model on validation set"""
self.model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch in dataloader:
pixel_values = batch['pixel_values'].to(self.device)
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
labels = batch['labels'].to(self.device)
outputs = self.model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask
)
loss = F.binary_cross_entropy_with_logits(
outputs['logits'].squeeze(-1),
labels.float()
)
total_loss += loss.item()
# Calculate accuracy
predictions = torch.sigmoid(outputs['logits']).squeeze(-1) > 0.5
correct += (predictions == labels).sum().item()
total += labels.size(0)
return {
'loss': total_loss / len(dataloader),
'accuracy': correct / total
}Production Multi-Modal Application
Multi-Modal Content Understanding Service
production_multimodal_service.py
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from typing import List, Optional, Dict, Any
import torch
from PIL import Image
import asyncio
import base64
from io import BytesIO
import logging
from dataclasses import dataclass
from contextlib import asynccontextmanager
import uvicorn
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class MultiModalRequest:
image: Optional[str] = None # base64 encoded
text: Optional[str] = None
task: str = "similarity" # similarity, classification, generation, search
options: Dict[str, Any] = None
@dataclass
class MultiModalResponse:
task: str
results: Dict[str, Any]
confidence: float
processing_time_ms: float
model_version: str
class MultiModalService:
"""Production multi-modal AI service"""
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.clip_model = None
self.vl_transformer = None
self.load_models()
# Performance monitoring
self.request_count = 0
self.total_processing_time = 0.0
def load_models(self):
"""Load pre-trained models"""
try:
# Load CLIP for general vision-language tasks
from transformers import CLIPModel, CLIPProcessor
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Load custom cross-modal transformer if available
# self.vl_transformer = CrossModalTransformer()
# self.vl_transformer.load_state_dict(torch.load("path/to/model.pth"))
self.clip_model.to(self.device)
# self.vl_transformer.to(self.device)
logger.info(f"Models loaded successfully on {self.device}")
except Exception as e:
logger.error(f"Failed to load models: {e}")
raise
async def process_similarity(self, image: Image.Image, text: str) -> Dict[str, Any]:
"""Compute image-text similarity"""
try:
with torch.no_grad():
# Process inputs
inputs = self.clip_processor(
text=[text],
images=[image],
return_tensors="pt",
padding=True
)
# Move to device
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get outputs
outputs = self.clip_model(**inputs)
# Calculate similarity
logits_per_image = outputs.logits_per_image
similarity = torch.softmax(logits_per_image, dim=1)[0, 0].item()
return {
"similarity_score": similarity,
"interpretation": self._interpret_similarity(similarity)
}
except Exception as e:
logger.error(f"Similarity processing error: {e}")
raise HTTPException(status_code=500, detail=str(e))
async def process_classification(self, image: Image.Image,
categories: List[str]) -> Dict[str, Any]:
"""Classify image into given categories"""
try:
with torch.no_grad():
# Create text prompts for categories
text_prompts = [f"a photo of {category}" for category in categories]
# Process inputs
inputs = self.clip_processor(
text=text_prompts,
images=[image],
return_tensors="pt",
padding=True
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get outputs
outputs = self.clip_model(**inputs)
# Get probabilities
logits = outputs.logits_per_image[0]
probs = torch.softmax(logits, dim=0)
# Create results
results = []
for i, (category, prob) in enumerate(zip(categories, probs)):
results.append({
"category": category,
"confidence": prob.item(),
"rank": i + 1
})
# Sort by confidence
results.sort(key=lambda x: x["confidence"], reverse=True)
return {
"predictions": results,
"top_prediction": results[0]["category"],
"confidence": results[0]["confidence"]
}
except Exception as e:
logger.error(f"Classification processing error: {e}")
raise HTTPException(status_code=500, detail=str(e))
async def process_search(self, query_image: Image.Image,
text_candidates: List[str]) -> Dict[str, Any]:
"""Search for most relevant text given an image"""
try:
with torch.no_grad():
# Process all inputs
inputs = self.clip_processor(
text=text_candidates,
images=[query_image],
return_tensors="pt",
padding=True
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get outputs
outputs = self.clip_model(**inputs)
# Get similarities
logits = outputs.logits_per_image[0]
similarities = torch.softmax(logits, dim=0)
# Create ranked results
results = []
for i, (text, sim) in enumerate(zip(text_candidates, similarities)):
results.append({
"text": text,
"similarity": sim.item(),
"rank": i + 1
})
results.sort(key=lambda x: x["similarity"], reverse=True)
return {
"ranked_results": results,
"best_match": results[0]["text"],
"best_similarity": results[0]["similarity"]
}
except Exception as e:
logger.error(f"Search processing error: {e}")
raise HTTPException(status_code=500, detail=str(e))
def _interpret_similarity(self, score: float) -> str:
"""Interpret similarity score"""
if score > 0.8:
return "Very high similarity"
elif score > 0.6:
return "High similarity"
elif score > 0.4:
return "Moderate similarity"
elif score > 0.2:
return "Low similarity"
else:
return "Very low similarity"
def _decode_base64_image(self, image_data: str) -> Image.Image:
"""Decode base64 image"""
try:
image_bytes = base64.b64decode(image_data)
image = Image.open(BytesIO(image_bytes))
return image.convert('RGB')
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid image data: {e}")
async def process_request(self, request: MultiModalRequest) -> MultiModalResponse:
"""Main request processing function"""
import time
start_time = time.time()
try:
# Decode image if provided
image = None
if request.image:
image = self._decode_base64_image(request.image)
# Process based on task type
if request.task == "similarity":
if not image or not request.text:
raise HTTPException(
status_code=400,
detail="Both image and text required for similarity task"
)
results = await self.process_similarity(image, request.text)
elif request.task == "classification":
if not image:
raise HTTPException(
status_code=400,
detail="Image required for classification task"
)
categories = request.options.get("categories", ["object", "animal", "person"])
results = await self.process_classification(image, categories)
elif request.task == "search":
if not image:
raise HTTPException(
status_code=400,
detail="Image required for search task"
)
candidates = request.options.get("candidates", [])
results = await self.process_search(image, candidates)
else:
raise HTTPException(
status_code=400,
detail=f"Unsupported task: {request.task}"
)
processing_time = (time.time() - start_time) * 1000
# Update monitoring metrics
self.request_count += 1
self.total_processing_time += processing_time
# Calculate confidence based on results
confidence = self._calculate_confidence(results, request.task)
return MultiModalResponse(
task=request.task,
results=results,
confidence=confidence,
processing_time_ms=processing_time,
model_version="clip-vit-base-patch32"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Request processing error: {e}")
raise HTTPException(status_code=500, detail=str(e))
def _calculate_confidence(self, results: Dict[str, Any], task: str) -> float:
"""Calculate overall confidence for the response"""
if task == "similarity":
return results.get("similarity_score", 0.0)
elif task == "classification":
return results.get("confidence", 0.0)
elif task == "search":
return results.get("best_similarity", 0.0)
return 0.5
def get_stats(self) -> Dict[str, Any]:
"""Get service performance statistics"""
avg_time = (
self.total_processing_time / max(1, self.request_count)
)
return {
"total_requests": self.request_count,
"average_processing_time_ms": avg_time,
"device": str(self.device),
"models_loaded": ["CLIP"]
}
# Initialize service
multimodal_service = MultiModalService()
# FastAPI setup
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Starting Multi-Modal AI Service")
yield
logger.info("Shutting down Multi-Modal AI Service")
app = FastAPI(
title="Multi-Modal AI Service",
description="Production-ready multi-modal AI API",
version="1.0.0",
lifespan=lifespan
)
@app.post("/process", response_model=MultiModalResponse)
async def process_multimodal(
image: Optional[UploadFile] = File(None),
text: Optional[str] = Form(None),
task: str = Form("similarity"),
categories: Optional[str] = Form(None),
candidates: Optional[str] = Form(None)
):
"""Process multi-modal request"""
# Prepare request
image_data = None
if image:
image_bytes = await image.read()
image_data = base64.b64encode(image_bytes).decode()
options = {}
if categories:
options["categories"] = categories.split(",")
if candidates:
options["candidates"] = candidates.split("\n")
request = MultiModalRequest(
image=image_data,
text=text,
task=task,
options=options
)
return await multimodal_service.process_request(request)
@app.get("/stats")
async def get_service_stats():
"""Get service performance statistics"""
return multimodal_service.get_stats()
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"device": str(multimodal_service.device),
"models": "loaded"
}
if __name__ == "__main__":
uvicorn.run(
"production_multimodal_service:app",
host="0.0.0.0",
port=8000,
workers=1
)Real-World Examples
OpenAI GPT-4V
Large-scale vision-language model with advanced reasoning capabilities across image and text modalities.
- • 100B+ parameter multi-modal model
- • Chart/diagram understanding
- • Visual reasoning and QA
Google Bard/Gemini
Multi-modal conversational AI supporting text, image, and code understanding with real-time web access.
- • Real-time image analysis
- • Multi-turn conversations
- • Code + visual debugging
Meta LLaMA-V
Open-source vision-language model for research and production deployment with efficient architectures.
- • 7B-70B parameter models
- • Open-source availability
- • Efficient inference optimizations
Multi-Modal Best Practices
✅ Do's
- •Use pre-trained vision and language encoders
- •Implement cross-modal attention for rich interactions
- •Balance modality-specific and shared representations
- •Use contrastive learning for alignment
- •Implement proper data augmentation strategies
- •Monitor cross-modal alignment quality
❌ Don'ts
- •Don't ignore modality-specific preprocessing
- •Don't use simple concatenation for fusion
- •Don't neglect computational efficiency
- •Don't skip cross-modal evaluation metrics
- •Don't ignore domain adaptation challenges
- •Don't underestimate data quality importance
No quiz questions available
Quiz ID "multi-modal-ai-systems" not found