Warning: First draft, read at your own peril
Background
Language models are probabilistic models $P_{\theta}(x)$ where $x$ is a string and $\theta$ is the model parameters. During gradient descent, we update $\theta$ and calculate loss (or error function) $L(\theta)$.
In this article, I want to understand how the model changes when we change $\theta$ at timestep $t+1$ w.r.t the previous timestep $\theta_t$. This article is partially motivated by Natural Gradients1 and partly because I am driven by nightmares that I am not setting the learning rate correctly! xD
These derivations were shown by Andy1. I just want to re-write them for my understanding.
The first two terms are zero
Let $P_{\theta}$ refer to the probability measure induced by $\theta$ on some set of strings $X$.
where H is the Hessian w.r.t $\theta$.
The first term $D_{KL}(P_{\theta_t} || P_{\theta_t})$ is $0$ as the probability distributions are the same.
The second term is also $0$. This is quite interesting.
This means $D_{KL}(P_{\theta_t} || P_{\theta}) \approx (\theta - \theta_t)^{\top} H(\theta_t) (\theta - \theta_t)$.
Bounded weight updates bound behaviour
Let’s ask the question what is the maximum $D_{KL}(P_{\theta_t} || P_{\theta})$ given finite weight updates $\Delta \theta$.
where $\lambda_{max}(\theta_t)$ is the largest eigenvalue of $H(\theta_t)$
If we consider a weight update in SGD with l2 regularization:
where $\nu_t$ is the step multiplier (or learning rate at step t) and $\gamma$ is the weight decay coefficient. Furthermore, let’s presume that the gradient norm is clipped to 1 which is a standard practice.
We can conclude three things from the bound $D_{KL}(P_{\theta_t} || P_{\theta_{t + 1}})$:
-
The eigenvalues of the Hessian influence model behaviour.
-
Weight regularization only increases the magnitude of KL-divergence.
-
Smaller learning rate limits change in model behaviour in accordance with expectation.
It would be interesting to see how problem formulation or model architecture can change the Hessian. This may suggest certain training methodologies more readily change model behaviours.
In fine-tuning, practioners use small learning rates. Despite this, the model behaviours seems to change drastically. This maybe because of we measure other metrics rather than KL-divergence but it may also hint that the eigenvalues of the Hessian are large.
Ghorbani et al, 20192 shows the eigenvalues of the Hessian getting larger during training. They demonstrated this for ResNet over image datasets.
We could view pre-training through the lens of eigenstructure of the hessian. Pre-training neural networks helps with finetuning. Measure the PSD-ness of the Hessian could serve as an indicator of how much the behaviour could change during fine-tuning (even RL-finetuning).