Batch Size Trade-offs in Training
Larger batch sizes improve GPU utilization but can harm generalization. GPU utilization, gradient noise, convergence speed, and generalization trade-offs.
Intent & Description
🎯 Intent
Balance GPU utilization against model generalization by choosing appropriate batch size. Large batches improve hardware utilization but can degrade generalization performance.
📋 Context
Very small batches (1-8) have low GPU utilization, high gradient noise, slow convergence, but good generalization (implicit regularization). Large batches (1024-4096) have excellent GPU utilization, low gradient noise, fast wall-clock convergence, but can degrade generalization due to sharp minima. The linear scaling rule helps maintain gradient variance equivalence.
💡 Solution
Maximize batch size to fill GPU memory, then apply linear scaling rule (multiply LR by K when multiplying batch size by K). Use LR warmup for large batches. If generalization degrades, use gradient noise injection or sharpness-aware minimization (SAM). Measure samples/second and cost/sample, not just steps/second. For LLM pre-training, ramp batch size over training.
Real-world Use Case
📌 TL;DR
Batch size: small (1-8) = good generalization, low utilization. Large (1024-4096) = high utilization, potential generalization loss. Use linear scaling rule: LR × (batch_size / base_batch). Use warmup, consider SAM or gradient noise for large batches.
Advantages
- Systematic approach to batch size optimization
- Linear scaling rule maintains training dynamics
- Large batches improve GPU utilization and wall-clock time
- Small batches provide implicit regularization
Disadvantages
- Large batches can harm generalization
- Optimal batch size varies by task and model
- Requires hyperparameter tuning (learning rate, warmup)
- Memory constraints limit maximum batch size
# Batch Size Trade-offs in Training
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
class BatchSizeOptimizer:
def __init__(self, model, train_dataset, max_batch_size=512):
self.model = model
self.train_dataset = train_dataset
self.max_batch_size = max_batch_size
def find_optimal_batch_size(self, learning_rate=0.001):
"""Find optimal batch size using linear scaling rule"""
results = []
# Test different batch sizes
batch_sizes = [16, 32, 64, 128, 256, 512]
for batch_size in batch_sizes:
if batch_size > self.max_batch_size:
continue
# Apply linear scaling rule for learning rate
scaled_lr = learning_rate * (batch_size / 32) # Base batch size = 32
result = self.train_with_batch_size(
batch_size=batch_size,
learning_rate=scaled_lr
)
results.append(result)
# Select batch size based on samples/second and validation accuracy
best_result = max(results, key=lambda r: (
r['val_accuracy'],
r['samples_per_second']
))
return best_result
def train_with_batch_size(self, batch_size, learning_rate, epochs=5):
"""Train with specific batch size and scaled learning rate"""
dataloader = DataLoader(
self.train_dataset,
batch_size=batch_size,
shuffle=True
)
optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate)
scheduler = self.get_lr_scheduler(optimizer, total_steps=len(dataloader) * epochs)
# Training loop
for epoch in range(epochs):
self.model.train()
for batch_idx, (data, target) in enumerate(dataloader):
# Warmup for first 100 steps
if batch_idx < 100:
scheduler.step()
# Forward pass
output = self.model(data)
loss = nn.functional.cross_entropy(output, target)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Validation
val_accuracy = self.validate()
return {
'batch_size': batch_size,
'learning_rate': learning_rate,
'val_accuracy': val_accuracy,
'samples_per_second': self.calculate_throughput(dataloader)
}
def get_lr_scheduler(self, optimizer, total_steps, warmup_steps=100):
"""Linear warmup scheduler"""
def lr_lambda(current_step):
if current_step < warmup_steps:
return current_step / warmup_steps
return 1.0
return torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda, total_steps=total_steps
)
def add_gradient_noise(self, model, noise_level=0.01):
"""Add gradient noise for better generalization with large batches"""
for param in model.parameters():
if param.grad is not None:
noise = torch.randn_like(param.grad) * noise_level
param.grad.add_(noise)
def sharpness_aware_minimization(self, model, optimizer, data, target):
"""SAM for better generalization"""
# Forward pass
output = model(data)
loss = nn.functional.cross_entropy(output, target)
# Compute gradient
loss.backward()
grads = [param.grad.clone() for param in model.parameters()]
# Compute perturbation
grad_norm = torch.stack([g.norm() for g in grads]).norm()
epsilon = 0.05 / (grad_norm + 1e-12)
for param, grad in zip(model.parameters(), grads):
param.grad.add_(grad, alpha=epsilon)
# Forward with perturbation
output_perturbed = model(data)
loss_perturbed = nn.functional.cross_entropy(output_perturbed, target)
# Backward with perturbation
optimizer.zero_grad()
loss_perturbed.backward()
optimizer.step()
# Usage example
model = torchvision.models.resnet50(pretrained=False)
dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
optimizer = BatchSizeOptimizer(model, dataset, max_batch_size=256)
best_batch_config = optimizer.find_optimal_batch_size(learning_rate=0.001)
print("Optimal batch size:", best_batch_config['batch_size'])
print("Optimal learning rate:", best_batch_config['learning_rate'])
print("Validation accuracy:", best_batch_config['val_accuracy'])
print("Throughput (samples/sec):", best_batch_config['samples_per_second'])