A Fast Variational Approach for Learning Markov Random Field

A Fast Variational Approach for
Learning Markov Random Field
Language Models
Yacine Jernite, Alexander Rush and David Sontag
Let’s talk about language
modeling
● What is language modeling?
● Probability distribution over sentences
● Text generation, machine translation,
speech recognition
● Useful parameters
Let’s talk about language
modeling
● Language modeling: size 𝐾 context
●
The
rats
scared
the
cat
.
● Language model:
𝑛
𝑝 𝑤 = Π𝑖=1
𝑝(𝑤𝑖 |𝑤𝑖−1 , … , 𝑤𝑖−𝐾 )
● Multinomial 𝑝(𝑤𝑖 |𝑤𝑖−1 , … , 𝑤𝑖−𝐾 ): ngram models
Let’s talk about language
modeling
● Neural language models
● Based on embeddings
of the vocabulary 𝑉 into ℝ𝐷
● Expects distributional
similarities, same idea as
Latent Semantic Analysis
scared
the
cat
Let’s talk about language
modeling
● Neural language models
scared
the
cat
● Distance dependent context embedding:
2
1
𝑈scared
, 𝑈the
∈ ℝ𝐷
● Target word embedding:
𝑊cat ∈ ℝ𝐷
● Distribution:
2
1
T
exp((𝑈scared
+ 𝑈the
)𝑊cats
)
𝑝 cat scared, the) =
𝑍(scared, the)
Bi-directional embedding
systems
● How to get better embeddings
The
rats
scared
the
● Word2Vec: take advantage of right context
−1
𝐾
𝑙
𝑙
T
exp
𝑙=−𝐾 𝑈𝑤𝑖+𝑙 + 𝑙=1 𝑈𝑤𝑖+𝑙 𝑊𝑤𝑖
𝑝 𝑤𝑖 𝑤𝐶𝑖 =
𝑍(𝑤𝐶𝑖 )
● Does not define a distribution over sentences,
maximizes
𝑛
𝑖=1 log 𝑝(𝑤𝑖 |𝑤𝐶𝑖 )
cat
Our contributions
In this work:
• We propose a new language model, with
similarities to recent embedding learning
algorithms
• We provide a fast learning algorithm,
independent of the corpus size
Markov random fields
•
Graph structure 𝐺 = (𝑉, 𝐸), cliques 𝑐 ∈ 𝐺:
𝑝 𝑋1 , … , 𝑋𝑛 = exp
𝑐∈𝐺 𝜃𝑐
𝑥𝑐 − 𝐴(𝜃)
C
A
𝜃𝐴,𝐵
B
𝜃𝐵,𝐶,𝐷
D
Markov random fields
•
Graph structure 𝐺 = (𝑉, 𝐸), cliques 𝑐 ∈ 𝐺:
𝑝 𝑋1 , … , 𝑋𝑛 = exp
•
𝑐∈𝐺 𝜃𝑐
𝑥𝑐 − 𝐴(𝜃)
Log-partition function:
𝐴 𝜃 = log
𝑥 exp
𝑐∈𝐺 𝜃𝑐
𝑥𝑐
Low rank Markov sequence
model
•
<S>
•
Word distribution depends on size K context:
<S>
The
rats
scared
the
cat
.
<S>
Low rank log-potentials:
(𝑗−𝑖)
𝜃𝑖,𝑗 𝑤𝑖 , 𝑤𝑗 = 𝑈𝑤𝑖
exp(
𝑊𝑤T𝑗
−1
𝐾
′𝑙
′𝑙
𝑙=−𝐾 𝑈𝑤𝑖+𝑙 + 𝑙=1 𝑈𝑤𝑖+𝑙
′T )
𝑊𝑤
𝑖
•
𝑝 𝑤𝑖 𝑤𝐶𝑖 =
•
Pseudo-likelihood equivalent to Word2Vec objective!
−𝑖 = 𝑁 log 𝑝 𝑤 𝑤
𝜓 𝑤 = 𝑁
𝑖 𝐶𝑖
𝑖=1 log 𝑝 𝑤𝑖 𝑤
𝑖=1
𝑍(𝑤𝐶𝑖 )
<S>
Running time to learn
language models
•
N-gram model: 𝑂(𝑁 × 𝐾)
•
Neural Language Models: 𝑂 𝑁 × 𝑉
o
•
Approximations such as hierarchical softmax
can reduce this further
MRF likelihood learning: 𝑂(𝑁 × 𝑉 𝐾+1 )
o MRF is treewidth K
o We need an approximation of 𝐴(𝜃)
Obtaining a tractable
approximation for 𝐴(𝜃)
1. Lifted inference
• Deriving a symmetrical model
• Complexity: 𝑶( 𝑽 𝑲+𝟏 )
• Independent of number of
words N
2. Tree Re-Weighted approximation
• Wainwright et al., 2005
• Complexity: 𝑶(𝑲 × 𝑽 𝟐 )
What is lifting?
• Symmetric graphs
What is lifting?
• Lifted loopy belief propagation:
• Lifted TRW (Bui et al., UAI ‘14)
1. Getting symmetries:
cyclic model
● Some regularities
The
rats
scared
Then
the
they
cat
stole
its
milk
1. Getting symmetries:
cyclic model
● Border effects: sentences
<S>
<S>
<S>
<S>
The
Then
rats
they
scared
stole
the
cat
its
milk
.
<S>
.
<S>
<S>
<S>
1. Getting symmetries:
cyclic model
● Border effects: sentences
scared
the
cat
.
<S>
<S>
Then
they
stole
its
1. Getting symmetries:
cyclic model
● Broken symmetry: conditioning
1. Getting symmetries:
cyclic model
● Broken symmetry: conditioning
𝒑(< 𝐒 >)𝒑 𝒘 < 𝐒 > = 𝒑(𝒘)
2. Tree Re-Weighted
approximation
● A(θ) is convex in θ:
G
T1
T2
2. Tree Re-Weighted
approximation
Deriving the bound
𝐴 𝜃 = max 𝜃, 𝜇 + 𝐻(𝜇)
≤
𝜇∈ℳ
max 𝜃 cycl , 𝜇
𝜇∈ℳ
≤ max
𝜇∈ℒ
𝑁
𝑖=1
𝑁
=
max
𝐾 + 1 𝜇∈ℒ
+ 𝐻(𝜇)
cycl
𝜃𝑖
𝐾
𝑖=0
Symmetrical version
, 𝜇𝑖 + 𝐻(𝜇𝑖 ) +
cycl
𝜃𝑖 , 𝜇𝑖
cycl
𝑖,𝑗 ∈𝐸
+ 𝐻(𝜇𝑖 ) +
1
𝐾
𝑗=1
cycl
𝜃0,𝑗 , 𝜇0,𝑗
TRW
− 𝐼(𝜇0,𝑗 )
Lifted TRW, 𝜌 =
𝑠. 𝑡. 𝜇0 = 𝜇1 = ⋯ = 𝜇𝐾
0
𝜃𝑖,𝑗 , 𝜇𝑖,𝑗 − 𝜌𝑖,𝑗 𝐼(𝜇𝑖,𝑗 )
…
K
1
𝐾+1
Deriving the bound
𝐾
𝑖=0
max
𝜇∈ℒ
cycl
𝜃𝑖
, 𝜇𝑖 + 𝐻(𝜇𝑖 ) +
𝐾
𝑗=1
cycl
𝜃0,𝑗 , 𝜇0,𝑗 − 𝐼(𝜇0,𝑗 )
𝑠. 𝑡. 𝜇0 = 𝜇1 = ⋯ = 𝜇𝐾
𝐾
𝑖=0
= min max
𝛿
𝜇∈ℒ
cycl
𝜃𝑖
+ 𝛿𝑖 , 𝜇𝑖 + 𝐻(𝜇𝑖 )
+
- 𝛅1 - … - 𝛅K
+ 𝛅1
0
1
= min 𝐴𝐾 (𝜃 𝛿 )
𝛿
𝐾
𝑗=1
cycl
𝜃0,𝑗 , 𝜇0,𝑗 − 𝐼(𝜇0,𝑗 )
+ 𝛅K
…
K
Deriving the bound
𝐴 𝜃 = max 𝜃, 𝜇 + 𝐻(𝜇)
≤
𝜇∈ℳ
max 𝜃 cycl , 𝜇
𝜇∈ℳ
≤ max
𝜇∈ℒ
=
=
𝑁
𝑖=1
+ 𝐻(𝜇)
cycl
𝜃𝑖
𝑁
max
𝐾 + 1 𝜇∈ℒ
𝑁
min max
𝐾 + 1 𝛿 𝜇∈ℒ
, 𝜇𝑖 + 𝐻(𝜇𝑖 ) +
𝐾
𝑖=0
, 𝜇𝑖 + 𝐻(𝜇𝑖 ) +
𝐾
𝑗=1
cycl
𝜃𝑖,𝑗 , 𝜇𝑖,𝑗 − 𝐼(𝜇𝑖,𝑗 )
cycl
𝜃0,𝑗 , 𝜇0,𝑗 − 𝐼(𝜇0,𝑗 )
𝑠. 𝑡. 𝜇0 = 𝜇1 = ⋯ = 𝜇𝐾
𝐾
𝑖=0
𝑁
min 𝐴𝐾 (𝜃 𝛿 )
𝐾+1 𝛿
= 𝐴(𝜃)
=
cycl
𝜃𝑖
𝑖,𝑗 ∈𝐸 𝜌𝑖,𝑗
cycl
𝜃𝑖
+ 𝛿𝑖 , 𝜇𝑖 + 𝐻(𝜇𝑖 ) +
𝐾
𝑗=1
cycl
𝜃0,𝑗 , 𝜇0,𝑗 − 𝐼(𝜇0,𝑗 )
Dual decomposition
Algorithm
● Objective: max 𝜃, 𝜇 − 𝐴(𝜃)
𝜃
● Collect moments 𝜇
Whereas great delays have been used by
sheriffs, gaolers and other officers, to whose
custody, any of the King's subjects have been
committed for criminal or supposed criminal
matters, in making returns of writs of habeas
corpus…
Distance 1
Algorithm
● Objective: max 𝜃, 𝜇 − 𝐴(𝜃)
𝜃
● Collect moments 𝜇
Whereas great delays have been used by
sheriffs, gaolers and other officers, to whose
custody, any of the King's subjects have been
committed for criminal or supposed criminal
matters, in making returns of writs of habeas
corpus…
Distance 2
Algorithm
● Objective: max 𝜃, 𝜇 − 𝐴(𝜃)
𝜃
● Collect moments 𝜇
Whereas great delays have been used by
sheriffs, gaolers and other officers, to whose
custody, any of the King's subjects have been
committed for criminal or supposed criminal
matters, in making returns of writs of habeas
corpus…
Distance 1
Algorithm
● Objective: max 𝜃, 𝜇 − 𝐴(𝜃)
𝜃
● Collect moments 𝜇
Whereas great delays have been used by
sheriffs, gaolers and other officers, to whose
custody, any of the King's subjects have been
committed for criminal or supposed criminal
matters, in making returns of writs of habeas
corpus…
Distance 2
Algorithm
● Objective: ℒ(𝑤; 𝑈, 𝑉) = 𝜃, 𝜇 − 𝐴(𝜃)
● Gradient descent
- 𝛅1 - 𝛅2
+ 𝛅1
+ 𝛅2
0
1
2
● Compute 𝐴 𝜃
● min …
𝛿
● 𝛻𝛿 : belief propagation
● until convergence, or 𝐴𝛿 𝜃
Algorithm
● Objective: ℒ(𝑤; 𝑈, 𝑉) = 𝜃, 𝜇 − 𝐴(𝜃)
- 𝛅*1 - 𝛅*2
+ 𝛅*1
+ 𝛅*2
0
1
2
● Gradient descent
● Compute 𝐴 𝜃 = min
𝛿
● Compute 𝛻𝜃 𝐴
● Belief propagation with 𝛿 ∗
Algorithm
● Objective: ℒ(𝑤; 𝑈, 𝑉) = 𝜃, 𝜇 − 𝐴(𝜃)
- 𝛅*1 - 𝛅*2
+ 𝛅*1
+ 𝛅*2
0
1
2
● Gradient descent
● Compute 𝐴 𝜃 = min
𝛿
● Compute 𝛻𝜃 𝐴
● Compute 𝛻𝜃 ℒ, 𝛻𝑈,𝑉 ℒ (chain rule)
𝜃 𝑙 = 𝑈𝑙 𝑊 T
Complexity: 𝑶(𝑵 × 𝑲 + 𝑲 × 𝑽 𝟐 )
Comparison to exact
inference
● Toy dataset for exact inference
● 𝑉 = 𝑎, 𝑏, 𝑐, 𝑑
● 𝑁 = 14
<S>abdbbcaabdcdac<S>
Objective values
Comparison to exact
inference
Learn bound,
test bound
LBFGS iterations
Comparison to exact
inference
Objective values
Learn likelihood,
test likelihood
Learn bound,
test bound
LBFGS iterations
Objective values
Comparison to exact
inference
Learn likelihood,
test likelihood
Learn bound,
test likelihood
Exact log-likelihood
Lower bound
LBFGS iterations
Learn bound,
test bound
Language modeling
● Penn Treebank dataset
● |𝑉| = 10,000
● 𝑁 = 1,000,000
Language modeling
Objective values
NLM test loglikelihood
LBFGS iterations
Language modeling
NLM test loglikelihood
Objective values
Learning logpotentials
Lower bound on
test log-likelihood
LBFGS iterations
Language modeling
NLM test loglikelihood
Objective values
Learning logpotentials
Lower bound on
test log-likelihood
LBFGS iterations
Learning word
embeddings
Language modeling
Part-of-Speech Tagging
Part-of-Speech Tagging
Take away points
● New language model
<S>
<S>
The
rats
scared
the
cat
.
<S>
<S>
Take away points
● New language model
The
rats
scared
the
cat
.
● Fast learning algorithm
𝑶(𝑲 × 𝑽 𝟐 )
Take away points
● New language model
The
rats
scared
the
cat
.
● Fast learning algorithm
● Wider applicability
𝑶(𝑲 × 𝑽 𝟐 )
Take away points
● Find the code at:
https://github.com/srush/MRF-LM
● Questions?