Neural Tangent Kernels
Apr 11, 2025
A rough derivation of the neural tangent kernel
What Are They and Why Do We Care?
It's long been known that overparameterized neural networks are easily able to achieve near-zero training loss, while still maintaining decent generalization performance. Even though models are initialized randomly, different initializations often converge to similarly good performance, especially when models are overparameterized (i.e. far more parameters than training data points).
Neural Tangent Kernels provide a method to explain the evolution of deep neural networks during training. By studying the NTK of a given architecture, we can understand why wide enough networks consistently converge to global minima.
Some Preliminaries
Kernels
A kernel \(K\) acts as a similarity measure between data points. Formally, we can define a kernel \(K: \mathcal{X} \times \mathcal{X} \to \mathbb{R}\), and intuitively, \(K(x, x^\prime)\) tells us how much the prediction of one data point (say \(x\)) depends on another (\(x^\prime\)). Generally, we can define some function \(\varphi\), and then the kernel is simply an inner product of the two data points after transformation so that \(K(x, x^\prime) = \langle\varphi(x), \varphi(x^\prime)\rangle\). Kernels provide a simple way to make predictions: interpreting the kernel as a measure of the similarity between two inputs, we can define a prediction for an unseen input \(x^\prime\) as: \begin{equation} \hat{y} = \sum_{i=1}^n K(x_i, x^\prime) y_i, \end{equation}
for training data \(\{(x_i, y_i)\}_{i=1}^n\). It's worth noting that a kernel function is always symmetric, and its matrix representation is always SPSD.
Jacobians
For some function \(f: \mathbb{R}^n \to \mathbb{R}^m\), we define its derivative w.r.t. an input vector \(\vec{x} \in \mathbb{R}^n\) via the Jacobian matrix \(J \in \mathbb{R}^{m \times n}\): \begin{equation} J = \pdv{f}{\vec{x}} = \begin{pmatrix} \pdv{f_1}{x_1} & \cdots & \pdv{f_1}{x_n} \\ \vdots & \ddots & \vdots \\ \pdv{f_m}{x_1} & \cdots & \pdv{f_m}{x_n} \end{pmatrix} \end{equation}
The derivative is then defined \(\nabla_\vec{x} f = J^\top\).
Gaussian Processes
Gaussian processes are super interesting and probably deserve a post of their own, but here we just define the basics. Given a collection of data points \(\{x_i\}_{i=1}^n\), GPs fundamentally assume that they follow a jointly Gaussian distribution, defined by a mean vector \(\mu(x)\) and covariance matrix \(\Sigma(x)\), where the covariance matrix is defined entrywise as \(\Sigma_{i,j} = K(x_i, x_j)\), where \(K\) is some kernel of our choosing. Then, making predictions with a GP is the same thing as sampling from this distribution, conditioned on our known points.
Setup
Consider a fully connected network of width \(L\), with \(n_i\) neurons in each layer, where \(i \in \mathbb{Z}_{0 \leq i \leq L}\). Layer \(0\) and layer \(L\) are our read-in (input) and read-out (output) layers respectively. We define the action of the network as a function \(f_\theta: \mathbb{R}^{n_0} \to \mathbb{R}^{n_L}\), where the subscript \(\theta\) indicates the parameters of our model. Our training dataset contains \(n\) input-output pairs: \(\mathcal{D} = \{(x_i, y_i)\}^n_{i=1}\), and we denote all training inputs \(\mathcal{X}\), and all training outputs \(\mathcal{Y}\).
Let's also define the forward pass of our network: \begin{equation} A^{(0)} = \vec{x}, \end{equation} \begin{equation} \widetilde{A}^{(l+1)}(\vec{x}) = \frac{1}{\sqrt{n_l}}{W^{(l)}}^\top A^{(l)} + \beta b^{(l)}, \end{equation} \begin{equation} A^{(l+1)}(\vec{x}) = \sigma(\widetilde{A}^{(l+1)}(\vec{x})), \end{equation}
where \(\widetilde{A}^{(l+1)}(\vec{x})\) denotes pre-activations, and \({A}^{(l+1)}(\vec{x})\) denotes post-activations. We apply a \(1/\sqrt{n_l}\) scaling so that infinite width networks (where most of our analysis will take place) don't diverge. We initialize all paramters i.i.d. Gaussian \(\sim \mathcal{N}(0, 1)\).
As always, our objective is to minimize some total loss function \(\mathcal{L}\) defined via a per-sample loss \(\ell\): \begin{equation} \mathcal{L}(\theta) = \frac{1}{N} \sum_{i=1}^N \ell(f(x_i; \theta), y_i). \end{equation}
Via chain rule, we have \begin{equation} \nabla_\theta \mathcal{L}(\theta) = \frac{1}{N}\sum_{i=1}^N \nabla_\theta f(x_i; \theta) \nabla_f \ell(f, y_i). \end{equation}
If we now take our step size as infinitesimally small, we can consider it a time derivative of \(\theta\), so that \begin{equation} \dv{\theta}{t} = - \nabla_\theta \mathcal{L}(\theta) = -\frac{1}{N}\sum_{i=1}^N \nabla_\theta f(x_i; \theta) \nabla_f \ell(f, y_i), \end{equation}
and using chain rule again, \begin{equation} \dv{f(\vec{x}; \theta)}{t} = \dv{f(\vec{x}; \theta)}{\theta} \dv{\theta}{t} = -\frac{1}{N}\sum_{i=1}^N \nabla_\theta f(x; \theta)^\top \nabla_\theta f(x_i; \theta) \nabla_f \ell(f, y_i). \end{equation}
Now, we see the Neural Tangent Kernel appear: \(K(x, x^\prime; \theta) = \nabla_\theta f(x; \theta)^\top \nabla_\theta f(x^\prime; \theta)\). Employing the intuition of taking an inner product of some transformation of the inputs, we can define \(\varphi(x) = \nabla_\theta f(x; \theta)\), and then our inner product is the standard dot product.
Infinite Width Networks
NTKs as Gaussian Processes
The output functions (\(f_i(x; \theta)\)) are i.i.d. centered Gaussian processes with covariance \(\Sigma^{(L)}\). We define this Gaussian process recursively: \begin{equation} \Sigma^{(1)}(x, x^\prime) = \frac{1}{n_0}x^\top x^\prime + \beta^2 \end{equation} \begin{equation} \lambda^{l+1}(x, x^\prime) = \begin{pmatrix} \Sigma^{(l)}(x, x) & \Sigma^{(l)}(x, x^\prime)\\ \Sigma^{(l)}(x^\prime, x) & \Sigma^{(l)}(x^\prime, x^\prime) \end{pmatrix} \end{equation} \begin{equation} \Sigma^{(l+1)}(x, x^\prime) = \mathbb{E}_{f \sim \mathcal{N}(0, \lambda^{(l)})} [\sigma(f(x))\sigma(f(x^\prime))] + \beta^2 \end{equation}
The proof of this statement is by induction, and is omitted here, but is covered in Lee & Bahri et al. (linked in references). We refer to this as the Neural Network Gaussian Process (NNGP).
Why is this useful? NNGPs let us do Bayesian inference using neural networks without every having to train them! We use the kernel to do GP regression, which allows us to analyze and predict generalization in the lazy learning regime.
Deterministic NTK
The main contribution of Jacot's original NTK paper is that as the network approaches the infinite width limit, the NTK does two very important things:
- it is deterministic at initialization (i.e. the initialization paramters don't affect the kernel, the kernel is solely determined by the architecture)
- it stays constant during training
This is a pretty surprising result -- it tells us that in the infinite width limit, the network's behavior is entirely determined by its architecture, not the initialization parameters. In fact, this is observed in practice: we often see large enough networks converge to similar solutions.
Convergence Analysis
Since the NTK is deterministic and constant during training, we can analyze the convergence of the network. Let's consider the case of a linear network, where the output is a linear function of the input. We'll examine gradient flow (continuous-time gradient descent) with learning rate \(\eta\):
For a network with parameters \(\theta\), the evolution of the network's output on any input \(x\) follows: \begin{equation} \dv{f(x; \theta)}{t} = -\eta \sum_{i=1}^n K(x, x_i; \theta) \nabla_f \ell(f(x_i; \theta), y_i), \end{equation}
a linear ODE in function space for which the solution can be written as \begin{equation} f(x; \theta(t)) = f(x; \theta(0)) + \sum_{i=1}^n K(x, x_i; \theta(0)) \alpha_i(t) \end{equation}
where \(\alpha_i(t)\) are time-dependent coefficients that depend on the loss function. Effectively, this equation tells us that the network's output evolves in some subspace spanned by the kernel functions \(K(x, x_i; \theta(0))\).
Practical Implications
Several useful implications follow from the NTK being deterministic and constant during training. For example, we can use the NTK to compare different architectures before training, or to understand generalization in the lazy learning regime (my research in BAIR).
- Architecture Design: A "good" architecture should have a well-conditioned NTK that allows for efficient learning. (A well-conditioned NTK is one where the singular values are reasonably close to one another, and allows iterative algorithms (like gradient descent) to converge quickly.)
- Generalization: If the NTK is well-behaved, the network will generalize well even without explicit regularization.
- Training Dynamics: The constant NTK assumption helps explain why wide networks often train stably and predictably, unlike their narrower counterparts and lends some theoretical justification for the empirical observation that making neural networks bigger makes them easier to train.
Computing the NTK
For a network with parameters \(\theta\), the NTK between two inputs \(x\) and \(x'\) is: \begin{equation} K(x, x'; \theta) = \mathbb{E}_{\theta \sim \mathcal{N}(0, I)} \left[ \nabla_\theta f(x; \theta)^\top \nabla_\theta f(x'; \theta) \right] \end{equation}
This is approximated by sampling multiple initializations and averaging the Jacobian products. I'm not super familar with this, but I know that Google's neural-tangents library has some tools for computing NTKs.
Visualizing the NTK
