A Category Theoretic View of Machine Learning
Nov 02, 2025
Using symmetries to better understand how to design attention kernels
Overview
Deep learning models often bake in symmetry and locality through ad-hoc tricks (e.g., convolutional tying, RoPE embeddings, \(\text{SE(3)}\)1 kernels, etc.). Our goal is to see if we can derive attention kernels from a declared symmetry and geometry, rather than hand-designing it.
Preliminaries
Categories
- Objects (the "things"), e.g., sets, vector spaces, etc., and
- Morphisms \(f: X \to Y\), the maps between objects,
Categories are particularly useful because they give us (i) a nice way to think about symmetry, and (ii) a way to "translate" problems from one domain to another via functors.
Functors
A functor acts as an embedding that maps geometric inputs into feature spaces.
If this embedding respects symmetries (e.g. rotation, permutation), we call it an equivariant functor: \begin{equation} F(g \cdot x) = \rho(g)F(x), \end{equation} where \(\rho(g)\) is a tensor representation of the group action on outputs.
Natural Transformations
Two functors \(F, G: \mathcal C \to \mathcal D\) are related by a natural transformation \(\eta_X: F \Rightarrow G\) if for every object X, we have that \begin{equation} \eta_X: F(X) \to G(X) \end{equation} so that for every morphism \(f: X \to Y\), \begin{equation} G(f) \circ \eta_X = \eta_Y \circ F(f). \end{equation}
This is effectively equivalent to saying that changing the representation then applying a map is the same as applying a map then changing representation.
Ends
Intuitively, ends enforce symmetry consistency.
Coends
Intuitively, coends enforce spatial consistency.
Fibrations
A fibration is a morphism of categories \(\pi: \mathcal{E} \to \mathcal{B}\) that organizes "fibers" of structured objects over a base space. Intuitively, it tells us how a family of parameter spaces (the total category \(\mathcal{E}\)) projects onto model configurations (the base \(\mathcal{B}\)).
Putting it All Together
Equipped with the basic tools (categories, functors, natural transformations, and (co)ends), we are now ready to begin using these tools to better understand machine learning itself through the lens of category theory.
The key theme of the following is more or less that learning is composition inside a structured category.
Computation as a Category
A deep network is a morphism in a computational category \(\mathcal C\): \begin{equation} f_\theta: X \to Y \end{equation} where objects are data/feature spaces and morphisms are differentiable3 maps parameterized by \(\theta\).
Then,
- composition corresponds to stacking layers \((h \circ g) (x) = h(g(x))\),
- identity corresponds to the skip connection,
- and associativity guarantees that we may regroup computations arbitrarily (the chain rule is associative composition).
Then, because differentiation (backprop) is itself a functor \(RD: \mathcal C \to \mathcal C\), the category is closed under gradients. Thus, learning is not some new addition to the model -- it is a morphism living in the same category.
Symmetry and Equivariance as Functorial Structure
A symmetry group \(G\) acting on data \(X\) induces a category of representations, where morphisms represent the group action.
An equivariant layer is then just a functor that preserves the action: \begin{equation} F: \rm{Rep}(G,X) \to \rm{Rep}(G,Y) \end{equation} so that \begin{equation} F(g \cdot x) = \rho_Y(g)F(x). \end{equation}
Thus, CNNs, \(SE(3)\) Transformers and permutation-invariant networks are all functors that commute with their respective group actions. From this category theoretic view, we have the guarantee that a composition of equivariant functors is equivariant, so that equivariance is closed under composition.
Universality
Every architecture embeds assumptions about what should remain the same (invariance) and what should fit together (locality). Category theory expresses these as universal constructions:
- The End \(\int_C F(C,C)\) is the “maximally consistent” object under all morphisms -- the space of global invariants.
- The Coend \(\int^C F(C,C)\) is the “minimal gluing” that merges overlapping structures.
Learning Dynamics as Fibration
Training is more than just optimization in parameter space -- it is a geometric process.
We may think of parameters \(\mathcal M\) as fibers sitting over models \(\Theta\) via a fibration \begin{equation} \pi: \mathcal M \to \Theta. \end{equation}
Each point in \(\mathcal M\) (a specific parameterization) correpsonds to the same function \(f_\theta\) in \(\Theta\). A good learning rule is then one that moves "horizontally" along this bundle so that it changes the function itself, not just the coordinates.
A cartesian (natural) update \begin{equation} T \circ \phi = \phi \circ T \end{equation} commutes with any reparametrization \(\phi\), meaning it is coordinate-free. This is the fundamental principle behind natural gradient and K-FAC methods.
Gauge Theories
The categorical framework we just built is near-identical to gauge theories in physics.
Fields, Bundles, and Equivariance
A field in physics assigns a quantity (e.g., a vector or tensor) to every point in space, subject to transformation laws under symmetries. Mathematically, this is a functor: \begin{equation} F: \text{(Spacetime)} \to \text{(Fields)} \end{equation} that maps each region to its field content and each coordinate transformation to the corresponding tensor transformation.
Equivariance in ML plays the same role: \begin{equation} F(g \cdot x) = \rho(g) F(x), \end{equation} where \(g\) is a symmetry (rotation, translation, permutation) and \(\rho(g)\) is its representation on the output space. An equivariant network is then a discrete field theory: a functor assigning representations to points on the data manifold.
Ends as Gauge Constraints
In a gauge theory, physical quantities are those that remain invariant under local transformations. An end \begin{equation} \int_{g \in G} F(g,g) \end{equation} plays exactly this role: it collects all quantities consistent under the action of every \(g \in G\). In our framework, the end defines the space of legal attention kernels -- those consistent with the declared symmetry. This is the categorical version of a gauge constraint.
Coends as Bundle Gluing
A gauge field is not defined globally but patched together from local charts related by transition functions. The coend \begin{equation} \int^{U_i \cap U_j} F(U_i, U_j) \end{equation} formalizes this gluing. This ensures smooth global behavior, much like a well-constructed vector bundle or gauge potential.
Fibration as Connection
In gauge theory, a connection specifies how to move (parallel transport) along a fiber bundle without leaving the gauge orbit. In our framework, a fibration \begin{equation} \pi: \mathcal M \to \Theta \end{equation} plays the same role: it organizes parameter space (\(\mathcal M\)) over model space (\(\Theta\)), and the cartesian update \begin{equation} T \circ \phi = \phi \circ T \end{equation} ensures that learning follows a connection that is natural -- invariant to reparameterization, just as covariant derivatives respect the gauge connection.
Deriving Known Kernels
This is cool and all, but without any application, it's not particularly useful. Here, we use the categorical language to derive some known results.
Namely, by recognizing that an end computes the space of all symmetry-consistent kernels, we may see that all equivariant attention mechanisms are instances of this same universal construction.
The End to Equivariant Kernels
Given a group \(G\) acting on an input manifold \(X\) and an output representation \(\rho_Y\) on \(Y\), the space of all admissible kernels is the end: \begin{equation} \mathsf{Ker}_G(X, Y) = \int_{g \in G} \mathcal{C}\left(g \cdot (X \times X),\, g \cdot \mathrm{End}(Y)\right). \end{equation}
In other words, that means that \(K\) commutes with the symmetry action, the categorical analog of a gauge-invariant kernel.
Translations: Convolutions and CNN Tying
For \(G = \mathbb{R}^2\) (translations on the image plane) acting by shifts \(g \cdot x = x + g\), equivariance requires \begin{equation} K(q + g, k + g) = K(q, k). \end{equation} Thus, \(K(q, k) = \kappa(k - q)\): the kernel depends only on relative position. This is exactly the weight-tying rule of convolution -- a convolution layer is simply an end over the translation group.
Rotations: RoPE and \(SO(2)\) Equivariant Attention
For \(G = \mathrm{SO}(2)\)4, the end yields kernels invariant under rotation: \begin{equation} K(R_\phi q, R_\phi k) = R_\phi K(q, k) R_\phi^{-1}. \end{equation} Parameterizing \(K\) in polar coordinates gives \begin{equation} K(q, k) = \sum_{\ell=0}^{L} f_\ell(r) [\cos(\ell \phi), \sin(\ell \phi)], \end{equation} where \(r = \abs{k - q}\) and \(\phi\) is the relative angle. This recovers rotational position embeddings (RoPE) and circular harmonics kernels as the unique solutions to the equivariance constraint.
3D Rotations: \(SE(3)\) and Molecular Attention
For \(G = \mathrm{SO}(3)\)5 or \(E(3)\)6, the end requires rotational (and possibly translational) invariance: \begin{equation} K(g q, g k) = K(q, k). \end{equation} Expanding in spherical harmonics gives \begin{equation} K(x, y) = \sum_{\ell=0}^L a_\ell P_\ell(x \cdot y), \end{equation} where \(P_\ell\) are Legendre polynomials or real spherical harmonics. These are the basis functions used in \(SE(3)\) Transformers and tensor field networks.
Permutations: DeepSets and Graph Attention
For \(G = S_n\)7, equivariance requires \begin{equation} K(\sigma q, \sigma k) = K(q, k). \end{equation} The unique linear form is \begin{equation} K = \alpha I + \beta \mathbf{1}\mathbf{1}^\top, \end{equation} which is exactly the kernel used in DeepSets and permutation-equivariant graph networks. If we further restrict interactions by adjacency, we recover message passing.
Conclusion
Thus, we have that
- category theory gives the language (composition, functors, ends, coends, fibrations),
- physics supplies the intuition (symmetries, gluing, connections),
- machine learning is the instantiation (attention, patches, optimization).
- \(SE(3)\) is the special Euclidean group in 3D: the group of rigid body motions (rotations + translations) in 3D space. ↩
- Note the superscript \(C \in \mathcal{C}\) instead of the subscript in the end; this is how we denote duality and know that this is a coend not an end (similar to upper-lower index notation). ↩
- We require differentiability so that backprop works. ↩
- \(SO(2)\) is the special orthogonal group in 2D: the group of rotations in the plane. ↩
- Similarly to \(SO(2)\), \(SO(3)\) is the special orthogonal group in 3D: the group of rotations in 3D space. ↩
- \(E(3)\) is the Euclidean group in 3D: rotations, reflections, translations in 3D (similar to \(SE(3)\)). \(SE(3)\) is a subgroup of \(E(3)\). ↩
- \(S_n\) is the symmetric group: the group of all permutations of \(n\) elements. ↩