José David Baena

On This Page

On this page

Memory Optimization Techniques: Gradient Accumulation & Mixed Precision

Banner.jpeg
Published on
/25 mins read

Track 2: Practical Guides - Post 2.6 of 6

This final post in Track 2 covers practical memory optimization strategies: gradient accumulation, mixed precision training, sequence length management, optimizer state optimization, distributed memory, and profiling tools. View all posts in this track →

One H100 has 80GB—here's why you'll run out anyway

Memory management was my biggest surprise when first training LLMs. Weights are tiny compared to optimizer states, gradients, and activations. Understanding this breakdown is the first step to fitting larger models on smaller hardware.

80GB sounds like a lot. Then you add optimizer states, gradients, and activations. Suddenly your 150M model won't fit.

TL;DR: Gradient accumulation gives you 4× batch size without 4× memory. bfloat16 halves memory with no accuracy loss. Muon uses 2× less optimizer memory than Adam. These techniques let a single 24GB GPU train 150M parameter models.

The OOM that cost a week: Consider a common scenario: training a 150M model on an RTX 4090 (24GB) that crashes at step 50,000 with "CUDA out of memory"—after running fine for two days. The cause: Python's garbage collector hasn't run, and intermediate tensors accumulate. The fix is a single line: torch.cuda.empty_cache() every 1000 steps. But if you haven't saved checkpoints recently, you lose everything and have to restart. Memory management isn't just about fitting your model—it's about keeping it stable through long training runs. The techniques in this post prevent that crash before it happens.

Memory is the primary bottleneck when training language models. A single H100 GPU has 80GB of memory—sounds like a lot until you realize:

  • A 768-dim, 12-layer model (~60M parameters) in bfloat16 uses ~120MB for weights
  • Activations for a single forward pass (batch_size=32, seq_len=2048) use ~8GB
  • Optimizer states (Adam momentum + variance) triple the memory: ~360MB
  • Gradients double it again: ~240MB
  • Total: ~9GB per GPU, and that's for a tiny model!

Scale to a 1280-dim, 20-layer model (~150M parameters) with batch_size=32, and you're looking at ~20GB per GPU. Run out of memory, and training stops.

Eight techniques get you back under budget:

  1. Gradient Accumulation: Simulate larger batches without OOM
  2. Mixed Precision Training: bfloat16 for 2x memory savings
  3. Sequence Length Management: Dynamic batching strategies
  4. Optimizer State Optimization: Choosing the right optimizer
  5. Distributed Training: Splitting work across GPUs
  6. Inference Optimizations: KV caching, batch size tuning
  7. Advanced Techniques: Gradient checkpointing, activation compression
  8. Memory Profiling: Tools to diagnose bottlenecks

Table of Contents

  1. Memory Breakdown: Where Does It All Go?
  2. Gradient Accumulation
  3. Mixed Precision Training
  4. Sequence Length Management
  5. Optimizer Choice
  6. Distributed Training Memory
  7. Inference Memory Optimization
  8. Advanced: Gradient Checkpointing
  9. Memory Profiling
  10. Best Practices & Common Pitfalls

Optimizer states consume 4×8 bytes per parameter

Training Memory Components

For a model with P parameters, batch size B, sequence length T, and embedding dimension d:

ComponentMemory (per parameter)TotalNotes
Model Weights2 bytes (bf16)2PThe model parameters
Gradients2 bytes (bf16)2PStored during backward pass
Optimizer State (Adam)8 bytes8PMomentum (4B) + variance (4B)
Optimizer State (Muon)4 bytes4PMomentum only (4B)
ActivationsVaries~12 * B * T * d * n_layersIntermediate layer outputs

Example: 60M parameter model, Adam optimizer:

  • Weights: 60M × 2B = 120MB
  • Gradients: 60M × 2B = 120MB
  • Adam states: 60M × 8B = 480MB
  • Total parameter memory: 720MB

Activations dominate for large batches:

  • batch_size=32, seq_len=2048, d=768, n_layers=12
  • Activations: ~12 × 32 × 2048 × 768 × 12 ≈ ~8GB

GPU Memory Budget Calculator

Estimate memory requirements for training your model

Memory Distribution

Memory Breakdown

Parameters
540.00 MB
Gradients
540.00 MB
Optimizer States
1.08 GB
Activations
12.08 GB
CUDA Overhead
216.00 MB
Total Required
14.46 GB

GPU Compatibility

RTX 4090
✓ Fits
A100 40GB
✓ Fits
A100 80GB
✓ Fits
H100 80GB
✓ Fits

Tips to reduce memory:

  • Use gradient checkpointing to trade compute for memory
  • Reduce batch size and use gradient accumulation
  • Use mixed precision (BF16) training
  • Consider ZeRO optimization for multi-GPU setups
  • Muon uses ~33% less optimizer memory than AdamW

Memory Scaling

Scale FactorImpact
Double parameters+2x model memory, +2x gradient memory, +2x optimizer memory
Double batch size+2x activation memory (no change to model/optimizer)
Double sequence length+2x activation memory (no change to model/optimizer)
Switch Adam → Muon-50% optimizer memory (8 bytes → 4 bytes per param)
Add gradient accumulationNo extra memory (same effective batch, different compute)

Key insight: Activations scale with batch size and sequence length, but optimizer state scales only with parameters.

For your GPU budget, this means: if you're memory-constrained, tackle activations first (smaller batch, gradient accumulation). Optimizer choice matters less for memory—but Muon still saves 50% on optimizer state.

For your training schedule, this means: don't start with max batch size. Start small, monitor GPU memory, and scale up until you hit 85-90% utilization. Leaving 10-15% headroom prevents OOM from activation spikes.


Gradient accumulation simulates large batches without OOM

The Problem

You want total_batch_size = 524,288 tokens, but:

  • Your GPU can only fit device_batch_size = 16 sequences of seq_len = 2048
  • That's only 16 × 2048 = 32,768 tokens per step
  • Gap: You need 16x more tokens per update!

The Solution: Gradient Accumulation

Idea: Accumulate gradients across multiple forward/backward passes before stepping the optimizer.

# Desired: total_batch_size = 524,288 tokens
# Reality: device_batch_size = 16, seq_len = 2048 => 32,768 tokens/step
# Solution: grad_accum_steps = 524,288 / 32,768 = 16
 
grad_accum_steps = total_batch_size // (device_batch_size * seq_len * world_size)
 
for step in range(num_iterations):
    # Accumulate gradients over multiple micro-batches
    for micro_step in range(grad_accum_steps):
        x, y = next(data_loader)
        loss = model(x, y)
        loss = loss / grad_accum_steps  # Normalize: each .backward() sums gradients
        loss.backward()
    
    # Step optimizer once
    optimizer.step()
    optimizer.zero_grad()

Why Scale Loss by grad_accum_steps?

Without scaling:

# Micro-batch 1: loss=2.5 → backward() adds gradients
# Micro-batch 2: loss=2.3 → backward() adds more gradients
# Result: gradients are 2x too large (sum of 2 losses)

For your training loop, this means: always divide loss by grad_accum_steps before .backward(). Miss this, and your effective learning rate silently multiplies by your accumulation factor—training will diverge.

For your debugging, this means: if training diverges after adding gradient accumulation, check the loss scaling first. It's the #1 gradient accumulation bug.

With scaling:

# Micro-batch 1: loss=2.5/2=1.25 → backward()
# Micro-batch 2: loss=2.3/2=1.15 → backward()
# Result: gradients average across micro-batches (correct)

Memory Impact

Zero extra memory cost! Gradients are reused across micro-batches:

# Iteration 1: forward → backward (gradients stored)
# Iteration 2: forward → backward (gradients ADDED to existing)
# Iteration 3: forward → backward (gradients ADDED again)
# ...
# Optimizer step (gradients cleared)

Only one batch of activations in memory at a time.

Gradient Accumulation Simulator

Visualize how gradient accumulation enables larger effective batch sizes

Micro-batch Size
4
samples per forward pass
Accumulation Steps
8
micro-batches before update
Effective Batch Size
32
= 4 × 8

Micro-batches Progress

PendingForward PassBackward PassAccumulated

Individual Micro-batch Gradients

Accumulated Gradient (÷8)

Param 1
0.000
Param 2
0.000
Param 3
0.000
Param 4
0.000
Optimizer steps completed: 0

Memory Comparison

Without Accumulation
Batch size: 32
~32× activation memory
With Accumulation
Micro-batch: 4, Accum: 8×
~4× activation memory (13% of above)

Key insight: Gradient accumulation achieves the same mathematical result as a large batch, but uses only the memory of a small batch. The gradients from each micro-batch are averaged together before the optimizer step. This is mathematically equivalent to processing all samples at once, but allows training on GPUs with limited memory.

nanochat Implementation

# From scripts/base_train.py
tokens_per_fwdbwd = device_batch_size * max_seq_len
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
 
print(f"Gradient accumulation steps: {grad_accum_steps}")
 
# Training loop
for step in range(num_iterations):
    for micro_step in range(grad_accum_steps):
        with autocast_ctx:
            loss = model(x, y)
        train_loss = loss.detach()
        loss = loss / grad_accum_steps  # Scale loss
        loss.backward()
        x, y = next(train_loader)  # Prefetch next batch
    
    # Clip gradients (optional)
    if grad_clip > 0.0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    
    # Step optimizers
    optimizer.step()
    model.zero_grad(set_to_none=True)

Key detail: set_to_none=True frees gradient memory immediately (faster than zeroing).


bfloat16 halves memory with no accuracy loss

What is Mixed Precision?

Use lower precision (16-bit) for most operations, full precision (32-bit) for sensitive ops:

PrecisionSizeRangePrecisionUse Case
float324 bytes±3.4e38~7 decimal digitsDefault, stable
float162 bytes±65,504~3 decimal digitsFast but unstable
bfloat162 bytes±3.4e38~2 decimal digitsBest of both worlds

bfloat16 = same exponent range as float32, less mantissa precision.

Mixed Precision Explainer

Understand floating-point formats used in LLM training and inference

FP16 (Half Precision)

Bit Layout (16 bits total)
S
Exp (5)
Mantissa (10)
Bits
16
Exponent
5
Mantissa
10
Range
±65,504

Precision Test

FormatRepresented AsErrorError %
FP320.1234570.00e+00.00%
FP160.1230474.10e-40.33%
BF160.1250001.54e-31.25%
FP8_E4M30.1250001.54e-31.25%
FP8_E5M20.0000001.23e-1100.00%
INT80.1259842.53e-32.05%
INT40.1428571.94e-215.71%

Training (Mixed Precision)

  • • Master weights in FP32
  • • Forward/backward in FP16/BF16
  • • Loss scaling to prevent underflow
  • • ~2x speedup, 50% memory saved

Inference (Quantization)

  • • INT8/INT4 for weights
  • • FP16 for activations
  • • 4x-8x memory reduction
  • • Requires calibration

Why BF16 for Training?

BF16 has the same exponent range as FP32, preventing overflow/underflow issues that can occur with FP16 during training. While it has lower precision, the dynamic range is more important for gradient updates. Modern GPUs (A100, H100) have native BF16 tensor cores.

Why bfloat16?

  1. 2x memory savings: Weights, activations, gradients all use half the memory
  2. Faster compute: Tensor cores on A100/H100 are optimized for bf16
  3. Stable training: Wider range than fp16 → no loss scaling needed
  4. Drop-in replacement: Works out-of-the-box for LLM training

nanochat Implementation

# Create autocast context
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
 
# Use it during forward/backward
with autocast_ctx:
    loss = model(x, y)
loss.backward()

Under the hood:

  • Matrix multiplications → bfloat16 (fast, memory-efficient)
  • Reductions (sum, softmax) → float32 (numerical stability)
  • Optimizer step → float32 (precision for weight updates)

What Gets Cast?

# Forward pass
x = self.transformer.wte(idx)  # Embedding: bfloat16
x = norm(x)                     # RMSNorm: bfloat16
q, k, v = self.attn(x)          # Linear: bfloat16
attn = softmax(q @ k.T)         # Softmax: float32 (automatic upcast)
out = attn @ v                  # MatMul: bfloat16
logits = self.lm_head(out)      # Linear: bfloat16
loss = cross_entropy(logits, y) # Cross-entropy: float32 (automatic upcast)

Result: Most memory and compute in bf16, stability-critical ops in fp32.

Explicit Casting in nanochat

# From nanochat/gpt.py
 
# Cast embeddings to bf16 (save memory in embedding table)
self.transformer.wte.to(dtype=torch.bfloat16)
 
# Cast rotary embeddings to bf16
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos.bfloat16(), sin.bfloat16()
 
# During forward: cast logits to fp32 for stable loss
logits = self.lm_head(x)
logits = logits.float()  # fp32 for cross-entropy
loss = F.cross_entropy(logits, targets)

Memory Savings Example

60M parameter model, batch_size=32:

Componentfp32bf16Savings
Weights240MB120MB-50%
Gradients240MB120MB-50%
Activations (~8GB)16GB8GB-50%
Total~16.5GB~8.2GB~50%

Enables: 2x larger models or 2x larger batches on same hardware.

For your next training run, this means: enable torch.bfloat16 on day one. It's free memory savings with zero quality loss. There's no reason not to use it on modern GPUs (A100, H100, RTX 4090).

For your production deployment, this means: bfloat16 inference is twice as fast and uses half the memory. If your serving infrastructure doesn't support it, you're leaving money on the table.


Shorter sequences free quadratic attention memory

The Challenge

Activation memory scales quadratically with sequence length:

Attention: Q @ K^T creates (seq_len × seq_len) matrix
Memory: O(batch_size * n_heads * seq_len^2 * sizeof(dtype))

Example: seq_len=2048 vs seq_len=4096:

  • Attention memory: 4x increase (2048² → 4096²)
  • Total activation memory: ~2-3x increase

Strategy 1: Start Short, Grow Gradually

Train on shorter sequences early, increase length later:

# Hypothetical staged training
stage_1_iters = 5000
stage_2_iters = 5000
 
if step < stage_1_iters:
    max_seq_len = 1024  # Start short
elif step < stage_1_iters + stage_2_iters:
    max_seq_len = 2048  # Grow
else:
    max_seq_len = 4096  # Full length

Trade-off: Early training sees less long-range context, but uses memory efficiently.

Strategy 2: Variable-Length Batching

nanochat uses fixed-length batches for simplicity:

# Every batch has exactly B × T tokens
batch = torch.randint(0, vocab_size, (B, T))

Alternative: Pack variable-length sequences into fixed token budget:

# Advanced (not in nanochat): pack sequences dynamically
# Batch 1: [seq_len=1024, seq_len=2048, seq_len=512] → 3584 tokens
# Batch 2: [seq_len=2048, seq_len=2048] → 4096 tokens
# Goal: maintain ~4096 tokens/batch, varying sequence counts

This is complex (requires padding/masking) but maximizes GPU utilization.

Strategy 3: Truncation

nanochat truncates long sequences during tokenization:

def render_conversation(self, conversation, max_tokens=2048):
    ids, mask = [], []
    # ... render conversation ...
    
    # Truncate to max_tokens
    ids = ids[:max_tokens]
    mask = mask[:max_tokens]
    return ids, mask

Trade-off: Long conversations lose tail context, but avoids OOM.

nanochat's Choice

# Fixed sequence length throughout training
max_seq_len = 2048  # Constant
 
# Data loader yields batches of shape (batch_size, max_seq_len)
for x, y in data_loader:
    assert x.shape == (batch_size, max_seq_len)
    loss = model(x, y)

Why? Simplicity + compiled model performance (dynamic shapes are slower).


Muon uses 2× less optimizer memory than Adam

Memory Footprint Comparison

OptimizerState per ParameterMemory per ParamExample (60M params)
SGDNone0 bytes0 MB
SGD + MomentumMomentum buffer4 bytes240 MB
AdamWMomentum + Variance8 bytes480 MB
MuonMomentum (2D params only)4 bytes240 MB

nanochat's Hybrid Approach

# From nanochat/gpt.py
def setup_optimizers(self):
    # Separate parameters by type
    matrix_params = list(self.transformer.h.parameters())      # 2D (linear layers)
    embedding_params = list(self.transformer.wte.parameters()) # 1D (embeddings)
    lm_head_params = list(self.lm_head.parameters())           # 2D (classifier)
    
    # AdamW for embeddings + lm_head (needs adaptive LR)
    adamw_optimizer = AdamW([
        {"params": lm_head_params, "lr": 0.004},
        {"params": embedding_params, "lr": 0.2},
    ], betas=(0.8, 0.95))
    
    # Muon for Transformer matrix params (memory-efficient)
    muon_optimizer = Muon(matrix_params, lr=0.02, momentum=0.95)
    
    return [adamw_optimizer, muon_optimizer]

Memory breakdown (60M param model):

  • Matrix params: ~50M params → Muon → 50M × 4B = 200MB
  • Embedding + lm_head: ~10M params → AdamW → 10M × 8B = 80MB
  • Total optimizer memory: 280MB (vs 480MB for full AdamW)

Savings: ~40% reduction in optimizer state memory.

When to Use What?

OptimizerUse CaseMemoryConvergence Speed
SGDVery memory-constrainedLowestSlowest
AdamWStandard choice, embeddingsHighFast
MuonMatrix params (Transformers)MediumFast
8-bit AdamWExtreme memory constraintsMediumFast (slight quality loss)

DDP replicates memory; FSDP shards it across GPUs

Data Parallelism (DDP)

Each GPU holds:

  • Full copy of model weights
  • Full copy of optimizer state
  • 1/N of the batch (where N = number of GPUs)
# Example: 8 GPUs, batch_size=32, seq_len=2048
# Total batch: 32 × 2048 × 8 = 524,288 tokens
# Per-GPU batch: 32 × 2048 = 65,536 tokens
 
# Each GPU:
device_batch_size = 32  # Per-GPU
total_batch_size = 32 * 8 * 2048  # Across all GPUs
 
# Memory per GPU:
# - Model weights: Same on all GPUs
# - Activations: 1/8 of total (only local batch)
# - Gradients: Same on all GPUs (all-reduced after backward)

Memory savings: Activation memory is distributed, but model/optimizer memory is not.

ZeRO-2 (DistMuon/DistAdamW)

nanochat uses ZeRO-2 for optimizer state:

# From nanochat/muon.py
class DistMuon:
    """Distributed Muon with ZeRO-2: shard optimizer state across ranks"""
    
    def __init__(self, params, lr, momentum):
        # Each rank owns a shard of parameters
        self.rank = dist.get_rank()
        self.world_size = dist.get_world_size()
        
        # Shard parameters across ranks
        self.owned_params = [p for i, p in enumerate(params) if i % world_size == rank]
        
        # Only allocate momentum for owned parameters
        self.momentum_buffers = [torch.zeros_like(p) for p in self.owned_params]

Memory savings:

  • Optimizer state: Divided by world_size
  • Gradients: Still all-reduced (not sharded in ZeRO-2)
  • Weights: Still replicated (not sharded in ZeRO-2)

Example (60M params, 8 GPUs):

  • Without ZeRO: 280MB optimizer state per GPU
  • With ZeRO-2: 280MB / 8 = 35MB optimizer state per GPU

Trade-off: Requires all-gather before optimizer step (communication overhead).

ZeRO-3 (Not in nanochat)

ZeRO-3 shards model weights too:

  • Weights: Divided by world_size
  • Gradients: Divided by world_size
  • Optimizer state: Divided by world_size

Memory per GPU = total_memory / world_size

Trade-off: Much more communication (all-gather on every forward pass).


KV caching trades memory for 10× inference speed

KV Caching

From Post 1.3, KV caching reuses past key/value projections:

class KVCache:
    def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
        # Shape: (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
        self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
        self.kv_cache = None
        self.pos = 0
    
    def insert_kv(self, layer_idx, k, v):
        # Dynamically grow cache if needed
        if self.pos + k.size(2) > self.kv_cache.size(4):
            # Grow by 1024 tokens, round up
            new_size = ((self.pos + k.size(2) + 1024) + 1023) & ~1023
            self.kv_cache.resize_([..., new_size, ...])
        
        # Insert new k, v
        self.kv_cache[layer_idx, 0, :, :, self.pos:self.pos+k.size(2)] = k
        self.kv_cache[layer_idx, 1, :, :, self.pos:self.pos+k.size(2)] = v
        self.pos += k.size(2)
        
        # Return full cache
        return self.kv_cache[layer_idx, 0, :, :, :self.pos], self.kv_cache[layer_idx, 1, :, :, :self.pos]

Memory cost:

# For a 12-layer, 6-head, 128-dim model:
# batch_size=1, max_seq_len=2048
kv_memory = 2 * 12 * 6 * 2048 * 128 * 2  # (K+V) * layers * heads * seq * dim * bytes(bf16)
          = ~70 MB per sequence

Batch size impact: KV cache scales linearly with batch size:

  • batch_size=1: 70 MB
  • batch_size=8: 560 MB
  • batch_size=64: 4.5 GB

Inference Batch Size Tuning

For generation, batch size trades throughput vs latency:

# Low batch size: low latency, low throughput
engine.generate_batch(tokens, num_samples=1)  # 1 sequence at a time
 
# High batch size: high latency, high throughput
engine.generate_batch(tokens, num_samples=64) # 64 sequences in parallel

Memory scaling:

memory_per_batch = weights + (batch_size × kv_cache_per_seq) + (batch_size × activations_per_seq)

For serving: Use largest batch size that fits in memory (maximizes throughput).

Prefill vs Decode

nanochat separates prefill (process prompt) and decode (generate tokens):

# Prefill: batch_size=1, full prompt at once
kv_cache_prefill = KVCache(batch_size=1, seq_len=len(prompt))
logits = model.forward(prompt_tokens, kv_cache=kv_cache_prefill)
 
# Decode: batch_size=num_samples, one token at a time
kv_cache_decode = KVCache(batch_size=num_samples, seq_len=max_gen_len)
kv_cache_decode.prefill(kv_cache_prefill)  # Copy from prefill
 
# Generate multiple samples in parallel
for step in range(max_gen_len):
    logits = model.forward(next_tokens, kv_cache=kv_cache_decode)
    next_tokens = sample(logits)

Memory optimization: Prefill uses batch_size=1 (save KV cache memory), then replicate for decode.


Gradient checkpointing trades compute for 3× less activation memory

The Problem

Activation memory dominates for large models:

# Forward pass stores activations for backward pass
x = input
for layer in model.layers:
    x = layer(x)  # Activation stored in memory

For a 20-layer model, 20 sets of activations stored simultaneously.

Gradient Checkpointing

Idea: Don't store all activations. Recompute them during backward pass.

# Without checkpointing: store all activations
x = input
activations = []
for layer in layers:
    x = layer(x)
    activations.append(x)  # Store for backward
loss = criterion(x, target)
loss.backward()  # Uses stored activations
 
# With checkpointing: store only some activations
x = input
checkpoints = []
for i, layer in enumerate(layers):
    if i % checkpoint_every == 0:
        checkpoints.append(x)  # Checkpoint
    x = layer(x)  # Don't store
loss = criterion(x, target)
loss.backward()  # Recomputes missing activations on-the-fly

Memory/compute trade-off:

  • Memory saved: ~N/checkpoint_every (e.g., checkpoint every 4 layers → 4x savings)
  • Compute cost: ~33% increase (recompute during backward)

PyTorch Implementation

import torch.utils.checkpoint as checkpoint
 
class TransformerWithCheckpointing(nn.Module):
    def forward(self, x):
        for layer in self.layers:
            # Checkpoint this layer (no activations stored)
            x = checkpoint.checkpoint(layer, x, use_reentrant=False)
        return x

nanochat Status

Not currently implemented in nanochat (focus on simplicity), but would add:

# Hypothetical addition to nanochat/gpt.py
class GPT(nn.Module):
    def __init__(self, config, use_checkpointing=False):
        self.use_checkpointing = use_checkpointing
        # ...
    
    def forward(self, x):
        for block in self.transformer.h:
            if self.use_checkpointing:
                x = checkpoint.checkpoint(block, x, cos_sin, kv_cache, use_reentrant=False)
            else:
                x = block(x, cos_sin, kv_cache)
        return x

When to use:

  • Very large models (billions of parameters)
  • Long sequences (4K+ tokens)
  • Memory-constrained hardware

When NOT to use:

  • Small models (overhead dominates)
  • Short sequences (activation memory is small)
  • When training speed is critical

torch.cuda.memory_summary shows exactly where memory goes

PyTorch Memory Stats

import torch
 
# Track peak memory
torch.cuda.reset_peak_memory_stats()
 
# ... training code ...
 
peak_memory_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
print(f"Peak memory: {peak_memory_mb:.2f} MB")
 
# Detailed stats
print(torch.cuda.memory_summary())

Output example:

|===========================================================================|
|                  PyTorch CUDA memory summary                              |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |
|        Allocation retries: 0       |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   8.2 GB   |  12.5 GB   |  50.3 GB   |  42.1 GB   |
| Active memory         |   8.2 GB   |  12.5 GB   |            |            |
| ...

nanochat's Usage

# From scripts/base_train.py
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")

Profiling Individual Components

def profile_memory(model, batch):
    torch.cuda.reset_peak_memory_stats()
    
    # Forward pass
    loss = model(batch)
    fwd_memory = torch.cuda.max_memory_allocated() / 1024**2
    print(f"After forward: {fwd_memory:.2f} MB")
    
    # Backward pass
    torch.cuda.reset_peak_memory_stats()
    loss.backward()
    bwd_memory = torch.cuda.max_memory_allocated() / 1024**2
    print(f"After backward: {bwd_memory:.2f} MB")
    
    # Optimizer step
    torch.cuda.reset_peak_memory_stats()
    optimizer.step()
    opt_memory = torch.cuda.max_memory_allocated() / 1024**2
    print(f"After optimizer: {opt_memory:.2f} MB")

PyTorch Profiler (Advanced)

from torch.profiler import profile, ProfilerActivity
 
with profile(activities=[ProfilerActivity.CUDA], profile_memory=True) as prof:
    for step in range(10):
        loss = model(x, y)
        loss.backward()
        optimizer.step()
 
# Export to Chrome trace
prof.export_chrome_trace("trace.json")
 
# Or print table
print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=10))

Output:

---------------------------------  ------------  ------------  
Name                               Self CUDA Mem Peak CUDA Mem
---------------------------------  ------------  ------------  
aten::matmul                       2.50 GB      2.50 GB      
aten::linear                       1.20 GB      1.20 GB      
aten::softmax                      0.80 GB      0.80 GB      
...

Debugging OOM Errors

try:
    loss = model(x, y)
    loss.backward()
except RuntimeError as e:
    if "out of memory" in str(e):
        print(f"OOM! Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
        print(f"Reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
        print(torch.cuda.memory_summary())
        
        # Try to recover
        torch.cuda.empty_cache()
        # Reduce batch size and retry
        batch_size = batch_size // 2
    else:
        raise e

These patterns prevent OOM errors

Best Practices

1. Start Conservative, Scale Up

# Good: Start with small batch, measure memory, then scale
device_batch_size = 8   # Start small
# ... measure peak memory ...
# If memory allows, increase to 16, 32, etc.
 
# Bad: Start with huge batch, OOM immediately
device_batch_size = 128  # Instant OOM

2. Use Mixed Precision by Default

# Good: Always use bfloat16 for modern GPUs
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
 
# Bad: Stick to fp32 (uses 2x memory for no reason)
# ... no autocast, everything in fp32 ...

3. Gradient Accumulation for Large Batches

# Good: Simulate large batch with gradient accumulation
device_batch_size = 16  # Fits in memory
grad_accum_steps = 32   # Effective batch = 16 × 32 = 512
 
# Bad: Try to fit huge batch directly
device_batch_size = 512  # OOM

4. Clear Gradients Properly

# Good: Free gradient memory immediately
model.zero_grad(set_to_none=True)
 
# Okay: Zero gradients (slightly slower)
model.zero_grad()
 
# Bad: Don't clear gradients (memory leak)
# ... no zero_grad call ...

5. Monitor Memory Throughout Training

# Good: Log memory usage periodically
if step % 100 == 0:
    mem_mb = torch.cuda.memory_allocated() / 1024**2
    wandb.log({"memory_mb": mem_mb})
 
# Bad: Never check memory (discover OOM at step 5000)

Common Pitfalls

Pitfall 1: Forgetting to Scale Loss in Gradient Accumulation

# Bad: Gradients are grad_accum_steps × too large
for micro_step in range(grad_accum_steps):
    loss = model(x, y)
    loss.backward()  # Oops, no scaling!
 
# Good: Scale loss by grad_accum_steps
for micro_step in range(grad_accum_steps):
    loss = model(x, y)
    loss = loss / grad_accum_steps
    loss.backward()

Pitfall 2: Storing Unnecessary Tensors

# Bad: Accumulating losses keeps activation graphs in memory
losses = []
for x, y in data_loader:
    loss = model(x, y)
    losses.append(loss)  # Stores computation graph!
 
# Good: Detach scalars
losses = []
for x, y in data_loader:
    loss = model(x, y)
    losses.append(loss.detach().item())  # No graph, just float

Pitfall 3: Inefficient KV Cache Initialization

# Bad: Preallocate huge cache upfront
kv_cache = torch.zeros((layers, 2, batch, heads, 100000, dim))  # 100K seq len!
 
# Good: Grow dynamically as needed
kv_cache = torch.zeros((layers, 2, batch, heads, prompt_len, dim))
# ... grow as generation proceeds ...

Pitfall 4: Mixed Precision Without Autocast

# Bad: Manually cast everything (error-prone, verbose)
x = x.to(dtype=torch.bfloat16)
y = model(x)
y = y.to(dtype=torch.float32)
 
# Good: Use autocast context (automatic, correct)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
    y = model(x)

Pitfall 5: Not Clearing CUDA Cache

# Bad: Cache fragmentation causes OOM over time
for step in range(10000):
    # ... training ...
    # Memory gets fragmented, eventual OOM
 
# Good: Periodically clear cache
for step in range(10000):
    # ... training ...
    if step % 1000 == 0:
        torch.cuda.empty_cache()

Pitfall 6: Large Batch Size on Small Models

# Bad: batch_size=256 on a 10M param model
# Problem: Activation memory >> model memory (wasteful)
 
# Good: Use batch size proportional to model size
# Small model (10M): batch_size=16-32
# Medium model (100M): batch_size=32-64
# Large model (1B+): batch_size=64-128

Memory optimization is the art of doing more with less

You can train larger models, longer sequences, and bigger batches on the hardware you already have.

Key techniques:

  1. Gradient Accumulation: Free larger effective batch sizes
  2. Mixed Precision (bfloat16): 2x memory savings with no quality loss
  3. Optimizer Choice: Muon saves 50% vs AdamW on matrix parameters
  4. Sequence Length Management: Start short, grow gradually
  5. ZeRO-2: Shard optimizer state across GPUs
  6. KV Caching: Reuse past keys/values during generation
  7. Gradient Checkpointing: Trade compute for memory (33% slower, 4x less memory)
  8. Memory Profiling: Measure before optimizing

Memory hierarchy:

Model Weights (fixed) &lt; Optimizer State (fixed) &lt;&lt; Activations (scales with batch)

Optimization priority:

  1. Enable mixed precision (bfloat16) → instant 2x savings
  2. Tune device_batch_size + grad_accum_steps → maximize GPU utilization
  3. Choose efficient optimizer (Muon for Transformers) → 40% optimizer memory savings
  4. Use distributed training (DDP + ZeRO-2) → linear scaling across GPUs
  5. Add gradient checkpointing (if desperate) → 4x activation memory savings

With these techniques, you can train 2-4x larger models on the same hardware—or train faster with larger batches and more aggressive settings.


Before you optimize your memory usage:

  1. Enable bfloat16 first. This is free 2× memory savings with zero quality loss—do this before anything else.
  2. Profile before optimizing. Use torch.cuda.memory_allocated() to identify whether activations, optimizer state, or model weights are your bottleneck.
  3. Start with small batch size, scale up. Begin at batch_size=8, measure peak memory, then double until you approach 90% utilization.
  4. Scale loss during gradient accumulation. Divide by grad_accum_steps—forgetting this makes gradients N× too large.
  5. Use set_to_none=True in zero_grad. Frees gradient memory immediately instead of just zeroing values.

The GPU you have is more capable than you think. Now you know how to unlock it.


Sources

Research Papers

  • ZeRO: Memory Optimizations Toward Training Trillion Parameter Models (Rajbhandari et al., 2019) - The foundational paper introducing Zero Redundancy Optimizer, which partitions optimizer states, gradients, and parameters across data-parallel processes to dramatically reduce memory footprint. arXiv:1910.02054

  • Mixed Precision Training (Micikevicius et al., 2017) - Introduces techniques for training deep neural networks using half-precision floating point numbers while maintaining model accuracy, reducing memory consumption by nearly 2x. Published at ICLR 2018. arXiv:1710.03740

  • Training Deep Nets with Sublinear Memory Cost (Chen et al., 2016) - Proposes gradient checkpointing, a systematic approach to reduce memory consumption from O(n) to O(√n) for training n-layer networks, enabling deeper models on limited hardware. arXiv:1604.06174

  • ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning (Rajbhandari et al., 2021) - Extends ZeRO with CPU and NVMe offloading capabilities to train models with trillions of parameters. arXiv:2104.07857

  • Muon: An optimizer for hidden layers in neural networks (Kosson et al., 2025) - Describes the Muon optimizer that maintains significantly less state than Adam (only momentum vs. both first and second moments), reducing optimizer memory requirements. arXiv:2502.16982

Technical Documentation

  • PyTorch Automatic Mixed Precision (torch.amp) - Official PyTorch documentation for the AMP package, covering autocast context managers, GradScaler for gradient scaling, and op-specific behavior for mixed precision training. PyTorch AMP Docs

  • DeepSpeed ZeRO Tutorial - Comprehensive guide to implementing ZeRO stages 1-3, including configuration examples, memory savings calculations, and ZeRO-Infinity offloading to CPU and NVMe. DeepSpeed ZeRO Tutorial

  • DeepSpeed Configuration Reference - Complete documentation for DeepSpeed JSON configuration options including ZeRO optimizations, FP16/BFloat16 training, and optimizer settings. DeepSpeed Config JSON

Framework Resources

  • PyTorch Gradient Checkpointing - Documentation for torch.utils.checkpoint which implements activation checkpointing to trade compute for memory during training. PyTorch Checkpoint Utils

  • NVIDIA Apex - NVIDIA's PyTorch extension library providing optimized mixed precision and distributed training utilities. NVIDIA Apex GitHub

  • Microsoft DeepSpeed - Deep learning optimization library that implements ZeRO, mixed precision training, and various memory optimization techniques. DeepSpeed GitHub

GPU Hardware & Pricing (as of January 2025)

GPUVRAMMemory BandwidthTypical Cost/hrSource
RTX 309024GB936 GB/sConsumer purchaseNVIDIA RTX 3090
RTX 409024GB1 TB/sConsumer purchaseNVIDIA RTX 4090
A100 80GB80GB2 TB/s~$1.44/hrLambda Labs
H100 80GB80GB3.35 TB/s~$2.49/hrLambda Labs

Industry Research (as of January 2025)

  • Epoch AI Training Compute: Compute Trends in Machine Learning. Tracks memory efficiency improvements; shows 3× memory efficiency gains in 2024 vs 2022 baseline.
  • MLCommons MLPerf Training: Training Benchmark Results. Industry-standard benchmarks showing memory efficiency across hardware configurations.


Exercises

  1. Measure activation memory: Profile your model and measure what percentage of memory is activations vs weights vs optimizer state.

  2. Gradient accumulation experiment: Train with batch_size=32, grad_accum=1 vs batch_size=16, grad_accum=2. Verify identical convergence.

  3. Mixed precision ablation: Train with fp32 vs bf16. Compare memory usage, training speed, and final validation loss.

  4. Sequence length scaling: Measure peak memory as you scale seq_len from 512 → 1024 → 2048 → 4096. Plot memory vs seq_len².

  5. Optimizer state comparison: Train identical model with AdamW vs Muon. Compare optimizer state memory (use torch.cuda.memory_allocated()).

  6. KV cache growth: Implement dynamic KV cache growth and measure memory usage during generation from 0 to 2048 tokens.


Series Complete! You've completed Track 2 (Practical Guides). Congratulations on mastering the practical implementation of nanochat!


Part of the nanochat Deep-Dive Series • Code: nanochat on GitHub