What really matters in matrix-whitening optimizers?

Kevin Frans

UC Berkeley
Oct 2025

In recent years, increasing growth in scale has resulted in a strong need to understand how neural networks can be trained efficiently. One aspect of training is the choice of optimization strategy, often thought of as the proper transformation to apply to a raw gradient, such that convergence is smoother and faster. The incumbent default is Adam, however, a range of optimizers have been proposed claiming superior performance.

Interestingly, the most performant optimizers all share a similar matrix-whitening transformation, which we will describe shortly, and can generally be derived from the same core principles. At the same time, these optimizers differ in their exact approximations and auxilliary implementation details. A natural question arises – which of these design choices truly influence downstream performance? Let’s dissest and take a closer look.

Descent on the Whitening Metric#

Gradient descent can be seen as solving for a trade-off between linear improvement and a distance penalty over parameters. While standard gradient descent assumes a Euclidean distance, we can generally represent second-order distances and their corresponding solutions via a symmetric positive-definite matrix \(M\):

\[ \Delta \theta = \text{argmin}_{\Delta\theta} \; \underbrace{\; -g^T\Delta\theta \;}_{\text{Improvement}} + \underbrace{(1/2)\Delta\theta^TM\Delta\theta}_{\text{Distance Penalty}} \quad = \quad M^{-1}g, \]

where \(g = \nabla_\theta \; L(\theta, x, y)\) and \(M^{-1}\) is sometimes referred to as a preconditioner.

While there are many possibilities for choosing \(M\), many recent optimizers have converged on a specific form which we refer to as the whitening metric:

\[ \begin{equation} \label{eq:whitening} M_\text{Whitening} = \mathbb{E}_{x,y} \left[ gg^T \right]^{1/2}. \end{equation} \]

Prior works have related this whitening metric to the Hessian and the Fisher information matrix. Adam can be understood as utilizing an elementwise approximation to the whitening metric, resulting in an efficient update where \(m = diag(M)\):

\[ m = E_{x,y} \left[ g^2\right]^{-1/2} \qquad \Delta \theta = g \odot m. \]

Matrix-whitening metric. Two powerful connections appear when we accept that in neural networks, parameters are structured matrices rather than an arbitrary set. First, for dense layers with a gradient \(G \in R^{m,n}\), we can approximate the full whitening metric with its Kronecker factors, as done in the Shampoo formulation. This reduces an \((mn, mn)\) matrix inversion to a cheaper \((n,n)\) and \((m,m)\) inversion:

\[ E_{x,y}[gg^T]^{-1/2}g \quad \leftarrow \text{approx.} \rightarrow \quad E_{x,y}[GG^T]^{-1/4} \; G \; E_{x,y}[G^TG]^{-1/4}. \]

Second, if we ignore the expectation, the term above is equivalent to the orthogonalization of G, i.e. the matrix-sign of G where each of its singular values is set to \(\pm 1\). We can derive this by writing \(G\) as its singular-vector decomposition, \(G = U \Sigma V^T\), after which:

\[ (GG^T)^{-1/4} \; G \; (G^TG)^{-1/4} = (U \Sigma^2 U^T)^{-1/4} \; U \Sigma V^T \; (V \Sigma^2 V^T)^{-1/4} = UV^T. \]

The view of matrix-whitening as a form of spectral normalization is insightful – we can immediately realize that the resulting matrices descend under each singular vector uniformly. The orthogonal update \(UV^T\) is in fact the steepest descent direction under the spectral norm.



Data Points
Gradient Descent
Signed Descent
Spectral Descent

Figure: Affect of taking a descent step using different metrics. Weights are initialized as a 2D identity matrix, represented via the unit circle. Drag to adjust the data points.


As shown above, spectral descent (and accordingly, matrix-whitening) has a number of desirable properties. Spectral normalization amplifies the magnitude of update directions which, while often helpful, are "small" under the raw gradient. For certain losses -- specifically, if the gradient is an orthogonal transformation of the input features -- spectral descent results in update that is invariant to the specific distribution of inputs. The same is not true for raw gradient descent, or elementwise signed descent.

Practically, there are a variety of concrete algorithms that approximate descent on the matrix-whitening metric, which we generally refer to matrix-whitening optimizers. Three common classes approximate the whitening transformation in different ways:

Shampoo + (SOAP, SPlus)

Shampoo works by explicitly tracking the left/right factors \(E[GG^T]\), \(E[G^TG]\) using a moving average:

\[\begin{split} L \leftarrow (1-\beta)L + \beta GG^T \\ R \leftarrow (1-\beta)R + \beta G^TG \end{split}\]

Every \(N\) iterations, these factors are taken to the \((-1/4)\) power, and the results are cached as \(P_L\) and \(P_R\):

\[\begin{split} P_L = \text{matpow}(L, -1/4) \\ P_R = \text{matpow}(R, -1/4) \end{split}\]

The update is calculated as:

\[ \Delta \theta = P_L \; G \; P_R. \]
PSGD-Fisher

PSGD works by fitting \(P\) to the following objective via relative gradient descent:

\[ \text{cost(P)} = E \left[ g^T P g + P^{-1} \right] % \text{cost(P)} = E \left[ P_L G P_R + P^{-1} \right] \]

which is minimized at the original:

\[ P^{*} = E \left[ gg^T \right]^{-1/2}. \]

In practice, \(P\) is constrained to be symmetric, and additionally factored as:

\[ P = (Q_L^T Q_L) \otimes (Q_R ^T Q_R) \]

after which the update is calculated as:

\[ \Delta \theta = (Q_L^T Q_L) \; G \; (Q_R ^T Q_R). \]
Muon

Muon works by directly orthogonalizing the gradients, asserting:

\[ \Delta \theta = UV^T = orth(G). \]

The specific trick to avoid an expensive SVD at every iteration is to use Newton-Schulz iterations to calculate this orthogonalization:

\[ G_{t+1} \leftarrow \dfrac{3}{2} G_t - \dfrac{1}{2} G_t G_t^T G_t, \]

which converges at the limit to

\[ \Delta \theta = G_{\inf} = UV^T. \]

Performance is not solely explained by accurate spectral descent#

Given the tight relationship between the descent under the matrix-whitening metric and the spectral norm, it is tempting to identify accurate spectral descent as the “true” aim of matrix whitening. However, experiments show that this is not neccessarily the case.

First, we find that across the optimizer families, Muon achieves the most accurate spectral descent direction, as shown by a spread of singular values at almost exactly \(\pm 1\). In contrast, methods such as SOAP achieve a looser spread of around \(2\) to \(3\). Adam’s spread is around \(12\). It is reasonable that Shampoo-style optimizers (and to a greater extent SOAP) would result in a less-faithful orthogonalization of the update than Muon, as the preconditioner is calculated using a historical buffer of gradients.

However, when running a carefully-tuned comparison between these methods, we find that SOAP is the method that displays the largest per-step performance, and this gain is consistent across variations in hyperparameters. We utilize the same data, network architecture, and independently sweep over learning rate, weight decay, beta1, and beta2.

matrix-whitening-fig1

In other words, our hypothesis that descending under the spectral norm is the full picture has been countered by the following observation: while Muon descends on the spectral norm more accurately than SOAP, it is SOAP that reliably achieves the greatest per-step improvement.

So, what explains the gap in performance?

Variance adaptation is a crucial matrix-whitening ingredient#

We found that an answer can be found by comparing a design choice present in many optimizers – regardless of the prior or post transformations, a raw update can be normalized by either 1) its instantaneous sign, or 2) by its square-root historical (uncentered) variance, which we refer to as variance adaptation. This distinction can be made explicit by considering three pairs of optimizers:

\[\begin{split} \begin{align} \textbf{Signum:} \; \text{sign($\bar{g}$)} & \quad \rightarrow \quad \textbf{Adam:} \; \bar{g} \; \odot \; \mathbb{E}[g^2]^{-1/2} \\ \textbf{SPlus:} \;\text{unrot(sign(rot($\bar{g}$)))} & \quad \rightarrow \quad \textbf{SOAP:} \; \text{unrot(rot($\bar{g}$)} \; \odot \; \mathbb{E}[\text{rot}(g)^2]^{-1/2}) \\ \textbf{Muon:} \; \text{NS($\bar{g}$)} &\quad \rightarrow \quad \textbf{AdaMuon:} \; \text{NS($\bar{g}$)} \; \odot \; \mathbb{E}[\text{NS}(\bar{g})^2]^{-1/2} \end{align} \end{split}\]

For each pair, the same basis-projection behavior is used (e.g. an identity basis, a rotated eigenbasis, or implicit Newton-Shulz basis), but the elementwise normalizations are handled differently. Note that the Newton-Shulz operator of Muon is implicitly a signed descent method, as it approximates the orthogonalization of \(\bar{g}\) such that all singular values are \(\pm 1\).

matrix-whitening-fig2

The trend is clear: variance-adapted variants of optimizers outperform their strictly signed-descent counterparts. The performance difference is nontrivial – for example, the difference between Muon and Adamuon is almost as large as the difference between Adam and Muon itself, indicating that variance adaptation is roughly as important as the spectral-descent aspect of matrix whitening (on our setup).

Notably, variance adaptation is a natural consequence of the whitening metric, but equivalences between matrix-whitening methods and spectral descent often rely on “disabling the accumulation” and drawing an equivalence to signed descent, which may not be capturing the full picture. In fact, comparing Adam and Muon may be understating the gains from Newton-Schulz orthogonalization; a more fine-grained comparison would be Signum vs. Muon, or Adam vs. AdaMuon. Alternate optimizers that focus solely on orthogonalizing updates will likely benefit from reimplementing variance adaptation in some form.

Raw Gradient

+ Spectral Normalization

+ Variance Adapted

Decoupled variance adaptation. In SOAP, variance adaptation is performed in the rotated eigenbasis, as a pure alternative to signed descent. The same exchange is done in Adam versus Signum. However, in the AdaMuon setup, variance adaptation is performed in the original elementwise basis, after the update has already been spectrally-normalized. In both cases, variance adaptation provides a reliable performance boost. One understanding that may explain this phenomenon is the interpretation of variance adaptation as a heuristic for dynamically adjusting a trust region in proportion to a signal-to-noise ratio. When \(\beta_1=\beta_2\), Adam can be re-written as:

\[ \textbf{Adam:} \quad \text{sign}(\bar{g}) \; \cdot \; (1 + \bar{\sigma}^2/\bar{g}^2)^{-1/2} \]

Under this interpretation, the variance adaptation term serves as a dynamic adjustment to our distance metric, and does not necessarily have to share the same basis as the sign term. We can therefore describe a wider class of matrix-whitening algorithms with the following form:

\[ \textbf{Matrix-Whitened Update:} \quad \underbrace{\text{orth}(\bar{g})}_{\text{Spectral Normalization}} \; \circ \; \; \underbrace{(1 + \tilde{\sigma}^2/\tilde{g}^2)^{-1/2}}_{\text{Variance Adaptation}} \]

where \(\tilde{g}\) and \(\tilde{\sigma}\) can be in a basis of choosing. In other words, we can decouple the spectral-normalizing and variance-adapting properties of matrix-whitening, and implement them in compute-efficient ways.

Discussion#

  • This blog post goes along with our paper, "What really matters in matrix-whitening optimizers?". The paper contains a larger set of experiments, including a carefully tuned benchmark of proposed optimizer families. We encourage readers to try out changes using our evaluation suite, and we have open sourced the codebase along with tuned hyperparameters for each family.

  • The question of which basis to perform variance-adaptation in is a promising open problem. Should we adapt on an elementwise basis, or along a rotated eigenbasis, or something else?

  • On a more fundamental level, the question of where the variance comes from also remains unclear. Do performance gains come from adapting to variance from the stochasticity of a random batch, or from also remains unclear. Do performance gains come from adapting to variance from the stochasticity of a random batch, or variance over time due to oscillations in parameters?

  • A challenge. It is clear that matrix-whitening methods reliably outperform Adam, to a degree of around a 30% speedup in our setting. Within flavors of such optimizers, performance can vary, but only by an additional 5-10%. (See the paper for exact numbers.) What would it take to achieve double these gains? Does a strategy exist that does not follow a matrix-whitening approach?