Training Nets with Large Batch Size
Research from early 2018 had Yann Lecun saying we shouldn’t use batch sizes larger than 32 because it’s bad for your test error. It points out that when you increase batch size, you increase your gradient variance (that makes lots of intuitive sense).
This post summarizes recent research into using large batches for training.
A Common Heuristic
To review: when you’re training a deepnet with minibatch SGD and you want it to converge in a smaller number of steps, you can increase batch size. Since you’re calculating a gradient over a larger number of examples at a time, you can take larger steps toward the final loss minima (you can increase your learning rate).
A common rule of thumb for practitioners has historically been to follow a linear scaling rule: if you double your batch size, you should also double your learning rate. (They also say you should warm up the learning rate by starting small and increasing it initially).
The heuristic doesn’t work
As it turns out this is not a great assumption / doesn’t hold. Recent research from Google Brain (Nov 2018) shows that optimal effective learning rates don’t always follow lnear or square root scaling heuristics. Instead, it shows that as you increase batch size, there is initially a period of perfect scaling (higher batch size means # steps needed is reduced linearly). Then there is a period of diminishing returns, and finally there is a point at which you’ve achieved maximum data parallelism. These three periods (or “scaling regimes”) vary across different problem domains. For example they found that different nets (e.g a simple CNN vs Transformer vs Resnet-50) have different scaling characteristics as batch size increases. So there’s no heuristic you can follow.
Also note that for each point used in their research, all hyperparameters are tuned for optimal
performance. This is ostensibly done so that everything else is controlled and we get to look only at how batch size
affects net training. If you want to find the optimal batch size for your task in the way they did, you’ll also
have to tune. You have other options, for example research from OpenAI (Dec 2018)
describes using a metric called the
gradient noise scale to predict the largest useful batch size.
Finding the ideal batch size
Since we know there are (3) regimes, the ideal batch size would be the highest batch size within the “perfect scaling regime” (where larger batch size sees a linear reduction in number of training steps needed to converge). Finding this is difficult though, and requires basically duplicating the process the researchers used in the google paper. So to reiterate their findings (just read the plots and section headers of the paper if you don’t have time):
- The model matters. Training well at large batch size means you need to retune your model extensively (tune for each batch size experiment).
- The data also affects the ideal batch size, but not by much, and not in the way we’d expect. More data doesn’t mean we get to use a larger batch size (sometimes it actually seems like the opposite).
- Optimizers have different performance; using SGD with momentum extends the perfect scaling region meaning you can use a slightly larger batch size.
Interestingly if you move backward in time from this paper, you can see in this 2017 paper researchers came up with a new kind of layer-wise optimizer called the LARS optimizer. This turns out to work quite well with large batch sizes.
For most people right now it doesn’t seem to make that much sense to use large batch sizes; you’d basically need TPUs and then expensive parameter tuning (maybe through a service). There are actually other challenges even if you do go this route as to how you load that much data into your hardware.
But it is really interesting to know that it can and is done particularly by organizations with the resources to do so, and that limiting yourself to batch size 32 in all cases isn’t always the smart thing to do.
Related paper: They trained ResNet-50 in 122 seconds using extremely large batch sizes.