ML Model Distillation
Some key things to know about distilling ML models:
-
If you have a big neural net, it can have like O(10,000,000) parameters and O(1,500,000,000) MultAdds. Parameters ~= how much memory your model uses, and MultAdds ~= how much CPU your model uses and how fast it can deliver a response (latency).
-
If you want to put a neural net on mobile, you should reduce params and MultAdds. The model size above for example might get reduced to something like O(2,000,000) params (1/5 size), and O(100,000,000) MultAdds (1/3 fewer).
-
You can reduce MultAdds by shrinking the number of convolutional filters per layer in a CNN (depth shrinking), you can successively decrease your image resolution between layers (internal resolution shrinking) (requires large enough input image to do this effectively, it’s OK to blow up a small image so this is possible), and you can factorize your convolutions in a CNN using what they call depthwise separable convolution operations (because 3D convolutions in vision CNNs are depthwise separable). This reduces redundant calculation.
-
You can reduce parameters by making your network shallower and have fewer nodes. However “isn’t that going to make my network less accurate and great?” Yeah, unless you do this cool thing called distillation where you teach your neural network to mimic a state of the art network! So instead of training it like a normal model (hard targets for classification: is it a bird or not a bird), you train it to learn how the complex model learns (soft targets: how did the complex model come to 0.7662 confidence that this image contains a bird?)
Another key thing is it’s not necessary for all target classes to be represented in the distillation training set for high accuracy. For example, if your complex model recognizes handwritten digits (MNIST), and while training your small model it never sees examples of 3s (but the big model did), your small model can still learn to classify 3’s with ~98% accuracy. Because it learns how your big model generalizes.