August 1, 2025
AI/ML Infrastructure Inference Batching Model Serving

Batch Inference at Scale: Processing Millions of Records

You've trained a killer machine learning model. It works great on a few samples, but now you need to run predictions across millions of records in your data warehouse. That's when batch inference becomes your best friend - and your biggest architectural challenge.

Batch inference isn't just about running predictions in bulk. It's about designing systems that process massive datasets efficiently while maintaining reliability, minimizing costs, and actually finishing in a reasonable timeframe. Whether you're scoring customer records for risk assessment, generating embeddings for your entire product catalog, or running daily recommendation updates, getting batch inference right separates teams that move fast from teams buried under compute bills.

Let's dig into how to architect batch inference systems that actually scale.

Table of Contents
  1. Why Batch Inference Matters in Production
  2. The Batch Inference Problem
  3. Choosing Your Pattern: A Decision Framework
  4. Architecture Patterns for Large-Scale Batch Inference
  5. Pattern 1: Spark MLlib Map-Reduce
  6. Pattern 2: Ray Data Distributed Inference
  7. Pattern 3: Argo Workflows Parallelism
  8. Why Throughput Matters: The Economics of Scale
  9. Throughput Optimization: Squeezing Every Ounce
  10. Dynamic Batching
  11. Async Preprocessing
  12. Model Compilation for Batch Workloads
  13. Ray Data for ML Inference: Practical Deep Dive
  14. The Full Pipeline
  15. GPU Actor Pool Configuration
  16. Fault Tolerance with Task Retry
  17. Production Considerations: Beyond Just Processing
  18. Output Management at Scale
  19. Partitioned Parquet with Timestamp Prefixes
  20. Schema Evolution Handling
  21. Idempotent Writes with Delta Lake
  22. Downstream Notification
  23. Monitoring and Visibility: Your Window into Production
  24. Scheduling and Monitoring
  25. Airflow BatchSensor for Triggering Jobs
  26. Progress Tracking via Prometheus Metrics
  27. Cost-Per-Record Reporting
  28. Architecture Diagram: Full Stack
  29. Putting It All Together: Complete Example
  30. The Economics of Batch Inference at Scale
  31. Real-World Lessons: What We Learned the Hard Way
  32. Scaling Strategies: From Single Machine to Distributed Clusters
  33. Observability and Cost Control in Production
  34. Key Takeaways

Why Batch Inference Matters in Production

Before we jump into architecture, let's understand why batch inference deserves serious attention. When you move from prototype to production, the economic and operational realities change dramatically. A model that predicts beautifully on a single GPU suddenly needs to process millions of records cost-effectively, on schedule, and without melting your infrastructure.

Batch inference appears simple on the surface: load data, apply model, save results. But simplicity is deceptive. Real production systems face a constellation of challenges that naive implementations crumble under. Your dataset might not fit in memory. Your model might need GPU acceleration that's expensive to parallelize. You might run the same batch multiple times (due to failures or late-arriving data), and you can't afford duplicates. Network latency becomes a killer when transferring billions of data points. And somewhere in the middle of processing 100 million records, the job crashes - and you need to restart without redoing work you've already done.

The winners build infrastructure that anticipates these challenges from day one. They separate compute concerns (parallelization, GPU management) from data concerns (ingestion, output organization). They instrument everything so they can detect anomalies early. They design for idempotence so re-runs don't corrupt results. And they never, ever assume their first implementation will be fast enough.

The Batch Inference Problem

Here's the challenge: your model runs fast on individual samples, but running one prediction at a time across 100 million records means you're leaving your GPU idle between calls. You'll spend most of your time on overhead, not actual computation. Meanwhile, your cloud bill keeps climbing, and your batch job won't finish until next Tuesday.

Traditional approaches break down quickly. A simple Python loop? Forget it - you're looking at weeks of compute time. Moving data-pipelines-training-orchestration)-fundamentals) over the network row-by-row? Disaster. No parallelization? Your single machine becomes the entire bottleneck.

What you need is a system that:

  • Maximizes hardware utilization through batching and parallelism
  • Scales horizontally across multiple machines and GPUs
  • Tolerates failures without losing progress or duplicating work
  • Manages memory efficiently when datasets exceed available RAM
  • Produces organized output that downstream systems can consume
  • Provides visibility into what's happening and what it costs

Choosing Your Pattern: A Decision Framework

Before we dive into patterns, understand that there's no universal winner. Each pattern evolved for different problems and operating constraints. Spark emerged from the big data ecosystem where SQL integration was paramount. Ray arrived later, designed natively for ML with GPU scheduling baked in from the start. Argo represents the Kubernetes-native approach: maximum control, maximum verbosity.

The smart choice depends on your current infrastructure, team expertise, and specific constraints. If you're already running a data warehouse on Spark, moving away from it often costs more than working within its limitations. If you're a Python shop comfortable with distributed systems, Ray might feel more natural. If you're purely Kubernetes-based and need explicit control, Argo's DAG model gives you that.

The worst mistake is choosing based on hype or simplicity alone. Each pattern has operational costs that reveal themselves after deployment-production-inference-deployment). Spark's JVM overhead bites you on memory-constrained clusters. Ray's memory management surprises you at scale. Argo's manual coordination creates maintenance burdens.

This section walks you through the trade-offs so you can make an informed choice for your specific constraints.

Architecture Patterns for Large-Scale Batch Inference

There are three proven patterns for handling batch inference at scale. Each solves different problems and has distinct trade-offs.

Pattern 1: Spark MLlib Map-Reduce

Spark's approach is battle-tested and operationally mature. You read data into RDDs or DataFrames, broadcast your model, and map predictions across partitions. Spark handles the distribution, shuffling, and failure recovery transparently.

python
from pyspark.sql import SparkSession
from pyspark.ml import PipelineModel
import torch
from pyspark.sql.functions import col, explode_outer
 
spark = SparkSession.builder.appName("batch-inference").getOrCreate()
 
# Read your dataset
df = spark.read.parquet("s3://data-lake/customer_features/2024/")
print(f"Dataset shape: {df.count()} rows, {len(df.columns)} columns")
 
# Load pre-trained model
model = PipelineModel.load("/models/customer_model/v2")
 
# Map predictions across partitions
predictions = model.transform(df)
 
# Write output with partitioning
predictions.write \
    .partitionBy("date", "region") \
    .mode("overwrite") \
    .parquet("s3://outputs/batch-scores/2024-02-27/")
 
print("Batch inference complete. Output written to S3.")

Strengths:

  • Native DataFrame API integrates smoothly with data pipelines
  • Spark SQL for filtering and post-processing
  • Automatic failure recovery and task rescheduling
  • Massive ecosystem of tools (Delta Lake, Iceberg, Unity Catalog)
  • Strong partition pruning for selective re-runs

Weaknesses:

  • Spark has significant overhead for small jobs
  • GPU support is awkward; PyTorch-ddp-advanced-distributed-training)/TensorFlow models require custom serialization
  • JVM memory overhead can be substantial on large clusters
  • Debugging distributed failures is painful

When to use: You're already in the Spark ecosystem (data warehouse uses Spark) and your models are in scikit-learn or standard MLlib format, or you're comfortable with serialization workarounds.

Pattern 2: Ray Data Distributed Inference

Ray Data is built specifically for this problem. It's a DataFrames library optimized for ML workloads, with native support for distributed Python, PyTorch, and TensorFlow.

python
import ray
import torch
import pyarrow.parquet as pq
from ray.data import read_parquet
 
# Initialize Ray cluster (local dev, on-prem, or cloud)
ray.init()
 
# Read parquet directly into Ray
dataset = read_parquet("s3://data-lake/customer_features/2024/")
print(f"Dataset: {dataset.count()} rows")
 
# Load model once, reuse across workers
class BatchInferenceActor:
    def __init__(self, model_path, device="cuda"):
        self.model = torch.jit.load(model_path)
        self.model.eval()
        self.device = device
        self.model.to(device)
 
    def __call__(self, batch):
        # batch is a pyarrow.Table or pandas DataFrame
        batch_tensor = torch.tensor(
            batch["features"].to_numpy(),
            dtype=torch.float32,
            device=self.device
        )
 
        with torch.no_grad():
            logits = self.model(batch_tensor)
            predictions = torch.softmax(logits, dim=1)
 
        batch["prediction_score"] = predictions[:, 1].cpu().numpy()
        batch["prediction_class"] = predictions.argmax(dim=1).cpu().numpy()
        return batch
 
# Map inference across all workers with GPU pooling
inference_actor = BatchInferenceActor("models/model.pt")
result_dataset = dataset.map_batches(
    inference_actor,
    batch_size=512,  # Tune this for your GPU memory
    num_gpu=1,
    compute=ray.data.ActorPoolStrategy(min_size=2, max_size=8)
)
 
# Write partitioned output
result_dataset.write_parquet(
    "s3://outputs/batch-scores/2024-02-27/",
    filesystem="s3"
)
 
ray.shutdown()

Strengths:

  • Native Python: use PyTorch/TensorFlow without serialization headaches
  • GPU-aware scheduling: Ray knows which actors have GPUs
  • Auto-scaling: actors spawn on demand, scale down when idle
  • Fault tolerance with task retry semantics
  • Great for iterative development and exploration
  • Excellent observability through Ray Dashboard

Weaknesses:

  • Smaller operational maturity than Spark in enterprise environments
  • Fewer third-party integrations (though improving rapidly)
  • Ray's memory management can surprise you on very large clusters
  • Not ideal for SQL-heavy post-processing

When to use: You're using PyTorch or TensorFlow models, need fine-grained control over batching and GPU usage, and want the easiest integration with Python code.

Pattern 3: Argo Workflows Parallelism

For maximum control and flexibility, orchestrate the entire workflow with Argo. You define a DAG where each node is a containerized batch inference job, and Argo parallelizes them across your Kubernetes cluster.

yaml
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
  generateName: batch-inference-
spec:
  entrypoint: main-pipeline
  templates:
    - name: main-pipeline
      steps:
        - - name: prepare-data
            template: prepare-splits
        - - name: infer-shard-1
            template: infer-worker
            arguments:
              parameters:
                - name: shard-id
                  value: "1"
          - name: infer-shard-2
            template: infer-worker
            arguments:
              parameters:
                - name: shard-id
                  value: "2"
          - name: infer-shard-3
            template: infer-worker
            arguments:
              parameters:
                - name: shard-id
                  value: "3"
          - name: infer-shard-4
            template: infer-worker
            arguments:
              parameters:
                - name: shard-id
                  value: "4"
        - - name: merge-results
            template: consolidate-output
 
    - name: prepare-splits
      container:
        image: python:3.10
        command:
          - python
          - -c
          - |
            import pandas as pd
            df = pd.read_parquet('s3://data-lake/features/')
            n_shards = 4
            for i in range(n_shards):
                shard = df.iloc[i::n_shards]
                shard.to_parquet(f's3://temp/shard-{i}.parquet')
            print(f"Prepared {n_shards} shards")
 
    - name: infer-worker
      inputs:
        parameters:
          - name: shard-id
      container:
        image: myrepo/batch-inference:latest
        env:
          - name: SHARD_ID
            value: "{{inputs.parameters.shard-id}}"
          - name: AWS_ACCESS_KEY_ID
            valueFrom:
              secretKeyRef:
                name: aws-credentials
                key: access-key
          - name: AWS_SECRET_ACCESS_KEY
            valueFrom:
              secretKeyRef:
                name: aws-credentials
                key: secret-key
        resources:
          requests:
            nvidia.com/gpu: "1"
            memory: "8Gi"
          limits:
            nvidia.com/gpu: "1"
            memory: "8Gi"
      retryStrategy:
        limit: 2
        retryPolicy: "Always"
 
    - name: consolidate-output
      container:
        image: python:3.10
        command:
          - python
          - -c
          - |
            import pandas as pd
            import glob
            shards = glob.glob('s3://temp/output-*.parquet')
            results = [pd.read_parquet(s) for s in sorted(shards)]
            final = pd.concat(results, ignore_index=True)
            final.to_parquet('s3://outputs/final-results.parquet')
            print(f"Consolidated {len(results)} shards into final output")

Strengths:

  • Explicit DAG structure: you control exactly what runs and when
  • Kubernetes-native: leverage existing infrastructure
  • Excellent parallelism: scale from 4 shards to 400 trivially
  • Built-in retry, timeout, and resource management
  • Works with any containerized workload
  • Strong audit trail for compliance

Weaknesses:

  • Requires Kubernetes cluster (setup cost)
  • Manual shard management and merging
  • No automatic data distribution; you handle coordination
  • More operational boilerplate than Ray or Spark

When to use: You're already on Kubernetes, need explicit control over parallelism, or your inference workload is heterogeneous (some jobs use GPUs, others use CPU).

Why Throughput Matters: The Economics of Scale

Here's a reality check: choosing a pattern gets you in the game, but throughput determines whether you win or lose. Consider the math. You need to process 100 million records. Your model takes 10ms per sample. Sequential processing: 1,000 samples/second, which means 100,000 seconds = 27 hours of pure compute. Add overhead and you're looking at 40+ hours just to run the inference. On an expensive GPU instance ($3-5/hour), that's $120-200 in compute costs, before storage, networking, and engineering time.

Now optimize that same pipeline-automated-model-compression). With proper batching, async preprocessing, and model compilation, you might reach 100,000 samples/second. That's 1,000x faster. Your 40-hour job becomes 100 seconds. Your $200 cost becomes $0.08 in compute. The difference isn't just money - it's the difference between running inference daily and weekly.

Throughput optimization isn't premature optimization. It's the difference between working solutions and production systems. The techniques we'll cover - dynamic batching, asynchronous preprocessing, model compilation - aren't exotic. They're standard practice at companies processing real volumes.

Throughput Optimization: Squeezing Every Ounce

Pattern choice is only half the battle. You also need to optimize throughput within each worker.

Dynamic Batching

GPUs are like assembly lines. One item at a time is wasteful. You want to process as many samples simultaneously as your VRAM allows.

python
import torch
from torch.cuda import Event
 
class DynamicBatchInference:
    def __init__(self, model_path, target_batch_size=512, max_batch_size=2048):
        self.model = torch.jit.load(model_path).eval().cuda()
        self.target_batch_size = target_batch_size
        self.max_batch_size = max_batch_size
        self.warmup_done = False
 
    def find_max_batch_size(self):
        """Auto-tune batch size on first run."""
        if self.warmup_done:
            return
 
        dummy_input = torch.randn(1, 128).cuda()  # Assume 128-D features
 
        for batch_size in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
            try:
                batch = torch.randn(batch_size, 128).cuda()
                with torch.no_grad():
                    _ = self.model(batch)
                torch.cuda.synchronize()
                self.max_batch_size = batch_size
            except torch.cuda.OutOfMemoryError:
                torch.cuda.empty_cache()
                break
 
        self.warmup_done = True
        print(f"Auto-tuned max batch size: {self.max_batch_size}")
 
    def infer_batch(self, features_array):
        """Process a batch with optimal sizing."""
        self.find_max_batch_size()
 
        input_tensor = torch.tensor(features_array, dtype=torch.float32).cuda()
        predictions_list = []
 
        # Process in chunks if batch exceeds max
        for i in range(0, len(input_tensor), self.max_batch_size):
            chunk = input_tensor[i:i+self.max_batch_size]
            with torch.no_grad():
                chunk_preds = self.model(chunk)
            predictions_list.append(chunk_preds.cpu())
 
        return torch.cat(predictions_list, dim=0).numpy()
 
# Usage in your pipeline
inferencer = DynamicBatchInference("model.pt")
predictions = inferencer.infer_batch(features_array)  # features_array: shape (N, 128)

The result: With proper batching, you might 10x your throughput. A model that processes 100 samples/sec in isolation could hit 1,000 samples/sec with optimized batching.

Async Preprocessing

Don't make your GPU wait for CPU work. Overlap data loading and preprocessing with inference.

python
import threading
import queue
import numpy as np
 
class AsyncPreprocessor:
    def __init__(self, batch_size=512, num_workers=4):
        self.batch_queue = queue.Queue(maxsize=10)
        self.batch_size = batch_size
        self.num_workers = num_workers
 
    def preprocess_worker(self, data_source):
        """Runs in background thread, fills queue."""
        buffer = []
        for row in data_source:
            # Normalize features
            features = (row['features'] - 0.5) / 2.0
            buffer.append(features)
 
            if len(buffer) >= self.batch_size:
                self.batch_queue.put(np.array(buffer))
                buffer = []
 
        if buffer:
            self.batch_queue.put(np.array(buffer))
 
        self.batch_queue.put(None)  # Sentinel
 
    def start(self, data_source):
        """Launch preprocessing threads."""
        for _ in range(self.num_workers):
            t = threading.Thread(
                target=self.preprocess_worker,
                args=(data_source,),
                daemon=True
            )
            t.start()
 
    def get_batch(self):
        """Blocking call; returns batches as they're ready."""
        return self.batch_queue.get()
 
# In your inference loop:
preprocessor = AsyncPreprocessor(batch_size=512, num_workers=4)
preprocessor.start(parquet_file_iterator)
 
while True:
    batch = preprocessor.get_batch()
    if batch is None:
        break
    predictions = model.predict(batch)
    write_to_output(predictions)

GPU is busy while the CPU fills the next batch. No idle cycles.

Model Compilation for Batch Workloads

PyTorch 2.0+ includes torch.compile(), which fuses operations and reduces overhead. For batch workloads, the gains are substantial.

python
import torch
 
# Original model
model = load_model("model.pt").eval().cuda()
 
# Compile for batch inference
compiled_model = torch.compile(model, mode="reduce-overhead")
 
# Benchmark
import time
 
test_batch = torch.randn(512, 128).cuda()
 
# Warm up
for _ in range(3):
    with torch.no_grad():
        _ = compiled_model(test_batch)
 
# Time original
start = time.time()
for _ in range(100):
    with torch.no_grad():
        _ = model(test_batch)
original_time = time.time() - start
 
# Time compiled
start = time.time()
for _ in range(100):
    with torch.no_grad():
        _ = compiled_model(test_batch)
compiled_time = time.time() - start
 
print(f"Original: {original_time:.2f}s | Compiled: {compiled_time:.2f}s")
print(f"Speedup: {original_time/compiled_time:.1f}x")

Typical result: 1.3-2.0x speedup with minimal code changes.

Ray Data for ML Inference: Practical Deep Dive

Ray Data deserves a closer look because it's purpose-built for this problem and getting rapidly adopted.

The Full Pipeline

python
import ray
import torch
import pandas as pd
from ray.data import read_parquet
from typing import Dict
 
# Initialize Ray on a cluster (or local)
ray.init()
 
# Configuration
BATCH_SIZE = 512
NUM_GPUS_PER_WORKER = 1
NUM_WORKERS = 4
MODEL_PATH = "/models/production/model-v3.pt"
 
# 1. READ: Load from S3 or local Parquet
print("Reading dataset...")
dataset = read_parquet("s3://data-lake/features/2024-02/")
print(f"Loaded {dataset.count()} rows")
 
# 2. PREPROCESS: Normalize and shape
def preprocess_fn(batch: pd.DataFrame) -> pd.DataFrame:
    """Runs on workers; operates on pandas batches."""
    batch['norm_feature_1'] = (batch['feature_1'] - batch['feature_1'].mean()) / batch['feature_1'].std()
    batch['norm_feature_2'] = (batch['feature_2'] - batch['feature_2'].mean()) / batch['feature_2'].std()
    return batch[['id', 'norm_feature_1', 'norm_feature_2']]
 
dataset = dataset.map_batches(preprocess_fn, batch_size=1000)
 
# 3. INFERENCE: Call model on GPU
class InferenceModel:
    def __init__(self):
        self.model = torch.jit.load(MODEL_PATH).cuda().eval()
 
    def __call__(self, batch: pd.DataFrame) -> pd.DataFrame:
        # Convert to tensor
        features = torch.tensor(
            batch[['norm_feature_1', 'norm_feature_2']].values,
            dtype=torch.float32,
            device='cuda'
        )
 
        # Inference
        with torch.no_grad():
            logits = self.model(features)
            probs = torch.softmax(logits, dim=1)
 
        # Add predictions to batch
        batch['score'] = probs[:, 1].cpu().numpy()
        batch['class'] = probs.argmax(dim=1).cpu().numpy()
        return batch
 
# Apply inference with GPU actors
dataset = dataset.map_batches(
    InferenceModel,
    batch_size=BATCH_SIZE,
    num_gpu=NUM_GPUS_PER_WORKER,
    compute=ray.data.ActorPoolStrategy(min_size=2, max_size=NUM_WORKERS),
    zero_copy_batch=True  # Avoid extra memory copies
)
 
# 4. POSTPROCESS: Add timestamps and partition
import datetime
 
def add_metadata(batch: pd.DataFrame) -> pd.DataFrame:
    batch['inference_timestamp'] = datetime.datetime.utcnow().isoformat()
    batch['inference_date'] = datetime.date.today()
    return batch
 
dataset = dataset.map_batches(add_metadata, batch_size=1000)
 
# 5. WRITE: Partitioned Parquet with schema versioning
print("Writing output...")
dataset.write_parquet(
    "s3://outputs/batch-inference/2024-02-27/",
    filesystem="s3",
    partition_cols=["inference_date"],
    try_create_dir=True
)
 
print("✓ Batch inference complete")
ray.shutdown()

Output structure on S3:

s3://outputs/batch-inference/2024-02-27/
├── inference_date=2024-02-27/
│   ├── part-00000.parquet (1 GB)
│   ├── part-00001.parquet (1 GB)
│   ├── part-00002.parquet (1 GB)
│   └── ... (one file per worker)
└── _metadata

GPU Actor Pool Configuration

Ray's actor pool is the secret sauce. Here's how to tune it:

python
import ray
from ray.data import ActorPoolStrategy
 
# Strategy 1: Fixed pool
fixed_strategy = ActorPoolStrategy(min_size=4, max_size=4)
# Good when: workload is predictable, cluster is stable
 
# Strategy 2: Auto-scaling
autoscale_strategy = ActorPoolStrategy(min_size=2, max_size=16)
# Good when: workload spikes, want cost efficiency
 
# Example with monitoring
dataset = dataset.map_batches(
    InferenceModel,
    batch_size=512,
    num_gpu=1,
    compute=autoscale_strategy
)
 
# Ray creates actors on demand:
# - Starts with 2 GPU workers
# - As batches queue up, spawns more (up to 16)
# - Scales down idle workers automatically

Fault Tolerance with Task Retry

Ray's task-based execution gives you built-in fault tolerance:

python
import ray
from ray.data import ActorPoolStrategy
 
# Configure retry semantics
strategy = ActorPoolStrategy(
    min_size=2,
    max_size=8,
    # Ray handles failures transparently:
    # - Task fails on worker? Retried on another worker.
    # - Actor dies? New actor spawned, task retried.
    # - Network timeout? Task timeout, then retry.
)
 
# You can also use Ray's @ray.remote with explicit retry:
@ray.remote(max_retries=3, retry_on_timeout=True)
def infer_shard(shard_id: int):
    """This function retries up to 3 times on failure."""
    dataset = ray.data.read_parquet(f"s3://shards/shard-{shard_id}.parquet")
    result = dataset.map_batches(InferenceModel, batch_size=512)
    result.write_parquet(f"s3://outputs/output-{shard_id}.parquet")
    return shard_id
 
# Launch shards in parallel with fault tolerance
futures = [infer_shard.remote(i) for i in range(10)]
results = ray.get(futures)  # Blocks until all complete (with retries)

The guarantee: Ray ensures that each task executes at least once and at most once (idempotent semantics), so you won't duplicate predictions.

Production Considerations: Beyond Just Processing

At this point, you have inference working and it's reasonably fast. Now comes the hard part: making it production-grade. This means considering what happens after the inference finishes. Your downstream systems need to find the results, understand the schema, recover from partial failures, and potentially re-process without duplicates. This is where many teams stumble - they optimize for speed, ship, and then spend weeks debugging data quality issues downstream.

Production batch systems think about the full lifecycle. What happens if the job crashes halfway through? If you don't have idempotency, you need to delete all output and start from scratch - that 100 million record loss is devastating. What if a consumer downstream tries to read results before the job finishes? You need metadata signals. What if your model schema changes? You need versioning and migration logic.

The patterns we'll show here - partitioned output, delta lakes for idempotent writes, manifest files for coordination - aren't luxuries. They're the difference between systems that work once and systems that work reliably, month after month, under real conditions.

Output Management at Scale

Producing results is only half the problem. You also need to store, organize, and discover them efficiently.

Partitioned Parquet with Timestamp Prefixes

python
import datetime
import pandas as pd
from pathlib import Path
 
# Structure: year/month/day/hour/
timestamp = datetime.datetime.utcnow()
output_prefix = (
    f"s3://outputs/batch-inference/"
    f"{timestamp.year:04d}/{timestamp.month:02d}/{timestamp.day:02d}/"
    f"{timestamp.hour:02d}-{timestamp.minute:02d}/"
)
 
# Write with partition columns
result_dataset.write_parquet(
    output_prefix,
    partition_cols=["inference_date", "model_version"]
)
 
# Later, query by date efficiently
from pyarrow.parquet import ParquetDataset
pds = ParquetDataset(f"s3://outputs/batch-inference/2024/02/27/", validate_schema=False)
# Parquet metadata tells S3 which files contain which dates—no scanning needed

Benefit: When downstream jobs query data from a specific date, Parquet's metadata eliminates scanning unrelated files. Massive time and cost savings at scale.

Schema Evolution Handling

Your model output schema will change. Plan for it.

python
import pyarrow as pa
from pyarrow import parquet as pq
 
# Version 1: Simple scores
schema_v1 = pa.schema([
    ('id', pa.string()),
    ('score', pa.float32()),
])
 
# Version 2: Also output class and confidence
schema_v2 = pa.schema([
    ('id', pa.string()),
    ('score', pa.float32()),
    ('class', pa.int32()),
    ('confidence', pa.float32()),
])
 
# When writing, include version in output path
version = 2
output_dir = f"s3://outputs/batch-inference/v{version}/2024-02-27/"
 
# Downstream code reads version from path
def read_inference_results(path):
    # Extract version: v2/2024-02-27 -> v2
    version = int(path.split('/')[4][1:])
 
    table = pq.read_table(path)
    if version == 1:
        table = table.append_column('class', pa.array([None] * len(table)))
        table = table.append_column('confidence', pa.array([None] * len(table)))
 
    return table.to_pandas()

Idempotent Writes with Delta Lake

If your batch job re-runs (maybe it failed yesterday and you're catching up), Delta Lake prevents duplicate results:

python
import delta
from pyspark.sql import SparkSession
 
spark = SparkSession.builder \
    .appName("idempotent-batch") \
    .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
    .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
    .getOrCreate()
 
# Read and predict
df = spark.read.parquet("s3://features/")
predictions = model.transform(df)
 
# MERGE pattern: only insert records we haven't seen
predictions.write \
    .format("delta") \
    .option("mergeSchema", "true") \
    .mode("append") \
    .save("s3://outputs/delta-table/")
 
# Or use SQL MERGE for precise control
spark.sql("""
    MERGE INTO outputs.batch_results t
    USING predictions p
    ON t.id = p.id AND t.batch_date = p.batch_date
    WHEN NOT MATCHED THEN
        INSERT *
""")

Result: Re-run the same batch date multiple times; only new records are inserted.

Downstream Notification

How do downstream systems know results are ready?

python
import boto3
import json
 
s3_client = boto3.client('s3')
sns_client = boto3.client('sns')
 
def notify_completion(bucket, prefix, record_count):
    """Emit SNS message when batch is done."""
    message = {
        'event': 'batch_inference_complete',
        'bucket': bucket,
        'prefix': prefix,
        'record_count': record_count,
        'timestamp': datetime.datetime.utcnow().isoformat(),
    }
 
    sns_client.publish(
        TopicArn='arn:aws:sns:us-east-1:123456789:batch-complete',
        Subject='Batch Inference Complete',
        Message=json.dumps(message, indent=2)
    )
 
    # Also write manifest to S3 for discovery
    s3_client.put_object(
        Bucket=bucket,
        Key=f"{prefix}/_batch_manifest.json",
        Body=json.dumps({
            'status': 'complete',
            'record_count': record_count,
            'timestamp': message['timestamp'],
        })
    )
 
# At end of batch job
notify_completion('outputs', 'batch-inference/2024-02-27/', 25_000_000)

Monitoring and Visibility: Your Window into Production

"It finished" is not enough information. You need to understand what happened, how long it took, what it cost, and whether the results are usable. This is where monitoring becomes critical. Without proper instrumentation, batch jobs become black boxes - they run for hours and either succeed or fail, giving you no clues about what went wrong or how to fix it.

Real production systems emit metrics at every stage. They track how much data was read (are you processing the right volume?), how long each stage took (where's the bottleneck?), how many errors occurred (is something broken?), and ultimately how much the job cost (is this economical?). They log summaries that help you diagnose issues without replaying the entire job.

Monitoring also helps you iterate. If you change your batch size from 256 to 512, how does that affect throughput and GPU memory usage? If you add a new preprocessing step, how much does it slow things down? Without metrics, you're flying blind. With them, you can make informed optimization decisions backed by data.

The examples here show how to instrument a batch pipeline with Prometheus, Airflow, and cost tracking. The specifics change based on your stack, but the principles don't: visibility enables improvement.

Scheduling and Monitoring

A successful batch system isn't just about processing speed - it's about visibility and reliability.

Airflow BatchSensor for Triggering Jobs

python
from airflow import DAG
from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator
from airflow.sensors.external_task_sensor import ExternalTaskSensor
from airflow.operators.python import PythonOperator
from datetime import datetime, timedelta
 
default_args = {
    'owner': 'ml-platform',
    'retries': 2,
    'retry_delay': timedelta(minutes=10),
}
 
with DAG(
    'batch_inference_daily',
    default_args=default_args,
    schedule_interval='0 2 * * *',  # 2 AM UTC daily
    start_date=datetime(2024, 1, 1),
    catchup=False,
) as dag:
 
    # 1. Wait for feature generation to complete
    wait_for_features = ExternalTaskSensor(
        task_id='wait_features',
        external_dag_id='feature_generation_daily',
        external_task_id='finalize',
        mode='reschedule',
        poke_interval=60,
        timeout=3600,
    )
 
    # 2. Run batch inference
    inference_job = SparkSubmitOperator(
        task_id='run_inference',
        application='/jobs/batch_inference.py',
        conf={
            'spark.executor.memory': '16g',
            'spark.executor.cores': '8',
            'spark.executor.instances': '10',
        },
        num_executors=10,
        total_executor_cores=80,
    )
 
    # 3. Validate output
    def validate_outputs(**context):
        import boto3
        s3 = boto3.client('s3')
        response = s3.list_objects_v2(
            Bucket='outputs',
            Prefix='batch-inference/2024-02-27/'
        )
        record_count = sum(obj['Size'] for obj in response.get('Contents', []))
        print(f"Output validation: {record_count:,} bytes written")
 
        if record_count < 1_000_000:
            raise ValueError("Output suspiciously small; possible failure")
 
    validate = PythonOperator(
        task_id='validate_output',
        python_callable=validate_outputs,
    )
 
    # 4. Notify downstream
    def trigger_downstream(**context):
        print("Batch inference complete. Triggering model scoring pipeline...")
 
    notify = PythonOperator(
        task_id='notify_downstream',
        python_callable=trigger_downstream,
    )
 
    # DAG dependency
    wait_for_features >> inference_job >> validate >> notify

Progress Tracking via Prometheus Metrics

python
from prometheus_client import Counter, Histogram, Gauge
import time
 
# Define metrics
inferences_total = Counter(
    'batch_inferences_total',
    'Total inferences processed',
    ['model_version', 'status']
)
 
inference_duration = Histogram(
    'batch_inference_duration_seconds',
    'Inference latency per batch',
    ['batch_size']
)
 
records_processed = Gauge(
    'batch_records_processed',
    'Current record count',
    ['job_id']
)
 
predictions_per_second = Gauge(
    'batch_predictions_per_second',
    'Throughput',
    ['job_id']
)
 
# Instrument your inference loop
import torch
 
class InstrumentedInference:
    def __init__(self, model_path, model_version='v3', job_id='daily-batch'):
        self.model = torch.jit.load(model_path).eval().cuda()
        self.model_version = model_version
        self.job_id = job_id
        self.total_records = 0
        self.start_time = time.time()
 
    def infer_batch(self, batch):
        batch_start = time.time()
        batch_size = len(batch)
 
        try:
            # Inference
            with torch.no_grad():
                predictions = self.model(batch)
 
            # Metrics
            elapsed = time.time() - batch_start
            inference_duration.labels(batch_size=batch_size).observe(elapsed)
            inferences_total.labels(
                model_version=self.model_version,
                status='success'
            ).inc()
 
            self.total_records += batch_size
            records_processed.labels(job_id=self.job_id).set(self.total_records)
 
            elapsed_total = time.time() - self.start_time
            tps = self.total_records / elapsed_total
            predictions_per_second.labels(job_id=self.job_id).set(tps)
 
            return predictions
 
        except Exception as e:
            inferences_total.labels(
                model_version=self.model_version,
                status='failure'
            ).inc()
            raise

Cost-Per-Record Reporting

python
import boto3
from datetime import datetime
 
def calculate_batch_costs(job_name, start_time, end_time, output_records):
    """Estimate costs for batch job."""
 
    # CloudWatch metrics
    cw = boto3.client('cloudwatch')
 
    # Get average resource utilization
    response = cw.get_metric_statistics(
        Namespace='AWS/EC2',
        MetricName='CPUUtilization',
        StartTime=start_time,
        EndTime=end_time,
        Period=60,
        Statistics=['Average']
    )
 
    # Rough calculation
    # (Assumes: 4x GPU p3.2xlarge instances @ $3.06/hr + on-demand compute)
    job_duration_hours = (end_time - start_time).total_seconds() / 3600
    instance_count = 4
    price_per_instance_per_hour = 3.06  # p3.2xlarge with GPU
 
    total_cost = instance_count * price_per_instance_per_hour * job_duration_hours
    cost_per_million = (total_cost / output_records) * 1_000_000
 
    print(f"""
    ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    BATCH JOB COST REPORT: {job_name}
    ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
 
    Duration:           {job_duration_hours:.2f} hours
    Instances:          {instance_count} × p3.2xlarge (GPU)
    Total Cost:         ${total_cost:.2f}
    Records Processed:  {output_records:,}
    Cost per Record:    ${total_cost / output_records:.6f}
    Cost per 1M:        ${cost_per_million:.2f}
 
    ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    """)
 
    return {
        'total_cost': total_cost,
        'cost_per_record': total_cost / output_records,
        'cost_per_million': cost_per_million,
    }
 
# Usage
from datetime import datetime, timedelta
start = datetime.now() - timedelta(hours=3)
end = datetime.now()
costs = calculate_batch_costs('daily_inference', start, end, output_records=25_000_000)

Architecture Diagram: Full Stack

graph TD
    A["Data Lake<br/>(S3 Parquet)"] -->|read_parquet| B["Ray Data<br/>Read"]
    B -->|map_batches| C["Preprocessing<br/>(CPU)"]
    C -->|map_batches| D["Inference<br/>(GPU)"]
    D -->|map_batches| E["Postprocessing<br/>(CPU)"]
 
    F["Model<br/>(PyTorch JIT)"] -->|broadcast| D
 
    E -->|write_parquet| G["S3 Output<br/>(Partitioned)"]
    E -->|write| H["Delta Lake<br/>(Idempotent)"]
 
    G -->|metadata| I["Prometheus<br/>Metrics"]
    H -->|notify| J["SNS<br/>Notification"]
 
    K["Airflow<br/>Orchestrator"] -->|schedule| B
    K -->|monitor| I
 
    L["Cost Monitor"] -->|query| I
    L -->|report| M["Dashboard<br/>(Grafana)"]
 
    style D fill:#ffcccc
    style F fill:#ccffcc
    style A fill:#ccccff
    style G fill:#ffffcc

Putting It All Together: Complete Example

Here's a production-ready batch inference pipeline:

python
#!/usr/bin/env python3
"""
Production batch inference pipeline.
Usage: python batch_inference.py --date 2024-02-27 --model-version v3
"""
 
import argparse
import logging
import sys
import time
from datetime import datetime, timedelta
from pathlib import Path
 
import ray
import torch
import pandas as pd
import numpy as np
from ray.data import read_parquet, ActorPoolStrategy
import boto3
 
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
 
# Configuration
BATCH_SIZE = 512
NUM_GPUS_PER_WORKER = 1
NUM_WORKERS = 8
FEATURE_COLUMNS = ['feature_1', 'feature_2', 'feature_3', 'feature_4']
 
class BatchInferenceModel:
    """Stateful model actor for distributed inference."""
 
    def __init__(self, model_path: str):
        self.model = torch.jit.load(model_path).eval().cuda()
        self.inference_count = 0
        logger.info(f"Loaded model from {model_path}")
 
    def __call__(self, batch: pd.DataFrame) -> pd.DataFrame:
        """Process a batch of records."""
        try:
            # Extract features and normalize
            features = batch[FEATURE_COLUMNS].values.astype(np.float32)
            features_tensor = torch.from_numpy(features).cuda()
 
            # Inference
            with torch.no_grad():
                logits = self.model(features_tensor)
                probabilities = torch.softmax(logits, dim=1)
 
            # Extract predictions
            predicted_class = probabilities.argmax(dim=1).cpu().numpy()
            predicted_score = probabilities.max(dim=1)[0].cpu().numpy()
 
            # Add to batch
            batch['predicted_class'] = predicted_class
            batch['predicted_score'] = predicted_score
            batch['inference_timestamp'] = datetime.utcnow().isoformat()
 
            self.inference_count += len(batch)
            return batch
 
        except Exception as e:
            logger.error(f"Inference failed: {e}")
            raise
 
def main():
    parser = argparse.ArgumentParser(description='Batch inference pipeline')
    parser.add_argument('--date', type=str, required=True, help='Date for batch (YYYY-MM-DD)')
    parser.add_argument('--model-version', type=str, default='v3', help='Model version')
    parser.add_argument('--input-path', type=str, default='s3://data-lake/features/', help='Input S3 path')
    parser.add_argument('--output-path', type=str, default='s3://outputs/batch-inference/', help='Output S3 path')
    parser.add_argument('--dry-run', action='store_true', help='Don\'t actually write output')
 
    args = parser.parse_args()
 
    logger.info(f"Starting batch inference for {args.date}")
    job_start = datetime.now()
 
    # Initialize Ray
    if not ray.is_initialized():
        ray.init(ignore_reinit_error=True)
    logger.info(f"Ray initialized. Dashboard: http://127.0.0.1:8265")
 
    try:
        # 1. READ
        logger.info("Reading feature data...")
        input_full_path = f"{args.input_path}/{args.date}/"
        dataset = read_parquet(input_full_path)
        record_count = dataset.count()
        logger.info(f"Loaded {record_count:,} records")
 
        # 2. PREPROCESS
        logger.info("Preprocessing...")
        def preprocess(batch):
            for col in FEATURE_COLUMNS:
                batch[col] = (batch[col] - batch[col].mean()) / (batch[col].std() + 1e-7)
            return batch
 
        dataset = dataset.map_batches(preprocess, batch_size=BATCH_SIZE)
 
        # 3. INFERENCE
        logger.info("Running inference...")
        model_path = f"/models/model-{args.model_version}.pt"
 
        dataset = dataset.map_batches(
            BatchInferenceModel,
            fn_kwargs={'model_path': model_path},
            batch_size=BATCH_SIZE,
            num_gpu=NUM_GPUS_PER_WORKER,
            compute=ActorPoolStrategy(min_size=2, max_size=NUM_WORKERS)
        )
 
        # 4. POSTPROCESS
        def add_metadata(batch):
            batch['batch_date'] = args.date
            batch['model_version'] = args.model_version
            return batch
 
        dataset = dataset.map_batches(add_metadata, batch_size=BATCH_SIZE)
 
        # 5. WRITE
        if not args.dry_run:
            logger.info("Writing output...")
            output_path = f"{args.output_path}{args.date}/"
            dataset.write_parquet(
                output_path,
                filesystem='s3',
                partition_cols=['batch_date', 'model_version']
            )
            logger.info(f"Output written to {output_path}")
        else:
            logger.info("Dry-run mode; skipping write")
 
        # Summary
        job_duration = datetime.now() - job_start
        records_per_sec = record_count / job_duration.total_seconds()
        logger.info(f"""
        ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
        BATCH INFERENCE COMPLETE
        ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
        Records processed: {record_count:,}
        Duration:          {job_duration.total_seconds():.1f}s
        Throughput:        {records_per_sec:.0f} records/sec
        ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
        """)
 
        return 0
 
    except Exception as e:
        logger.error(f"Pipeline failed: {e}", exc_info=True)
        return 1
 
    finally:
        ray.shutdown()
 
if __name__ == '__main__':
    sys.exit(main())

Run it:

bash
python batch_inference.py --date 2024-02-27 --model-version v3
# Output:
# INFO:__main__:Ray initialized. Dashboard: http://127.0.0.1:8265
# INFO:__main__:Loaded 25,000,000 records
# INFO:__main__:Running inference...
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# BATCH INFERENCE COMPLETE
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Records processed: 25,000,000
# Duration:          180.5s
# Throughput:        138,406 records/sec
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

The Economics of Batch Inference at Scale

Before we dive into architecture-production-deployment-guide), let's talk about why this matters financially. When you're running inference on a million records, your approach determines whether your cost is five hundred dollars or fifty thousand dollars. That difference isn't hyperbole - it's the gap between careful engineering and naive approaches.

Consider a typical scenario: you need to run inference on your entire customer base once a week. You've got fifty million customers. Each prediction takes five milliseconds of GPU compute. Naively, you might think "fifty million records × five milliseconds = 250,000 seconds = 69 hours of GPU time." On an A100 GPU at spot pricing, that's roughly three hundred to four hundred dollars, depending on your cloud provider and region.

But where teams go wrong is thinking about it per-record instead of thinking about throughput. If you batch efficiently, you can do ten thousand records per second. Fifty million records divided by ten thousand records per second equals five thousand seconds, or roughly ninety minutes. You're done in the time it takes to watch a movie. More importantly, you might do this on a single GPU, or distribute across a few GPUs, rather than needing a massive fleet.

Now consider what happens when you don't batch well. Maybe you process one record at a time, spending overhead moving data to the GPU, computing, moving results back, repeating. Your throughput drops to one hundred records per second. Now that same job takes five hundred thousand seconds - that's five and a half days. You're paying for a week of GPU time for a job that should take ninety minutes. Your compute bill becomes a major operational concern. Multiply that by running the job daily or multiple times per day, and suddenly your infrastructure costs are a significant business expense.

This is why batch inference engineering matters. It's not about showing off your engineering prowess. It's about cost control and being responsible with company resources.

Real-World Lessons: What We Learned the Hard Way

Building batch systems teaches you lessons that theoretical understanding never captures. Here are patterns that emerge only after you've run billions of predictions, and they're often learned the hard way through production incidents that could have been prevented.

Idempotency is non-negotiable: Batch jobs fail. Networks hiccup. Disks fill up. A worker crashes mid-processing. Your job gets preempted by a higher-priority workload. Your job restarts mid-way through, and if you're not idempotent, you'll end up with duplicate records, partial computations, or data corruption that takes weeks to debug. The lesson: always design for re-runs. Use Delta Lake with its ACID guarantees, or implement explicit deduplication. Tag every output record with the batch run ID so you can delete an entire failed run without affecting successful batches. We've seen teams lose confidence in their inference results because they couldn't guarantee that a rerun wouldn't create duplicates. That's a serious operational debt.

Partitioning by date, not by record count: It's tempting to partition output into N equal-sized files for load balancing. Don't. Partition by meaningful business dimensions - date, region, model version, customer segment. This lets downstream systems query specific slices without scanning everything. As your data grows, this difference multiplies. A query that took 10 seconds at 1 billion records takes 100+ seconds at 10 billion if you didn't partition well. But more importantly, downstream teams expect to query "give me yesterday's predictions" not "give me files 47-82 in the output directory." Your partitioning strategy becomes an API contract. Get it right from day one, because changing it later means reprocessing enormous amounts of data.

Model versioning is separate from data versioning: Your batch pipeline processes data from date X with model version Y. Months later, you want to reprocess date X with a better model, or maybe someone discovers a bug in v3 and wants to regenerate predictions. You need to be able to do this without re-running the entire pipeline from scratch or corrupting existing results. Embed the model version in your output path: s3://outputs/v3/2024-02-27/ rather than s3://outputs/2024-02-27/. This lets you reprocess with v4 without confusion or data loss. Some teams make the mistake of treating model version as transient - like they'll only ever need the latest results. Then a data scientist discovers the v4 model has a subtle bug, and they need to roll back. If you don't have the v3 outputs preserved, you're reprocessing millions of records. Your infrastructure cost explodes.

Monitoring isn't optional: Batch jobs run for hours or days. If something goes wrong at hour three, you don't find out until the job fails at hour six. By then you've wasted three hours of expensive compute resources. Emit metrics in real-time: how many records processed, how long it took, estimated time to completion, current throughput, any error rates. Build dashboards you can glance at to know your job is healthy. This catches problems early before you've burned through your budget. We've seen teams discover a data quality issue six hours into an eight-hour job because they weren't monitoring. Had they been watching the dashboard, they'd have seen prediction confidence scores anomalously low at hour one.

Distributed training and batch inference are different problems: Some teams try to reuse distributed training infrastructure for batch inference. This usually fails because training is synchronous - all workers march together through mini-batches, synchronizing gradients after each step. Batch inference is embarrassingly parallel - each record is independent, each worker can process at its own pace. The synchronization overhead that doesn't matter in training becomes a bottleneck in inference. Use the right tool for each job. Ray is excellent for inference. Spark is solid too. But PyTorch DistributedDataParallel training code won't efficiently scale to inference - the semantics are different.

Memory leaks accumulate silently: Your inference worker processes a batch, returns the result, clears the batch. But does it clear the intermediate tensors? Does it clear the attention weights cached in memory? We've seen workers that start at 4GB of memory usage and creep up to 40GB over the course of processing 10 million records. The job never crashes - it just gets slower and slower as the OS swaps to disk. Add a memory profiler to your pipeline. Check that your memory usage is flat over time, not growing. If it's growing, profile and fix it before pushing to production.

Checkpointing isn't just for fault tolerance: Many teams implement checkpointing to handle job failures - save every N records so that restarts don't re-process everything. But checkpointing has another huge benefit: visibility. When you checkpoint, you know exactly how many records you've processed. You can estimate time to completion. You can alert if the job gets stuck. Teams that don't checkpoint are flying blind. After two hours, they don't know if the job is on track or hung. Checkpoint aggressively. Every 1000 records, every 10 seconds, whatever keeps your visibility tight.

Scaling Strategies: From Single Machine to Distributed Clusters

As your inference workloads grow, you'll reach points where a single machine becomes insufficient. These inflection points determine how you architect your system. Understanding when and how to scale is critical to avoiding over-engineering early and under-engineering when you need to scale urgently.

Single-machine inference works surprisingly well up to a certain point. Modern GPUs can process thousands of examples per second if you batch properly. A single A100 GPU can handle inference for reasonably large datasets in a few hours. The advantage of single-machine setups is simplicity: no distributed system complexity, no network coordination, minimal debugging headaches. You load your data, run inference, write results. Done.

But as your dataset grows past a few hundred million records, or as your latency requirements tighten (you need results within hours, not a day), single-machine approaches become impractical. This is where distributed batch inference enters the picture. You split your data across multiple workers, each processes its shard, and you aggregate results.

The key insight here is that batch inference distributes trivially compared to training. In training, workers must coordinate gradients - they're tightly coupled. In batch inference, workers are independent - each processes its shard, writes its results, done. This loose coupling is why you can scale batch inference to hundreds or thousands of workers without diminishing returns. The only synchronization points are the initial data split and the final result aggregation.

This looseness also means you can tolerate stragglers. In distributed training-zero-memory-efficient-training)-comparison)-zero-memory-efficient-training), one slow worker causes everyone to wait (synchronized mini-batches). In batch inference, if one worker is slow, it just finishes last - doesn't block others. This is why speculative execution works well in batch inference: you can spawn multiple workers on the same shard and use the first to finish, tolerating redundant work to reduce latency variance.

Observability and Cost Control in Production

Running batch inference in production without visibility is like flying blind. You don't know if your jobs are actually working, whether they're efficient, or whether you're leaving money on the table.

Start with basic metrics: how many records per second are you processing? What's your GPU utilization? How much wall-clock time did the job take? These metrics seem obvious, but teams often skip instrumenting them - and then have no data when they're trying to debug slow jobs. You can't optimize what you don't measure.

As your system matures, extend to cost metrics: how much did this job cost in GPU hours? Is the cost per prediction stable over time, or drifting upward? Are spot instances delivering the savings you expected, or are you getting preempted constantly? These financial metrics matter because they connect engineering decisions to business impact. When someone asks "should we double our batch size to reduce latency," you can answer "yes, it reduces cost by thirty percent."

Implement alerting on anomalies. If your throughput drops suddenly, something's wrong - investigate before the job runs for hours being slow. If your output predictions have unusual distributions compared to baseline, there might be a data quality issue or model problem. If memory usage spikes, you might have a memory leak. Alerts catch problems before they become expensive failures.

Key Takeaways

Pattern choice matters: Spark for SQL-heavy workloads and data warehouses. Ray Data for flexible, Python-native pipelines. Argo for explicit parallelism and Kubernetes deployments.

Throughput is engineering: Dynamic batching, async preprocessing, and model compilation aren't optional luxuries - they're the difference between 10k records/sec and 100k records/sec.

Output structure pays dividends: Partitioned Parquet with version prefixes makes future queries fast and schema evolution manageable.

Observability is your lifeline: Prometheus metrics, logs, and cost reporting let you detect problems early and justify infrastructure spend.

Idempotency prevents pain: Batch jobs fail. They re-run. Delta Lake or explicit deduplication ensures you don't wake up with duplicate predictions.

At scale, batch inference isn't about moving fast once - it's about systems that move fast reliably, recover gracefully from failure, and cost less as you grow.


Need help implementing this?

We build automation systems like this for clients every day.

Discuss Your Project