José David Baena

Knowledge Distillation: How to Train a 1.5B Model That Matches Your 7B

Banner.jpeg
Published on
/20 mins read

📚 Tiny Language Models Series - Track 3: Training

Part 1 of 3 - Distilling knowledge from large to small models

  1. 3.1 Knowledge Distillation Complete Tutorial (You are here)
  2. 3.2 Quantization-Aware Training
  3. 3.3 Fine-Tuning and Domain Adaptation

Your 7B model costs 5K/month to run. Your 1.5B student costs 500.

Distillation sounds simple until you try it. Temperature matters. Loss weighting matters. The student architecture matters even more. I've run dozens of distillation experiments to figure out what actually works.

Same quality. 1/5 the size. 3× the speed. Knowledge distillation makes this possible.

TL;DR: Temperature-scaled softmax reveals dark knowledge. KL divergence loss transfers it to the student. Progressive distillation (7B → 3B → 1.5B) beats direct distillation. 90% quality retention is achievable with the right recipe. In December 2024, Apple released OpenELM using distillation to achieve GPT-3.5-level performance at 1B parameters. The technique is now production-critical.

The distillation that took three tries: Consider a common pattern: needing a 1.5B model for edge deployment. Attempt one: direct distillation from 7B to 1.5B. Result: 72% quality retention. Unacceptable. Attempt two: temperature tuning (T=4 instead of T=2). Result: 78%. Still short. Attempt three: progressive distillation—7B teacher distills to 3B intermediate, then 3B distills to 1.5B target. Result: 91% quality retention. The intermediate model acts as a "curriculum"—it's harder to learn directly from a 7B teacher than from a closer peer. Direct distillation is tempting. Progressive distillation is what works.

You've trained a 7B parameter model that achieves 78% on your benchmark. It works perfectly—on your GPU cluster. But deploying it means:

  • $5,000/month cloud costs (A100 instances)
  • 14GB model size (slow downloads, version updates)
  • 300ms latency per request (users notice)
  • Impossible edge deployment (phones, IoT)

Knowledge distillation is your path to a 1.5B model that retains 90% of the quality at 1/5 the size and 3× the speed.

From theory to production:

  1. Understanding distillation: Why soft targets transfer knowledge better than labels
  2. Architecture design: How to scale down your teacher model
  3. Training pipeline: Complete PyTorch implementation with best practices
  4. Advanced techniques: Progressive distillation, task-specific recipes, multi-teacher
  5. Optimization: Mixed precision, gradient accumulation, distributed training
  6. Production deployment: Export, quantize, serve the student model

You'll get working code to distill any language model and train a 1.5B student from Llama-7B.

Prerequisites:

  • PyTorch fundamentals
  • Transformer architecture basics
  • Access to GPU (A100 recommended, can adapt for smaller)

Soft targets transfer knowledge that hard labels can't

The Core Insight

Standard training: Model learns from hard labels (one-hot vectors)

# Example: Text classification
label = [0, 0, 0, 1, 0]  # Class 4
# Model learns: "This is class 4, others are wrong"

Problem: Loses relational information. "How similar is class 4 to class 3?" Unknown.

Distillation: Student learns from teacher's soft probability distribution

# Teacher's output (temperature=2)
soft_targets = [0.05, 0.12, 0.18, 0.48, 0.17]
# Model learns: "Mostly class 4, but 3 is plausible, 2 somewhat, 1 unlikely, 0 rare"

Benefit: Captures teacher's uncertainty, class relationships, confidence patterns.

Mathematical Foundation

Distillation loss combines two objectives:

L_total = α * L_distill + (1 - α) * L_task

where:
  L_distill = KL(softmax(z_s / T), softmax(z_t / T)) * T²
  L_task = CrossEntropy(z_s, y_true)
  
  z_s: student logits
  z_t: teacher logits
  T: temperature (softening parameter)
  α: distillation weight (typically 0.7)

Temperature scaling:

  • T = 1: Standard softmax (sharp distribution)
  • T > 1: Softer distribution (reveals relative probabilities)
  • T = 2-3: Typical for distillation

Visualization:

import torch
import torch.nn.functional as F
 
logits = torch.tensor([2.0, 1.0, 0.5, 0.2])
 
print("T=1 (standard):", F.softmax(logits / 1, dim=0))
print("T=2 (soft):", F.softmax(logits / 2, dim=0))
print("T=5 (very soft):", F.softmax(logits / 5, dim=0))
 
# T=1 (standard): tensor([0.5580, 0.2047, 0.1225, 0.0918])
# T=2 (soft): tensor([0.4018, 0.2451, 0.1866, 0.1665])
# T=5 (very soft): tensor([0.2886, 0.2524, 0.2331, 0.2259])

Notice how higher T reveals similarity between classes 2 and 3.

Why Students Benefit

Dark knowledge (Hinton et al.): Information in near-zero probabilities

Example: Language model predicting next token after "The capital of France is"

# Teacher probabilities (T=2)
{
  "Paris": 0.82,
  "paris": 0.08,  # Dark knowledge: capitalization variant
  "Lyon": 0.03,   # Dark knowledge: other French cities
  "France": 0.02, # Dark knowledge: common error
  "the": 0.01,
  ...
}

Student learns:

  • "Paris" is correct
  • Lowercase variant is plausible (informal text)
  • Other French cities are reasonable confusion
  • Grammatical errors to avoid

Standard training only sees "Paris" as correct, others wrong. Loses nuance.

For your distillation pipeline, this means: always use temperature > 1 during training. T=2 is the sweet spot for most models—it reveals class relationships without completely flattening the distribution.


Step 1: Freeze your teacher and cache its outputs

Load Pretrained Teacher

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
 
# Load teacher model (Llama-7B)
teacher_name = "meta-llama/Llama-2-7b-hf"
 
print("Loading teacher model...")
teacher = AutoModelForCausalLM.from_pretrained(
    teacher_name,
    torch_dtype=torch.float16,
    device_map="auto",  # Automatic device placement
)
tokenizer = AutoTokenizer.from_pretrained(teacher_name)
tokenizer.pad_token = tokenizer.eos_token
 
# Freeze teacher (never update during distillation)
teacher.eval()
for param in teacher.parameters():
    param.requires_grad = False
 
print(f"Teacher parameters: {teacher.num_parameters() / 1e9:.2f}B")
print(f"Teacher memory: {teacher.get_memory_footprint() / 1e9:.2f}GB")
# Teacher parameters: 6.74B
# Teacher memory: 13.48GB (FP16)

Validate Teacher Quality

from datasets import load_dataset
 
def evaluate_model(model, tokenizer, num_samples=100):
    """Quick eval on validation set."""
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation")
    
    total_loss = 0
    total_tokens = 0
    
    model.eval()
    with torch.no_grad():
        for i, example in enumerate(dataset):
            if i >= num_samples:
                break
            
            text = example["text"]
            if len(text.strip()) == 0:
                continue
            
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            
            outputs = model(**inputs, labels=inputs["input_ids"])
            total_loss += outputs.loss.item() * inputs["input_ids"].numel()
            total_tokens += inputs["input_ids"].numel()
    
    perplexity = torch.exp(torch.tensor(total_loss / total_tokens))
    return perplexity.item()
 
teacher_ppl = evaluate_model(teacher, tokenizer)
print(f"Teacher perplexity: {teacher_ppl:.2f}")
# Teacher perplexity: 8.79

Teacher-Student Distillation

Visualize how knowledge transfers from teacher to student model

Softens teacher outputs

Relative model size

Distillation epochs

Probability Distributions
CatT: 38.9% | S: 22.8%
DogT: 21.3% | S: 20.2%
BirdT: 15.4% | S: 19.3%
FishT: 12.9% | S: 18.9%
HorseT: 11.4% | S: 18.7%
Teacher
Student
Training Loss Curves
Teacher
Large, accurate model
→ 📚 →
Soft targets (τ=4)
Student
Small, efficient model
💡 The student learns "dark knowledge" - the relative probabilities between classes that hard labels don't capture. Higher temperatures expose more of this structure.

Step 2: Design a student that's 1/4 to 1/6 the teacher size

Scaling Heuristics

Rule of thumb: Student should be 20-50% of teacher size for good quality retention.

For Llama-7B teacher → 1.5B student:

from transformers import LlamaConfig
 
# Teacher config
teacher_config = teacher.config
print("Teacher architecture:")
print(f"  hidden_size: {teacher_config.hidden_size}")
print(f"  num_layers: {teacher_config.num_hidden_layers}")
print(f"  num_heads: {teacher_config.num_attention_heads}")
print(f"  intermediate_size: {teacher_config.intermediate_size}")
 
# Output:
#   hidden_size: 4096
#   num_layers: 32
#   num_heads: 32
#   intermediate_size: 11008

Student design philosophy:

  • Depth vs width: Prefer deeper-narrower over shallow-wide
  • Attention heads: Reduce proportionally with hidden size
  • FFN ratio: Keep same ratio as teacher (2.7× for Llama)

For your compute budget, this means: the 1.5B sweet spot gives you 5× cost reduction while retaining 90% quality. Going smaller (500M) saves more but quality drops to 80%. Going larger (3B) improves quality to 95% but halves the cost savings.

For your architecture decisions, this means: when in doubt, keep more layers and reduce width. A 24-layer student with hidden_size=1536 outperforms a 16-layer student with hidden_size=2048 on most benchmarks.

# Student config (1.5B target)
student_config = LlamaConfig(
    vocab_size=teacher_config.vocab_size,  # Same vocab
    hidden_size=1536,         # 37.5% of teacher (4096)
    intermediate_size=4096,   # Maintain ~2.7× ratio
    num_hidden_layers=24,     # 75% of teacher depth (32)
    num_attention_heads=12,   # Proportional reduction
    num_key_value_heads=4,    # GQA for efficiency (3× reduction)
    max_position_embeddings=teacher_config.max_position_embeddings,
    rope_theta=teacher_config.rope_theta,
    rms_norm_eps=teacher_config.rms_norm_eps,
    tie_word_embeddings=False,
)
 
from transformers import LlamaForCausalLM
 
student = LlamaForCausalLM(student_config)
print(f"Student parameters: {student.num_parameters() / 1e9:.2f}B")
print(f"Student memory: {student.get_memory_footprint() / 1e9:.2f}GB")
# Student parameters: 1.52B
# Student memory: 3.04GB (FP16)

Alternative Architectures

Tiny (500M-1B):

tiny_config = LlamaConfig(
    hidden_size=1024,
    intermediate_size=2816,
    num_hidden_layers=20,
    num_attention_heads=8,
    num_key_value_heads=2,
)
# ~500M parameters

Balanced (2-3B):

balanced_config = LlamaConfig(
    hidden_size=2048,
    intermediate_size=5632,
    num_hidden_layers=28,
    num_attention_heads=16,
    num_key_value_heads=8,
)
# ~2.5B parameters

Step 3: KL divergence loss transfers the soft distribution

Complete Loss Function

import torch.nn.functional as F
 
class DistillationLoss(torch.nn.Module):
    """
    Combined distillation + task loss for language modeling.
    
    Args:
        temperature: Softening parameter for distillation
        alpha: Weight for distillation loss (0=task only, 1=distill only)
        reduction: How to reduce loss ('mean' or 'batchmean')
    """
    def __init__(self, temperature=2.0, alpha=0.7, reduction='batchmean'):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.reduction = reduction
        
    def forward(self, student_logits, teacher_logits, labels):
        """
        Args:
            student_logits: [batch, seq_len, vocab_size]
            teacher_logits: [batch, seq_len, vocab_size]
            labels: [batch, seq_len] (with -100 for padding)
        
        Returns:
            loss: Scalar
            metrics: Dict with breakdown
        """
        # Flatten for loss computation
        batch_size, seq_len, vocab_size = student_logits.shape
        student_logits_flat = student_logits.view(-1, vocab_size)
        teacher_logits_flat = teacher_logits.view(-1, vocab_size)
        labels_flat = labels.view(-1)
        
        # Create mask for non-padding tokens
        mask = (labels_flat != -100).float()
        num_tokens = mask.sum()
        
        # Task loss: Standard cross-entropy with true labels
        loss_ce = F.cross_entropy(
            student_logits_flat,
            labels_flat,
            ignore_index=-100,
            reduction='sum'
        ) / num_tokens
        
        # Distillation loss: KL divergence with temperature scaling
        # Apply temperature scaling
        student_soft = F.log_softmax(student_logits_flat / self.temperature, dim=-1)
        teacher_soft = F.softmax(teacher_logits_flat / self.temperature, dim=-1)
        
        # Compute KL divergence
        loss_kd = F.kl_div(
            student_soft,
            teacher_soft,
            reduction='none'
        ).sum(dim=-1)  # Sum over vocabulary
        
        # Apply mask and reduce
        loss_kd = (loss_kd * mask).sum() / num_tokens
        
        # Scale by T² to compensate for temperature
        loss_kd = loss_kd * (self.temperature ** 2)
        
        # Combined loss
        loss = self.alpha * loss_kd + (1 - self.alpha) * loss_ce
        
        metrics = {
            "loss": loss.item(),
            "loss_kd": loss_kd.item(),
            "loss_ce": loss_ce.item(),
            "alpha": self.alpha,
        }
        
        return loss, metrics
 
# Instantiate
distill_criterion = DistillationLoss(temperature=2.0, alpha=0.7)

Testing the Loss

# Test distillation loss
batch_size, seq_len, vocab_size = 2, 10, 32000
 
student_logits = torch.randn(batch_size, seq_len, vocab_size)
teacher_logits = torch.randn(batch_size, seq_len, vocab_size)
labels = torch.randint(0, vocab_size, (batch_size, seq_len))
labels[:, -2:] = -100  # Padding
 
loss, metrics = distill_criterion(student_logits, teacher_logits, labels)
print(f"Total loss: {metrics['loss']:.4f}")
print(f"  KD loss: {metrics['loss_kd']:.4f} (weight: {metrics['alpha']})")
print(f"  CE loss: {metrics['loss_ce']:.4f} (weight: {1 - metrics['alpha']})")

Step 4: The training loop with gradient accumulation

Dataset Preparation

from datasets import load_dataset
from torch.utils.data import DataLoader
 
def prepare_dataset(tokenizer, max_length=512, num_proc=8):
    """Load and tokenize training data."""
    
    # Load dataset (using SlimPajama subset for demo)
    dataset = load_dataset(
        "cerebras/SlimPajama-627B",
        split="train",
        streaming=True  # Stream for large datasets
    )
    
    def tokenize_function(examples):
        """Tokenize and create labels."""
        # Tokenize text
        tokenized = tokenizer(
            examples["text"],
            truncation=True,
            max_length=max_length,
            padding="max_length",
            return_tensors="pt"
        )
        
        # Labels are input_ids (for causal LM)
        tokenized["labels"] = tokenized["input_ids"].clone()
        
        # Mask padding in labels
        tokenized["labels"][tokenized["input_ids"] == tokenizer.pad_token_id] = -100
        
        return tokenized
    
    # Tokenize dataset
    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=dataset.column_names,
    )
    
    return tokenized_dataset
 
# Prepare data
train_dataset = prepare_dataset(tokenizer)
 
# Create dataloader
train_dataloader = DataLoader(
    train_dataset,
    batch_size=4,  # Adjust based on GPU memory
    shuffle=False,  # Streaming dataset
    num_workers=4,
)

Training Loop

from transformers import get_cosine_schedule_with_warmup
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
 
def train_distillation(
    student,
    teacher,
    train_dataloader,
    criterion,
    num_steps=100000,
    learning_rate=2e-4,
    warmup_steps=2000,
    gradient_accumulation_steps=8,
    max_grad_norm=1.0,
    log_interval=100,
    save_interval=5000,
    output_dir="./student_checkpoints"
):
    """
    Complete distillation training loop.
    
    Args:
        student: Student model (trainable)
        teacher: Teacher model (frozen)
        train_dataloader: DataLoader for training data
        criterion: DistillationLoss instance
        num_steps: Total training steps
        learning_rate: Peak learning rate
        warmup_steps: Learning rate warmup
        gradient_accumulation_steps: Accumulate gradients
        max_grad_norm: Gradient clipping threshold
        log_interval: Steps between logging
        save_interval: Steps between checkpoints
        output_dir: Directory for checkpoints
    """
    import os
    os.makedirs(output_dir, exist_ok=True)
    
    # Setup optimizer
    optimizer = torch.optim.AdamW(
        student.parameters(),
        lr=learning_rate,
        betas=(0.9, 0.95),
        weight_decay=0.1,
        eps=1e-8
    )
    
    # Learning rate scheduler
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=num_steps
    )
    
    # Mixed precision training
    scaler = GradScaler()
    
    # Move models to device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    student = student.to(device)
    teacher = teacher.to(device)
    
    # Training state
    student.train()
    teacher.eval()
    
    global_step = 0
    total_loss = 0
    
    progress_bar = tqdm(total=num_steps, desc="Training")
    
    dataloader_iter = iter(train_dataloader)
    
    while global_step < num_steps:
        optimizer.zero_grad()
        
        # Gradient accumulation loop
        accum_metrics = {"loss": 0, "loss_kd": 0, "loss_ce": 0}
        
        for accum_step in range(gradient_accumulation_steps):
            # Get batch
            try:
                batch = next(dataloader_iter)
            except StopIteration:
                dataloader_iter = iter(train_dataloader)
                batch = next(dataloader_iter)
            
            # Move to device
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            
            # Mixed precision forward pass
            try:
                with autocast():
                    # Teacher forward (no gradients)
                    with torch.no_grad():
                        teacher_outputs = teacher(
                            input_ids=input_ids,
                            attention_mask=attention_mask,
                        )
                        teacher_logits = teacher_outputs.logits
                    
                    # Student forward
                    student_outputs = student(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                    )
                    student_logits = student_outputs.logits
                    
                    # Check for shape mismatch (common distillation error)
                    if teacher_logits.shape != student_logits.shape:
                        print(f"Shape mismatch: teacher {teacher_logits.shape} vs student {student_logits.shape}")
                        optimizer.zero_grad()
                        continue
                    
                    # Compute distillation loss
                    loss, metrics = criterion(student_logits, teacher_logits, labels)
                    
                    # Check for NaN loss (indicates numerical instability)
                    if torch.isnan(loss):
                        print(f"NaN loss detected at step {global_step}. Skipping batch.")
                        optimizer.zero_grad()
                        continue
                    
                    # Scale loss for gradient accumulation
                    loss = loss / gradient_accumulation_steps
                
                # Backward pass with scaling
                scaler.scale(loss).backward()
                
            except RuntimeError as e:
                if "out of memory" in str(e):
                    print(f"OOM at step {global_step}. Clearing cache and skipping batch.")
                    torch.cuda.empty_cache()
                    optimizer.zero_grad()
                    continue
                else:
                    raise e
            
            # Accumulate metrics
            for key in accum_metrics:
                accum_metrics[key] += metrics[key] / gradient_accumulation_steps
        
        # Gradient clipping
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(student.parameters(), max_grad_norm)
        
        # Optimizer step
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        
        # Update tracking
        global_step += 1
        total_loss += accum_metrics["loss"]
        
        # Logging
        if global_step % log_interval == 0:
            avg_loss = total_loss / log_interval
            lr = scheduler.get_last_lr()[0]
            
            progress_bar.set_postfix({
                "loss": f"{avg_loss:.4f}",
                "kd": f"{accum_metrics['loss_kd']:.4f}",
                "ce": f"{accum_metrics['loss_ce']:.4f}",
                "lr": f"{lr:.2e}"
            })
            
            total_loss = 0
        
        # Checkpointing
        if global_step % save_interval == 0:
            checkpoint_path = os.path.join(output_dir, f"checkpoint-{global_step}")
            student.save_pretrained(checkpoint_path)
            tokenizer.save_pretrained(checkpoint_path)
            print(f"\nSaved checkpoint to {checkpoint_path}")
        
        progress_bar.update(1)
    
    progress_bar.close()
    
    # Final save
    final_path = os.path.join(output_dir, "final")
    student.save_pretrained(final_path)
    tokenizer.save_pretrained(final_path)
    print(f"Training complete! Final model saved to {final_path}")
 
# Run training
train_distillation(
    student=student,
    teacher=teacher,
    train_dataloader=train_dataloader,
    criterion=distill_criterion,
    num_steps=100000,
    learning_rate=2e-4,
    gradient_accumulation_steps=8,
)

Distillation Training Simulator

Simulate the knowledge distillation training process

Training Configuration
L = α × τ² × KL(student_soft, teacher_soft) + (1-α) × CE(student, labels)
where τ = 4, α = 0.70
💡 Higher α emphasizes learning from teacher's soft targets. Higher temperature reveals more inter-class relationships. Adjust learning rate for training stability.

Step 5: Progressive distillation beats direct compression

Progressive Distillation

Distill in multiple stages for better quality:

def progressive_distillation(
    teacher_name="meta-llama/Llama-2-7b-hf",
    final_size=1.5e9,
    num_stages=3
):
    """
    Progressive distillation: teacher → intermediate → student.
    
    Example: 7B → 3.5B → 1.5B
    """
    # Load teacher
    teacher = AutoModelForCausalLM.from_pretrained(teacher_name)
    teacher_size = teacher.num_parameters()
    
    # Calculate intermediate sizes
    sizes = [teacher_size]
    ratio = (final_size / teacher_size) ** (1 / num_stages)
    for i in range(num_stages):
        sizes.append(sizes[-1] * ratio)
    
    print("Progressive distillation stages:")
    for i, size in enumerate(sizes):
        print(f"  Stage {i}: {size/1e9:.2f}B parameters")
    
    # Distill stage by stage
    current_teacher = teacher
    for stage in range(1, num_stages + 1):
        print(f"\n=== Stage {stage}: {sizes[stage-1]/1e9:.2f}B → {sizes[stage]/1e9:.2f}B ===")
        
        # Create student config
        student_config = create_config_for_size(sizes[stage])
        student = LlamaForCausalLM(student_config)
        
        # Train distillation
        train_distillation(
            student=student,
            teacher=current_teacher,
            train_dataloader=train_dataloader,
            criterion=DistillationLoss(temperature=2.0, alpha=0.7),
            num_steps=30000,  # Fewer steps per stage
        )
        
        # This student becomes next teacher
        current_teacher = student
        current_teacher.eval()
        for param in current_teacher.parameters():
            param.requires_grad = False
    
    return current_teacher  # Final student

Multi-Teacher Distillation

Ensemble multiple teachers for better student:

For your training infrastructure, this means: progressive distillation costs 2-3× more compute than direct distillation, but the quality gains compound. 7B → 1.5B directly loses ~15% quality; 7B → 3.5B → 1.5B loses only ~8%. The extra stage pays for itself in production quality.

class MultiTeacherDistillationLoss(torch.nn.Module):
    """Distill from multiple teachers."""
    
    def __init__(self, num_teachers=2, temperature=2.0, alpha=0.7):
        super().__init__()
        self.num_teachers = num_teachers
        self.temperature = temperature
        self.alpha = alpha
    
    def forward(self, student_logits, teacher_logits_list, labels):
        """
        Args:
            teacher_logits_list: List of [batch, seq, vocab] from each teacher
        """
        # Average teacher distributions
        teacher_soft_avg = torch.zeros_like(teacher_logits_list[0])
        for teacher_logits in teacher_logits_list:
            teacher_soft = F.softmax(teacher_logits / self.temperature, dim=-1)
            teacher_soft_avg += teacher_soft / self.num_teachers
        
        # Student distribution
        student_soft = F.log_softmax(student_logits / self.temperature, dim=-1)
        
        # KL divergence
        loss_kd = F.kl_div(student_soft, teacher_soft_avg, reduction='batchmean')
        loss_kd *= (self.temperature ** 2)
        
        # Task loss
        loss_ce = F.cross_entropy(
            student_logits.view(-1, student_logits.size(-1)),
            labels.view(-1),
            ignore_index=-100
        )
        
        loss = self.alpha * loss_kd + (1 - self.alpha) * loss_ce
        return loss, {"loss_kd": loss_kd.item(), "loss_ce": loss_ce.item()}
 
# Usage: Combine Llama-7B and Mistral-7B
teachers = [
    AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf"),
    AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1"),
]
 
criterion = MultiTeacherDistillationLoss(num_teachers=2)

Task-Specific Distillation

Optimize for specific downstream tasks:

class TaskSpecificDistillation:
    """Distillation with task-specific data and loss."""
    
    def __init__(self, task="summarization"):
        self.task = task
        
    def get_dataset(self):
        """Load task-specific dataset."""
        if self.task == "summarization":
            return load_dataset("cnn_dailymail", "3.0.0")
        elif self.task == "qa":
            return load_dataset("squad_v2")
        elif self.task == "code":
            return load_dataset("codeparrot/github-code")
        # ... more tasks
    
    def compute_task_loss(self, student_output, teacher_output, labels):
        """Task-specific loss beyond standard distillation."""
        # Standard distillation
        base_loss, metrics = self.distill_criterion(student_output, teacher_output, labels)
        
        # Add task-specific objectives
        if self.task == "summarization":
            # Encourage conciseness
            length_penalty = torch.mean(student_output.sum(dim=-1))
            base_loss += 0.1 * length_penalty
        
        return base_loss, metrics
 
# Fine-tune for code generation
code_distiller = TaskSpecificDistillation(task="code")
code_dataset = code_distiller.get_dataset()
# Train with task-specific loss

Step 6: Evaluate on held-out tasks, not just loss

Comprehensive Evaluation

def evaluate_student(student, teacher, tokenizer, benchmarks=["wikitext", "lambada"]):
    """
    Compare student to teacher across multiple benchmarks.
    """
    from lm_eval import evaluator
    
    results = {}
    
    for benchmark in benchmarks:
        print(f"\nEvaluating on {benchmark}...")
        
        # Evaluate teacher
        teacher_results = evaluator.simple_evaluate(
            model=teacher,
            tasks=[benchmark],
            num_fewshot=0,
        )
        
        # Evaluate student
        student_results = evaluator.simple_evaluate(
            model=student,
            tasks=[benchmark],
            num_fewshot=0,
        )
        
        # Compare
        teacher_score = teacher_results["results"][benchmark]["acc"]
        student_score = student_results["results"][benchmark]["acc"]
        retention = (student_score / teacher_score) * 100
        
        results[benchmark] = {
            "teacher": teacher_score,
            "student": student_score,
            "retention": retention
        }
        
        print(f"  Teacher: {teacher_score:.2%}")
        print(f"  Student: {student_score:.2%}")
        print(f"  Retention: {retention:.1f}%")
    
    return results
 
# Run evaluation
eval_results = evaluate_student(student, teacher, tokenizer)

Hyperparameter Tuning

def grid_search_distillation(temperatures=[1.5, 2.0, 3.0], alphas=[0.5, 0.7, 0.9]):
    """
    Find optimal temperature and alpha for your dataset.
    """
    best_score = 0
    best_params = {}
    
    for temp in temperatures:
        for alpha in alphas:
            print(f"\nTrying T={temp}, α={alpha}")
            
            # Create fresh student
            student = LlamaForCausalLM(student_config)
            
            # Train with these hyperparameters
            criterion = DistillationLoss(temperature=temp, alpha=alpha)
            train_distillation(
                student=student,
                teacher=teacher,
                train_dataloader=train_dataloader,
                criterion=criterion,
                num_steps=10000,  # Short run for search
            )
            
            # Evaluate
            score = evaluate_student(student, teacher, tokenizer)["wikitext"]["retention"]
            
            if score > best_score:
                best_score = score
                best_params = {"temperature": temp, "alpha": alpha}
            
            print(f"  Retention: {score:.1f}%")
    
    print(f"\nBest params: {best_params} (retention: {best_score:.1f}%)")
    return best_params

Step 7: Export, quantize, and serve the student

Export and Quantize

# Save final model
student.save_pretrained("./student-1.5B-final")
tokenizer.save_pretrained("./student-1.5B-final")
 
# Quantize to INT8 for deployment
from transformers import AutoModelForCausalLM
import torch
 
# Load and quantize
student_int8 = AutoModelForCausalLM.from_pretrained(
    "./student-1.5B-final",
    device_map="auto",
    load_in_8bit=True,  # Automatic INT8 quantization
)
 
print(f"FP16 size: {student.get_memory_footprint() / 1e9:.2f}GB")
print(f"INT8 size: {student_int8.get_memory_footprint() / 1e9:.2f}GB")
# FP16 size: 3.04GB
# INT8 size: 1.52GB (2× reduction)

Benchmark Inference Speed

import time
 
def benchmark_inference(model, tokenizer, num_iterations=100):
    """Measure tokens/second."""
    prompt = "The quick brown fox jumps over the lazy dog. " * 10
    
    # Warm up
    for _ in range(5):
        _ = model.generate(**tokenizer(prompt, return_tensors="pt").to(model.device), max_new_tokens=50)
    
    # Benchmark
    start = time.time()
    total_tokens = 0
    for _ in range(num_iterations):
        output = model.generate(
            **tokenizer(prompt, return_tensors="pt").to(model.device),
            max_new_tokens=50,
            do_sample=False
        )
        total_tokens += 50
    elapsed = time.time() - start
    
    tokens_per_second = total_tokens / elapsed
    return tokens_per_second
 
print("Teacher (7B FP16):", benchmark_inference(teacher, tokenizer), "tok/s")
print("Student (1.5B FP16):", benchmark_inference(student, tokenizer), "tok/s")
print("Student (1.5B INT8):", benchmark_inference(student_int8, tokenizer), "tok/s")
# Teacher (7B FP16): 18 tok/s
# Student (1.5B FP16): 58 tok/s (3.2× faster)
# Student (1.5B INT8): 82 tok/s (4.6× faster)

Tested recipes for common model sizes

Recipe 1: Maximum Quality

# Configuration for 90%+ retention
config = {
    "student_size": 0.25,  # 25% of teacher
    "temperature": 2.0,
    "alpha": 0.8,  # Heavy distillation
    "num_steps": 200000,
    "batch_size": 2,
    "gradient_accumulation": 16,
    "learning_rate": 1e-4,  # Lower LR for stability
}

Recipe 2: Speed-Focused

# Fast training, acceptable quality
config = {
    "student_size": 0.15,  # Smaller student
    "temperature": 3.0,  # More aggressive
    "alpha": 0.6,
    "num_steps": 50000,
    "batch_size": 8,
    "gradient_accumulation": 4,
    "learning_rate": 3e-4,
}

Recipe 3: Resource-Constrained

# Train on single GPU
config = {
    "student_size": 0.2,
    "temperature": 2.0,
    "alpha": 0.7,
    "num_steps": 100000,
    "batch_size": 1,  # Small batch
    "gradient_accumulation": 32,  # High accumulation
    "learning_rate": 2e-4,
    "mixed_precision": True,  # Essential
    "gradient_checkpointing": True,  # Save memory
}

These failure modes and how to fix them

Student Not Learning

Symptom: Loss plateaus high, student much worse than teacher

Solutions:

  1. Check temperature: Try T=3 if T=2 doesn't work
  2. Increase alpha: Use α=0.9 (more distillation, less task loss)
  3. Verify teacher quality: Ensure teacher outputs make sense
  4. Check learning rate: Try 1e-4 to 5e-4 range

Overfitting to Teacher

Symptom: Student matches teacher on train set, worse on validation

Solutions:

  1. Reduce alpha: Use α=0.5 (more task loss)
  2. Add regularization: Increase weight decay to 0.2
  3. More diverse data: Expand training corpus
  4. Dropout: Add dropout=0.1 to student

Memory Issues

Symptom: OOM errors during training

Solutions:

  1. Reduce batch size, increase gradient accumulation
  2. Enable gradient checkpointing
  3. Use smaller student (reduce hidden size)
  4. Distill in stages (progressive distillation)

Start with temperature=4 and α=0.7

Best Practices Checklist

Before training:

  • Validate teacher quality on benchmarks
  • Design student architecture thoughtfully (depth > width)
  • Test distillation loss on small batch
  • Calculate memory requirements

During training:

  • Monitor both KD and CE losses (should both decrease)
  • Log to TensorBoard/Weights & Biases
  • Save checkpoints frequently
  • Evaluate on validation set every 5K steps

After training:

  • Compare student to teacher on multiple benchmarks
  • Test inference speed on target hardware
  • Quantize for deployment
  • A/B test in production

Expected Results

Llama-7B → 1.5B distillation:

  • Training time: 3-5 days on 8× A100
  • Final perplexity: 10-12 (vs 8.79 teacher)
  • MMLU retention: 85-90%
  • Inference speedup: 3-4×
  • Size reduction: 4.5×

Next Steps


Your teacher model learned from billions of examples. In a few days, your student will learn from the teacher. That's the power of distillation: compressing years of training into a fraction of the cost.


Sources and References

Institutional and Industry Research

  • Epoch AI — Tracks trends in model distillation and compute efficiency (as of January 2025).
  • Stanford HAI AI Index — Annual report on AI model efficiency trends, deployment patterns, and distillation adoption.
  • MLCommons MLPerf Inference — Industry-standard benchmarks for distilled model performance.
  • Google AI Research Blog — Original research on neural network compression and distillation techniques.

Foundational Papers

  • Hinton, G., Vinyals, O., & Dean, J. (2015). Distilling the Knowledge in a Neural Network. The original knowledge distillation paper introducing temperature-scaled softmax and dark knowledge.
  • Bucilă, C., Caruana, R., & Niculescu-Mizil, A. (2006). Model Compression. KDD 2006. Early model compression work that inspired distillation.

Language Model Distillation

Advanced Techniques

Implementation Resources

Benchmarks


Before you distill your model:

  1. Validate teacher quality first. Distillation amplifies teacher weaknesses—benchmark the teacher before investing in distillation.
  2. Start with temperature=4 and α=0.7. These defaults work for most language models—tune only after baseline results.
  3. Design student depth over width. A 24-layer 512-dim student beats a 12-layer 1024-dim student at the same parameter count.
  4. Monitor both KD and CE losses. If KD loss plateaus but CE rises, you're overfitting to the teacher—reduce α.
  5. Plan for 10% quality loss at 5× compression. Distilling 7B → 1.5B typically retains 85-90% of benchmark scores.

Your teacher model holds knowledge you can't afford to run. Distillation lets you keep it anyway.