Quantization Trade-offs
Reduce numerical precision to shrink memory and increase speed, at the cost of accuracy. FP32, BF16, FP16, FP8, INT8, INT4, and 1-bit quantization trade-offs.
Intent & Description
🎯 Intent
Balance model accuracy against memory footprint and inference speed by reducing numerical precision. Different quantization formats offer different accuracy-vs-efficiency trade-offs.
📋 Context
FP32 provides baseline accuracy but highest memory usage (4 bytes/parameter). BF16/FP16 halve memory with minimal accuracy loss. FP8 reduces to 1/4 memory with small accuracy loss. INT8 provides 2× memory reduction over FP16 with very small accuracy loss. INT4 (GPTQ/AWQ) provides 4× reduction with small-moderate accuracy loss. 1-bit quantization offers 8× reduction but significant accuracy loss (research stage).
💡 Solution
Use BF16 for training (better dynamic range than FP16). Use INT8 for inference as default. Use AWQ over GPTQ for INT4 when possible. Use mixed-precision quantization (higher precision for sensitive layers). Always benchmark specific tasks — accuracy loss is task-dependent.
Real-world Use Case
📌 TL;DR
Quantization: BF16/FP16 (2× cheaper, minimal loss). INT8 (4× cheaper, very small loss). INT4 AWQ (8× cheaper, small loss). Use BF16 for training, INT8 for inference default, INT4 AWQ for cost optimization. Mixed precision for sensitive layers.
Advantages
- Significant memory and cost savings (2-8×)
- Minimal accuracy loss for moderate quantization (INT8)
- Different formats for different use cases
- Modern quantization methods preserve accuracy well
Disadvantages
- Accuracy loss varies by task and model
- Some layers more sensitive than others
- Quantization adds deployment complexity
- Extreme quantization (1-bit) still experimental
# Quantization Trade-offs
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import infer_auto_device_map, dispatch_model
class QuantizationManager:
def __init__(self, model_name):
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def load_fp32(self):
"""Baseline FP32 model"""
model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float32
)
return self._get_memory_info(model, 'FP32')
def load_bf16(self):
"""BF16 - 2x memory reduction, minimal accuracy loss"""
model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
return self._get_memory_info(model, 'BF16')
def load_int8(self):
"""INT8 - 4x memory reduction, very small accuracy loss"""
model = AutoModelForCausalLM.from_pretrained(
self.model_name,
load_in_8bit=True,
device_map="auto"
)
return self._get_memory_info(model, 'INT8')
def load_int4_awq(self):
"""INT4 AWQ - 8x memory reduction, small accuracy loss"""
model = AutoModelForCausalLM.from_pretrained(
self.model_name,
load_in_4bit=True,
device_map="auto"
)
return self._get_memory_info(model, 'INT4-AWQ')
def load_mixed_precision(self):
"""Mixed precision: higher precision for sensitive layers"""
model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
device_map="auto"
)
# Apply higher precision to first/last layers
self._apply_layer_specific_precision(model)
return self._get_memory_info(model, 'Mixed')
def _get_memory_info(self, model, precision):
"""Get memory usage information"""
memory_info = {}
total_params = sum(p.numel() for p in model.parameters())
memory_info['total_parameters'] = total_params
memory_info['precision'] = precision
# Estimate memory based on precision
bytes_per_param = {
'FP32': 4, 'BF16': 2, 'INT8': 1, 'INT4-AWQ': 0.5, 'Mixed': 1.5
}
memory_info['estimated_memory_gb'] = (
total_params * bytes_per_param.get(precision, 2) / 1e9
)
return memory_info
def _apply_layer_specific_precision(self, model):
"""Apply higher precision to sensitive layers"""
# First and last layers typically more sensitive
first_layer = model.model.layers[0]
last_layer = model.model.layers[-1]
# Convert to higher precision
first_layer.to(torch.float32)
last_layer.to(torch.float32)
def benchmark_accuracy(self, model, test_data):
"""Benchmark accuracy after quantization"""
# Implementation would compare metrics on test set
baseline_accuracy = 0.92
quantized_accuracy = self._evaluate_model(model, test_data)
accuracy_drop = baseline_accuracy - quantized_accuracy
return {
'baseline_accuracy': baseline_accuracy,
'quantized_accuracy': quantized_accuracy,
'accuracy_drop': accuracy_drop
}
# Usage example
quant_manager = QuantizationManager('meta-llama/Llama-2-7b-hf')
# Compare different quantization strategies
fp32_info = quant_manager.load_fp32()
bf16_info = quant_manager.load_bf16()
int8_info = quant_manager.load_int8()
int4_info = quant_manager.load_int4_awq()
print("FP32 Memory:", fp32_info['estimated_memory_gb'], "GB")
print("BF16 Memory:", bf16_info['estimated_memory_gb'], "GB")
print("INT8 Memory:", int8_info['estimated_memory_gb'], "GB")
print("INT4 Memory:", int4_info['estimated_memory_gb'], "GB")
# Benchmark accuracy impact
int8_model = quant_manager.load_int8()
accuracy_results = quant_manager.benchmark_accuracy(int8_model, test_data)
print("Accuracy drop:", accuracy_results['accuracy_drop'])