Optimizing Neural Networks with Kronecker-factored Approximate Curvature
PremiumOptimization
Optimizing Neural Networks with Kronecker-factored Approximate Curvature
1 Introduction
We know that the curvature associated with neural network objective functions is highly nondiagonal, and that updates which properly respect and account for this non-diagonal curvature, such as those generated by HF, can make much more progress minimizing the objective than the plain gradient or updates computed from diagonal approximations of the curvature (usually ∼102 HF updates are required to adequately minimize most objectives, compared to the ∼104−105 required by methods that use diagonal approximations). Thus, if we had an efficient and direct way to compute the inverse of a high-quality non-diagonal approximation to the curvature matrix (i.e. without relying on first-order methods like CG) this could potentially yield an optimization method whose updates would be large and powerful like HF's, while being (almost) as cheap to compute as the stochastic gradient.
In this work we develop such a method, which we call Kronecker-factored Approximate Curvature (K-FAC). We show that our method can be much faster in practice than even highly tuned implementations of SGD with momentum on certain standard neural network optimization benchmarks.
The main ingredient in K-FAC is a sophisticated approximation to the Fisher information matrix, which despite being neither diagonal nor low-rank, nor even block-diagonal with small blocks, can be inverted very efficiently, and can be estimated in an online fashion using arbitrarily large subsets of the training data (without increasing the cost of inversion).
2 Background and notation
2.1 Neural networks
For convenience we will define the following additional notation:
Dv=dvdL(y,f(x,θ))=−dvdlogp(y∣x,θ)andgi=Dsi
2.2 The natural gradient
F=E[dθdlogp(y∣x,θ)(dθdlogp(y∣x,θ))T]=E[DθDθT]
3 A block-wise Kronecker-factored Fisher approximation
The main computational challenge associated with using the natural gradient is computing F−1 (or its product with a ∇h). In this section we develop an initial approximation of F which will be a key ingredient in deriving our efficiently computable approximation to F−1 and the natural gradient.
Note that
Dθ=(vec(DW1)T,vec(DW2)T,…,vec(DWl)T)T
and
F=E[DθDθT],
we see that F can be viewed as an l by l block matrix, with the (i,j)-th block F(i,j) given by
F(i,j)=E[vec(DWi)vec(DWj)T].
Noting that DWi=giaˉi−1T and vec(uvT)=v⊗u, we have
which has the form of what is known as a Khatri-Rao product in multivariate statistics.
4 Additional approximations to F~ and inverse computations
To the best of our knowledge there is no efficient general method for inverting a Khatri-Rao product like F~. Thus, we must make further approximations if we hope to obtain an efficiently computable approximation of the inverse Fisher.
In the following subsections we argue that the inverse of F~ can be reasonably approximated as having one of two special structures, either of which make it efficiently computable. The second of these will be slightly less restrictive than the first (and hence a better approximation) at the cost of some additional complexity.
4.1 Structured inverses and the connection to linear regression