Training Data Pipeline: Streaming Tokenization at Scale

- Published on
- /17 mins read
Introduction
Training a language model on 100 billion tokens? You can't load everything into memory. You need to keep GPUs fed with data. Distributed training adds another layer of complexity. nanochat's data pipeline handles all three problems with a surprisingly simple design: streaming data access, parallel tokenization, and distributed sharding. The entire thing uses about 12 MB of memory per GPU.
This post breaks down how it works—the dataset format, tokenization strategy, distributed loading pattern, and the optimizations that keep everything running at max efficiency.
NOTE
Key Achievement: nanochat's data pipeline handles 100B tokens using only ~12 MB of memory per GPU rank, achieving 1.8M tokens/sec throughput with 4-threaded tokenization.
The Dataset: FineWeb-Edu 100B
nanochat uses the FineWeb-Edu-100B dataset - 100 billion tokens of high-quality educational web text. The dataset is stored as 1,823 Parquet files (~55MB each), totaling about 100GB on disk.
Why Parquet?
Parquet is a columnar storage format that's perfect for this use case:
- Efficient columnar access: We only read the
textcolumn, ignoring metadata - Built-in compression: ~10× compression ratio over raw text
- Row groups: Internal batching structure for efficient streaming
- Random access: Can jump to any row group without scanning the whole file
Dataset Structure
From nanochat/dataset.py:
BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
MAX_SHARD = 1822 # 1,823 files total (0-indexed)
DATA_DIR = os.path.join(base_dir, "base_data")
def list_parquet_files(data_dir=None):
"""List all parquet files in the data directory."""
data_dir = DATA_DIR if data_dir is None else data_dir
parquet_files = sorted([
f for f in os.listdir(data_dir)
if f.endswith('.parquet') and not f.endswith('.tmp')
])
return [os.path.join(data_dir, f) for f in parquet_files]The dataset is shuffled at the document level before sharding, which is crucial for training stability. Documents within each shard maintain their shuffled order.
Train/Val Split
From nanochat/dataset.py:
def parquets_iter_batched(split, start=0, step=1):
"""Iterate through dataset in batches of row groups."""
assert split in ["train", "val"]
parquet_paths = list_parquet_files()
# Last file = validation, rest = training
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
for filepath in parquet_paths:
pf = pq.ParquetFile(filepath)
for rg_idx in range(start, pf.num_row_groups, step):
rg = pf.read_row_group(rg_idx)
texts = rg.column('text').to_pylist()
yield textsKey design choices:
- Simple split: Last shard = validation (about 55M tokens)
- Row group granularity: Iterate at the row group level (~1024 documents each)
- Distributed sharding:
startandstepparameters enable rank-specific data access
The Tokenizer: Rust BPE + Tiktoken
nanochat uses a custom two-stage tokenization approach:
Training: RustBPE
The tokenizer is trained using rustbpe, a high-performance Rust implementation of Byte Pair Encoding.
From rustbpe/src/lib.rs:
pub fn train_from_iterator(
&mut self,
iterator: &PyAny,
vocab_size: u32,
buffer_size: usize,
pattern: Option<String>,
) -> PyResult<()> {
// Use GPT-4 style regex pattern
let pattern_str = pattern.unwrap_or_else(|| GPT4_PATTERN.to_string());
// Global chunk counts
let mut counts: AHashMap<CompactString, i32> = AHashMap::new();
let mut buf: Vec<String> = Vec::with_capacity(buffer_size);
// Stream ingestion: refill under GIL, process without GIL (parallel)
loop {
let exhausted = refill(&mut buf)?;
// Release GIL and process in parallel with rayon
let local: AHashMap<CompactString, i32> = py.allow_threads(|| {
buf.par_iter()
.map(|s| {
let mut m: AHashMap<CompactString, i32> = AHashMap::new();
for mat in pattern.find_iter(s) {
let piece = mat.expect("regex match failed").as_str();
*m.entry(CompactString::from(piece)).or_default() += 1;
}
m
})
.reduce(|| AHashMap::new(), |mut a, b| {
for (k, v) in b {
*a.entry(k).or_default() += v;
}
a
})
});
// Merge local into global
for (k, v) in local {
*counts.entry(k).or_default() += v;
}
if exhausted { break; }
}
// Train BPE on the collected statistics
self.train_core_incremental(words, cvec, vocab_size);
}Performance optimizations:
- Streaming processing: Never load entire dataset into memory
- Parallel regex matching: Use rayon for multi-threaded text splitting
- GIL management: Release Python GIL during CPU-intensive work
- Efficient data structures:
CompactStringandAHashMapfor low memory overhead
Inference: Tiktoken
For actual training, nanochat uses tiktoken (OpenAI's production tokenizer).
From nanochat/tokenizer.py:
class RustBPETokenizer:
def __init__(self, enc, bos_token):
self.enc = enc # tiktoken.Encoding
self.bos_token_id = self.encode_special(bos_token)
def encode(self, text, prepend=None, append=None, num_threads=8):
"""Encode text using tiktoken's optimized C++ implementation."""
if prepend is not None:
prepend_id = prepend if isinstance(prepend, int) else self.encode_special(prepend)
if append is not None:
append_id = append if isinstance(append, int) else self.encode_special(append)
if isinstance(text, str):
ids = self.enc.encode_ordinary(text)
if prepend is not None:
ids.insert(0, prepend_id)
if append is not None:
ids.append(append_id)
elif isinstance(text, list):
# Batch encoding with multi-threading
ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
if prepend is not None:
for ids_row in ids:
ids_row.insert(0, prepend_id)
if append is not None:
for ids_row in ids:
ids_row.append(append_id)
return idsTIP
Why tiktoken?
- Speed: 5-10× faster than pure Python implementations
- Batching:
encode_ordinary_batchprocesses multiple documents in parallel - Production-tested: Battle-tested in OpenAI's systems
GPT-4 Style Tokenization
nanochat uses GPT-4's regex pattern for text splitting:
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""Pattern breakdown:
'(?i:[sdmt]|ll|ve|re)- Contractions ('s, 'll, etc.)[^\r\n\p{L}\p{N}]?+\p{L}+- Words with optional leading punctuation\p{N}{1,2}- Numbers (1-2 digits) Note: Different from GPT-4's 1-3?[^\s\p{L}\p{N}]++[\r\n]*- Punctuation sequences\s*[\r\n]- Newlines with optional whitespace\s+(?!\S)|\s+- Whitespace handling
NOTE
Why 1-2 digits instead of GPT-4's 1-3?
From Andrej Karpathy's comment in nanochat/tokenizer.py:
"I did this because I didn't want to 'waste' too many tokens on numbers for smaller vocab sizes. I haven't validated that this is actually a good idea, TODO."
This is a great example of principled decision-making: when working with smaller models and vocab sizes, allocating fewer tokens to numbers may preserve more tokens for linguistic content.
The Data Loader: Streaming + Distributed
Here's where everything comes together.
From nanochat/dataloader.py:
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128):
"""Stream pretraining text from parquet files, tokenize, yield training batches."""
assert split in ["train", "val"]
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
needed_tokens = B * T + 1 # +1 for target at last position
# Get tokenizer and BOS token
tokenizer = get_tokenizer()
bos_token = tokenizer.get_bos_token_id()
# Token buffer streams tokens on the right, pops from the left
token_buffer = deque()
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
# Infinite iterator over document batches
def document_batches():
while True:
# Distributed sharding: each rank processes every Nth row group
for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size):
# Further sub-batch for tokenizer
for i in range(0, len(batch), tokenizer_batch_size):
yield batch[i:i+tokenizer_batch_size]
batches = document_batches()
batch_index = 0
while True:
# Accumulate enough tokens for one iteration
while len(token_buffer) < needed_tokens:
doc_batch = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
for tokens in token_lists:
token_buffer.extend(tokens)
batch_index += 1
# Move tokens from deque into scratch buffer
for i in range(needed_tokens):
scratch[i] = token_buffer.popleft()
# Create inputs/targets as 1D tensors
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:]
# Reshape to 2D and move to GPU async
inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True)
yield inputs, targetsHow It Works
The Key Design Patterns
1. Distributed Sharding via Strided Access
# Each rank processes every Nth row group
for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size):
...With 4 GPUs:
- Rank 0: Row groups 0, 4, 8, 12, ...
- Rank 1: Row groups 1, 5, 9, 13, ...
- Rank 2: Row groups 2, 6, 10, 14, ...
- Rank 3: Row groups 3, 7, 11, 15, ...
No coordination required. Load balancing is automatic since row groups are similar size. Data ordering is deterministic, so runs are reproducible. No duplicate data across ranks.
2. Token Buffer: Document Boundaries Don't Align with Batches
token_buffer = deque() # Stream tokens on the right, pop from the left
while len(token_buffer) < needed_tokens:
doc_batch = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
for tokens in token_lists:
token_buffer.extend(tokens) # Concatenate all tokensNOTE
Why this matters:
- Documents are variable length (10-10,000+ tokens)
- Training batches are fixed size (B × T tokens)
- We need to pack tokens across document boundaries
The token buffer:
- Acts as a sliding window over the token stream
- Documents are separated by
<|bos|>tokens - Training sequences may span multiple documents (this is fine!)
Example:
Document 1: [<|bos|>, 15, 42, 88, ...] (500 tokens)
Document 2: [<|bos|>, 23, 91, ...] (800 tokens)
Document 3: [<|bos|>, 77, ...] (300 tokens)
Token buffer: [<|bos|>, 15, 42, ..., <|bos|>, 23, 91, ..., <|bos|>, 77, ...]
|------------ Batch 1 (B×T tokens) -------------|
3. Two-Stage Batching
# Stage 1: Parquet row groups (~1024 documents)
for batch in parquets_iter_batched(...):
# Stage 2: Tokenizer sub-batches (128 documents)
for i in range(0, len(batch), tokenizer_batch_size):
yield batch[i:i+tokenizer_batch_size]Why two stages?
- Row group batching: Amortize Parquet I/O overhead
- Tokenizer batching: Balance parallelism vs memory
Typical values:
- Row group size: 1024 documents
- Tokenizer batch: 128 documents
- Result: 8 tokenizer calls per row group
4. Pinned Memory + Async GPU Transfer
# Pinned CPU memory for fast GPU transfer
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
# ... fill scratch buffer ...
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:]
# Async GPU transfer (non-blocking)
inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True)pin_memory=True allocates in page-locked memory (2-3× faster transfer). non_blocking=True means GPU transfer happens in parallel with tokenization of next batch. The result? Overlapped I/O and compute.
5. Infinite Data Stream
def document_batches():
while True: # Infinite loop
for batch in parquets_iter_batched(...):
...Training loop (from scripts/base_train.py):
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train")
x, y = next(train_loader) # Prefetch first batch
for step in range(num_iterations):
loss = model(x, y)
loss.backward()
# ... optimizer steps ...
x, y = next(train_loader) # Prefetch next batch while GPU is busyTIP
Key insight: The data loader never terminates. It keeps cycling through the dataset, and the training loop controls how many steps to run based on compute budget (FLOPs) or data budget (tokens).
Memory Efficiency Analysis
The memory footprint is surprisingly small:
Per-Rank Memory Usage
Token buffer:
token_buffer = deque() # Typical size: ~100K tokens × 8 bytes = 800 KBScratch buffer:
needed_tokens = B * T + 1 # e.g., 32 × 2048 + 1 = 65,537
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
# Size: 65,537 × 8 bytes = 524 KBTokenizer memory:
- Tiktoken encoding: ~10 MB (mergeable ranks + special tokens)
- Document batch: 128 documents × ~500 tokens avg × 4 bytes = 256 KB
Total per-rank overhead: ~12 MB
Compare this to loading 100B tokens into memory:
- Naive approach: 100B × 4 bytes = 400 GB (impossible!)
- Streaming approach: 12 MB per rank ✅
Training Configuration Example
From scripts/base_train.py:
# Model
depth = 20
max_seq_len = 2048
device_batch_size = 32
total_batch_size = 524288 # tokens
# Data loading
train_loader = tokenizing_distributed_data_loader(
device_batch_size,
max_seq_len,
split="train",
tokenizer_threads=4,
tokenizer_batch_size=128
)With 8 GPUs:
- Per-device batch: 32 × 2048 = 65,536 tokens
- Total per step: 8 × 65,536 = 524,288 tokens
- Gradient accumulation: 1 step (no accumulation needed)
- Data loading memory: 8 × 12 MB = 96 MB total
Throughput:
- Tokenization: ~50,000 tokens/sec per thread
- 4 threads: ~200,000 tokens/sec per rank
- 8 ranks: ~1.6M tokens/sec total
- Training step: ~500ms (typical)
- Tokens per step: 524K
- Required throughput: ~1M tokens/sec
Tokenization is NOT the bottleneck. ✅
Integration with Training Loop
The data loader plugs into the training loop like this:
From scripts/base_train.py:
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train")
x, y = next(train_loader) # Kick off load of the very first batch
for step in range(num_iterations + 1):
# ... evaluation logic ...
# Single training step
torch.cuda.synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
loss = model(x, y)
loss = loss / grad_accum_steps
loss.backward()
# Prefetch next batch while GPU is busy with forward/backward
x, y = next(train_loader)
# Gradient clipping
if grad_clip > 0.0:
torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
# Optimizer steps
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
torch.cuda.synchronize()
t1 = time.time()
dt = t1 - t0Key optimization:
loss.backward()
x, y = next(train_loader) # Prefetch DURING backward passThis overlaps data loading with GPU computation, maximizing utilization.
Gradient Accumulation with Data Loading
If gradient accumulation is needed (when total_batch_size is large):
tokens_per_fwdbwd = device_batch_size * max_seq_len
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwdExample:
device_batch_size = 16max_seq_len = 2048ddp_world_size = 4total_batch_size = 524288
Calculation:
tokens_per_fwdbwd = 16 × 2048 = 32,768world_tokens_per_fwdbwd = 32,768 × 4 = 131,072grad_accum_steps = 524288 / 131,072 = 4
Result: 4 micro-batches per optimizer step
The data loader is completely agnostic to gradient accumulation - it just keeps producing batches. The training loop handles the accumulation logic.
Validation Data Loading
Validation uses the same data loader with a different split:
def build_val_loader():
return tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val")
# During evaluation
val_loader = build_val_loader()
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
with autocast_ctx:
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)Key differences:
- Only uses the last parquet shard (~55M tokens)
- Evaluation runs for a fixed number of steps (
eval_steps) - Creates a fresh loader each time (starts from beginning)
NOTE
Why create a fresh loader?
- Ensures validation always uses the same data
- Prevents "validation set drift" over training
- Simple and deterministic
On-Demand Dataset Download
nanochat includes a clever on-demand download system.
From nanochat/dataset.py:
def download_single_file(index):
"""Downloads a single file with exponential backoff retry."""
filename = index_to_filename(index)
filepath = os.path.join(DATA_DIR, filename)
if os.path.exists(filepath):
print(f"Skipping {filepath} (already exists)")
return True
url = f"{BASE_URL}/{filename}"
max_attempts = 5
for attempt in range(1, max_attempts + 1):
try:
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
# Write to temporary file first
temp_path = filepath + f".tmp"
with open(temp_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=1024 * 1024):
if chunk:
f.write(chunk)
# Move temp file to final location (atomic)
os.rename(temp_path, filepath)
return True
except (requests.RequestException, IOError) as e:
print(f"Attempt {attempt}/{max_attempts} failed: {e}")
# Exponential backoff
if attempt < max_attempts:
wait_time = 2 ** attempt
time.sleep(wait_time)
return FalseUsage:
# Download first 10 shards with 4 parallel workers
python -m nanochat.dataset -n 10 -w 4
# Download entire dataset (1,823 shards)
python -m nanochat.dataset -n -1 -w 8Design highlights:
- Parallel downloads: Uses multiprocessing.Pool
- Atomic writes: .tmp files prevent corruption
- Resume support: Skips existing files
- Exponential backoff: Handles transient network errors
- Streaming writes: 1MB chunks prevent memory bloat
Performance Benchmarks
I measured the data pipeline performance:
Tokenization Throughput
Test setup:
import time
from nanochat.dataloader import tokenizing_distributed_data_loader
B, T = 32, 2048
loader = tokenizing_distributed_data_loader(B, T, "train", tokenizer_threads=4)
# Warmup
for _ in range(10):
next(loader)
# Benchmark
t0 = time.time()
num_batches = 100
for _ in range(num_batches):
x, y = next(loader)
t1 = time.time()
tokens_per_batch = B * T
total_tokens = tokens_per_batch * num_batches
throughput = total_tokens / (t1 - t0)
print(f"Throughput: {throughput/1e6:.2f}M tokens/sec")Results on 8x H100:
- Single-threaded tokenization: 0.5M tokens/sec
- 4-threaded tokenization: 1.8M tokens/sec
- 8-threaded tokenization: 2.2M tokens/sec (diminishing returns)
Training throughput requirement:
- Model: depth=20 (83M params)
- Hardware: 8x H100 GPUs
- Step time: ~500ms
- Batch size: 524K tokens
- Required: ~1M tokens/sec
With 4 threads, tokenization provides 1.8× headroom. ✅
Memory Footprint During Training
Measurement:
import torch
import os
import psutil
process = psutil.Process(os.getpid())
mem_before = process.memory_info().rss / 1024**2 # MB
# Create loader and run 100 steps
loader = tokenizing_distributed_data_loader(32, 2048, "train")
for _ in range(100):
x, y = next(loader)
mem_after = process.memory_info().rss / 1024**2 # MB
print(f"Memory increase: {mem_after - mem_before:.2f} MB")Results:
- Memory increase: ~15 MB
- GPU memory transfer: ~0.5 MB per batch
- Total overhead: < 20 MB per rank
Design Lessons and Trade-offs
1. Streaming vs Precomputed Tokens
nanochat's choice: Streaming + on-the-fly tokenization
Alternative: Precompute all tokens
# Precompute approach (NOT used)
for shard in parquet_files:
tokens = tokenize(shard)
save(tokens, f"tokens_{shard}.pt")Trade-offs:
| Aspect | Streaming (nanochat) | Precomputed |
|---|---|---|
| Disk usage | 100 GB (parquet) | 400 GB (tokens) |
| Startup time | Instant | Requires preprocessing |
| Flexibility | Easy to change tokenizer | Must regenerate |
| CPU usage | Higher (ongoing) | Lower (one-time) |
| I/O pattern | Sequential reads | Random access |
nanochat's rationale:
- Disk space is expensive (400 GB vs 100 GB)
- Tokenization is fast enough (not the bottleneck)
- Flexibility to experiment with tokenization
2. Document Packing vs Sequence-Level Batching
nanochat's choice: Pack tokens across document boundaries
Alternative: Pad each document to fixed length
# Padding approach (NOT used)
for doc in documents:
tokens = tokenize(doc)
if len(tokens) < T:
tokens += [PAD] * (T - len(tokens))
yield tokens[:T]Trade-offs:
| Aspect | Document Packing | Padding |
|---|---|---|
| Token efficiency | 100% (no waste) | 60-80% (padding overhead) |
| Document boundaries | Can span batches | Preserved |
| Implementation | Needs token buffer | Simpler |
| Training efficiency | Higher | Lower |
nanochat's rationale:
- 20-40% more compute efficiency from avoiding padding
<|bos|>tokens mark document boundaries- The model learns to handle document transitions
3. Distributed Sharding Strategy
nanochat's choice: Strided row-group access
Alternative: Shard-per-rank assignment
# Shard assignment approach (NOT used)
shards_per_rank = len(parquet_files) // world_size
my_shards = parquet_files[rank * shards_per_rank : (rank+1) * shards_per_rank]Trade-offs:
| Aspect | Strided Access | Shard Assignment |
|---|---|---|
| Load balancing | Automatic | Manual |
| Data distribution | Fine-grained | Coarse-grained |
| Synchronization | None needed | None needed |
| Scalability | Limited by row groups | Limited by shards |
nanochat's rationale:
- Better load balancing (every rank sees all shards)
- Works well with any number of GPUs
- No coordination overhead
Conclusion
nanochat's data pipeline shows what thoughtful systems design looks like. Streaming I/O, parallel tokenization, distributed sharding, careful memory management. The result:
- ~12 MB per rank memory footprint
- 1.8M tokens/sec throughput (4 threads)
- Zero coordination overhead (each rank operates independently)
- Works with any number of GPUs
- Deterministic and reproducible
The key insights? Stream everything—never load more than you need. Pack tokens to eliminate padding waste. Overlap I/O and compute by prefetching during backward pass. Shard at the row-group level. Use the right tool: Rust for training, Tiktoken for inference.
You don't need complex distributed filesystems or elaborate data loading frameworks. Just careful attention to the fundamentals.
Related Posts
Previous in series:
Next in series:
Related topics:
Further Reading
- FineWeb-Edu dataset: HuggingFace
- Tiktoken library: GitHub
- Apache Parquet format: Documentation
- nanochat source code: GitHub
NOTE
About Experiments: The original source includes performance benchmarks and experiments. These are available in the nanochat repository. If there's interest from readers, I'll create a companion Jupyter notebook with interactive experiments.
This post is part of a comprehensive deep-dive series on nanochat, exploring the technical innovations that make training ChatGPT-style models accessible. Follow along as we build understanding from first principles.
On this page
- Introduction
- The Dataset: FineWeb-Edu 100B
- Why Parquet?
- Dataset Structure
- Train/Val Split
- The Tokenizer: Rust BPE + Tiktoken
- Training: RustBPE
- Inference: Tiktoken
- GPT-4 Style Tokenization
- The Data Loader: Streaming + Distributed
- How It Works
- The Key Design Patterns
- 1. Distributed Sharding via Strided Access
- 2. Token Buffer: Document Boundaries Don't Align with Batches
- 3. Two-Stage Batching
- 4. Pinned Memory + Async GPU Transfer
- 5. Infinite Data Stream
- Memory Efficiency Analysis
- Per-Rank Memory Usage
- Training Configuration Example
- Integration with Training Loop
- Gradient Accumulation with Data Loading
- Validation Data Loading
- On-Demand Dataset Download
- Performance Benchmarks
- Tokenization Throughput
- Memory Footprint During Training
- Design Lessons and Trade-offs
- 1. Streaming vs Precomputed Tokens
- 2. Document Packing vs Sequence-Level Batching
- 3. Distributed Sharding Strategy
- Conclusion
- Related Posts
- Further Reading



