Your Contrastive Learning Problem Is Secretly A Distribution Alignment Problem
Your Contrastive Learning Problem Is Secretly A Distribution Alignment Problem
Zihao Chen∗, Chi-Heng Lin, Ran Liu, Jingyun Xiao, Eva L. Dyer∗
School of Electrical & Computer Engineering
Georgia Tech, Atlanta, GA
arXiv:2502.20141v1 [cs.LG] 27 Feb 2025
Abstract
Despite the success of contrastive learning (CL) in vision and language, its the-
oretical foundations and mechanisms for building representations remain poorly
understood. In this work, we build connections between noise contrastive estima-
tion losses widely used in CL and distribution alignment with entropic optimal
transport (OT). This connection allows us to develop a family of different losses
and multistep iterative variants for existing CL methods. Intuitively, by using
more information from the distribution of latents, our approach allows a more
distribution-aware manipulation of the relationships within augmented sample
sets. We provide theoretical insights and experimental evidence demonstrating
the benefits of our approach for generalized contrastive alignment. Through this
framework, it is possible to leverage tools in OT to build unbalanced losses to
handle noisy views and customize the representation space by changing the con-
straints on alignment. By reframing contrastive learning as an alignment problem
and leveraging existing optimization tools for OT, our work provides new insights
and connections between different self-supervised learning models in addition to
new tools that can be more easily adapted to incorporate domain knowledge into
learning.
1 Introduction
In machine learning, the availability of vast amounts of unlabeled data has created an opportunity
to learn meaningful representations without relying on costly labeled datasets [26, 52, 27]. Self-
supervised learning has emerged as a powerful solution to this problem, allowing models to leverage
the inherent structure in data to build useful representations. Among self-supervised methods, con-
trastive learning (CL) is widely adopted for its ability to create robust representations by distinguishing
between similar (positive) and dissimilar (negative) data pairs. With success in fields like image
and language processing [8, 46], contrastive learning now also shows promise in domains where
cross-modal, noisy, or structurally complex data make labeling especially challenging [34, 56, 10].
Traditional contrastive learning methods primarily aim to bring positive pairs—often augmentations
of the same sample—closer together in representation space. While effective, this approach often
struggles with real-world challenges such as noise in views, variations in data quality, or shifts
introduced by complex transformations, where positive pairs may not perfectly align. Additionally,
in tasks requiring domain generalization, aligning representations across diverse domains (e.g.,
variations in style or sensor type) is critical but difficult to achieve with standard contrastive learning,
which typically lacks mechanisms for incorporating domain-specific relationships. These limitations
highlight the need for a more flexible approach that can adapt alignment strategies based on the data
structure, allowing for finer control over similarity and dissimilarity among samples.
∗
Contact: {zchen959, evadyer}@gatech.edu
2 Background
2.1 Contrastive learning
Contrastive learning (CL) is a representation learning methodology that uses positive and negative
pairs to define similarity in the latent space. Let D = {xi }N i=1 denote our dataset. For each sample
xi in a batch of training data with size B, we create two augmented copies x′i and x′′i independently,
i.e., x′i = ψ(xi ) where ψ is a randomly drawn augmentation function from some augmentation class
A and likewise for x′′i . The (x′i , x′′i ) is called a positive pair of xi while (x′i , x′′j ) is treated as a
negative pair for any j ̸= i. One of the most widely used formulations of the CL problem, InfoNCE
(INCE) [8], seeks to maximize the negative log probability that a sample is correctly classified as
esii
LINCE = − log , (1)
esii + i̸=j esij
P
where sij = ε−1 fθ (x′i )⊤ fθ (x′′j )/∥fθ (x′i )∥∥fθ (x′′j )∥ is the score between augmented samples.
2
Building upon the principles of INCE, SimCLR [8] and MoCo [24] are two representative works that
form the foundation of contrastive learning methods for visual representation tasks. Alternatively,
BYOL [22] and SimSiam [9] discard the use of negative samples to avoid large batch size and instead
use exponential moving average-based updates to avoid representational collapse. Recent contrastive
methods have focused on improving the tolerance to noise in samples to enhance robustness in diverse
scenarios [13]. Among them, Robust INCE (RINCE) is a robust contrastive loss function characterized
by its symmetric properties and theoretical resistance to noisy labels [47, 12]. Specifically, RINCE
provides robustness to noisy views by introducing adjustable parameters λ and q [12] which rebalance
the cost of positive and negative views, resulting in the following loss:
1 X
Lλ,q − eqsii + λq (esii + esij )q
RINCE = (2)
q i̸=j
By optimizing the above loss functions, the encoder f is trained to construct a semantically coherent
representation space where positive pairs of samples are positioned nearby, while those negative pairs
with divergent semantic attributes are separated [57].
2.2 Proximal Operators and Projections
To make the connections between different CL losses clearer later, we use the notion of proximal
operators. In words, the proximal operator will provide a way to find the closest point in some closed
convex set. Formally, we can define the proximal operator as follows.
Definition 1 (Proximal Operator). Let dΓ (x, v) = Γ(x) − Γ(v) − ⟨∇Γ(v), x − v⟩ be a Bregman
divergence with a convex function Γ. The proximal operator of h : X → R ∪ {+∞} is defined for a
point v ∈ X with a closed convex set B ⊆ X :
Proxdh,B
Γ
(v) = arg min {h(x) + dΓ (x, v)} .
x∈B
Moreover, we can define the concept of a projection as a special case of the proximal operator when
we let h(x) be an indicator function hB (x) = {0, if x ∈ B; ∞, if x ∈ / B} on constraint set B. See
Appendix A.2 for more details.
2.3 Solving Optimal Transport Through Proximal Point Methods
Optimal transport (OT) is widely used in characterizing the distance between two collections of
PB PB
samples {xi }B B
i=1 and {yj }j=1 with associated measures µ = i=1 δxi pi and ν = j=1 δyj qj with
Dirac delta function δx and δy on finite support [43]. Here, p and q are vertices of the RB simplex
PB
defined as ∆B := {v ∈ RB : vi ≥ 0, i=1 vi = 1}. OT aims to learn a joint coupling matrix,
or transport plan P ∈ RB×B
+ that minimizes the cost of transporting mass encoded by cost matrix
C ∈ RB×B+ , from one distribution to another. In practice, entropy regularization is used to solve the
OT objective, resulting in the following entropy-regularized OT (EOT) objective:
X
min ⟨P, C⟩ − εH(P), where H(P) = − Pij log(Pij ), (3)
P∈B
ij
where ε is a user specified parameter that controls the amount of smoothing in the transport plan, and
C(x, y) = 1 − ⟨x, y⟩/∥x∥∥y∥ is often set to encode the cosine similarity between pairs of samples.
The Sinkhorn Algorithm and its Interpretation as a Bregman Projection. Solving Equation (3)
could be interpreted as iterative alignment problem on a Hilbert space generated from the kernel
Kij = exp(−Ci,j /ε). This alignment problem can be solved through iterative Bregman projections
onto the two constraints sets that encode the marginals along the rows and columns [3, 5, 43]:
C1µ := {P : P1B = µ}, C2ν := {P : P⊤ 1B = ν} (4)
The first step of Bregman projection is to find the minimizer P(1) = arg min{εKL(P∥K) : P1B =
µ
µ} by the proximal operator ProxKL C µ (K) with Lagrange multiplier f on the row constraint set C1 ,
1
and compute its derivatives with respect to P with u = ef /ε > 0:
µ
ε log(P(1) /K) − f 1 = 0 ⇒ P(1) = uK, ⟨P(1) , 1⟩ = µ ⇒ ⟨uK, 1⟩ = µ, u = (5)
K1
3
Next, we project P(1) onto the column constraint set C2ν , resulting in P(2) := ProxKLC2ν (P
(1)
) =
(1) ν
P diag( P(1)⊤ 1B ). The iterative updates can be succinctly expressed as the Sinkhorn iterations:
The Wasserstein Dependency Measure (WDM) is a measure of deviation between two probability
measures. We will use this later and thus provide the formal definition here [39].
Definition 2 (Wasserstein Dependency Measure). Define the WDM as the Wasserstein distance (W1 )
between the joint distribution π(x, y) and the product of marginal distributions µ ⊗ ν(x, y) of two
random variables x and y. W1 (π, µ ⊗ ν) = supf ∈C(X ×Y) Eπ(x,y) [f (x, y)] − Eµ⊗ν(x,y) [f (x, y)] ,
where C(X × Y) denotes the set of all 1-Lipschitz functions from X × Y to R.
Distribution alignment and OT have been widely used for domain adaptation [33, 14, 30, 59],
and in generative modeling [2, 55, 49, 58]. The connections between distribution alignment and
contrastive learning, however, are still nascent. In [51], the authors explore the connection between
inverse OT (IOT) [32, 53, 18] and INCE. Our work builds on this connection to OT to build robust
divergences (RINCE) and to build a novel unbalanced optimal transport (UOT) method (Section 3.3).
Additionally, we show how our framework can be used to build flexible methods for encouraging
contrast at multiple levels. We use this concept of hierarchical contrast and show that it can be used
in domain generalization settings (Section 6.2). It is of note that GCA-UOT focuses on relaxing the
hard constraints on the row and columns into the soft penalties, which is different with the idea of
“unbalanced matching” in [51] which considers the case where the encoders may not have the same
weights.
4
Defining the Kernel Space. Before formally stating our objective, we first need to define the
concept of an augmentation kernel for our positive and negative examples.
Definition 3 (Augmentation Kernel). Let fθ denote an encoder with parameters θ and let (x′i , x′′j ) ∼
A be two views drawn from the family of augmentations A. The augmentation kernel for the encoder
θ is defined as Kθ (x′i , x′′j ) = exp(−dist(feθ (x′i ), feθ (x′′j ))/ε), where dist(·) can be an arbitrary
distance metric, and feθ (x′i ) is the normalized output of fθ , and ε is the regularization parameter.
Main Objective. With this definition in hand, we can now formalize our objective as follows:
min dM Ptgt ||Pθ ), with Pθ = arg min{h(P) + dΓ (P||Kθ )}, (8)
θ P∈B
where Kθ is the augmentation kernel defined in Definition (3), h(x) is a convex function (typically
an indicator function), B is a closed convex constraint set (i.e. Birkhoff polytope) that defines the
constraints of proximal operators, dΓ is a Bregman divergence that is used to find the nearest points
Pθ on the constraint set B of Kθ , dM is a convex function (e.g., KL-divergence) that measures
divergence between Pθ and the target coupling plan Ptgt .
Our objective is a bi-level optimization problem which aims to learn a representation that minimizes
the divergence between the transport plan Pθ with the target alignment plan Ptgt that encodes the
matching constraints. When we consider a standard contrastive learning setup where we have pairs
of positive examples the source and target distribution, then the target Ptgt is the identity matrix
I. However, we will show later that other alignment constraints can be considered. Moreover,
when B is the intersection of more constraint sets like C1µ ∩ C2ν in Equation (4), a nature way to
get the approximation of the nearest points Pθ of Kθ is to run iterative projections algorithm [3],
which could be extended into the intersection of several constraint sets like {∩ni=1 Ci }, resulting in a
multi-marginal problem [41].
3.2 A Proximal Point Algorithm for GCA
In practice, we can solve the alignment problem above by iteratively updating the two main com-
ponents in our bi-level objective. First, for a fixed encoder parameters θ, we obtain the transport
coupling Pθ through our corresponding proximal operator. Second, we measure the deviation be-
tween the transport plan Pθ with the target Ptgt that encodes our matching constraints, which denotes
the ideal alignment plan on the intersection of the constraint sets. We provide pseudocode for this
iterative approach in Algorithm 1, which we refer to as generalized contrastive alignment or GCA.
The implementation of our methods is in https://github.com/nerdslab/gca.
3: Calculate the loss: Calculate deviation between the target and current transport plans
LGCA = dM (Pθ , Ptgt ).
Update networks fθ (encoder) and gθ (projector) to minimize LGCA .
4: Repeat until convergence: Repeat steps 2 and 3 until convergence.
5
3.3 GCA-UOT Method
We can also benefit from the rich literature on optimal transport to build different relaxations of our
objective [43, 7, 54, 33, 35]. In particular, we choose to leverage a formulation of unbalanced optimal
transport (UOT) to further relax the marginal constraints [11] in our objective.
In this case, we can add the dual form of dΓ to the Equation (8) and reformulate our objective as:
min dM (Ptgt ∥Pθ ) + λ1 hF (Pθ 1||µ) + λ2 hG (P⊤
θ 1||ν) + εH(Pθ ). (9)
θ
Here hF and hG can be different divergence measures (e.g., KL divergence) that penalize deviations
from the desired marginals µ and ν, and λ1 and λ2 are regularization parameters that control the trade-
off between the transport cost and the divergence penalties. This relaxation leads to different types
of proximal operators which we outline in Appendix B.2. The impact of the entropy regularization
parameter ε on the coupling matrix is studied in Figure A5, along with the number of iterations and
corresponding sensitivity is provided in Figure A6.
Contrastive learning objectives can be cast as a minimization of the deviations between the transport
plan Pθ and the identity matrix, i.e., Ptgt = I. However, our GCA formulation enables learning
representations that extend beyond this one-to-one matching constraint. This flexibility allows us to
incorporate additional matching constraints informed by domain-specific knowledge. For example, in
domain generalization scenarios [23, 28], where each batch contains samples from multiple domains,
the target alignment plan can be structured as:
Ptgt [i, j] = I[i, j] + α · I(Di = Dj , i ̸= j) + β · I(Di ̸= Dj , i ̸= j),
Where I(·) is the indicator function, which equals 1 if the condition inside is true and 0 otherwise.
Di represents the domain of sample i, where α ≥ 0 and β ≥ 0. In this case, we can improve the
representation by building the block constraints which encode either class information (in supervised
setting) or domain information (in across domain generalization, visualized in Figure 1).
The forward-pass only involves the scaling operations in Equation (7) and doesn’t affect the complex-
ity of the backward-pass. Therefore, GCA methods can be thought of as a form of batch normalization
operations with adaptive scaling. An analysis of the complexity is provided along with experiments
in Appendix B.1. Our results show that GCA iterations only slightly increase the computational
complexity when compared with their single step equivalent (GCA-INCE vs. INCE). However, we
found that GCA-UOT is faster than INCE due to the improved symmetry and smoothness of the loss.
Moreover, we record the floating point operations per second (Flops) of running GCA methods. We
find that GCA-INCE (6.65 MFlops) has 5% more Flops than INCE (6.31 MFlops), while GCA-UOT
saves 30% Flops (4.54 MFlops). These results show that our GCA-UOT method is not only superior
in terms of accuracy but also in speed.
6
Theorem 1 (INCE Equivalence). Let Kθ denote the augmentation kernel as in Definition (3) with
cosine similarity, dΓ and dM equal to KL-divergence, and constraint set as C1µ in Equation (4). The
INCE objective in Equation (1) can be re-expressed as a GCA problem in Equation (8) as follows:
min KL I||ProxKL
C µ (Kθ )). (10)
θ 1
The proof is contained in Appendix B.3. Theorem (1) shows that the INCE loss can be viewed
as solving the matching problems in Equation (3) with row normalization constraints C1µ . This
connection between GCA and INCE allows us to derive the iterative algorithm for GCA-INCE by
running Bregman projection iteratively on both row and column normalization sets.
4.2 Connection to RINCE
We introduce the following result to build the connection between our framework and RINCE [12].
Theorem 2 (RINCE Equivalence). Let Kθ denote the augmentation kernel as in Definition (3). Set
q
target plan Ptgt = I, dΓ equal to the KL-divergence, dM (I∥P) = − 1q ( diag(Pu
θ) q
) + λI
u with λ,
µ
µ
q, and u = diag P(0) 1 , and constraint set C 1 defined in Equation (4). The RINCE objective in
Equation (2) can be re-expressed as a GCA problem as follows:
min dM (I∥Pθ ), with Pθ = ProxKL C1µ (Kθ ), (11)
θ
The proof is provided in Appendix B.4.1. As we can see, RINCE introduces adjustable parameters q
and λ, with λ controlling the weight of negative samples, while q ∈ (0, 1] serves to switch between
KL divergence and Wasserstein discrepancy. When q = 1, we have the following theorem:
Theorem 3 (W1 Equivalence). Let Kθ denote the augmentation kernel as in Definition (3) with
cosine similarity. Set target plan Ptgt = I, dΓ equal to the KL-divergence, dM equal to the 1-
Wasserstein distance (W1 ) in Definition (2), and the constraint set as C1µ defined in Equation (4).
The RINCE object in Equation (2) with q = 1 can be re-expressed as a GCA problem as follows:
min W1 Ptgt ||ProxKL
C µ (Kθ )). (12)
θ 1
See Appendix B.5 for the proof. This connection to RINCE suggests an extended iterative formulation
to calculate the coupling plan as the projection point P(∞) = ProxKL C1µ ∩C2ν (Kθ ) of Kθ on the
constraint set C1µ ∩ C2ν . In this case, we can write an iterative algorithm for robust alignment called
GCA-RINCE as follows:
(2t−1)
Lλ,q
GCA-RINCE = min −q
−1
(diag(Pθ )/u(t) )q + q −1 (λPtgt /u(t) )q , (13)
θ
where λ and q are hyperparameters, P(1) := diag(u(1) )Kθ diag(v(0) ), and t is the number of
iterations.
4.3 Connection to BYOL
Our framework also allows us to make connections to BYOL [22]. BYOL learns by encouraging
similarity between positive image pairs, without explicitly conditioning on negative examples. To
build this connection, recall that BYOL has the online network parameterized by θ and target
network parameterized by ξ, where z′θ = feθ (x′ ) and z′′ξ = feξ (x′′ ) are the normalized outputs of the
online and target networks, respectively. A simplified version of the BYOL loss can be written as:
LBYOL = ∥e qθ (z′θ ) − z′′ξ ∥22 , where qeθ (z′θ ) is the normalized output after online network and qθ is the
3
predictor. In this case, we can provide the following connection between GCAand BYOL as follows.
Theorem 4 (BYOL Equivalence). Let Sθ (x′i , x′′j ) = exp(−∥e qθ (z′i ) − z′′j ∥) denote the augmentation
kernel. Set the target plan Ptgt = I, dΓ equal to the L2-distance, dM equal to the KL-divergence,
and constraint set as RB×B . The BYOL objective can be re-expressed as a GCA problem as follows:
∥·∥
min KL I||Sθ ), with Sθ = ProxRB×B (Sθ ). (14)
θ
7
5 Theoretical Analysis
In this section, we aim to show how the GCA-methods can improve alignment and uniformity in
the latent space [57]. Here, alignment means that the features of the positive samples are as close
as possible, while uniformity means that the features of negative samples are uniformly distributed
on latent space (see Appendix C.1 for formal definitions). These quantities have been studied in a
number of related works [57, 45], where one can show that improved alignment and uniformity can
lead to different benefits in representation learning.
The full proof is provided in Appendix C.1.1. The above theorem tells us that solving Equation (8)
(∞)
with iterative projection will converge to a transport plans Pθ with lower KL divergence than the
one-step solution provided by INCE. We can establish the convergence of the P(t) → P(∞) , based
on the convergence of Bregman projection.
Analysis of RINCE vs GCA-RINCE.
GCA also benefits from other Bregman divergences, like the WDM in RINCE, which provides
robustness against distribution shift compared to the KL-divergence in INCE. GCA-RINCE provides
a lower bound on the RINCE loss in Equation (2), which allows us to develop a tighter bound with
P(∞) obtained by several proximal steps with GCA.
(t)
Theorem 6 (Improved Alignment with RINCE). GCA-RINCE loss with Pθ in Equation (13) is
(t) (1)
lower than the loss in the Theorem (2) as Lλ,q=1 λ,q=1
GCA-RINCE (Pθ ) ≤ LRINCE (Pθ ).
See Appendix C.1.1 for the full proof and an analysis of GCA methods for different choices of dM .
5.2 Improved Uniformity of Representations Through GCA
The improved alignment of GCA-methods comes from maximization of the uniformity under the
constraint of intersection C1µ ∩ C2ν in Equation (4), rather than the constraint set C1µ in INCE (see
Table 1). Finding the projection of Kθ on set of C1µ ∩ C2ν through proximal steps is equivalent to
solving the dual problem of EOT, which can be summarized through the following theorem.
Theorem 7 (Improved Uniformity). Given the constraint sets in Equation (4), the optimal transport
coupling upon convergence of Equation (6), denoted as P(∞) , achieves a higher uniformity loss
compared to the single-step transport plan P(1) obtained by INCE.
The proof is provided in the Appendix C.2. Through loss propagation, we show that the alignment
plan offered by P(∞) will guide the subsequent iterations towards more uniform representations.
5.3 Impacts of GCA on a downstream classification task
We take this one step further and examine the impact of GCA on a downstream classification task.
For a classification task, using a labeled dataset D = {(x̄i , yi )} ∈ X̄ × Y where Y = [1, . . . , M ]
with M classes, we consider a fixed, pre-trained encoder fθ ∈ F : X → S. Assume that positive and
negative views of n original samples (x̄i )i∈[1..n] ⊂ X̄ are sampled from the data distribution p(x̄).
In this case, the uniformity loss is equivalent to optimizing the downstream supervised classification
tasks with cross-entropy (CE) loss when the following two assumptions are satisfied [16].
8
Standard Setting Noisy Setting
Method CIFAR-10 CIFAR-100 SVHN ImageNet100 CIFAR-10 (Ex) CIFAR-100 (Ex) CIFAR-10C CIFAR-10C (Ex)
INCE 92.01 ± 0.40 71.09 ± 0.31 92.42 ± 0.24 73.01 ± 0.61 82.03 ± 0.32 62.54 ± 0.20 81.52 ± 1.04 83.28 ± 0.25
GCA-INCE 93.02 ± 0.19 71.55 ± 0.12 92.64 ± 0.26 73.04 ± 0.76 82.18 ± 0.69 62.65 ± 0.17 82.63 ± 0.28 82.72 ± 0.27
∆ +1.01 +0.46 +0.22 +0.03 +0.15 +0.11 +1.11 -0.56
RINCE 93.27 ± 0.20 71.63 ± 0.36 93.26 ± 0.15 71.91 ± 0.43 82.60 ± 0.63 63.55 ± 0.14 82.86 ± 0.21 83.64 ± 0.26
GCA-RINCE 93.47 ± 0.30 71.95 ± 0.48 93.57 ± 0.26 73.44 ± 0.55 82.76 ± 0.49 63.14 ± 0.41 82.87 ± 0.11 83.69 ± 0.16
∆ +0.20 +0.32 +0.31 +1.53 +0.16 -0.47 +0.01 +0.05
SimCLR 92.07 ± 0.90 70.85 ± 0.50 92.13 ± 0.34 72.20 ± 0.78 81.87 ± 0.53 62.94 ± 0.13 81.74 ± 1.54 83.25 ± 0.18
BYOL 90.56 ± 0.59 69.75 ± 0.37 89.50 ± 0.46 69.75 ± 0.83 81.55 ± 0.50 62.11 ± 0.25 82.43 ± 0.06 70.09 ± 0.34
IOT [51] 92.09 ± 0.22 68.37 ± 0.42 92.25 ± 0.19 72.27 ± 0.53 80.59 ± 0.64 62.69 ± 0.34 82.01 ± 0.80 81.57 ± 0.83
IOT-uni [51] 91.49 ± 0.11 68.62 ± 0.35 92.33 ± 0.11 72.88 ± 0.71 80.79 ± 0.24 62.56 ± 0.22 81.19 ± 1.12 81.82 ± 0.51
GCA-UOT 93.50 ± 0.31 72.16 ± 0.38 93.82 ± 0.17 74.09 ± 0.40 83.18 ± 0.44 63.62 ± 0.27 82.90 ± 0.50 83.64 ± 0.19
Table 2: Test accuracy (%) on a downstream classification task after pretraining. Results are shown for CIFAR-
10 (ResNet18), CIFAR-100 (ResNet18), SVHN (ResNet18), and ImageNet100 (ResNet50) under standard and
extreme (Ex) augmentation conditions (averaged over 5 seeds). The CIFAR-100 column is now updated with
the latest values from Table A3. The top model is in bold and the second-place model is underlined. For INCE
and RINCE, we also provide the improvement ∆ by adding GCA to each method.
Assumption 1 (Expressivity of the Encoder). Let us define HX̄ is the RKHS associated with the kernel
KX̄ defined on X̄ , and (Hfθ , Kθ ) defined on X with augmentation kernel Kθ = ⟨fθ (·), fθ (·)⟩Rd in
Definition 3. And we assume that ∀g ∈ Hfθ , EA(x|·) g(x) ∈ HX̄ .
Assumption 2 (Small Intra-Class Variance). For y ̸= y ′ , the intra-class variance δi , δj are negligible
compared to the distance among different class centroids, µy , µy′ as ∥µy − µy′ ∥ ≫ ∥δi − δj ∥.
Claim 1. If Assumption 1 and Assumption 2 hold, then maximizing the uniformity is equivalent to
minimizing the downstream CE loss.
The proof is provided in Appendix C.2. Optimizing the self-supervised loss under ideal conditions
improves downstream CE tasks and helps to explain why maximizing uniformity aids classification.
Remark.. Maximizing uniformity can enhance downstream classification but risks “feature suppres-
sion” by encouraging shortcut features that harm generalization [48]. In GCA-UOT, adding penalties
modifies the transport plan from that of a pure uniformity loss, helping to avoid feature suppres-
sion. We find empirical evidence that UOT provides a more robust transport plan which appears to
circumvent some of these shortcut features from being learned (Figure A4 in Appendix C.3).
6 Experiments
In this section, we conduct empirical evaluations to study the performance of our approach in both
handling noisy and corrupted views and in domain generalization tasks.
9
obtained by each methods are provided in Figure A4 along with a study of the sensitivity of the
methods to hyperparameters (Appendix A7).
Results on Corrupted Data and Extreme Augmentations. Next, we tested the methods in two
noisy settings. In the first set of experiments, we apply extreme augmentations to CIFAR-10 (Ex)
and CIFAR-100 (Ex) (see Appendix D.2) to introduce noisy views during training. In the second set
of experiments, we used the CIFAR-10C to further test the ability of our method to work in noisy
settings.
Our experimental results demonstrate that the GCA-based strategy effectively enhances the model’s
generalization ability and adaptability to aggressive data augmentations. In addition to improving
classification accuracy, the GCA-based methods also improve the representational alignment and
uniformity, as shown in Appendix E.2. This observation is in line with our theoretical analysis in
Section 5.2, where we show that the obtained representations provide better overall alignment of
positive views and better spread in terms of uniformity [57].
7 Conclusion
In this work, we introduced generalized contrastive alignment (GCA), a flexible framework that
redefines contrastive learning as a distributional alignment problem using optimal transport to
control alignment. By allowing targeted control over alignment objectives, GCA demonstrates
strong performance across both standard and challenging settings, such as noisy views and domain
10
generalization tasks. This work opens up broader possibilities for learning robust representations in
real-world scenarios, where data is often diverse, noisy, or comes from multiple domains.
Future work includes applications of GCA to graphs and time series data, as well as multi-modal
settings where our approach can integrate various forms of similarity. As alignment strategies become
integral to contrastive learning, GCA offers a promising foundation for more adaptive and expressive
self-supervised models.
Acknowledgements
We would like to thank Mehdi Azabou, Divyansha, Vinam Arora, Shivashriganesh Mahato, and Ian
Knight for their valuable feedback on the work. This work was funded through NSF IIS-2212182,
NSF IIS-2039741, and the support from the Canadian Institute for Advanced Research (CIFAR). We
would also like to acknowledge the use of ChatGPT for providing useful feedback and suggestions
on the writing of the paper.
References
[1] Dongsheng An, Na Lei, Xiaoyin Xu, and Xianfeng Gu. Efficient optimal transport algorithm by
accelerated gradient descent. In Proceedings of the AAAI Conference on Artificial Intelligence,
volume 36, pages 10119–10128, 2022.
[2] Martin Arjovsky, Soumith Chintala, and Léon Bottou. Wasserstein generative adversarial
networks. In International conference on machine learning, pages 214–223. PMLR, 2017.
[3] Jean-David Benamou, Guillaume Carlier, Marco Cuturi, Luca Nenna, and Gabriel Peyré. Itera-
tive bregman projections for regularized transportation problems. SIAM Journal on Scientific
Computing, 37(2):A1111–A1138, 2015.
[4] Robert J Berman. The sinkhorn algorithm, parabolic optimal transport and geometric monge–
ampère equations. Numerische Mathematik, 145(4):771–836, 2020.
[5] Lev M Bregman. The relaxation method of finding the common point of convex sets and
its application to the solution of problems in convex programming. USSR computational
mathematics and mathematical physics, 7(3):200–217, 1967.
[6] Xing-Ju Cai, Ke Guo, Fan Jiang, Kai Wang, Zhong-Ming Wu, and De-Ren Han. The devel-
opments of proximal point algorithms. Journal of the Operations Research Society of China,
10(2):197–239, 2022.
[7] Liqun Chen, Dong Wang, Zhe Gan, Jingjing Liu, Ricardo Henao, and Lawrence Carin. Wasser-
stein contrastive representation distillation. In Proceedings of the IEEE/CVF conference on
computer vision and pattern recognition, pages 16296–16305, 2021.
[8] Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey Hinton. A simple framework
for contrastive learning of visual representations. In International conference on machine
learning, pages 1597–1607. PMLR, 2020.
[9] Xinlei Chen and Kaiming He. Exploring simple siamese representation learning. In Proceedings
of the IEEE/CVF conference on computer vision and pattern recognition, pages 15750–15758,
2021.
[10] Zining Chen, Weiqiu Wang, Zhicheng Zhao, Fei Su, Aidong Men, and Yuan Dong. Instance
paradigm contrastive learning for domain generalization. IEEE Transactions on Circuits and
Systems for Video Technology, 34(2):1032–1042, 2023.
[11] Lenaic Chizat, Gabriel Peyré, Bernhard Schmitzer, and François-Xavier Vialard. Scaling algo-
rithms for unbalanced optimal transport problems. Mathematics of Computation, 87(314):2563–
2609, 2018.
11
[12] Ching-Yao Chuang, R Devon Hjelm, Xin Wang, Vibhav Vineet, Neel Joshi, Antonio Torralba,
Stefanie Jegelka, and Yale Song. Robust contrastive learning against noisy views. In Proceedings
of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 16670–16681,
2022.
[13] Ching-Yao Chuang, Joshua Robinson, Yen-Chen Lin, Antonio Torralba, and Stefanie Jegelka.
Debiased contrastive learning. Advances in neural information processing systems, 33:8765–
8775, 2020.
[14] Nicolas Courty, Rémi Flamary, and Devis Tuia. Domain adaptation with regularized optimal
transport. In Machine Learning and Knowledge Discovery in Databases: European Conference,
ECML PKDD 2014, Nancy, France, September 15-19, 2014. Proceedings, Part I 14, pages
274–289. Springer, 2014.
[15] Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-
scale hierarchical image database. In 2009 IEEE Conference on Computer Vision and Pattern
Recognition, pages 248–255. Ieee, 2009.
[16] Benoit Dufumier, Carlo Alberto Barbano, Robin Louiset, Edouard Duchesnay, and Pietro Gori.
Integrating prior knowledge in contrastive learning with kernel. In International Conference on
Machine Learning, pages 8851–8878. PMLR, 2023.
[17] Marvin Eisenberger, Aysim Toker, Laura Leal-Taixé, Florian Bernard, and Daniel Cremers.
A unified framework for implicit sinkhorn differentiation. In Proceedings of the IEEE/CVF
Conference on Computer Vision and Pattern Recognition, pages 509–518, 2022.
[18] Zhenghan Fang, Sam Buchanan, and Jeremias Sulam. What’s in a prior? learned proximal
networks for inverse problems. arXiv preprint arXiv:2310.14344, 2023.
[19] Promit Ghosal and Marcel Nutz. On the convergence rate of sinkhorn’s algorithm. arXiv
preprint arXiv:2212.06000, 2022.
[20] Aritra Ghosh, Naresh Manwani, and PS Sastry. Making risk minimization tolerant to label
noise. Neurocomputing, 160:93–107, 2015.
[21] Will Grathwohl, Kuan-Chieh Wang, Jörn-Henrik Jacobsen, David Duvenaud, Mohammad
Norouzi, and Kevin Swersky. Your classifier is secretly an energy based model and you should
treat it like one. arXiv preprint arXiv:1912.03263, 2019.
[22] Jean-Bastien Grill, Florian Strub, Florent Altché, Corentin Tallec, Pierre Richemond, Elena
Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Guo, Mohammad Gheshlaghi Azar,
et al. Bootstrap your own latent-a new approach to self-supervised learning. Advances in neural
information processing systems, 33:21271–21284, 2020.
[23] Ishaan Gulrajani and David Lopez-Paz. In search of lost domain generalization. arXiv preprint
arXiv:2007.01434, 2020.
[24] Kaiming He, Haoqi Fan, Yuxin Wu, Saining Xie, and Ross Girshick. Momentum contrast for
unsupervised visual representation learning. In Proceedings of the IEEE/CVF conference on
computer vision and pattern recognition, pages 9729–9738, 2020.
[25] Dan Hendrycks and Thomas Dietterich. Benchmarking neural network robustness to common
corruptions and perturbations. arXiv preprint arXiv:1903.12261, 2019.
[26] Ashish Jaiswal, Ashwin Ramesh Babu, Mohammad Zaki Zadeh, Debapriya Banerjee, and Fillia
Makedon. A survey on contrastive self-supervised learning. Technologies, 9(1):2, 2020.
[27] Longlong Jing and Yingli Tian. Self-supervised visual feature learning with deep neural net-
works: A survey. IEEE transactions on pattern analysis and machine intelligence, 43(11):4037–
4058, 2020.
[28] Daehee Kim, Youngjun Yoo, Seunghyun Park, Jinkyu Kim, and Jaekoo Lee. Selfreg: Self-
supervised contrastive regularization for domain generalization. In Proceedings of the IEEE/CVF
International Conference on Computer Vision, pages 9619–9628, 2021.
12
[29] Alex Krizhevsky and Geoffrey Hinton. Learning multiple layers of features from tiny images.
Technical report, University of Toronto, 2009.
[30] John Lee, Max Dabagia, Eva Dyer, and Christopher Rozell. Hierarchical optimal transport for
multimodal distribution alignment. Advances in neural information processing systems, 32,
2019.
[31] Da Li, Yongxin Yang, Yi-Zhe Song, and Timothy M Hospedales. Deeper, broader and artier
domain generalization. In Proceedings of the IEEE international conference on computer vision,
pages 5542–5550, 2017.
[32] Ruilin Li, Xiaojing Ye, Haomin Zhou, and Hongyuan Zha. Learning to match via inverse
optimal transport. Journal of machine learning research, 20(80):1–37, 2019.
[33] Chi-Heng Lin, Mehdi Azabou, and Eva L Dyer. Making transport more robust and interpretable
by moving data through a small number of anchor points. Proceedings of machine learning
research, 139:6631, 2021.
[34] Ran Liu, Mehdi Azabou, Max Dabagia, Chi-Heng Lin, Mohammad Gheshlaghi Azar, Keith
Hengen, Michal Valko, and Eva Dyer. Drop, swap, and generate: A self-supervised approach for
generating neural activity. Advances in neural information processing systems, 34:10587–10599,
2021.
[35] Eduardo Fernandes Montesuma, Fred Ngole Mboula, and Antoine Souloumiac. Recent advances
in optimal transport for machine learning. arXiv preprint arXiv:2306.16156, 2023.
[36] Yuval Netzer, Tao Wang, Adam Coates, Alessandro Bissacco, Bo Wu, and Andrew Y Ng.
Reading digits in natural images with unsupervised feature learning. In NIPS Workshop on
Deep Learning and Unsupervised Feature Learning, volume 2011, 2011.
[37] XuanLong Nguyen, Martin J Wainwright, and Michael I Jordan. Estimating divergence func-
tionals and the likelihood ratio by convex risk minimization. IEEE Transactions on Information
Theory, 56(11):5847–5861, 2010.
[38] Aaron van den Oord, Yazhe Li, and Oriol Vinyals. Representation learning with contrastive
predictive coding. arXiv preprint arXiv:1807.03748, 2018.
[39] Sherjil Ozair, Corey Lynch, Yoshua Bengio, Aaron Van den Oord, Sergey Levine, and Pierre
Sermanet. Wasserstein dependency measure for representation learning. Advances in Neural
Information Processing Systems, 32, 2019.
[40] Neal Parikh, Stephen Boyd, et al. Proximal algorithms. Foundations and trends® in Optimiza-
tion, 1(3):127–239, 2014.
[41] Brendan Pass. Multi-marginal optimal transport: theory and applications. ESAIM: Mathematical
Modelling and Numerical Analysis, 49(6):1771–1790, 2015.
[42] Gabriel Peyré. Entropic approximation of wasserstein gradient flows. SIAM Journal on Imaging
Sciences, 8(4):2323–2351, 2015.
[43] Gabriel Peyré, Marco Cuturi, et al. Computational optimal transport: With applications to data
science. Foundations and Trends® in Machine Learning, 11(5-6):355–607, 2019.
[44] Khiem Pham, Khang Le, Nhat Ho, Tung Pham, and Hung Bui. On unbalanced optimal transport:
An analysis of sinkhorn algorithm. In International Conference on Machine Learning, pages
7673–7682. PMLR, 2020.
[45] Shi Pu, Kaili Zhao, and Mao Zheng. Alignment-uniformity aware representation learning for
zero-shot video classification. In Proceedings of the IEEE/CVF Conference on Computer Vision
and Pattern Recognition, pages 19968–19977, 2022.
[46] Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal,
Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, et al. Learning transferable visual
models from natural language supervision. In International conference on machine learning,
pages 8748–8763. PMLR, 2021.
13
[47] Joshua Robinson, Ching-Yao Chuang, Suvrit Sra, and Stefanie Jegelka. Contrastive learning
with hard negative samples. arXiv preprint arXiv:2010.04592, 2020.
[48] Joshua Robinson, Li Sun, Ke Yu, Kayhan Batmanghelich, Stefanie Jegelka, and Suvrit Sra.
Can contrastive learning avoid shortcut solutions? Advances in neural information processing
systems, 34:4974–4986, 2021.
[49] Litu Rout, Alexander Korotin, and Evgeny Burnaev. Generative modeling with optimal transport
maps. arXiv preprint arXiv:2110.02999, 2021.
[50] Nikunj Saunshi, Jordan Ash, Surbhi Goel, Dipendra Misra, Cyril Zhang, Sanjeev Arora, Sham
Kakade, and Akshay Krishnamurthy. Understanding contrastive learning requires incorporating
inductive biases. In International Conference on Machine Learning, pages 19250–19286.
PMLR, 2022.
[51] Liangliang Shi, Gu Zhang, Haoyu Zhen, Jintao Fan, and Junchi Yan. Understanding and
generalizing contrastive learning from the inverse optimal transport perspective. In International
Conference on Machine Learning, pages 31408–31421. PMLR, 2023.
[52] Saeed Shurrab and Rehab Duwairi. Self-supervised learning methods and applications in
medical imaging analysis: A survey. PeerJ Computer Science, 8:e1045, 2022.
[53] Andrew M Stuart and Marie-Therese Wolfram. Inverse optimal transport. SIAM Journal on
Applied Mathematics, 80(1):599–619, 2020.
[54] Fariborz Taherkhani, Ali Dabouei, Sobhan Soleymani, Jeremy Dawson, and Nasser M Nasrabadi.
Self-supervised wasserstein pseudo-labeling for semi-supervised image classification. In Pro-
ceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages
12267–12277, 2021.
[55] Ilya Tolstikhin, Olivier Bousquet, Sylvain Gelly, and Bernhard Schoelkopf. Wasserstein auto-
encoders. arXiv preprint arXiv:1711.01558, 2017.
[56] Ankit Vishnubhotla, Charlotte Loh, Akash Srivastava, Liam Paninski, and Cole Hurwitz.
Towards robust and generalizable representations of extracellular data using contrastive learning.
Advances in Neural Information Processing Systems, 36, 2024.
[57] Tongzhou Wang and Phillip Isola. Understanding contrastive representation learning through
alignment and uniformity on the hypersphere. In International Conference on Machine Learning,
pages 9929–9939. PMLR, 2020.
[58] Yule Wang, Chengrui Li, Weihan Li, and Anqi Wu. Exploring behavior-relevant and disen-
tangled neural dynamics with generative diffusion models. arXiv preprint arXiv:2410.09614,
2024.
[59] Yule Wang, Zijing Wu, Chengrui Li, and Anqi Wu. Extraction and recovery of spatio-temporal
structure in latent dynamics alignment with diffusion model. Advances in Neural Information
Processing Systems, 36, 2024.
14
Appendix
Datasets and contrastive pairs: Let x denotes a vector and X denotes a matrix, with right
subscript Xb denote the batch of the input samples, Xb := [x1 , x2 , . . . , xB ], here B is equal to
batch size. For each sample xi in the input batch matrix Xb , x′i means augmented view 1 of x′i ,
x′′i means augmented view 2 of xi , the positive pairs in input data denoted as (xi , x′′i ), negative
pairs in input data denoted as (x′i , x′′j ), i ̸= j. Give a weights (θ) parametrized representation
function (artificial neural network) fθ with adjustable adjustable temperature ε, which project the
the positive pairs in latent space denoted as s+ = ⟨ε−1 feθ (x′i ), feθ (x′′i )⟩, and negative pairs in latent
space denoted as s− = ⟨ε−1 feθ (x′i ), feθ (x′′j )⟩, i ̸= j. Here, ⟨·, ·⟩ is the inner product, which means
⟨feθ (x′i ), feθ (x′′i )⟩ = fθ (x′i )⊤ fθ (x′′j )/∥fθ (x′i )∥∥fθ (x′′j )∥ is the normalized form.
Continuous settings for optimal transport X and Y are topological spaces, X × Y is the product
space, or Torus. C(X ) is the compact topological space which contains all of continuous functions on
X endowed with the sup-norm. On Torus X and Y we define M as a compact n-dimensional manifold
in product space X × Y (X = Y := Rn /Zn ) endowed with a cost function c(x, y) := dM (x, y)2 /2
(Euclidean dsitance function) on Rn . Transport plan π(x, y) : X ×Y → R is an element in P(X ×Y).
P(X × Y) means the collections of the joint distributions of the two marginal distributions µ ∈ P(X )
and ν ∈ P(Y). P(X ) the space of all (Borel) probability measures on X , P(Y) means the same to
Y. To find a joint distribution (or plan) π(x, y) in collections U (µ, ν) with marginals µ and ν in the
product space X × Y, we can formulate as:
X X X
min π(x, y)c(x, y) s.t. π(x, y) = µ(x), π(x, y) = ν(y)
π∈U (µ,ν)
x∈X ,y∈Y y∈Y x∈X
cost matrix calculated by c(x, y), whose sampled from finite sets X and Y defined previously, and N
is the batch size. u and v are B × 1 scale factors matrix, u(t) and v(t) mean the scale factor matrix
after t iterations of sinkhorn algorithms. P is a B × B joint distribution matrix of µ and ν, which
represents the transport plan π(x, y) that corresponds to minimize the cost. P(2t−1) means we use
t iterations u(t) and v(t−1) to calculate P(t) . When t = 1 means we use u(1) and v(0) to calculate
P(1) , which is called half-step OT or one step Bregman projection. When t = 2 means we use u(1)
and v(1) to calculate P(2) , which is called half-step OT or one step Bregman projection.
In this section, we are going to provide the detailed illustration about the proximal operators. How
the proximal operator would convert to the projection. And how to solve the Bregman projection
with KL divergence.
15
The solution to the proximal operator exists and is unique due to the strong convexity of the above
function.
Proof of the Lemma 1: This lemma can be proved by the strict convexity of the proximal operator.
Figure A1: Illustration of the proximal operators A. Visualization of proximal operators in R3 . On the surface
defined by h(x, y) = x2 + y 2 within the domain constraints −1.2 < x < 1.2 and −1.2 < y < 1.2. If
v = v1 = (0.76, 0.76, 1.16), it lies within the domain of h, represented on the surface at the exact location
matching its third coordinate with h(x, y). If v = v2 = (1.5, 1.5, 6), which is outside the feasible region defined
by h, the proximal operator projects it to the closest point within the domain, resulting in v2 ’s projection to
approximately (0.85, 0.85, 1.45). B. Visualization of proximal operators in R2 . The blue dashed line represents
the function h(x) = x2 . The orange dash-dotted line illustrates the penalty term 12 ∥x − v∥2 with v = (2, 0),
indicating the squared distance from any x to v. The green solid line is the proximal operator 2x2 + 21 ∥x − v∥2 ,
which gets close to the minimization point of h(x) from v. The red point marks the Proxh (v) in this space.
Γ is a strictly convex function smooth on int(B), and ProxddΓϕ (K) ∈ int(B) is always uniquely defined
by strict convexity. (Note that this theory is general and does not need to parametrize the K and P as
models with θ). As B = dom(Γ),
∀(P, K) ∈ B × int(B), dΓ (P∥K) = Γ(P) − Γ(K) − ⟨∇Γ(K), P − K⟩,
which has its Legendre transform is also smooth and strictly convex:
Γ∗ (ρ) = max⟨P, ρ⟩ − Γ(P)
P∈B
16
The Bregman divergence for a convex function Γ between points x and y is defined as:
dΓ (x, y) = Γ(x) − Γ(y) − ⟨∇Γ(y), x − y⟩
where ∇Γ(y) is the gradient of Γ at y. Giving the squared L2 distance can be viewed as a Bregman
divergence derived from the convex function Γ(x) = ∥x∥2 . For this function, the Bregman divergence
between two points x and y becomes:
dΓ (x, y) = ∥x∥2 − ∥y∥2 − 2y⊤ (x − y) = ∥x − y∥2
If we repeat this progress for the set C2ν := {P : PT 1n = ν}, we will get the ProxKL C2ν (P
(t+1)
). And
ν
in the second step, we project onto the second constraint set C2 with indicator function hG (x) defined
on C2ν and get:
ν
P(2) := ProxKL (P(1)
) = P (1)
diag . (22)
P(1)⊤ 1B
C2ν
The Sinkhorn algorithm is composed with two steps Bregman projection, Similarly, we can write out
(t)
this recursive relationship as: Pt+1 can be updated with dual variables f , g and u(t) = ef /ε , v(t) =
(t)
eg /ε . The set U (µ, ν) = C1µ ∩ C2ν , representing the feasible transport plans with given marginals.
It could be any random sets, i. e. B = C11 ∩ C21 denote the Birkhoff polytope of doubly stochastic
matrices where µ = 1 and ν = 1 are the uniform distributions with all one element.
17
A.3 Background on OT
This section defines discrete and continuous optimal transport. Since the section 2.3 lacks a discrete
OT definition, we discuss it here and show the equivalence between solving Bregman projection and
the entropy-regularized OT (EOT) problem.
To support convergence proofs later, we introduce definitions of continuous measures. Symbols µ
and ν may represent both discrete and continuous measures for intuitive consistency, with precise
definitions at the start of each subsection.
Even though directly solving Equation (24) is high computational complexity O(n3 ), we introduce a
common relaxation called entropic regularization to smooth the transport plan.
where ε is a user specified parameter that controls the amount of smoothing in the transport plan. The
cost matrix C could be transformed into the Gibbs kernel matrix K on a Hilbert space with the given
formula,
Kij = exp −ε−1 Cij
(26)
To solve (25) under the kernel space induced by K, we can use the iterative Sinkhorn algorithm with
the initialization of u(0) and v(0) as all one vector divided by the batch size, and the update rules:
def µ def ν
u(t+1) = (t)
and v(t+1) = , (27)
Kv K u(t+1)
T
18
Lemma 2. Solving the entropy optimal transport in Equation (3) is consistent with iterative solving
the Bregman projection.
Proof of the Lemma 2: Giving that some points K and P, their distance could be measured by KL
divergence:
X Pij
KL(P∥K) = Pij log − Pij + Kij
ij
Kij
As Cij = −ε log Kij in Sinkhorn, we can see find P to minimize the Equation (3) can be transformed
into some formula about Kij :
X X
min⟨P, C⟩ − εH(P) = min Cij Pij + ε Pij (log(Pij )) (30)
P P
i,j i,j
X
= min ε (−Pij log Kij + Pij log(Pij )) (31)
P
i,j
Consider the K is a point in Hilbert kernel space, and ε is the constant, we set the µ and ν form the
B, so here can have:
where the constraint set U(c) is defined by U(c) := {(µ, ν) ∈ C(X )×C(Y)}|f (x)+g(y) ≤ c(x, y)}.
Here, C(X ) is the space of all continuous functions on X , the functions which measured using the
supreme norm ||f ||∞ , with the Legendre transform:
Definition 6 (Legendre c-transforms). For the dual variables, or so called potentials, there exists the
Legendre c-transforms:
f c (y) := sup (−c(x, y) + f (x)), g c (x) := sup(−c(x, y) + g(y)). (36)
x∈X y∈Y
In which g c (x) and f c (y) are Legendre c-transforms of g(y) ∈ C(Y) and f (x) ∈ C(X ) with cost
function c(x, y).
Definition 7 (Pushforward measure). The pushforward measure of µ under the map T , denoted as
Tµ , is a measure on X defined by Tµ (B) = µ(T −1 (B)) for any Borel set B in X . Tµ = ν when T is
an optimal transport map. Following the similar way we can define the push-forward measure Tµ
and Tν as:
Z Z
−c(x,·)+f (x)
Tµ : C(X ) → C(Y) := log e µ(x), Tν : C(Y) → C(X ) := log e−c(·,y)+g(y) ν(y)
(37)
19
Definition 8 (φ-divergence regularized OT in continuous). Given two dual variables (also called
potentials) f ∈ Rn and g ∈ Rm for each marginal constraint, the entropy regularized optimal
transport in Equation (3) could be transformed into some problems with the Kantorovich functional:
Z Z !
φ dπ(x, y)
Wε,c (µ, ν) = inf ( c(x, y)dπ(x, y) + ε φ dµ(x)dν(y) (38)
πinΠ(µ,ν) X ×Y X ×Y dµ(x)dν(y)
Proposition 1 (Dual of EOT). Consider OT between two probability measures µ and ν with a convex
regularizer ϕ on R+ in Equation (38)
f (x) + g(y) − c(x, y)
Z Z Z
φ
Wc,ε (µ, ν) = sup f dµ(x)+ dν(y)−ε φ∗ ( )dµ(x)dν(y)
f,g∈C(X )×C(Y) X Y X ×Y ε
(39)
where φ∗ is the Legendre transform of φ defined by φ∗ (v) := supx xv − ϕ(x)
A good choice for φ∗ is that the φ∗ (v) = ev . The entropy regularization term ensures the problem is
solvable, especially for computational schemes. If the ε → ∞, the optimal primal plan π ∗ can be
retrieved using, which corresponds to the mutual information formula:
dπ ∗
∗
f (x) + g ∗ (y) − c(x, y)
(x, y) = exp
dµdν ε
In the discrete version in Equation (3), the optimal transport plan P can often be expressed in terms
of the optimal transport map T ∗ when it exists, one can define the so-called barycentric projection
map
1 X
T ∗ : xi ∈ X 7→ Pi,j yj ∈ Rd ,
µi j
This link provides the connection between the mutual information with the optimal mapping:
Z
∗ dπ(x, y)
T : x ∈ X 7→ y dν(y).
Y dµ(x)dν(y)
dπ(x,y)
Note that the joint distribution π always has a density dµ(x)dν(y) with respect to µ ⊗ ν, and the mutual
information method will lead us to the optimal solution.
B Analysis of GCA
B.1 Convergence of GCA
In this section, we provide a proof of convergence in the forward pass for our GCA algorithm. To do
this, we show the general form in Algorithms 1 for all Bregman divergence (dΓ ) in forward pass in
GCA algorithms could be converged through Djkstra’s projection algorithms. Finally, we show the
uniformly convergence of the transport plan P, and the convergence of its dual variables f (t) in each
iteration.
Proof for Corollary 1: First, let’s define the following Hilbert space: ∀(u, u′ ) ∈
′ ui u′ ′
(Rn+,∗ )2 , dH (u, u ) := log maxi,j uj uj′ .
For any pairs of vectors that (v, v ) ∈ (Rm
+,∗ )
2
holds:
i
v
dH (v, v′ ) = dH ′
, 1m = dH (1m /v, 1m /v′ ) . (41)
v
20
Algorithm A1 Generalized Contrastive Alignment (GCA)
Input: Encoder fθ , projector gθ , data {xk }N k=1 , batch size B, cost function c(x, y), entropy
parameter ε, constant τ , total iterations T , marginal constraints µ and ν, relax items d1 , d2 and
constant δeps , some divergence dM and dΓ (could be KL or WDM),
for sampled minibatch {xk }B k=1 do
Generate two views (z′k , z′′k ) using fθ , gθ with randomly sampled augmentations.
end for
u(0) = 1, v(0) = 1, f = 0, g = 0, Cij = c(z′i , z′′j )
d1 ← d1 /(d1 + ε), d2 ← d2 /(d2 + ε), K = exp(Cij /ε−1 )
for i = 1 to T do
δf ← exp −f /(ε + d1 ), δg ← exp −g/(ε + d2 )
u ← δf · ProxF (Kv + δeps )f i
v ← δg · ProxG (K T u + δeps )f i
if u > τ orv > τ then
f ← f + ε · log(max(u)), g ← g + ε · log(max(v))
K ← exp (f + g − C)/ε, v = 1
end if
end for
log u ← f /(ε + u), log v ← g/(ε + v)
Compute transport plan as:
P ← exp(log u + log v − C/ε)
(T )
Normalize Pu by its column sums.
(T )
Loss: LGCA = dM (µ ⊗ ν, Pu )
Update networks fθ and gθ to minimize LGCA
Let K ∈ Rn×m ′ m 2
+,∗ , then for (v, v ) ∈ (R+,∗ ) we have
1n 1n
dH (u(t+1) , u∗ ) = dH , = dH (Kv(t) , Kv∗ ) ≤ λ(K)dH (v(t) , v∗ ),
Kv(t) Kv∗
where p
η(K) − 1 Ki,k Kj,ℓ
λ(K) := p < 1, η(K) := max . (42)
η(K) + 1 i,j,k,ℓ Kj,k Ki,ℓ
Based on the contraction mapping theory, one has (u(ℓ) , v(ℓ) ) → (u∗ , v∗ ) and
dH (u(t) , u∗ ) ≤ dH (u(t+1) , u(t) ) + dH (u(t+1) , u∗ ) (43)
µ
≤ dH (t)
, u(t)
+ λ(K)2 dH (u(t) , u∗ ) (44)
Kv
= dH µ, u(t) ⊙ (Kv(t) ) + λ(K)2 dH (u(t) , u∗ ), (45)
dH (P(t) 1m , µ) dH (P(t)⊤ 1n , ν)
dH (u(t) , u∗ ) ≤ , dH (v(t) , v∗ ) ≤ , (46)
1 − λ(K)2 1 − λ(K)2
where we denoted P(t) := diag(u(t) )K diag(v(t) ). Last, one has
∥ log(P(t) ) − log(P∗ )∥∞ ≤ dH (u(t) , u∗ ) + dH (v(t) , v∗ ), (47)
where P∗ is the unique solution of Equation (3). The above formula also shows that the t-step solution
gives a better lower bound than the 1-step solution.
21
We present a general convergence proof for Dykstra’s projection algorithm, sharing the form in
Definition (1). First, we define dΓ as a generic Bregman divergence on some convex set B, and the
proximal map of a convex function dϕ according to this divergence is:
Γ is a strictly convex function smooth on int(B), and ProxddΓϕ (K) ∈ int(B) is always uniquely defined
by strict convexity. As B = dom(Γ),
∀(P, K) ∈ B × int(B), dΓ (P∥K) = Γ(P) − Γ(K) − ⟨∇Γ(K), P − K⟩,
which has its Legendre transform is also smooth and strictly convex:
Γ∗ (ρ) = max⟨P, ρ⟩ − Γ(P)
P∈B
In particular, one has that ∇Γ and ∇Γ are bijective maps between int(B) and int(dom(Γ∗ )) such
∗
that ∇Γ∗ = (∇Γ)−1 . For Γ = || · ||2 , one recovers the squared Euclidean norm dΓ = || · ||2 . One has
PB
KL = dΓ for Γ(P) = h(P) = − i,j=1 (Pij (log Pij − 1)). Dykstra’s algorithm starts by initializing
P(0) := K and U(0) = U(−1) := 0. One then iterative defines, for k > 0,
P(k) := ProxddΓϕ (∇Γ∗ (∇Γ(P(k−1) ) + U(k−2) )), (49)
[k]2
Lemma 3 (Uniformly convergence). [4] When t1 → ∞, f (t1 ) converges uniformly to a fixed point
f (∞) , with f (t1 ) ≤ f (∞) .
Proof for the Lemma 3 (Uniformly convergence): We follow the procedures of methods in [4].
Giving push-forward measure Tµ and Tν and a composed operator S = Tν ◦ Tµ , which yields an
iteration on C(X ) as S : C(X ) → C(X ), f → f ◦ g ◦ f , f (m+1) = S(f (m) ), and eS(f )−f µ is the
probability measure on X .
Lemma 4 (Existence and uniqueness). The following conditions are equivalent for a function f in
the space C(X), where C(X) denotes the space of continuous functions on a set X:
Moreover, if f is a critical point, then f ∗ := S(f ) is a fixed point for the operator S on C(X).
22
Proof of the Lemma 1 Consider the functional L defined in Equation (53), the differential of L at an
element f ∈ C(X) is represented by the probability measure exp(S(f ) − f )µ. For some iterations
f (m+1) − f (m) = S(f (m) ) − f (m) , when f is a critical point (derivative is zero or undefined) for
the functional J on C(X) , and f ∗ := S(f ) is a fixed point for the operator S on C(X), proved by
realizing for any f˙ ∈ C(X):
Z
d
L(f + tf˙) = f˙e(S(f )−f ) dµ. (54)
dt t=0 X
This follows readily from the definitions by differentiating t 7→ g[(f + tf˙)] to get an integral over
(X, µ) and then switching the order of integration. As a consequence, f is a critical point of the
functional F on C 0 (X) if and only if e(S(f )−f ) µ = µ, i.e., if and only if e(S(f )−f ) = 1 almost
everywhere with respect to µ. Finally, if this is the case, then S(f ) = f almost everywhere with
respect to µ and hence S(S(f )) = S(f ) (since S(f ) only depends on f viewed as an element in
L1 (X, µ)).
Lemma 5. Given a point x0 ∈ X, the subset Kx0 of C(X ) defined as all elements f in the image of
S satisfying f (x0 ) = 0 is compact in C(X ).
Proof of the Lemma 5: Based on the compactness of the product space X × Y, the continuous
function c is uniformly continuous on X . So S(C(X )) is an equicontinuous family of continuous
functions on X. By Arzelà-Ascoli theorem, it follows that the set Kx0 is compact in C(X ).
Proposition 3. The operator S has a fixed point f ∗ in C(X ). Moreover, f ∗ is uniquely determined
a.e. wrt µ up to an additive constant, and f ∗ minimizes the functional F . More precisely, there exists
a unique fixed point in S(C(X ))/R.
Proof of the Proposition (3): Then based on the Jensen’s inequality, we have
Z Z
(m+1)
Iµ (f (m)
) − Iµ (f ) = log exp (S(f ) − f )dµ ≤ log exp (S(f (m) ) − f (m) )dµ = 0,
(m) (m)
(55)
Z Z
L(f (m) ) − L(f (m+1) ) = log exp (S(g (m) ) − g (m) )dν ≤ log exp (S(g (m) ) − g (m) )dν = 0.
(56)
So we know the functionals are strictly decreasing at f (m) unless S(f ∗ ) = f ∗ for f ∗ := S(f (m) ).
Then based on the Lemma 5, we know for each initial data f0 , the closure of its images denoted
as Kf0 in C(X )/R is compact, under the operator S. Hence, f (m) → f (∞) in C(X )/R. And J is
decreasing along the orbit but has lower bound:
By the condition for strict monotonicity, it must be that S(f (∞) ) = f (∞) a.e. wrt µ. It then follows
from the Proposition (3) that f (∞) is uniquely determined in C(X )/R (by the initial data f (0) ), i.e.
the whole sequence converges in C(X )/R. We first show that there exists a number λ ∈ R such that
limm→∞ Iµ (f (m) ) = λ. Iµ is decreasing and hence it is enough to show that Iµ (f (m) ) is bounded
from below. By Iµ = J + L, and J is bounded from below (by F (f (∞) )). Moreover, by the first
step L(f (m) ) ≥ L(f (0) ). Next, decompose
By the Lemma 5 the sequence (f˜(m) is relatively compact in C(X ) and we claim that |f (m) (x0 )| ≤ C
for some constant C. Indeed, if this is not the case then there is a subsequence f (mj ) such that
|f (mj ) | → ∞ uniformly on X. But this contradicts that Iµ (f (m) ) is uniformly bounded. It follows
that the sequence (f (m) ) is also relatively compact. Hence, by the previous step the whole sequence
f (m) converges to the unique minimizer f ∗ of F in S(C(X )) satisfying Iµ (f ∗ ) = λ.
23
B.2 GCA version of unbalanced optimal transport (GCA-UOT)
In this section, we are going to introduce the relaxation of the EOT plan as Unbalanced optimal
transport plan (UOT). And its relationship with the dual formula of EOT. Here we need to emphasize
that the GCA-UOT not just add constraint to the proximal operators which computes the coupling
matrix Pθ , but also add the penalty (i.e. KL-divergence) to the loss function dM . For the specific
function we used in the method of GCA-UOT in Table 2, we employed a version with the loss in
Equation (11) plus the loss in Equation (10) with a weight control parameter.
UOT (µ, ν) = min ⟨P, C⟩ + λ1 dϕ1 (P1||µ) + λ2 dϕ2 (P⊤ 1||ν) + εH(P) (57)
P
Here ⟨P, C⟩ represents the total transport cost. λ1 and λ2 are regularization parameters that control
the trade-off between the transport cost and the divergence penalties.
Proof of the Lemma 6: Let’s start with the dual formula of the Equation (3) with B = C1µ ∩ C2ν , we
can introduce the Lagrangian E(P, f, g) of Equation (3) reads:
ProxKL
B (K) := min⟨P, C⟩ − εH(P) = E(P, f, g) (59)
P∈B
The solution to the Equation (3) is unique with scaling variabl (u, v) ∈ Rn+ × Rm
+ in Equation (23).
And each items in the optimal transport matrix P is, and optimal (f, g) are linked to non-negative
vectors (u, v) through (u, v) = (ef /ε , eg/ε ).
Pij = efi /ε e−Cij /ε egj /ε = ui Kij vj , (f (t) , g (t) ) = ε(log(u(t) ), log(v(t) )), (62)
In this section, we are going to discuss how to build the equivalence between minimizing the
KL-divergence dM between the P(1) and the Ptgt with respect to θ in GCA objective:
min KL I||ProxKL
C µ (Kθ )),
θ 1
with the INCE loss minimization in Equation (1). Here P(1) is the nearest point of Kθ on constraint
set C1µ measured by the KL-divergence dΓ defined in Equation (18), through one step of proximal
operator (Bregman projection). And Kθ denote the augmentation kernel as in Definition (3) with
cosine similarity.
24
B.3.1 Proof of the Theorem 1
Suppose we had a encoder fθ with parameter θ in INCE, with feθ to represent its normalized form,then
we can use the following proposition to assist our proof:
Proposition 4. Given the cost matrix as Ci,j = 1 − feθ (x′i )⊤ feθ (x′′j ), and Gibbs kernel Kθ =
exp(−Ci,j /ε), based on the cosine dissimilarity scores of the inner products ⟨zθi , zθj ⟩, with zi =
fθ (x′i ) fθ (x′′ )
and zj = ∥fθ (x′′j )∥ . Set dM and dΓ to KL-divergence, and the target transport plan Ptgt = I.
∥fθ (x′i )∥ j
The probability matrix P after one-step Bregman iteration of entropy optimal transport problem
could be represented as:
Proof of the Proposition (4): We assume that gibbs kernel Kθ is a matrix which can be expressed as:
with a temperature parameter ε. µ, ν, u(0) and v(0) can be initialized as a vector of ones with the
same size as B, the batch size,
µ = 1, ν = 1 u(0) = 1, v(0) = 1.
So we know that:
1
u(1) = Pb .
j=1 Kθij
Thus, half-step sinkhorn iteration or one-step Bregman interation for P can be expressed as:
This concludes the expressions of P at half-step iteration. Reminds us the formula of the KL
divergence KL(I∥P) and the entropy H(P):
def
X Ii,j Ii,j
KL(I∥P) = Ii,j log − Ii,j + Pi,j , where Ii,j log = 0, if Ii,j = 0. (64)
i,j
Pi,j Pi,j
P
And after the batch normalization of P, the value of i,j Pi,j is equal to the batch size B and exactly
P
the same as the i,j Ii,j , we can obtain:
j represents the elements on the diagonal of the similarity matrix, which is the same structure as the
INCE loss as:
!
X exp(fθ (x′i )⊤ fθ (x′′i ))
LINCE = − log Pb
i j=1 exp(fθ (x′i )⊤ fθ (x′′j ))
25
B.4 Proximal operator version of RINCE
In this section, we are going to discuss how to build the equivalence between minimizing the some
convex function of dM with adjustable parameters q and λ between the P(1) and the Ptgt as:
(1)
!q q !
1 diag(Pθ ) λI
dM (I, P) = − − (65)
q u(1) u(1)
For the specific parameters θ, we record the normalized latent of the ziθ+ = sii , and ziθ− = sij , j ̸= i.
The positive pairs are stored in the diagonal of the gibbs kernel Kθ , and the negative pairs are stored
in the off-diagonal elements, which means:
i
Kii = exp −ε−1 Ci,i = exp −ε−1 |1 − ⟨z′θi , z′′θi ⟩| = exp ε−1 ⟨z′θi , z′′θi ⟩ − ε−1 ∝ ezθ+ . (67)
i
Kij = exp −ε−1 Ci,j = exp ε−1 ⟨z′θi , z′′θj ⟩ − ε−1 ∝ ezθ− , j ̸= i.
(68)
By solving the ⟨u(1) K, 1⟩ = µ in the Equation (21), we have the ith column elements
P
j=1 Kθij =
µ (1)
(1) , in which u is given in 21:
ui
B B
µ X 1 −1
⟨z′θi ,z′′
θi ⟩
X −1
⟨z′θi ,z′′
θj ⟩
(1)
= Kθij = ε−1
(eε + eε ), i ̸= j, (69)
ui j=1
e j=1,j̸=i
i
ezθ+ diag(P(1) )
diag(Kθ ) = −1 = . (70)
eε u(1)
The diagonal of K matrix contains the positive views and the marginal distribution of the u contains
the negative view, we have:
i PB ij PB
(λ · (esθ+ + j=1,j̸=i esθ− ))q
i
λ,q i eqsθ+ diag(Kθ )qii (λ · ( j=1 Kθ ij ))q
LRINCE (sθ ) = − + ∝− +
q q q q
(71)
Furthermore, we have:
q q
λ,q 1 1 λI
−E(LRINCE (Kθ )) = diag(Kθ ) − . (72)
q q u(1)
where P(0) = diag(1)Kθ diag(1), P(1) = diag(u(1) )Kθ diag(1), we have:
(1) 1 diag(P(1) ) q 1 λI q
Lλ,q
RINCE (Pθ ) = − ( ) + ( (1) ) . (73)
q u(1) q u
26
B.4.2 Proof of the Symmetry and robustness of RINCE
Symmetry loss is said to be noise tolerant as the classifier will keep performance with the label noise
in Empirical Risk Minimization (ERM). In many practical machine learning scenarios, we aim to
select a model or function fθ that minimizes the expected loss across all possible inputs and outputs
from a distribution D, which is typically unknown. Instead of minimizing the true risk, which is
often not feasible due to the unknown distribution D, we minimize what is called the empirical
risk R̂L (feθ ), which is defined as the average loss over the training dataset of size B, which consists
of independently and identically distributed (iid) data points. Mathematically, it is given by the
following formula:
B
1 X e
R̂L (fθ ) = L(fθ (xi ), yi ) (74)
B i=1
Here, L(feθ (xi ), yi ) represents the loss function, which measures the discrepancy between the
predicted value feθ (xi ) and the true value yi . The function feθ that minimizes this empirical risk
is chosen as the model for making predictions. This approach is based on the assumption that
minimizing the empirical risk will also approximate the minimization of the true risk, especially as
the size of the training set increases.
First we show the symmetry loss is robust to the noisy view with the following Lemma [20], which
means they will achieve the same performance in ERM with the noisy labels. Then we show RINCE
satisfy the symmetry condition when q = 1, so the lemma is:
Lemma 7. Give a loss function L(feθ (x), y) exhibits a certain symmetry for some positive constant
K, with respect to the labels y = 1 and y = −1:
L(feθ (x), 1) + L(feθ (x), −1) = K, ∀x, ∀f, (Symmetry) (75)
Symmetry loss is noise tolerant given the label noise η < 0.5, which corresponds to the flipped labels:
PD [sign(feθ∗ (x)) = yx ] = PD [sign(feθη
∗
(x)) = yx ], (Noisy tolerant) (76)
As we know that this formula has the same structure as the exponential loss function: L(zθ , y) =
−yezθ . To check for symmetry, we define a new binary classification loss function as:
L̃x (zθ (x), y) = B + Lx (feθ (x), y) = B − y · efθ (x) ≥ 0
e
where the prediction score feθ (x) is bounded by smax = log(B). Then we can establish that the loss
satisfies the symmetry property:
L̃(feθ (x), 1) + L̃(feθ (x), −1) = 2B (78)
So we prove that this loss function is symmetry.
B.5 Proof for RINCE is the upper bound of the 1-Wasserstein distance
In this section, we are trying to build the connection when change the dM from the KL-divergence in
Equation (10) to the 1-Wasserstein distance in Equation (12), when q=1 in the RINCE loss.
27
where C(X × Y) denotes the set of all 1-Lipschitz functions from X × Y to R. A function f :
X × Y → R is defined to be 1-Lipschitz if, for any two points (x1 , y1 ), (x2 , y2 ) ∈ X × Y, the
following condition is satisfied:
|f (x1 , y1 ) − f (x2 , y2 )| ≤ d((x1 , y1 ), (x2 , y2 ))
where d((x1 , y1 ), (x2 , y2 )) denotes the metric on X × Y typically defined, for example, by the
Euclidean distance:
p
d((x1 , y1 ), (x2 , y2 )) = (x1 − x2 )2 + (y1 − y2 )2
Based on the Lipschitz continuity and inner product, it is easy to know for two given point (x1 , y1 ),
(x2 , y2 ), the following properties hold with − 1ε ≤ s ≤ 1ε , which implies |∇s es | ≤ e1/ε . Therefore,
by the mean value theorem, we have:
T T 1 1
|ex1 y1 /ε − ex2 y2 /ε | ≤ e1/ε |⟨x1 , y1 ⟩ − ⟨x2 , y2 ⟩| = e1/ε |⟨x1 − x2 , y1 ⟩ + ⟨x2 , y1 − y2 ⟩|
ε ε (79)
1 1
≤ e1/ε (∥x1 − x2 ∥∥y1 ∥ + ∥y1 − y2 ∥∥x2 ∥) = e1/ε (∥x1 − x2 ∥ + ∥y1 − y2 ∥)
ε ε
Consider two pairs of views, (z′θ1 , z′′θ1 ) and (z′θ2 , z′′θ2 ), sampled from the joint distribution π of µ and
ν. Thus, each pair (z′θi , z′′θi ) for i = 1, 2 represents a sample from the joint distribution π, where
z′θi ∼ µ and z′′θi ∼ ν. The RINCE loss is a symmetry loss with q = 1, so we have the Equation (2):
B
(
−1 e
λ,q=1 zii z ii X
z ij zii
θ+ = ε fθ (x′i )⊤ feθ (x′′i ), for i = i,
LRINCE = −e θ+ + λ · (e θ+ + e θ− ), ij (80)
j=1,j̸=i
zθ− = ε fθ (x′i )⊤ feθ (x′′j ), for i ̸= j.
−1 e
So we know that:
B−1
ε−1 z′T ′′
ε−1 z′T ′′
X
− E(Lλ,q=1
RINCE (zθ )) = E z′θi ∼µ
(1 − λ)e θi zθi −λ e θi zθj
z′′
θi ∼ν|µ=zθi
′
j=1
z′′
θj ∼ν
z′T z′′
z′T ′′
θi zθi θi θj
= E(z′θi ,z′′θi )∼π (1 − λ)e ε − λ(B − 1)Ez′T ′′T
θi ∼µ,zθj ∼ν
e ε
′T ′′ ′T ′′
zθ zθ zθ zθ
≤ (1 − λ) E(z′θ ,z′′θ )∼π e ε − Ez′T ′′
θ ∼µ,zθ ∼ν
e ε (Giving setting λ(B − 1) > 1 − λ)
If we give two couples of two views (z′θ1 , z′′θ1 ) and (z′θ2 , z′′θ2 ) from joint distribution π of µ and ν,
z′θ ∼ µ and z′′θ ∼ ν, which means to maximize:
−1 ′T ′′ −1 ′T ′′
|eε zθ1 zθ1
− eε zθ2 zθ2 |
1 1
≤ (1 − λ)e ε (∥z′θ1 − z′θ2 ∥∥z′′θ1 ∥ + ∥z′′θ1 − z′′θ2 ∥∥z′θ2 ∥) (Mean value theorem from Equation (79))
ε
1 1
= (1 − λ)e ε (∥z′θ1 − z′θ2 ∥2 + ∥z′′θ1 − z′′θ2 ∥2 )
ε
1 1
= (1 − λ)e ε d ((z′θ1 , z′′θ1 ), (z′θ2 , z′′θ2 ))
ε
1
≤ (1 − λ)e1/ε W1 (π, µ ⊗ ν).
ε
In this section, we are going to who how the change of the augmetation kernel from the Kθ in
Definition (3) into the BYOL kernel Sθ would lead to the BYOL loss.
Proof for the Theorem 4
28
BYOL has the online network parameterized by θ and target network parameterized by ξ, where
z′θ = feθ (x′ ) and z′′ξ = feξ (x′′ ) are the normalized outputs of the online and target networks,
respectively. The kernel of BYOL looks like:
Sθ (x′i , x′′j ) = exp(−⟨e
qθ (feθ (x′i )), feξ (x′′j )⟩),
The kernel here involves both the parameters θ and ξ, however, the target network has the stop
gradient. Therefore, the only θ needs to be updated, so we can rewrite the kernel as Sθ (x′i , x′′j ) as we
show in the main text. As we give in the equation, the corresponding proximal operators evolving
with dΓ is equal to L2-distance has the formula, and h(x) = 0 for all P ∈ RB×B :
∥·∥2 1
ProxRB×B (Sθ ) = arg min h(P) + ∥P − Sθ ∥22 ⇒ P = Sθ
P∈RB×B 2
The BYOL loss can be written as normalized L2-distance between the normalized output after online
network qeθ (z′θ ) in which qeθ is predictor and the stop gradient results for the target network qeθ (z′ ),
and the formula of BYOL object reads as LBYOL = ∥e qθ (z′θ ) − z′′ξ ∥22 .
In this case, there exists equivalence between
B
X B
X
KL(I∥Sθ ) = − log Sθii = qθ (z′θ ) − z′′ξ ∥22
∥e (81)
i i
which is the BYOL loss.
In the forward pass, iteratively running the GCA does not involve inner optimization for gradient
back-propagation. In the Sinkhorn algorithm, the transport plan Pθ is computed as:
Pθ = exp(f + g − Cθ )/ϵ,
where f and g are dual variables iteratively updated in the Sinkhorn algorithm but do not involve
gradients with respect to θ. The Sinkhorn optimization primarily entails scaling the rows and columns
of P to satisfy the marginal constraints, which can be viewed as element-wise operations (scaling
and exponentiation) on the cost matrix Cθ .
Since Pθ is computed through the fixed-point iteration of f and g that depend only on the current
values of Cθ , the gradient back-propagation process is simplified. Specifically, the gradient of the loss
with respect to the cost matrix Cθ is the key part that needs to be differentiated, rather than through
each iterative update of f and g. A typical workflow of these algorithms was shown in Figure 2
of [17], the gradient flow primarily involves differentiating through Cθ , which is done only once, and
not through each step of the Sinkhorn iterations. This approach reduces computational complexity and
avoids the need for back-propagation through every iterative update within the Sinkhorn algorithm,
which might otherwise be computationally expensive.
In this section, we are going to show the GCA methods minimize the difference between the target
alignment plan with the coupling matrix on latent. The uniformity and alignment loss have been
used to exam the quality of the representation in self-supervised learning, which is defined as the
following [57]:
Definition 9 (Alignment loss). Given π as joint distribution of positive samples on the latent, (z′θi , z′′θi )
are the normalized positive pairs sampled from the joint distribution π with encoder parameterized
by θ, the alignment loss is:
X
Lalign = min E(z′θi ,z′′θi )∼π ∥z′θi − z′′θi ∥22 = min
diag(Cii ), (82)
θ θ
i
where C is the cost matrix defined in Equation (24).
29
We can alter the constraint sets of proximal operators to provide the better alignment plans, i.e. GCA-
INCE changes the constraint sets by considering both row and column normalization in coupling
matrix Rather than just the row normalization. Such change will not affect the alignment loss in
forward pass, it will benefit the alignment loss in the backward pass through a tighter bound of
empirical risk minimization with the identity matrix.
30
λ,q=1,ε (t) (1)
Combine the above two items, we have the equation like LGCA-RINCE (Pθ ) ≤ Lλ,q=1,ε
RINCE (Pθ ).
C.2 GCA methods improve the uniformity and benefit downstream classification tasks
In this section, we provide theoretical evidence that the GCA approaches could improve the perfor-
mance of downstream task, i.e. classification tasks, by providing the maximum uniformity through
solving the EOT, as Theorem (7) stated. Here, the uniformity loss is defined as [50]:
Definition 10 (Uniformity loss). Let z′θi ∼ µ and z′′θj ∼ ν in which µ and ν are two distributions on
the representation space, we define the uniformity loss as the following:
′ ′′ 2
Luniform = log Ez′θi ,z′′θj i.i.d.∼pdata [e−ε∥zθi −zθj ∥2 ] (87)
Here, pdata (·) should be the marginal distribution of the samples. As the z′θi and z′′θ j are normalized
′ ′′ 2
latent variables, we have the right items of the uniformity loss e−ε∥zθi −zθj ∥2 is the same as the
entropy-regularized kernel Kij = e−εCij with cost matrix items Cij = ∥z′θi − z′′θj ∥22 .
is the relative entropy of the transport plan π with respect to the product measure µ ⊗ ν. So the
corresponding dual problem of this EOT one is shown in the following formula:
Z Z
Wc,ε (µ, ν) = max f (x) dµ(x) + g(y) dν(y) (90)
f ∈C(X ),g∈C(Y) X Y
Z
f (x)+g(y)−c(x,y)
−ε e ε dµ(x)dν(y) + ε (91)
X ×Y
h f (x)+g(y)−c(x,y)
i
= max Eµ⊗ν f (x) + g(y) − e ε +ε (92)
f ∈C(X ),g∈C(Y)
The µ(x) and ν(x) are defined as the uniformly distribution with Dirac delta function we have on
the two latent supports {z′θi }B ′′ B
i=1 and {zθi }i=1 , so the function f (x) and g(y) could be pull out of
the expectation operators. Since the ∥zθi − z′′θj ∥22 is the element in the cost matrix Cij , which is
′
computed through the cost function c(x, y). As the z′θi and z′′θj are drawn independently from the
−c(x,y)
latent distribution, so the remaining item Eµ⊗ν [e ϵ ] is equivalent to the uniformity loss. The
the above integral could be turned into the sum of the elements in matrix of dual variables of f (t1 )
andg (t1 ) in each iteration. Meanwhile, based on the convergence provided in the Lemma 3, When
t1 → ∞, f (t1 ) converge uniformly to a fixed point f (∞) with f (t1 ) ≤ f (∞) , which would provided
the maximum value of the dual formula in the f (∞) , which corresponding to the coupling plan the
P(∞) .
31
C.2.2 GCA benefits the downstream supervised classification task
Here, we further show how the minimizing the uniformity loss is equivalent to minimize the down-
stream supervised loss in classification tasks under several assumptions [16]. Giving a labeled dataset
D = {(x̄i , yi )} ∈ X̄ × Y where Y = [1..M ] with M classes, we consider a fixed, pre-trained
encoder fθ ∈ F : X → S with its representation fθ (X ) and the input space X contains both positive
and negative views of n original samples (x̄i )i∈[1..n] ∈ X̄ , sampled from the data distribution p(x̄).
For each positive views x̄′i in X , we sample from x̄i using x′i ∼ A(·|x̄i ), A(·|x̄i ) is augmentation
distribution (e.g., by applying color jittering, flip, or crop with a given probability). For consistency,
we assume A(x̄) = p(x̄) so that the distributions A(·|x̄) and p(x̄) induce a marginal distribution
p(x) over X . Given an anchor x̄i , all views x′′ ∼ A(·|x̄j ), j ̸= i from different samples x̄j are
considered as negatives.
Proof of claim 1: From assumption 1 we know that the representation ability of encoders is good
enough via the augmented samples in the Reproducing Kernel Hilbert Space (RKHS) HX̄ of the
original sample spaces X̄ . And the kernel KX̄ with any function g RKHS defined by (Hfθ , Kθ ) also
belongs to HX̄ when conditioned on the distribution A(x|·). So based on the assumption we have,
we can obtain a centroid estimator by [16]:
Definition 11 (Kernel-based centroid estimator). Let (xi , x̄i )i∈[1..n] ∼ A(x, x̄), asssuming a consis-
tent estimator of µx̄ is.
n
X
∀x̄ ∈ X̄ , µ̂x̄ = αi (x̄)f (xi ),
i=1
Pn
where αi (x̄) = j=1 [(Kn + nλIn )−1 ]ij KX̄ (x̄j , x̄) and Kn = [KX̄ (x̄i , x̄j )]i,j∈[1..n] . It converges
to µx̄ with the ℓ2 norm at a rate O(n−1/4 ) for λ = O(n−1/2 ).
The above estimator allows us to use representations of images close to an anchor x̄ to estimate µx̄ .
From the assumption 2, we assume that all the samples in the same class is achievable when give the
ideal augmentation or at least close to the augmented points in an ϵ region.
Consequently, if the prior is “good enough” to connect intra-class images disconnected in the
augmentation graph suggested by Assumption 1, then this estimator allows us to tightly control
the classification risk of the representation of fθ on a classification task with a linear classifier
g(x̄) = Wfθ (x̄) (with fθ fixed) that minimizes the multi-class classification loss.
First we show the cross-entropy could be transformed into centroid based distance (optimal
supervised loss): The cross-entropy (CE) to measure the difference between the true distribution
(actual labels) and the estimated probability distribution (predicted probabilities from the model),
which usually computes logits zk from the model, then apply the softmax function to obtain probabil-
ities pk . The logits zk could be defined as negative distances between f (x̄) and class centroids µk
after the representation:
zk = −∥f (x̄) − µk ∥2 , µk = Ep(x̄|y=k) µx̄
which encourages the model to reduce the distance to the correct class centroid while increasing
distances to others. The probability of class k in M classes given input x̄ is:
ezk 2
p(y = k|x̄) = PM , p(y|x̄) ∝ e−∥f (x̄)−µy ∥ .
z
j=1 e
j
If the model predictions p(y|x̄) are influenced by the distances between x̄ and the class centroids µy ,
then minimizing cross-entropy indirectly affects these distances. The standard CE loss in supervised
learning for classification tasks is:
LCE (fθ ) = −E(x̄,y)∼D [log p(y|x̄)] (93)
N M
1 XX
= −E(x̄,y)∼D −∥f (x̄) − µy ∥2 − log Z = −
yi,k log(pi,k ) (94)
N i=1
k=1
which focuses on maximizing the likelihood ŷ = arg maxk p(y = k|x̄) of the correct class for each
individual sample x̄i , where yi,k is the true label indicator for example i and class k, pi,k is the
predicted probability for example i and class k. Therefore, we can rewrite the CE loss as optimal
supervised loss in [16], which is defined as:
32
Lemma 9 (Optimal supervised loss). Let a downstream task D with M classes. We assume that
M ≤ d+1 (i.e., a big enough representation space), that all classes are balanced and the realizability
of an encoder f ∗ = arg minf ∈F Lsup (fθ ) with
h 2
i
Lsup (fθ ) = log Ey,y′ ∼p(y)p(y′ ) e−∥µy −µy′ ∥ ,
and µy = Ep(x̄|y) µx̄ . Then the optimal centroids (µ∗y )y∈Y associated to f ∗ make a regular simplex
on the hypersphere S d−1 and they are perfectly linearly separable, i.e.,
min E(x̄,y)∼D 1(wy · µ∗y < 0) = 0.
(wy )y∈Y ∈Rd
Proof of the Lemma 9 All "labeled" centroids µy = Ep(x̄|y) µx̄ are bounded by 1 (∥µy ∥ ≤
Ep(x̄|y) EA(x|x′ ) ∥f (x)∥ = 1 by Jensen’s inequality). Then, since all classes are balanced, we can
re-write the supervised loss as:
C
1 X −∥µy −µy′ ∥2
Lsup (fθ ) = log e .
C2 ′
y,y =1
We have:
X X X X
ΓY (µ) := ∥µy −µy′ ∥2 = ∥µy ∥2 +∥µy′ ∥2 −2µy ·µy′ ≤ (2−2µy ·µy′ ) = 2C 2 −2∥ µy ∥2 ≤ 2C 2 ,
y,y ′ y,y ′ y,y ′ y
PC
with equality if and only if y=1 µy = 0 and ∀y ∈ [1..C], ∥µy ∥ = 1. By the strict convexity of
u → e−u , we have:
X
2 ΓY (µ) 2C
exp(−∥µy − µy′ ∥ ) ≥ C(C − 1) exp − ≥ C(C − 1) exp − ,
′
C(C − 1) C −1
y̸=y
with equality if and only if all pairwise distances ∥µy − µy′ ∥ are equal (equality case in Jensen’s
PC
inequality for a strict convex function), y=1 µy = 0, and ∥µy ∥ = 1. Thus, all centroids must form
a regular (C − 1)-simplex inscribed on the hypersphere S d−1 centered at 0. Furthermore, since
∥µy ∥ = 1, we have equality in Jensen’s inequality:
∥µy ∥ = ∥EA(x|x̄′ ) fθ (x)∥ ≤ EA(x|x̄′ ) ∥fθ (x)∥ = 1,
so f must be perfectly aligned for all samples belonging to the same class: ∀x, x̄′ ∼ p(·|y), fθ (x̄) =
fθ (x̄′ ).
Seond we show optimizing the uniformity loss is equivalent to the supervised loss:
As we have uniformity Loss defined in Equation (87)
′′ 2
h ′
i
Luniform (fθ ) = log Ez′i ,z′′j ∼pdata e−ε∥zi −zj ∥ , (95)
where µy = Ep(x̄|y) µ̂x̄ . Express the expectation over all pairs in terms of class labels:
Ez′i ,z′′j = Ey,y′ Ez′i ∼p(z|y),z′′j ∼p(z|y′ ) .
So the uniformity loss could be decomposed into intra-class and inter-class components:
′′ 2 ′′ 2
h h ′
ii h h ′
ii
Luniform (fθ ) = log Ey Ez′i ,z′′j ∼p(z|y) e−ε∥zi −zj ∥ + Ey̸=y′ Ez′i ∼p(z|y),z′′j ∼p(z|y′ ) e−ε∥zi −zj ∥ .
| {z } | {z }
Intra-Class Term Inter-Class Term
33
Based on the assumption 2, we can approximate the Intra-Class term by:
2 2 2 2
z′i − z′′j = ∥(µy + δi ) − (µy′ + δj )∥ = ∥µy − µy′ + δi − δj ∥ ≈ ∥µy − µy′ ∥
′′ 2 2
h ′
i
=⇒ Ez′i ∼p(z|y),z′′j ∼p(z|y′ ) e−ε∥zi −zj ∥ ≈ e−ε∥µy −µy′ ∥
2
Since e−∥µy −µy ∥ = 1 (for y = y ′ ), the difference will be mainly dependent on the inter-class term.
Therefore, a tighter (smaller) uniformity loss leads to smaller values of the supervised loss. This
supports the idea that improving uniformity in representations can benefit downstream supervised
classification tasks.
Although minimizing the uniformity loss can enhance downstream classification tasks, it may also
lead the model to learn shortcut features that could impair the encoder’s generalization ability. To
show this, we incorporate two propositions from previous work by Robinson et al. [48].
a set of feature vectors z = (z1 , z2 , . . . , zn ) = (zj )j∈S ∈ Z, where each zj comes from Zj . Further,
let λ denote the measure on Z induced by z and λ(·|zS ) denote the conditional measure on Z for
fixed zS . For S ⊆ [n] we use zS to denote the projection of z onto Z S . Finally, an injective map
g : Z → X produces observations x = g(z). The feature suppression is defined as:
Definition 12. Consider an encoder fθ : X → Sd−1 and features S ⊆ [n]. For each zS ∈ Z S , let
µ(·|zS ) be the pushforward measure on S d−1 by fθ ◦ g of the conditional λ(·|zS ).
If one feature is uniformly distributed on the latent space, it might cause feature suppression due
to different features could both achieve the minimization of the uniformity loss as the following
propositions [48]:
34
Proposition 5 (Feature suppression). For a set S ⊆ [n] of features let
h h + ⊤ − −
ii
LS (fθ ) = Lalign (fθ ) + Ex+ − log Ex ef (x ) f (x ) zS = zS
denote the (limiting) InfoNCE conditioned on x+ , x− having the same features S. Suppose that
pj is uniform on Z j = S d−1 for all j ∈ [n]. Then the infimum inf LS is attained, and every
fθ ∈ arg minf LS (fθ′ ) suppresses features S almost surely.
Proof of proposition 5 is in [48].
C.3.2 How the GCA methods and unbalanced OT and alleviates the feature suppression
Here we extended the unbalanced OT in the Equation (9) as the following:
min dM (Ptgt ∥Pθ ) + λ1 dϕ1 (Pθ ) + λ2 dϕ2 (Pθ ) + · · · + λn dϕn (Pθ ) (97)
θ
The UOT equation can be converted with finding the transport plan Pθ that minimizes the transporta-
tion cost between two probability measures µ and ν. Here we only need to show that the relaxation
or adding penalties will change the optimal transport plan Pθ , which is empirically exhibited in the
Figure A4.
Suppose we have empirical samples {z′i }ni=1 from µ and {z′′j }m
j=1 from ν. We can approximate the
measures using empirical distributions:
n m
1X 1 X
µ≈ δz′i , ν ≈ δz′′ ,
n i=1 m j=1 j
where δz is the Dirac delta function at point z. The standard UOT objective can be written as:
!
m
n X m n
X X 1 X 1
min C(z′i , z′′j )Pij + λ1 dϕ1 Pij + λ2 dϕ2 Pij (98)
P≥0
i=1 j=1 j=1
n i=1
m
n X m
X Pij Pij
= min Cij Pij + λ1 Pij log − 1 + λ2 Pij log −1 (99)
P≥0
i=1 j=1
ri cj
where C is the cost matrix dϕ could be any divergence (e.g., Kullback-Leibler divergence) with respect
to a convex function ϕ. P1µ and P⊤ 1ν are the marginal distributions. λ1 , λ2 are regularization
parameters controlling the unbalancedness and ri = n1 (source marginal mass for z′i ), cj = m1
(target
′′
marginal mass for zj ). Based on the UOT, here we can choose the divergence as L:
X
Pij
Pij
L(P) = Cij Pij + λ1 Pij log − 1 + λ2 Pij log −1
i,j
ri cj
To find the minimizer, we take the partial derivative of L(P) with respect to Pij and set it to zero:
∂L Pij Pij
= Cij + λ1 log + λ2 log =0 (100)
∂Pij ri cj
=⇒ λ1 (log Pij − log ri ) + λ2 (log Pij − log cj ) = −Cij (101)
=⇒ (λ1 + λ2 ) log Pij − λ1 log ri − λ2 log cj = −Cij (102)
−Cij + λ1 log ri + λ2 log cj
=⇒ log Pij = (103)
λ1 + λ2
−Cij + λ1 log ri + λ2 log cj
=⇒ Pij = exp (104)
λ1 + λ2
The minimizer Pij depends on λ1 and λ2 and the weights of ri and cj , which determine the influence
of the marginals ri and cj , and through the scaling of the cost Cij by λ1 +λ2 . This explicit relationship
shows how λ1 and λ2 determine the minimizer.
35
D Details of Experiments
The following experiments involving with the GPU was set up on NVIDIA GeForce RTX 3090.
In Table 2 standard settings, we used two different experimental setups. The first setup, referred to
as the C0 or standard settings, was applied specifically to the CIFAR10 and CIFAR100 tasks. The
second setup was used for the SVHN and ImageNet100 tasks, respectively. Below, we present the
settings for CIFAR10 and CIFAR100, followed by the setups for SVHN and ImageNet100. Here is
the setups for CIFAR10 and CIFAR100:
• The SSL model has 512 feature dimensions with the base model (ResNet-18), which first
convolutional changed as a layer with 3 input channels, 64 output channels, kernel size 3,
stride 1, padding 1, and no bias. We replace the max-pooling layer as the identity.
• A sequential projector comprising a linear layer mapping from feature dimension to 2048,
ReLU activation, and another linear layer mapping from 2048 to 128.
• For SSL training, an SGD optimizer is used with a learning rate of 0.6, momentum 0.9,
and a weight decay of 1.0e-6. A LambdaLR scheduler is employed with linearly decay
the learning rate to 1.0e-3 over total steps, which equals the length of the SSL training
loader times the maximum epochs. The SSL model is trained for a maximum of 500 epochs,
without loading a pre-trained model. The parameters of encoders are frozen after training.
Temperature or epsilon: 0.5.
• For supervised training, an Adam optimizer is also used with a learning rate of 0.2, mo-
mentum 0.9 and a weight decay of 0. A same LambdaLR scheduler is applied, where the
learning rate is reduced by a factor of 1.0e-3. For supervised training, the model is trained
for a maximum of 200 epochs using the specified train and test loaders.
The setups for SVHN and ImageNet100 are:
• The SSL model has number of feature dimensions equal to the fc layer incoming features of
base model (ResNet-50). We replace the max-pooling layer as the identity.
• A sequential projector comprising a linear layer mapping from feature dimension to 2048,
ReLU activation, and another linear layer mapping from 2048 to 128.
• For SSL training, an Adam optimizer is used with a learning rate of 3e-4. The SSL model
is trained for a maximum of 200 epochs for ImageNet100 and 500 epochs for the SVHN,
without loading a pre-trained model. The parameters of encoders are frozen after training.
Temperature or epsilon: 0.5.
• For supervised training, an Adam optimizer is also used with a learning rate of 3e-4. The
model is trained for a maximum of 100 epochs using the specified train and test loaders.
There is the "extreme DA" (Ex DA) column in Table 2, which is the average of the following three
settings:
• C1: Large Erase Settings: Here, we first employed the same standard augmentation as C0 in
Appendix D.1 does, than we apply the random erase with ’p=1’ (random erasing is applied
every time), the ’scale=(0.10, 0.33)’. The large erase is applied before the normalization.
• C2: Strong Crop Setting: This involves a strong cropping operation followed by resizing,
which applied by ’transforms.RandomCrop’ and ’transforms.Resize’. The crop size varies
based on the severity level, with values ranging from 96 to 224 pixels. We selected level 3
during our experiments, than Resizes the cropped image back to 32x32 pixels.
• C3: Brightness settings: This augmentation alters the brightness of the images. We have
’severity’ determines the degree of brightness change, with predefined levels ranging from
‘.05‘ to ‘.3‘, corresponding to level 1 and level 5. And we chosse the level 5 as our C3
augmentation. The brightness is adjusted in the HSV color space, specifically altering the
value channel to change the brightness.
36
To evaluate performance on CIFAR10-C, we use a pretrained SSL model with frozen parameters.
Fine-tuning is performed by training only the linear layer with 10% of CIFAR10-C data for 50 epochs.
We compute the final score by averaging results across all corruption types and severity levels in
CIFAR10-C. And the details of each column are provided in Table A2, Table A3 and Table A4.
Table A2: Test accuracy for contrastive methods on CIFAR-10. Test accuracy for different contrastive methods
and their GCA equivalents on CIFAR-10 for ResNet-18 under extreme augmentation conditions, averaged over
5 seeds.
Conditions INCE GCA-INCE RINCE GCA-RINCE SimCLR BYOL IOT IOT-uni GCA-UOT
Standard 92.01 ± 0.40 93.02 ± 0.19 93.27±0.20 93.47±0.32 92.16 ±0.16 90.56 ± 0.59 92.10 ± 0.22 91.49 ± 0.11 93.50±0.31
Erase 88.40 ± 0.17 88.16 ± 0.89 88.80±1.01 89.21±0.59 88.44 ± 0.24 88.77 ± 0.58 87.02 ± 0.43 87.83 ± 0.30 89.84 ± 0.58
Crop 72.45 ± 0.40 72.79 ± 0.62 73.02±0.39 73.10±0.31 71.84 ± 1.02 70.78 ± 0.62 70.44 ± 0.64 70.78 ± 0.21 73.35 ± 0.41
Brightness 85.24 ± 0.41 85.60 ± 0.57 85.97±0.50 85.98 ± 0.58 85.32 ± 0.32 85.10 ± 0.29 84.31 ± 0.84 83.77 ± 0.21 86.36±0.34
Table A3: Test accuracy for contrastive methods on CIFAR-100. Test accuracy for different contrastive methods
and their GCA equivalents on CIFAR-100 using ResNet-18 under extreme augmentation conditions, averaged
over 5 seeds.
Conditions INCE GCA-INCE RINCE GCA-RINCE SimCLR BYOL IOT IOT-uni GCA-UOT
Standard 71.09 ± 0.31 71.55 ± 0.12 71.63 ± 0.36 71.95 ± 0.48 70.85 ± 0.50 69.75 ± 0.37 68.37 ± 0.42 68.62 ± 0.35 72.16 ± 0.38
Large Erase 62.54 ± 0.20 62.65 ± 0.17 63.55 ± 0.14 63.14 ± 0.41 62.94 ± 0.13 62.70 ± 0.31 62.69 ± 0.34 62.56 ± 0.22 63.62 ± 0.27
Strong Crop 45.67 ± 0.31 46.31 ± 0.22 46.47 ± 0.20 46.50 ± 0.35 46.05 ± 0.34 43.11 ± 0.41 45.36 ± 0.19 45.29 ± 0.12 46.60 ± 0.34
Brightness 59.87 ± 0.36 59.68 ± 0.33 60.38 ± 0.25 60.56 ± 0.25 60.46 ± 0.19 55.74 ± 0.63 57.52 ± 0.32 57.42 ± 0.18 61.02 ± 0.55
Table A4: Test accuracy for contrastive methods on CIFAR-10C. Test accuracy for different contrastive methods
and their GCA equivalents on CIFAR-10C for ResNet-18 under extreme augmentation conditions, averaged over
5 seeds.
Conditions INCE GCA-INCE RINCE GCA-RINCE SimCLR BYOL IOT IOT-uni GCA-UOT
Standard 81.52 ± 1.04 82.63 ± 0.28 82.86 ± 0.21 82.87 ± 0.11 81.74 ± 1.54 82.43 ± 0.06 82.01 ± 0.80 81.18 ± 1.12 82.90 ± 0.49
Large Erase 82.67 ± 0.31 82.19 ± 0.35 83.11 ± 0.36 83.44 ± 0.55 82.49 ± 0.24 82.53 ± 0.34 82.30 ± 0.69 82.30 ± 0.48 83.09 ± 0.62
Strong Crop 43.55 ± 0.90 41.12 ± 3.28 42.51 ± 1.48 44.96 ± 1.46 42.74 ± 1.27 43.86 ± 0.91 42.36 ± 0.87 42.43 ± 1.03 45.20 ± 1.20
Brightness 84.00 ± 0.47 84.08 ± 0.26 84.24 ± 0.36 84.54 ± 0.26 83.86 ± 0.46 83.87 ± 0.19 82.33 ± 0.31 80.86 ± 0.72 84.88 ± 0.44
This section is going to show the settings of experiments in Figure 1, which involves the domain
generalization task. Training was executed under the DomainBed framework. Each model underwent
training across multiple domains, with 5 distinct seeds (seed 71, 68, 42, 36, 15) used to ensure
reproducibility:
• For SSL model configuration, we employed a ResNet-18 architecture as the encoder, fol-
lowing with a 2048-dimensional, 3-layer projector equipped with BatchNorm1D and ReLU
activations. We improved the framework of the SelfReg algorithm in Domainbed [23] by a
self-supervised contrastive learning phase which involves the GCA-INCE, with regularized
parameters ε = 0.2.
• For SSL training hyperparameters, an Adam optimizer is used with a learning rate of 3e-4,
and a weight decay of 1.5e-6. A Cosine Annealing learning rate scheduler is employed with
a maximum number of 200 iterations equal to the length of the SSL training. The learning
rate is scheduled to decrease to a minimum value of 0. The SSL model is trained for a
maximum of 1500 epochs.
• In the self-supervised learning phase, we utilized 20% of the data from each of the four
datasets in the PACS dataset. The unsupervised holdout part employed contrastive learning
augmentations to enhance generalization capabilities. Specifically, we implemented dual
augmentation, including operations such as random resized crops, flips, color jitter, and
grayscale conversion, standardized to an input shape of 3 × 224 × 224.
• The supervised learning rate was set at 5 × 10−5 using MSE loss, and the Adam optimizer
with no weight decay. Training involved both domain and class labels over 3000 epochs,
with checkpoints every 300 epochs to capture the model’s best performance. This approach
was supplemented by fine-tuning the model post-unsupervised training phase. Domain
labels were categorized into four types corresponding to the PACS dataset, and class labels
were divided into five categories. In domain classification, all four domains are used for
37
training, with 70% of the data held out for training and the remaining 30% used for testing.
Four domains are utilized for class classification tasks. We train supervised models on three
domains and test on the fourth.
• The domain accuracy is computed as the average of the highest domain accuracies across
five seeds, with each of the four test domains set sequentially as the test domain. The
standard deviation for domain accuracy is calculated from the results across these five seeds.
• Class label accuracy is determined by averaging the accuracies of the four test environments
for each domain. The average of highest performance across the domain is taken as the
mean accuracy. The standard deviation for each domain is computed from the five seeds,
and these values are then averaged to obtain the final class standard deviation.
Both the label classification tasks and the domain classification tasks use the Mean Squared Error
(MSE) loss.
E Additional Experiments
E.1 Complexity Analysis of GCA Algorithms
Time complexity analysis: The computational complexity of GCA including the forward pass and
backward propagation phases. The complexity varies in different variants. For GCA-INCE, the
computational complexity of forward pass is related to the speed of Sinkhorn when solving the
EOT problem as O(n2 /ε3 ), in which ε is the regularization parameter . For GCA-UOT, the forward
complexity is the Sinkhorn algorithm solving unbalanced OT, which is characterized by
O(τ (α + β)2 /ε log(n)[log(∥C∥∞ ) + log(log(n)) + log(1/ε)]),
where C is the cost matrix, α and β denote the total masses of the measures, and τ is a regularization
parameter related to KL divergences in the UOT framework [44]. Notably, the gradient backpropaga-
tion speed is not seriously affected by scaling operations in the EOT as we explained in Section B.7.
Moreover, the relaxations of penalties in UOT provide a even faster speed compared with the INCE
and GCA-INCE (see Figure A2).
Figure A2: Time complexity analysis (A) Time complexity analysis of different methods. Here, we provide
the time complexity for different contrastive methods (INCE, RINCE) and GCA-based methods (GCA-INCE,
GCA-RINCE, and GCA-UOT) on CIFAR-10. (B) Time complexity for INCE (GCA-INCE-1), and GCA-INCE
with different number of iterations GCA-INCE-100 denotes GCA-INCE with 100 iterations. We ran the methods
on the CIFAR-10 as self-supervised learning task for 50 epochs, and compared their run time. (C) Performance
of the INCE (iteration=1) and GCA-INCE (iterations>1) on the CIFAR10 with different number of iterations.
The shaded blue region is the standard deviation across 5 seeds.
The complexity of the forward pass is affected by the choice of proximal operator, whereas the
complexity of the gradient backward pass is influenced by the form of dM [37]. Notably, utilizing
Sinkhorn algorithms in GCA-UOT, GCA-RINCE, and GCA-INCE, only requires updating the
coupling matrix P (B × B) without impacting the complexity of the backward pass, where B is
the batch size. OT is known to have B 2 complexity and in many cases can converge very quickly in
fewer than 10 iterations. In practice, we use a simple stopping criterion for the multiple iterations
using a convergence criterion.
Upon analyzing the run time for the different methods (see Figure A2) we observe that the GCA-
based variants of the different base approaches (INCE, or RINCE) achieve very similar run time
38
as their equivalent loss, but different losses (RINCE vs INCE) exhibit more significant variability.
Specifically, we find that RINCE and GCA-RINCE have lower time complexity than INCE and
GCA-INCE. So the runnning speed is even quicker if we utillized different dM in Equation (8).
We study the uniformity and alignment of the representations learned by our GCA-INCE vs. INCE
variants of GCA in Algorithms 1.We train the model through the corresponding settings (C0: standard
provided in the , C1: erase, C2: crop, C3: brightness) provided in the Appendix D.1 and Appendix D.2.
We find that in general, the GCA variants improve the representation quality evaluated by alignment
and uniformity on both CIFAR-10 and CIFAR-10C datasets.
Here we compared the optimal transport (OT) plans of different methods after training for 500
epochs under standard augmentation C0 settings in Appendix D.1. Specifically, we analyzed the
− log(P) matrices of INCE, GCA-INCE, GCA-RINCE, and GCA-UOT, as shown in Figure A4. In
these matrices, darker blue regions represent higher similarity, while lighter blue areas indicate less
similarity. The matrices are rearranged based on class labels, so an effective model should display
empty diagonals and block structures aligned along the main diagonal and sub-diagonals—reflecting
high intra-class similarity and low inter-class similarity.
Figure A4(A) shows that INCE results in a matrix with only row normalization. In contrast, Fig-
ures A4(B) and (C) demonstrate that GCA-INCE and GCA-RINCE achieve both row and column
normalization, leading to more uniform distributions. Figure A4(D) reveals that GCA-UOT pro-
duces a matrix highlighting greater differences between positive and negative pairs, underscoring its
effectiveness in distinguishing them.
39
(A) INCE (B) GCA-INCE
Figure A5: Visualization transport plan P for different amounts of entropic regularization. (Top) The transport
plans for ϵ from 0.01 to 1 for INCE and (Bottom) GCA-UOT after 5 iterations. To compute each plan, we took a
mini-batch on CIFAR-10 with 1024 samples, and loaded the same weights of Resnet-18 for each subfigure.
40
Figure A7: Hyperparameter sensitivity for
q and λ in GCA-RINCE. Both experiments
are tested on the CIFAR-10 dataset with a
ResNet-18 encoder and involve strong aug-
mentation with large erase. (Left) Given
q = 0.98, we change λ from 0 to 1. (Right)
Given λ = 0.01, we change q from 0 to
1. The red threshold line is the INCE per-
formance with the large erase augmentation.
Each point represents the CIFAR-10 classi-
fication accuracy of the ResNet-18 model
pre-trained for 400 epochs and evaluated af-
ter 300 epochs.
41