DR

μ\muP as Optimal Transport in a Vanilla MLP

Jan 05, 2026

Deriving μ\muP as the unique scaling maximizing Wasserstein transport under stability constraints

Goal

In the following, we will show that μ\muP can be rederived from first principles as the solution to a variational problem in Wasserstein geometry. Concretely, we posit that the "correct" scaling for a neural network is the unique solution to the following optimization problem:

Variational Selection Principle
Define the function-space distance: d(f,g)(ExDx[(f(x)g(x))2])1/2 d(f, g) \doteq \left(\mathbb{E}_{x \sim \mathcal{D}_x}[(f(x) - g(x))^2]\right)^{1/2} Then the width-transfer discrepancy is: DT(s)lim supnsupt[0,T]d(ft(n,s),ft(κn,s)) \mathcal{D}_T(s) \doteq \limsup_{n \to \infty} \sup_{t \in [0,T]} d(f^{(n,s)}_t, f^{(\kappa n,s)}_t) where ft(n,s)f^{(n,s)}_t is the trained predictor at time tt with width nn and scaling ss, and κ>1\kappa > 1 is a width ratio. Define the total Wasserstein dissipation (cumulative kinetic energy integrated over time): AT(s)=1Llim infn0T12ni=1nθ˙i(t)2dt \mathcal{A}_T(s) \doteq \sum_{\ell=1}^L \liminf_{n \to \infty} \int_0^T \frac{1}{2n} \sum_{i=1}^n \left\| \dot{\theta}^{\ell}_i(t) \right\|^2 dt In the warm-up, L=1L=1 and AT(s)=lim infn0T12ni=1nθ˙i(t)2dt\mathcal{A}_T(s) = \liminf_{n \to \infty} \int_0^T \frac{1}{2n} \sum_{i=1}^n \left\| \dot{\theta}_i(t) \right\|^2 dt. (Note: This is the Benamou-Brenier "action" functional from Optimal Transport. Unlike the Principle of Least Action in mechanics, here we maximize it subject to stability—this selects the critical mobility where dissipation is finite but nonzero.) Then μ\muP is the unique scaling (within power-law learning-rate scalings, modulo time-rescaling) that solves: maxsAT(s)subject toDT(s)=0 \max_s \mathcal{A}_T(s) \quad \text{subject to} \quad \mathcal{D}_T(s) = 0 This is a constrained criticality problem: maximize total dissipation (cumulative transport) subject to stability.
This is not meant to be a new theorem about μ\muP. I just thought it was interesting to consider a geometric "selection principle": among all width-transferable scalings, pick the one with maximal non-degenerate Wasserstein transport. This reframes μ\muP as a critical-mobility point selected by a constrained variational problem, which is (maybe?) portable to other parameter spaces and metrics.

In this view, the standard "lazy" (kernel-like) parameterization fails because the transport distance is zero (the particles freeze). Unstable parameterizations fail because the transport distance explodes. μ\muP emerges as the precise scaling regime where the Wasserstein speed of the layers is non-zero, finite, and coupled.

We will proceed in two stages:

  1. Warm-Up. We analyze a 1-hidden-layer MLP with biases under square loss. We explicitly compute the Wasserstein speed of the neuron measure and show that μ\muP is the only scaling that allows non-degenerate transport.
  2. Generalization. We lift this logic to a deep MLP. We formulate a "Total Transport" objective -- summing the Wasserstein path lengths across all layers -- and show that maximizing this quantity forces every hidden layer to adopt the specific n\sqrt{n} scaling factors characteristic of μ\muP.

Networks as Particle Systems

We shall treat the network as a dynamic system of interacting particles.

Model

We focus on a standard 1-hidden-layer (MLP) with width nn with biases. For conceptual clarity in the OT setting, we use mean-field normalization:

fn(x)=1ni=1naiσ(wix+bi)+c f_n(x) = \frac{1}{n} \sum_{i=1}^n a_i \sigma(w_i^\top x + b_i) + c

(Note: Standard μ\muP uses 1n\frac{1}{\sqrt{n}} normalization. The two formulations are related by rescaling aa and time. The mean-field form aligns perfectly with OT geometry where measures are probability distributions.)

Here, the "particles" are the hidden neurons. Each neuron ii is defined by its parameter tuple θi\theta_i:

θi(ai,wi,bi)R×Rd×R \theta_i \doteq (a_i, w_i, b_i) \in \mathbb{R} \times \mathbb{R}^d \times \mathbb{R}

where aia_i is the output weight, wiw_i is the input weight vector, and bib_i is the bias. The scalar cc is the global output bias.

We train this network to minimize the population square loss L\mathcal{L} over a data distribution D\mathcal{D}:

L(θ)=E(x,y)D[12(fn(x)y)2] \mathcal{L}(\theta) = \mathbb{E}_{(x, y) \sim \mathcal{D}} \left[ \frac{1}{2} (f_n(x) - y)^2 \right]

Mean-Field Limit

In the "Mean-Field" or Optimal Transport limit, we don't track indices ii and instead track the distribution of neurons. We define the empirical measure μt\mu_t at training time tt:

μt=1ni=1nδθi(t) \mu_t = \frac{1}{n} \sum_{i=1}^n \delta_{\theta_i(t)}

This allows us to write the network output as an integral against this measure. As nn \to \infty, if μt\mu_t converges to a smooth probability density ρt\rho_t, the network becomes:

f(x)=aσ(wx+b)dρt(θ)+c f(x) = \int a \sigma(w^\top x + b) d\rho_t(\theta) + c

This is exactly the mean-field limit: the network function is an integral against the neuron distribution.

Dynamics and Wasserstein Speed

We model training as gradient flow (continuous-time gradient descent). The parameters evolve according to:

θ˙i=ηnθiL(θ) \dot{\theta}_i = -\eta_n \nabla_{\theta_i} \mathcal{L}(\theta)

where ηn\eta_n is some width-dependent learning rate (mobility).

The core quantity we care about is the Wasserstein kinetic energy. In Optimal Transport geometry (specifically the W2W_2 metric), the instantaneous kinetic energy of the measure μt\mu_t is:

Ekin(μt)=12vt(θ)2dμt(θ)12ni=1nθ˙i2 E_{\text{kin}}(\mu_t) = \frac{1}{2} \int \left\| v_t(\theta) \right\|^2 d\mu_t(\theta) \approx \frac{1}{2n} \sum_{i=1}^n \left\| \dot{\theta}_i \right\|^2

where vtv_t is the velocity field. The Wasserstein speed (metric speed) is the square root of twice the kinetic energy:

Speed(μt)=(1ni=1nθ˙i2)1/2=2Ekin(μt) \text{Speed}(\mu_t) = \left( \frac{1}{n} \sum_{i=1}^n \left\| \dot{\theta}_i \right\|^2 \right)^{1/2} = \sqrt{2 E_{\text{kin}}(\mu_t)}

This quantity dictates how the system behaves:

This is exactly a phase transition story: μ\muP sits at the critical point where kinetic energy is O(1)O(1).

Learning Rates as Mobilities

In physics, the relationship between force and velocity is mediated by a mobility tensor. Here, learning rates play the role of mobilities (inverse friction coefficients). We write the dynamics as:

θ˙i=M(n)θiL(θ) \dot{\theta}_i = -M(n) \nabla_{\theta_i} \mathcal{L}(\theta)

where M(n)M(n) is a block-diagonal mobility tensor that scales with width. The choice of scaling for M(n)M(n) determines the mobility scale: too small (lazy regime) and particles freeze; too large and particles accelerate infinitely. μ\muP corresponds to the critical mobility where kinetic energy stays finite but nonzero.

Learning rates act like mobilities (inverse friction coefficients) in an overdamped Langevin system: force is L-\nabla\mathcal{L}, velocity is θ˙\dot{\theta}, and AT\mathcal{A}_T measures total dissipation.

Computing the Gradients

To understand how the scaling affects dynamics, we need to compute the gradients. Let uiwix+biu_i \doteq w_i^\top x + b_i denote the preactivation. The gradients are:

aiL=E[e(x)1nσ(ui)] \partial_{a_i}\mathcal{L} = \mathbb{E}\left[e(x) \frac{1}{n} \sigma(u_i)\right]

wiL=E[e(x)1naiσ(ui)x] \nabla_{w_i}\mathcal{L} = \mathbb{E}\left[e(x) \frac{1}{n} a_i \sigma'(u_i) x\right]

biL=E[e(x)1naiσ(ui)] \partial_{b_i}\mathcal{L} = \mathbb{E}\left[e(x) \frac{1}{n} a_i \sigma'(u_i)\right]

cL=E[e(x)] \partial_c\mathcal{L} = \mathbb{E}[e(x)]

where e(x)=fn(x)ye(x) = f_n(x) - y is the error.

Assume we are in a well-scaled regime where fn(x)=O(1)f_n(x) = O(1) and the residual e(x)=fn(x)ye(x) = f_n(x) - y remains O(1)O(1) on the time horizon of interest (as is standard in scaling analyses). Under standard initialization with ai=O(1)a_i = O(1) and wi=O(1)w_i = O(1), this gives us:

Learning Rate Scaling and Parameter Updates

Now suppose we use different learning rates for different parameter types, scaling with width as:

ηa(n)=ηa,0nγa,ηw(n)=ηw,0nγw,ηb(n)=ηb,0nγb,ηc(n)=ηc,0nγc \eta_a(n) = \eta_{a,0} n^{\gamma_a}, \quad \eta_w(n) = \eta_{w,0} n^{\gamma_w}, \quad \eta_b(n) = \eta_{b,0} n^{\gamma_b}, \quad \eta_c(n) = \eta_{c,0} n^{\gamma_c}

Under gradient flow, the parameter velocities are:

a˙i=ηa(n)aiL=O(nγa1) \left| \dot{a}_i \right| = \eta_a(n) \left| \partial_{a_i}\mathcal{L} \right| = O(n^{\gamma_a - 1})

w˙i=ηw(n)wiL=O(nγw1) \left\| \dot{w}_i \right\| = \eta_w(n) \left\| \nabla_{w_i}\mathcal{L} \right\| = O(n^{\gamma_w - 1})

b˙i=ηb(n)biL=O(nγb1) \left| \dot{b}_i \right| = \eta_b(n) \left| \partial_{b_i}\mathcal{L} \right| = O(n^{\gamma_b - 1})

c˙=ηc(n)cL=O(nγc) \left| \dot{c} \right| = \eta_c(n) \left| \partial_c\mathcal{L} \right| = O(n^{\gamma_c})

Critical Scaling

The Wasserstein speed depends on the maximum velocity across all parameter types:

Speed(μt)=O(nmax(γa,γw,γb)1) \text{Speed}(\mu_t) = O\left(n^{\max(\gamma_a, \gamma_w, \gamma_b) - 1}\right)

This gives us three regimes:

For the output bias cc, the critical scaling is γc=0\gamma_c = 0 since its gradient is already O(1)O(1).

Selection Principle: Critical Mobility

Width-transferability requires that the training trajectory in function space converges as nn \to \infty. Empirically and theoretically, this rules out the supercritical regime. So DT=0\mathcal{D}_T = 0 (the transferability constraint) effectively enforces:

max(γa,γw,γb)1,γc0 \max(\gamma_a, \gamma_w, \gamma_b) \le 1, \qquad \gamma_c \le 0

To maximize feature learning (non-zero transport), we want to maximize the Wasserstein speed subject to this constraint. This forces us to the boundary:

γa=γw=γb=1,γc=0 \boxed{\gamma_a = \gamma_w = \gamma_b = 1, \qquad \gamma_c = 0}

This corresponds to the μ\muP scaling: hidden-unit parameters get learning rates that scale as nn, while the output parameters use constant learning rates. Uniqueness is meant within the family of power-law learning-rate scalings (mobilities) considered here, modulo an overall rescaling of time.

Gauge Equivalence (Mean-Field \leftrightarrow μ\muP Convention)
In mean-field form, the critical mobility is ηn\eta \asymp n for neuron parameters (a,w,b)(a, w, b) to achieve O(1)O(1) Wasserstein speed in the original time variable. Under the standard μ\muP 1n\frac{1}{\sqrt{n}} network, define a~iai/n\tilde{a}_i \doteq a_i / \sqrt{n}. Then: fn(x)=1ni=1na~iσ() f_n(x) = \frac{1}{\sqrt{n}} \sum_{i=1}^n \tilde{a}_i \sigma(\cdot) which is the standard 1n\frac{1}{\sqrt{n}} form. Under this change of variables, the gradient flow time variable rescales by a factor n\sqrt{n} (equivalently, learning rates rescale by n\sqrt{n}). Thus the critical mobility ηn\eta \asymp n in mean-field gauge corresponds to ηn\eta \asymp \sqrt{n} in the standard μ\muP gauge.

Deep Networks: Layerwise Transport

From here on we revert to the standard 1n\frac{1}{\sqrt{n}} neural tangent / μ\muP gauge, because it matches common deep-network parameterizations. The mean-field gauge differs by a simple rescaling of output weights and time (as noted above).

Now we lift this logic to a deep MLP. Consider a depth-LL fully-connected network with width nn in each hidden layer:

h0=x,z=1nWh1+b,h=σ(z),=1,,L h^0 = x, \qquad z^\ell = \frac{1}{\sqrt{n}} W^\ell h^{\ell-1} + b^\ell, \quad h^\ell = \sigma(z^\ell), \quad \ell = 1, \dots, L

fn(x)=1nwhL+c f_n(x) = \frac{1}{\sqrt{n}} w^\top h^L + c

We train on the same population square loss: L=12E[(fn(x)y)2]\mathcal{L} = \tfrac{1}{2} \mathbb{E}[(f_n(x) - y)^2].

Row-Gradient Scaling

The key observation is that gradients for rows of weight matrices scale as O(n1/2)O(n^{-1/2}). To see this, define the backprop signal δf/zRn\delta^\ell \doteq \partial f / \partial z^\ell \in \mathbb{R}^n. Under standard forward scaling, each coordinate satisfies δi=O(n1/2)\delta^\ell_i = O(n^{-1/2}).

For a row Wi:RnW^\ell_{i:} \in \mathbb{R}^n of the weight matrix, the gradient is:

Wi:L=E[e(x)δi(x)1nh1(x)] \nabla_{W^\ell_{i:}}\mathcal{L} = \mathbb{E}\left[e(x) \delta^\ell_i(x) \frac{1}{\sqrt{n}} h^{\ell-1}(x)\right]

Since h1=O(n)\left\| h^{\ell-1} \right\| = O(\sqrt{n}), we have (1/n)h1=O(1)\left\| (1/\sqrt{n}) h^{\ell-1} \right\| = O(1), hence:

Wi:L=O(n1/2) \boxed{\left\| \nabla_{W^\ell_{i:}}\mathcal{L} \right\| = O(n^{-1/2})}

Similarly, for biases:

biL=O(n1/2) \boxed{\left| \partial_{b^\ell_i}\mathcal{L} \right| = O(n^{-1/2})}

For the readout and output bias, we get:

wL=O(1),cL=O(1) \left\| \nabla_w\mathcal{L} \right\| = O(1), \qquad \left| \partial_c\mathcal{L} \right| = O(1)

Layerwise Neuron Measures

For each hidden layer \ell, we can define per-neuron parameters θi(Wi:,bi)\theta^\ell_i \doteq (W^\ell_{i:}, b^\ell_i) and the empirical measure:

μt,n=1ni=1nδθi(t) \mu^{\ell,n}_t = \frac{1}{n} \sum_{i=1}^n \delta_{\theta^\ell_i(t)}

The Wasserstein speed for layer \ell is:

μ˙t,nW2=O(nmax(γW,γb)1/2) \left\| \dot{\mu}^{\ell,n}_t \right\|_{W_2} = O\left(n^{\max(\gamma_W, \gamma_b) - 1/2}\right)

where γW\gamma_W and γb\gamma_b are the learning rate exponents for weights and biases in hidden layers.

Total Transport Objective

Define the total transport up to time TT as the sum of Wasserstein path lengths across all layers:

ST(s)=1Llim infn0Tμ˙t,nW2dt \mathcal{S}_T(s) \doteq \sum_{\ell=1}^L \liminf_{n \to \infty} \int_0^T \left\| \dot{\mu}^{\ell,n}_t \right\|_{W_2} dt

This measures how much "movement" happens across all layers combined. The selection principle is the same: maximize ST\mathcal{S}_T subject to width-transferability.

Deep μ\muP Result (Scaling Argument)

For deep networks, rigorous OT gradient-flow limits are less clean than in the two-layer case. The following is a somewhat handwavy scaling argument that relies on the following assumptions: signals remain O(1)O(1), no gradient explosion, and approximate exchangeability across neurons.

If we assign learning rates as:

ηW(n)=η,0nγW,ηb(n)=ηb,0nγb,ηw(n)=ηw,0nγout,ηc(n)=ηc,0nγc \eta_{W^\ell}(n) = \eta_{\ell,0} n^{\gamma_W}, \quad \eta_{b^\ell}(n) = \eta_{b,0} n^{\gamma_b}, \quad \eta_w(n) = \eta_{w,0} n^{\gamma_{\text{out}}}, \quad \eta_c(n) = \eta_{c,0} n^{\gamma_c}

then each hidden layer has the same critical exponent 12\tfrac{1}{2}. The same vanish/explode argument applies:

The mechanism is clearest at the preactivation level. For a typical neuron ii at layer \ell:

zi=1nWi:h1+bi z^\ell_i = \frac{1}{\sqrt{n}} W^\ell_{i:} h^{\ell-1} + b^\ell_i

Differentiating and focusing on the dominant term:

z˙i1nW˙i:h1 \dot{z}^\ell_i \supset \frac{1}{\sqrt{n}} \dot{W}^\ell_{i:} h^{\ell-1}

Since W˙i:nγW1/2\left\| \dot{W}^\ell_{i:} \right\| \sim n^{\gamma_W - 1/2} and h1n\left\| h^{\ell-1} \right\| \sim \sqrt{n}:

1nW˙i:h1nγW1/2 \left| \frac{1}{\sqrt{n}} \dot{W}^\ell_{i:} h^{\ell-1} \right| \sim n^{\gamma_W - 1/2}

This makes the vanish/critical/blow-up trichotomy immediate: if γW<12\gamma_W \lt \tfrac{1}{2}, preactivations freeze; if γW>12\gamma_W > \tfrac{1}{2}, they blow up; at γW=12\gamma_W = \tfrac{1}{2}, preactivation updates are O(1)O(1)—maximal non-degenerate update.

Maximizing total transport while maintaining transferability forces every hidden layer to be critical:

γW=γb=12,γout=γc=0 \boxed{\gamma_W = \gamma_b = \tfrac{1}{2}, \qquad \gamma_{\text{out}} = \gamma_c = 0}

This is the deep μ\muP prescription: every hidden layer gets n\sqrt{n} learning rate scaling, while output parameters use constant rates.

Summary

We've shown that μ\muP emerges naturally from an Optimal Transport perspective:

The OT Characterization of μ\muP
Among all parameterizations that achieve width-transferability, μ\muP is the unique one that maximizes the total Wasserstein transport distance traveled by neuron measures across all layers. This ensures maximal feature learning while maintaining stable, transferable training dynamics. The key insight is that μ\muP corresponds to the critical mobility where kinetic energy is O(1)O(1): too small and particles freeze (NTK), too large and dynamics explode.