Sequential Learning through Knowledge Distillation
In the last blog post, we saw that neural networks can be simplified (or pruned) to alleviate their complexity and cost for computation. Often small improvements to a network's efficiency can greatly affect other tasks downstream, and speeding up a network's inference time is only one of many strategies to optimize a machine learning pipeline. In this post, we will discuss another such strategy: sequential learning.
A Tale of Two Cities
Imagine two cities, City A and City B, both with a need for machine learning. In City A, Alice is the Director of Urban Development and needs to estimate car $\text{CO}_2$ emissions in her city, and to do that she needs to know the distribution of different makes and models of cars on the road. So she sets up a camera at a busy intersection and hires Striveworks to develop a convolutional neural network to determine the distribution of car makes and models in her video footage. The resulting model is able to determine the make and model of a car in a photo with 90% accuracy. She is very pleased and puts it right to work hosted on her own Chariot inference server.
Meanwhile over in City B, Bob is an inventory manager at a car manufacturing facility and hears of the success that Alice had with machine learning and wants to replicate it for his own problem. Bob needs to estimate the distribution of car colors in his city so that he can purchase the right amount of paint for the factory. Bob knows that he can use a convolutional neural network like Alice did to solve his problem. Unfortunately, Bob doesn't have a budget for developing and hosting his own machine learning model like Alice did, but he thinks he has a way around that. He reaches out to his friend Alice and asks to use her model. He promises that he can retrain her model to do the task he wants without sacrificing its accuracy on classifying car makes and models. This way they can share the model for both of their tasks. Alice agrees, but lets him know that she doesn't have any of her original training data anymore. So Bob will have to figure out a way to retrain Alice's model using only his data without a significant loss in accuracy on her task. This is is known as sequential learning.
The problem: Catastrophic Forgetting
Bob's situation is more precisely stated as follows. Suppose you are given a model $f$ that has been trained on task A, and suppose you want to retrain it so that it can adequately perform on a new task, task B, without significantly diminishing its original performance on task A. A common approach is to retrain $f$ using the datasets from both tasks together; however, this isn't always possible (as we saw with Bob) and, even if it is, that approach can be much more computationally expensive depending on what task A is. So this problem is often given the restriction that the dataset for task A is not accessible or is minimally accessible. For simplicity, we will assume that $f$ is a classifier and that both tasks are classification tasks.
The naïve approach would be to simply train $f$ on task B with a standard loss function $\mathcal{L}_B$ (say, cross entropy) until it reaches a desirable validation accuracy. Unfortunately, this can often lead to "catastrophic forgetting", wherein the model quickly "forgets" task A while in the process of learning task B [2]. In term of decision boundaries, this means that a new decision boundary is being learned for task B that doesn't necessarily respect the old decision boundary (see schematic below).
Two classifiers being combined into a single classifier. In general, the decision boundaries of the two classifiers cannot merely be superimposed; instead new boundaries must be learned that respect both tasks.
Catastrophic forgetting is a difficult problem in real world scenarios. Many times (like with Bob) it is economically preferable to host a single model that can perform multiple tasks rather than host a separate model for each task. Additionally, to the end of making artificial neural network perform more like mammalian brains, there should be some notion of "memory" about old tasks without having to re-introduce old data.
Some Solutions: Learning Rates, Knowledge Distillation and Elastic Weight Consolidation
In this post, we will study the efficacy of three common alleviations to catastrophic forgetting:
-
Lowering the learning rate (also called fine-tuning [3]).
-
Knowledge Distillation (KD): $\mathcal{L}_B + \kappa \mathcal{L}_{KD}$.
-
Elastic Weight Consolidation (EWC): $\mathcal{L}_B + \lambda \mathcal{L}_{EWC}$.
The rationale behind lowering the learning rate is that smaller step sizes will possibly prevent the process from taking too extreme of update steps. Knowledge Distillation is due to Hinton, Vinyals and Dean [1, 3], and Elastic Weight Consolidation is due to Kirkpatrick et. al. [2]. Both of these techniques consist of adding an extra loss term of the base loss $\mathcal{L}_B$ that discourages the model from deviating too far from its original state at time zero. We will see that the KD loss $\mathcal{L}_{KD}$ is an "extrinsic" measure of this deviation and the EWC loss $\mathcal{L}_{EWC}$ is an "intrinsic" measure of this deviation. In the next sections we give a short account of these two loss terms as well as their formulae.
Knowledge Distillation
Knowledge distillation (KD) is part of a scheme used in [3] and originally introduced by [1]. It consists of adding a loss term that discourages the network parameters from deviating too far from the original state. Let $\theta_A$ be the weights of $f$ after it has been trained on task A. An extrinsic proxy for the state of $f$ on task A data is the collection of softmaxed outputs ("responses") on task B data: \[{ \tilde{y}_i := \text{softmax}(f_{\theta_A}(x_i)) \mid x_i \in \mathcal{X}_B}\] Denote $y_i = \text{softmax}(f_{\theta}(x_i))$ to be the softmaxed outputs of $f$ during the training of task B. The knowledge distillation loss term $\mathcal{L}_{KD}$ is a cross-entropy-like function of the form \begin{equation} \mathcal{L}_{KD}(y_i) = -\sum_i \mathcal{F}_T(y_i) \log \mathcal{F}_T(\tilde{y}_i) \end{equation} where $\mathcal{F}_T$ is a regularization function is defined by \[ \mathcal{F}_T(u) = \frac{(u^{(i)})^{1/T}}{\sum_j (u^{(j)})^{1/T}}, \quad T \in \mathbb{R}_{\geq 0} \] The purpose of $\mathcal{F}_T$, known as temperature scaling, is to artificially raise lower probabilities in the softmax vectors. It has been shown to be effective by [1] for $T >1$. Adding the KD loss term will encourage the network to seek weights that recreate the original responses $\tilde{y}_i$. Assuming a somewhat broad distribution of data in dataset B, the responses $\tilde{y}_i$ will ideally carry enough information about the old network to performance on task A.
Elastic Weight Consolidation
Elastic Weight Consolidation (EWC) was introduced in [2] and is a more direct way to enforce that $\theta$ stay close to $\theta_A$. To do this, one might add a loss term of the form: \[\mathcal{L}(\theta) = \sum_j (\theta^{(j)} -\theta_A^{(j)})^2 \] which punishes large deviations from $\theta_A$. However this wrongly assumes that all parameters matter equally, which is generally not true for a neural network. To determine a weighting for each parameter component $\theta^{(j)}$, the authors in [2] use the Fisher information matrix $F_{ij}$. It is the Hessian of the infinitesimal K-L divergence: \begin{equation} F_{ij} = \text{Hess}(D_{KL}(f_{\theta}|| f_{\theta'}))|_{\theta' = \theta} = -\mathbb{E} \left[ \frac{\partial^2}{\partial \theta^{(i)} \partial \theta^{(j)}} \log f_\theta(x) \ \big | \ \theta \right] \end{equation} The diagonals $F_{ii}$ of the Fisher matrix determine how much the parameter $\theta^{(i)}$ contributes to the entropy of $f_{\theta_A}$. The higher the value, the more it matters. This leads to the Elastic Weight Consolidation loss: \begin{equation} \mathcal{L}_{EWC}(\theta) = \sum_j F_{jj} (\theta^{(j)}-\theta_A^{(j)})^2 \end{equation} This loss term directly encourages the neural network to stay parametrically near $\theta_A$, while still allowing new things to be learned. Note that this loss term uses an intrinsic measure of the network's state, namely the Fisher information, to punish large deviations from that state.
In order to actually compute the Fisher information, one needs a reasonably small representative sample of the training data on task A. While this would technically violate our constraint of not having access to the training data for task A, in practice it is common to be able to obtain such a sample.
Implementation and Experiments
Let's see if we can solve Bob's problem by implementing these techniques. First we trained a ResNet 18 classifier on task A, which is car make and model classification. The dataset we used for this task is VMMRdb, which we curated to have 182 different classes of car makes and models. This model achieved reached 87% validation accuracy for task A after 10,000 steps. Let $f_{\theta_A}$ denote this model.
For task B, we used the car color dataset curated by Chen, Bai and Lu [5], which consists of 10 color labels. We added an extra 10 corresponding output nodes to the ResNet classifier we trained on task A so that it can make these color predictions. We considered the following three arrangements for training on task B:
-
Baseline: Train $f_{\theta_A}$ with the loss $\mathcal{L}_B$, using both a low learning rate of 0.0001 and a standard learning rate of 0.001. The results of these experiments are shown in Figures 1 and 2.
-
Knowledge Distillation (KD): Train $f_{\theta_A}$ with the loss $\mathcal{L}_B + \kappa \mathcal{L}_{KD}$, using both a low learning rate of 0.0001 and a standard learning rate of 0.001.
-
Elastic Weight Consolidation (EWC): Train $f_{\theta_A}$ with the loss $\mathcal{L}_B + \lambda \mathcal{L}_{EWC}$, using both a low learning rate of 0.0001 and a standard learning rate of 0.001.
The optimal parameters $\kappa$ and $\lambda$ were determined by hyperparameter exhaustion and the loss $\mathcal{L}_B$ was cross-entropy loss.
Baseline Performance
The baseline approach which uses no extra loss term and a standard learning rate of 0.001 is shown below in Figure 1. This performed almost as bad as possible; the validation performance on task A dropped to zero only after a few steps.
Figure 1:
The baseline approach demonstrating catastrophic forgetting. We see the validation performance on the original task (red) decrease to practically zero as the new task (blue) is learned.
Now let's repeat this approach but with the lower learning rate of 0.0001; this is shown in the Figure below.
Figure 2:
The baseline approach but using a lower learning rate of 0.0001. Using a lower learning rate partially alieves catastrophic forgetting. Task A validation (red) starts to decrease as soon as training on task B starts, but less drastically than with a standard learning rate.
While lowering the learning rate does appear to alleviate some of the forgetting on task A, the validation performance on task A nevertheless decreases nontrivially over time.
KD and EWC Performance
Let's see how Knowledge Distillation and Elastic Weight Consolidation perform now that we understand the baselines. This is shown in the figure below.
Figure 3:
Using Knowledge Distillation (KD) and Elastic Weight Consolidation (EWC) helps a model retain information and performance from task A.
Recall that each approach started from the same trained model on task A, hence why all models have identical performance for the Task A training portion of Figure 3. Further, once task B training begins, we see how each method learns task B in a roughly monotonically increasing fashion (bottom plot). But for task A performance (top plot), this is not the case, and some methods stand out more clearly than others. We will discuss these results in the next section.
Results & Discussion
To evaluate these three techniques, we'll use the peak average validation score between task A and task B; in other words: \[ \text{Score} = \max_{\text{time steps}} \frac{\text{val}(\text{Task A}) + \text{val}(\text{Task B})}{2} \] Let $\text{Score}_A$ be the validation score of task A that achieves the total $\text{Score}$, and similarly for $\text{Score}_B$. These scores are shown in the table below.
Technique | $\text{Score}$ | $\text{Score}_A$ | $\text{Score}_B$ |
---|---|---|---|
Baseline | 48.4% | 0% | 96.8% |
Baseline (low lr) | 64.4% | 40% | 88.7% |
Elastic Weight Cons. | 74% | 55.1% | 92.8% |
Elastic Weight Cons. (low lr) | 77.8% | 66.9% | 88.7% |
Knowledge Distillation | 70.6% | 60.2% | 81% |
Knowledge Distillation (low lr) | 80.3% | 73.4% | 87.1% |
- As expected, the baselines were the worst performers. Even with a lower learning rate, the task A validation performance monotonically decreased as task B was learned.
- The EWC loss combined with a lower learning rate was able to partially mitigate forgetting, as the task A validation performance leveled out at around 66% while task B was learned. Interestingly, with the higher learning rate EWC, we observe strong oscillations, which is consistent with the "elasticity" of the loss function $\mathcal{L}_{EWC}$.
- The KD loss combined with lower learning rate performed the best overall. At its peak, it achieved a task A validation performance of 73.4% and a task B validation performance of 87.1%, for a score of 80.3%. This method is further preferable because it does not require a small representative sample of task A training data as EWC did.
Conclusion
Both Knowledge Distillation and Elastic Weight Consolidation proved to be effective at mitigating catastrophic forgetting. We also saw that regardless of which additional loss one chooses, lowering the learning rate is an effective way to boost a model's resilience to forgetting old tasks.
References
-
Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling Knowledge in a Neural Network. arXiv:1503.02531v1.
-
James Kirkpatrick, Razvan Pascanu, Neil Rabinowitz, Joel Veness, Guillaume Desjardins, Andrei A. Rusu, Kieran Milan, John Quan, Tiago Ramalho, Agnieszka Grabska-Barwinska, Demis Hassabis, Claudia Clopath, Dharshan Kumaran and Raia Hadsell. Overcoming Catastrophic Forgetting in Neural Networks. arXiv:1612.00796v2. January 2017.
-
Zhizhong Li and Derek Hoiem. Learning Without Forgetting. Presented at IEEE Transactions on Pattern Analysis and Machine Intelligence. Vol. 40, Issue 12. arXiv:1606.09282v3. December 2018.
-
Friedemann Zenke, Ben Poole and Surya Ganguli. Continual Learning through Synaptic Intelligence. arXiv:1703.04200v3. June 2017.
-
P. Chen, X. Bai and W. Liu, Vehicle Color Recognition on Urban Road by Feature Context in IEEE Transactions on Intelligent Transportation Systems, vol. 15, no. 5, pp. 2340-2346, Oct. 2014, doi: 10.1109/TITS.2014.2308897.