SOAP: Improving and Stabilizing Shampoo Using Adam

PremiumOptimization

SOAP: IMPROVING AND STABILIZING SHAMPOO USING ADAM

Introduction

The success of Shampoo has drawn increasing attention from the deep learning community. Several works have explored ways to scale Shampoo by improving its memory and compute efficiency (Wang et al., 2024; Anil et al., 2020; Shi et al., 2023). Other research (Morwani et al., 2024) has examined the theoretical foundations of Shampoo and proposed minor adjustments (such as using power 1/21/2 rather than 1/41/4) that align with prior empirical findings (Anil et al., 2020).

We study SOAP (ShampoO with Adam in the Preconditioner's eigenbasis) an algorithm that runs AdamW in the eigenbasis provided by Shampoo. Our main contributions are as follows:

  • We make a formal connection between the Shampoo and the Adafactor algorithm. This insight leads us to consider the SOAP algorithm, which runs AdamW in the preconditioned space provided by Shampoo.
  • SOAP outperforms both Shampoo and Adam in language model pre-training tasks with model sizes 360m and 660m, even after extensive hyperparameter tuning of Shampoo.
  • SOAP reduces the number of hyperparameters compared to Shampoo, resulting in only one additional hyperparameter compared to AdamW: preconditioning frequency.
  • SOAP demonstrates greater robustness to large preconditioning frequency compared to Shampoo on language model pre-training tasks.

Notation and background

We denote the weight matrix of a neural network layer by WRm×nW \in \mathbb{R}^{m \times n}, and the corresponding gradient by GRm×nG \in \mathbb{R}^{m \times n}. At a given time step tt, these are denoted as WtW_t and GtG_t, respectively. For a batch of inputs at time tt, denoted by BtB_t, the loss and its gradient evaluated at WtW_t are represented as ϕBt(Wt)\phi_{B_t} (W_t) and WϕBt(Wt)\nabla_W \phi_{B_t} (W_t), respectively.

Adam (Kingma & Ba, 2015), a widely used first-order optimization algorithm in deep learning is a diagonal approximation of Adagrad. It maintains an exponential moving average of the gradients GtG_t (denoted as MtM_t) and of element-wise squared gradients Gt2G_t^2 (denoted as VtV_t) for a given weight matrix WW. Its update rule with learning rate η\eta is given by

WtWtηMt/(Vt)W_t \leftarrow W_t - \eta M_t / (\sqrt{V_t})

where the division is performed element-wise.

Adafactor (Shazeer & Stern, 2018; Zhai et al., 2022), a variant of Adam, replaces VtV_t with its best rank-1 approximation VtV'_t to reduce memory usage. While the original Adafactor paper (Shazeer & Stern, 2018) proposed additional modifications, such as changes to the learning rate schedule, we focus on the version of Adafactor proposed in recent works (Zhai et al., 2022; Zhao et al., 2024c), whose update with learning rate η\eta is given by

WtWtηMt/(Vt)W_t \leftarrow W_t - \eta M_t / (\sqrt{V'_t})

Shampoo (Gupta et al., 2018b) is a second-order optimization algorithm that approximates Adagrad and maintains two preconditioners, LtRm×mL_t \in \mathbb{R}^{m \times m} and RtRn×nR_t \in \mathbb{R}^{n \times n}, for a given weight matrix WRm×nW \in \mathbb{R}^{m \times n}. The updates for the preconditioners and the weights with learning rate η\eta are as follows:

LtLt+GtGtT,RtRt+GtTGt,Wt+1WtηLt1/4GtRt1/4L_t \leftarrow L_t + G_t G_t^T, \quad R_t \leftarrow R_t + G_t^T G_t, \quad W_{t+1} \leftarrow W_t - \eta L_t^{-1/4} G_t R_t^{-1/4}

Algorithms

We begin by describing an equivalence between Shampoo and running Adafactor in the eigenbasis of the Shampoo preconditioner. For simplicity we omit momentum but the equivalence also holds with momentum. For this equivalence we use Shampoo with the following modifications from the original Shampoo optimizer (Gupta et al., 2018b):

  • We use power 1/21/2 instead of power 1/41/4. This was already recommended in practical implementations (Anil et al., 2020; Shi et al., 2023) and a theoretical connection between optimal Kronecker approximation of Adagrad (Duchi et al., 2011b) preconditioner and Shampoo with power 1/21/2 was established in Morwani et al. (2024).
  • We also use the scalar correction to per layer learning rates described in Ren & Goldfarb (2021); Morwani et al. (2024).
  • Instead of the running average of LL and RR across time steps, we use dataset averages.

Algorithm 1 Single step of idealized Shampoo with power 1/21/2


  • Sample batch BtB_t.
  • GtRm×nWϕBt(Wt)G_t \in \mathbb{R}^{m \times n} \leftarrow -\nabla_W \phi_{B_t} (W_t)
  • LEB[GBGBT]L \leftarrow \mathbb{E}_B [G_B G_B^T] (Where the expectation is over a random batch B.)
  • REB[GBTGB]R \leftarrow \mathbb{E}_B [G_B^T G_B]
  • H^(LR)/tr(L)\hat{H} \leftarrow (L \otimes R) / \operatorname{tr}(L)
  • Wt+1WtηH^1/2Gt=WtηL1/2GtR1/2/tr(L)1/2W_{t+1} \leftarrow W_t - \eta \hat{H}^{-1/2} G_t = W_t - \eta L^{-1/2} G_t R^{-1/2} / \operatorname{tr}(L)^{-1/2}

This is a premium article

Sign in and subscribe to unlock the full article.