José David Baena

On This Page

On this page

Distributed Muon: Custom Gradient Synchronization for Memory-Efficient Training

Distributed muon custom gradient synchronization banner.jpg
Published on
/20 mins read

Standard DDP wastes 8× memory—every GPU stores everything

The DistMuon implementation is one of the most sophisticated parts of nanochat's codebase. Understanding how ZeRO-2 sharding works with Muon's matrix structure requirements taught me more about distributed training than any paper.

Training large language models requires distributing computation across multiple GPUs. While PyTorch's DDP makes this conceptually simple, it comes with significant memory overhead—every GPU stores a complete copy of the model parameters, gradients, and optimizer states.

TL;DR: For a 1 billion parameter model on 8 GPUs, standard DDP means 8× redundant copies of everything. ZeRO eliminates this redundancy by sharding optimizer states across GPUs—but Muon's Newton-Schulz orthogonalization requires preserving 2D matrix structure. DistMuon solves both problems.

The OOM that stopped a training run: Consider a scenario common in distributed training: 40% through training a 2.7B model on 8×A100s, GPU memory usage creeps up and crashes the run. Standard DDP means each GPU holds a full copy of optimizer states—18GB of redundant memory across 8 GPUs. Switching to DeepSpeed ZeRO-2 seems like the fix, but Muon's Newton-Schulz iterations fail silently because the 2D matrix structure is broken by naive sharding. Gradients look fine. Optimizer states are garbage. The issue: you can't shard a 2D weight matrix along arbitrary dimensions without breaking Muon's orthogonalization. DistMuon's "preserve matrix structure" sharding was designed specifically to avoid this failure mode.

⚠️ Note on Muon Optimizer

The DistMuon implementation referenced in this tutorial is based on the Muon optimizer approach from the nanochat research codebase. This post demonstrates distributed training techniques combining ZeRO-2 optimization with Newton-Schulz orthogonalization for educational purposes.

For production use:

  • This serves as a learning resource for understanding custom distributed optimizer design
  • Standard distributed optimizers (DeepSpeed ZeRO, FSDP) are recommended for production workloads
  • The nanochat DistMuon implementation is experimental and designed for research/education
  • Consider established frameworks (DeepSpeed, Megatron-LM) for production-scale training

NOTE

Prerequisites: Understanding of the Muon optimizer and basic distributed training concepts. Reading time: ~12 minutes.

DistMuon achieves:

  • ~2-3× memory savings compared to standard DDP
  • Seamless integration with existing training loops
  • Custom reduce_scatter → compute → all_gather pattern optimized for Muon

Every rank stores full model, gradients, and optimizer states

Standard DDP's Synchronization Model

PyTorch DDP follows a simple but memory-inefficient pattern:

Standard DDP Training Pattern
# Pseudo-code for standard DDP
for step in training_loop:
    loss = model(x, y)
    loss.backward()  # Compute gradients locally
    # DDP hooks: all_reduce gradients (implicit, happens in backward)
    optimizer.step()  # Each rank updates full model independently

During backward(), DDP's hooks automatically trigger an all_reduce operation that averages gradients across all ranks. This ensures every GPU has identical gradients before the optimizer step.

Memory Overhead Analysis

For each parameter, every rank stores:

  1. Parameters (model weights): P bytes
  2. Gradients: P bytes
  3. Optimizer states: Depends on optimizer
    • Adam/AdamW: 2 states (exp_avg, exp_avg_sq) = 2P bytes
    • Muon: 1 state (momentum_buffer) = P bytes

Total per rank (Muon + DDP): P + P + P = 3P bytes
Total across N ranks: 3P × N bytes

For a 1B parameter model (2GB in bfloat16) on 8 GPUs:

Memory per rank = 2GB (params) + 2GB (grads) + 2GB (momentum) = 6GB
Total memory = 6GB × 8 = 48GB

This redundancy is wasteful. Can we do better?

What ZeRO-2 Offers

ZeRO has three stages of optimization:

  • Stage 1: Shard optimizer states across ranks
  • Stage 2: Shard gradients + optimizer states
  • Stage 3: Shard parameters + gradients + optimizer states

DistMuon implements ZeRO-2, keeping parameters replicated but sharding gradients and optimizer states. This strikes a balance between memory efficiency and implementation complexity.

Memory per rank (ZeRO-2): P + P/N + P/N = P(1 + 2/N) bytes

For our 1B parameter example with 8 GPUs:

Memory per rank = 2GB (params) + 0.25GB (grads/8) + 0.25GB (momentum/8) = 2.5GB
Total memory = 2.5GB × 8 = 20GB
Savings: 48GB - 20GB = 28GB (58% reduction!)

For your training cluster, this means: ZeRO-2 is the difference between "needs 8 GPUs" and "fits on 4." Same model, half the hardware. The communication overhead is negligible compared to the memory savings.

Distributed Training Simulator

Visualize how DDP and FSDP distribute work across GPUs

DDP (Data Parallel): Each GPU has full model copy. Gradients are averaged after backward pass.

Communication: All-reduce gradients only

GPU ClusterIdle
Per-GPU Memory
16.0 GB
100% of full model
Total Cluster Memory
64.0 GB
Across 4 GPUs
Memory Efficiency
25%
Fair
ComputingCommunicatingParametersGradientsOptimizer

Memory Savings Calculator

Compare memory usage across different distributed training strategies

Memory per GPU by Strategy

GPU Memory Limit: 80 GB
StrategyMemorySavingsFits?
No Sharding
Full model copy on each GPU
33.9 GB-
ZeRO Stage 1
Optimizer states partitioned
23.4 GB31%
ZeRO Stage 2
Optimizer + gradients partitioned
18.1 GB46%
ZeRO Stage 3 / FSDP
Everything partitioned across GPUs
15.5 GB54%
Tensor Parallelism
Layers split across GPUs
7.2 GB79%

Recommendation: Use Tensor Parallelism to train your 1.5B model on 8x 80GB GPUs. Memory per GPU: 7.2 GB (79% savings).

Parameters (BF16)
Gradients (FP32)
Optimizer (Adam)
Activations

Three design decisions make DistMuon work

Parameter Grouping by Shape

DistMuon groups all parameters by their shape before assigning ownership.

From the nanochat codebase (view on GitHub):

nanochat/muon.py - Parameter Grouping
rank = dist.get_rank()
shapes = sorted({p.shape for p in params})  # Unique shapes, sorted
param_groups = []
for shape in shapes:
    group_params = [p for p in params if p.shape == shape]
    device, dtype = group_params[0].device, group_params[0].dtype
    assert all(p.device == device for p in group_params)
    assert all(p.dtype == dtype for p in group_params)
    if rank == 0:
        print(f"Muon: Grouping {len(group_params)} params of shape {shape}")
    param_groups.append(dict(
        params=group_params,
        zero_buffer=torch.zeros_like(group_params[0])
    ))

Why group by shape?

  1. Efficient batched operations: Newton-Schulz can process multiple matrices of the same shape simultaneously
  2. Simplified communication: reduce_scatter and all_gather require uniform tensor shapes
  3. Better GPU utilization: Batched matrix operations maximize throughput

TIP

Example: A transformer model might have 100 parameters of shape [768, 768] (attention matrices), 50 of shape [3072, 768] (FFN matrices), and 50 of shape [768, 3072]. DistMuon creates 3 parameter groups, enabling efficient batched Newton-Schulz within each group.

Parameters map to GPUs in a round-robin pattern—here's why that matters

Within each shape group, parameters are assigned to ranks in a block-cyclic pattern:

Block-Cyclic Assignment Pattern
world_size = dist.get_world_size()
for base_i in range(0, len(params), world_size):
    owner_idx = base_i + rank  # Each rank owns param at (base + rank)

Visual representation (4 GPUs, 10 parameters):

Param indices:  [0, 1, 2, 3,  4, 5, 6, 7,  8, 9]
                 └─────────┘  └─────────┘  └──┘
                  Block 0      Block 1    Block 2

Rank 0 owns:    [0,          4,          8    ]  ← indices 0, 4, 8
Rank 1 owns:    [   1,          5,          9 ]  ← indices 1, 5, 9
Rank 2 owns:    [      2,          6          ]  ← indices 2, 6
Rank 3 owns:    [         3,          7       ]  ← indices 3, 7

Why block-cyclic?

  • Load balancing: Distributes parameters roughly evenly across ranks
  • Simplicity: Each rank's ownership is a simple calculation (base_i + rank)
  • Fault tolerance: Uneven parameter counts are handled gracefully with padding

For your debugging sessions, this means: when training stalls or a rank crashes, you can trace exactly which parameters each GPU owns. Block-cyclic assignment makes ownership deterministic—same layout every run.

The Three-Phase Update Pattern

DistMuon orchestrates a custom communication pattern that shards computation while maintaining parameter replication:

Loading diagram...

Phase 1: Reduce-scatter averages gradients and distributes ownership

The reduce_scatter operation combines gradients from all ranks and distributes the averaged results.

From the nanochat codebase (view on GitHub):

nanochat/muon.py - Reduce-Scatter Phase
all_reduce_futures = []
for group in self.param_groups:
    params = group["params"]
    zero_buffer = group["zero_buffer"]
    for base_i in range(0, len(params), world_size):
        owner_idx = base_i + rank
        # Each rank collects gradients for world_size consecutive params
        rs_input = [p.grad for p in params[base_i:base_i + world_size]]
        # Pad with zeros if we don't have enough params
        rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
        # Output buffer: gradient for the param this rank owns
        rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
        # Launch async reduce_scatter
        work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
        all_reduce_futures.append(work)

What happens here:

  1. Input collection: Each rank gathers gradients for a block of world_size parameters
  2. Padding: If the block is incomplete (e.g., last block with fewer params), pad with zero_buffer
  3. Reduce-scatter: All ranks participate in averaging gradients
  4. Output: Each rank receives the averaged gradient for its owned parameter

Example (4 GPUs, block 0 with params [0,1,2,3]):

Rank 0: rs_input = [grad₀[0], grad₀[1], grad₀[2], grad₀[3]]  → rs_output = avg(grad[0])
Rank 1: rs_input = [grad₁[0], grad₁[1], grad₁[2], grad₁[3]]  → rs_output = avg(grad[1])
Rank 2: rs_input = [grad₂[0], grad₂[1], grad₂[2], grad₂[3]]  → rs_output = avg(grad[2])
Rank 3: rs_input = [grad₃[0], grad₃[1], grad₃[2], grad₃[3]]  → rs_output = avg(grad[3])

After reduce-scatter, each rank has the averaged gradient for its owned parameter, ready for computation.


Phase 2: Only the owner rank computes Newton-Schulz orthogonalization

Once gradients are averaged, each rank computes the Muon update for its owned parameters.

From the nanochat codebase (view on GitHub):

nanochat/muon.py - Compute Phase
future_idx = 0
all_gather_futures = []
for group in self.param_groups:
    params = group["params"]
    for base_i in range(0, len(params), world_size):
        owner_idx = base_i + rank
        # Wait for reduce_scatter to complete
        all_reduce_futures[future_idx].wait()
        future_idx += 1
        
        # Only owner computes the update
        if owner_idx < len(params):
            p = params[owner_idx]
            g = p.grad  # Already averaged across ranks
            state = self.state[p]
            
            # Initialize momentum buffer if needed
            if "momentum_buffer" not in state:
                state["momentum_buffer"] = torch.zeros_like(g)
            
            buf = state["momentum_buffer"]
            # Momentum accumulation
            buf.lerp_(g, 1.0 - group["momentum"])
            # Nesterov momentum (if enabled)
            g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
            # Newton-Schulz orthogonalization
            g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
            # Aspect-ratio scaled step
            scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
            p.add_(g, alpha=-group["lr"] * scale)

Key points:

  1. Wait synchronization: wait() ensures the gradient is ready before computation
  2. Owner-only execution: Non-owner ranks skip computation (idle during this phase)
  3. Standard Muon update: Same as single-GPU Muon (see Post 1.1)
    • Momentum accumulation with lerp_
    • Optional Nesterov momentum
    • Newton-Schulz orthogonalization
    • Aspect-ratio scaling: sqrt(max(1, height/width))

NOTE

Memory efficiency: Each rank stores momentum_buffer only for its owned parameters, achieving 1/N sharding.


Phase 3: All-gather broadcasts updated parameters to all GPUs

After computing updates, ranks replicate their updated parameters to all other ranks.

From the nanochat codebase (view on GitHub):

nanochat/muon.py - All-Gather Phase
        # Replicate updated parameters to all ranks
        ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
        ag_output = params[base_i:base_i + world_size]
        ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))])
        work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
        all_gather_futures.append(work)
 
# Wait for all gathers to complete (outside the loops)
torch.futures.collect_all(all_gather_futures).wait()

What happens:

  1. Input: Each rank's owned parameter (or zero_buffer if padding)
  2. Output: List of tensors to populate with gathered parameters
  3. All-gather: Broadcast each rank's parameter to all other ranks
  4. Synchronization: collect_all().wait() ensures all communications complete

After all-gather, every rank has identical copies of all parameters, ready for the next forward pass.


DistMuon vs DistAdamW: same pattern, different memory footprints

Both optimizers implement ZeRO-2, but their strategies differ due to algorithmic requirements.

Key Differences

FeatureDistAdamWDistMuonReason
Parameter RequirementsAny shape2D onlyNewton-Schulz needs matrices
Sharding StrategySlice along dim 0Block-cyclic whole paramsPreserve aspect ratio
State StorageSlice-local (exp_avg, exp_avg_sq)Param-local (momentum_buffer)Matrix operations
Compute PatternAll ranks on slicesOwner ranks onlySimplify NS batching
Reduce-scatter InputFull tensorList of tensorsShape uniformity
Memory Efficiency~1/N states~1/N statesSimilar overall
Load BalancePerfect (slicing)Imperfect (padding)Trade-off for simplicity

DistAdamW's Sharding Approach

From the nanochat codebase (view on GitHub):

nanochat/adamw.py - Tensor Slicing Strategy
for base_i in range(len(params)):
    grad = params[base_i].grad
    rank_size = grad.shape[0] // world_size
    grad_slice = torch.empty_like(grad[:rank_size])
    reduce_scatter_futures.append(
        dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True)
    )

DistAdamW slices each parameter along the first dimension, distributing rows across ranks. This works for any parameter shape and achieves perfect load balancing.

WARNING

Why doesn't DistMuon do this? Newton-Schulz requires the full 2D matrix structure with its original aspect ratio. Slicing a [768, 768] matrix into [192, 768] slices would change the aspect ratio from 1:1 to 1:4, breaking the orthogonalization geometry. DistMuon preserves matrices intact by assigning whole parameters to owner ranks.

Comparison: Sharding Granularity

Example: 4 GPUs, parameter shape [1024, 768]

DistAdamW:
┌────────────┐
│ Rank 0 (256 rows) │  Each rank stores:
├────────────┤  - param slice [256, 768]
│ Rank 1 (256 rows) │  - grad slice [256, 768]
├────────────┤  - exp_avg slice [256, 768]
│ Rank 2 (256 rows) │  - exp_avg_sq slice [256, 768]
├────────────┤
│ Rank 3 (256 rows) │
└────────────┘

DistMuon (within a block of 4 params):
Rank 0: param[0] [1024,768]  ← Full matrix
Rank 1: param[1] [1024,768]  ← Full matrix
Rank 2: param[2] [1024,768]  ← Full matrix
Rank 3: param[3] [1024,768]  ← Full matrix

Each rank stores:
- Full param (replicated)
- momentum_buffer for owned param only

The math: 58-67% savings depend on your model architecture and Efficiency Gains

Memory Breakdown Per Rank

Standard DDP + Muon:

Parameters:          P bytes (full model)
Gradients:           P bytes (full model)
Momentum buffers:    P bytes (full model)
─────────────────────────────
Total per rank:      3P bytes
Total across N:      3P × N bytes

DistMuon (ZeRO-2):

Parameters:          P bytes (replicated)
Gradients:           P/N bytes (sharded)
Momentum buffers:    P/N bytes (sharded)
─────────────────────────────
Total per rank:      P(1 + 2/N) bytes
Total across N:      P(N + 2) bytes

Efficiency Calculations

N ranksDDP TotalDistMuon TotalMemory SavingsSavings %
26P4P2P33%
412P6P6P50%
824P10P14P58%
1648P18P30P63%
64192P66P126P66%

Asymptotic behavior: As N → ∞, savings approach 67% (2/3 reduction).

Practical Example: nanochat's 270M Model

nanochat's depth-20 model has ~270M parameters in bfloat16 (540 MB total).

8× H100 GPUs (80GB each):

MetricStandard DDPDistMuonSavings
Params540 MB540 MB0 MB
Grads540 MB67.5 MB472.5 MB
States540 MB67.5 MB472.5 MB
Total/rank1.62 GB675 MB945 MB (58%)
Total/cluster12.96 GB5.4 GB7.56 GB

This 945 MB per-rank saving allows:

  • Larger batch sizes (more memory for activations)
  • Longer sequences (quadratic attention memory)
  • Bigger models (fit 470M params with same memory as 270M DDP)

For your multi-GPU training, this means: sharded optimizer state unlocks training headroom you didn't know you had. That 945 MB per rank translates directly into longer context windows or deeper models without touching your hardware budget.


Async communication overlaps compute with network—here's how

Why Asynchronous Operations?

DistMuon uses async_op=True throughout to overlap communication with computation:

Asynchronous Communication Pattern
# Launch all reduce-scatters without waiting
for group in self.param_groups:
    for base_i in range(0, len(params), world_size):
        work = dist.reduce_scatter(..., async_op=True).get_future()
        all_reduce_futures.append(work)
 
# Compute and gather (wait only when needed)
for group in self.param_groups:
    for base_i in range(0, len(params), world_size):
        all_reduce_futures[future_idx].wait()  # Wait for specific gradient
        if owner_idx < len(params):
            # Compute Muon update
            ...
        work = dist.all_gather(..., async_op=True).get_future()
        all_gather_futures.append(work)
 
# Final synchronization
torch.futures.collect_all(all_gather_futures).wait()

Benefits:

  1. Communication-computation overlap: While GPU computes updates for earlier parameters, network transfers gradients for later parameters
  2. Pipelining: Reduce-scatter and all-gather operations can overlap across parameter groups
  3. Lower latency: Non-blocking calls prevent idle GPU time

Synchronization Pattern

Time ──────────────────────────────────────────────►

Rank 0:
  [reduce_scatter₀] [wait] [compute₀] [all_gather₀]
                    [reduce_scatter₁] [wait] [compute₁] [all_gather₁]
                                      [reduce_scatter₂] [wait] ...

Rank 1:
  [reduce_scatter₀] [wait] [compute₀] [all_gather₀]
                    [reduce_scatter₁] [wait] [compute₁] [all_gather₁]
                                      [reduce_scatter₂] [wait] ...

TIP

Key insight: By launching all reduce-scatters first, then processing them sequentially with compute + gather, we maximize overlap.


Drop-in replacement: your training loop stays the same

One of DistMuon's best features is its zero-friction integration.

From the training script (view on GitHub):

scripts/base_train.py - Optimizer Setup
optimizers = model.setup_optimizers(
    unembedding_lr=unembedding_lr,
    embedding_lr=embedding_lr,
    matrix_lr=matrix_lr,
    weight_decay=weight_decay
)
adamw_optimizer, muon_optimizer = optimizers

The setup_optimizers() method automatically selects DistMuon when running distributed:

Automatic DistMuon Selection Pattern
def setup_optimizers(self, ...):
    # Separate parameters by dimensionality
    matrix_params = [p for p in self.parameters() if p.ndim == 2]
    vector_params = [p for p in self.parameters() if p.ndim < 2]
    
    # Use Dist* optimizers if distributed, else regular
    if dist.is_initialized():
        from nanochat.muon import DistMuon
        from nanochat.adamw import DistAdamW
        muon_opt = DistMuon(matrix_params, lr=matrix_lr, ...)
        adamw_opt = DistAdamW([{"params": vector_params}], lr=embedding_lr, ...)
    else:
        muon_opt = Muon(matrix_params, lr=matrix_lr, ...)
        adamw_opt = AdamW(vector_params, lr=embedding_lr, ...)
    
    return adamw_opt, muon_opt

Training loop (unchanged):

Training Loop - No Special Handling Needed
lrm = get_lr_multiplier(step)
for opt in optimizers:
    for group in opt.param_groups:
        group["lr"] = group["initial_lr"] * lrm
muon_momentum = get_muon_momentum(step)
for group in muon_optimizer.param_groups:
    group["momentum"] = muon_momentum
for opt in optimizers:
    opt.step()  # DistMuon.step() handles all communication!
model.zero_grad(set_to_none=True)

No special handling needed—just call opt.step() and DistMuon orchestrates all distributed operations internally.


Performance: near-linear scaling with GPU count

Communication Cost Analysis

For P parameters and N ranks:

OperationData Volume (per rank)Time Complexity
Reduce-scatterSend: P/N, Recv: P/NO(P/N)
Compute (Muon)Local onlyO(P/N)
All-gatherSend: P/N, Recv: PO(P)
Total per stepSend: 2P/N, Recv: P(1 + 1/N)O(P)

Key takeaway: Communication scales linearly with parameter count, independent of N for large N.

Scaling Behavior

Weak scaling (increase model size proportionally with GPUs):

  • Near-linear: Memory per rank stays constant
  • Communication stays constant: Each rank sends/receives same amount

Strong scaling (fixed model size, increase GPUs):

  • ⚠️ Sub-linear: Communication overhead increases relative to computation
  • ⚠️ Sweet spot: 8-64 GPUs for typical Transformers
  • Poor at scale: Beyond 128 GPUs, communication dominates

Example: Training nanochat's 270M model

  • 8 GPUs: ~90% scaling efficiency
  • 64 GPUs: ~75% scaling efficiency
  • 512 GPUs: ~40% scaling efficiency (not recommended)

For your multi-GPU training: what this means

DistMuon demonstrates that domain-specific optimizations can significantly improve distributed training efficiency. By tailoring the ZeRO-2 pattern to Muon's unique needs—preserving matrix structure, enabling batched Newton-Schulz, and implementing block-cyclic assignment—nanochat achieves:

  1. 58-67% memory savings vs standard DDP (8-64 GPUs)
  2. Seamless integration with existing codebases
  3. Efficient scaling for typical Transformer training workloads

Key Takeaways

When to use DistMuon:

  • Training Transformers with 2D weight matrices
  • Memory-constrained multi-GPU setups (8-64 GPUs)
  • Want ZeRO-2 benefits without DeepSpeed dependency

When to avoid:

  • Single GPU training (use regular Muon)
  • Models with mostly 1D parameters (use DistAdamW)
  • Extreme scale (>128 GPUs, consider ZeRO-3 or model parallelism)

Design Principles Worth Remembering

  1. Group by shape: Enable batched operations by processing uniform tensors together
  2. Block-cyclic assignment: Balance load while maintaining simplicity
  3. Async communication: Overlap network transfers with computation
  4. Preserve algorithmic invariants: Don't break Newton-Schulz by slicing matrices

Distributed training isn't about more GPUs—it's about smarter coordination. DistMuon gives you both.

What's Next in This Series

💾 Post 1.3: KV Caching Deep-Dive (Coming Soon)

Memory-efficient Transformer inference with prefill-and-clone patterns and dynamic cache growth.

🏗️ Post 1.4: Modern Architecture Choices (Coming Soon)

RoPE, QK normalization, Multi-Query Attention, and design trade-offs explained.

Sources and References

Distributed Training Fundamentals

ZeRO Stages and Memory Optimization

Muon Optimizer

Gradient Communication

nanochat Implementation

Industry Research & Hardware (as of January 2025)

GPU Cluster Pricing (as of January 2025)

ConfigurationTypical CostProvider Example
8× H100 (single node)~$24/hrLambda Labs
8× A100 (single node)~$12/hrLambda Labs
64× H100 (8 nodes)~$200/hrEnterprise cloud (AWS, GCP, Azure)

Before you implement distributed Muon:

  1. Verify NCCL connectivity first. Run a simple all-reduce test—distributed gradient sync failures are brutal to debug mid-training.
  2. Measure single-GPU memory baseline. Know exactly how much memory AdamW uses before claiming DistMuon savings.
  3. Separate 2D parameters from 1D. Matrix parameters get Newton-Schulz; embedding layers stay with AdamW. Mix them wrong and training diverges.
  4. Log communication overhead. Profile all-reduce time vs compute time—if communication exceeds 20%, your network is the bottleneck.

The memory you save isn't just about fitting larger models. It's about running the experiments you couldn't afford to run before.