Joint distribution optimal transportation for domain adaptation

Joint distribution optimal transportation for
domain adaptation
R. Flamary
Joint work with N. Courty, A. Habrard, A. Rakotomamonjy,
Optimal Transport meets Probability, Statistics and Machine Learning
Oaxaca, Mexico, 2017
Table of content
Optimal transport for domain adaptation
Supervised learning
Domain adaptation and optimal transport
Discussion : labels and final classifier ?
Joint distribution optimal transport for domain adaptation (JDOT)
Joint distribution and classifier estimation
Generalization bound
Learning with JDOT : regression and classification
Numerical experiments
Caltech-Office classification dataset
Amazon Review Classification dataset
Wifi localization regression dataset
Conclusion
2 / 20
Supervised learning
Traditional supervised learning
I
We want to learn predictor such that
y ≈ f (x).
I
Actual P(X, Y ) unknown.
I
We have access to training dataset
b
(xi , yi )i=1,...,n (P(X,
Y )).
I
We choose a loss function L(y, f (x)) that
measure the discrepancy.
Empirical risk minimization
We week for a predictor f minimizing
(
min
f
E
b
(x,y)∼P
L(y, f (x)) =
)
X
L(yj , f (xj ))
(1)
j
I
Well known generalization results for predicting on new data.
I
Loss is usually L(y, f (x)) = (y − f (x))2 for least square regression and is
L(y, f (x)) = max(0, 1 − yf (x))2 for squared Hinge loss SVM.
3 / 20
Domain Adaptation problem
Our context
I
Classification problem with data coming from different sources (domains).
I
Distributions are different but related.
Problem
I
Labels only available in the source domain, and classification is conducted in the
target domain.
I
Classifier trained on the source domain data performs badly in the target domain
4 / 20
Unsupervised domain adaptation problem
Amazon
DLSR
Feature extraction
+ Labels
Feature extraction
no labels !
not working !!!!
decision function
Source Domain
Target Domain
Problem
I
Labels only available in the source domain, and classification is conducted in the
target domain.
I
Classifier trained on the source domain data performs badly in the target domain
5 / 20
Domain adaptation short state of the art
Reweighting schemes [Sugiyama et al., 2008]
I
Distribution change between domains.
I
Reweigh samples to compensate this change.
Subspace methods
I
Data is invariant in a common latent subspace.
I
Minimization of a divergence between the
projected domains [Si et al., 2010].
I
Use additional label information
[Long et al., 2014b].
Gradual alignment
I
Alignment along the geodesic between source
and target subspace
[R. Gopalan and Chellappa, 2014].
I
Geodesic flow kernel [Gong et al., 2012].
6 / 20
Optimal transport for domain adaptation
Dataset
Optimal transport
Class 1
Class 2
Samples
Samples
Classifier on
Classification on transported samples
Samples
Samples
Samples
Samples
Classifier on
Assumptions
I There exist a transport in the feature space T between the two domains.
I
The transport preserves the conditional distributions:
Ps (y|xs ) = Pt (y|T(xs )).
3-step strategy [Courty et al., 2016a]
1. Estimate optimal transport between distributions.
2. Transport the training samples with barycentric mapping .
3. Learn a classifier on the transported training samples.
7 / 20
Optimal transport for domain adaptation
Dataset
Optimal transport
Class 1
Class 2
Samples
Samples
Classifier on
Classification on transported samples
Samples
Samples
Samples
Samples
Classifier on
Discussion
I Works very well in practice and handle large class of transformation.
I
Step 1 and 2 can be fused by estimating the mapping [Perrot et al., 2016].
But
I
Model transformation only in the feature space.
I
Requires the same class proportion between domains [Tuia et al., 2015].
I
Barycentric mapping is an approximation.
I
In the end we search for a classifier f : Rd → R, mapping is much more complex.
8 / 20
Joint distribution and classifier estimation
Objectives of JDOT
I
Model the transformation of labels (allow change of proportion/value).
I
Learn an optimal target predictor with no labels on target samples.
I
Approach theoretically justified.
Joint distributions and dataset
I
We work with the joint feature/label distributions.
I
Let Ω ∈ Rd be a compact input measurable space of dimension d and C the set
of labels.
I
Let Ps (X, Y ) ∈ P(Ω × C) and Pt (X, Y ) ∈ P(Ω × C) the source and target joint
distribution.
PNs
We have access to an empirical sampling Pˆs = 1
δxs ,ys of the source
I
Ns
i=1
i
i
s Ns
s
distribution defined by Xs = {xsi }N
i=1 and label information Ys = {yi }i=1 .
I
The target domain is defined only by an empirical distribution in the feature
t
space with samples Xt = {xti }N
i=1 .
9 / 20
Joint distribution OT (JDOT)
Proxy joint distribution
I
Let f be a Ω → C function from a given class of hypothesis H.
I
We define the following joint distribution that use f as a proxy of y
Ptf = (x, f (x))x∼µt
f
and its empirical counterpart P̂t =
1
Nt
PNt
i=1
(2)
δxti ,f (xti ) .
Learning with JDOT
We propose to learn the predictor f that minimize :
(
)
X
f
min
W1 (Pˆs , P̂t ) = inf
D(xsi , yis ; xtj , f (xtj ))γ
f
γ∈∆
ij
(3)
ij
I
∆ is the transport polytope.
I
D(xsi , yis ; xtj , f (xtj )) = αkxsi − xtj k2 + L(yis , f (xtj )) with α > 0.
I
We search for the predictor f that better align the joint distributions.
I
Objective value is Transportation Lp [Thorpe et al., 2016], we optimize f .
10 / 20
Generalization bound (1)
Expected loss
The target expcted loss for a given predictor f is defined as
def
errT (f ) =
E
L(y, f (x)).
(x,y)∼Pt
Similary we have on the target domain errT (f, g) = E(x,y)∼Pt L(g(x), f (x)) and the
inter function loss errT (f, g) = E(x,y)∼Pt L(g(x), f (x)).
Probabilistic Lipschitzness [Urner et al., 2011, Ben-David et al., 2012]
Let φ : R → [0, 1]. A labeling function f : Ω → R is φ-Lipschitz with respect to a
distribution P over Ω if for all λ > 0
P rx∼P [∃y : [|f (x) − f (y)| > λd(x, y)]] ≤ φ(λ).
11 / 20
Generalization bound (2)
Theorem 1
Let fT∗ and fS∗ be the two optimal labeling functions that verifies the φ-probabilistic
Lipschitzness assumption. Let L be any loss function bounded by M , symmetric,
k-lipschtiz and that satisfies the triangle inequality. Consider a sample of Ns labeled
source instances drawn from Ps and Nt unlabeled instances drawn from µt , and any
f ∈ H, then for all λ > 0, with α = kλ, we have with probability at least 1 − δ that:
r
2
1
2
1
ˆ
√
√
log(
)
+
errT (f ) ≤ W1 (Pˆs , Ptf ) +
c0
δ
NS
NT
∗
+errS (fS ) + errT (fS∗ , fT∗ ) + errT (fT∗ ) + k ∗ M ∗ φ(λ).
I
First term is JDOT objective function.
I
Second term is an empirical sampling bound.
I
Last terms are usual in DA [Mansour et al., 2009, Ben-David et al., 2010].
12 / 20
Optimization problem
min
f ∈H,γ∈∆
X
γ i,j αd(xsi , xtj ) + L(yis , f (xtj )) + λΩ(f )
(4)
i,j
Optimization procedure
I Ω(f ) is a regularization for the predictor f
I
We propose to use block coordinate descent (BCD)/Gauss Seidel.
I
Provably converges to a stationary point of the problem.
γ update for a fixed f
I Classical OT problem can be solved by network simplex.
I
Regularized OT can also be used (just adds a term to problem (4))
f update for a fixed γ
min
f ∈H
X
γ i,j L(yis , f (xtj )) + λΩ(f )
(5)
i,j
I
Weighted loss from all source labels.
I
γ performs label propagation.
13 / 20
Regression with JDOT
Least square regression with quadratic regularization
For a fixed γ the optimization problem is equivalent to
min
f ∈H
P
X 1
kŷj − f (xtj )k2 + λkf k2
n
t
j
γ i,j yis is a weighted average of the source target values.
I
ŷj = nt
I
Note that this problem is linear instead of quadratic.
I
Can use any solver (linear, kernel ridge, neural network).
j
(6)
14 / 20
Regression with JDOT
Least square regression with quadratic regularization
For a fixed γ the optimization problem is equivalent to
min
f ∈H
P
X 1
kŷj − f (xtj )k2 + λkf k2
n
t
j
γ i,j yis is a weighted average of the source target values.
I
ŷj = nt
I
Note that this problem is linear instead of quadratic.
I
Can use any solver (linear, kernel ridge, neural network).
j
(6)
14 / 20
Classification with JDOT
Accuracy along BCD iterations
1.0
0.8
α = 0.1
α = 0.5
α = 1.0
α = 10.0
α = 50.0
α = 100.0
0.6
0.4
0.2
0.00
2
4
6
8
10
12
Multiclass classification with Hinge loss
For a fixed γ the optimization problem is equivalent to
X
X
min
P̂j,k L(1, fk (xtj )) + (1 − P̂j,k )L(−1, fk (xtj )) + λ
kfk k2
fk ∈H
14
16
(7)
k
j,k
1
γ > Ps .
Nt
I
P̂ is the class proportion matrix P̂ =
I
Ps and Ys are defined from the source data with One-vs-All strategy as
(
(
1
if yis = k
1 if yis = k
s
s
Yi,k =
, Pi,k =
−1 else
0 else
with k ∈ 1, · · · , K and K being the number of classes.
15 / 20
Caltech-Office classification dataset
Domains
Base
SurK
SA
OT-IT
OT-MM
JDOT
caltech→amazon
caltech→webcam
caltech→dslr
amazon→caltech
amazon→webcam
amazon→dslr
webcam→caltech
webcam→amazon
webcam→dslr
dslr→caltech
dslr→amazon
dslr→webcam
92.07
76.27
84.08
84.77
79.32
86.62
71.77
79.44
96.18
77.03
83.19
96.27
91.65
77.97
82.80
84.95
81.36
87.26
71.86
78.18
95.54
76.94
82.15
92.88
90.50
81.02
85.99
85.13
85.42
89.17
75.78
81.42
94.90
81.75
83.19
88.47
89.98
80.34
78.34
85.93
74.24
77.71
84.06
89.56
99.36
85.57
90.50
96.61
92.59
78.98
76.43
87.36
85.08
79.62
82.99
90.50
99.36
83.35
90.50
96.61
91.54
88.81
89.81
85.22
84.75
87.90
82.64
90.71
98.09
84.33
88.10
96.61
Mean
Avg. rank
83.92
4.50
83.63
4.75
85.23
3.58
86.02
3.00
86.95
2.42
89.04
2.25
Numerical experiments
I Classical dataset [Saenko et al., 2010] is dedicated to visual adaptation.
I
Feature extraction by convolutional neural network [Donahue et al., 2014].
I
Comparison with Surrogate Kernel [Zhang et al., 2013], Subspace Alignment
[Fernando et al., 2013] and OT Domain Adaptation [Courty et al., 2016b].
I
Parameter selected via reverse cross-validation [Zhong et al., 2010].
I
SVM (Hinge loss) classifiers with linear kernel.
I
Best ranking method and 2% accuracy gain in average.
16 / 20
Amazon Review Classification dataset
Domains
NN
DANN
JDOT (mse)
JDOT (Hinge)
books→dvd
books→kitchen
books→electronics
dvd→books
dvd→kitchen
dvd→electronics
kitchen→books
kitchen→dvd
kitchen→electronics
electronics→books
electronics→dvd
electronics→kitchen
0.805
0.768
0.746
0.725
0.760
0.732
0.704
0.723
0.847
0.713
0.726
0.855
0.806
0.767
0.747
0.747
0.765
0.738
0.718
0.730
0.846
0.718
0.726
0.850
0.794
0.791
0.778
0.761
0.811
0.778
0.732
0.764
0.844
0.740
0.738
0.868
0.795
0.794
0.781
0.763
0.821
0.788
0.728
0.765
0.845
0.749
0.737
0.872
Mean
0.759
0.763
0.783
0.787
Numerical experiments
I Dataset aim at predicting reviews across domains [Blitzer et al., 2006].
I
Comparison with Domain adversarial neural network [Ganin et al., 2016].
I
Classifier f is a neural network with same architecture as DANN.
I
JDOT has better accuracy, classification loss is better than mean square error.
17 / 20
Wifi localization regression dataset
Domains
KRR
SurK
DIP
DIP-CC
GeTarS
CTC
CTC-TIP
JDOT
t1 → t2
t1 → t3
t2 → t3
80.84±1.14
76.44±2.66
67.12±1.28
90.36±1.22
94.97±1.29
85.83 ± 1.31
87.98±2.33
84.20±4.29
80.58 ± 2.10
91.30±3.24
84.32±4.57
81.22 ± 4.31
86.76 ± 1.91
90.62±2.25
82.68 ± 3.71
89.36±1.78
94.80±0.87
87.92 ± 1.87
89.22±1.66
92.60 ± 4.50
89.52 ± 1.14
93.03 ± 1.24
90.06 ± 2.01
86.76 ± 1.72
hallway1
hallway2
hallway3
60.02 ±2.60
49.38 ± 2.30
48.42 ±1.32
76.36 ± 2.44
64.69 ±0.77
65.73 ± 1.57
77.48 ± 2.68
78.54 ± 1.66
75.10± 3.39
76.24± 5.14
77.8± 2.70
73.40± 4.06
84.38 ± 1.98
77.38 ± 2.09
80.64 ± 1.76
86.98 ± 2.02
87.74 ± 1.89
82.02± 2.34
86.78 ± 2.31
87.94 ± 2.07
81.72 ± 2.25
98.83±0.58
98.45±0.67
99.27±0.41
Numerical experiments
I Objective is to predict position of a device on a discretized grid
[Zhang et al., 2013].
I
Same experimental protocol as [Zhang et al., 2013, Gong et al., 2016].
I
Comparison with domain-invariant projection and its cluster regularized version
([Baktashmotlagh et al., 2013], DIP and DIP-CC), generalized target shift
([Zhang et al., 2015], GeTarS), and conditional transferable components, with its
target information preservation regularization ([Gong et al., 2016], CTC and
CTC-TIP).
I
JDOT solve the adaptation problem for transfer across device (10% accuracy
gain on Hallway).
18 / 20
Conclusion
Joint distribution optimal transportation for domain adaptation
I
General framework for domain adaptation.
I
Model transformation of the joint distribution.
I
Theoretical justification with generalization bound
I
Similar in scope to [Long et al., 2014a] but use Wasserstein instead of MMD.
I
Do not depend on the function hypothesis class (linear, kernel, neural network).
19 / 20
Thank you
Python code available on GitHub:
https://github.com/rflamary/POT
I OT LP solver, Sinkhorn (stabilized, −scaling, GPU)
I Domain adaptation with OT.
I Barycenters, Wasserstein unmixing.
Papers available on my website:
https://remi.flamary.com/
20 / 20
References I
Baktashmotlagh, M., Harandi, M., Lovell, B., and Salzmann, M. (2013).
Unsupervised domain adaptation by domain invariant projection.
In ICCV, pages 769–776.
Ben-David, S., Blitzer, J., Crammer, K., Kulesza, A., Pereira, F., and Vaughan, J.
(2010).
A theory of learning from different domains.
Machine Learning, 79(1-2):151–175.
Ben-David, S., Blitzer, J., Crammer, K., Kulesza, A., Pereira, F., and Vaughan, J. W.
(2010).
A theory of learning from different domains.
Machine Learning, 79(1-2):151–175.
Ben-David, S., Shalev-Shwartz, S., and Urner, R. (2012).
Domain adaptation–can quantity compensate for quality?
In Proc of ISAIM.
Benamou, J.-D., Carlier, G., Cuturi, M., Nenna, L., and Peyré, G. (2015).
Iterative Bregman projections for regularized transportation problems.
SISC.
21 / 20
References II
Blitzer, J., McDonald, R., and Pereira, F. (2006).
Domain adaptation with structural correspondence learning.
In Proc. of the 2006 conference on empirical methods in natural language processing,
pages 120–128.
Courty, N., Flamary, R., Tuia, D., and Rakotomamonjy, A. (2016a).
Optimal transport for domain adaptation.
Pattern Analysis and Machine Intelligence, IEEE Transactions on.
Courty, N., Flamary, R., Tuia, D., and Rakotomamonjy, A. (2016b).
Optimal transport for domain adaptation.
IEEE Transactions on Pattern Analysis and Machine Intelligence.
Cuturi, M. (2013).
Sinkhorn distances: Lightspeed computation of optimal transportation.
In Neural Information Processing Systems (NIPS), pages 2292–2300.
Donahue, J., Jia, Y., Vinyals, O., Hoffman, J., Zhang, N., Tzeng, E., and Darrell, T.
(2014).
Decaf: A deep convolutional activation feature for generic visual recognition.
In ICML.
22 / 20
References III
Fernando, B., Habrard, A., Sebban, M., and Tuytelaars, T. (2013).
Unsupervised visual domain adaptation using subspace alignment.
In ICCV.
Ferradans, S., Papadakis, N., Peyré, G., and Aujol, J.-F. (2014).
Regularized discrete optimal transport.
SIAM Journal on Imaging Sciences, 7(3).
Ganin, Y., Ustinova, E., Ajakan, H., Germain, P., Larochelle, H., Laviolette, F.,
Marchand, M., and Lempitsky, V. (2016).
Domain-adversarial training of neural networks.
Journal of Machine Learning Research, 17(59):1–35.
Gong, B., Shi, Y., Sha, F., and Grauman, K. (2012).
Geodesic flow kernel for unsupervised domain adaptation.
In CVPR.
Gong, M., Zhang, K., Liu, T., Tao, D., Glymour, C., and Schölkopf, B. (2016).
Domain adaptation with conditional transferable components.
In ICML, volume 48, pages 2839–2848.
23 / 20
References IV
Kantorovich, L. (1942).
On the translocation of masses.
C.R. (Doklady) Acad. Sci. URSS (N.S.), 37:199–201.
Long, M., Wang, J., Ding, G., Pan, S. J., and Philip, S. Y. (2014a).
Adaptation regularization: A general framework for transfer learning.
IEEE Transactions on Knowledge and Data Engineering, 26(5):1076–1089.
Long, M., Wang, J., Ding, G., Sun, J., and Yu, P. (2014b).
Transfer joint matching for unsupervised domain adaptation.
In CVPR, pages 1410–1417.
Mansour, Y., Mohri, M., and Rostamizadeh, A. (2009).
Domain adaptation: Learning bounds and algorithms.
In Proc. of COLT.
Monge, G. (1781).
Mémoire sur la théorie des déblais et des remblais.
De l’Imprimerie Royale.
Perrot, M., Courty, N., Flamary, R., and Habrard, A. (2016).
Mapping estimation for discrete optimal transport.
In Neural Information Processing Systems (NIPS).
24 / 20
References V
R. Gopalan, R. L. and Chellappa, R. (2014).
Unsupervised adaptation across domain shifts by generating intermediate data
representations.
IEEE Transactions on Pattern Analysis and Machine Intelligence, page To be published.
Redko, I., Habrard, A., and Sebban, M. (2016).
Theoretical Analysis of Domain Adaptation with Optimal Transport.
ArXiv e-prints.
Saenko, K., Kulis, B., Fritz, M., and Darrell, T. (2010).
Adapting visual category models to new domains.
In ECCV, LNCS, pages 213–226.
Si, S., Tao, D., and Geng, B. (2010).
Bregman divergence-based regularization for transfer subspace learning.
IEEE Transactions on Knowledge and Data Engineering, 22(7):929–942.
Sugiyama, M., Suzuki, T., Nakajima, S., Kashima, H., von Bünau, P., and Kawanabe,
M. (2008).
Direct importance estimation for covariate shift adaptation.
Annals of the Institute of Statistical Mathematics, 60(4):699–746.
25 / 20
References VI
Thorpe, M., Park, S., Kolouri, S., Rohde, G., and Slepcev, D. (2016).
A transportation lp distance for signal analysis.
CoRR, abs/1609.08669.
Tuia, D., Flamary, R., Rakotomamonjy, A., and Courty, N. (2015).
Multitemporal classification without new labels: a solution with optimal transport.
In 8th International Workshop on the Analysis of Multitemporal Remote Sensing Images.
Urner, R., Shalev-Shwartz, S., and Ben-David, S. (2011).
Access to unlabeled data can speed up prediction time.
In Proceedings of ICML, pages 641–648.
Zhang, K., Gong, M., and Schölkopf, B. (2015).
Multi-source domain adaptation: A causal view.
In AAAI Conference on Artificial Intelligence, pages 3150–3157.
Zhang, K., Zheng, V. W., Wang, Q., Kwok, J. T., Yang, Q., and Marsic, I. (2013).
Covariate shift in Hilbert space: A solution via surrogate kernels.
In ICML.
Zhong, E., Fan, W., Yang, Q., Verscheure, O., and Ren, J. (2010).
Cross validation framework to choose amongst models and datasets for transfer learning.
In ECML/PKDD.
26 / 20
Generalization error in domain adaptation
Theoretical bounds [Ben-David et al., 2010]
The error performed by a given classifier in the target domain is upper-bounded by the
sum of three terms :
I
Generalization error of the classifier in the source domain;
I
Divergence measure between the densities the two domains (W1 in
[Redko et al., 2016]);
I
A third term measuring how much the classification tasks are related to each
other.
Optimal transport for domain adaptation [Courty et al., 2016a]
I
Model the discrepancy between the distribution through a general transformation.
I
Use optimal transport to estimate the transportation map between the two
distributions.
I
Use regularization terms for the optimal transport problem that exploits labels
from the source domain.
27 / 20
OT for domain adaptation : Step 1
Dataset
Optimal transport
Class 1
Class 2
Samples
Samples
Classifier on
Classification on transported samples
Samples
Samples
Samples
Samples
Classifier on
Step 1 : Estimate optimal transport between distributions.
I Choose the ground metric (squared euclidean in our experiments).
I
Using regularization allows
I
I
I
Large scale and regular OT with entropic regularization [Cuturi, 2013].
Class labels in the transport with group lasso [Courty et al., 2016a].
Efficient optimization based on Bregman projections [Benamou et al., 2015] and
I
I
Majoration minimization for non-convex group lasso.
Generalized Conditionnal gradient for general regularization (cvx. lasso,
Laplacian).
28 / 20
OT for domain adaptation : Steps 2 & 3
Dataset
Optimal transport
Class 1
Class 2
Samples
Samples
Classifier on
Classification on transported samples
Samples
Samples
Samples
Samples
Classifier on
Step 2 : Transport the training samples onto the target distribution.
I The mass of each source sample is spread onto the target samples (line of γ ).
0
I
We estimate the transported position for each source [Ferradans et al., 2014] :
X
csi = argmin
x
γ 0 (i, j)c(x, xtj ).
(8)
x
I
j
Can be computed efficiently for a quadratic loss.
Step 3 : Learn a classifier on the transported training samples
I Classic ML problem when samples are well transported.
29 / 20
Efficient regularized optimal transport
Transportation cost matric C
Optimal matrix γ (Sinkhorn)
Entropic regularization [Cuturi, 2013]
γ λ0 = argmin hγ, CiF − λh(γ),
(9)
γ∈P
where h(γ) = −
P
i,j
γ(i, j) log γ(i, j) computes the entropy of γ.
I
Entropy introduces smoothness in γ λ0 .
I
Sinkhorn-Knopp algorithm (efficient implementation in parallel, GPU).
I
General framework using Bregman projections [Benamou et al., 2015].
30 / 20