\(\require{physics}\)
Neural Tangent Kernels

Neural Tangent Kernels

Apr 11, 2025

A rough derivation of the neural tangent kernel

What Are They and Why Do We Care?

It's long been known that overparameterized neural networks are easily able to achieve near-zero training loss, while still maintaining decent generalization performance. Even though models are initialized randomly, different initializations often converge to similarly good performance, especially when models are overparameterized (i.e. far more parameters than training data points).

Neural Tangent Kernels provide a method to explain the evolution of deep neural networks during training. By studying the NTK of a given architecture, we can understand why wide enough networks consistently converge to global minima.

Some Preliminaries

Kernels

A kernel \(K\) acts as a similarity measure between data points. Formally, we can define a kernel \(K: \mathcal{X} \times \mathcal{X} \to \mathbb{R}\), and intuitively, \(K(x, x^\prime)\) tells us how much the prediction of one data point (say \(x\)) depends on another (\(x^\prime\)). Generally, we can define some function \(\varphi\), and then the kernel is simply an inner product of the two data points after transformation so that \(K(x, x^\prime) = \langle\varphi(x), \varphi(x^\prime)\rangle\). Kernels provide a simple way to make predictions: interpreting the kernel as a measure of the similarity between two inputs, we can define a prediction for an unseen input \(x^\prime\) as: \begin{equation} \hat{y} = \sum_{i=1}^n K(x_i, x^\prime) y_i, \end{equation}

for training data \(\{(x_i, y_i)\}_{i=1}^n\). It's worth noting that a kernel function is always symmetric, and its matrix representation is always SPSD.

Jacobians

For some function \(f: \mathbb{R}^n \to \mathbb{R}^m\), we define its derivative w.r.t. an input vector \(\vec{x} \in \mathbb{R}^n\) via the Jacobian matrix \(J \in \mathbb{R}^{m \times n}\): \begin{equation} J = \pdv{f}{\vec{x}} = \begin{pmatrix} \pdv{f_1}{x_1} & \cdots & \pdv{f_1}{x_n} \\ \vdots & \ddots & \vdots \\ \pdv{f_m}{x_1} & \cdots & \pdv{f_m}{x_n} \end{pmatrix} \end{equation}

The derivative is then defined \(\nabla_\vec{x} f = J^\top\).

Gaussian Processes

Gaussian processes are super interesting and probably deserve a post of their own, but here we just define the basics. Given a collection of data points \(\{x_i\}_{i=1}^n\), GPs fundamentally assume that they follow a jointly Gaussian distribution, defined by a mean vector \(\mu(x)\) and covariance matrix \(\Sigma(x)\), where the covariance matrix is defined entrywise as \(\Sigma_{i,j} = K(x_i, x_j)\), where \(K\) is some kernel of our choosing. Then, making predictions with a GP is the same thing as sampling from this distribution, conditioned on our known points.

Setup

Consider a fully connected network of width \(L\), with \(n_i\) neurons in each layer, where \(i \in \mathbb{Z}_{0 \leq i \leq L}\). Layer \(0\) and layer \(L\) are our read-in (input) and read-out (output) layers respectively. We define the action of the network as a function \(f_\theta: \mathbb{R}^{n_0} \to \mathbb{R}^{n_L}\), where the subscript \(\theta\) indicates the parameters of our model. Our training dataset contains \(n\) input-output pairs: \(\mathcal{D} = \{(x_i, y_i)\}^n_{i=1}\), and we denote all training inputs \(\mathcal{X}\), and all training outputs \(\mathcal{Y}\).

Let's also define the forward pass of our network: \begin{equation} A^{(0)} = \vec{x}, \end{equation} \begin{equation} \widetilde{A}^{(l+1)}(\vec{x}) = \frac{1}{\sqrt{n_l}}{W^{(l)}}^\top A^{(l)} + \beta b^{(l)}, \end{equation} \begin{equation} A^{(l+1)}(\vec{x}) = \sigma(\widetilde{A}^{(l+1)}(\vec{x})), \end{equation}

where \(\widetilde{A}^{(l+1)}(\vec{x})\) denotes pre-activations, and \({A}^{(l+1)}(\vec{x})\) denotes post-activations. We apply a \(1/\sqrt{n_l}\) scaling so that infinite width networks (where most of our analysis will take place) don't diverge. We initialize all paramters i.i.d. Gaussian \(\sim \mathcal{N}(0, 1)\).

As always, our objective is to minimize some total loss function \(\mathcal{L}\) defined via a per-sample loss \(\ell\): \begin{equation} \mathcal{L}(\theta) = \frac{1}{N} \sum_{i=1}^N \ell(f(x_i; \theta), y_i). \end{equation}

Via chain rule, we have \begin{equation} \nabla_\theta \mathcal{L}(\theta) = \frac{1}{N}\sum_{i=1}^N \nabla_\theta f(x_i; \theta) \nabla_f \ell(f, y_i). \end{equation}

If we now take our step size as infinitesimally small, we can consider it a time derivative of \(\theta\), so that \begin{equation} \dv{\theta}{t} = - \nabla_\theta \mathcal{L}(\theta) = -\frac{1}{N}\sum_{i=1}^N \nabla_\theta f(x_i; \theta) \nabla_f \ell(f, y_i), \end{equation}

and using chain rule again, \begin{equation} \dv{f(\vec{x}; \theta)}{t} = \dv{f(\vec{x}; \theta)}{\theta} \dv{\theta}{t} = -\frac{1}{N}\sum_{i=1}^N \nabla_\theta f(x; \theta)^\top \nabla_\theta f(x_i; \theta) \nabla_f \ell(f, y_i). \end{equation}

Now, we see the Neural Tangent Kernel appear: \(K(x, x^\prime; \theta) = \nabla_\theta f(x; \theta)^\top \nabla_\theta f(x^\prime; \theta)\). Employing the intuition of taking an inner product of some transformation of the inputs, we can define \(\varphi(x) = \nabla_\theta f(x; \theta)\), and then our inner product is the standard dot product.

Infinite Width Networks

NTKs as Gaussian Processes

The output functions (\(f_i(x; \theta)\)) are i.i.d. centered Gaussian processes with covariance \(\Sigma^{(L)}\). We define this Gaussian process recursively: \begin{equation} \Sigma^{(1)}(x, x^\prime) = \frac{1}{n_0}x^\top x^\prime + \beta^2 \end{equation} \begin{equation} \lambda^{l+1}(x, x^\prime) = \begin{pmatrix} \Sigma^{(l)}(x, x) & \Sigma^{(l)}(x, x^\prime)\\ \Sigma^{(l)}(x^\prime, x) & \Sigma^{(l)}(x^\prime, x^\prime) \end{pmatrix} \end{equation} \begin{equation} \Sigma^{(l+1)}(x, x^\prime) = \mathbb{E}_{f \sim \mathcal{N}(0, \lambda^{(l)})} [\sigma(f(x))\sigma(f(x^\prime))] + \beta^2 \end{equation}

The proof of this statement is by induction, and is omitted here, but is covered in Lee & Bahri et al. (linked in references). We refer to this as the Neural Network Gaussian Process (NNGP).

Why is this useful? NNGPs let us do Bayesian inference using neural networks without every having to train them! We use the kernel to do GP regression, which allows us to analyze and predict generalization in the lazy learning regime.

Deterministic NTK

The main contribution of Jacot's original NTK paper is that as the network approaches the infinite width limit, the NTK does two very important things:

  1. it is deterministic at initialization (i.e. the initialization paramters don't affect the kernel, the kernel is solely determined by the architecture)
  2. it stays constant during training

This is a pretty surprising result -- it tells us that in the infinite width limit, the network's behavior is entirely determined by its architecture, not the initialization parameters. In fact, this is observed in practice: we often see large enough networks converge to similar solutions.

Convergence Analysis

Since the NTK is deterministic and constant during training, we can analyze the convergence of the network. Let's consider the case of a linear network, where the output is a linear function of the input. We'll examine gradient flow (continuous-time gradient descent) with learning rate \(\eta\):

For a network with parameters \(\theta\), the evolution of the network's output on any input \(x\) follows: \begin{equation} \dv{f(x; \theta)}{t} = -\eta \sum_{i=1}^n K(x, x_i; \theta) \nabla_f \ell(f(x_i; \theta), y_i), \end{equation}

a linear ODE in function space for which the solution can be written as \begin{equation} f(x; \theta(t)) = f(x; \theta(0)) + \sum_{i=1}^n K(x, x_i; \theta(0)) \alpha_i(t) \end{equation}

where \(\alpha_i(t)\) are time-dependent coefficients that depend on the loss function. Effectively, this equation tells us that the network's output evolves in some subspace spanned by the kernel functions \(K(x, x_i; \theta(0))\).

Practical Implications

Several useful implications follow from the NTK being deterministic and constant during training. For example, we can use the NTK to compare different architectures before training, or to understand generalization in the lazy learning regime (my research in BAIR).

  1. Architecture Design: A "good" architecture should have a well-conditioned NTK that allows for efficient learning. (A well-conditioned NTK is one where the singular values are reasonably close to one another, and allows iterative algorithms (like gradient descent) to converge quickly.)
  2. Generalization: If the NTK is well-behaved, the network will generalize well even without explicit regularization.
  3. Training Dynamics: The constant NTK assumption helps explain why wide networks often train stably and predictably, unlike their narrower counterparts and lends some theoretical justification for the empirical observation that making neural networks bigger makes them easier to train.

Computing the NTK

For a network with parameters \(\theta\), the NTK between two inputs \(x\) and \(x'\) is: \begin{equation} K(x, x'; \theta) = \mathbb{E}_{\theta \sim \mathcal{N}(0, I)} \left[ \nabla_\theta f(x; \theta)^\top \nabla_\theta f(x'; \theta) \right] \end{equation}

This is approximated by sampling multiple initializations and averaging the Jacobian products. I'm not super familar with this, but I know that Google's neural-tangents library has some tools for computing NTKs.

Visualizing the NTK

NTK Evolution Visualization
NTK Evolution with Network Width
The training data consists of noisy samples from a sine function. The diagonal of the Neural Tangent Kernel (NTK), K(x, x), is plotted for various network widths. As the width increases, the NTK curves become smoother, larger in magnitude, and more similar to each other, illustrating convergence to a deterministic kernel in the infinite-width limit.
You can download the code used to generate this figure here.

References

  1. Lilian Weng's Blog Post on NTK
  2. Neural Tangent Kernel: Convergence and Generalization in Neural Networks
  3. Deep Neural Networks as Gaussian Processes