DR

Rao-Blackwellized Score Matching on Manifolds

Figuring out what denoising score matching actually learns when data is drawn from a low-dimensional structure in a high-dimensional space [arXiv]

(Part of a series of short writeups covering recent work.)

A lot of modern machine learning assumes the manifold hypothesis: the assumption that the data we care about has some structure to it (formally, this means that it lies on or near some manifold). Under the manifold hypothesis, score matching (a widely used method for learning a data distribution) is fundamentally ill-posed. As a result, practitioners use a variant known as denoising score matching, which resolves the ill-posedness problem, but doesn't necessarily guarantee that we learn the right data distribution. In this work, we examine what exactly is learned by denoising score matching.

Main results.

Background

Score Matching

A lot of times, we want to learn a data distribution, but we can't because computing the partition function1 is intractable. So, instead of matching the distribution's values at every point in the domain, we instead match the gradient [Hyv05] (this gets rid of our partition function problem because the partition function is independent of the coordinate xx). This works fantastically for general data, but if we constrain our data to lie on some low dimensional structure, things fall apart rather quickly. Since the probability distribution is nonzero on the manifold, but zero right off the manifold, we have points at which the derivative does not exist and score matching is not well defined. To get around this, we add noise to the data so that it no longer lies exactly on the manifold.2

Denoising score matching. The practical way to do this is denoising score matching (DSM): we take a clean sample Z∼qZ \sim q (here qq is our data distribution defined on the manifold), corrupt it with Gaussian noise to get X≐Z+σξX \doteq Z + \sigma\xi for ξ∼N(0,ID)\xi \sim \mathcal{N}(0, I_D), and train a network to regress against the denoising direction (Zβˆ’X)/Οƒ2(Z-X)/\sigma^2. By Tweedie's formula [Efr11], minimizing this regression loss recovers the score of the noisy density qΟƒ=qβˆ—Ο•Οƒq_\sigma = q * \phi_\sigma, and as Οƒβ†’0+\sigma \to 0^+, this converges to the score of qq itself. When qq is a nice density on RD\mathbb R^D, we are done: DSM gives you the score up to a Οƒ2\sigma^2 smoothing bias.

But under the manifold hypothesis the ambient score doesn't exist at all (qq has no density w.r.t. the Lebesgue measure), so there is nothing to take the gradient of. DSM at Οƒ>0\sigma > 0 is still a well-defined regression problem, it's just not clear what exactly it's estimating. To start thinking about this, it's helpful to split the denoising direction at each noisy point XX into pieces along and orthogonal to the manifold (denoted MM from here on out) at the nearest projection Ο€(X)∈M\pi(X) \in M. The normal piece is just (Ο€(X)βˆ’X)/Οƒ2(\pi(X) - X)/\sigma^2, a deterministic function of XX -- so it contains no information about where ZZ actually came from on MM. All the signal lives in the tangent piece: Tσ≐1Οƒ2PTΟ€(X)M(Zβˆ’X)∈TΟ€(X)M. T_\sigma \doteq \frac{1}{\sigma^2} P_{T_{\pi(X)} M}(Z - X) \in T_{\pi(X)} M.

This leaves a few questions open: what is TΟƒT_\sigma actually estimating? Does it converge to the intrinsic Riemannian score βˆ‡Mlog⁑q\nabla_M \log q? Does it even have a sensible variance as Οƒβ†’0+\sigma \to 0^+?

In order to answer these, we'll need to set up a bit of differential geometry.

Differential Geometry

Differential geometry is concerned with differentiable manifolds: geometric structures that, when we zoom in enough, look like regular flat (Euclidean) space.3 For example, a plane embedded in a three-dimensional space is a manifold. A plane is a special type of manifold since it doesn't have any curvature: most manifolds have two notions of curvature.

Curvature. Extrinsic curvature tells us how MM bends inside the ambient RD\mathbb{R}^D. Intrinsic curvature is inherent to MM and requires no knowledge of the embedding in ambient space. Both end up as linear maps on the tangent space, so we may compare them directly.

At each z∈Mz \in M, the ambient space splits as RD=TzMβŠ•NzM\mathbb{R}^D = T_z M \oplus N_z M, with orthogonal projections PTzMP_{T_z M} and PNzMP_{N_z M}.

Second fundamental form, Weingarten operator, and mean curvature. Consider traveling along MM in some tangent direction uu. The velocity vector can't stay constant in RD\mathbb{R}^D; the manifold is curving, so uu must bend. The part of that bending that points normal to MM is the second fundamental form: IIz(u,v)∈NzMII_z(u, v) \in N_z M, a symmetric bilinear map on TzMΓ—TzMT_z M \times T_z M. It's a super clean object for measuring extrinsic curvature since it directly measures how tangent directions get pushed off the manifold. The Weingarten operator is the exact same information, just in the form of a linear map. For each normal direction ν∈NzM\nu \in N_z M, define WΞ½:TzMβ†’TzMW_\nu : T_z M \to T_z M by ⟨WΞ½u,v⟩=⟨IIz(u,v),ν⟩\langle W_\nu u, v \rangle = \langle II_z(u, v), \nu \rangle. WΞ½W_\nu acts on tangent vectors, so we can apply it directly to things like βˆ‡Mlog⁑q(z)\nabla_M \log q(z) -- which is what shows up in the result. The mean curvature vector is the trace of IIII: H(z)β‰βˆ‘iIIz(ei,ei)∈NzMH(z) \doteq \sum_i II_z(e_i, e_i) \in N_z M for any orthonormal basis (ei)(e_i) of TzMT_z M (the sum is basis-independent). This points in the direction MM is curving "on average."

Ricci endomorphism. Now, let us turn to intrinsic curvature; ignoring the embedding entirely. If we live on MM, we can still measure curvature directly by noticing things like how geodesics4 spread apart. The Ricci endomorphism Ricz♯:TzM→TzM\mathrm{Ric}_z^\sharp : T_z M \to T_z M packages this intrinsic curvature into a linear operator on the tangent space.5

Gauss equation. Intrinsic and extrinsic curvature are actually linked: the metric on MM is induced from the ambient inner product. The exact relationship is given by the Gauss equation Riczβ™―=WH(z)βˆ’βˆ‘Ξ±WnΞ±2\mathrm{Ric}_z^\sharp = W_{H(z)} - \sum_\alpha W_{n_\alpha}^2, where (nΞ±)(n_\alpha) is any orthonormal basis of NzMN_z M. The takeaway is that intrinsic curvature is determined by the embedding: it's the mean-curvature Weingarten minus a separate "sum of squared Weingartens" contribution.

Rao-Blackwellization and the Nearest-Point Projection

Having established our machinery, we may now return to the question we really care about: what is TσT_\sigma actually estimating, and can we control its variance?

The classical statistics move here is Rao-Blackwellization. The Rao-Blackwell Theorem tells us that conditioning an unbiased estimator on a sufficient statistic can only reduce its variance. Intuitively, we keep the expected value, but integrate out the part of the noise that doesn't carry any signal. The natural thing to condition on in our setting is the nearest-point projection Ο€(X)∈M\pi(X) \in M. It captures the manifold-valued summary of the noisy observation -- "where on MM did ZZ likely come from" -- and discards the normal offset, which by construction carries no information about ZZ. Formally, among all fiber-collapsing summaries S(X)S(X) -- statistics depending on XX only through Ο€(X)\pi(X), i.e. Οƒ(S)βŠ†Οƒ(Ο€(X))\sigma(S) \subseteq \sigma(\pi(X)) -- the projection Ο€(X)\pi(X) is the finest. Anything coarser throws away information that Ο€(X)\pi(X) retains.

Define the Rao-Blackwellized tangent target: rΟƒ(z)≐E[TΟƒβˆ£Ο€(X)=z]∈TzMr_\sigma(z) \doteq \mathbb{E}\bigl[T_\sigma \mid \pi(X) = z\bigr] \in T_z M. From the Pythagorean identity for L2L^2 projections, we can show that among all functions hh of Ο€(X)\pi(X), the choice h(Ο€(X))=rΟƒ(Ο€(X))h(\pi(X)) = r_\sigma(\pi(X)) minimizes the regression risk: E∣TΟƒβˆ’h(Ο€(X))∣2=E∣TΟƒβˆ’rΟƒ(Ο€(X))∣2+E∣rΟƒ(Ο€(X))βˆ’h(Ο€(X))∣2. \mathbb{E}|T_\sigma - h(\pi(X))|^2 = \mathbb{E}|T_\sigma - r_\sigma(\pi(X))|^2 + \mathbb{E}|r_\sigma(\pi(X)) - h(\pi(X))|^2.

Thus, rσr_\sigma is the unique (up to null sets) L2L^2-optimal target in this class, and the excess risk of any coarser estimator decomposes into the irreducible "noise" part and a separable "estimator error" part. Intuitively, rσr_\sigma keeps all the signal-bearing information in TσT_\sigma and averages out the rest.

Variance Collapse

Okay, so we found the L2L^2-optimal target, now we want to know if it actually buys us anything quantitatively. The answer is a resounding yes.

The conditional variance of TσT_\sigma given π(X)\pi(X) is the irreducible portion that can't be removed by any fiber-collapsing summary, and we can write it down exactly:

Theorem 1 (Variance collapse under Rao-Blackwellization). Uniformly in z∈Mz \in M, Var⁑(TΟƒβˆ£Ο€(X)=z)=dΟƒ2+O(1)\operatorname{Var}\bigl(T_\sigma \mid \pi(X) = z\bigr) = \frac{d}{\sigma^2} + O(1). Consequently Var⁑(TΟƒ)=d/Οƒ2+O(1)\operatorname{Var}(T_\sigma) = d/\sigma^2 + O(1), while Var⁑(rΟƒ(Ο€(X)))=O(1)\operatorname{Var}(r_\sigma(\pi(X))) = O(1).

The intuition for the d/Οƒ2d/\sigma^2 is pretty straightforward. Conditioned on Ο€(X)=z\pi(X) = z, the latent ZZ is concentrated on MM within a tangent ball of radius βˆΌΟƒ\sim \sigma around zz (this is just the posterior of ZZ given XX). So Zβˆ’XZ - X has typical tangent magnitude βˆΌΟƒ\sim \sigma, and dividing by Οƒ2\sigma^2 gives a typical magnitude of 1/Οƒ1/\sigma. Squaring and summing over the dd tangent dimensions gives d/Οƒ2d/\sigma^2. So, as we shrink the noise, the variance of TΟƒT_\sigma grows, which is exactly the opposite of what we want from a regression target.

Second moment of raw denoising target vs Rao-Blackwellized target vs noise scale on S^2
Second moment of TΟƒT_\sigma (black, slope βˆ’2-2 in log⁑σ\log \sigma, matching d/Οƒ2d/\sigma^2 with d=2d=2) versus the Rao-Blackwellized target rΟƒ(Ο€(X))r_\sigma(\pi(X)) (blue, bounded at the theoretical Eβˆ£βˆ‡Mlog⁑q∣2\mathbb{E}|\nabla_M \log q|^2). Computed on S2S^2 under vMF(ΞΌ,ΞΊ=2\mu, \kappa=2).

The d/Οƒ2d/\sigma^2 floor isn't just a property of our particular estimator -- it's an irreducible Bayes-risk floor for any fiber-collapsing summary. That is, for any statistic S(X)S(X) with Οƒ(S)βŠ†Οƒ(Ο€(X))\sigma(S) \subseteq \sigma(\pi(X)), the best possible L2L^2 predictor of TΟƒT_\sigma based on SS has expected squared error at least d/Οƒ2+O(1)d/\sigma^2 + O(1). Coarsening Ο€(X)\pi(X) doesn't help; only the projection itself achieves the floor exactly.

Οƒ2\sigma^2 Correction

Now that we have a target with bounded variance, the natural next question is what it actually equals. Expanding rΟƒ(z)r_\sigma(z) as a power series in Οƒ\sigma: rΟƒ(z)=βˆ‡Mlog⁑q(z)+Οƒ2[bq(z)+gMext(z)]+o(Οƒ2), r_\sigma(z) = \nabla_M \log q(z) + \sigma^2 \bigl[b_q(z) + g_M^{\text{ext}}(z)\bigr] + o(\sigma^2), uniformly on MM.

The leading term is the intrinsic Riemannian score βˆ‡Mlog⁑q\nabla_M \log q -- exactly what fully-intrinsic methods regress against. So at leading order, ambient DSM (after the Rao-Blackwellization fix) recovers the right thing. The Οƒ2\sigma^2 correction is more interesting. It splits into two distinct pieces:

  1. The intrinsic Tweedie bias bq(z)=12βˆ‡M(Ξ”Mlog⁑q+βˆ£βˆ‡Mlog⁑q∣2)b_q(z) = \frac{1}{2} \nabla_M\left(\Delta_M \log q + |\nabla_M \log q|^2\right). This is the manifold analog of the flat-Tweedie bias: it comes from Gaussian smoothing of the density and would appear even on a flat support. It depends on qq but not on the embedding, and it's the same correction you'd get from intrinsic methods.
  2. The extrinsic curvature term gMext(z)=(12WH(z)βˆ’Riczβ™―)βˆ‡Mlog⁑q(z)g_M^{\text{ext}}(z) = \bigl(\tfrac{1}{2} W_{H(z)} - \mathrm{Ric}_z^\sharp\bigr)\nabla_M \log q(z). This piece matters for non-flat manifolds. It depends on the embedding of MM in RD\mathbb{R}^D via the Weingarten and Ricci operators defined before. It's invisible to any method that corrupts ZZ by intrinsic Riemannian noise -- intrinsic noise never leaves the manifold, and thus never feels the embedding.

A nice property of gMextg_M^{\text{ext}} is that the operator 12WHβˆ’Ricβ™―\tfrac{1}{2} W_H - \mathrm{Ric}^\sharp is purely geometric -- it depends only on how MM sits in RD\mathbb{R}^D, not on qq. So if we've trained a score model s^(z)β‰ˆβˆ‡Mlog⁑q(z)\hat s(z) \approx \nabla_M \log q(z), we can apply this operator to our model's output and subtract Οƒ2\sigma^2 times the result. The correction is fully post-hoc; you don't need to re-train, and you don't need to know qq.6

S2S^2 Cancellation

On SdβŠ‚Rd+1S^d \subset \mathbb{R}^{d+1}, WΞ½=βˆ’IdW_\nu = -\mathrm{Id}, H(z)=βˆ’dzH(z) = -dz, WH(z)=dIdW_{H(z)} = d \mathrm{Id}, and Ricβ™―=(dβˆ’1)Id\mathrm{Ric}^\sharp = (d-1) \mathrm{Id}. Plugging in: gSdext(z)=(d2βˆ’(dβˆ’1))Idβ‹…βˆ‡Sdlog⁑q(z)=(1βˆ’d/2)βˆ‡Sdlog⁑q(z). g_{S^d}^{\text{ext}}(z) = \bigl(\tfrac{d}{2} - (d-1)\bigr) \mathrm{Id} \cdot \nabla_{S^d}\log q(z) = (1 - d/2) \nabla_{S^d}\log q(z). So on the sphere, the extrinsic correction is a scalar multiple of the intrinsic score, with coefficient Ξ±d≐1βˆ’d/2\alpha_d \doteq 1 - d/2:

On S2S^2, the extrinsic curvature correction vanishes identically, and ambient DSM recovers the intrinsic score up to only the flat-Tweedie bias bqb_q.

It's worth emphasizing that this is a coincidence specific to the round sphere, not a general 2-manifold property. The torus T2βŠ‚R3T^2 \subset \mathbb{R}^3 is intrinsically flat (Ricβ™―=0\mathrm{Ric}^\sharp = 0) but extrinsically curved, and the corresponding gT2extg_{T^2}^{\text{ext}} has coefficient +1/2+1/2 -- a measurable embedding-only bias. Higher-dimensional spheres also show the bias clearly: on S6S^6 and S10S^{10}, raw ambient DSM under-concentrates the equilibrium distribution by about 17% and 33% respectively, both removed by the post-hoc correction.

Equilibrium density of z dot mu on S^2, S^6, S^10 under ambient DSM and RB+debias Langevin drifts
Top: equilibrium of Langevin dynamics on SdS^d, ambient DSM (orange) matches the truth (black) exactly on S2S^2, but under-concentrates on S6S^6 and S10S^{10}. The post-hoc correction (blue) recovers the truth in all panels. Bottom: error of an MLP regressor trained with (blue) and without (orange) Rao-Blackwellization as a function of sample size.

This explains a longstanding empirical observation: ambient diffusion models work surprisingly well on most test cases such as climate or weather data, which naturally live on S2S^2.

  1. The partition function is the normalization constant that ensures the probability distribution integrates to unity. ↩
  2. To get some intuition for why this works, it's nice to consider a toy case: for a random variable XX distributed uniformly between 0 and 1, the probability of the event X=0.5X=0.5 is exactly 0, but the probability of the event X∈(.49,.51)X \in (.49, .51) is nonzero. The idea here works similarly in principle. ↩
  3. Here, we just cover the absolute basics, but do Carmo [doC92] is a great reference. ↩
  4. Geodesics are paths with zero acceleration. ↩
  5. Concretely, Ricβ™―\mathrm{Ric}^\sharp is the metric-raised version of the Ricci tensor, which is itself a trace of the full Riemann curvature tensor. You can think of ⟨Ricβ™―u,u⟩\langle \mathrm{Ric}^\sharp u, u \rangle as an average of sectional curvatures of all 2-planes containing uu. ↩
  6. Strictly, you do need your score estimate to be reasonably accurate at the points where you're applying the correction. But you don't need access to qq itself, which is the usual obstacle. ↩

References