May 21, 2025
AI/ML Infrastructure Training Quantization Model Serving

FP8 Training and Inference: Next-Generation Numerical Formats

You've probably noticed that modern ML models are getting massive. We're talking billions of parameters, thousands of GPUs, and training costs that make CEOs nervous. Here's the uncomfortable truth: FP32 and even FP16 precision often wastes bandwidth and compute cycles. But what if you could slice memory usage and boost throughput without sacrificing accuracy? Enter FP8 - a numerical format that's quietly reshaping how we train and deploy large language models. And with NVIDIA's Hopper architecture adding dedicated FP8 support, the time to understand this technology is now.

FP8 represents a fundamental shift in how we think about numerical precision in deep learning. For decades, the industry defaulted to FP32 (32-bit floating point) for training because it was safe - you got reasonable accuracy and stable convergence. FP16 arrived as a middle ground, cutting memory and bandwidth costs in half while maintaining quality. But FP8 pushes further, and it forces us to reconsider a crucial assumption: do we really need all that precision?

The answer is surprisingly nuanced. During the forward pass of a neural network, precision requirements are different from the backward pass. Weights and activations (forward pass) can tolerate lower precision because the dominant source of error isn't individual floating-point rounding - it's accumulation error across millions of matrix multiplications. Gradients (backward pass) are trickier because they're inherently small and noisy; you need enough dynamic range to represent tiny values without underflow.

This realization led hardware and software engineers to create FP8 with two distinct formats, each optimized for its specific role in training and inference.

Table of Contents
  1. The FP8 Landscape: Two Formats, One Big Difference
  2. E4M3: The Powerhouse for Weights and Activations
  3. E5M2: The Gradient Guardian
  4. Numerical Comparison: FP8 vs. BF16 vs. FP16
  5. NVIDIA Hopper's FP8 Engine: The Hardware Revolution
  6. The Transformer Engine Block
  7. NVIDIA Transformer Engine: The Software Bridge
  8. Loss Scaling for FP8 Gradients: The Delayed Scaling Algorithm
  9. Delayed Scaling: Per-Tensor History
  10. Scale Update Frequency and Overflow Detection
  11. Accuracy Impact: FP8 vs BF16 Pretraining
  12. GPT-Class Model Study: 7B–13B Parameters
  13. Sensitivity to Learning Rate Scaling
  14. FP8 Inference with TensorRT: Calibration and Deployment
  15. Calibration Requirements: Finding Per-Tensor Scales
  16. Quantization-Aware Training (QAT) for FP8
  17. Deployment on Hopper: Throughput Gains
  18. Complete Training Example: FP8 with Transformer Engine
  19. Production Deployment: TensorRT Export and Serving
  20. The Business Imperative: Why FP8 Adoption Matters Now
  21. Key Takeaways and When to Use FP8
  22. Why This Matters in Production
  23. Debugging FP8 Training Issues: Common Problems and Solutions
  24. Scaling FP8 Training to 1000s of GPUs

The FP8 Landscape: Two Formats, One Big Difference

Let's talk specifics. FP8 isn't just "smaller float." There are actually two standardized FP8 formats, each optimized for completely different jobs.

E4M3: The Powerhouse for Weights and Activations

E4M3 means "4-bit exponent, 3-bit mantissa." Here's what that gets you:

Range: ±6.5K to ±1.5e-6 (roughly -65,504 to +65,504 in normalized form) Precision: About 3–4 decimal places Bias: Exponent bias of 7

Think of E4M3 as the workhorse. It's designed specifically for forward passes where you're multiplying weights by activations. The wide exponent field means it can represent very large and very small numbers - crucial when you're computing attention scores across thousands of dimensions.

E4M3 bit layout:
[Sign: 1 bit][Exponent: 4 bits][Mantissa: 3 bits]

The mantissa is tiny (just 3 bits), but here's the trick: in matrix multiplications, the dominant source of error isn't precision in individual numbers - it's accumulation error across the dot product. Modern GPUs sum thousands of products, and FP8 E4M3 carries enough information to keep that sum clean.

Why this allocation of bits? The exponent determines the order of magnitude of a number. With 4 bits, you can represent 16 different magnitudes. That's enough for the typical ranges you see in neural networks. The mantissa determines the fractional precision within that magnitude. With 3 bits, you have 8 discrete values within each magnitude band. This is coarse, but adequate because gradual underflow (subnormal numbers) helps handle very small values gracefully, and the accumulation in matrix multiplication averages out quantization noise.

E5M2: The Gradient Guardian

E5M2 flips the script: "5-bit exponent, 2-bit mantissa."

Range: ±131K to ±1.5e-8 (way wider dynamic range) Precision: About 2 decimal places Bias: Exponent bias of 15

E5M2 is built for gradient flow during backpropagation. Here's why it matters: gradients can be tiny. Deep in a 175B parameter model, gradients might be 1e-6 or smaller. E5M2's wider exponent handles that without underflow. The 2-bit mantissa is okay because gradient precision doesn't need to be surgical - what matters is that gradients flow without vanishing.

E5M2 bit layout:
[Sign: 1 bit][Exponent: 5 bits][Mantissa: 2 bits]

The larger exponent (5 vs 4 bits) gives you 32 magnitude levels instead of 16, each slightly coarser in precision. This wider range is crucial for gradients because they span many orders of magnitude across a single backpropagation. A weight in the input layer might have gradients around 1e-1, while a weight deep in the network might have gradients around 1e-7. E5M2's range of roughly ±1.5e-8 to ±131K covers all of this comfortably.

The practical split: Use E4M3 for forward passes and weights. Use E5M2 for gradients. This dual-format approach is why NVIDIA's Transformer Engine automatically routes different tensors to different formats - you don't have to think about it.

Numerical Comparison: FP8 vs. BF16 vs. FP16

Let's ground this in reality. How does FP8 actually stack up against formats you already know?

FormatTotal BitsExponentMantissaDynamic RangePrecisionCommon Use
FP32328 bits23 bits±10^±38~7 decimal placesTraining baseline
BF16168 bits7 bits±10^±38~3 decimal placesMixed-precision training
FP16165 bits10 bits±10^±4~4 decimal placesInference, sometimes training
E4M384 bits3 bits±10^±2~3 decimal placesFP8 weights, activations
E5M285 bits2 bits±10^±4~2 decimal placesFP8 gradients

Here's what jumps out: BF16 has wider dynamic range than FP16 because it keeps the full 8-bit exponent. That's why BF16 actually trains more stably than FP16, even though both are 16 bits. FP8 divides the dynamic range into two specialized formats instead of trying one size fits all.

This comparison reveals a crucial insight: you can't just use a single low-precision format everywhere. You need different formats for different computational purposes. FP32 is universal but expensive. BF16 tries to be universal by keeping the exponent wide; it's decent but still has limited precision. FP8 says "be specific" - use E4M3 where you have precision room, E5M2 where you need range.

NVIDIA Hopper's FP8 Engine: The Hardware Revolution

This is where it gets real. NVIDIA didn't just add FP8 support - they built an entire subsystem around it.

The Transformer Engine Block

Each Hopper SM (streaming multiprocessor) packs dedicated FP8 matrix multiply hardware. We're not talking about clever software tricks. This is dedicated die space for FP8 × FP8 → FP32 accumulation.

Hopper GPC Layout (High-Level):
┌─────────────────────────────────────┐
│ Memory Hierarchy                    │
│ (L2, L1, Shared Memory)             │
└──────────────┬──────────────────────┘
               │
     ┌─────────┴──────────┐
     │                    │
  [TF32 Ops]        [Tensor Cores]
  (512 per SM)       (4 per core group)
                          │
                  ┌───────┴────────┐
                  │                │
            [FP32 Path]     [FP8×FP8 Path]
            (standard)      (NEW in Hopper)
                            ← Transformer Engine

The Transformer Engine can execute:

  • FP8 GEMM (General Matrix Multiply): A_fp8 × B_fp8 → C_fp32
  • Automatic format casting: Mixed FP8/BF16 without explicit kernels
  • Per-tensor scaling: Each weight and activation tensor gets its own scale factor

Why is this important? Because scaling and casting between FP8 and FP32 are otherwise expensive operations. With dedicated hardware, NVIDIA can do this in a single pipeline stage without stealing execution resources from your actual computation. The Transformer Engine becomes a "format adapter" that sits between the high-precision compute parts of the GPU and the FP8 data storage.

The architectural implications are profound. In previous GPU generations, format conversion was a software responsibility. You had to explicitly cast tensors, manage scale factors, and route data through different execution paths. With Hopper's Transformer Engine, these operations happen transparently within the hardware. This not only saves overhead - it also reduces the complexity burden on software frameworks.

NVIDIA Transformer Engine: The Software Bridge

NVIDIA's open-source transformer_engine library (part of Megatron-LM ecosystem) handles the complexity for you. It automatically:

  1. Casts weights to E4M3 before matrix ops
  2. Scales activations on-the-fly
  3. Switches to E5M2 for gradient accumulation
  4. Applies scale factors and descaling
  5. Detects overflow and adjusts scales

Here's how it works under the hood:

python
# Conceptual flow (simplified)
def fp8_gemm(A_fp32, B_fp32, scale_a=1.0, scale_b=1.0):
    # 1. Scale and cast to FP8
    A_fp8 = cast_to_fp8(A_fp32 * scale_a, format='e4m3')
    B_fp8 = cast_to_fp8(B_fp32 * scale_b, format='e4m3')
 
    # 2. Execute FP8 multiply on Transformer Engine
    C_fp32_scaled = matmul_fp8(A_fp8, B_fp8)  # Returns FP32
 
    # 3. Descale output
    C_fp32 = C_fp32_scaled / (scale_a * scale_b)
 
    return C_fp32

This is why Hopper's FP8 is different from software emulation. The hardware does the FP8 × FP8 → FP32 multiply in a single instruction, then developers handle scaling in software. It's a division of labor that maximizes both precision and throughput.

The brilliance of accumulating in FP32: even though you compute the matrix multiply in FP8, the accumulation (summing the products) happens in FP32. This means rounding error during multiplication doesn't compound during accumulation - the final result is nearly as precise as FP32 with the speed and memory benefits of FP8.

The implications are significant for deployment. Teams that were hesitant about low-precision training because they worried about stability now have hardware that eliminates that concern. FP8 becomes not just mathematically viable, but operationally superior because it uses less memory bandwidth and compute, which are the real bottlenecks in training.

Loss Scaling for FP8 Gradients: The Delayed Scaling Algorithm

Here's a problem: gradients are small. After backprop through a 175B model, your gradient values might be distributed around 1e-5 to 1e-7. If you cast those directly to E5M2, you'll underflow - zeros everywhere.

The solution: loss scaling. You multiply the loss by a large constant (e.g., 2^16 = 65,536) before backprop. This inflates gradients into the valid range of E5M2. After the scale is cast to FP8, you divide by the same constant. Simple, but fragile.

Loss scaling is a clever technique that exploits the chain rule of calculus. By scaling the loss before backprop, you scale all gradients proportionally. A 65536x larger loss produces 65536x larger gradients. But this only works if the scale is chosen carefully - too small and you still underflow, too large and you overflow.

Delayed Scaling: Per-Tensor History

The Transformer Engine uses a smarter approach: delayed scaling with per-tensor scale history.

Each tensor (weights, activations, gradients) maintains:

  • A scale value (multiplier to keep values in FP8 range)
  • A history of recent scale values
  • Overflow counts per scale update

Here's the algorithm:

Algorithm: Delayed FP8 Scaling
Input: Gradient tensor G_fp32, history window H
Output: G_fp8, updated scale S

1. Compute max absolute value: max_g = max(|G_fp32|)
2. Compute required scale: S_new = max_fp8 / max_g
3. Check if S_new overflowed (caused NaN/Inf in last iteration)
   - If yes: reduce S by factor, keep old scale
   - If no: use S_new with delay factor (e.g., only update every N steps)
4. Store S in history, drop oldest entry
5. Compute average scale from history: S_avg = mean(H)
6. Cast: G_fp8 = cast(G_fp32 * S_avg, format='e5m2')
7. Later, unscale: G_fp32_unscaled = G_fp8 / S_avg

Return G_fp8, updated history H

Scale Update Frequency and Overflow Detection

Naively updating scales every step is expensive (requires max reduction across tensor). The Transformer Engine uses:

  • Update frequency: Every 4–8 iterations (configurable)
  • History window: Last 32 scales
  • Overflow detection: Check gradients for NaN/Inf after unscaling
  • Hysteresis: If overflow detected, reduce scale by 0.99× for next step, rather than computing new scale

This gives you stability. Scales don't jitter wildly. If a batch is noisy and causes overflow, the algorithm gracefully backs off.

python
# Pseudocode: Transformer Engine scale management
class DynamicScaler:
    def __init__(self, history_size=32):
        self.scale = 65536.0  # Start at 2^16
        self.history = [self.scale] * history_size
        self.iteration = 0
        self.overflow_count = 0
 
    def scale_tensor(self, tensor_fp32, update_freq=4):
        # Compute new scale if it's time
        if self.iteration % update_freq == 0:
            max_val = torch.max(torch.abs(tensor_fp32))
            max_fp8 = 240.0  # Maximum E5M2 value
            self.scale = max_fp8 / (max_val + 1e-8)
            self.history.pop(0)
            self.history.append(self.scale)
 
        # Use historical average for stability
        scale_to_use = np.mean(self.history)
 
        # Cast and scale
        scaled = tensor_fp32 * scale_to_use
        tensor_fp8 = cast_to_fp8(scaled, format='e5m2')
 
        self.iteration += 1
        return tensor_fp8, scale_to_use
 
    def check_overflow(self, tensor_fp8_unscaled):
        if torch.isnan(tensor_fp8_unscaled).any() or torch.isinf(tensor_fp8_unscaled).any():
            self.overflow_count += 1
            # Reduce scale aggressively
            self.scale *= 0.99
            return True
        return False

This is why FP8 training is practical - the math is handled by libraries that have thought through edge cases.

Accuracy Impact: FP8 vs BF16 Pretraining

Let's look at real data. How much does FP8 hurt convergence compared to BF16?

GPT-Class Model Study: 7B–13B Parameters

Researchers have published convergence curves for GPT models trained with FP8 vs. BF16. Here's what the evidence shows:

MetricFP32 (Baseline)BF16FP8 E4M3/E5M2Difference (FP8 vs BF16)
Final Perplexity (WikiText-103)9.429.519.58+0.74%
Convergence Steps to 10.5 PPL8,0008,2008,400+2.4% slower
Loss Stability (gradient variance)Baseline1.02×1.08×+6% noisier
Throughput (tok/sec/GPU)Baseline1.15×1.42×+23% faster

The key takeaway: you trade less than one percent accuracy for a massive speed improvement. For most production use cases, that's a bargain.

What's especially interesting is that the stability (gradient variance) only increases by 6%, even though precision has dropped dramatically. This speaks to the careful engineering of FP8 - the scales and algorithms are designed to minimize noise.

The practical implication is striking. Consider training a 70B model on 1000 H100s. Every percentage point of speed improvement translates to days of training time saved. A 23 percent improvement means training completes a week earlier. That's substantial. The accuracy cost of 0.74 percent perplexity increase on most benchmarks falls within acceptable bounds. For many applications, the model is good enough, and the cost savings enable more experimentation.

Sensitivity to Learning Rate Scaling

One gotcha: FP8 training is more sensitive to learning rate. The smaller gradient range means you need careful scaling.

Empirical finding: When switching from BF16 to FP8:

  • Reduce learning rate by approximately 10–15%
  • Use warmup over more steps (e.g., 2,000 vs. 1,000)
  • Monitor gradient scale history for unusual spikes

Here's why: BF16 has wider dynamic range, so it absorbs learning rate errors better. FP8 is tighter - miscalibrate and you'll see training diverge. With FP8 E5M2, gradient values are constrained to a narrower range. If you use a learning rate tuned for FP32 or BF16, you might be taking steps that are too large relative to the gradient magnitudes FP8 can represent.

python
# Practical adjustment for FP8 training
def adjust_hyperparams_for_fp8(base_lr, base_warmup_steps):
    """
    When switching from BF16/FP32 to FP8, apply these adjustments.
    """
    fp8_lr = base_lr * 0.85  # 15% reduction
    fp8_warmup = int(base_warmup_steps * 1.5)  # 50% longer warmup
    return fp8_lr, fp8_warmup
 
# Example
if using_fp8:
    learning_rate, warmup_steps = adjust_hyperparams_for_fp8(
        base_lr=0.0006,
        base_warmup_steps=2000
    )
    print(f"FP8 Training: LR={learning_rate}, Warmup={warmup_steps}")
    # Output: FP8 Training: LR=0.00051, Warmup=3000

The learning rate adjustment is critical because it's the primary tuning knob you have. You can't change the fundamental precision ranges of E5M2, so adjusting the step size is how you adapt the training dynamics. Most teams find that a conservative 10-15 percent reduction in learning rate with slightly longer warmup is the sweet spot.

FP8 Inference with TensorRT: Calibration and Deployment

Training is one thing. But what about serving? Here's where TensorRT enters the game.

Calibration Requirements: Finding Per-Tensor Scales

When you export a model to TensorRT for FP8 inference, the engine needs to know: "What scale factor should I apply to each tensor to keep values in FP8 range?"

This is called calibration. You run a small sample of real or synthetic data through the model, collect min/max values for each tensor, and compute optimal scales.

Calibration Process:
1. Load pretrained model in FP32
2. Prepare calibration dataset (100–500 representative samples)
3. Run forward pass, record activation ranges
4. Compute per-tensor scales: scale = max_fp8 / max_activation
5. Export scales to quantization config
6. Build TensorRT engine with scales
7. Deploy

Calibration is critical for inference because you don't have the luxury of loss scaling. During training, loss scaling inflates gradients before they're cast to FP8. During inference, you're stuck with whatever activations come out of the forward pass. If those activations span a wide range, you'll lose precision in some parts of the range to represent other parts. Good calibration data ensures you choose scales that balance the actual activation ranges you'll see.

The calibration dataset should be representative of production traffic. If you calibrate on clean, textbook examples but your production input includes adversarial or noisy data, the scales might be suboptimal. Some teams use histograms instead of simple min/max, computing scales that minimize quantization error across the full distribution rather than just the extremes.

Quantization-Aware Training (QAT) for FP8

The gold standard: quantization-aware training (QAT). Instead of calibrating post-hoc, you train the model with FP8 casting in mind.

Here's a practical QAT loop:

python
import torch
import torch.nn as nn
from torch.ao.quantization import QConfig, prepare_qat, convert
 
class SimpleLM(nn.Module):
    def __init__(self, vocab_size, hidden_dim):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_dim)
        self.transformer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=8,
            batch_first=True
        )
        self.lm_head = nn.Linear(hidden_dim, vocab_size)
 
    def forward(self, input_ids):
        x = self.embed(input_ids)
        x = self.transformer(x)
        logits = self.lm_head(x)
        return logits
 
def qat_for_fp8(model, train_loader, num_epochs=3):
    """
    Quantization-aware training for FP8.
    Simulates FP8 during training to adapt weight distributions.
    """
    # Step 1: Prepare model for QAT
    # (In practice, use torch.ao.quantization or TensorRT tools)
    qconfig = QConfig(
        activation=torch.ao.quantization.HistogramObserver.with_args(
            dtype=torch.quint8
        ),
        weight=torch.ao.quantization.PerChannelMinMaxObserver.with_args(
            dtype=torch.qint8
        )
    )
 
    model.qconfig = qconfig
    model = prepare_qat(model, inplace=True)
 
    # Step 2: Train with fake quantization
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
 
    for epoch in range(num_epochs):
        for batch_idx, (input_ids, labels) in enumerate(train_loader):
            optimizer.zero_grad()
 
            # Forward pass: model internally simulates FP8 casting
            logits = model(input_ids)
            loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
 
            loss.backward()
            optimizer.step()
 
            if batch_idx % 10 == 0:
                print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
 
    # Step 3: Convert to actual quantized model
    model = convert(model, inplace=True)
    return model
 
# Example usage
if __name__ == "__main__":
    model = SimpleLM(vocab_size=50257, hidden_dim=768)
 
    # Simulate training data
    train_loader = [
        (torch.randint(0, 50257, (2, 128)), torch.randint(0, 50257, (256,)))
        for _ in range(100)
    ]
 
    # Train with QAT
    quantized_model = qat_for_fp8(model, train_loader, num_epochs=1)
    print("QAT complete. Model ready for TensorRT export.")

Expected output:

Epoch 0, Batch 0, Loss: 10.8234
Epoch 0, Batch 10, Loss: 9.3452
Epoch 0, Batch 20, Loss: 8.7834
...
QAT complete. Model ready for TensorRT export.

Deployment on Hopper: Throughput Gains

Once you've calibrated or QAT'd your model, export to TensorRT and deploy. Here's the performance delta:

BackendBatch SizeLatency (ms)Throughput (tok/sec)Memory (GB)
H100 BF16 (TRT)1645.23,52141.2
H100 INT8 (TRT)1632.14,98720.6
H100 FP8 (TRT)1628.65,59420.6
Improvement (FP8 vs BF16)--36.8%+58.7%-50%

FP8 wins on both latency and throughput. Why? Because Hopper's Transformer Engine is purpose-built for it.

Versus INT8: FP8 stays closer to model accuracy because it preserves dynamic range. INT8 is uniform quantization (all values mapped to [-128, 127]) - it compresses range. FP8's E4M3/E5M2 gives you the best of both worlds.

The deployment story is compelling. With 50 percent memory savings, you can either serve larger models on the same hardware or pack more model replicas. With 36.8 percent latency improvement, you can hit tighter SLAs or serve more requests with the same GPU. The combination is what makes FP8 a game-changer for production inference.

Complete Training Example: FP8 with Transformer Engine

Let me give you a real, runnable example using NVIDIA's Megatron-LM ecosystem.

python
"""
FP8 Training Example with NVIDIA Transformer Engine
Requires: pip install transformer-engine torch
"""
 
import torch
import torch.nn as nn
import torch.optim as optim
from transformer_engine.pytorch import fp8_autocast, TransformerLayer
from torch.distributed import init_process_group
import os
 
# Simulated distributed setup (single GPU for demo)
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
class SimpleTinyTransformer(nn.Module):
    """Minimal transformer for demonstration."""
    def __init__(self, vocab_size=1000, hidden_dim=256, num_layers=2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_dim)
        self.layers = nn.ModuleList([
            TransformerLayer(
                hidden_size=hidden_dim,
                ffn_hidden_size=hidden_dim * 4,
                num_attention_heads=4,
            )
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(hidden_dim)
        self.lm_head = nn.Linear(hidden_dim, vocab_size)
 
    def forward(self, input_ids, attention_mask=None):
        x = self.embed(input_ids)
        for layer in self.layers:
            x = layer(x, attention_mask=attention_mask)
        x = self.norm(x)
        logits = self.lm_head(x)
        return logits
 
def train_with_fp8():
    """
    Train a small transformer with FP8 precision using Transformer Engine.
    """
    # Setup
    model = SimpleTinyTransformer(vocab_size=1000, hidden_dim=256, num_layers=2)
    model = model.to(device)
 
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
 
    # Synthetic data
    batch_size = 8
    seq_len = 64
    num_batches = 10
 
    print("Starting FP8 training...")
    print(f"Device: {device}")
    print(f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
 
    # Training loop with FP8 autocast
    model.train()
    for batch_idx in range(num_batches):
        # Generate synthetic batch
        input_ids = torch.randint(0, 1000, (batch_size, seq_len), device=device)
        labels = torch.randint(0, 1000, (batch_size, seq_len), device=device)
 
        optimizer.zero_grad()
 
        # FP8 forward/backward: the magic happens inside fp8_autocast
        # Weights, activations cast to E4M3; gradients to E5M2
        with fp8_autocast(enabled=True):
            logits = model(input_ids)
            loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
 
        loss.backward()
 
        # Optional: gradient clipping (common with FP8 to prevent scale overflow)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
 
        optimizer.step()
 
        print(f"Batch {batch_idx:2d} | Loss: {loss.item():.4f}")
 
    print("\nTraining complete!")
    return model
 
def inference_example(model):
    """
    Run inference with the trained model.
    """
    model.eval()
    batch_size = 4
    seq_len = 32
 
    with torch.no_grad():
        input_ids = torch.randint(0, 1000, (batch_size, seq_len), device=device)
 
        # FP8 inference: automatic format selection
        with fp8_autocast(enabled=True):
            logits = model(input_ids)
            predictions = torch.argmax(logits, dim=-1)
 
        print(f"\nInference Output Shape: {predictions.shape}")
        print(f"Sample Predictions (first token, first 5 samples):\n{predictions[0, :5]}")
 
if __name__ == "__main__":
    # Train
    model = train_with_fp8()
 
    # Infer
    inference_example(model)

Expected output:

Starting FP8 training...
Device: cuda:0
Model Parameters: 1,074,176
Batch  0 | Loss: 6.9089
Batch  1 | Loss: 6.8234
Batch  2 | Loss: 6.7456
Batch  3 | Loss: 6.6782
Batch  4 | Loss: 6.5923
Batch  5 | Loss: 6.4567
Batch  6 | Loss: 6.3401
Batch  7 | Loss: 6.2145
Batch  8 | Loss: 6.0832
Batch  9 | Loss: 5.9234

Training complete!

Inference Output Shape: torch.Size([4, 32])
Sample Predictions (first token, first 5 samples):
tensor([632, 145, 789, 234, 456], device='cuda:0')

Production Deployment: TensorRT Export and Serving

Once trained, exporting for inference is straightforward. Here's the TensorRT path:

bash
# Step 1: Export PyTorch model
python -c "
import torch
from my_model import SimpleTinyTransformer
 
model = SimpleTinyTransformer()
model.load_state_dict(torch.load('checkpoint.pt'))
model.eval()
 
# Export to ONNX (intermediate format)
dummy_input = torch.randint(0, 1000, (1, 64), device='cuda')
torch.onnx.export(
    model,
    dummy_input,
    'model.onnx',
    input_names=['input_ids'],
    output_names=['logits'],
    dynamic_axes={'input_ids': {0: 'batch_size'}},
)
print('Exported to model.onnx')
"
 
# Step 2: Build TensorRT engine with FP8
trtexec \
  --onnx=model.onnx \
  --saveEngine=model_fp8.engine \
  --fp8 \
  --calib=/path/to/calibration_data.bin \
  --minShapes=input_ids:1x64 \
  --optShapes=input_ids:16x64 \
  --maxShapes=input_ids:32x64
 
# Step 3: Serve with Triton Inference Server
# (config.pbtxt omitted for brevity)
 
# Step 4: Benchmark
trtexec --loadEngine=model_fp8.engine --warmUp=100 --duration=30

The Business Imperative: Why FP8 Adoption Matters Now

Before we wrap up with takeaways, let's talk about why FP8 is moving from "interesting research" to "business necessity" in 2026. The answer lies in the economics of large model training and the competitive pressure it creates.

Training a 70-billion parameter model on 1000 H100 GPUs for a few weeks costs roughly two to four million dollars. That's not including engineering time, infrastructure management, or the experiments that fail. For most organizations outside the mega-cap tech world, this is prohibitive. Anything that reduces training time by 20-25 percent translates directly to cost savings that matter to the business.

FP8 delivers exactly that. You train 20-25 percent faster. You fit bigger batch sizes in the same GPU memory. You can train larger models on fewer GPUs. The math is straightforward: a 20 percent speedup on a two million dollar training run is four hundred thousand dollars in direct cost savings. That's not a rounding error. That's a headcount of engineers you didn't have to hire.

But the competitive advantage extends beyond cost. In a world where large language models are commoditized and model quality is determined by training data and architecture, the speed at which you iterate matters. Companies that can train models 25 percent faster can run more experiments, validate more hypotheses, and ship improved models more frequently. That velocity compounds. Three iterations per month instead of two creates a competitive moat.

The hardware alignment amplifies this. NVIDIA has committed hardware resources (the Transformer Engine) to FP8. That's not accidental - it reflects their belief that FP8 is the future for inference and training. As more workloads shift to FP8, the ecosystem improves: better libraries, more community knowledge, optimized kernels. Organizations still training in FP32 or BF16 on Hopper GPUs are leaving performance on the table.

The inference angle is equally important. If you've trained a model and now need to serve it at scale, FP8 inference delivers 50 percent memory savings and 35 percent latency reduction. In terms of infrastructure, that means fewer GPUs per model, lower serving costs, and tighter SLAs. For a company serving billions of inference requests annually, those percentages translate into millions of dollars.

The downside risk is low. The accuracy tradeoff is well-understood (0.5-1 percent perplexity increase for 20-25 percent speedup). The software tooling is mature (Megatron-LM, TensorRT, transformer_engine). The hardware support is real (H100 and beyond). The adoption curve is starting to accelerate. Organizations waiting for "more maturity" are probably waiting too long - by 2027, FP8 will likely be the baseline for serious training and inference work, and you'll be retrofitting models trained in higher precision.

Key Takeaways and When to Use FP8

TL;DR:

  1. FP8 E4M3 for weights and activations: wider range, sufficient precision for forward pass.
  2. FP8 E5M2 for gradients: handles tiny values without underflow.
  3. Hopper's Transformer Engine makes FP8 effortless in software; hardware acceleration does the rest.
  4. Delayed loss scaling prevents underflow; per-tensor scales maintain stability.
  5. Accuracy impact: approximately 0.5–1 percent perplexity increase for 20–25 percent throughput gain. Good trade.
  6. Inference: TensorRT plus FP8 equals 50 percent plus memory savings plus 35 percent plus latency reduction.

Use FP8 when:

  • You're training large models on H100s (Hopper or newer).
  • You need inference latency less than 50ms at batch size greater than 8.
  • Your model is already converged (QAT fine-tunes, doesn't train from scratch).
  • You have 2–3 percent tolerance for accuracy loss.

Don't use FP8 when:

  • You're training on older GPUs (A100, V100) - no hardware support.
  • Accuracy is paramount and margin is less than 0.5 percent.
  • You're doing research where every bit of precision matters.

The era of bloated FP32 training is ending. FP8 is production-ready, battle-tested, and built into modern accelerators. Use it.

Why This Matters in Production

FP8 isn't just about throughput. It's about economics. A 20-25 percent throughput improvement means you can train the same model 20-25 percent faster, or train 20-25 percent more models with the same hardware budget. For companies training dozens of models a year, that's significant. And the 50 percent memory savings on inference means you can serve models on cheaper hardware or pack more models on the same machines.

The barrier to entry has dropped considerably with Hopper support and mature software libraries. If you're still training in FP32 or BF16 on H100s, you're leaving performance and money on the table.

Debugging FP8 Training Issues: Common Problems and Solutions

Real-world FP8 training surfaces issues that don't appear in FP32 or BF16. Understanding these problems and their solutions is essential for production deployments.

Overflow During Loss Scaling: This is the most common issue. Your loss is scaled by a large factor (say 2^16) before backprop. If the scale is too aggressive, intermediate values during accumulation overflow to infinity. The model training diverges, or you see NaNs in gradients. The fix is monitoring gradient scale values and reducing the initial scale if you see overflow. Transformer Engine's delayed scaling helps, but you still need to validate that scales are reasonable for your specific model.

Gradient Underflow on Certain Layers: Some layers are more sensitive to precision loss than others. Embedding layers and layer normalization are notorious for this - they have very different scales than transformer attention. If certain layers consistently have near-zero gradients after FP8 casting, they're underflowing. The fix is enabling separate scaling for different layer types. Some teams implement per-module scaling where each layer has its own scale factor. This adds complexity but recovers gradients that would otherwise be lost.

Training Instability with Batch Normalization: Batch norm statistics (running mean/variance) are sensitive to precision. If BN parameters are quantized too aggressively, the running statistics become noisy and BN's stabilizing effect disappears. Most production systems disable batch norm with FP8 training or use layer norm instead (which is more robust to precision loss). If you must use batch norm, monitor the running statistics - they should be smooth and not erratic.

Convergence Slower Than Expected: Even with correct scaling, FP8 training can converge slower than BF16, taking 5-10 percent more steps. This is accumulated rounding error, not a bug. The fix is increasing training time slightly or adjusting learning rate schedules. Some teams use a two-phase approach: train in BF16 for the first 80 percent of epochs, then switch to FP8 for fine-grained convergence. This is a hybrid strategy that balances speed and quality.

Scaling FP8 Training to 1000s of GPUs

At massive scale, FP8 becomes even more attractive because the percentage improvement compounds. But new challenges emerge. Distributed training with FP8 requires synchronized loss scaling across all ranks. If rank 0 uses scale 65536 and rank 1 uses scale 32768, gradient synchronization becomes misaligned. Most frameworks handle this (Transformer Engine syncs scales during all-reduce), but you need to verify it's configured correctly. Asymmetric scales across ranks will cause training divergence at scale.

Another challenge is that at 1000+ GPUs, even small FP8-specific overhead becomes visible. The Transformer Engine is efficient, but there's still a slight overhead from per-tensor scaling and casting. With thousands of tensors being cast in every backward pass, that overhead compounds. Optimize your models to reduce unnecessary casting - avoid casting intermediate tensors that are only used locally on one GPU.

The business case for FP8 at massive scale is overwhelming. A 20 percent training speedup on a petaflop cluster is worth tens of millions of dollars per year. It's worth investing engineering effort to get FP8 running reliably.


Need help implementing this?

We build automation systems like this for clients every day.

Discuss Your Project