Neural Architecture Search (NAS)
Master automated neural network design through advanced search techniques and optimization methods
What is Neural Architecture Search?
Neural Architecture Search (NAS) is an automated machine learning technique that searches for optimal neural network architectures. Instead of manually designing networks, NAS algorithms explore the space of possible architectures to find designs that maximize performance while considering constraints like latency, memory, and energy efficiency.
Automated Design
Removes human bias and explores larger design spaces
Multi-Objective
Optimizes for accuracy, efficiency, and deployment constraints
Domain Agnostic
Works across vision, NLP, speech, and other domains
NAS Search Cost Calculator
Estimate the computational cost and time required for different NAS configurations
Search Results
NAS Search Methods
evolutionary
differentiable
bayesian
reinforcement
Search Space Design
Macro Search Space
High-level architectural choices
Layer types, connections, depth
Medium
Evolutionary NAS Implementation
import torch
import torch.nn as nn
import numpy as np
import random
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from concurrent.futures import ProcessPoolExecutor
import json
@dataclass
class ArchitectureGene:
"""Genetic representation of neural architecture"""
layers: List[Dict] # Layer configurations
connections: List[Tuple[int, int]] # Skip connections
depth: int
width_multiplier: float
def __post_init__(self):
self.fitness: Optional[float] = None
self.latency: Optional[float] = None
self.memory: Optional[float] = None
class NASSearchSpace:
"""Defines the search space for neural architectures"""
def __init__(self):
self.layer_types = [
'conv3x3', 'conv1x1', 'depthwise_conv3x3',
'dilated_conv3x3', 'maxpool3x3', 'avgpool3x3',
'identity', 'bottleneck', 'inverted_residual'
]
self.activations = ['relu', 'swish', 'hardswish', 'gelu']
self.normalizations = ['batchnorm', 'layernorm', 'groupnorm']
self.channel_sizes = [16, 32, 64, 128, 256, 512, 1024]
self.kernel_sizes = [1, 3, 5, 7]
def random_architecture(self, max_depth: int = 20) -> ArchitectureGene:
"""Generate random architecture within search space"""
depth = random.randint(8, max_depth)
layers = []
for i in range(depth):
layer = {
'type': random.choice(self.layer_types),
'channels': random.choice(self.channel_sizes),
'kernel_size': random.choice(self.kernel_sizes),
'activation': random.choice(self.activations),
'normalization': random.choice(self.normalizations),
'dropout': random.uniform(0.0, 0.3)
}
layers.append(layer)
# Generate random skip connections
connections = []
for i in range(depth):
if random.random() < 0.3: # 30% chance of skip connection
source = random.randint(0, i)
connections.append((source, i))
return ArchitectureGene(
layers=layers,
connections=connections,
depth=depth,
width_multiplier=random.uniform(0.5, 2.0)
)
def mutate_architecture(self, arch: ArchitectureGene,
mutation_rate: float = 0.1) -> ArchitectureGene:
"""Apply mutations to architecture"""
mutated_layers = []
for layer in arch.layers:
new_layer = layer.copy()
if random.random() < mutation_rate:
# Mutate layer type
if random.random() < 0.3:
new_layer['type'] = random.choice(self.layer_types)
# Mutate channels
if random.random() < 0.3:
new_layer['channels'] = random.choice(self.channel_sizes)
# Mutate activation
if random.random() < 0.2:
new_layer['activation'] = random.choice(self.activations)
mutated_layers.append(new_layer)
# Mutate connections
mutated_connections = arch.connections.copy()
if random.random() < mutation_rate:
if mutated_connections and random.random() < 0.5:
# Remove a connection
mutated_connections.pop(random.randint(0, len(mutated_connections) - 1))
else:
# Add a connection
i = random.randint(1, arch.depth - 1)
j = random.randint(0, i - 1)
mutated_connections.append((j, i))
return ArchitectureGene(
layers=mutated_layers,
connections=mutated_connections,
depth=len(mutated_layers),
width_multiplier=arch.width_multiplier
)
def crossover_architectures(self, parent1: ArchitectureGene,
parent2: ArchitectureGene) -> Tuple[ArchitectureGene, ArchitectureGene]:
"""Create offspring through crossover"""
# Single-point crossover for layers
crossover_point = random.randint(1, min(len(parent1.layers), len(parent2.layers)) - 1)
child1_layers = parent1.layers[:crossover_point] + parent2.layers[crossover_point:]
child2_layers = parent2.layers[:crossover_point] + parent1.layers[crossover_point:]
# Combine connections
child1_connections = parent1.connections + parent2.connections
child2_connections = parent2.connections + parent1.connections
child1 = ArchitectureGene(
layers=child1_layers,
connections=child1_connections,
depth=len(child1_layers),
width_multiplier=(parent1.width_multiplier + parent2.width_multiplier) / 2
)
child2 = ArchitectureGene(
layers=child2_layers,
connections=child2_connections,
depth=len(child2_layers),
width_multiplier=(parent1.width_multiplier + parent2.width_multiplier) / 2
)
return child1, child2
class ArchitectureEvaluator:
"""Evaluates neural architectures"""
def __init__(self, dataset, device='cuda'):
self.dataset = dataset
self.device = device
def build_model(self, arch: ArchitectureGene) -> nn.Module:
"""Build PyTorch model from architecture gene"""
class DynamicNet(nn.Module):
def __init__(self, arch: ArchitectureGene):
super().__init__()
self.layers = nn.ModuleList()
self.connections = arch.connections
for i, layer_config in enumerate(arch.layers):
layer = self._build_layer(layer_config)
self.layers.append(layer)
def _build_layer(self, config: Dict) -> nn.Module:
layer_type = config['type']
channels = config['channels']
if layer_type == 'conv3x3':
return nn.Conv2d(channels, channels, 3, padding=1)
elif layer_type == 'conv1x1':
return nn.Conv2d(channels, channels, 1)
elif layer_type == 'depthwise_conv3x3':
return nn.Conv2d(channels, channels, 3, groups=channels, padding=1)
elif layer_type == 'maxpool3x3':
return nn.MaxPool2d(3, stride=1, padding=1)
elif layer_type == 'avgpool3x3':
return nn.AvgPool2d(3, stride=1, padding=1)
else: # identity
return nn.Identity()
def forward(self, x):
layer_outputs = [x]
for i, layer in enumerate(self.layers):
# Apply skip connections
layer_input = layer_outputs[-1]
for src, dst in self.connections:
if dst == i and src < len(layer_outputs):
layer_input = layer_input + layer_outputs[src]
output = layer(layer_input)
layer_outputs.append(output)
return layer_outputs[-1]
return DynamicNet(arch)
def evaluate_architecture(self, arch: ArchitectureGene,
epochs: int = 5) -> Dict[str, float]:
"""Evaluate architecture performance"""
try:
model = self.build_model(arch)
model = model.to(self.device)
# Fast evaluation with limited training
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
model.train()
total_loss = 0
num_batches = min(100, len(self.dataset)) # Limit for speed
for epoch in range(epochs):
for i, (inputs, targets) in enumerate(self.dataset):
if i >= num_batches:
break
inputs, targets = inputs.to(self.device), targets.to(self.device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
total_loss += loss.item()
# Evaluate accuracy
model.eval()
correct = 0
total = 0
with torch.no_grad():
for i, (inputs, targets) in enumerate(self.dataset):
if i >= 50: # Quick evaluation
break
inputs, targets = inputs.to(self.device), targets.to(self.device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += (predicted == targets).sum().item()
accuracy = correct / total if total > 0 else 0.0
# Estimate model complexity
total_params = sum(p.numel() for p in model.parameters())
return {
'accuracy': accuracy,
'loss': total_loss / (epochs * num_batches),
'params': total_params,
'latency': self._estimate_latency(model),
'memory': self._estimate_memory(model)
}
except Exception as e:
print(f"Architecture evaluation failed: {e}")
return {
'accuracy': 0.0,
'loss': float('inf'),
'params': float('inf'),
'latency': float('inf'),
'memory': float('inf')
}
def _estimate_latency(self, model: nn.Module) -> float:
"""Estimate inference latency"""
model.eval()
with torch.no_grad():
dummy_input = torch.randn(1, 3, 224, 224).to(self.device)
# Warm up
for _ in range(10):
_ = model(dummy_input)
# Time inference
torch.cuda.synchronize()
start_time = torch.cuda.Event(enable_timing=True)
end_time = torch.cuda.Event(enable_timing=True)
start_time.record()
for _ in range(100):
_ = model(dummy_input)
end_time.record()
torch.cuda.synchronize()
return start_time.elapsed_time(end_time) / 100 # ms per inference
def _estimate_memory(self, model: nn.Module) -> float:
"""Estimate memory usage"""
total_memory = 0
for param in model.parameters():
total_memory += param.numel() * param.element_size()
return total_memory / (1024 * 1024) # MB
class EvolutionaryNAS:
"""Main evolutionary NAS algorithm"""
def __init__(self, search_space: NASSearchSpace,
evaluator: ArchitectureEvaluator,
population_size: int = 50,
generations: int = 100):
self.search_space = search_space
self.evaluator = evaluator
self.population_size = population_size
self.generations = generations
self.best_architectures = []
def initialize_population(self) -> List[ArchitectureGene]:
"""Create initial population"""
return [self.search_space.random_architecture() for _ in range(self.population_size)]
def evaluate_population(self, population: List[ArchitectureGene]) -> List[ArchitectureGene]:
"""Evaluate fitness of entire population"""
for arch in population:
if arch.fitness is None: # Only evaluate if not already evaluated
results = self.evaluator.evaluate_architecture(arch)
# Multi-objective fitness function
accuracy_weight = 0.7
efficiency_weight = 0.3
# Normalize and combine metrics
accuracy_score = results['accuracy']
efficiency_score = 1.0 / (1.0 + results['latency'] / 100.0) # Lower latency is better
arch.fitness = accuracy_weight * accuracy_score + efficiency_weight * efficiency_score
arch.latency = results['latency']
arch.memory = results['memory']
return sorted(population, key=lambda x: x.fitness, reverse=True)
def selection(self, population: List[ArchitectureGene], k: int) -> List[ArchitectureGene]:
"""Tournament selection"""
selected = []
for _ in range(k):
tournament = random.sample(population, min(3, len(population)))
winner = max(tournament, key=lambda x: x.fitness)
selected.append(winner)
return selected
def evolve(self) -> ArchitectureGene:
"""Run evolutionary algorithm"""
# Initialize population
population = self.initialize_population()
for generation in range(self.generations):
print(f"Generation {generation + 1}/{self.generations}")
# Evaluate population
population = self.evaluate_population(population)
# Track best architectures
best_arch = population[0]
self.best_architectures.append(best_arch)
print(f"Best fitness: {best_arch.fitness:.4f}, "
f"Latency: {best_arch.latency:.2f}ms, "
f"Memory: {best_arch.memory:.2f}MB")
# Selection and reproduction
elite_size = self.population_size // 10 # Keep top 10%
new_population = population[:elite_size]
# Generate offspring
while len(new_population) < self.population_size:
# Selection
parents = self.selection(population, 2)
# Crossover
child1, child2 = self.search_space.crossover_architectures(parents[0], parents[1])
# Mutation
child1 = self.search_space.mutate_architecture(child1)
child2 = self.search_space.mutate_architecture(child2)
new_population.extend([child1, child2])
# Trim to population size
population = new_population[:self.population_size]
# Final evaluation and return best
population = self.evaluate_population(population)
return population[0]
# Usage example
def run_evolutionary_nas():
"""Example of running evolutionary NAS"""
# Setup (replace with actual dataset)
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Load dataset (example with CIFAR-10)
dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# Initialize NAS components
search_space = NASSearchSpace()
evaluator = ArchitectureEvaluator(dataloader)
# Run evolutionary NAS
nas = EvolutionaryNAS(
search_space=search_space,
evaluator=evaluator,
population_size=20, # Smaller for demo
generations=10
)
best_architecture = nas.evolve()
print("\nBest Architecture Found:")
print(f"Fitness: {best_architecture.fitness:.4f}")
print(f"Depth: {best_architecture.depth}")
print(f"Latency: {best_architecture.latency:.2f}ms")
print(f"Memory: {best_architecture.memory:.2f}MB")
return best_architecture
if __name__ == "__main__":
best_arch = run_evolutionary_nas()Differentiable NAS (DARTS)
Continuous Relaxation
DARTS treats architecture search as a continuous optimization problem by representing the search space as a weighted combination of all possible operations, making it differentiable and efficient.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict
class MixedOperation(nn.Module):
"""Mixed operation for DARTS - weighted combination of primitives"""
def __init__(self, C, stride, primitives):
super().__init__()
self._ops = nn.ModuleList()
for primitive in primitives:
op = self._get_operation(primitive, C, stride)
self._ops.append(op)
def _get_operation(self, primitive, C, stride):
"""Get specific operation"""
if primitive == 'none':
return Zero(stride)
elif primitive == 'skip_connect':
return Identity() if stride == 1 else FactorizedReduce(C, C)
elif primitive == 'conv_1x1':
return nn.Sequential(
nn.Conv2d(C, C, 1, stride, 0, bias=False),
nn.BatchNorm2d(C)
)
elif primitive == 'conv_3x3':
return nn.Sequential(
nn.Conv2d(C, C, 3, stride, 1, bias=False),
nn.BatchNorm2d(C)
)
elif primitive == 'sep_conv_3x3':
return SeparableConv2d(C, C, 3, stride, 1)
elif primitive == 'sep_conv_5x5':
return SeparableConv2d(C, C, 5, stride, 2)
elif primitive == 'dil_conv_3x3':
return DilatedConv2d(C, C, 3, stride, 2, 2)
elif primitive == 'max_pool_3x3':
return nn.MaxPool2d(3, stride, 1)
elif primitive == 'avg_pool_3x3':
return nn.AvgPool2d(3, stride, 1)
else:
raise ValueError(f"Unknown primitive: {primitive}")
def forward(self, x, weights):
"""Forward pass with architecture weights"""
return sum(w * op(x) for w, op in zip(weights, self._ops))
class Cell(nn.Module):
"""DARTS cell with learnable connections"""
def __init__(self, steps, multiplier, C_prev, C, reduction, primitives):
super().__init__()
self.reduction = reduction
self.primitives = primitives
if reduction:
self.preprocess0 = FactorizedReduce(C_prev, C)
else:
self.preprocess0 = nn.Sequential(
nn.Conv2d(C_prev, C, 1, bias=False),
nn.BatchNorm2d(C)
)
self.preprocess1 = nn.Sequential(
nn.Conv2d(C, C, 1, bias=False),
nn.BatchNorm2d(C)
)
self._steps = steps
self._multiplier = multiplier
self._ops = nn.ModuleList()
for i in range(self._steps):
for j in range(2 + i):
stride = 2 if reduction and j < 2 else 1
op = MixedOperation(C, stride, primitives)
self._ops.append(op)
def forward(self, s0, s1, weights):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
offset = 0
for i in range(self._steps):
s = sum(self._ops[offset + j](h, weights[offset + j])
for j, h in enumerate(states))
offset += len(states)
states.append(s)
return torch.cat(states[-self._multiplier:], dim=1)
class DARTSNetwork(nn.Module):
"""Complete DARTS network with learnable architecture"""
def __init__(self, C, num_classes, layers, steps=4, multiplier=4):
super().__init__()
self.primitives = [
'none',
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'dil_conv_3x3',
'conv_1x1'
]
self._C = C
self._num_classes = num_classes
self._layers = layers
self._steps = steps
self._multiplier = multiplier
# Initialize architecture parameters
k = sum(1 for i in range(self._steps) for n in range(2 + i))
num_ops = len(self.primitives)
self.alphas_normal = nn.Parameter(torch.randn(k, num_ops))
self.alphas_reduce = nn.Parameter(torch.randn(k, num_ops))
# Network stem
self.stem = nn.Sequential(
nn.Conv2d(3, C, 3, padding=1, bias=False),
nn.BatchNorm2d(C)
)
# Build network layers
C_prev, C_curr = C, C
self.cells = nn.ModuleList()
for i in range(layers):
if i in [layers // 3, 2 * layers // 3]:
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(steps, multiplier, C_prev, C_curr, reduction, self.primitives)
self.cells.append(cell)
C_prev = C_curr * multiplier
# Classification head
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
def forward(self, input):
s0 = s1 = self.stem(input)
for i, cell in enumerate(self.cells):
if cell.reduction:
weights = F.softmax(self.alphas_reduce, dim=-1)
else:
weights = F.softmax(self.alphas_normal, dim=-1)
s0, s1 = s1, cell(s0, s1, weights)
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0), -1))
return logits
def arch_parameters(self):
return [self.alphas_normal, self.alphas_reduce]
def model_parameters(self):
return [p for n, p in self.named_parameters()
if 'alpha' not in n]
def genotype(self):
"""Extract discrete architecture from continuous weights"""
def _parse(weights):
gene = []
n = 2
start = 0
for i in range(self._steps):
end = start + n
W = weights[start:end].copy()
edges = sorted(range(i + 2),
key=lambda x: -max(W[x][k] for k in range(len(W[x]))
if k != self.primitives.index('none')))[:2]
for j in edges:
k_best = None
for k in range(len(W[j])):
if k != self.primitives.index('none'):
if k_best is None or W[j][k] > W[j][k_best]:
k_best = k
gene.append((self.primitives[k_best], j))
start = end
n += 1
return gene
weights_normal = F.softmax(self.alphas_normal, dim=-1).data.cpu().numpy()
weights_reduce = F.softmax(self.alphas_reduce, dim=-1).data.cpu().numpy()
gene_normal = _parse(weights_normal)
gene_reduce = _parse(weights_reduce)
return {
'normal': gene_normal,
'reduce': gene_reduce
}
class DARTSTrainer:
"""DARTS training procedure"""
def __init__(self, model, train_loader, valid_loader,
device='cuda', learning_rate=0.025, arch_learning_rate=3e-4):
self.model = model.to(device)
self.train_loader = train_loader
self.valid_loader = valid_loader
self.device = device
# Separate optimizers for model weights and architecture weights
self.model_optimizer = torch.optim.SGD(
self.model.model_parameters(),
lr=learning_rate,
momentum=0.9,
weight_decay=3e-4
)
self.arch_optimizer = torch.optim.Adam(
self.model.arch_parameters(),
lr=arch_learning_rate,
betas=(0.5, 0.999),
weight_decay=1e-3
)
self.criterion = nn.CrossEntropyLoss()
def train_epoch(self):
self.model.train()
for step, (train_X, train_y) in enumerate(self.train_loader):
train_X, train_y = train_X.to(self.device), train_y.to(self.device)
# Get validation batch for architecture optimization
try:
valid_X, valid_y = next(self.valid_iter)
except:
self.valid_iter = iter(self.valid_loader)
valid_X, valid_y = next(self.valid_iter)
valid_X, valid_y = valid_X.to(self.device), valid_y.to(self.device)
# Step 1: Update architecture weights
self.arch_optimizer.zero_grad()
logits = self.model(valid_X)
arch_loss = self.criterion(logits, valid_y)
arch_loss.backward()
self.arch_optimizer.step()
# Step 2: Update model weights
self.model_optimizer.zero_grad()
logits = self.model(train_X)
model_loss = self.criterion(logits, train_y)
model_loss.backward()
nn.utils.clip_grad_norm_(self.model.model_parameters(), 5.0)
self.model_optimizer.step()
if step % 100 == 0:
print(f"Step {step}: Model Loss = {model_loss:.4f}, Arch Loss = {arch_loss:.4f}")
def search(self, epochs=25):
"""Run DARTS search"""
self.valid_iter = iter(self.valid_loader)
for epoch in range(epochs):
print(f"\nEpoch {epoch + 1}/{epochs}")
self.train_epoch()
# Print current genotype
genotype = self.model.genotype()
print(f"Current architecture: {genotype}")
return self.model.genotype()
# Usage example
def run_darts():
"""Example of running DARTS"""
# Create model
model = DARTSNetwork(C=16, num_classes=10, layers=8)
# Setup data loaders (replace with actual data)
# train_loader, valid_loader = get_data_loaders()
# Initialize trainer
# trainer = DARTSTrainer(model, train_loader, valid_loader)
# Run search
# final_architecture = trainer.search(epochs=25)
print("DARTS search completed!")
# return final_architecture
if __name__ == "__main__":
run_darts()Real-World NAS Applications
Google - EfficientNet
Google used NAS to discover EfficientNet architectures, achieving state-of-the-art accuracy with 10x fewer parameters. The compound scaling method balances network depth, width, and resolution.
- • 84.3% ImageNet top-1 accuracy
- • 10x parameter efficiency vs ResNet
- • Compound scaling methodology
- • Mobile-optimized variants (B0-B7)
Meta - RegNet
Meta's RegNet family was discovered through NAS, providing simple and efficient architectures that outperform EfficientNet on several benchmarks while maintaining interpretability.
- • Linear design space exploration
- • Better speed-accuracy trade-offs
- • Consistent scaling across model sizes
- • Open-source implementation
Microsoft - BingBERT
Microsoft applied NAS to transformer architectures, creating BingBERT with optimized attention patterns and layer configurations for production search scenarios.
- • Transformer architecture optimization
- • Production search deployment
- • Hardware-aware optimization
- • Multi-task learning support
NVIDIA - FasterTransformer
NVIDIA uses NAS techniques to optimize transformer architectures for their GPU hardware, achieving significant speedups in inference while maintaining model quality.
- • Hardware-specific optimization
- • Inference speed improvements
- • Memory efficiency optimization
- • Multi-GPU scaling support
NAS Best Practices
✅ Do
Define meaningful search spaces with domain knowledge
Use multi-objective optimization (accuracy + efficiency)
Implement early stopping to reduce search time
Validate discovered architectures thoroughly
Consider hardware constraints during search
Use surrogate models to speed up evaluation
Track and analyze search progress continuously
❌ Don't
Use overly broad search spaces without constraints
Ignore computational cost and search efficiency
Rely solely on accuracy without considering deployment
Skip thorough evaluation on held-out test data
Use insufficient training for architecture evaluation
Forget to validate transferability across datasets
Ignore the interpretability of discovered architectures