Deriving μP as the unique scaling maximizing Wasserstein transport under stability constraints
Goal
In the following, we will show that μP can be rederived from first principles as the solution to a variational problem in Wasserstein geometry. Concretely, we posit that the "correct" scaling for a neural network is the unique solution to the following optimization problem:
Variational Selection Principle
Define the function-space distance:
d(f,g)≐(Ex∼Dx[(f(x)−g(x))2])1/2
Then the width-transfer discrepancy is:
DT(s)≐n→∞limsupt∈[0,T]supd(ft(n,s),ft(κn,s))
where ft(n,s) is the trained predictor at time t with width n and scaling s, and κ>1 is a width ratio.
Define the total Wasserstein dissipation (cumulative kinetic energy integrated over time):
AT(s)≐ℓ=1∑Ln→∞liminf∫0T2n1i=1∑nθ˙iℓ(t)2dtIn the warm-up, L=1 and AT(s)=liminfn→∞∫0T2n1∑i=1nθ˙i(t)2dt.(Note: This is the Benamou-Brenier "action" functional from Optimal Transport. Unlike the Principle of Least Action in mechanics, here we maximize it subject to stability—this selects the critical mobility where dissipation is finite but nonzero.)
Then μP is the unique scaling (within power-law learning-rate scalings, modulo time-rescaling) that solves:
smaxAT(s)subject toDT(s)=0
This is a constrained criticality problem: maximize total dissipation (cumulative transport) subject to stability.
This is not meant to be a new theorem about μP. I just thought it was interesting to consider a geometric "selection principle": among all width-transferable scalings, pick the one with maximal non-degenerate Wasserstein transport. This reframes μP as a critical-mobility point selected by a constrained variational problem, which is (maybe?) portable to other parameter spaces and metrics.
In this view, the standard "lazy" (kernel-like) parameterization fails because the transport distance is zero (the particles freeze). Unstable parameterizations fail because the transport distance explodes. μP emerges as the precise scaling regime where the Wasserstein speed of the layers is non-zero, finite, and coupled.
We will proceed in two stages:
Warm-Up. We analyze a 1-hidden-layer MLP with biases under square loss. We explicitly compute the Wasserstein speed of the neuron measure and show that μP is the only scaling that allows non-degenerate transport.
Generalization. We lift this logic to a deep MLP. We formulate a "Total Transport" objective -- summing the Wasserstein path lengths across all layers -- and show that maximizing this quantity forces every hidden layer to adopt the specific n scaling factors characteristic of μP.
Networks as Particle Systems
We shall treat the network as a dynamic system of interacting particles.
Model
We focus on a standard 1-hidden-layer (MLP) with width n with biases. For conceptual clarity in the OT setting, we use mean-field normalization:
fn(x)=n1i=1∑naiσ(wi⊤x+bi)+c
(Note: Standard μP uses n1 normalization. The two formulations are related by rescaling a and time. The mean-field form aligns perfectly with OT geometry where measures are probability distributions.)
Here, the "particles" are the hidden neurons. Each neuron i is defined by its parameter tuple θi:
θi≐(ai,wi,bi)∈R×Rd×R
where ai is the output weight, wi is the input weight vector, and bi is the bias. The scalar c is the global output bias.
We train this network to minimize the population square loss L over a data distribution D:
L(θ)=E(x,y)∼D[21(fn(x)−y)2]
Mean-Field Limit
In the "Mean-Field" or Optimal Transport limit, we don't track indices i and instead track the distribution of neurons. We define the empirical measure μt at training time t:
μt=n1i=1∑nδθi(t)
This allows us to write the network output as an integral against this measure. As n→∞, if μt converges to a smooth probability density ρt, the network becomes:
f(x)=∫aσ(w⊤x+b)dρt(θ)+c
This is exactly the mean-field limit: the network function is an integral against the neuron distribution.
Dynamics and Wasserstein Speed
We model training as gradient flow (continuous-time gradient descent). The parameters evolve according to:
θ˙i=−ηn∇θiL(θ)
where ηn is some width-dependent learning rate (mobility).
The core quantity we care about is the Wasserstein kinetic energy. In Optimal Transport geometry (specifically the W2 metric), the instantaneous kinetic energy of the measure μt is:
Ekin(μt)=21∫∥vt(θ)∥2dμt(θ)≈2n1i=1∑nθ˙i2
where vt is the velocity field. The Wasserstein speed (metric speed) is the square root of twice the kinetic energy:
Speed(μt)=(n1i=1∑nθ˙i2)1/2=2Ekin(μt)
This quantity dictates how the system behaves:
If Speed =0, kinetic energy vanishes and the system is frozen (lazy, kernel-like regime).
If Speed =∞, kinetic energy diverges and the system explodes (Unstable regime).
If Speed =O(1), kinetic energy is finite and nonzero—the system flows effectively (Feature Learning / μP).
This is exactly a phase transition story: μP sits at the critical point where kinetic energy is O(1).
Learning Rates as Mobilities
In physics, the relationship between force and velocity is mediated by a mobility tensor. Here, learning rates play the role of mobilities (inverse friction coefficients). We write the dynamics as:
θ˙i=−M(n)∇θiL(θ)
where M(n) is a block-diagonal mobility tensor that scales with width. The choice of scaling for M(n) determines the mobility scale: too small (lazy regime) and particles freeze; too large and particles accelerate infinitely. μP corresponds to the critical mobility where kinetic energy stays finite but nonzero.
Learning rates act like mobilities (inverse friction coefficients) in an overdamped Langevin system: force is −∇L, velocity is θ˙, and AT measures total dissipation.
Computing the Gradients
To understand how the scaling affects dynamics, we need to compute the gradients. Let ui≐wi⊤x+bi denote the preactivation. The gradients are:
∂aiL=E[e(x)n1σ(ui)]
∇wiL=E[e(x)n1aiσ′(ui)x]
∂biL=E[e(x)n1aiσ′(ui)]
∂cL=E[e(x)]
where e(x)=fn(x)−y is the error.
Assume we are in a well-scaled regime where fn(x)=O(1) and the residual e(x)=fn(x)−y remains O(1) on the time horizon of interest (as is standard in scaling analyses). Under standard initialization with ai=O(1) and wi=O(1), this gives us:
∣∂aiL∣=O(n−1)
∥∇wiL∥=O(n−1)
∣∂biL∣=O(n−1)
∣∂cL∣=O(1)
Learning Rate Scaling and Parameter Updates
Now suppose we use different learning rates for different parameter types, scaling with width as:
Under gradient flow, the parameter velocities are:
∣a˙i∣=ηa(n)∣∂aiL∣=O(nγa−1)
∥w˙i∥=ηw(n)∥∇wiL∥=O(nγw−1)
b˙i=ηb(n)∣∂biL∣=O(nγb−1)
∣c˙∣=ηc(n)∣∂cL∣=O(nγc)
Critical Scaling
The Wasserstein speed depends on the maximum velocity across all parameter types:
Speed(μt)=O(nmax(γa,γw,γb)−1)
This gives us three regimes:
Subcritical.max(γa,γw,γb)<1⟹ Speed →0 as n→∞. The neurons barely move; this is the lazy (kernel-like) regime.
Critical.max(γa,γw,γb)=1⟹ Speed =O(1). The neurons move at a non-degenerate rate; this is feature learning.
Supercritical.max(γa,γw,γb)>1 \implies Speed →∞ as n→∞. The dynamics explode and width-transferability breaks.
For the output bias c, the critical scaling is γc=0 since its gradient is already O(1).
Selection Principle: Critical Mobility
Width-transferability requires that the training trajectory in function space converges as n→∞. Empirically and theoretically, this rules out the supercritical regime. So DT=0 (the transferability constraint) effectively enforces:
max(γa,γw,γb)≤1,γc≤0
To maximize feature learning (non-zero transport), we want to maximize the Wasserstein speed subject to this constraint. This forces us to the boundary:
γa=γw=γb=1,γc=0
This corresponds to the μP scaling: hidden-unit parameters get learning rates that scale as n, while the output parameters use constant learning rates. Uniqueness is meant within the family of power-law learning-rate scalings (mobilities) considered here, modulo an overall rescaling of time.
Gauge Equivalence (Mean-Field ↔μP Convention)
In mean-field form, the critical mobility is η≍n for neuron parameters (a,w,b) to achieve O(1) Wasserstein speed in the original time variable. Under the standard μP n1 network, define a~i≐ai/n. Then:
fn(x)=n1i=1∑na~iσ(⋅)
which is the standard n1 form. Under this change of variables, the gradient flow time variable rescales by a factor n (equivalently, learning rates rescale by n). Thus the critical mobility η≍n in mean-field gauge corresponds to η≍n in the standard μP gauge.
Deep Networks: Layerwise Transport
From here on we revert to the standard n1 neural tangent / μP gauge, because it matches common deep-network parameterizations. The mean-field gauge differs by a simple rescaling of output weights and time (as noted above).
Now we lift this logic to a deep MLP. Consider a depth-L fully-connected network with width n in each hidden layer:
h0=x,zℓ=n1Wℓhℓ−1+bℓ,hℓ=σ(zℓ),ℓ=1,…,L
fn(x)=n1w⊤hL+c
We train on the same population square loss: L=21E[(fn(x)−y)2].
Row-Gradient Scaling
The key observation is that gradients for rows of weight matrices scale as O(n−1/2). To see this, define the backprop signal δℓ≐∂f/∂zℓ∈Rn. Under standard forward scaling, each coordinate satisfies δiℓ=O(n−1/2).
For a row Wi:ℓ∈Rn of the weight matrix, the gradient is:
∇Wi:ℓL=E[e(x)δiℓ(x)n1hℓ−1(x)]
Since hℓ−1=O(n), we have (1/n)hℓ−1=O(1), hence:
∇Wi:ℓL=O(n−1/2)
Similarly, for biases:
∂biℓL=O(n−1/2)
For the readout and output bias, we get:
∥∇wL∥=O(1),∣∂cL∣=O(1)
Layerwise Neuron Measures
For each hidden layer ℓ, we can define per-neuron parameters θiℓ≐(Wi:ℓ,biℓ) and the empirical measure:
μtℓ,n=n1i=1∑nδθiℓ(t)
The Wasserstein speed for layer ℓ is:
μ˙tℓ,nW2=O(nmax(γW,γb)−1/2)
where γW and γb are the learning rate exponents for weights and biases in hidden layers.
Total Transport Objective
Define the total transport up to time T as the sum of Wasserstein path lengths across all layers:
ST(s)≐ℓ=1∑Ln→∞liminf∫0Tμ˙tℓ,nW2dt
This measures how much "movement" happens across all layers combined. The selection principle is the same: maximize ST subject to width-transferability.
Deep μP Result (Scaling Argument)
For deep networks, rigorous OT gradient-flow limits are less clean than in the two-layer case. The following is a somewhat handwavy scaling argument that relies on the following assumptions: signals remain O(1), no gradient explosion, and approximate exchangeability across neurons.
then each hidden layer has the same critical exponent 21. The same vanish/explode argument applies:
If max(γW,γb)<21, then μ˙tℓ,nW2→0 and that layer freezes (lazy).
If max(γW,γb)>21, then μ˙tℓ,nW2→∞ and width-transferability breaks.
The boundary max(γW,γb)=21 gives non-degenerate transport.
The mechanism is clearest at the preactivation level. For a typical neuron i at layer ℓ:
ziℓ=n1Wi:ℓhℓ−1+biℓ
Differentiating and focusing on the dominant term:
z˙iℓ⊃n1W˙i:ℓhℓ−1
Since W˙i:ℓ∼nγW−1/2 and hℓ−1∼n:
n1W˙i:ℓhℓ−1∼nγW−1/2
This makes the vanish/critical/blow-up trichotomy immediate: if γW<21, preactivations freeze; if γW>21, they blow up; at γW=21, preactivation updates are O(1)—maximal non-degenerate update.
Maximizing total transport while maintaining transferability forces every hidden layer to be critical:
γW=γb=21,γout=γc=0
This is the deep μP prescription: every hidden layer gets n learning rate scaling, while output parameters use constant rates.
Summary
We've shown that μP emerges naturally from an Optimal Transport perspective:
The OT Characterization of μP
Among all parameterizations that achieve width-transferability, μP is the unique one that maximizes the total Wasserstein transport distance traveled by neuron measures across all layers. This ensures maximal feature learning while maintaining stable, transferable training dynamics. The key insight is that μP corresponds to the critical mobility where kinetic energy is O(1): too small and particles freeze (NTK), too large and dynamics explode.