\(\require{physics}\)
A Category Theoretic View of Machine Learning

A Category Theoretic View of Machine Learning

Nov 02, 2025

Using symmetries to better understand how to design attention kernels

Overview

Deep learning models often bake in symmetry and locality through ad-hoc tricks (e.g., convolutional tying, RoPE embeddings, \(\text{SE(3)}\)1 kernels, etc.). Our goal is to see if we can derive attention kernels from a declared symmetry and geometry, rather than hand-designing it.

Preliminaries

Categories

Category
A category \(\mathcal C\) composes objects. It consists of
  • Objects (the "things"), e.g., sets, vector spaces, etc., and
  • Morphisms \(f: X \to Y\), the maps between objects,
so that we respect composition, i.e., \(g \circ f: X \to Z\) for \(f: X \to Y, g: Y \to Z\), and we have an identity morphism \(1_X: X\to X\) for each object. Composition is associative (\(h \circ (g \circ f) = (h \circ g) \circ f\)), and the identity acts trivially (\(f \circ 1_X = f = 1_Y \circ f\))

Categories are particularly useful because they give us (i) a nice way to think about symmetry, and (ii) a way to "translate" problems from one domain to another via functors.

Functors

Functor
A functor \(F: \mathcal C \to \mathcal D\) maps objects and morphisms while preserving composition: \begin{equation} F(f: X \to Y) = F(f): F(X) \to F(Y),\quad F(g \circ f) = F(g) \circ F(f) \end{equation}

A functor acts as an embedding that maps geometric inputs into feature spaces.

If this embedding respects symmetries (e.g. rotation, permutation), we call it an equivariant functor: \begin{equation} F(g \cdot x) = \rho(g)F(x), \end{equation} where \(\rho(g)\) is a tensor representation of the group action on outputs.

Natural Transformations

Two functors \(F, G: \mathcal C \to \mathcal D\) are related by a natural transformation \(\eta_X: F \Rightarrow G\) if for every object X, we have that \begin{equation} \eta_X: F(X) \to G(X) \end{equation} so that for every morphism \(f: X \to Y\), \begin{equation} G(f) \circ \eta_X = \eta_Y \circ F(f). \end{equation}

This is effectively equivalent to saying that changing the representation then applying a map is the same as applying a map then changing representation.

Ends

End
An end is a universal construction that represents the space of maps invariant under a family of morphisms. Formally, given a functor \begin{equation} F: \mathcal{C}^{\rm{op}} \times \mathcal{C} \to \mathbf{Set}, \end{equation} the end \begin{equation} \int_{C \in \mathcal{C}} F(C, C) \end{equation} is the equalizer of all pairs of maps \(F(f, 1)\) and \(F(1, f)\) for every morphism \(f\) in \(\mathcal{C}\). Intuitively, it is the space of "globally consistent" elements -- things that remain unchanged as we move along all morphisms.

Intuitively, ends enforce symmetry consistency.

Coends

Coend
A coend is the dual notion of an end: instead of enforcing invariance, it glues local data into a consistent global object. Formally, for a functor \begin{equation} F: \mathcal{C}^{\mathrm{op}} \times \mathcal{C} \to \mathbf{Set}, \end{equation} the coend \begin{equation} \int^{C \in \mathcal{C}} F(C, C) \end{equation} is the coequalizer that identifies elements across overlaps induced by morphisms in \(\mathcal{C}\).2

Intuitively, coends enforce spatial consistency.

Fibrations

A fibration is a morphism of categories \(\pi: \mathcal{E} \to \mathcal{B}\) that organizes "fibers" of structured objects over a base space. Intuitively, it tells us how a family of parameter spaces (the total category \(\mathcal{E}\)) projects onto model configurations (the base \(\mathcal{B}\)).

Putting it All Together

Equipped with the basic tools (categories, functors, natural transformations, and (co)ends), we are now ready to begin using these tools to better understand machine learning itself through the lens of category theory.

The key theme of the following is more or less that learning is composition inside a structured category.

Computation as a Category

A deep network is a morphism in a computational category \(\mathcal C\): \begin{equation} f_\theta: X \to Y \end{equation} where objects are data/feature spaces and morphisms are differentiable3 maps parameterized by \(\theta\).

Then,

Then, because differentiation (backprop) is itself a functor \(RD: \mathcal C \to \mathcal C\), the category is closed under gradients. Thus, learning is not some new addition to the model -- it is a morphism living in the same category.

Symmetry and Equivariance as Functorial Structure

A symmetry group \(G\) acting on data \(X\) induces a category of representations, where morphisms represent the group action.

An equivariant layer is then just a functor that preserves the action: \begin{equation} F: \rm{Rep}(G,X) \to \rm{Rep}(G,Y) \end{equation} so that \begin{equation} F(g \cdot x) = \rho_Y(g)F(x). \end{equation}

Thus, CNNs, \(SE(3)\) Transformers and permutation-invariant networks are all functors that commute with their respective group actions. From this category theoretic view, we have the guarantee that a composition of equivariant functors is equivariant, so that equivariance is closed under composition.

Universality

Every architecture embeds assumptions about what should remain the same (invariance) and what should fit together (locality). Category theory expresses these as universal constructions:

Thus, together, Ends and Coends provide the principles that organize geometry and locality in learning systems.

Learning Dynamics as Fibration

Training is more than just optimization in parameter space -- it is a geometric process.

We may think of parameters \(\mathcal M\) as fibers sitting over models \(\Theta\) via a fibration \begin{equation} \pi: \mathcal M \to \Theta. \end{equation}

Each point in \(\mathcal M\) (a specific parameterization) correpsonds to the same function \(f_\theta\) in \(\Theta\). A good learning rule is then one that moves "horizontally" along this bundle so that it changes the function itself, not just the coordinates.

A cartesian (natural) update \begin{equation} T \circ \phi = \phi \circ T \end{equation} commutes with any reparametrization \(\phi\), meaning it is coordinate-free. This is the fundamental principle behind natural gradient and K-FAC methods.

Gauge Theories

The categorical framework we just built is near-identical to gauge theories in physics.

Fields, Bundles, and Equivariance

A field in physics assigns a quantity (e.g., a vector or tensor) to every point in space, subject to transformation laws under symmetries. Mathematically, this is a functor: \begin{equation} F: \text{(Spacetime)} \to \text{(Fields)} \end{equation} that maps each region to its field content and each coordinate transformation to the corresponding tensor transformation.

Equivariance in ML plays the same role: \begin{equation} F(g \cdot x) = \rho(g) F(x), \end{equation} where \(g\) is a symmetry (rotation, translation, permutation) and \(\rho(g)\) is its representation on the output space. An equivariant network is then a discrete field theory: a functor assigning representations to points on the data manifold.

Ends as Gauge Constraints

In a gauge theory, physical quantities are those that remain invariant under local transformations. An end \begin{equation} \int_{g \in G} F(g,g) \end{equation} plays exactly this role: it collects all quantities consistent under the action of every \(g \in G\). In our framework, the end defines the space of legal attention kernels -- those consistent with the declared symmetry. This is the categorical version of a gauge constraint.

Coends as Bundle Gluing

A gauge field is not defined globally but patched together from local charts related by transition functions. The coend \begin{equation} \int^{U_i \cap U_j} F(U_i, U_j) \end{equation} formalizes this gluing. This ensures smooth global behavior, much like a well-constructed vector bundle or gauge potential.

Fibration as Connection

In gauge theory, a connection specifies how to move (parallel transport) along a fiber bundle without leaving the gauge orbit. In our framework, a fibration \begin{equation} \pi: \mathcal M \to \Theta \end{equation} plays the same role: it organizes parameter space (\(\mathcal M\)) over model space (\(\Theta\)), and the cartesian update \begin{equation} T \circ \phi = \phi \circ T \end{equation} ensures that learning follows a connection that is natural -- invariant to reparameterization, just as covariant derivatives respect the gauge connection.

Deriving Known Kernels

This is cool and all, but without any application, it's not particularly useful. Here, we use the categorical language to derive some known results.

Namely, by recognizing that an end computes the space of all symmetry-consistent kernels, we may see that all equivariant attention mechanisms are instances of this same universal construction.

The End to Equivariant Kernels

Given a group \(G\) acting on an input manifold \(X\) and an output representation \(\rho_Y\) on \(Y\), the space of all admissible kernels is the end: \begin{equation} \mathsf{Ker}_G(X, Y) = \int_{g \in G} \mathcal{C}\left(g \cdot (X \times X),\, g \cdot \mathrm{End}(Y)\right). \end{equation}

In other words, that means that \(K\) commutes with the symmetry action, the categorical analog of a gauge-invariant kernel.

Translations: Convolutions and CNN Tying

For \(G = \mathbb{R}^2\) (translations on the image plane) acting by shifts \(g \cdot x = x + g\), equivariance requires \begin{equation} K(q + g, k + g) = K(q, k). \end{equation} Thus, \(K(q, k) = \kappa(k - q)\): the kernel depends only on relative position. This is exactly the weight-tying rule of convolution -- a convolution layer is simply an end over the translation group.

Rotations: RoPE and \(SO(2)\) Equivariant Attention

For \(G = \mathrm{SO}(2)\)4, the end yields kernels invariant under rotation: \begin{equation} K(R_\phi q, R_\phi k) = R_\phi K(q, k) R_\phi^{-1}. \end{equation} Parameterizing \(K\) in polar coordinates gives \begin{equation} K(q, k) = \sum_{\ell=0}^{L} f_\ell(r) [\cos(\ell \phi), \sin(\ell \phi)], \end{equation} where \(r = \abs{k - q}\) and \(\phi\) is the relative angle. This recovers rotational position embeddings (RoPE) and circular harmonics kernels as the unique solutions to the equivariance constraint.

3D Rotations: \(SE(3)\) and Molecular Attention

For \(G = \mathrm{SO}(3)\)5 or \(E(3)\)6, the end requires rotational (and possibly translational) invariance: \begin{equation} K(g q, g k) = K(q, k). \end{equation} Expanding in spherical harmonics gives \begin{equation} K(x, y) = \sum_{\ell=0}^L a_\ell P_\ell(x \cdot y), \end{equation} where \(P_\ell\) are Legendre polynomials or real spherical harmonics. These are the basis functions used in \(SE(3)\) Transformers and tensor field networks.

Permutations: DeepSets and Graph Attention

For \(G = S_n\)7, equivariance requires \begin{equation} K(\sigma q, \sigma k) = K(q, k). \end{equation} The unique linear form is \begin{equation} K = \alpha I + \beta \mathbf{1}\mathbf{1}^\top, \end{equation} which is exactly the kernel used in DeepSets and permutation-equivariant graph networks. If we further restrict interactions by adjacency, we recover message passing.

Conclusion

Thus, we have that

At some point, it'd probably be cool to write up some code that computes legal kernels given the symmetries that your data respects, but that's a project for later.

  1. \(SE(3)\) is the special Euclidean group in 3D: the group of rigid body motions (rotations + translations) in 3D space.
  2. Note the superscript \(C \in \mathcal{C}\) instead of the subscript in the end; this is how we denote duality and know that this is a coend not an end (similar to upper-lower index notation).
  3. We require differentiability so that backprop works.
  4. \(SO(2)\) is the special orthogonal group in 2D: the group of rotations in the plane.
  5. Similarly to \(SO(2)\), \(SO(3)\) is the special orthogonal group in 3D: the group of rotations in 3D space.
  6. \(E(3)\) is the Euclidean group in 3D: rotations, reflections, translations in 3D (similar to \(SE(3)\)). \(SE(3)\) is a subgroup of \(E(3)\).
  7. \(S_n\) is the symmetric group: the group of all permutations of \(n\) elements.