A Stochastic PCA Algorithm with an Exponential Convergence Rate

A Stochastic PCA Algorithm
with an Exponential Convergence Rate
Ohad Shamir
Weizmann Institute of Science
NIPS Optimization Workshop
December 2014
Ohad Shamir
Stochastic PCA with Exponential Convergence
1/19
Principal Component Analysis
PCA
Input: x1 , . . . , xn ∈ Rd
Goal: Find k directions with most
variance
max
U∈Rd×k :U > U=I
n
1 X
> 2
U x
n
i=1
For k = 1: Find leading
eigenvector of covariance matrix
!
n
X
1
max
w>
xi x>
w
i
n
w∈Rd :kwk=1
i=1
Ohad Shamir
Stochastic PCA with Exponential Convergence
2/19
Existing Approaches
n
max
w∈Rd :kwk=1
w>
1X >
xi xi
n
!
w
i=1
Regime: n, d “large”, non-sparse matrix
Ohad Shamir
Stochastic PCA with Exponential Convergence
3/19
Existing Approaches
n
max
w∈Rd :kwk=1
w>
1X >
xi xi
n
!
w
i=1
Regime: n, d “large”, non-sparse matrix
Approach 1: Eigendecomposition
Compute leading eigenvector of
(e.g. via QR decomposition)
1
n
Pn
>
i=1 xi xi
exactly
Runtime: O(d3 )
Ohad Shamir
Stochastic PCA with Exponential Convergence
3/19
Existing Approaches
Approach 2: Power Iterations
Initialize w1 randomly on unit sphere
For t = 1, 2, . . .
Pn
0
:= n1 i=1xi x>
wt+1
i wt =
0
0
/ wt+1
wt+1 := wt+1
O
1
λ
log
1
1
n
Pn
i=1
hwt , xi i xi
iterations for -optimality
λ: Eigengap
O(nd) runtime per iteration
d
Overall runtime O nd
λ log Ohad Shamir
Stochastic PCA with Exponential Convergence
4/19
Existing Approaches
Approach 2: Power Iterations
Initialize w1 randomly on unit sphere
For t = 1, 2, . . .
Pn
0
:= n1 i=1xi x>
wt+1
i wt =
0
0
/ wt+1
wt+1 := wt+1
O
1
λ
log
1
1
n
Pn
i=1
hwt , xi i xi
iterations for -optimality
λ: Eigengap
O(nd) runtime per iteration
d
Overall runtime O nd
λ log Approach 2.5: Lanczos Iterations
More complex
but roughly similar iteration runtime
algorithm,
1
1
and only O √λ log iterations [Kuczyński and
Wozńiakowski 1989] Overall runtime O
nd
√
λ
log
Ohad Shamir
d
Stochastic PCA with Exponential Convergence
4/19
Existing Approaches
Approach 3: Stochastic/Incremental Algorithms
Example (Oja’s algorithm)
Initialize w1 randomly on unit sphere
For t = 1, 2, . . .
Pick it ∈ {1, . . . , n} (randomly or otherwise)
0
:= wt + ηt xit x>
wt+1
t
it w
0
0
/ wt+1
wt+1 := wt+1
Also Krasulina 1969; Arora, Cotter, Livescu, Srebro 2012;
Mitliagkas, Caramanis, Jain 2013; De Sa, Olukotun, Ré 2014...
Ohad Shamir
Stochastic PCA with Exponential Convergence
5/19
Existing Approaches
Approach 3: Stochastic/Incremental Algorithms
Example (Oja’s algorithm)
Initialize w1 randomly on unit sphere
For t = 1, 2, . . .
Pick it ∈ {1, . . . , n} (randomly or otherwise)
0
:= wt + ηt xit x>
wt+1
t
it w
0
0
/ wt+1
wt+1 := wt+1
Also Krasulina 1969; Arora, Cotter, Livescu, Srebro 2012;
Mitliagkas, Caramanis, Jain 2013; De Sa, Olukotun, Ré 2014...
O(d) runtime per iteration
Iteration bounds:
Balsubramani, Dasgupta, Freund 2013: Õ λd2 1 + d
De Sa,Olukotun, Ré 2014: For a different SGD method,
Õ λd2 Runtime: Õ
d2
λ2 Ohad Shamir
Stochastic PCA with Exponential Convergence
5/19
Existing Approaches
Up to constants/log-factors:
Algorithm
Time per iter.
# iter.
Runtime
d3
Exact
Power/Lanczos
nd
1
λp
nd
λp
Incremental
d
d
λ2 d2
λ2 Main Question
Can we get the best of both worlds? O(d) time per iteration and
fast convergence (logarithmic dependence on ?)
Ohad Shamir
Stochastic PCA with Exponential Convergence
6/19
Convex Optimization to the Rescue?
Our problem is equivalent to:
n
min
w:kwk=1
1 X
− hw, xi i2
n
i=1
Much recent progress in solving strongly convex + smooth
problems with finite-sum structure
n
min
w∈W
1X
fi (w)
n
i=1
Stochastic algorithms with O(d) runtime per iteration and
exponential convergence
[Le Roux, Schmidt, Bach 2012; Shalev-Shwartz and Zhang 2012;
Johnson and Zhang 2013; Zhang, Mahdavi, Jin 2013; Konečný and
Richtárik 2013; Xiao and Zhang 2014; Zhang and Xiao, 2014...]
Ohad Shamir
Stochastic PCA with Exponential Convergence
7/19
Convex Optimization to the Rescue?
n
1 X
− hw, xi i2
w:kwk=1 n
min
i=1
Unfortunately:
Function not strongly convex, or even convex
(in fact, concave everywhere)
Has > 1 global optima, plateaus...
⇒ Existing results don’t work as-is
But: Maybe we can borrow some ideas...
Ohad Shamir
Stochastic PCA with Exponential Convergence
8/19
Algorithm
n
1 X
− hw, xi i2
w:kwk=1 n
min
i=1
Oja Iteration
Choose it ∈ {1, . . . , n} at random
0
wt+1
= wt + ηt hwt , xit i xit
wt+1 := w0 / w0 t+1
t+1
Essentially projected stochastic gradient descent
Ohad Shamir
Stochastic PCA with Exponential Convergence
9/19
Algorithm
Letting A =
1
n
Pn
>
i=1 xi xi ,
update step is
0
wt+1
= wt + ηt xit x>
i t wt
= wt +
ηt Awt
| {z }
power/gradient step
Ohad Shamir
+ ηt xit x>
−
A
wt
it
{z
}
|
zero-mean noise
Stochastic PCA with Exponential Convergence
10/19
Algorithm
Letting A =
1
n
Pn
>
i=1 xi xi ,
update step is
0
wt+1
= wt + ηt xit x>
i t wt
= wt +
ηt Awt
| {z }
power/gradient step
+ ηt xit x>
−
A
wt
it
{z
}
|
zero-mean noise
Main idea: Replace by
0
wt+1
= wt +
ηAwt
| {z }
power/gradient step
+ η xit x>
it − A (wt − ũ)
|
{z
}
zero-mean noise
where ũ “close” to wt
(similar to SVRG of Johnson and Zhang (2013))
Ohad Shamir
Stochastic PCA with Exponential Convergence
10/19
Algorithm
VR-PCA
Parameters: Step size η, epoch length m
Input: Data set {xi }ni=1 , Initial unit vector w̃0
For s = 1, 2, . . .
Pn
ũ = n1 i=1 xi x>
i w̃s−1
w0 = w̃s−1
For t = 1, 2, . . . , m
Pick it ∈ {1, . . . , n} uniformly at random
wt0 = wt−1 + η xit x>
it (wt−1 − w̃s−1 ) + ũ
wt = w10 wt0
k tk
w̃s = wm
Ohad Shamir
Stochastic PCA with Exponential Convergence
11/19
Algorithm
VR-PCA
Parameters: Step size η, epoch length m
Input: Data set {xi }ni=1 , Initial unit vector w̃0
For s = 1, 2, . . .
Pn
ũ = n1 i=1 xi x>
i w̃s−1
w0 = w̃s−1
For t = 1, 2, . . . , m
Pick it ∈ {1, . . . , n} uniformly at random
wt0 = wt−1 + η xit x>
it (wt−1 − w̃s−1 ) + ũ
wt = w10 wt0
k tk
w̃s = wm
To get k > 1 directions: Either repeat, or perform orthogonal-like
iterations:
Replace all vectors by k × d matrices
Replace normalization step by orthogonalization step
Ohad Shamir
Stochastic PCA with Exponential Convergence
11/19
Analysis
Theorem
Suppose maxi kxi k2 ≤ r , and A has leading eigenvector v1 .
Assuming hw̃0 , v1 i ≥ √12 , then for any δ, ∈ (0, 1), if
p
2
2r 2 + r
η ≤ c1r δ2 λ , m ≥ c2 log(2/δ)
,
mη
mη 2 log(2/δ) ≤ c3 ,
ηλ
l
m
log(1/)
(where c1 , c2 , c3 are constants) and we run T = log(2/δ)
epochs,
then Pr hw̃T , v1 i2 ≥ 1 − ≥ 1 − 2 log(1/)δ
Corollary
Picking η, m appropriately,
w.h.p.
-convergence
1
1
in O d n + λ2 log runtime
Exponential convergence with O(d)-time iterations
Proportional to # examples plus eigengap
√
Proportional to single data pass if λ ≥ 1/ n
Ohad Shamir
Stochastic PCA with Exponential Convergence
12/19
Proof Idea
Track decay of F (wt ) = 1 − hwt , v1 i2
Key Lemma
Assuming η = αλ and F (wt ) ≤ 3/4,
E [F (wt+1 )|wt ] ≤ 1 − Θ(αλ2 ) F (wt ) + O α2 λ2 F (w̃s−1 ) .
Ohad Shamir
Stochastic PCA with Exponential Convergence
13/19
Proof Idea
Track decay of F (wt ) = 1 − hwt , v1 i2
Key Lemma
Assuming η = αλ and F (wt ) ≤ 3/4,
E [F (wt+1 )|wt ] ≤ 1 − Θ(αλ2 ) F (wt ) + O α2 λ2 F (w̃s−1 ) .
Ohad Shamir
Stochastic PCA with Exponential Convergence
13/19
Proof Idea
Assume η = αλ (α 1)
Ohad Shamir
Stochastic PCA with Exponential Convergence
14/19
Proof Idea
Assume η = αλ (α 1)
Using martingale arguments: W.h.p., never reach “flat” region
in m ≤ O α21λ2 iterations
Ohad Shamir
Stochastic PCA with Exponential Convergence
14/19
Proof Idea
Assume η = αλ (α 1)
Using martingale arguments: W.h.p., never reach “flat” region
in m ≤ O α21λ2 iterations
⇒ For all t ≤ m
E [F (wt+1 )|wt ] ≤
1 − Θ(αλ2 ) F (wt ) + O α2 λ2 F (w̃s−1 ) .
Ohad Shamir
Stochastic PCA with Exponential Convergence
14/19
Proof Idea
Assume η = αλ (α 1)
Using martingale arguments: W.h.p., never reach “flat” region
in m ≤ O α21λ2 iterations
⇒ For all t ≤ m
E [F (wt+1 )|wt ] ≤
1 − Θ(αλ2 ) F (wt ) + O α2 λ2 F (w̃s−1 ) .
Carefully unwinding recursion, and
m using w0= w̃s−1 ,
2
E [F (wm )|w0 ] ≤
1 − Θ(αλ ) + O(α) F (w̃s−1 )
Ohad Shamir
Stochastic PCA with Exponential Convergence
14/19
Proof Idea
Assume η = αλ (α 1)
Using martingale arguments: W.h.p., never reach “flat” region
in m ≤ O α21λ2 iterations
⇒ For all t ≤ m
E [F (wt+1 )|wt ] ≤
1 − Θ(αλ2 ) F (wt ) + O α2 λ2 F (w̃s−1 ) .
Carefully unwinding recursion, and
m using w0= w̃s−1 ,
2
E [F (wm )|w0 ] ≤
1 − Θ(αλ ) + O(α) F (w̃s−1 )
1
⇒ If m ≥ Ω αλ
2 , F (wm ) smaller than F (w̃s−1 ) by some
constant factor
Ohad Shamir
Stochastic PCA with Exponential Convergence
14/19
Proof Idea
Assume η = αλ (α 1)
Using martingale arguments: W.h.p., never reach “flat” region
in m ≤ O α21λ2 iterations
⇒ For all t ≤ m
E [F (wt+1 )|wt ] ≤
1 − Θ(αλ2 ) F (wt ) + O α2 λ2 F (w̃s−1 ) .
Carefully unwinding recursion, and
m using w0= w̃s−1 ,
2
E [F (wm )|w0 ] ≤
1 − Θ(αλ ) + O(α) F (w̃s−1 )
1
⇒ If m ≥ Ω αλ
2 , F (wm ) smaller than F (w̃s−1 ) by some
constant factor
1
Overall, if αλ
2 m constant factor
1
,
α 2 λ2
Ohad Shamir
every epoch shrinks F by
Stochastic PCA with Exponential Convergence
14/19
Experiments
Ran VR-PCA with
Random initialization
m = n (epoch ≈ one pass over data)
√
η = 0.05/ n (based on theory)
Compared to Oja’s algorithm with hand-tuned step-size
Ohad Shamir
Stochastic PCA with Exponential Convergence
15/19
Experiments
λ=0.258
λ=0.163
λ=0.078
0
0
0
−1
−1
−1
−2
−2
−2
−3
−3
−3
−4
−4
−4
−5
0
20
40
60
−5
0
λ=0.016
20
40
60
−5
0
20
40
λ=0.004
0
0
−1
−1
VR−PCA
ηt = 1/t
−2
−2
ηt = 3/t
−3
−3
ηt = 9/t
−4
−4
ηt = 27/t
−5
−5
0
20
40
60
60
0
Ohad Shamir
20
40
60
Stochastic PCA with Exponential Convergence
16/19
Preliminary Experiments
0
−1
VR−PCA
ηt=3
−2
ηt=27
−3
ηt=81
−4
Hybrid
ηt=9
ηt=243
−5
−6
−7
−8
−9
0
10
20
30
40
50
Ohad Shamir
60
70
80
90
100
Stochastic PCA with Exponential Convergence
17/19
Work-in-progress and Open Questions
Ohad Shamir
Stochastic PCA with Exponential Convergence
18/19
Work-in-progress and Open Questions
Runtime is O(d(n + λ12 ) log 1 ) can we get
O(d(n + λ1 ) log 1 or better, analogous to convex case?
Experimentally, required parameters do seem somewhat
different
Generalizing analysis to PCA with several directions
Theoretically optimal parameters without knowing λ
Other “fast-stochastic” approaches
Other non-convex problems
Ohad Shamir
Stochastic PCA with Exponential Convergence
18/19
More details: arXiv 1409:2848
THANKS!
Ohad Shamir
Stochastic PCA with Exponential Convergence
19/19