December 30, 2025
Python PyTorch Deep Learning Neural Networks

Training Loops: Loss Functions, Optimizers, and Learning Rate Scheduling

You've built your neural network. Your data's loaded. Now comes the part that actually teaches your model, the training loop. This is where theory meets practice, where loss decreases (hopefully), and where understanding the mechanics separates "I got something to run" from "I built something that actually learns."

Here's the thing most tutorials gloss over: the training loop isn't just boilerplate you copy and paste. Every line in that loop makes a deliberate design decision. The loss function you choose tells the model what "wrong" means. The optimizer determines how aggressively and intelligently it corrects course. The learning rate schedule controls whether you explore broadly early on or zoom in carefully once you've found a promising region. Getting any one of these wrong doesn't just slow training, it can cause it to fail entirely, produce a model that memorizes instead of generalizes, or leave you chasing NaN losses with no idea why.

This article unpacks all of it. We'll walk through each component of the canonical PyTorch training loop and explain not just what it does but why it's designed that way. We'll look at how to pick loss functions for different task types, compare the major optimizer families, discuss how learning rate schedules influence convergence, and give you a complete, production-ready template you can adapt to your own projects. By the end, you'll understand the reasoning behind every decision, which means you'll be able to debug problems when they arise and tune your setup when off-the-shelf defaults aren't cutting it.

Let's dig into the canonical PyTorch training loop, dissect every decision you need to make, and show you why each piece matters.

Table of Contents
  1. The Canonical Training Loop: Four Steps That Matter
  2. Loss Functions: Choosing the Right One
  3. Loss Function Intuition
  4. MSELoss: Regression's Workhorse
  5. CrossEntropyLoss: Classification King
  6. BCEWithLogitsLoss: Multi-Label Classification
  7. Optimizers: How Your Model Actually Learns
  8. Optimizer Comparison
  9. SGD with Momentum: The Classic
  10. Adam: The Adaptive Default
  11. AdamW: Adam, But Fixed
  12. DataLoader and Batching: Why Size Matters
  13. Validation Loop: Measuring Real Performance
  14. Learning Rate Scheduling: Adapting as You Train
  15. Learning Rate Strategies
  16. StepLR: Simple Decay
  17. CosineAnnealingLR: Smooth Decay
  18. OneCycleLR: Cyclical Scheduling
  19. Gradient Clipping: Preventing Explosions
  20. PyTorch Lightning: Condensing the Boilerplate
  21. Putting It All Together: A Complete Example
  22. Debugging Your Training Loop
  23. Key Takeaways

The Canonical Training Loop: Four Steps That Matter

Every modern neural network training follows the same four-step rhythm:

  1. Forward pass: Push data through the model, get predictions
  2. Loss calculation: Measure how wrong we are
  3. Backward pass: Compute gradients (how to improve)
  4. Optimizer step: Actually improve the weights

Before looking at code, it helps to think about what's actually happening mathematically. Your model is a function that maps inputs to outputs, parameterized by millions of weights. Training is an optimization problem: find the weights that minimize the loss function across your training data. Gradient descent is the algorithm we use to solve that optimization, it computes the slope of the loss with respect to each weight, then moves each weight a small step in the downhill direction. The four-step loop is just that process, batched and repeated.

Here's what it looks like in code:

python
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
 
# Setup
model = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
 
# Create dummy data for illustration
X_train = torch.randn(1000, 784)
y_train = torch.randint(0, 10, (1000,))
train_loader = DataLoader(
    TensorDataset(X_train, y_train),
    batch_size=32,
    shuffle=True
)
 
# The training loop
for epoch in range(10):
    epoch_loss = 0.0
    for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
        # Move to device (GPU/CPU)
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
 
        # 1. FORWARD PASS: predictions = model(input)
        # WHY: This generates model outputs given current weights
        logits = model(X_batch)
 
        # 2. LOSS: measure prediction error
        # WHY: Loss quantifies "how wrong are we", gradient descent minimizes this
        loss = criterion(logits, y_batch)
 
        # 3. BACKWARD PASS: compute gradients
        # WHY: .backward() traces the chain rule back through every weight
        # .zero_grad() first because PyTorch accumulates gradients by default
        optimizer.zero_grad()
        loss.backward()
 
        # 4. OPTIMIZER STEP: update weights using gradients
        # WHY: optimizer.step() moves weights in the direction that reduces loss
        optimizer.step()
 
        epoch_loss += loss.item()
 
    avg_loss = epoch_loss / len(train_loader)
    print(f"Epoch {epoch+1}/10, Loss: {avg_loss:.4f}")

Read through that and you'll notice the loop body is surprisingly compact given everything happening under the hood. Each call, model(X_batch), criterion(...), optimizer.zero_grad(), loss.backward(), optimizer.step(), triggers a cascade of tensor operations that would take hundreds of lines to write manually. PyTorch's autograd system handles all of that for you, but only if you call things in the right order.

Why these four steps in this order?

  • zero_grad() first, because PyTorch accumulates gradients. If you don't zero them, old gradients pile up and corrupt the update direction.
  • backward() after loss, because the loss function defines what "wrong" means, gradients point back through that error signal.
  • step() after backward, because we need the gradients before we can use them to update weights.

This order is sacred. Breaking it breaks training.

Understanding the gradient flow: When you call loss.backward(), PyTorch traces the computational graph backward from the loss to every parameter that contributed to it. It computes the derivative of the loss with respect to each parameter using the chain rule. These derivatives tell us which direction to move each weight to reduce the loss. The optimizer uses these directions, scaled by the learning rate.

Common mistakes here: forgetting zero_grad() causes gradients to accumulate across batches (usually a bug, but sometimes intentional for gradient accumulation). Calling step() before backward() wastes the computation and updates based on stale gradients. Moving data to device after creating the loss wastes GPU memory.

Loss Functions: Choosing the Right One

The loss function is your model's "report card." It translates the task into a number that optimization can chase. Pick the wrong one, and your model will optimize for the wrong thing.

Loss Function Intuition

Before diving into specific functions, it's worth building an intuition for what we're actually asking the loss to do. A loss function must be differentiable, that's non-negotiable, because we need gradients to flow backward through it. But beyond that constraint, the choice of loss function is essentially a statement about what kinds of errors you care about and how much.

Think of it this way: if you're predicting house prices and you use Mean Squared Error, you're telling the model that being off by $100,000 is four times as bad as being off by $50,000 (because the error is squared). That quadratic penalty makes the model very sensitive to large errors, it will sacrifice accuracy on typical predictions to avoid catastrophic failures on outliers. If instead you used Mean Absolute Error, you're saying all errors scale linearly; a $100,000 mistake is exactly twice as bad as a $50,000 one. Neither is objectively better, the right choice depends on your task.

For classification, the logic is different. Cross-entropy loss is grounded in information theory: it measures how many bits of information it would take to encode the true label under the model's predicted probability distribution. A model that confidently predicts the wrong class wastes a lot of bits, so it gets heavily penalized. A model that assigns even probability to all classes loses some information but not catastrophically. This matches intuition, confident wrong answers are worse than uncertain ones. The key insight is that the loss function shapes the entire optimization landscape your model will navigate, so it pays to understand exactly what you're asking it to minimize.

MSELoss: Regression's Workhorse

When to use: Predicting continuous values (price, temperature, regression).

MSELoss is the default choice for regression because it's smooth, well-understood, and works well when your target values don't have extreme outliers. Before reaching for anything else, start here and see if it gets the job done.

python
criterion = nn.MSELoss()
# Use in training loop
loss = criterion(predictions, targets)

Why MSE? It penalizes large errors quadratically (squared error). A prediction that's off by 2 costs 4x more than one off by 1. This makes the model very careful about outliers. The formula is simple: for each sample, subtract prediction from target, square it, and average across the batch.

When MSE is good: Regression tasks where you care about magnitude of error equally. Predicting stock prices, temperature, house values.

When MSE might be bad: If you have outliers, squared penalty can dominate the loss and send gradients crazy. Consider MAELoss (mean absolute error) for robustness, or smooth L1 loss for a hybrid approach.

One practical pattern: always normalize your regression targets before training. If your targets span a range of 100,000, your raw MSE loss will be enormous compared to what you'd see after normalization, this doesn't break training mathematically, but it makes choosing a good learning rate much harder.

Practical example:

python
# Predicting house prices
model = nn.Sequential(
    nn.Linear(10, 64),
    nn.ReLU(),
    nn.Linear(64, 1)
)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
 
# In your loop:
predictions = model(X)  # Shape: (batch_size, 1)
loss = criterion(predictions, y)  # y is real values like [250000, 320000, ...]

Notice that the model's final layer outputs a single value (no activation), and the targets are real numbers. MSELoss doesn't care about the range, it compares whatever the model outputs against whatever the target is. Squeezing the output tensor to match the target shape (predictions.squeeze(1)) is a common gotcha when the shapes don't align.

CrossEntropyLoss: Classification King

When to use: Multi-class classification (is this a cat, dog, or bird?).

CrossEntropyLoss handles the overwhelmingly common case where each sample belongs to exactly one of several mutually exclusive categories. It's what you should reach for any time you're predicting a single class label.

python
criterion = nn.CrossEntropyLoss()
# Use in training loop
loss = criterion(logits, class_indices)

Why CrossEntropy? It combines softmax (convert logits to probabilities) + negative log-likelihood. It punishes confident wrong predictions harder than uncertain ones. If the model predicts 0.9 probability for the wrong class, the loss is huge. If it predicts 0.1, the loss is smaller but still penalizes incorrectness.

Mathematical intuition: Softmax converts raw scores (logits) into a probability distribution. Log-likelihood measures how likely the true class is under the model's distribution. CrossEntropy is the negative log-likelihood, so minimizing it makes the true class more probable.

Critical detail: Pass logits (raw model outputs), not probabilities. CrossEntropyLoss does the softmax for you. This is more numerically stable. If you softmax yourself and then log, floating-point errors accumulate. PyTorch combines these operations safely.

python
# CORRECT: Pass logits
logits = model(X)  # Shape: (batch_size, num_classes)
loss = nn.CrossEntropyLoss()(logits, targets)  # targets are class indices [0, 1, 2, ...]
 
# WRONG: Don't softmax manually
probs = torch.softmax(logits, dim=1)
loss = nn.CrossEntropyLoss()(probs, targets)  # This breaks the numerics

After this code runs, the gradients from the combined softmax + log + negative operation flow cleanly back through the model. If you inserted a manual softmax first, the gradients would still flow but with reduced numerical precision, in practice this can cause training to be unstable or converge to worse solutions.

Practical tip: If your model is outputting probabilities (e.g., using softmax in the final layer), remove that softmax and pass logits directly. CrossEntropyLoss expects unbounded values.

BCEWithLogitsLoss: Multi-Label Classification

When to use: Each sample can have multiple labels (image contains both a dog and a cat).

This is the loss that trips people up most often. The key distinction from CrossEntropyLoss is that here the classes are not mutually exclusive, a news article can be tagged as both "politics" and "economy," a movie can belong to both "drama" and "thriller." Each output neuron makes an independent binary decision.

python
criterion = nn.BCEWithLogitsLoss()
# Use in training loop
loss = criterion(logits, binary_targets)

Why? Like CrossEntropyLoss, it expects logits and handles sigmoid internally for numerical stability. Each output neuron is independent (not competing like in softmax).

python
# Example: predicting movie genres for a film
# Output shape: (batch_size, num_genres)
# Each genre gets its own logit, evaluated independently
logits = model(X)  # Shape: (batch_size, 10) for 10 genres
targets = torch.tensor([[1, 0, 1, 0, ...], ...])  # 1 = has genre, 0 = doesn't
loss = criterion(logits, targets.float())  # Convert to float

A common mistake here is forgetting to call .float() on the targets. PyTorch's BCEWithLogitsLoss expects float tensors, not integer labels. The explicit type conversion prevents cryptic runtime errors that can be hard to trace back to their source.

Quick reference table:

TaskLoss FunctionModel OutputTarget Format
RegressionMSELossFloat (any range)Real numbers
Multi-classCrossEntropyLossLogits (unbounded)Class indices
Multi-labelBCEWithLogitsLossLogits (unbounded)Binary vectors

Optimizers: How Your Model Actually Learns

The optimizer takes gradients and decides how to move the weights. Different optimizers move differently, some are faster, some are steadier, some handle certain landscapes better. Think of an optimizer as a strategy for rolling a ball down a complex landscape: basic gradient descent rolls straight downhill (but gets stuck), momentum smooths the path by considering previous direction, and adaptive methods like Adam adjust step size per dimension.

Optimizer Comparison

Choosing an optimizer shouldn't be a mystery. Here's the practical breakdown of the three you'll actually use.

SGD with momentum is the oldest and in some ways most principled approach. You compute the gradient, blend it with your previous update direction, and take a step. The advantage is transparency, you know exactly what's happening to your weights at each step. The disadvantage is that it's sensitive to the learning rate. Too high, and training oscillates; too low, and it crawls. For large-scale vision tasks (think ResNet on ImageNet), SGD with momentum and a carefully tuned schedule often achieves the best final accuracy, but it takes more effort to tune.

Adam changed the game by adapting the effective learning rate per parameter. Parameters that get large, noisy gradients have their updates dampened automatically. Parameters with small but consistent gradients get amplified updates. This makes Adam remarkably robust across a huge variety of tasks and architectures without extensive hyperparameter tuning. The tradeoff is that it can sometimes generalize slightly worse than SGD, because the adaptive scaling can allow the model to "overfit" to the optimization landscape in ways that don't transfer to new data.

AdamW fixes a subtle bug in Adam's implementation of weight decay. In standard Adam, L2 regularization gets entangled with the adaptive scaling, effectively making weight decay weaker for parameters that receive large gradients. AdamW decouples weight decay from the gradient update, applying it directly to the weights before the gradient step. The result is that weight decay actually does what you intend, which usually means better regularization and better generalization. When in doubt, use AdamW over Adam.

SGD with Momentum: The Classic

What it does: Update weights in the direction of the gradient, but with "momentum", previous updates influence current ones.

python
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

Why momentum? Gradients are noisy. Momentum smooths the path, accelerating in consistent directions and damping oscillations. Think of rolling a ball downhill, it gains speed but doesn't bounce around wildly. Mathematically, momentum keeps a running average of gradients and uses that for updates. The momentum parameter (0.9 is standard) controls how much history to remember.

Performance note: Momentum SGD often converges faster than vanilla SGD, especially in narrow valleys where gradients oscillate. It's also more resistant to noisy gradients.

When to use: When you have intuition about learning rate and want maximal control. Used in many production systems. If you're fine-tuning a pre-trained model, SGD + momentum often wins over Adam.

python
# Typical usage
for epoch in range(100):
    for X, y in train_loader:
        logits = model(X)
        loss = criterion(logits, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

The code is identical to what you'd write for any other optimizer, that's the beauty of PyTorch's optimizer API. The difference is entirely in how optimizer.step() computes the weight update internally. You can swap between SGD, Adam, and AdamW by changing a single line, which makes it easy to compare them empirically on your specific problem.

Adam: The Adaptive Default

What it does: Adapts the learning rate per parameter, considering both first moments (mean gradient) and second moments (variance of gradients).

python
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

Why Adam? It "just works" across a wide range of problems. It doesn't require careful learning rate tuning. Less fiddling, faster convergence in practice. Adam keeps exponential moving averages of both gradients and squared gradients, then uses these to compute an adaptive per-parameter learning rate.

How it helps: Parameters with small, consistent gradients get larger updates. Parameters with large, volatile gradients get smaller updates. This handles both smooth and steep parts of the loss landscape gracefully.

When to use: When you want good defaults. Start here unless you have a reason not to. Especially great for transformer models, where Adam and its variants dominate.

Caveat: Adam can generalize poorly on some tasks compared to SGD. If your validation loss plateaus at a suboptimal point, consider switching to SGD for fine-tuning. This is rare, but it happens.

python
# Adam for initial training
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
 
# ... train for a while ...
 
# Switch to SGD for fine-tuning (sometimes helps)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)

AdamW: Adam, But Fixed

What it does: Adam + proper weight decay (L2 regularization). Standard Adam bakes weight decay into the adaptive learning rate, which is wrong.

python
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

Why AdamW over Adam? Weight decay actually works. It prevents weights from growing unbounded without slowing gradient updates.

Modern recommendation: Use AdamW, not Adam.

python
# Current best practice
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=0.001,
    weight_decay=0.01  # Prevents overfitting
)

The weight_decay=0.01 parameter is a good default starting point. If your model is severely overfitting (large gap between training and validation loss), try increasing it to 0.1. If training is sluggish or underfitting, reduce it to 0.001 or 0.0001. Weight decay essentially shrinks weights toward zero at each step, which acts as a form of regularization that prevents any single parameter from dominating the model's predictions.

DataLoader and Batching: Why Size Matters

The batch size, how many samples you process before updating, shapes everything: memory, convergence speed, and generalization. Batch size affects how noisy your gradient estimates are. Larger batches give smoother gradients but require more memory. Smaller batches are noisier but often generalize better.

Think about why this noise can actually be helpful. When you estimate the gradient from 32 samples instead of the full dataset, you get a noisy estimate of the true gradient. That noise acts like a random perturbation that can knock the model out of sharp local minima, the kind of minima that have low training loss but poor generalization. Large batch training converges to "sharper" minima that tend to overfit more. This is sometimes called the "generalization gap" of large-batch training, and it's why many practitioners stick to batch sizes of 32 or 64 even when GPU memory could handle much more.

python
from torch.utils.data import DataLoader, TensorDataset
 
dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    pin_memory=True  # If using GPU
)
 
for X_batch, y_batch in train_loader:
    # X_batch shape: (32, feature_dim)
    # y_batch shape: (32,)
    pass

Why shuffle=True? If batches have sequential patterns (all class 0, then all class 1), gradients will be skewed. Shuffling breaks these patterns and makes gradient estimates more representative of the full dataset.

Why num_workers? Loading data is a bottleneck. Without it, your GPU sits idle waiting for data. Multiple workers load the next batch in parallel while your GPU trains on the current batch. Set this to 4-8 depending on your CPU and disk speed. Too many workers can cause excessive memory usage.

Why pin_memory? If using GPU, pinned (page-locked) memory is faster to transfer to GPU. It stays in RAM instead of being paged to disk, enabling faster GPU transfer. Minimal downside for typical batch sizes.

python
# Small dataset: batch_size=32
# Large dataset: batch_size=256
# Memory-constrained: batch_size=8
# Speed matters: batch_size=64-128
 
# Rule of thumb: larger batches (up to GPU memory) = faster training,
# but smaller batches (32-64) sometimes generalize better

Validation Loop: Measuring Real Performance

Training loss will always decrease, you're optimizing for it. But does the model generalize to unseen data? That's what validation measures. This is critical: a model that memorizes the training set will have zero training loss but fail on new data.

The validation loop is where you find out whether your model is actually learning the underlying patterns in the data or just memorizing the training examples. These are fundamentally different things, and training loss alone cannot tell them apart. Overfitting happens when a model is complex enough to fit noise in the training data, it gets every training example right but generalizes poorly because it's encoding the quirks of that particular dataset rather than the underlying signal. Validation loss exposes this immediately.

python
model.eval()  # Disable dropout, batch norm updates
eval_loss = 0.0
correct = 0
total = 0
 
with torch.no_grad():  # No gradient computation
    for X_val, y_val in val_loader:
        X_val, y_val = X_val.to(device), y_val.to(device)
 
        logits = model(X_val)
        loss = criterion(logits, y_val)
        eval_loss += loss.item()
 
        # Accuracy
        preds = logits.argmax(dim=1)
        correct += (preds == y_val).sum().item()
        total += y_val.size(0)
 
avg_val_loss = eval_loss / len(val_loader)
accuracy = correct / total
print(f"Val Loss: {avg_val_loss:.4f}, Accuracy: {accuracy:.4f}")
 
model.train()  # Re-enable dropout, batch norm

Notice the symmetry: the validation loop uses the exact same forward pass and loss computation as training, but without the gradient computation or weight update. The only thing we're doing differently is measuring, not learning.

Three critical things:

  1. model.eval(): Disable dropout and batch norm statistics updates. In training mode, dropout randomly zeros activations, and batch norm tracks running statistics. You want deterministic predictions on validation data.
  2. torch.no_grad(): Skip gradient computation. Validation doesn't update weights, so computing gradients wastes memory and computation. This decorator tells PyTorch to skip the backward pass entirely.
  3. model.train(): Re-enable training mode before the next epoch.

Why separate validation? Training loss is a poor proxy for generalization. Validation loss tells you if the model is overfitting. If training loss drops but validation loss increases, your model is memorizing. Early stopping (stop training when validation loss stops improving) prevents overfitting.

python
# Full epoch loop with validation
for epoch in range(num_epochs):
    # TRAINING
    model.train()
    train_loss = 0.0
    for X, y in train_loader:
        X, y = X.to(device), y.to(device)
        logits = model(X)
        loss = criterion(logits, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
 
    # VALIDATION
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for X_val, y_val in val_loader:
            X_val, y_val = X_val.to(device), y_val.to(device)
            logits = model(X_val)
            loss = criterion(logits, y_val)
            val_loss += loss.item()
 
    print(f"Epoch {epoch+1}: Train Loss={train_loss/len(train_loader):.4f}, Val Loss={val_loss/len(val_loader):.4f}")

Watch the gap between training loss and validation loss as training progresses. Early in training, both should decrease together. If they diverge, training loss keeps falling while validation loss plateaus or rises, that's your overfitting signal. Save the model checkpoint at the point of best validation loss, not at the end of training.

Learning Rate Scheduling: Adapting as You Train

A fixed learning rate rarely works for the entire training. Early on, you want aggressive updates to explore the loss landscape. Later, you want smaller updates for fine-tuning near the minimum. Schedulers automate this learning rate decay.

Learning Rate Strategies

There's a deeper principle behind learning rate scheduling that's worth understanding before looking at specific implementations. The loss landscape of a neural network is not a simple bowl, it's a high-dimensional surface with ridges, valleys, plateaus, and local minima. Early in training, you're somewhere random in this landscape. A large learning rate helps you move quickly and explore broadly, jumping over small local minima that might be traps.

As training progresses and loss decreases, you're presumably in a promising region of the landscape. Now a large learning rate becomes a liability, it causes you to overshoot the good regions and bounce around without settling. A smaller learning rate allows the optimizer to make precise adjustments, navigating the fine structure of the loss surface near a good minimum. The practical consequence is clear: you almost always want to train with a learning rate that starts moderate, optionally warms up briefly, then decays over time. The question is just which decay shape to use.

Warmup, starting with a very small learning rate and increasing it during the first few epochs or steps, is particularly useful when training from scratch or fine-tuning large models. It gives the optimizer time to establish good gradient statistics before committing to large weight updates. Many modern training recipes for transformers combine warmup with cosine decay, spending 5-10% of training on the warmup phase.

StepLR: Simple Decay

python
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=10,  # Every 10 epochs
    gamma=0.1  # Multiply LR by 0.1
)
 
for epoch in range(100):
    # Training loop
    for X, y in train_loader:
        # ... training code ...
        pass
 
    scheduler.step()  # Decay LR after each epoch

StepLR is the simplest scheduler you'll encounter, and its simplicity is a virtue when you're first getting started. Set it up once, let it decay the learning rate at fixed intervals, and move on. There's nothing subtle about the implementation or behavior.

Why? Empirically, dividing the learning rate by 10 every N epochs often works well. Simple, interpretable. After 10 epochs, LR drops to 10% of original. After 20, to 1%. This encourages coarse exploration early, fine-tuning late.

Practical tip: Choose step_size based on your dataset. If training 100 epochs, decay every 10-30 epochs. If training 10 epochs, decay every 3-5 epochs.

CosineAnnealingLR: Smooth Decay

python
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=100,  # Total epochs
    eta_min=1e-6  # Minimum LR
)

The smooth cosine decay is what most practitioners default to today. Unlike StepLR's abrupt drops, the gradual change of a cosine curve means there's never a sudden shift in optimizer behavior that might destabilize training. The learning rate flows naturally from its initial value down to eta_min, spending more time at medium values than at extremes.

Why? Learning rate decays smoothly following a cosine curve, starting high and gradually reducing to eta_min. Often converges to better minima than abrupt steps. Modern favorite for vision and NLP tasks. The cosine schedule is theoretically motivated: it avoids the sharp drops of StepLR that can disrupt training.

OneCycleLR: Cyclical Scheduling

python
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=0.1,
    total_steps=len(train_loader) * num_epochs,
    pct_start=0.3  # Spend 30% increasing LR, 70% decreasing
)
 
for epoch in range(num_epochs):
    for X, y in train_loader:
        # Training loop
        logits = model(X)
        loss = criterion(logits, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
        scheduler.step()  # Step EVERY BATCH, not every epoch

OneCycleLR is the scheduler that surprised researchers when it was first proposed, the idea of deliberately increasing the learning rate partway through training seemed counterintuitive. But it works, often achieving results in half the training time of more conservative schedules.

Why? Start low, ramp to max, then decay. The cycle helps escape local minima and often achieves better final loss. Leslie Smith's research (Super-Convergence paper) shows this works surprisingly well, often matching multi-phase schedules in half the training time.

The mechanism: High learning rates early help explore the landscape and escape poor local minima. Peak learning rate pushes the model through the "sweet spot." Gradual decay fine-tunes near the final minimum. The single cycle often outperforms constant or multi-step schedules.

Critical detail: OneCycleLR steps every batch, not every epoch. Most others step every epoch. This means you need to calculate total_steps = number_of_batches * number_of_epochs.

python
# Quick reference for scheduler.step() timing
# StepLR, CosineAnnealingLR: call after each epoch
# OneCycleLR: call after each batch
# ReduceLROnPlateau: call when validation metric plateaus

Gradient Clipping: Preventing Explosions

In some architectures (RNNs, especially), gradients can explode during backprop, sending weights to infinity and breaking training. Gradient clipping caps the gradient norm, preventing these explosions without losing gradient direction.

python
for epoch in range(num_epochs):
    for X, y in train_loader:
        logits = model(X)
        loss = criterion(logits, y)
        optimizer.zero_grad()
        loss.backward()
 
        # Clip gradients to max norm of 1.0
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
 
        optimizer.step()

The key thing to notice in the code above is placement: gradient clipping happens after loss.backward() but before optimizer.step(). By the time we clip, the gradients are computed and stored in each parameter's .grad attribute. The clip rescales those gradient tensors so their combined norm doesn't exceed max_norm, then optimizer.step() uses the clipped gradients for the weight update. If you called clipping before backward(), there would be nothing to clip.

Why? Some architectures (RNNs, LSTMs, Transformers early in training) produce unstable gradients. In RNNs, errors from the loss can be backpropagated over many steps, and the chain rule multiplies gradients. If the weight matrix has eigenvalues >1, gradients can grow exponentially. Clipping rescales the gradient vector to have norm at most max_norm, preserving direction while capping magnitude.

When to use: RNNs, LSTMs, Transformers. Less critical for CNNs. If training is unstable (loss goes to NaN, weights explode), try clipping.

python
# Check if you need it: inspect gradient norms during training
for name, param in model.named_parameters():
    if param.grad is not None:
        grad_norm = param.grad.norm().item()
        if grad_norm > 10:
            print(f"{name}: gradient norm = {grad_norm}")
        # If you see values >100, consider clipping

Run this diagnostic code for a few batches when you first start training a new architecture. If you see gradient norms consistently above 10 or 100, clipping with max_norm=1.0 is a sensible intervention. If norms are in the 0.1 to 5 range, you probably don't need it, though it rarely hurts to include as a safety measure.

PyTorch Lightning: Condensing the Boilerplate

Writing training loops is verbose. PyTorch Lightning abstracts the loop while keeping flexibility.

Once you've written the raw training loop enough times to understand every line, Lightning starts to look attractive. The boilerplate, moving tensors to device, switching between train and eval modes, zeroing gradients, tracking metrics, is repetitive and error-prone. Lightning handles all of it, letting you focus on what's unique to your model.

python
import pytorch_lightning as pl
 
class ImageClassifier(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
        self.criterion = nn.CrossEntropyLoss()
 
    def forward(self, x):
        return self.model(x)
 
    def training_step(self, batch, batch_idx):
        X, y = batch
        logits = self(X)
        loss = self.criterion(logits, y)
        return loss
 
    def validation_step(self, batch, batch_idx):
        X, y = batch
        logits = self(X)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("val_loss", loss)
        self.log("val_acc", acc)
 
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=0.001)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
        return [optimizer], [scheduler]
 
# Training
trainer = pl.Trainer(max_epochs=10, accelerator="gpu", devices=1)
model = ImageClassifier()
trainer.fit(model, train_loader, val_loader)

Comparing this Lightning code to the raw loop above is instructive. The same logical steps are present, forward pass, loss computation, backward pass, optimizer step, but Lightning infers from the return value of training_step that it should call .backward() and .step(). The configure_optimizers method replaces the manual optimizer and scheduler setup. What you lose is visibility; what you gain is clean, maintainable code that's hard to get wrong.

Why Lightning?

  • Handles device placement, mixed precision, multi-GPU automatically
  • Cleaner code structure
  • Built-in logging, checkpointing, early stopping
  • Still lets you write custom training logic in training_step()

When to use: Once you're comfortable with raw loops, Lightning is great for production code. For learning, stick with raw loops, you'll understand what's happening.

Putting It All Together: A Complete Example

python
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
 
# Model
model = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(64, 10)
)
 
# Loss, optimizer, scheduler
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-6)
 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
 
# Create dummy data
X_train = torch.randn(1000, 784)
y_train = torch.randint(0, 10, (1000,))
X_val = torch.randn(200, 784)
y_val = torch.randint(0, 10, (200,))
 
train_loader = DataLoader(
    TensorDataset(X_train, y_train),
    batch_size=32,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)
 
val_loader = DataLoader(
    TensorDataset(X_val, y_val),
    batch_size=32,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)
 
# Training
num_epochs = 20
for epoch in range(num_epochs):
    # Training phase
    model.train()
    train_loss = 0.0
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
 
        # Forward, loss, backward, step
        logits = model(X_batch)
        loss = criterion(logits, y_batch)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
 
        train_loss += loss.item()
 
    # Validation phase
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            logits = model(X_batch)
            loss = criterion(logits, y_batch)
            val_loss += loss.item()
 
            preds = logits.argmax(dim=1)
            correct += (preds == y_batch).sum().item()
            total += y_batch.size(0)
 
    scheduler.step()
 
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"  Train Loss: {train_loss/len(train_loader):.4f}")
    print(f"  Val Loss: {val_loss/len(val_loader):.4f}")
    print(f"  Val Acc: {correct/total:.4f}")
    print(f"  LR: {optimizer.param_groups[0]['lr']:.6f}")

Every line in this complete example is there for a reason. The model uses Dropout to regularize. AdamW handles adaptive optimization with proper weight decay. CosineAnnealingLR decays the learning rate smoothly. Gradient clipping acts as a safety net against instability. The DataLoader is tuned for GPU training with worker processes and pinned memory. The epoch loop tracks both training and validation metrics so you can detect overfitting early.

What this does:

  1. Adam optimizer with weight decay for stable, general-purpose training
  2. Cosine annealing scheduler to smoothly decay learning rate
  3. Dropout in the model to prevent overfitting
  4. Gradient clipping (1.0 norm) as a safety measure
  5. DataLoader with shuffling, num_workers, and pin_memory for efficient batching
  6. Separate train/eval modes with validation loop
  7. Printed metrics to track convergence

This is the template you'll use (with variations) for most classification tasks.

Debugging Your Training Loop

Training broken? Loss stuck or NaN? Here's how to diagnose:

Loss goes to NaN immediately:

  • Learning rate too high (gradients explode)
  • Try dividing by 10, or enable gradient clipping
  • Check for numerical instabilities (log of negative, division by zero)

Loss decreases but validation plateaus:

  • Overfitting. Add dropout, regularization, or reduce model size
  • Try data augmentation to inject noise
  • Enable early stopping to halt before overfitting

Loss oscillates wildly:

  • Learning rate too high (still)
  • Batch size too small (noisy gradients)
  • Try larger batch size or smaller learning rate

Model trains fast but validation is random:

  • Possible bug in validation logic (dataset mismatch, wrong labels)
  • Verify validation data is truly separate from training
  • Check that model.eval() and torch.no_grad() are set

GPU out of memory:

  • Reduce batch size
  • Reduce model size (fewer parameters)
  • Enable gradient accumulation (update every N batches, not every batch)

Key Takeaways

  • Loss function defines your task: MSELoss for regression, CrossEntropyLoss for classification, BCEWithLogitsLoss for multi-label.
  • Optimizer is how you move: Adam/AdamW for most cases, SGD with momentum if you want fine-grained control.
  • Learning rate scheduling adapts the step size: StepLR is simple, CosineAnnealingLR is modern, OneCycleLR is research-backed.
  • Batch size trades off speed and generalization: larger = faster, smaller = often better regularization.
  • Validation loop measures generalization: use model.eval(), torch.no_grad(), and separate metrics.
  • Gradient clipping prevents explosions: essential for RNNs, optional for CNNs.
  • PyTorch Lightning abstracts the loop: great once you understand the mechanics.

The training loop is where all your design choices come together. What we've covered here is not just API documentation, it's the reasoning that lets you adapt when things go wrong, tune when defaults aren't good enough, and understand what your model is actually doing as it learns. Start with the complete template above, watch your training and validation curves, and adjust based on what you see. Loss going NaN? Check learning rate and gradient norms. Validation plateauing early? Add regularization or lower the learning rate. Converging slowly? Consider OneCycleLR or a larger initial learning rate. Every problem has a systematic fix, and now you know where to look.

Master this, and you can train anything.

Need help implementing this?

We build automation systems like this for clients every day.

Discuss Your Project