Distributed Muon: Custom Gradient Synchronization for Memory-Efficient Training

- Published on
- /20 mins read
Standard DDP wastes 8× memory—every GPU stores everything
The DistMuon implementation is one of the most sophisticated parts of nanochat's codebase. Understanding how ZeRO-2 sharding works with Muon's matrix structure requirements taught me more about distributed training than any paper.
Training large language models requires distributing computation across multiple GPUs. While PyTorch's DDP makes this conceptually simple, it comes with significant memory overhead—every GPU stores a complete copy of the model parameters, gradients, and optimizer states.
TL;DR: For a 1 billion parameter model on 8 GPUs, standard DDP means 8× redundant copies of everything. ZeRO eliminates this redundancy by sharding optimizer states across GPUs—but Muon's Newton-Schulz orthogonalization requires preserving 2D matrix structure. DistMuon solves both problems.
The OOM that stopped a training run: Consider a scenario common in distributed training: 40% through training a 2.7B model on 8×A100s, GPU memory usage creeps up and crashes the run. Standard DDP means each GPU holds a full copy of optimizer states—18GB of redundant memory across 8 GPUs. Switching to DeepSpeed ZeRO-2 seems like the fix, but Muon's Newton-Schulz iterations fail silently because the 2D matrix structure is broken by naive sharding. Gradients look fine. Optimizer states are garbage. The issue: you can't shard a 2D weight matrix along arbitrary dimensions without breaking Muon's orthogonalization. DistMuon's "preserve matrix structure" sharding was designed specifically to avoid this failure mode.
⚠️ Note on Muon Optimizer
The DistMuon implementation referenced in this tutorial is based on the Muon optimizer approach from the nanochat research codebase. This post demonstrates distributed training techniques combining ZeRO-2 optimization with Newton-Schulz orthogonalization for educational purposes.
For production use:
- This serves as a learning resource for understanding custom distributed optimizer design
- Standard distributed optimizers (DeepSpeed ZeRO, FSDP) are recommended for production workloads
- The nanochat DistMuon implementation is experimental and designed for research/education
- Consider established frameworks (DeepSpeed, Megatron-LM) for production-scale training
NOTE
Prerequisites: Understanding of the Muon optimizer and basic distributed training concepts. Reading time: ~12 minutes.
DistMuon achieves:
- ~2-3× memory savings compared to standard DDP
- Seamless integration with existing training loops
- Custom reduce_scatter → compute → all_gather pattern optimized for Muon
Every rank stores full model, gradients, and optimizer states
Standard DDP's Synchronization Model
PyTorch DDP follows a simple but memory-inefficient pattern:
# Pseudo-code for standard DDP
for step in training_loop:
loss = model(x, y)
loss.backward() # Compute gradients locally
# DDP hooks: all_reduce gradients (implicit, happens in backward)
optimizer.step() # Each rank updates full model independentlyDuring backward(), DDP's hooks automatically trigger an all_reduce operation that averages gradients across all ranks. This ensures every GPU has identical gradients before the optimizer step.
Memory Overhead Analysis
For each parameter, every rank stores:
- Parameters (model weights):
Pbytes - Gradients:
Pbytes - Optimizer states: Depends on optimizer
- Adam/AdamW: 2 states (
exp_avg,exp_avg_sq) =2Pbytes - Muon: 1 state (
momentum_buffer) =Pbytes
- Adam/AdamW: 2 states (
Total per rank (Muon + DDP): P + P + P = 3P bytes
Total across N ranks: 3P × N bytes
For a 1B parameter model (2GB in bfloat16) on 8 GPUs:
Memory per rank = 2GB (params) + 2GB (grads) + 2GB (momentum) = 6GB
Total memory = 6GB × 8 = 48GB
This redundancy is wasteful. Can we do better?
What ZeRO-2 Offers
ZeRO has three stages of optimization:
- Stage 1: Shard optimizer states across ranks
- Stage 2: Shard gradients + optimizer states
- Stage 3: Shard parameters + gradients + optimizer states
DistMuon implements ZeRO-2, keeping parameters replicated but sharding gradients and optimizer states. This strikes a balance between memory efficiency and implementation complexity.
Memory per rank (ZeRO-2): P + P/N + P/N = P(1 + 2/N) bytes
For our 1B parameter example with 8 GPUs:
Memory per rank = 2GB (params) + 0.25GB (grads/8) + 0.25GB (momentum/8) = 2.5GB
Total memory = 2.5GB × 8 = 20GB
Savings: 48GB - 20GB = 28GB (58% reduction!)
For your training cluster, this means: ZeRO-2 is the difference between "needs 8 GPUs" and "fits on 4." Same model, half the hardware. The communication overhead is negligible compared to the memory savings.
Distributed Training Simulator
Visualize how DDP and FSDP distribute work across GPUs
DDP (Data Parallel): Each GPU has full model copy. Gradients are averaged after backward pass.
Communication: All-reduce gradients only
Memory Savings Calculator
Compare memory usage across different distributed training strategies
Memory per GPU by Strategy
| Strategy | Memory | Savings | Fits? |
|---|---|---|---|
No Sharding Full model copy on each GPU | 33.9 GB | - | ✓ |
ZeRO Stage 1 Optimizer states partitioned | 23.4 GB | 31% | ✓ |
ZeRO Stage 2 Optimizer + gradients partitioned | 18.1 GB | 46% | ✓ |
ZeRO Stage 3 / FSDP Everything partitioned across GPUs | 15.5 GB | 54% | ✓ |
Tensor Parallelism Layers split across GPUs | 7.2 GB | 79% | ✓ |
Recommendation: Use Tensor Parallelism to train your 1.5B model on 8x 80GB GPUs. Memory per GPU: 7.2 GB (79% savings).
Three design decisions make DistMuon work
Parameter Grouping by Shape
DistMuon groups all parameters by their shape before assigning ownership.
From the nanochat codebase (view on GitHub):
rank = dist.get_rank()
shapes = sorted({p.shape for p in params}) # Unique shapes, sorted
param_groups = []
for shape in shapes:
group_params = [p for p in params if p.shape == shape]
device, dtype = group_params[0].device, group_params[0].dtype
assert all(p.device == device for p in group_params)
assert all(p.dtype == dtype for p in group_params)
if rank == 0:
print(f"Muon: Grouping {len(group_params)} params of shape {shape}")
param_groups.append(dict(
params=group_params,
zero_buffer=torch.zeros_like(group_params[0])
))Why group by shape?
- Efficient batched operations: Newton-Schulz can process multiple matrices of the same shape simultaneously
- Simplified communication:
reduce_scatterandall_gatherrequire uniform tensor shapes - Better GPU utilization: Batched matrix operations maximize throughput
TIP
Example: A transformer model might have 100 parameters of shape [768, 768] (attention matrices), 50 of shape [3072, 768] (FFN matrices), and 50 of shape [768, 3072]. DistMuon creates 3 parameter groups, enabling efficient batched Newton-Schulz within each group.
Parameters map to GPUs in a round-robin pattern—here's why that matters
Within each shape group, parameters are assigned to ranks in a block-cyclic pattern:
world_size = dist.get_world_size()
for base_i in range(0, len(params), world_size):
owner_idx = base_i + rank # Each rank owns param at (base + rank)Visual representation (4 GPUs, 10 parameters):
Param indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
└─────────┘ └─────────┘ └──┘
Block 0 Block 1 Block 2
Rank 0 owns: [0, 4, 8 ] ← indices 0, 4, 8
Rank 1 owns: [ 1, 5, 9 ] ← indices 1, 5, 9
Rank 2 owns: [ 2, 6 ] ← indices 2, 6
Rank 3 owns: [ 3, 7 ] ← indices 3, 7
Why block-cyclic?
- Load balancing: Distributes parameters roughly evenly across ranks
- Simplicity: Each rank's ownership is a simple calculation (
base_i + rank) - Fault tolerance: Uneven parameter counts are handled gracefully with padding
For your debugging sessions, this means: when training stalls or a rank crashes, you can trace exactly which parameters each GPU owns. Block-cyclic assignment makes ownership deterministic—same layout every run.
The Three-Phase Update Pattern
DistMuon orchestrates a custom communication pattern that shards computation while maintaining parameter replication:
Phase 1: Reduce-scatter averages gradients and distributes ownership
The reduce_scatter operation combines gradients from all ranks and distributes the averaged results.
From the nanochat codebase (view on GitHub):
all_reduce_futures = []
for group in self.param_groups:
params = group["params"]
zero_buffer = group["zero_buffer"]
for base_i in range(0, len(params), world_size):
owner_idx = base_i + rank
# Each rank collects gradients for world_size consecutive params
rs_input = [p.grad for p in params[base_i:base_i + world_size]]
# Pad with zeros if we don't have enough params
rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
# Output buffer: gradient for the param this rank owns
rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
# Launch async reduce_scatter
work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
all_reduce_futures.append(work)What happens here:
- Input collection: Each rank gathers gradients for a block of
world_sizeparameters - Padding: If the block is incomplete (e.g., last block with fewer params), pad with
zero_buffer - Reduce-scatter: All ranks participate in averaging gradients
- Output: Each rank receives the averaged gradient for its owned parameter
Example (4 GPUs, block 0 with params [0,1,2,3]):
Rank 0: rs_input = [grad₀[0], grad₀[1], grad₀[2], grad₀[3]] → rs_output = avg(grad[0])
Rank 1: rs_input = [grad₁[0], grad₁[1], grad₁[2], grad₁[3]] → rs_output = avg(grad[1])
Rank 2: rs_input = [grad₂[0], grad₂[1], grad₂[2], grad₂[3]] → rs_output = avg(grad[2])
Rank 3: rs_input = [grad₃[0], grad₃[1], grad₃[2], grad₃[3]] → rs_output = avg(grad[3])
After reduce-scatter, each rank has the averaged gradient for its owned parameter, ready for computation.
Phase 2: Only the owner rank computes Newton-Schulz orthogonalization
Once gradients are averaged, each rank computes the Muon update for its owned parameters.
From the nanochat codebase (view on GitHub):
future_idx = 0
all_gather_futures = []
for group in self.param_groups:
params = group["params"]
for base_i in range(0, len(params), world_size):
owner_idx = base_i + rank
# Wait for reduce_scatter to complete
all_reduce_futures[future_idx].wait()
future_idx += 1
# Only owner computes the update
if owner_idx < len(params):
p = params[owner_idx]
g = p.grad # Already averaged across ranks
state = self.state[p]
# Initialize momentum buffer if needed
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(g)
buf = state["momentum_buffer"]
# Momentum accumulation
buf.lerp_(g, 1.0 - group["momentum"])
# Nesterov momentum (if enabled)
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
# Newton-Schulz orthogonalization
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
# Aspect-ratio scaled step
scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
p.add_(g, alpha=-group["lr"] * scale)Key points:
- Wait synchronization:
wait()ensures the gradient is ready before computation - Owner-only execution: Non-owner ranks skip computation (idle during this phase)
- Standard Muon update: Same as single-GPU Muon (see Post 1.1)
- Momentum accumulation with
lerp_ - Optional Nesterov momentum
- Newton-Schulz orthogonalization
- Aspect-ratio scaling:
sqrt(max(1, height/width))
- Momentum accumulation with
NOTE
Memory efficiency: Each rank stores momentum_buffer only for its owned parameters, achieving 1/N sharding.
Phase 3: All-gather broadcasts updated parameters to all GPUs
After computing updates, ranks replicate their updated parameters to all other ranks.
From the nanochat codebase (view on GitHub):
# Replicate updated parameters to all ranks
ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
ag_output = params[base_i:base_i + world_size]
ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))])
work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
all_gather_futures.append(work)
# Wait for all gathers to complete (outside the loops)
torch.futures.collect_all(all_gather_futures).wait()What happens:
- Input: Each rank's owned parameter (or zero_buffer if padding)
- Output: List of tensors to populate with gathered parameters
- All-gather: Broadcast each rank's parameter to all other ranks
- Synchronization:
collect_all().wait()ensures all communications complete
After all-gather, every rank has identical copies of all parameters, ready for the next forward pass.
DistMuon vs DistAdamW: same pattern, different memory footprints
Both optimizers implement ZeRO-2, but their strategies differ due to algorithmic requirements.
Key Differences
| Feature | DistAdamW | DistMuon | Reason |
|---|---|---|---|
| Parameter Requirements | Any shape | 2D only | Newton-Schulz needs matrices |
| Sharding Strategy | Slice along dim 0 | Block-cyclic whole params | Preserve aspect ratio |
| State Storage | Slice-local (exp_avg, exp_avg_sq) | Param-local (momentum_buffer) | Matrix operations |
| Compute Pattern | All ranks on slices | Owner ranks only | Simplify NS batching |
| Reduce-scatter Input | Full tensor | List of tensors | Shape uniformity |
| Memory Efficiency | ~1/N states | ~1/N states | Similar overall |
| Load Balance | Perfect (slicing) | Imperfect (padding) | Trade-off for simplicity |
DistAdamW's Sharding Approach
From the nanochat codebase (view on GitHub):
for base_i in range(len(params)):
grad = params[base_i].grad
rank_size = grad.shape[0] // world_size
grad_slice = torch.empty_like(grad[:rank_size])
reduce_scatter_futures.append(
dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True)
)DistAdamW slices each parameter along the first dimension, distributing rows across ranks. This works for any parameter shape and achieves perfect load balancing.
WARNING
Why doesn't DistMuon do this? Newton-Schulz requires the full 2D matrix structure with its original aspect ratio. Slicing a [768, 768] matrix into [192, 768] slices would change the aspect ratio from 1:1 to 1:4, breaking the orthogonalization geometry. DistMuon preserves matrices intact by assigning whole parameters to owner ranks.
Comparison: Sharding Granularity
Example: 4 GPUs, parameter shape [1024, 768]
DistAdamW:
┌────────────┐
│ Rank 0 (256 rows) │ Each rank stores:
├────────────┤ - param slice [256, 768]
│ Rank 1 (256 rows) │ - grad slice [256, 768]
├────────────┤ - exp_avg slice [256, 768]
│ Rank 2 (256 rows) │ - exp_avg_sq slice [256, 768]
├────────────┤
│ Rank 3 (256 rows) │
└────────────┘
DistMuon (within a block of 4 params):
Rank 0: param[0] [1024,768] ← Full matrix
Rank 1: param[1] [1024,768] ← Full matrix
Rank 2: param[2] [1024,768] ← Full matrix
Rank 3: param[3] [1024,768] ← Full matrix
Each rank stores:
- Full param (replicated)
- momentum_buffer for owned param only
The math: 58-67% savings depend on your model architecture and Efficiency Gains
Memory Breakdown Per Rank
Standard DDP + Muon:
Parameters: P bytes (full model)
Gradients: P bytes (full model)
Momentum buffers: P bytes (full model)
─────────────────────────────
Total per rank: 3P bytes
Total across N: 3P × N bytes
DistMuon (ZeRO-2):
Parameters: P bytes (replicated)
Gradients: P/N bytes (sharded)
Momentum buffers: P/N bytes (sharded)
─────────────────────────────
Total per rank: P(1 + 2/N) bytes
Total across N: P(N + 2) bytes
Efficiency Calculations
| N ranks | DDP Total | DistMuon Total | Memory Savings | Savings % |
|---|---|---|---|---|
| 2 | 6P | 4P | 2P | 33% |
| 4 | 12P | 6P | 6P | 50% |
| 8 | 24P | 10P | 14P | 58% |
| 16 | 48P | 18P | 30P | 63% |
| 64 | 192P | 66P | 126P | 66% |
Asymptotic behavior: As N → ∞, savings approach 67% (2/3 reduction).
Practical Example: nanochat's 270M Model
nanochat's depth-20 model has ~270M parameters in bfloat16 (540 MB total).
8× H100 GPUs (80GB each):
| Metric | Standard DDP | DistMuon | Savings |
|---|---|---|---|
| Params | 540 MB | 540 MB | 0 MB |
| Grads | 540 MB | 67.5 MB | 472.5 MB |
| States | 540 MB | 67.5 MB | 472.5 MB |
| Total/rank | 1.62 GB | 675 MB | 945 MB (58%) |
| Total/cluster | 12.96 GB | 5.4 GB | 7.56 GB |
This 945 MB per-rank saving allows:
- Larger batch sizes (more memory for activations)
- Longer sequences (quadratic attention memory)
- Bigger models (fit 470M params with same memory as 270M DDP)
For your multi-GPU training, this means: sharded optimizer state unlocks training headroom you didn't know you had. That 945 MB per rank translates directly into longer context windows or deeper models without touching your hardware budget.
Async communication overlaps compute with network—here's how
Why Asynchronous Operations?
DistMuon uses async_op=True throughout to overlap communication with computation:
# Launch all reduce-scatters without waiting
for group in self.param_groups:
for base_i in range(0, len(params), world_size):
work = dist.reduce_scatter(..., async_op=True).get_future()
all_reduce_futures.append(work)
# Compute and gather (wait only when needed)
for group in self.param_groups:
for base_i in range(0, len(params), world_size):
all_reduce_futures[future_idx].wait() # Wait for specific gradient
if owner_idx < len(params):
# Compute Muon update
...
work = dist.all_gather(..., async_op=True).get_future()
all_gather_futures.append(work)
# Final synchronization
torch.futures.collect_all(all_gather_futures).wait()Benefits:
- Communication-computation overlap: While GPU computes updates for earlier parameters, network transfers gradients for later parameters
- Pipelining: Reduce-scatter and all-gather operations can overlap across parameter groups
- Lower latency: Non-blocking calls prevent idle GPU time
Synchronization Pattern
Time ──────────────────────────────────────────────►
Rank 0:
[reduce_scatter₀] [wait] [compute₀] [all_gather₀]
[reduce_scatter₁] [wait] [compute₁] [all_gather₁]
[reduce_scatter₂] [wait] ...
Rank 1:
[reduce_scatter₀] [wait] [compute₀] [all_gather₀]
[reduce_scatter₁] [wait] [compute₁] [all_gather₁]
[reduce_scatter₂] [wait] ...
TIP
Key insight: By launching all reduce-scatters first, then processing them sequentially with compute + gather, we maximize overlap.
Drop-in replacement: your training loop stays the same
One of DistMuon's best features is its zero-friction integration.
From the training script (view on GitHub):
optimizers = model.setup_optimizers(
unembedding_lr=unembedding_lr,
embedding_lr=embedding_lr,
matrix_lr=matrix_lr,
weight_decay=weight_decay
)
adamw_optimizer, muon_optimizer = optimizersThe setup_optimizers() method automatically selects DistMuon when running distributed:
def setup_optimizers(self, ...):
# Separate parameters by dimensionality
matrix_params = [p for p in self.parameters() if p.ndim == 2]
vector_params = [p for p in self.parameters() if p.ndim < 2]
# Use Dist* optimizers if distributed, else regular
if dist.is_initialized():
from nanochat.muon import DistMuon
from nanochat.adamw import DistAdamW
muon_opt = DistMuon(matrix_params, lr=matrix_lr, ...)
adamw_opt = DistAdamW([{"params": vector_params}], lr=embedding_lr, ...)
else:
muon_opt = Muon(matrix_params, lr=matrix_lr, ...)
adamw_opt = AdamW(vector_params, lr=embedding_lr, ...)
return adamw_opt, muon_optTraining loop (unchanged):
lrm = get_lr_multiplier(step)
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * lrm
muon_momentum = get_muon_momentum(step)
for group in muon_optimizer.param_groups:
group["momentum"] = muon_momentum
for opt in optimizers:
opt.step() # DistMuon.step() handles all communication!
model.zero_grad(set_to_none=True)No special handling needed—just call opt.step() and DistMuon orchestrates all distributed operations internally.
Performance: near-linear scaling with GPU count
Communication Cost Analysis
For P parameters and N ranks:
| Operation | Data Volume (per rank) | Time Complexity |
|---|---|---|
| Reduce-scatter | Send: P/N, Recv: P/N | O(P/N) |
| Compute (Muon) | Local only | O(P/N) |
| All-gather | Send: P/N, Recv: P | O(P) |
| Total per step | Send: 2P/N, Recv: P(1 + 1/N) | O(P) |
Key takeaway: Communication scales linearly with parameter count, independent of N for large N.
Scaling Behavior
Weak scaling (increase model size proportionally with GPUs):
- ✅ Near-linear: Memory per rank stays constant
- ✅ Communication stays constant: Each rank sends/receives same amount
Strong scaling (fixed model size, increase GPUs):
- ⚠️ Sub-linear: Communication overhead increases relative to computation
- ⚠️ Sweet spot: 8-64 GPUs for typical Transformers
- ❌ Poor at scale: Beyond 128 GPUs, communication dominates
Example: Training nanochat's 270M model
- 8 GPUs: ~90% scaling efficiency
- 64 GPUs: ~75% scaling efficiency
- 512 GPUs: ~40% scaling efficiency (not recommended)
For your multi-GPU training: what this means
DistMuon demonstrates that domain-specific optimizations can significantly improve distributed training efficiency. By tailoring the ZeRO-2 pattern to Muon's unique needs—preserving matrix structure, enabling batched Newton-Schulz, and implementing block-cyclic assignment—nanochat achieves:
- 58-67% memory savings vs standard DDP (8-64 GPUs)
- Seamless integration with existing codebases
- Efficient scaling for typical Transformer training workloads
Key Takeaways
✅ When to use DistMuon:
- Training Transformers with 2D weight matrices
- Memory-constrained multi-GPU setups (8-64 GPUs)
- Want ZeRO-2 benefits without DeepSpeed dependency
❌ When to avoid:
- Single GPU training (use regular Muon)
- Models with mostly 1D parameters (use DistAdamW)
- Extreme scale (>128 GPUs, consider ZeRO-3 or model parallelism)
Design Principles Worth Remembering
- Group by shape: Enable batched operations by processing uniform tensors together
- Block-cyclic assignment: Balance load while maintaining simplicity
- Async communication: Overlap network transfers with computation
- Preserve algorithmic invariants: Don't break Newton-Schulz by slicing matrices
Distributed training isn't about more GPUs—it's about smarter coordination. DistMuon gives you both.
What's Next in This Series
💾 Post 1.3: KV Caching Deep-Dive (Coming Soon)
Memory-efficient Transformer inference with prefill-and-clone patterns and dynamic cache growth.
🏗️ Post 1.4: Modern Architecture Choices (Coming Soon)
RoPE, QK normalization, Multi-Query Attention, and design trade-offs explained.
Sources and References
Distributed Training Fundamentals
- Rajbhandari, S., et al. (2020). ZeRO: Memory Optimizations Toward Training Trillion Parameter Models. SC 2020. Original ZeRO optimization stages.
- Li, S., et al. (2020). PyTorch Distributed: Experiences on Accelerating Data Parallel Training. VLDB 2020. PyTorch DDP design principles.
- PyTorch DDP Tutorial. Understanding standard distributed training.
- PyTorch Distributed Collective Ops.
reduce_scatter,all_gatherdocumentation.
ZeRO Stages and Memory Optimization
- Ren, J., et al. (2021). ZeRO-Offload: Democratizing Billion-Scale Model Training. USENIX ATC 2021. CPU offloading for memory-constrained training.
- Rajbhandari, S., et al. (2022). DeepSpeed-Inference: Enabling Efficient Inference of Transformer Models at Unprecedented Scale. SC 2022. Inference optimizations.
Muon Optimizer
- Jordan, K. (2024). Muon: Momentum + Newton-Schulz. Original Muon blog post.
- Liu, J., et al. (2025). Muon is Scalable for LLM Training. arXiv:2502.16982. Scalability analysis across GPU counts.
- Post 1.1: The Muon Optimizer Explained. Prerequisite reading in this series.
Gradient Communication
- Shoeybi, M., et al. (2019). Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism. Model parallelism and gradient sharding.
- Huang, Y., et al. (2019). GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism. NeurIPS 2019. Pipeline parallelism fundamentals.
- Paszke, A., et al. (2019). PyTorch: An Imperative Style, High-Performance Deep Learning Library. NeurIPS 2019. PyTorch distributed backend.
nanochat Implementation
- muon.py on GitHub. DistMuon source code.
- nanochat Repository. Full training pipeline.
Industry Research & Hardware (as of January 2025)
- NVIDIA NVLink Specifications: NVLink and NVSwitch. 600 GB/s GPU-to-GPU interconnect; critical for distributed training efficiency.
- MLCommons MLPerf Training: Distributed Training Benchmarks. Industry-standard multi-GPU training benchmarks.
- Epoch AI Compute Analysis: Large-Scale Training Compute. Documents scaling efficiency across GPU counts for frontier model training.
GPU Cluster Pricing (as of January 2025)
| Configuration | Typical Cost | Provider Example |
|---|---|---|
| 8× H100 (single node) | ~$24/hr | Lambda Labs |
| 8× A100 (single node) | ~$12/hr | Lambda Labs |
| 64× H100 (8 nodes) | ~$200/hr | Enterprise cloud (AWS, GCP, Azure) |
Before you implement distributed Muon:
- Verify NCCL connectivity first. Run a simple all-reduce test—distributed gradient sync failures are brutal to debug mid-training.
- Measure single-GPU memory baseline. Know exactly how much memory AdamW uses before claiming DistMuon savings.
- Separate 2D parameters from 1D. Matrix parameters get Newton-Schulz; embedding layers stay with AdamW. Mix them wrong and training diverges.
- Log communication overhead. Profile all-reduce time vs compute time—if communication exceeds 20%, your network is the bottleneck.
The memory you save isn't just about fitting larger models. It's about running the experiments you couldn't afford to run before.
On this page
- Standard DDP wastes 8× memory—every GPU stores everything
- Every rank stores full model, gradients, and optimizer states
- Standard DDP's Synchronization Model
- Memory Overhead Analysis
- What ZeRO-2 Offers
- Three design decisions make DistMuon work
- Parameter Grouping by Shape
- Parameters map to GPUs in a round-robin pattern—here's why that matters
- The Three-Phase Update Pattern
- Phase 1: Reduce-scatter averages gradients and distributes ownership
- Phase 2: Only the owner rank computes Newton-Schulz orthogonalization
- Phase 3: All-gather broadcasts updated parameters to all GPUs
- DistMuon vs DistAdamW: same pattern, different memory footprints
- Key Differences
- DistAdamW's Sharding Approach
- Comparison: Sharding Granularity
- The math: 58-67% savings depend on your model architecture and Efficiency Gains
- Memory Breakdown Per Rank
- Efficiency Calculations
- Practical Example: nanochat's 270M Model
- Async communication overlaps compute with network—here's how
- Why Asynchronous Operations?
- Synchronization Pattern
- Drop-in replacement: your training loop stays the same
- Performance: near-linear scaling with GPU count
- Communication Cost Analysis
- Scaling Behavior
- For your multi-GPU training: what this means
- Key Takeaways
- Design Principles Worth Remembering
- What's Next in This Series
- Sources and References
- Distributed Training Fundamentals
- ZeRO Stages and Memory Optimization
- Muon Optimizer
- Gradient Communication
- nanochat Implementation
- Industry Research & Hardware (as of January 2025)
- GPU Cluster Pricing (as of January 2025)



