Optimizing Neural Networks with Kronecker-factored Approximate Curvature

PremiumOptimization

Optimizing Neural Networks with Kronecker-factored Approximate Curvature

1 Introduction

We know that the curvature associated with neural network objective functions is highly nondiagonal, and that updates which properly respect and account for this non-diagonal curvature, such as those generated by HF, can make much more progress minimizing the objective than the plain gradient or updates computed from diagonal approximations of the curvature (usually 102\sim 10^2 HF updates are required to adequately minimize most objectives, compared to the 104105\sim 10^4 - 10^5 required by methods that use diagonal approximations). Thus, if we had an efficient and direct way to compute the inverse of a high-quality non-diagonal approximation to the curvature matrix (i.e. without relying on first-order methods like CG) this could potentially yield an optimization method whose updates would be large and powerful like HF's, while being (almost) as cheap to compute as the stochastic gradient.

In this work we develop such a method, which we call Kronecker-factored Approximate Curvature (K-FAC). We show that our method can be much faster in practice than even highly tuned implementations of SGD with momentum on certain standard neural network optimization benchmarks.

The main ingredient in K-FAC is a sophisticated approximation to the Fisher information matrix, which despite being neither diagonal nor low-rank, nor even block-diagonal with small blocks, can be inverted very efficiently, and can be estimated in an online fashion using arbitrarily large subsets of the training data (without increasing the cost of inversion).

2 Background and notation

2.1 Neural networks

For convenience we will define the following additional notation:

Dv=dL(y,f(x,θ))dv=dlogp(yx,θ)dvandgi=Dsi\mathcal{D} v = \frac{d\mathcal{L}(y,f(x,\theta))}{dv} = -\frac{d\log p(y | x,\theta)}{dv} \quad \text{and} \quad g_i = \mathcal{D} s_i

2.2 The natural gradient

F=E[dlogp(yx,θ)dθ(dlogp(yx,θ)dθ)T]=E[DθDθT]F = \mathbb{E}\left[\frac{d\log p(y | x,\theta)}{d\theta} \left(\frac{d\log p(y | x,\theta)}{d\theta}\right)^T\right] = \mathbb{E}[\mathcal{D}\theta \mathcal{D}\theta^T]

3 A block-wise Kronecker-factored Fisher approximation

The main computational challenge associated with using the natural gradient is computing F1F^{-1} (or its product with a h\nabla h). In this section we develop an initial approximation of FF which will be a key ingredient in deriving our efficiently computable approximation to F1F^{-1} and the natural gradient.

Note that

Dθ=(vec(DW1)T,vec(DW2)T,,vec(DWl)T)T\mathcal{D}\theta = \begin{pmatrix} \operatorname{vec}(\mathcal{D}W_1)^T, \operatorname{vec}(\mathcal{D}W_2)^T, \dots, \operatorname{vec}(\mathcal{D}W_l)^T \end{pmatrix}^T

and

F=E[DθDθT],F = \mathbb{E}[\mathcal{D}\theta \mathcal{D}\theta^T],

we see that FF can be viewed as an ll by ll block matrix, with the (i,j)(i, j)-th block F(i,j)F_{(i,j)} given by

F(i,j)=E[vec(DWi)vec(DWj)T].F_{(i,j)} = \mathbb{E}[\operatorname{vec}(\mathcal{D}W_i) \operatorname{vec}(\mathcal{D}W_j)^T].

Noting that DWi=giaˉi1T\mathcal{D}W_i = g_i \bar{a}_{i-1}^T and vec(uvT)=vu\operatorname{vec}(u v^T) = v \otimes u, we have

F(i,j)=E[vec(DWi)vec(DWj)T]=E[(aˉi1gi)(aˉj1gj)T]=E[(ai1aj1T)(gigjT)].\begin{aligned} F_{(i,j)} &= \mathbb{E}[\operatorname{vec}(\mathcal{D}W_i) \operatorname{vec}(\mathcal{D}W_j)^T]\\ &= \mathbb{E}[(\bar{a}_{i-1} \otimes g_i)(\bar{a}_{j-1} \otimes g_j)^T]\\ &= \mathbb{E}[(a_{i-1} a_{j-1}^T) \otimes (g_i g_j^T)]. \end{aligned}

We have the identity (AB)1=A1B1(A \otimes B)^{-1} = A^{-1} \otimes B^{-1}.

Our initial approximation F~\tilde{F} to FF will be defined by the following block-wise approximation:

F(i,j)=E[(ai1aj1T)(gigjT)]E[ai1aj1T]E[gigjT]=Aˉi1,j1Gˉi,j=F~(i,j).F_{(i,j)} = \mathbb{E}[(a_{i-1} a_{j-1}^T) \otimes (g_i g_j^T)] \approx \mathbb{E}[a_{i-1} a_{j-1}^T] \otimes \mathbb{E}[g_i g_j^T] = \bar{A}_{i-1,j-1} \otimes \bar{G}_{i,j} = \tilde{F}_{(i,j)}.

where Aˉi1,j1=E[ai1aj1T]\bar{A}_{i-1,j-1} = \mathbb{E}[a_{i-1} a_{j-1}^T] and Gˉi,j=E[gigjT]\bar{G}_{i,j} = \mathbb{E}[g_i g_j^T].

This gives

F~=(Aˉ0,0Gˉ1,1Aˉ0,1Gˉ1,2Aˉ0,l1Gˉ1,lAˉ1,0Gˉ2,1Aˉ1,1Gˉ2,2Aˉ1,l1Gˉ2,lAˉl1,0Gˉl,1Aˉl1,1Gˉl,2Aˉl1,l1Gˉl,l)\tilde{F} = \begin{pmatrix} \bar{A}_{0,0} \otimes \bar{G}_{1,1} & \bar{A}_{0,1} \otimes \bar{G}_{1,2} & \dots & \bar{A}_{0,l-1} \otimes \bar{G}_{1,l}\\ \bar{A}_{1,0} \otimes \bar{G}_{2,1} & \bar{A}_{1,1} \otimes \bar{G}_{2,2} & \dots & \bar{A}_{1,l-1} \otimes \bar{G}_{2,l}\\ \vdots & \vdots & \ddots & \vdots\\ \bar{A}_{l-1,0} \otimes \bar{G}_{l,1} & \bar{A}_{l-1,1} \otimes \bar{G}_{l,2} & \dots & \bar{A}_{l-1,l-1} \otimes \bar{G}_{l,l} \end{pmatrix}

which has the form of what is known as a Khatri-Rao product in multivariate statistics.

4 Additional approximations to F~\tilde{F} and inverse computations

To the best of our knowledge there is no efficient general method for inverting a Khatri-Rao product like F~\tilde{F}. Thus, we must make further approximations if we hope to obtain an efficiently computable approximation of the inverse Fisher.

In the following subsections we argue that the inverse of F~\tilde{F} can be reasonably approximated as having one of two special structures, either of which make it efficiently computable. The second of these will be slightly less restrictive than the first (and hence a better approximation) at the cost of some additional complexity.

4.1 Structured inverses and the connection to linear regression

4.2 Approximating F~1\tilde{F}^{-1} as block-diagonal

4.3 Approximating F~1\tilde{F}^{-1} as block-tridiagonal

This is a premium article

Sign in and subscribe to unlock the full article.