Accelerated Stochastic Gradient Descent Praneeth Netrapalli MSR India Presented at OSL workshop, Les Houches, France. Joint work with Prateek Jain, Sham M. Kakade, Rahul Kidambi and Aaron Sidford Linear regression min ๐ ๐ฅ = ๐ด๐ฅ โ ๐ ๐ฅ 2 2 ๐ฅ โ โ๐ , ๐ด โ โ๐×๐ , ๐ โ โ๐ โข Basic problem and arises pervasively in applications โข Deeply studied in literature Gradient descent for linear regression โค ๐ฅ๐ก+1 = ๐ฅ๐ก โ ๐ฟ โ ๐ด ๐ด๐ฅ๐ก โ ๐ Gradient โข Convergence rate: ๐ โข ๐ โ = min ๐ ๐ฅ ; ๐ฅ ๐ ๐ฅ0 โ๐โ ๐ log ๐ ๐ = Target suboptimality โข Condition number: ๐ = ๐๐๐๐ฅ ๐ดโค ๐ด ๐๐๐๐ ๐ดโค ๐ด Question: Is it possible to do better? Gradient descent: โข Hope: GD does not reuse past gradients โข Answer: Yes! ๐ฅ๐ก+1 = ๐ฅ๐ก โ ๐ฟ โ ๐ป๐ ๐ฅ๐ก ๏Conjugate gradient (Hestenes and Stiefel 1952) ๏Heavy ball method (Polyak 1964) ๏Accelerated gradient descent (Nemirovsky and Yudin 1977, Nesterov 1983) Accelerated gradient descent (AGD) ๐ฅ๐ก+1 = ๐ฆ๐ก โ ๐ฟ๐ป๐ ๐ฆ๐ก ๐ฆ๐ก+1 = ๐ฅ๐ก+1 + ๐พ ๐ฅ๐ก+1 โ ๐ฅ๐ก ๐ ๐ฅ0 โ๐โ ๐ log ๐ โข Convergence rate: ๐ โข ๐ โ = min ๐ ๐ฅ ; ๐ฅ ๐ = Target suboptimality โข Condition number: ๐ = ๐๐๐๐ฅ ๐ดโค ๐ด ๐๐๐๐ ๐ดโค ๐ด Accelerated gradient descent (AGD) Compared to: ๐ ๐ฅ for GD ๐ ๐ฅ0 โ๐โ ๐ log ๐ โข Convergence rate: ๐ โข ๐ โ = min ๐ ๐ฅ ; ๐ ๐ฅ0 โ๐โ ๐ log ๐ ๐ = Target suboptimality โข Condition number: ๐ = ๐๐๐๐ฅ ๐ดโค ๐ด ๐๐๐๐ ๐ดโค ๐ด Source: http://blog.mrtz.org/2014/08/18/robustness-versus-acceleration.html Stochastic approximation (Robbins and Monro 1951) โข Distribution ๐ on โ๐ × โ โข ๐ ๐ฅ โ๐ผ โข Equivalent to ๐ด๐ฅ โ ๐ โข Observe ๐ pairs ๐1 , ๐1 , โฏ , ๐๐ , ๐๐ โข Interested in entire distribution ๐ rather than data points like in ML โข Fit a linear model to the distribution โข Cannot compute exact gradients ๐,๐ โผ๐ ๐โค ๐ฅ โ ๐ 2 2 2 where ๐ด has infinite rows Stochastic gradient descent (SGD) (Robbins and Monro 1951) เท ๐ฅ๐ก where ๐ผ ๐ป๐ เท ๐ฅ๐ก โข ๐ฅ๐ก+1 = ๐ฅ๐ก โ ๐ฟ โ ๐ป๐ โข Return 1 ฯ๐ ๐ฅ๐ ๐ = ๐ป๐ ๐ฅ๐ก (Polyak and Juditsky 1992) โข Is gradient descent in expectation โข For linear regression, SGD: ๐ฅ๐ก+1 = ๐ฅ๐ก โ ๐ฟ โ ๐๐กโค ๐ฅ๐ก โ ๐๐ก ๐๐ก โข Streaming algorithm: extremely efficient and widely used in practice Best possible rate โข Consider ๐ = ๐โค ๐ฅ โ + noise; noise โผ ๐ฉ 0, ๐ 2 โข Recall: ๐ ๐ฅ โ ๐ผ โข ๐ฅเท โ argmin ฯ๐๐=1 ๐ฅ โข ๐ผ ๐ ๐ฅเท โ๐ ๐ฅ โ ๐,๐ โผ๐ ๐๐โค ๐ฅ ๐โค ๐ฅ โ ๐ โ ๐๐ 2 2 = 1+๐ 1 ๐2 ๐ ๐ (van der Vaart, 2000) Best possible rate โข In general: ๐ฅ โ โ argmin ๐ ๐ฅ ๐ฅ โค โ 2 โข ๐ผ ๐ ๐ฅ โ ๐ ๐๐ โค 2 โผ ๐ ๐ผ ๐๐ โค Equivalently ๐ โค 1+๐ 1 โข ๐ฅเท โ ๐ ฯ argmin ๐=1 ๐ฅ โข ๐ผ ๐ ๐ฅเท โ๐ ๐ฅโ โค ๐๐ ๐ฅ โ๐ โค 1+๐ 1 2 ๐2 ๐ ๐ (van der Vaart, 2000) ๐ 2๐ ๐ Convergence rate of SGD โข Convergence rate: ๐เทจ โข ๐ โ = min ๐ ๐ฅ ; ๐ฅ ๐ ๐ฅ0 โ๐โ ๐ log ๐ + ๐2 ๐ ๐ (Jain et al. 2016) ๐ = Target suboptimality โข Condition number: ๐ โ max ๐ 22 ๐๐๐๐ ๐ผ ๐๐โค โข Noise level: ๐ผ ๐โค ๐ฅ โ โ ๐ 2 ๐๐โค โผ ๐ 2 ๐ผ ๐๐โค Recap Deterministic case Stochastic approximation GD ๐ ๐ฅ0 โ ๐ โ ๐ ๐ log ๐ SGD ๐ ๐ฅ0 โ ๐ โ ๐ 2 ๐ ๐เทจ ๐ log + ๐ ๐ ๐ AGD ๐ ๐ฅ0 โ ๐ โ ๐ log ๐ Accelerated SGD? Unknown Question: Is accelerating SGD possible? Is this really important? โข Extremely important in practice ๏As we saw, acceleration can really give orders of magnitude improvement ๏Neural network training uses Nesterovโs AGD as well as Adam; but no theoretical understanding ๏Jain et al. 2016 shows acceleration leads to more parallelizability โข Existing results show AGD not robust to deterministic noise (dโAspremont 2008, Devolder et al. 2014) but is robust to random additive noise (Ghadimi and Lan 2010, Dieuleveut et al. 2016) โข Stochastic approximation falls between the above two cases โข Key issue: mixes optimization and statistics (i.e., # iterations = #samples) Is acceleration possible? โข ๐ = ๐โค ๐ฅ โ โ Noise level: ๐ 2 = 0 โข SGD convergence rate: ๐เทจ โข Accelerated rate: ๐เทจ ๐ ๐ฅ0 โ๐โ ๐ log ๐ ๐ ๐ฅ0 โ๐โ ๐ log ๐ ? Example I: Discrete distribution โข๐= 0 โฎ 0 1 0 โฎ 0 with probability ๐๐ โข In this case, ๐ โ โข Is ๐เทจ max ๐ 22 ๐๐๐๐ ๐ผ ๐๐โค ๐ ๐ฅ0 โ๐โ ๐ log ๐ = 1 ๐min possible? โข Or, halve the error using ๐เทจ ๐ samples? Example I: Discrete distribution โข๐= 0 โฎ 0 1 0 โฎ 0 with probability ๐๐ ; ๐ โ 1 ๐min โข Fewer than ๐ samples โ do not observe ๐min direction โค ฯ โ ๐ ๐๐ ๐๐ not invertible โข Cannot do better than ๐ ๐ โข Acceleration not possible Example II: Gaussian โข ๐ โผ ๐ฉ 0, ๐ป , ๐ป is a PSD matrix โข In this case, ๐ โผ Tr ๐ป ๐min ๐ป โฅ๐ โข However, after ๐ ๐ samples: 1 ฯ ๐ ๐ ๐๐ ๐๐โค โผ ๐ป โข Possible to solve ๐๐โค ๐ฅ โ = ๐๐ after O ๐ samples โข Acceleration might be possible Discrete vs Gaussian Discrete distribution Gaussian distribution Key issue: matrix spectral concentration โข Recall: ๐๐ โผ ๐. Let ๐ป โ ๐ผ ๐๐ ๐๐โค . โข For ๐ฅเท โ argmin ฯ๐๐=1 ๐ฅ ๐๐โค ๐ฅ โ ๐๐ 2 to be good, need: ๐ 1 1 โ ๐ฟ ๐ป โผ เท ๐๐ ๐๐โค โผ 1 + ๐ฟ ๐ป ๐ ๐=1 โข How many samples are required for spectral concentration? Separating optimization and statistics Matrix variance (Tropp 2012): ๐ผ ๐ 2 โค 2 ๐๐ 2 Recall ๐ป โ ๐ผ ๐๐โค Statistical condition number: ๐ ว โ ๐ผ 1 โ2 ๐ป ๐ 2 2 1 โ2 1 โ2 ๐ป ๐ ๐ป ๐ โค 2 Matrix Bernstein Theorem (Tropp 2015) If ๐ > ๐ ๐ ว , then 1 โ ๐ฟ ๐ป โผ 1 ๐ ฯ๐=1 ๐๐ ๐๐โค ๐ โผ 1+๐ฟ ๐ป Is acceleration possible? โข ๐ ๐ ว samples sufficient โข Recall SGD convergence rate: ๐เทจ โข Always ๐ ว โค ๐ . ๐ ๐ฅ0 โ๐โ ๐ log ๐ Acceleration might be possible if ๐ ว โช ๐ โข Discrete case: ๐ ว = 1 ๐min = ๐ ; Gaussian case: ๐ ว = ๐ ๐ โค ๐ Result โข Convergence rate of ASGD: ๐เทจ โข Compared to SGD: ๐เทจ ๐ ๐ฅ0 โ๐โ ๐ ๐ ว log ๐ ๐ ๐ฅ0 โ๐ โ ๐ log ๐ ๐2 ๐ + ๐ ๐2 ๐ + ๐ โข Improvement since ๐ ว โค ๐ โข Conjecture: lower bound ฮฉ Srebro 2016) ๐ ๐ฅ0 โ๐ โ ๐ ๐ ว log ๐ (inspired by Woodworth and Key takeaway โข โข Acceleration possible! Gain depends on statistical condition number Simulations โ No noise Discrete distribution Gaussian distribution Simulations โ With noise Discrete distribution Gaussian distribution High level challenges โข Several versions of accelerated algorithms known e.g., conjugate gradient 1952, heavy ball 1964, momentum methods 1983, accelerated coordinate descent 2012, linear coupling 2014 โข Many of them are equivalent in deterministic setting but not in stochastic setting โข Many different analyses even for momentum methods: Nesterovโs analysis 1983, coordinate descent 2012, ODE analysis 2013, linear coupling 2014 Algorithm Parameters: ๐ผ, ๐ฝ, ๐พ, ๐ฟ 1. ๐ฃ0 = ๐ฅ0 2. ๐ฆ๐กโ1 = ๐ผ๐ฅ๐กโ1 + 1 โ ๐ผ ๐ฃ๐กโ1 3. ๐ฅ๐ก = ๐ฆ๐กโ1 โ ๐ฟ๐ป๐ ๐ฆ๐กโ1 4. ๐ง๐กโ1 = ๐ฝ๐ฆ๐กโ1 + 1 โ ๐ฝ ๐ฃ๐กโ1 5. ๐ฃ๐ก = ๐ง๐กโ1 โ ๐พ๐ป๐ ๐ฆ๐กโ1 Nesterov 2012 Parameters: ๐ผ, ๐ฝ, ๐พ, ๐ฟ 1. ๐ฃ0 = ๐ฅ0 2. ๐ฆ๐กโ1 = ๐ผ๐ฅ๐กโ1 + 1 โ ๐ผ ๐ฃ๐กโ1 3. ๐ฅ๐ก = ๐ฆ๐กโ1 โ ๐ฟ ๐ปเท ๐ก ๐ ๐ฆ๐กโ1 4. ๐ง๐กโ1 = ๐ฝ๐ฆ๐กโ1 + 1 โ ๐ฝ ๐ฃ๐กโ1 5. ๐ฃ๐ก = ๐ง๐กโ1 โ ๐พ๐ปเท ๐ก ๐ ๐ฆ๐กโ1 Our algorithm Proof overview โข Recall our guarantee: ๐เทจ ๐ ๐ฅ0 โ๐โ ๐ ๐ ว log ๐ + ๐2 ๐ ๐ โข First term depends on initial error; second is statistical error โข Different analyses for the two terms โข For the first term, analyze assuming ๐ = 0 โข For the second term, analyze assuming ๐ฅ0 = ๐ฅ โ Part I: Potential function โข Iterates ๐ฅ๐ก , ๐ฃ๐ก of ASGD. ๐ป โ ๐ผ ๐๐โค . โข Existing analyses use potential function ๐ฅ๐ก โ ๐ฅ โ 2๐ป + ๐min ๐ป โ ๐ฃ๐ก โ ๐ฅ โ โข We show โค 1โ 2 2 ๐ฅ๐ก โ ๐ฅ โ โข We use ๐ฅ๐ก โ ๐ฅ โ 1 ๐ เทฅ ๐ ๐ฅ๐กโ1 โ ๐ฅ โ 2 2 + ๐min ๐ป โ ๐ฃ๐ก โ ๐ฅ โ + ๐min ๐ป โ ๐ฃ๐ก โ ๐ฅ โ 2 2 2 ๐ป โ1 + ๐min ๐ป โ ๐ฃ๐กโ1 โ ๐ฅ โ 2 ๐ป โ1 2 2 2 ๐ป โ1 Part II: Stochastic process analysis ๐ฅ๐ก+1 โ ๐ฅ โ ๐ฅ๐ก โ ๐ฅ โ โ =๐ถ โ + noise ๐ฆ๐ก+1 โ ๐ฅ ๐ฆ๐ก โ ๐ฅ โ โค โ ๐ฅ๐ก โ ๐ฅ ๐ฅ๐ก โ ๐ฅ Let ๐๐ก โ ๐ผ ๐ฆ โ ๐ฅ โ ๐ฆ โ ๐ฅ โ ๐ก ๐ก ๐๐ก+1 = ๐ ๐๐ก + noise โ noiseโค ๐๐ โ เท ๐ ๐ noise โ noiseโค ๐ = ๐โ๐ โ1 noise โ noiseโค Parameters: ๐ผ, ๐ฝ, ๐พ, ๐ฟ 1. ๐ฃ0 = ๐ฅ0 2. ๐ฆ๐กโ1 = ๐ผ๐ฅ๐กโ1 + 1 โ ๐ผ ๐ฃ๐กโ1 เท ๐ฆ๐กโ1 3. ๐ฅ๐ก = ๐ฆ๐กโ1 โ ๐ฟ ๐ป๐ 4. ๐ง๐กโ1 = ๐ฝ๐ฆ๐กโ1 + 1 โ ๐ฝ ๐ฃ๐กโ1 เท ๐ฆ๐กโ1 5. ๐ฃ๐ก = ๐ง๐กโ1 โ ๐พ๐ป๐ Our algorithm Part II: Stochastic process analysis โข Need to understand ๐ โ ๐ โ1 noise โ noiseโค โข ๐ has singular values > 1, but fortunately eigenvalues < 1 โข Solve the 1-dim version of ๐ โ ๐ computations โ1 noise โ noiseโค via explicit โข Combine the 1-dim bounds with (statistical) condition number bounds โข ๐โ๐ โ1 noise โ noiseโค โผ ๐ ๐ป ว โ1 + ๐ฟ โ ๐ผ Recap Deterministic case Stochastic approximation GD ๐ ๐ฅ0 โ ๐ โ ๐ ๐ log ๐ SGD ๐ ๐ฅ0 โ ๐ โ ๐ 2 ๐ ๐เทจ ๐ log + ๐ ๐ ๐ AGD ๐ ๐ฅ0 โ ๐ โ ๐ log ๐ ๐เทจ ASGD ๐ ๐ฅ0 โ ๐ โ ๐ 2 ๐ ๐ ๐ ว log + ๐ ๐ โข Acceleration possible โ depends on statistical condition number โข Techniques: new potential function, stochastic process analysis โข Conjecture: Our result is tight Streaming optimization for ML โข Streaming algorithms are very powerful for ML applications โข SGD and variants widely used in practice โข Classical stochastic approximation focuses on asymptotic rates โข Tools from optimization help obtain strong finite sample guarantees โข Have implications for parallelization as well Some examples โข Linear regression ๏Finite sample guarantees: Moulines and Bach 2011, Defossez and Bach 2015 ๏Parallelization: Jain et al. 2016 ๏Acceleration: This talk โข Smooth convex functions: ๏Finite sample guarantees: Bach and Moulines 2013 โข PCA: Ojaโs algorithm ๏Rank-1: Balsubramani et al. 2013, Jain et al. 2016 ๏Higher rank: Allen-Zhu and Li 2016 Open problems โข Linear regression: Parameter free algorithm e.g., conjugate gradient โข General convex functions: Acceleration, parallelization? โข Non-convex functions: Streaming algorithms, acceleration, parallelization? โข PCA: Tight finite sample guarantees? โข Quasi-Newton methods Thank you! Questions? [email protected]
© Copyright 2026 Paperzz