0% found this document useful (0 votes)
4 views103 pages

Optimal Transport

Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
4 views103 pages

Optimal Transport

Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 103

Optimal Transport in Large-Scale

Machine Learning Applications

Nhat Ho

The University of Texas, Austin

1
Talk Outline
• Applications/ Methods of Optimal Transport (OT): Brief Introduction
• Foundations of Optimal Transport
• Monge’s Optimal Transport Formulation
• Kantorovich’s Optimal Transport Formulation
• Entropic Regularized Optimal Transport
• Application of Optimal Transport to Deep Generative Model
• Wasserstein GAN
• Issues of Wasserstein GAN and Solutions

2
Some Applications/ Methods of Optimal
Transport (OT): Brief Introduction

3
OT’s Method: Deep Generative Model

CIFAR 10
Speech

Goal: Given a set of data in high dimension (e.g., images, speeches, words, etc.),
we would like to learn the underlying data distribution

4
OT’s Method: Deep Generative Model
• OT is used as a loss between push-forward distribution from low-dimensional
space and the empirical distribution from data

• Popular examples: Wasserstein GAN [1, 2], Wasserstein Autoencoder [3]

Image from Internet

5
OT’s Method: Transfer Learning

Image from Internet

• Domain Adaptation: An important problem of designing autonomous vehicle is to


make sure that the model we train in some particular weather/ environment/ time
(source domains) will still perform well under other weathers/ environments/ time
(target domains)

• Optimal transport is an efficient loss function capture the difference between these
domains (e.g., [4] and [5]) 6
OT’s Method: Transfer Learning

• Domain Generalization: An important example is that we would like to develop a


face recognition system in new generation of Iphone (target domain) based on the
previous Iphones (source domains) without the expensive cost of collecting new
data for the new Iphone

• Optimal Transport also offers a great solution for this application


7
OT’s method: 3D Objects’ Representation

Above: Input 3D images

Below: Reconstruction of 3D images based on optimal transport [6]

[6] Trung Nguyen, Hieu Pham, Tam Le, Tung Pham, Nhat Ho, Son Hua. Point-set distances for learning representations of 3D point clouds. ICCV, 2021
8
OT’s Method: (Multilevel) Clustering

• Each image contains several annotated regions, such as, those of animals,
buildings, trees, etc.

• Goal: Based on the clustering behaviors of annotated regions from the images,
we would like to learn the themes/ clusters of images
9
OT’s Method: Multilevel Clustering

3 clusters of images based on

using optimal transport (cf. [7], [8])

[7] Nhat Ho, Long Nguyen, Mikhail Yurochkin, Hung Bui, Viet Huynh, and Dinh Phung. Multilevel clustering via Wasserstein means. ICML, 2017

[8] Viet Huynh, Nhat Ho, Nhan Dam, Long Nguyen, Mikhail Yurochkin, Hung Bui, Dinh Phung. On efficient multilevel clustering via Wasserstein distances. Journal of Machine
Learning Research (JMLR), 2021
10
OT’s Method: Other Applications
• Optimal Transport is also a powerful tool for other important applications:
• Forecasting Time Series (e.g., forecasting sales (Walmart), forecasting
expenses (Amazon), etc.) [9]

• Machine Translation [10]


• Robust/ Reliable Machine Learning [11]
• Fairness/ Responsible AI

11
OT is also useful as foundational theory tool
Object category Layer L

Layer
L-1

Intermediate
rendering Latent variables

Layer
1

Layer
Image 0

• Optimal transport can be used to understand the behaviors of latent variables


associated with Relu, Maxpooling from Convolutional Neural Networks (CNNs)
(cf. [12])

[12] Tan Nguyen, Nhat Ho, Ankit Patel, Anima Anandkumar, Michael I. Jordan, Richard Baraniuk. A Bayesian Perspective of Convolutional Neural Networks through a
Deconvolutional Generative Model. Under Revision, Journal of Machine Learning Research (JMLR), 2022
12
OT is also useful as foundational theory tool
• A few other popular applications of OT for understanding machine learning methods and
models include:

• Mixture models and hierarchical models: Characterizing the convergence rates


of estimating parameters, performing model selection, etc. (cf. [13], [14], [15])

• Distributional robust optimization: Optimal Transport can be used to define a


perturbed neighborhood of the true distribution (cf. [16], [17])

• Some potential new research directions: Optimal Transport can be useful to


understand

• (i) Self-training procedure in semi-supervised learning


• (ii) Self-attention in Transformer
• (iii) Contrastive Learning, Self-supervised Learning, etc.

13
Foundations of Optimal Transport
• Monge’s Optimal Transport Formulation
• Kantorovich’s Optimal Transport Formulation
• Entropic Regularized Optimal Transport

14
Monge’s OT Formulation: Motivation
• Optimal Transport was created by mathematician Gaspard Monge to find
optimal ways to transport commodities and products under certain constraints

Image from Internet

15
Monge’s OT Formulation: Motivation
• We start with a simple practical example
B1 R1 of moving products from Bakeries
C1,1 (denoted by B) to Restaurants (denoted by
R)

B2 • Two bakeries will not transport the


C2,2 R2 products to the same restaurant

• We denote by Cij the distance between


R4 bakery Bi to restaurant Rj
B3
• Goal: Find the shortest distance to move
products from the bakeries to restaurants
R3
B4
16
Monge’s OT Formulation
• Monge’s Optimal Transport is:
B1 C1,1 R1
n
1
n σ∈Pern ∑
min Ci,σ(i), (1)
i=1
B2
where n: number of restaurants or bakeries C2,2 R2

Pern: the set of all permutations of R4


{1,2,…, n} B3
• Monge’s formulation finds the optimal
matching between the bakeries and R3
restaurants B4
17
Monge’s OT Formulation
• If we search for all the possible permutations in the optimization problem, the
complexity of solving Monge’s Optimal Transport is 𝒪(n!) (The total number
of permutations of {1,2,…, n} is n!)

• By using Hungarian’s algorithm for graph matching, we can obtain an


3
improved complexity of 𝒪(n )
2
• When we have Cij = | Bi − Rj | , i.e., one dimensional setting, we can use
quick sort algorithm to compute Monge’s Optimal Transport in equation (1)
with a complexity of 𝒪(n log n)

18
Monge’s OT Formulation: Equivalent Form
n n
1 1
n∑ ∑
We define Pn = δBi and Qn = δRi as corresponding
• n Pn
i=1 i=1
empirical measures of bakeries and restaurants
2
• We denote Cij = ∥Bi − Rj∥ as the distance between Bi and
B1 B2 Bn
Rj
Qn
• The Monge’s formulation in equation (1) can be rewritten as

T ∫
2
inf ∥x − T(x)∥ dPn(x), R1 R2 Rn
d d
where the mapping T : ℝ → ℝ in the infimum is such that
T♯Pn = Qn

• Here, T♯Pn denotes the push-forward measure of Pn via


mapping T 19
Push-forward measure ℝ d
ℝd

n
1 d d
n∑
Recall that, Pn = δBi and T : ℝ → ℝ

i=1
n
1
n∑
Then, T♯Pn = δT(Bi)

i=1

• The equation T♯Pn = Qn implies that Pn


T♯Pn
{T(B1), T(B2), …, T(Bn)} ≡ {R1, R2, …, Rn}

20
General Monge’s OT Formulation
P
• In general, we can define the Monge’s optimal transport
beyond discrete probability distributions, such as
Gaussian distributions

• For any two probability distributions P and Q, the Monge’s Q


Optimal Transport between P and Q can be defined as T

T ∫
2
inf ∥x − T(x)∥ dP(x), (2)

d d
where the mapping T : ℝ → ℝ in the infimum is such that
T♯P = Q
• Note that, for continuous distributions, T♯P = Q means
−1 d
that P(T (A)) = Q(A) for any measurable set A of ℝ

21
General Monge’s OT Formulation: Challenges
• Good settings: When (i) P and Q admit density functions or (ii) P and
Q are discrete with uniform weights, there exist optimal maps T that
solve the Monge’s OT in equation (2)

• Pathological settings:
• In certain settings when P and Q are discrete, the existence of
mapping T such that T♯P = Q may not always be possible
1 1
• Assume that P = δx and Q = δy1 + δy2, the equation T♯P = Q
2 2 P
means that
1 x
−1
P(T ({y1})) = Q({y1}) = Q
2
−1
• However, it is not possible as P(T ({y1})) ∈ {0,1} depending y1 y2
−1
on whether x ∈ T (y1)
22
General Monge’s OT Formulation: Challenges
• The non-existence of transport map T under pathological settings makes it
challenging to use Monge’s OT formulation when the probability distributions
P and Q are discrete
• Furthermore, due to the non-linearity of the constraint T♯P = Q, it is non-
trivial to solve for or approximate the optimal mapping T in equation (2)

• A relaxation and optimization friendly form of Monge’s OT formulation is


needed

23
Kantorovich’s Optimal Transport Formulation

24
Kantorovich’s OT Formulation
• Given two probability distributions P and Q, the Kantorovich’s Optimal Transport
between P and Q can be defined as

π∈Π(P,Q) ∫
OT(P, Q) := inf c(x, y)dπ(x, y), (3)

where Π(P, Q) is the set of all joint distributions


between P and Q;

c( . , . ) is a given cost metric


• π is called transportation plan Image from Internet

• Under certain assumptions (see Section 4 in [18]), the Kantorovich’s OT and Monge’s
OT are equivalent

25
Kantorovich’s OT for Discrete Measures
m


When P = δη and Q = qiδθi, then
• p2
i=1
m p1 p3
p4

OT(P, Q) = qi ⋅ c(η, θi)
i=1 η1 η2 η3 η4
n m q1
θ1 π11 π12 π13 π14
∑ ∑
When P = piδηi and Q = qjδθj, then

i=1 j=1
q2 θ2 π21 π22 π23 π24
n m
θ3 π31 π32 π33 π34
∑∑
OT(P, Q) = min πij ⋅ c(ηi, θj), (4) q3
π≥0
i=1 j=1 q4 θ4 π41 π42 π43 π44
n m
θ5 π51 π52 π53 π54
q5
∑ ∑
s.t. πij = qj for all 1 ≤ j ≤ m; πij = pi for all
θ6 π61 π62 π63 π64
i=1 j=1 q6
1≤i≤n
• These simple examples show that there always exists optimal transportation
plan when P and Q are discrete, which is in contrast to the Monge’s OT
formulation 26
Kantorovich’s OT for Discrete Measures
• We can rewrite the problem (4) as follows
OT(P, Q) = min ⟨C, π⟩ (5)
n×m
π∈ℝ

s.t. π ≥ 0; π1m = p; π 1n = q,
where p = (p1, p2, …, pn); q = (q1, q2, …, qm)
• The problem (3) is a linear programming problem
n×m ⊤
• The set 𝒫 = {π ∈ ℝ : π ≥ 0, π1m = p, π 1n = q} is called a
transportation polytope, which is a convex set

27
Computational Complexity of Kantorovich’s Formulation
• The below theorem yields the best computational complexity of the network
simplex algorithm for solving the linear programming (5)

Theorem 1: The best computational complexity of the network simplex


algorithm for solving the linear programming (5) is of the order of [19]

𝒪((n + m)nm log(n + m)log((n + m)∥C∥∞))


3
• When n = m, the complexity becomes 𝒪(n log n), which is practically very
expensive when n is very large

• Therefore, the network simplex algorithm is not sufficiently scalable to use for
large-scale machine learning and deep learning applications

28
Entropic (Regularized) Optimal Transport

29
Entropic (Regularized) Optimal Transport
• We now discuss an useful approach to obtain scalable approximation of optimal
transport

• The idea is that we regularize the optimal transport (5) by the entropy of the
transportation plan [20], named entropic (regularized) optimal transport:

EOTη(P, Q) = min ⟨C, π⟩ − ηH(π), (6)


π∈𝒫(p,q)

where η > 0 is the regularized parameter;


n m

∑∑
H(π) = − πij log(πij);
i=1 j=1
n×m ⊤
𝒫(p, q) = {π ∈ ℝ : π1m = p, π 1n = q};
Here, we use a convention that log(x) = − ∞ when x ≤ 0
30
Properties of Entropic Optimal Transport
• For each regularized parameter η > 0, the objective function of the entropic
regularized optimal transport is η− strongly convex function

• It is because the function −H( . ) is 1-strongly convex function as long


as πij ≤ 1 for all (i, j)

• As the constrained set 𝒫(p, q) is convex, it indicates that there exists unique
optimal transportation plan, denoted by π*
η , for solving the entropic regularized
optimal transport

31
Properties of Entropic Optimal Transport
Theorem 2: (a) When η → 0, we have
EOTη(P, Q) → OT(P, Q),
π*
η → arg min {−H(π)},
π∈𝒫:⟨C,π⟩=OT(P,Q)

(b) When η → ∞, we have


EOTη(P, Q)→ ⟨C, p ⊗ q⟩,

π*
η → p ⊗ q = pq

• The results of part (b) indicate that when the regularized parameter η is
sufficiently large, we can treat the distributions P and Q as independent
distributions

32
Sinkhorn Algorithm
• We now discuss a popular algorithm, named Sinkhorn algorithm, for solving
the entropic regularized optimal transport (6)

• Optimization challenges of primal form: The primal form (6) is an


constrained optimization problem with several constraints; therefore, it may be
non-trivial to solve the primal form directly

• Dual form of entropic optimal transport (6): We will demonstrate that solving
the dual form of (9), which is an unconstrained optimization problem, is easier

• Solving the dual form is equivalent to solve


n m Cij
( )
⊤ ⊤
∑∑
min exp ui + vj − −u p−v q (7)
n
u∈ℝ ,v∈ℝm
i=1 j=1
η

33
Sinkhorn Algorithm: Detailed Description
0 n 0 m
• Step 1: Initialize u = 0 ∈ ℝ and v = 0 ∈ ℝ
• Step 2: For any t ≥ 0, we perform
• If t is an even number, then for all (i, j)
m Cij′
( η )
t+1 t t+1 t

ui = log(pi) − log exp vj′ − , vj = vj
j′=1

• If t is an odd number, then for all (i, j)


m Ci′j
( η )
t+1 t t+1 t

vj = log(qj) − log exp ui′ − , ui = ui
i′=1

• Increase t ← t + 1
34
Approximation of Optimal Transport via Sinkhorn algorithm
• Now, we discuss briefly the complexity of approximating the value of optimal
transport via the Sinkhorn algorithm

• Goal: We would like to find a transportation plan π̄ ∈ 𝒫 (see definition of 𝒫


in Slide 28) such that

⟨C, π̄⟩ ≤ min⟨C, π⟩ + ϵ


π∈𝒫

• We call π̄ the ϵ-approximation plan

35
Approximation of Optimal Transport via Sinkhorn algorithm
t t
• Denote (u , v ) as the updates of step t from the Sinkhorn algorithm (See
Slide 35)

• The corresponding transportation plan is


t t t
π := diag(exp(u )) ⋅ K ⋅ diag(exp(v )),
t t t
where diag(exp(u )) denotes the diagonal matrix with exp(u1), …, exp(un) in
its diagonal
t t
• Unfortunately, π ∉ 𝒫, namely, we do not have either π 1m = p or
t ⊤
(π ) 1n = q

36
Approximation of Optimal Transport via Sinkhorn algorithm
t t
• Therefore, we need to do an extra rounding step to transform π to π̄ such
t t ⊤
that π̄ 1m = p and (π̄ ) 1n = q

• Details of that rounding step are in Algorithm 2 in [21] (We skip this step in the
lecture for the simplicity)
ϵ t t
Theorem 3: Assume that η = . Denote by (u , v ) updates from the
4 log(max{n, m})
Sinkhorn algorithm for the entropic optimal transport with regularized parameter η and
t
denote by π̄ the rounding transportation plan we obtain from these updates. Then, we
have
t
⟨C, π̄ ⟩ ≤ min⟨C, π⟩ + ϵ
π∈𝒫
2
∥C∥∞ log(max{n, m})
as long as t = 𝒪( ) .
ϵ 2

37
Approximation of Optimal Transport via Sinkhorn algorithm
• The proof of Theorem 3 can be found in Theorem 2 of [22]
2
• Each iteration of the Sinkhorn algorithm requires max{n, m} arithmetic
operations

• The result of Theorem 6 indicates that the total computational complexity of


approximating the optimal transport via the Sinkhorn algorithm is
2
2
∥C∥∞ log(max{n, m})
𝒪(max{n, m} )
ϵ2
• It is much cheaper than the complexity of the network simplex algorithm in
3
Theorem 2, which is of the order 𝒪(max{n, m} )

38
Other Approximations of Optimal Transport
• There are other optimization algorithms that outperform Sinkhorn:
• Greedy version of Sinkhorn (Greenkhorn) [23]
• Accelerated Sinkhorn [24]
• The scalable approximations of optimal transport via these optimization
algorithms have lead to several interesting methodological developments in
machine learning

[23] Tianyi Lin, Nhat Ho, Michael I. Jordan.On efficient optimal transport: an analysis of greedy and accelerated mirror descent algorithms. ICML, 2019

[24] Tianyi Lin, Nhat Ho, Michael I. Jordan. On the efficiency of entropic regularized algorithms for optimal transport. Journal of Machine Learning Research (JMLR), 2022

39
Deep Generative Model via Optimal Transport
• Wasserstein GAN
• Issues of Wasserstein GAN:
• Misspecified Matchings of Minibatch Schemes
• Curse of Dimensionality

40
Generative Model
• We now discuss an important application of optimal transport in generative
modeling task

Imagenet
CIFAR 10

• Goal: Given a collection of very high dimensional data, we would like to learn
the underlying data distribution P effectively
41
Generative Model
• There are several approaches:
• Nonparametric approaches:
• Frequentist density estimator
• Bayesian nonparametric models
• Parametric approaches via latent variable assumption:
• Bayesian hierarchical models
• Deep learning models, i.e., Variational Auto-Encoder (VAE)
[25], Generative Adversarial Networks (GANs) [26], etc.

42
Generative Adversarial Networks (GANs)
• Generative Adversarial Networks is an instance of implicit methods, i.e., we
do not need explicit density estimation

• May allow a smooth interpolation across images


• May be able to capture the underlying variation of the data (images
with unseen patterns, etc.)

• It is different from Variational Auto-Encoder, which is an instance of explicit


methods

43
Generative Adversarial Networks (GANs)
General recipe of implicit methods:

• We generate z from some distribution pZ( . ) (e.g., Gaussian distribution)


• We consider a “fake” data generating distribution Tϕ(z) where Tϕ is some
vector-value function parametrized by ϕ

• We need to make sure that Tϕ( . ) is as close as possible to the true


distribution P of the data (Here, we do not make any parametric
assumption on the true distribution)

Some divergences between Tϕ( . ) and P are needed

44
Generative Adversarial Networks (GANs)
• For GANs [26], the choice of that divergence is the Jensen-Shannon divergence (JS):
min JS(Tϕ(z), P), (8)
ϕ

P + Tϕ(z) P + Tϕ(z)
( ) ( )
where JS(Tϕ(z), P) := KL Tϕ(z), + KL P,
2 2

• If we denote G = Tϕ, it is equivalent to the following minimax game:


min max 𝔼x∼P[log(D(x))] + 𝔼z∼pZ[log(1 − D(G(z)))] ,
G D

where G : generator, D : discriminator

• This is an instance of non-convex non-concave minimax optimization problem


45
Continuity Issue of GANs
• The JS divergence being used in GANs is problematic [27] when Tϕ(z) and P fall
into the following cases:

• Disjoint supports
• One is continuous distribution and another one is discrete distribution
• Example: To see that, we will consider the following simple example:
Tϕ(z) = (ϕ, z) where z ∼ U(0,1) and P = (0,U(0,1))
• Direct calculation shows that
JS(Tϕ(z), P) = log(2) if ϕ ≠ 0 and 0 otherwise

• Therefore, the JS divergence is discontinuous at the true parameter ϕ = 0 and


takes constant value when ϕ ≠ 0 (Gradient descent method cannot be used!)
46
Wasserstein GANs
• One solution to the continuity issue of JS divergence is by using weaker
metric, such as optimal transport

• The paper [27] suggests that we can use the first order Wasserstein metric
• For any two distributions P and Q, the first order Wasserstein metric between
P and Q is defined as follows:

π∈Π(P,Q) ∫
W1(P, Q) = inf ∥x − y∥dπ(x, y),

where Π(P, Q) denotes the set of joint probability measures between P and Q

47
Wasserstein GANs
• The objective of Wasserstein GANs is then given by:
min W1(Tϕ(z), P) (9)
ϕ

• The first order Wasserstein metric is meaningful even when the two
distributions

• Have disjoint supports


• One distribution is discrete and another distribution is continuous
• To see that, we reconsider the example in Slide 46

48
Wasserstein GANs
• Under this case, we can verify that W1(Tϕ(z), P) = | ϕ | for all ϕ ∈ ℝ
• It is clear that this function is continuous for all ϕ and we can use optimization
method to solve min | ϕ |
ϕ

• In general, if Tϕ( . ) is continuous in ϕ, the first order Wasserstein metric


W1(Tϕ(z), P) is also continuous in ϕ

• If Tϕ( . ) is locally Lipschitz and satisfies some regularity conditions, then


W1(Tϕ(z), P) is differentiable almost everywhere (See Theorem 1 in [27])

49
Wasserstein GANs
• These observations indicate that the first order Wasserstein metric is a valid
choice for GANs

• From the definition of first order Wasserstein metric, we can rewrite equation
(16) as follows:

π∈Π(Tϕ(z),P) ∫
min W1(Tϕ(z), P) = min min ∥x − y∥dπ(x, y) (10)
ϕ ϕ

• Directly optimizing the objective function in equation (10) is not feasible in


general

• We will discuss a dual function approach for dealing with that optimization
problem

50
Wasserstein GANs: Dual Function Approach
• Dual Function Approach: For any two probability distributions P and Q, the
dual form of the first order Wasserstein metric between P and Q has the
following form:

W1(P, Q) = sup 𝔼x∼P[ f(x)] − 𝔼x∼Q[ f(x)], (11)


f∈ℒ1

where ℒ1 is the set of 1-Lipschitz function f , i.e., | f(x) − f(y) | ≤ ∥x − y∥ for


d
all x, y ∈ ℝ

• Please refer to Section 5 in [27] about how to derive the dual form (11)

51
Wasserstein GANs: Dual Function Approach
• Given the dual form of the first order Wasserstein metric in equation (18), we
can rewrite Wasserstein GANs as follows:

min W1(Tϕ(z), P) = min max 𝔼x∼Tϕ(z)[ f(x)] − 𝔼x∼P[ f(x)]


ϕ ϕ f∈ℒ1

= min max 𝒯(ϕ, f ) (12)


ϕ f∈ℒ1

• To update the function f in Wasserstein GANs, it is non-trivial as it is a


maximization problem over the functional space

• We consider approximating the ℒ1 space using deep neural networks where


we parametrize it as {fω} and ω are the weights of neural networks

52
Wasserstein GANs: Dual Function Approach
• Therefore, we approximate the Wasserstein GANs (19) as
min max 𝔼z∼pZ[ fω(Tϕ(z))] − 𝔼x∼P[ fω(x)] (13)
ϕ ω

• We can solve both ϕ and ω via (stochastic) gradient descent methods


• The detailed optimization algorithm for solving the approximated Wasserstein
GANs (20) is in Algorithm 1 in [27]

53
Limitations of Dual Function Approach
• Limitations of dual function approach:
• It relies on the choice of first order Wasserstein metric and Euclidean
distance to have a nice dual form

• The Euclidean distance assumption can be very strong in practice as it is


not good to capture the difference of high dimensional data

• In general, we would like to have a more general form of Wasserstein GANs,


named optimal transport GANs (OT-GANs):

min OT(Tϕ(z), P), (14)


ϕ

π∈Π(Tϕ(z),P) ∫
where OT(Tϕ(z), P) = inf c(x, y)dπ(x, y) and c( . , . ) is some metric

54
Optimal Transport GANs (OT-GANs)
• For general cost matrix c( . , . ), the dual form of OT-GANs (21) can be non-
trivial to use

• Therefore, people also advocate the direct optimization of OT-GANs


• Challenge: Since both Tϕ(z) and P are continuous, we generally cannot
compute directly OT(Tϕ(z), P)

• Solution: We can use the sample versions of Tϕ(z) and P to approximate


OT(Tϕ(z), P)

55
Optimal Transport GANs (OT-GANs)
n
1

For the distribution P, we can use Pn = δXi where X1, X2, …, Xn are the
• n i=1
data
M
1

For Tϕ(z), we can use δTϕ(zi) where z1, z2, …, zM are i.i.d. samples from
• M i=1
pZ( . )
• It suggests the following approximation of OT-GANs (14)
M n
1 1
M∑ ∑
inf OT( δTϕ(zi), δXi) (15)
ϕ
i=1
n i=1

56
Computational Challenge of OT-GANs

57
Computational Challenge of OT-GANs
• Computational Challenge:
• The M
computational complexity of approximating the optimal transport between
n
1 1 2
M∑ ∑
δTϕ(zi) and δXi is 𝒪(max{M, n} )
i=1
n i=1

• In practice, n can be very large (as large as a few millions) and M need to be
chosen to be quite large (scale with the dimension) to guarantee good
M
1
M∑
approximation of Tϕ(z) via the empirical distribution δTϕ(zi)
i=1
• Unfortunately, it is unavoidable memory issue of optimal transport
• Practical Solution: A popular approach for doing that is to consider minibatches
of the entire data, which we refer to as minibatch optimal transport GANs
58
Minibatch Optimal Transport

59
Minibatch Optimal Transport GANs (mOT-GANs)
• To set up the stage, we need the following notations:
• We denote by m the minibatch size where m ≤ min{M, n}

(m) (m)
n M
X z
We denote and the sets of all m elements of {X1, …, Xn}

and {z1, …, zM} respectively

(m) (m)
n M
m X m z
For any X ∈ and z ∈ , we respectively denote by

1 1 m
∑ ∑
PXm = δx and Pzm = δz′ the empirical measures of X and
m x∈X m m z′∈z m
m
z
60
Minibatch Optimal Transport GANs (mOT-GANs)
Minibatch Optimal Transport GANs (mOT-GANs): For any batch size
m m m m
1 ≤ m ≤ min{M, n} and number of minibatches k, we draw X1 , …, Xk and z1 , …, zk

(m) (m)
n M
X z
uniformly from and . The minibatch optimal transport GANs is given by:

k
1

min OT(Tϕ(Pz m), PX m) (16)
ϕ k i i
i=1

• The common choice that people use in practice is k = 1 and m is chosen


based on the memory of GPU

• Note that, the choice that k = 1 can lead to sub-optimal result in practice

61
Minibatch Optimal Transport GANs (mOT-GANs)
• Computational Complexity of mOT-GANs:
• When ϕ is given, the complexity of computing OT(Tϕ(Pz m), PX m) exactly
i i
3
is at the order of 𝒪(m ) if we use exact-solver to solve the linear
programming
2
• We can improve the complexity to 𝒪(m ) via using entropic regularized
optimal transport to approximate OT(Tϕ(Pz m), PX m)
i i


Therefore, the best complexity of approximating OT(Tϕ(Pz m), PX m) is
• i i

2 i=1
𝒪(km )

62
OT GANs: Minibatch Approach
• For the approximation of OT-GANs in equation (15), the complexity is
2
𝒪(max{M, n} )
2 2
• As long as km ⋘ max{M, n} , the complexity of mOT-GANs is much
cheaper than that of OT-GANs for each parameter ϕ

• The mOT-GANs is convenient for large-scale settings of deep generative


model

• Similar to OT-GANs, we can solve optimal parameter ϕ of mOT-GANs (16) via


(stochastic) gradient descent methods

63
Wasserstein GANs: Minibatch Approach
• Examples of CIFAR 10 generated images via mOT-GANs:

Minibatch size: m= 200


Number of minibatches: k = 2

Minibatch size: m= 200


Number of minibatches: k = 4

Minibatch size: m= 200


Number of minibatches: k = 8

Data
Generated data
64
Issues of mOT-GANs
• mOT-GANs suffer from misspecified matching issue, i.e., the optimal transport
plan from the mOT-GANs contains wrong matchings that do not appear in the
original optimal transport plan of OT-GANs

• The misspecified matchings lead to a decline in the performance of mOT-


GANs

• There are a few recent proposals to solve the misspecified matching issue,
includes using partial optimal transport [28], hierarchical optimal transport
[29], unbalanced optimal transport [30]

65
Minibatch Partial Optimal Transport [28]

[28] Khai Nguyen, Dang Nguyen, Tung Pham, Nhat Ho. Improving minibatch optimal transport via partial transportation. ICML, 2022

66
Misspecified Matching Issue of MOT
• We consider a simple example where Pn, Qn are two empirical distributions
with 5 supports on 2D: {(0,1), (0,2), (0,3), (0,4), (0,5)},
{(1,1), (1,2), (1,3), (1,4), (1,5)}

LHS: Optimal matching (black color) between


Pn, Qn;
RHS: Wrong matchings (red color)
induced by minibatches
Alleviating Misspecified Matching of M-OT via
Partial Transportation
• We now demonstrate that we can alleviate the misspecified matching issue
via partial optimal transport

• The Partial Optimal Transport (POT) between Pn and Qn is defined as follow:


POTs(Pn, Qn) = min ⟨C, π⟩,
π∈Πs(un,un)

where C is the distance matrix; s : transportation fraction;


un is the uniform measures over n supports; and

Πs(un, un) := {π ∈ n×n


ℝ+ : π1n ≤ un, π 1n ≤ un,1 π1 = s}
⊤ ⊤
Minibatch Partial Optimal Transport
• The Minibatch Partial Optimal Transport (m-POT) [21] between Pn and Qn with
transportation fraction s is defined as
k
1

m-POTs(Pn, Qn) = POTs(PX m, PY m),
k i=1 i i

(m) (m)
n n
m m X m m Y
where X1 , …, Xk ∈ ; Y1 , …, Yk ∈ ;

m m
PXim, PYim are empirical measures associated with Xi and Yi
Computational Complexity of Minibatch Partial Optimal Transport
• We have an equivalent way to write m-POT in terms of m-OT as follows:
k
1
k∑
m-POTs(Pn, Qn) = min ⟨C̄i, π⟩,
π∈Π(ᾱi,ᾱi)
i=1

( 0 Ai)
Ci 0 (m+1)×(m+1)
where Ci = ∈ ℝ+ ;
m m
Ci is a cost matrix formed by the differences of elements of Xi and Yi ;
Ai > 0 for all i = 1,2,…, k;
ᾱi = [um,1 − s] for all i = 1,2,…, k
• By using entropic regularized approach, we can compute the m-POT with
2
computational complexity 𝒪(k(m + 1) ), which is comparable to that of m-OT
Minibatch Partial Optimal Transport
• The corresponding transportation plan of minibatch partial optimal transport
with transportation fraction s is given by:
k
m-POT 1s
POTs
k∑
π = k πP m,P m,
Xi Yi
i=1

where POT
πP ,P is a transportation matrix from solving POTs(PXm, PYm);
s
Xm
i Ym
i
i i

POT
πP ,P is expanded to a n × n matrix that has padded zero entries to
s
Xm
i Ym
i
m m
indices which are different from those of Xi and Yi
Minibatch Partial Optimal Transport
• The m-POT can alleviate misspecified matchings

Pn, Qn are two empirical distributions with 5 supports on 2D:


{(0,1), (0,2), (0,3), (0,4), (0,5)}, {(1,1), (1,2), (1,3), (1,4), (1,5)}
Minibatch Partial Optimal Transport
• The m-POT can alleviate misspecified matchings

The transportation between two empirical measures of 10 supports that are


drawn from two mixture of Gaussians of two components.
Experiments: Deep Generative Model

CelebA is a large-scale face attributes dataset with


more than 200000 celebrity images.
Batch of Minibatches Optimal Transport [29]

[29] Khai Nguyen, Dang Nguyen, Quoc Nguyen, Tung Pham, Dinh Phung, Hung Bui, Trung Le, Nhat Ho. On transportation of mini-batches: A
hierarchical approach. ICML, 2022

75
Alleviating Misspecified Matching of m-OT via Hierarchical Approach
• The m-POT requires to choose good transportation fraction s, which can be non-trivial in
practice

• We now describe another approach that can be used to alleviate the misspecified
matching of m-OT without any tuning parameter

• The Batch of Minibatches Optimal Transport (BoMb-OT) between Pn and Qn is defined as


k k

∑∑
BoMb-OT(Pn, Qn) = min γijOT(PXim, PYjm),
⊗m ⊗m
γ∈Π(Pk ,Qk )
i=1 j=1

(m) (m)
n n
m m X m m Y
where X1 , …, Xk ∈ ; Y1 , …, Yk ∈ ;

k k
⊗m 1 ⊗m 1
k∑ ∑
Pk = δXim and Qk = δYim;
i=1
k i=1
m m
PXim, PYjm are empirical measures associated with Xi and Yj
Batch of Minibatches Optimal Transport
Batch of Minibatches Optimal Transport
• The corresponding transportation plan of Batch of minibatches optimal
transport (BoMb-OT) between Pn and Qn is defined as
k k
BoMb-OT OT
∑∑
π k = γijπPXm,PY m,
i j
i=1 j=1

where πOT is a transportation matrix that is returned by solving OT (P , P ) ;


PXm,PY m Xi
m Yj
m
i j

πOT is expanded to a n × n matrix that has padded zero entries to


PXm,PY m
i j
m m
indices which are different from those of Xi and Yj ;
⊗m ⊗m
γ is the transportation matrix between Pk and Qk
Batch of Minibatches Optimal Transport

The transportation between two empirical measures of 10 supports


that are drawn from two Gaussians.
Experiments: Deep Generative Model
Curse of Dimensionality of OT-GANs

81
Curse of Dimensionality of OT-GANs
• Another important issue of OT-GANs is curse of dimensionality
• The required number of samples for OT-GANs to obtain good
estimation of the underlying distribution of the data is exponential in
the number of the dimension

• Therefore, using OT-GANs for large-scale deep generative model can


be expensive in terms of the sample size

• Solutions: We utilize sliced OT-GANs and their variants [31], [32], [33], [34]

[31] Khai Nguyen, Nhat Ho, Tung Pham, Hung Bui. Distributional sliced-Wasserstein and applications to deep generative modeling. ICLR, 2021

[32] Khai Nguyen, Nhat Ho, Tung Pham, Hung Bui. Improving relational regularized autoencoders with spherical sliced fused Gromov Wasserstein. ICLR, 2021

[33] Khai Nguyen, Nhat Ho. Revisiting projected Wasserstein metric on images: from vectorization to convolution. Arxiv Preprint, 2022

[34] Khai Nguyen, Nhat Ho. Amortized projection optimization for sliced Wasserstein generative models. Arxiv Preprint, 2022
82
Sliced Optimal Transport
• We first define sliced optimal transport, which is key to define sliced OT-GANs
• The sliced optimal transport (OT) between two probability distributions μ and ν is defined
as follows:
1/p

( ∫𝕊d−1 )
p
SWp(μ, ν) := Wp(θ♯μ, θ♯ν)dθ ,

d
where θ♯μ is the push-forward probability measure of μ through the function Tθ :ℝ →ℝ

with Tθ(x) = θ x;

p ≥ 1 is the order of sliced optimal transport;


Wp is the p-th order Wasserstein metric

83
Properties of Sliced OT
There are three key properties of sliced optimal transport that make them
appealing for large-scale applications:

• The sliced OT is a proper metric in the space of probability measures,


namely, it satisfies the identity, symmetric, and triangle inequality
properties

• The computational complexity of sliced OT between probability


measures with at most n supports is 𝒪(n log n), which is (much) faster
2
than that of OT, which is 𝒪(n ) (via entropic regularized approach)

• The sliced OT does not suffer from curse of dimensionality, namely, the
required sample for the sliced OT to obtain good estimation of the
underlying probability distribution does not scale exponentially with the
dimension

84
Sliced-OT GANs
• Given the definition of sliced-OT, the sliced optimal transport GANs (Sliced-OT GANs) is:
min SWp(Tϕ(z), P),
ϕ

where Tϕ is some vector-value function parametrized by ϕ;

P is the true distribution of the data


• However, for generative models with images, that form of sliced-OT GANs means that
we first vectorize images and then project them to one-dimensional space

• The spatial structure of images is not captured efficiently by the vectorization


step

• Memory inefficiency since each slicing direction is a vector that has the same
dimension as the images

85
Sliced-OT GANs

86
Convolution Sliced-OT GANs [33]

[33] Khai Nguyen, Nhat Ho. Revisiting projected Wasserstein metric on images: from vectorization to convolution. Arxiv Preprint, 2022

87
Convolution
• To efficiently capture the spatial structures and improve the memory efficiency
of sliced OT, we utilize the convolution operators to the slicing process of
sliced optimal transport

• The convolution operators had been demonstrated to be very efficient for


images in Convolutional Neural Networks (CNNs)

88
Convolution Slicer

• There are three useful types of convolution slicers for images:


• Convolution-base slicer: reduce the width and the height of the image
by half after each convolution operator

• Convolution-stride slicer: the size of its kernels does not depend on


the width and the height of images as that of the convolution-base
slicer

• Convolution-dilation slicer: has bigger receptive field in each


convolution operator than convolution-stride slicer
89
Convolution Sliced Optimal Transport

90
Convolution Sliced Optimal Transport

91
Experiments: Deep Generative Models

L: the number of slices to approximate the integral (or equivalent expectation) in


sliced and convolution sliced optimal transport;
b: base; s:slide; d: dilation. 92
Experiments: Deep Generative Models

93
Experiments: Deep Generative Models

94
Experiments: Deep Generative Models

95
Experiments: Deep Generative Models

96
Conclusion
• We have studied both the computational complexities of optimal transport as
well as its applications to deep generative models

• There are several interesting open directions:


• First direction: Improving further minibatch optimal transport in GANs and
other deep learning applications

• Second direction: Developing more efficient sliced optimal transport for


other applications, such as language-models, etc.

• Third direction: Exploring more computationally efficient ways to compute


optimal transport

• Fourth direction: Researching more important variants of optimal


transport, such as unbalanced optimal transport, partial optimal
transport, etc.

97
Thank You!

98
References
[1] Martin Arjovsky, Soumith Chintala, Léon Bottou. Wasserstein Generative Adversarial Networks. ICML,
2017

[2] Ishaan Gulrajani, Faruk Ahmed, Martin Arjovsky, Vincent Dumoulin, Aaron C. Courville. Improved
Training of Wasserstein GANs. NIPS, 2017

[3] Ilya Tolstikhin, Olivier Bousquet, Sylvain Gelly, Bernhard Scholkopf. Wasserstein Auto-Encoders. ICLR,
2018

[4] Nicolas Courty, Rémi Flamary, Devis Tuia, Alain Rakotomamonjy. Optimal Transport for Domain
Adaptation. IEEE Transactions on Pattern Analysis and Artificial Intelligence (PAMI), 2017

[5] Bharath Bhushan Damodaran, Benjamin Kellenberger, Rémi Flamary, Devis Tuia, Nicolas Courty.
DeepJDOT: Deep Joint Distribution Optimal Transport for Unsupervised Domain Adaptation. ECCV, 2018

[6] Trung Nguyen, Hieu Pham, Tam Le, Tung Pham, Nhat Ho, Son Hua. Point-set distances for learning
representations of 3D point clouds. ICCV, 2021

[7] Nhat Ho, Long Nguyen, Mikhail Yurochkin, Hung Bui, Viet Huynh, and Dinh Phung. Multilevel clustering
via Wasserstein means. ICML, 2017

[8] Viet Huynh, Nhat Ho, Nhan Dam, Long Nguyen, Mikhail Yurochkin, Hung Bui, Dinh Phung. On efficient
multilevel clustering via Wasserstein distances. Journal of Machine Learning Research, 2021
99
References
[9] Xing Han, Tongzheng Ren, Jing Hu, Joydeep Ghosh, Nhat Ho. Efficient Forecasting of Large Scale Hierarchical Time
Series via Multilevel Clustering. Under review, NeurIPS, 2022

[10] Jingjing Xu Hao Zhou Chun Gan Zaixiang Zheng Lei Li. Vocabulary Learning via Optimal Transport for Neural Machine
, , , ,

Translation. ACL, 2021

[11] Khang Le, Huy Nguyen, Quang Nguyen, Tung Pham, Hung Bui, Nhat Ho. On robust optimal transport: Computational
complexity and barycenter computation . NeurIPS, 2021

[12] Nhat Ho, Tan Nguyen, Ankit Patel, Anima Anandkumar, Michael I. Jordan, Richard Baraniuk. A Bayesian Perspective
of Convolutional Neural Networks through a Deconvolutional Generative Model. Under Revision, Journal of Machine
Learning Research, 2021

[13] Long Nguyen. Convergence of latent mixing measures in finite and infinite mixture models. Annals of Statistics, 2013

[14] Nhat Ho, Long Nguyen. Convergence rates of parameter estimation for some weakly identifiable finite mixtures.
Annals of Statistics, 2016

[15] Nhat Ho, Chiao-Yu Yang, Michael I. Jordan. Convergence rates for Gaussian mixtures of experts. Journal of Machine
Learning Research, 2022 (Accepted Under Minor Revision)

[16] Rui Gao, Anton J Kleywegt. Distributionally robust stochastic optimization with Wasserstein distance. Arxiv preprint
arXiv:1604.02199, 2016

[17] Daniel Kuhn, Peyman Mohajerin Esfahani, Viet Anh Nguyen, Soroosh Shafieezadeh-Abadeh. Wasserstein
distributionally robust optimization: Theory and applications in machine learning. INFORMS Tutorials in Operations
Research
100
References
[18] Matthew Thorpe. Introduction to Optimal Transport (https://www.math.cmu.edu/~mthorpe/
OTNotes)

[19] Gabriel Peyré Marco Cuturi. Computational Optimal Transport: With Applications to Data Science.
,

Foundations and Trends® in Machine Learning, 2019

[20] Marco Cuturi. Sinkhorn Distances: Lightspeed Computation of Optimal Transport. NIPS 2013

[21] Jason Altschuler Jonathan Weed Philippe Rigollet. Near-linear time approximation algorithms for
, ,

optimal transport via Sinkhorn iteration. NIPS, 2017

[22] Pavel Dvurechensky, Alexander Gasnikov, Alexey Kroshnin. Computational Optimal Transport:
Complexity by Accelerated Gradient Descent Is Better Than by Sinkhorn’s Algorithm. ICML, 2018

[23] T. Lin, N. Ho, M. I. Jordan.On efficient optimal transport: an analysis of greedy and accelerated
mirror descent algorithms. ICML, 2019

[24] T. Lin, N. Ho, M. I. Jordan. On the efficiency of entropic regularized algorithms for optimal
transport. Journal of Machine Learning Research (JMLR), 2022

101
References
[25] Diederik P Kingma, Max Welling. Auto-Encoding Variational Bayes. ICLR, 2014

[26] Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil
Ozair, Aaron Courville,Yoshua Bengio. Generative Adversarial Networks. NIPS, 2014

[27] Martin Arjovsky, Soumith Chintala, Léon Bottou. Wasserstein Generative Adversarial Networks.
ICML, 2017

[28] Khai Nguyen, Dang Nguyen, Tung Pham, Nhat Ho. Improving minibatch optimal transport via partial
transportation. ICML, 2022

[29] Khai Nguyen, Dang Nguyen, Quoc Nguyen, Tung Pham, Dinh Phung, Hung Bui, Trung Le, Nhat Ho.
On transportation of mini-batches: A hierarchical approach. ICML, 2022

[30] Kilian Fatras, Thibault Sejourne, Rémi Flamary, and Nicolas Courty. Unbalanced minibatch optimal
transport; applications to domain adaptation. ICML, 2021

102
References
[31] Khai Nguyen, Nhat Ho, Tung Pham, Hung Bui. Distributional sliced-
Wasserstein and applications to deep generative modeling. ICLR, 2021

[32] Khai Nguyen, Nhat Ho, Tung Pham, Hung Bui. Improving relational
regularized autoencoders with spherical sliced fused Gromov Wasserstein. ICLR,
2021

[33] Khai Nguyen, Nhat Ho. Revisiting projected Wasserstein metric on images:
from vectorization to convolution. Arxiv Preprint, 2022

[34] Khai Nguyen, Nhat Ho. Amortized projection optimization for sliced
Wasserstein generative models. Arxiv Preprint, 2022

103

You might also like