\(\require{physics}\)
\(\mu\)P as Optimal Transport in a Vanilla MLP

\(\mu\)P as Optimal Transport in a Vanilla MLP

Jan 05, 2026

Deriving \(\mu\)P as the unique scaling maximizing Wasserstein transport under stability constraints

Goal

In the following, we will show that \(\mu\)P can be rederived from first principles as the solution to a variational problem in Wasserstein geometry. Concretely, we posit that the "correct" scaling for a neural network is the unique solution to the following optimization problem:

Variational Selection Principle
Define the function-space distance: \begin{equation} d(f, g) \doteq \left(\mathbb{E}_{x \sim \mathcal{D}_x}[(f(x) - g(x))^2]\right)^{1/2} \end{equation} Then the width-transfer discrepancy is: \begin{equation} \mathcal{D}_T(s) \doteq \limsup_{n \to \infty} \sup_{t \in [0,T]} d(f^{(n,s)}_t, f^{(\kappa n,s)}_t) \end{equation} where \(f^{(n,s)}_t\) is the trained predictor at time \(t\) with width \(n\) and scaling \(s\), and \(\kappa > 1\) is a width ratio. Define the total Wasserstein dissipation (cumulative kinetic energy integrated over time): \begin{equation} \mathcal{A}_T(s) \doteq \sum_{\ell=1}^L \liminf_{n \to \infty} \int_0^T \frac{1}{2n} \sum_{i=1}^n \norm{\dot{\theta}^{\ell}_i(t)}^2 dt \end{equation} In the warm-up, \(L=1\) and \(\mathcal{A}_T(s) = \liminf_{n \to \infty} \int_0^T \frac{1}{2n} \sum_{i=1}^n \norm{\dot{\theta}_i(t)}^2 dt\). (Note: This is the Benamou-Brenier "action" functional from Optimal Transport. Unlike the Principle of Least Action in mechanics, here we maximize it subject to stability—this selects the critical mobility where dissipation is finite but nonzero.) Then \(\mu\)P is the unique scaling (within power-law learning-rate scalings, modulo time-rescaling) that solves: \begin{equation} \max_s \mathcal{A}_T(s) \quad \text{subject to} \quad \mathcal{D}_T(s) = 0 \end{equation} This is a constrained criticality problem: maximize total dissipation (cumulative transport) subject to stability.
This is not meant to be a new theorem about \(\mu\)P. I just thought it was interesting to consider a geometric "selection principle": among all width-transferable scalings, pick the one with maximal non-degenerate Wasserstein transport. This reframes \(\mu\)P as a critical-mobility point selected by a constrained variational problem, which is (maybe?) portable to other parameter spaces and metrics.

In this view, the standard "lazy" (kernel-like) parameterization fails because the transport distance is zero (the particles freeze). Unstable parameterizations fail because the transport distance explodes. \(\mu\)P emerges as the precise scaling regime where the Wasserstein speed of the layers is non-zero, finite, and coupled.

We will proceed in two stages:

  1. Warm-Up. We analyze a 1-hidden-layer MLP with biases under square loss. We explicitly compute the Wasserstein speed of the neuron measure and show that \(\mu\)P is the only scaling that allows non-degenerate transport.
  2. Generalization. We lift this logic to a deep MLP. We formulate a "Total Transport" objective -- summing the Wasserstein path lengths across all layers -- and show that maximizing this quantity forces every hidden layer to adopt the specific \(\sqrt{n}\) scaling factors characteristic of \(\mu\)P.

Networks as Particle Systems

We shall treat the network as a dynamic system of interacting particles.

Model

We focus on a standard 1-hidden-layer (MLP) with width \(n\) with biases. For conceptual clarity in the OT setting, we use mean-field normalization: \begin{equation} f_n(x) = \frac{1}{n} \sum_{i=1}^n a_i \sigma(w_i^\top x + b_i) + c \end{equation}

(Note: Standard \(\mu\)P uses \(\frac{1}{\sqrt{n}}\) normalization. The two formulations are related by rescaling \(a\) and time. The mean-field form aligns perfectly with OT geometry where measures are probability distributions.)

Here, the "particles" are the hidden neurons. Each neuron \(i\) is defined by its parameter tuple \(\theta_i\): \begin{equation} \theta_i \doteq (a_i, w_i, b_i) \in \mathbb{R} \times \mathbb{R}^d \times \mathbb{R} \end{equation}

where \(a_i\) is the output weight, \(w_i\) is the input weight vector, and \(b_i\) is the bias. The scalar \(c\) is the global output bias.

We train this network to minimize the population square loss \(\mathcal{L}\) over a data distribution \(\mathcal{D}\): \begin{equation} \mathcal{L}(\theta) = \mathbb{E}_{(x, y) \sim \mathcal{D}} \left[ \frac{1}{2} (f_n(x) - y)^2 \right] \end{equation}

Mean-Field Limit

In the "Mean-Field" or Optimal Transport limit, we don't track indices \(i\) and instead track the distribution of neurons. We define the empirical measure \(\mu_t\) at training time \(t\): \begin{equation} \mu_t = \frac{1}{n} \sum_{i=1}^n \delta_{\theta_i(t)} \end{equation}

This allows us to write the network output as an integral against this measure. As \(n \to \infty\), if \(\mu_t\) converges to a smooth probability density \(\rho_t\), the network becomes: \begin{equation} f(x) = \int a \sigma(w^\top x + b) d\rho_t(\theta) + c \end{equation}

This is exactly the mean-field limit: the network function is an integral against the neuron distribution.

Dynamics and Wasserstein Speed

We model training as gradient flow (continuous-time gradient descent). The parameters evolve according to: \begin{equation} \dot{\theta}_i = -\eta_n \nabla_{\theta_i} \mathcal{L}(\theta) \end{equation}

where \(\eta_n\) is some width-dependent learning rate (mobility).

The core quantity we care about is the Wasserstein kinetic energy. In Optimal Transport geometry (specifically the \(W_2\) metric), the instantaneous kinetic energy of the measure \(\mu_t\) is: \begin{equation} E_{\text{kin}}(\mu_t) = \frac{1}{2} \int \norm{v_t(\theta)}^2 d\mu_t(\theta) \approx \frac{1}{2n} \sum_{i=1}^n \norm{\dot{\theta}_i}^2 \end{equation}

where \(v_t\) is the velocity field. The Wasserstein speed (metric speed) is the square root of twice the kinetic energy: \begin{equation} \text{Speed}(\mu_t) = \left( \frac{1}{n} \sum_{i=1}^n \norm{\dot{\theta}_i}^2 \right)^{1/2} = \sqrt{2 E_{\text{kin}}(\mu_t)} \end{equation}

This quantity dictates how the system behaves:

This is exactly a phase transition story: \(\mu\)P sits at the critical point where kinetic energy is \(O(1)\).

Learning Rates as Mobilities

In physics, the relationship between force and velocity is mediated by a mobility tensor. Here, learning rates play the role of mobilities (inverse friction coefficients). We write the dynamics as: \begin{equation} \dot{\theta}_i = -M(n) \nabla_{\theta_i} \mathcal{L}(\theta) \end{equation}

where \(M(n)\) is a block-diagonal mobility tensor that scales with width. The choice of scaling for \(M(n)\) determines the mobility scale: too small (lazy regime) and particles freeze; too large and particles accelerate infinitely. \(\mu\)P corresponds to the critical mobility where kinetic energy stays finite but nonzero.

Learning rates act like mobilities (inverse friction coefficients) in an overdamped Langevin system: force is \(-\nabla\mathcal{L}\), velocity is \(\dot{\theta}\), and \(\mathcal{A}_T\) measures total dissipation.

Computing the Gradients

To understand how the scaling affects dynamics, we need to compute the gradients. Let \(u_i \doteq w_i^\top x + b_i\) denote the preactivation. The gradients are: \begin{equation} \partial_{a_i}\mathcal{L} = \mathbb{E}\left[e(x) \frac{1}{n} \sigma(u_i)\right] \end{equation} \begin{equation} \nabla_{w_i}\mathcal{L} = \mathbb{E}\left[e(x) \frac{1}{n} a_i \sigma'(u_i) x\right] \end{equation} \begin{equation} \partial_{b_i}\mathcal{L} = \mathbb{E}\left[e(x) \frac{1}{n} a_i \sigma'(u_i)\right] \end{equation} \begin{equation} \partial_c\mathcal{L} = \mathbb{E}[e(x)] \end{equation}

where \(e(x) = f_n(x) - y\) is the error.

Assume we are in a well-scaled regime where \(f_n(x) = O(1)\) and the residual \(e(x) = f_n(x) - y\) remains \(O(1)\) on the time horizon of interest (as is standard in scaling analyses). Under standard initialization with \(a_i = O(1)\) and \(w_i = O(1)\), this gives us:

Learning Rate Scaling and Parameter Updates

Now suppose we use different learning rates for different parameter types, scaling with width as: \begin{equation} \eta_a(n) = \eta_{a,0} n^{\gamma_a}, \quad \eta_w(n) = \eta_{w,0} n^{\gamma_w}, \quad \eta_b(n) = \eta_{b,0} n^{\gamma_b}, \quad \eta_c(n) = \eta_{c,0} n^{\gamma_c} \end{equation}

Under gradient flow, the parameter velocities are: \begin{equation} \abs{\dot{a}_i} = \eta_a(n) \abs{\partial_{a_i}\mathcal{L}} = O(n^{\gamma_a - 1}) \end{equation} \begin{equation} \norm{\dot{w}_i} = \eta_w(n) \norm{\nabla_{w_i}\mathcal{L}} = O(n^{\gamma_w - 1}) \end{equation} \begin{equation} \abs{\dot{b}_i} = \eta_b(n) \abs{\partial_{b_i}\mathcal{L}} = O(n^{\gamma_b - 1}) \end{equation} \begin{equation} \abs{\dot{c}} = \eta_c(n) \abs{\partial_c\mathcal{L}} = O(n^{\gamma_c}) \end{equation}

Critical Scaling

The Wasserstein speed depends on the maximum velocity across all parameter types: \begin{equation} \text{Speed}(\mu_t) = O\left(n^{\max(\gamma_a, \gamma_w, \gamma_b) - 1}\right) \end{equation}

This gives us three regimes:

For the output bias \(c\), the critical scaling is \(\gamma_c = 0\) since its gradient is already \(O(1)\).

Selection Principle: Critical Mobility

Width-transferability requires that the training trajectory in function space converges as \(n \to \infty\). Empirically and theoretically, this rules out the supercritical regime. So \(\mathcal{D}_T = 0\) (the transferability constraint) effectively enforces: \begin{equation} \max(\gamma_a, \gamma_w, \gamma_b) \le 1, \qquad \gamma_c \le 0 \end{equation}

To maximize feature learning (non-zero transport), we want to maximize the Wasserstein speed subject to this constraint. This forces us to the boundary: \begin{equation} \boxed{\gamma_a = \gamma_w = \gamma_b = 1, \qquad \gamma_c = 0} \end{equation}

This corresponds to the \(\mu\)P scaling: hidden-unit parameters get learning rates that scale as \(n\), while the output parameters use constant learning rates. Uniqueness is meant within the family of power-law learning-rate scalings (mobilities) considered here, modulo an overall rescaling of time.

Gauge Equivalence (Mean-Field \(\leftrightarrow\) \(\mu\)P Convention)
In mean-field form, the critical mobility is \(\eta \asymp n\) for neuron parameters \((a, w, b)\) to achieve \(O(1)\) Wasserstein speed in the original time variable. Under the standard \(\mu\)P \(\frac{1}{\sqrt{n}}\) network, define \(\tilde{a}_i \doteq a_i / \sqrt{n}\). Then: \begin{equation} f_n(x) = \frac{1}{\sqrt{n}} \sum_{i=1}^n \tilde{a}_i \sigma(\cdot) \end{equation} which is the standard \(\frac{1}{\sqrt{n}}\) form. Under this change of variables, the gradient flow time variable rescales by a factor \(\sqrt{n}\) (equivalently, learning rates rescale by \(\sqrt{n}\)). Thus the critical mobility \(\eta \asymp n\) in mean-field gauge corresponds to \(\eta \asymp \sqrt{n}\) in the standard \(\mu\)P gauge.

Deep Networks: Layerwise Transport

From here on we revert to the standard \(\frac{1}{\sqrt{n}}\) neural tangent / \(\mu\)P gauge, because it matches common deep-network parameterizations. The mean-field gauge differs by a simple rescaling of output weights and time (as noted above).

Now we lift this logic to a deep MLP. Consider a depth-\(L\) fully-connected network with width \(n\) in each hidden layer: \begin{equation} h^0 = x, \qquad z^\ell = \frac{1}{\sqrt{n}} W^\ell h^{\ell-1} + b^\ell, \quad h^\ell = \sigma(z^\ell), \quad \ell = 1, \dots, L \end{equation} \begin{equation} f_n(x) = \frac{1}{\sqrt{n}} w^\top h^L + c \end{equation}

We train on the same population square loss: \(\mathcal{L} = \tfrac{1}{2} \mathbb{E}[(f_n(x) - y)^2]\).

Row-Gradient Scaling

The key observation is that gradients for rows of weight matrices scale as \(O(n^{-1/2})\). To see this, define the backprop signal \(\delta^\ell \doteq \partial f / \partial z^\ell \in \mathbb{R}^n\). Under standard forward scaling, each coordinate satisfies \(\delta^\ell_i = O(n^{-1/2})\).

For a row \(W^\ell_{i:} \in \mathbb{R}^n\) of the weight matrix, the gradient is: \begin{equation} \nabla_{W^\ell_{i:}}\mathcal{L} = \mathbb{E}\left[e(x) \delta^\ell_i(x) \frac{1}{\sqrt{n}} h^{\ell-1}(x)\right] \end{equation}

Since \(\norm{h^{\ell-1}} = O(\sqrt{n})\), we have \(\norm{(1/\sqrt{n}) h^{\ell-1}} = O(1)\), hence: \begin{equation} \boxed{\norm{\nabla_{W^\ell_{i:}}\mathcal{L}} = O(n^{-1/2})} \end{equation}

Similarly, for biases: \begin{equation} \boxed{\abs{\partial_{b^\ell_i}\mathcal{L}} = O(n^{-1/2})} \end{equation}

For the readout and output bias, we get: \begin{equation} \norm{\nabla_w\mathcal{L}} = O(1), \qquad \abs{\partial_c\mathcal{L}} = O(1) \end{equation}

Layerwise Neuron Measures

For each hidden layer \(\ell\), we can define per-neuron parameters \(\theta^\ell_i \doteq (W^\ell_{i:}, b^\ell_i)\) and the empirical measure: \begin{equation} \mu^{\ell,n}_t = \frac{1}{n} \sum_{i=1}^n \delta_{\theta^\ell_i(t)} \end{equation}

The Wasserstein speed for layer \(\ell\) is: \begin{equation} \norm{\dot{\mu}^{\ell,n}_t}_{W_2} = O\left(n^{\max(\gamma_W, \gamma_b) - 1/2}\right) \end{equation}

where \(\gamma_W\) and \(\gamma_b\) are the learning rate exponents for weights and biases in hidden layers.

Total Transport Objective

Define the total transport up to time \(T\) as the sum of Wasserstein path lengths across all layers: \begin{equation} \mathcal{S}_T(s) \doteq \sum_{\ell=1}^L \liminf_{n \to \infty} \int_0^T \norm{\dot{\mu}^{\ell,n}_t}_{W_2} dt \end{equation}

This measures how much "movement" happens across all layers combined. The selection principle is the same: maximize \(\mathcal{S}_T\) subject to width-transferability.

Deep \(\mu\)P Result (Scaling Argument)

For deep networks, rigorous OT gradient-flow limits are less clean than in the two-layer case. The following is a somewhat handwavy scaling argument that relies on the following assumptions: signals remain \(O(1)\), no gradient explosion, and approximate exchangeability across neurons.

If we assign learning rates as: \begin{equation} \eta_{W^\ell}(n) = \eta_{\ell,0} n^{\gamma_W}, \quad \eta_{b^\ell}(n) = \eta_{b,0} n^{\gamma_b}, \quad \eta_w(n) = \eta_{w,0} n^{\gamma_{\text{out}}}, \quad \eta_c(n) = \eta_{c,0} n^{\gamma_c} \end{equation}

then each hidden layer has the same critical exponent \(\tfrac{1}{2}\). The same vanish/explode argument applies:

The mechanism is clearest at the preactivation level. For a typical neuron \(i\) at layer \(\ell\): \begin{equation} z^\ell_i = \frac{1}{\sqrt{n}} W^\ell_{i:} h^{\ell-1} + b^\ell_i \end{equation}

Differentiating and focusing on the dominant term: \begin{equation} \dot{z}^\ell_i \supset \frac{1}{\sqrt{n}} \dot{W}^\ell_{i:} h^{\ell-1} \end{equation}

Since \(\norm{\dot{W}^\ell_{i:}} \sim n^{\gamma_W - 1/2}\) and \(\norm{h^{\ell-1}} \sim \sqrt{n}\): \begin{equation} \abs{\frac{1}{\sqrt{n}} \dot{W}^\ell_{i:} h^{\ell-1}} \sim n^{\gamma_W - 1/2} \end{equation}

This makes the vanish/critical/blow-up trichotomy immediate: if \(\gamma_W \lt \tfrac{1}{2}\), preactivations freeze; if \(\gamma_W > \tfrac{1}{2}\), they blow up; at \(\gamma_W = \tfrac{1}{2}\), preactivation updates are \(O(1)\)—maximal non-degenerate update.

Maximizing total transport while maintaining transferability forces every hidden layer to be critical: \begin{equation} \boxed{\gamma_W = \gamma_b = \tfrac{1}{2}, \qquad \gamma_{\text{out}} = \gamma_c = 0} \end{equation}

This is the deep \(\mu\)P prescription: every hidden layer gets \(\sqrt{n}\) learning rate scaling, while output parameters use constant rates.

Summary

We've shown that \(\mu\)P emerges naturally from an Optimal Transport perspective:

The OT Characterization of \(\mu\)P
Among all parameterizations that achieve width-transferability, \(\mu\)P is the unique one that maximizes the total Wasserstein transport distance traveled by neuron measures across all layers. This ensures maximal feature learning while maintaining stable, transferable training dynamics. The key insight is that \(\mu\)P corresponds to the critical mobility where kinetic energy is \(O(1)\): too small and particles freeze (NTK), too large and dynamics explode.