A Fast Variational Approach for Learning Markov Random Field Language Models Yacine Jernite, Alexander Rush and David Sontag Let’s talk about language modeling ● What is language modeling? ● Probability distribution over sentences ● Text generation, machine translation, speech recognition ● Useful parameters Let’s talk about language modeling ● Language modeling: size 𝐾 context ● The rats scared the cat . ● Language model: 𝑛 𝑝 𝑤 = Π𝑖=1 𝑝(𝑤𝑖 |𝑤𝑖−1 , … , 𝑤𝑖−𝐾 ) ● Multinomial 𝑝(𝑤𝑖 |𝑤𝑖−1 , … , 𝑤𝑖−𝐾 ): ngram models Let’s talk about language modeling ● Neural language models ● Based on embeddings of the vocabulary 𝑉 into ℝ𝐷 ● Expects distributional similarities, same idea as Latent Semantic Analysis scared the cat Let’s talk about language modeling ● Neural language models scared the cat ● Distance dependent context embedding: 2 1 𝑈scared , 𝑈the ∈ ℝ𝐷 ● Target word embedding: 𝑊cat ∈ ℝ𝐷 ● Distribution: 2 1 T exp((𝑈scared + 𝑈the )𝑊cats ) 𝑝 cat scared, the) = 𝑍(scared, the) Bi-directional embedding systems ● How to get better embeddings The rats scared the ● Word2Vec: take advantage of right context −1 𝐾 𝑙 𝑙 T exp 𝑙=−𝐾 𝑈𝑤𝑖+𝑙 + 𝑙=1 𝑈𝑤𝑖+𝑙 𝑊𝑤𝑖 𝑝 𝑤𝑖 𝑤𝐶𝑖 = 𝑍(𝑤𝐶𝑖 ) ● Does not define a distribution over sentences, maximizes 𝑛 𝑖=1 log 𝑝(𝑤𝑖 |𝑤𝐶𝑖 ) cat Our contributions In this work: • We propose a new language model, with similarities to recent embedding learning algorithms • We provide a fast learning algorithm, independent of the corpus size Markov random fields • Graph structure 𝐺 = (𝑉, 𝐸), cliques 𝑐 ∈ 𝐺: 𝑝 𝑋1 , … , 𝑋𝑛 = exp 𝑐∈𝐺 𝜃𝑐 𝑥𝑐 − 𝐴(𝜃) C A 𝜃𝐴,𝐵 B 𝜃𝐵,𝐶,𝐷 D Markov random fields • Graph structure 𝐺 = (𝑉, 𝐸), cliques 𝑐 ∈ 𝐺: 𝑝 𝑋1 , … , 𝑋𝑛 = exp • 𝑐∈𝐺 𝜃𝑐 𝑥𝑐 − 𝐴(𝜃) Log-partition function: 𝐴 𝜃 = log 𝑥 exp 𝑐∈𝐺 𝜃𝑐 𝑥𝑐 Low rank Markov sequence model • <S> • Word distribution depends on size K context: <S> The rats scared the cat . <S> Low rank log-potentials: (𝑗−𝑖) 𝜃𝑖,𝑗 𝑤𝑖 , 𝑤𝑗 = 𝑈𝑤𝑖 exp( 𝑊𝑤T𝑗 −1 𝐾 ′𝑙 ′𝑙 𝑙=−𝐾 𝑈𝑤𝑖+𝑙 + 𝑙=1 𝑈𝑤𝑖+𝑙 ′T ) 𝑊𝑤 𝑖 • 𝑝 𝑤𝑖 𝑤𝐶𝑖 = • Pseudo-likelihood equivalent to Word2Vec objective! −𝑖 = 𝑁 log 𝑝 𝑤 𝑤 𝜓 𝑤 = 𝑁 𝑖 𝐶𝑖 𝑖=1 log 𝑝 𝑤𝑖 𝑤 𝑖=1 𝑍(𝑤𝐶𝑖 ) <S> Running time to learn language models • N-gram model: 𝑂(𝑁 × 𝐾) • Neural Language Models: 𝑂 𝑁 × 𝑉 o • Approximations such as hierarchical softmax can reduce this further MRF likelihood learning: 𝑂(𝑁 × 𝑉 𝐾+1 ) o MRF is treewidth K o We need an approximation of 𝐴(𝜃) Obtaining a tractable approximation for 𝐴(𝜃) 1. Lifted inference • Deriving a symmetrical model • Complexity: 𝑶( 𝑽 𝑲+𝟏 ) • Independent of number of words N 2. Tree Re-Weighted approximation • Wainwright et al., 2005 • Complexity: 𝑶(𝑲 × 𝑽 𝟐 ) What is lifting? • Symmetric graphs What is lifting? • Lifted loopy belief propagation: • Lifted TRW (Bui et al., UAI ‘14) 1. Getting symmetries: cyclic model ● Some regularities The rats scared Then the they cat stole its milk 1. Getting symmetries: cyclic model ● Border effects: sentences <S> <S> <S> <S> The Then rats they scared stole the cat its milk . <S> . <S> <S> <S> 1. Getting symmetries: cyclic model ● Border effects: sentences scared the cat . <S> <S> Then they stole its 1. Getting symmetries: cyclic model ● Broken symmetry: conditioning 1. Getting symmetries: cyclic model ● Broken symmetry: conditioning 𝒑(< 𝐒 >)𝒑 𝒘 < 𝐒 > = 𝒑(𝒘) 2. Tree Re-Weighted approximation ● A(θ) is convex in θ: G T1 T2 2. Tree Re-Weighted approximation Deriving the bound 𝐴 𝜃 = max 𝜃, 𝜇 + 𝐻(𝜇) ≤ 𝜇∈ℳ max 𝜃 cycl , 𝜇 𝜇∈ℳ ≤ max 𝜇∈ℒ 𝑁 𝑖=1 𝑁 = max 𝐾 + 1 𝜇∈ℒ + 𝐻(𝜇) cycl 𝜃𝑖 𝐾 𝑖=0 Symmetrical version , 𝜇𝑖 + 𝐻(𝜇𝑖 ) + cycl 𝜃𝑖 , 𝜇𝑖 cycl 𝑖,𝑗 ∈𝐸 + 𝐻(𝜇𝑖 ) + 1 𝐾 𝑗=1 cycl 𝜃0,𝑗 , 𝜇0,𝑗 TRW − 𝐼(𝜇0,𝑗 ) Lifted TRW, 𝜌 = 𝑠. 𝑡. 𝜇0 = 𝜇1 = ⋯ = 𝜇𝐾 0 𝜃𝑖,𝑗 , 𝜇𝑖,𝑗 − 𝜌𝑖,𝑗 𝐼(𝜇𝑖,𝑗 ) … K 1 𝐾+1 Deriving the bound 𝐾 𝑖=0 max 𝜇∈ℒ cycl 𝜃𝑖 , 𝜇𝑖 + 𝐻(𝜇𝑖 ) + 𝐾 𝑗=1 cycl 𝜃0,𝑗 , 𝜇0,𝑗 − 𝐼(𝜇0,𝑗 ) 𝑠. 𝑡. 𝜇0 = 𝜇1 = ⋯ = 𝜇𝐾 𝐾 𝑖=0 = min max 𝛿 𝜇∈ℒ cycl 𝜃𝑖 + 𝛿𝑖 , 𝜇𝑖 + 𝐻(𝜇𝑖 ) + - 𝛅1 - … - 𝛅K + 𝛅1 0 1 = min 𝐴𝐾 (𝜃 𝛿 ) 𝛿 𝐾 𝑗=1 cycl 𝜃0,𝑗 , 𝜇0,𝑗 − 𝐼(𝜇0,𝑗 ) + 𝛅K … K Deriving the bound 𝐴 𝜃 = max 𝜃, 𝜇 + 𝐻(𝜇) ≤ 𝜇∈ℳ max 𝜃 cycl , 𝜇 𝜇∈ℳ ≤ max 𝜇∈ℒ = = 𝑁 𝑖=1 + 𝐻(𝜇) cycl 𝜃𝑖 𝑁 max 𝐾 + 1 𝜇∈ℒ 𝑁 min max 𝐾 + 1 𝛿 𝜇∈ℒ , 𝜇𝑖 + 𝐻(𝜇𝑖 ) + 𝐾 𝑖=0 , 𝜇𝑖 + 𝐻(𝜇𝑖 ) + 𝐾 𝑗=1 cycl 𝜃𝑖,𝑗 , 𝜇𝑖,𝑗 − 𝐼(𝜇𝑖,𝑗 ) cycl 𝜃0,𝑗 , 𝜇0,𝑗 − 𝐼(𝜇0,𝑗 ) 𝑠. 𝑡. 𝜇0 = 𝜇1 = ⋯ = 𝜇𝐾 𝐾 𝑖=0 𝑁 min 𝐴𝐾 (𝜃 𝛿 ) 𝐾+1 𝛿 = 𝐴(𝜃) = cycl 𝜃𝑖 𝑖,𝑗 ∈𝐸 𝜌𝑖,𝑗 cycl 𝜃𝑖 + 𝛿𝑖 , 𝜇𝑖 + 𝐻(𝜇𝑖 ) + 𝐾 𝑗=1 cycl 𝜃0,𝑗 , 𝜇0,𝑗 − 𝐼(𝜇0,𝑗 ) Dual decomposition Algorithm ● Objective: max 𝜃, 𝜇 − 𝐴(𝜃) 𝜃 ● Collect moments 𝜇 Whereas great delays have been used by sheriffs, gaolers and other officers, to whose custody, any of the King's subjects have been committed for criminal or supposed criminal matters, in making returns of writs of habeas corpus… Distance 1 Algorithm ● Objective: max 𝜃, 𝜇 − 𝐴(𝜃) 𝜃 ● Collect moments 𝜇 Whereas great delays have been used by sheriffs, gaolers and other officers, to whose custody, any of the King's subjects have been committed for criminal or supposed criminal matters, in making returns of writs of habeas corpus… Distance 2 Algorithm ● Objective: max 𝜃, 𝜇 − 𝐴(𝜃) 𝜃 ● Collect moments 𝜇 Whereas great delays have been used by sheriffs, gaolers and other officers, to whose custody, any of the King's subjects have been committed for criminal or supposed criminal matters, in making returns of writs of habeas corpus… Distance 1 Algorithm ● Objective: max 𝜃, 𝜇 − 𝐴(𝜃) 𝜃 ● Collect moments 𝜇 Whereas great delays have been used by sheriffs, gaolers and other officers, to whose custody, any of the King's subjects have been committed for criminal or supposed criminal matters, in making returns of writs of habeas corpus… Distance 2 Algorithm ● Objective: ℒ(𝑤; 𝑈, 𝑉) = 𝜃, 𝜇 − 𝐴(𝜃) ● Gradient descent - 𝛅1 - 𝛅2 + 𝛅1 + 𝛅2 0 1 2 ● Compute 𝐴 𝜃 ● min … 𝛿 ● 𝛻𝛿 : belief propagation ● until convergence, or 𝐴𝛿 𝜃 Algorithm ● Objective: ℒ(𝑤; 𝑈, 𝑉) = 𝜃, 𝜇 − 𝐴(𝜃) - 𝛅*1 - 𝛅*2 + 𝛅*1 + 𝛅*2 0 1 2 ● Gradient descent ● Compute 𝐴 𝜃 = min 𝛿 ● Compute 𝛻𝜃 𝐴 ● Belief propagation with 𝛿 ∗ Algorithm ● Objective: ℒ(𝑤; 𝑈, 𝑉) = 𝜃, 𝜇 − 𝐴(𝜃) - 𝛅*1 - 𝛅*2 + 𝛅*1 + 𝛅*2 0 1 2 ● Gradient descent ● Compute 𝐴 𝜃 = min 𝛿 ● Compute 𝛻𝜃 𝐴 ● Compute 𝛻𝜃 ℒ, 𝛻𝑈,𝑉 ℒ (chain rule) 𝜃 𝑙 = 𝑈𝑙 𝑊 T Complexity: 𝑶(𝑵 × 𝑲 + 𝑲 × 𝑽 𝟐 ) Comparison to exact inference ● Toy dataset for exact inference ● 𝑉 = 𝑎, 𝑏, 𝑐, 𝑑 ● 𝑁 = 14 <S>abdbbcaabdcdac<S> Objective values Comparison to exact inference Learn bound, test bound LBFGS iterations Comparison to exact inference Objective values Learn likelihood, test likelihood Learn bound, test bound LBFGS iterations Objective values Comparison to exact inference Learn likelihood, test likelihood Learn bound, test likelihood Exact log-likelihood Lower bound LBFGS iterations Learn bound, test bound Language modeling ● Penn Treebank dataset ● |𝑉| = 10,000 ● 𝑁 = 1,000,000 Language modeling Objective values NLM test loglikelihood LBFGS iterations Language modeling NLM test loglikelihood Objective values Learning logpotentials Lower bound on test log-likelihood LBFGS iterations Language modeling NLM test loglikelihood Objective values Learning logpotentials Lower bound on test log-likelihood LBFGS iterations Learning word embeddings Language modeling Part-of-Speech Tagging Part-of-Speech Tagging Take away points ● New language model <S> <S> The rats scared the cat . <S> <S> Take away points ● New language model The rats scared the cat . ● Fast learning algorithm 𝑶(𝑲 × 𝑽 𝟐 ) Take away points ● New language model The rats scared the cat . ● Fast learning algorithm ● Wider applicability 𝑶(𝑲 × 𝑽 𝟐 ) Take away points ● Find the code at: https://github.com/srush/MRF-LM ● Questions?
© Copyright 2026 Paperzz