Training quality often improves with larger batches. But large batches require more memory. Gradient accumulation lets you compute K micro-batches sequentially, accumulating gradients, then take one optimizer step — emulating a K× larger batch with constant memory. Essential for CPU training.
The basic loop
# Without accumulation:
for batch in dataloader:
loss = forward(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# With accumulation:
for i, batch in enumerate(dataloader):
loss = forward(batch) / K # scale loss
loss.backward() # gradients accumulate
if (i+1) % K == 0:
optimizer.step()
optimizer.zero_grad()Each backward call adds gradients to existing param.grad. Skipping the optimizer step and zero_grad lets K micro-batches contribute to one update. Loss must be scaled by 1/K so the accumulated gradient equals the gradient of the average loss.
Math equivalence
# Accumulated gradient over K micro-batches:
g_total = (1/K) * sum over k of g_k
# Same as gradient of average loss:
g_total = grad((1/K) * sum loss_k)As long as your model is purely a per-example function (no batch statistics), accumulation is exactly equivalent to a larger batch. BatchNorm breaks this — RMSNorm/LayerNorm don't (per-token only).
Memory trade-off
Memory needed: 1 micro-batch's activations + full gradients + optimizer state. Compute trade-off: K forward+backward per optimizer step (same total compute as K micro-batches). Effective batch size: K × micro-batch size.
CPU example
# 350M param model, training 32 examples per effective batch:
micro_batch = 1 # all that fits in 16GB RAM
K = 32 # accumulation steps
effective_batch = 32
for i, ex in enumerate(loader):
loss = model(ex) / K
loss.backward()
if (i+1) % 32 == 0:
optim.step()
optim.zero_grad()Micro-batch = 1 (the only choice on CPU for many SLMs). Accumulate gradients. Optimizer step every K examples. Same as full-batch training on a GPU.
Combining with LR scaling
Common rule of thumb: scale LR linearly with effective batch size (up to some limit). If switching from batch=8 to batch=32 via accumulation, double or triple the peak LR. Re-tune after; don't blindly apply.