0% found this document useful (0 votes)
29 views

Conditional Diffusion Model

Copyright
© © All Rights Reserved
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
29 views

Conditional Diffusion Model

Copyright
© © All Rights Reserved
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 11

ShiftDDPMs: Exploring Conditional Diffusion Models

by Shifting Diffusion Trajectories


Zijian Zhang1 , Zhou Zhao1 * , Jun Yu2 , Qi Tian3
1
Department of Computer Science and Technology, Zhejiang University
2
School of Computer Science and Technology, Hangzhou Dianzi University
3
Huawei Cloud & AI
[email protected], [email protected], [email protected], [email protected]
arXiv:2302.02373v3 [cs.CV] 25 Mar 2023

Abstract
Diffusion models have recently exhibited remarkable abilities
to synthesize striking image samples since the introduction sample from (unconditional)
of denoising diffusion probabilistic models (DDPMs). Their
key idea is to disrupt images into noise through a fixed for-
ward process and learn its reverse process to generate sam-
early-stage critical-stage late-stage
ples from noise in a denoising way. For conditional DDPMs,
most existing practices relate conditions only to the reverse image
process and fit it to the reversal of unconditional forward pro- manifold
cess. We find this will limit the condition modeling and gen-
eration in a small time window. In this paper, we propose a
mixed sampling procedure (input 0~9 when conditional sampling)
novel and flexible conditional diffusion model by introducing
conditions into the forward process. We utilize extra latent late-stage

space to allocate an exclusive diffusion trajectory for each


condition based on some shifting rules, which will disperse critical-stage

condition modeling to all timesteps and improve the learn-


ing capacity of model. We formulate our method, which we early-stage
call ShiftDDPMs, and provide a unified point of view on ex-
isting related methods. Extensive qualitative and quantitative
experiments on image synthesis demonstrate the feasibility Figure 1: Exploration of the mechanism of conditional
and effectiveness of ShiftDDPMs. DDPMs. We grid 1000 timesteps with a step size of 50
and perform grid-search for (t1 , t2 ) paris to find the shortest
Introduction and Motivation critical-stage that can ensure high accuracy of conditional
generation. For MNIST, it is (400, 600).
Deep generative models such as Generative Adversarial Net-
works (GANs) (Goodfellow et al. 2014), Variational Au-
toencoders (VAEs) (Kingma and Welling 2013), autoregres-
sive models (Van Oord, Kalchbrenner, and Kavukcuoglu conditions (Dhariwal and Nichol 2021). Another is to train a
2016) and normalizing flows (Rezende and Mohamed 2015) conditional DDPM from scratch by incorporating conditions
have shown remarkable abilities to model complex data dis- into the function approximator of the reverse process. Both
tributions and synthesize high-quality samples in various methods try to fit their conditional reverse process to the re-
fields. Diffusion models (Sohl-Dickstein et al. 2015) are re- versal of fixed unconditional forward process. This brings up
cently brought back into focus by denoising diffusion prob- a question: Can we design a more effective forward process
abilistic models (DDPMs) (Ho, Jain, and Abbeel 2020), utilizing given conditions to form a new type of conditional
which exhibits competitive image synthesis results and has DDPMs and benefit from it?
been applied in a wide range of data modalities. We investigate this question by exploring the mechanism
Generally, DDPMs gradually disrupt images by adding of how conditional DDPMs achieve conditional sampling
noise through a fixed forward process and learn its reverse based on unconditional forward process, similar to that in
process to generate samples from noise in a denoising way. PDAE (Zhang, Zhao, and Lin 2022). We conduct some ex-
There are two main methods to achieve conditional DDPMs. periments, shown in Figure 1. Concretely, we train an un-
One is to learn an estimator that can compute the similarity conditional DDPM and a conditional one on MNIST (LeCun
between conditions and noisy data and use it to guide pre- et al. 1998), respectively. The conditional one incorporates
trained unconditional DDPMs to sample towards specified class labels (one-hot vector) into the function approxima-
* Corresponding author. tor of parameterized reverse process. The top two rows re-
Copyright © 2023, Association for the Advancement of Artificial spectively show the latents xt sampled from x0 for various
Intelligence (www.aaai.org). All rights reserved. t and the samples generated by the unconditional DDPM
starting from corresponding latents. Intuitively, the latents et al. 2022) and applications (Chen et al. 2020; Saharia et al.
for smaller t preserve more high-level information (such as 2022; Huang et al. 2022a,b; Ye et al. 2022, 2023) have fur-
class) of corresponding data, and they will be totally lost ther improved and expanded diffusion models. Among ex-
when t is large enough. It means that the diffusion trajecto- isting practices of conditional diffusion models, only Grad-
ries originating from different data will get entangled, and TTS (Popov et al. 2021) and PriorGrad (Lee et al. 2021)
the latents will become indistinguishable when t is large. involve conditions in forward process but, nonetheless, they
We then divide the diffusion trajectories into three stages: are totally different methods. We will demonstrate their dif-
early-stage (0 ∼ t1 ), critical-stage (t1 ∼ t2 ) and late-stage ferences under the point of view of ShiftDDPMs.
(t2 ∼ T ). Then we design a mixed sampling procedure that
employs unconditional sampling but switches to conditional ShiftDDPMs
sampling during the specified stage. Note that the uncon-
ditional and conditional reverse process can be connected Background
because they are trained to approximate the same forward DDPMs (Ho, Jain, and Abbeel 2020) employ a forward pro-
process so that they recognize the same pattern of latents. cess that sequentially destroys data distribution q(x0 ) into
The bottom three rows show the samples generated by three N (0, I) with Markov diffusion kernels defined by a fixed
different mixed sampling procedures, where each row only variance schedule {βt }Tt=1 :
employs conditional sampling for the right stage. As we can p
see, only the samples conditioned on input labels during q(xt |xt−1 ) = N ( 1 − βt xt−1 , βt I), (1)
critical-stage match the input class labels. which admits sampling xt from x0 for any timestep t in
These phenomena show that, for unconditional forward closed form:
process, the key to achieve conditional sampling is to shift √
and separate the generative trajectories of different condi- q(xt |x0 ) = N ( ᾱt x0 , (1 − ᾱt )I). (2)
tions during critical-stage. Besides, to some extent, the train-
ing and sampling during early and late stages are indepen- Then a parameterized Markov chain is trained to fit the re-
dent of conditions and leave the condition modeling and versal of forward process, denoising an arbitrary Gaussian
generation to the limited critical-stage. If we can utilize ex- noise to a data sample:
tra latent space and allocate an exclusive diffusion trajectory pθ (xt−1 |xt ) = N (µθ (xt , t), Σθ (xt , t)). (3)
for each condition to make the trajectories of different con-
ditions disentangled all the time, it will disperse condition Training is performed by maximizing the model log likeli-
modeling to all timesteps and may improve the learning ca- hood with some parameterization and simplication:
pacity of model.
√ √
 
Recently, Grad-TTS (Popov et al. 2021) and Prior- L(θ) = Et,x0 , k − θ ( ᾱt x0 + 1 − ᾱt , t)k2 . (4)
Grad (Lee et al. 2021) introduce conditional forward pro-
cess with data-dependent priors for audio diffusion models
and enable more efficient training than those with uncondi- See Appendix A for full details of DDPMs.
tional forward process. However, their differences and con-
nections have not been discussed, and there has not been Conditional Forward Process
a comprehensive exploration of this kind of methods, espe- We aim to shift the diffusion trajectories in some way re-
cially for image diffusion models. In this work, we systemat- lated to conditions. An intuitive way is to directly rewrite
ically study how to design controllable diffusion trajectories the Gaussian distribution in Eq.(2) as:
according to conditions and its effect for conditional diffu- √
sion models. Our main contributions contain: q(xt |x0 , c) = N ( ᾱt x0 + kt · E(c), (1 − ᾱt )Σ(c)).
(5)
• We systemically introduce conditional forward process Specifically, kt · E(c) is the cumulative mean shift of dif-
for diffusion models and provide a unified point of view fusion trajectories at t-th step, where kt is a shift coeffi-
on existing related approaches. cient schedule that decides the shift mode and E(·) is a
• By shifting diffusion trajectories, ShiftDDPMs improve function which we call shift predictor that maps conditions
the utilization rate of latent space and the learning capac- into the latent space. Σ(c) is a diagonal covariance matrix,
ity of model. where Σ(·) is some function similar to E(·). Comparing the
• We demonstrate the feasibility and effectiveness of Shift- diffusion trajectories to water pipes, then kt · E(c) is em-
DDPMs on various image synthesis tasks with extensive ployed to change their directions and Σ(c) is employed to
experiments. change their size in latent space. Note that both E(·) and
Σ(·) can be fixed or trainable. In our experiments on image
Related Works synthesis, trainable Σ(·) leads to complex training and sam-
pling procedure, unstable training and poor results, so we
Diffusion models (Sohl-Dickstein et al. 2015; Ho, Jain, and fix Σ(c) = I like that in Eq.(2). For generalization, we still
Abbeel 2020) are an emerging family of generative models use Σ(c) in our derivations. For simplicity, we use following
and have exhibited remarkable abilities to synthesize high- substitution:
quality samples. Numerous studies (Song et al. 2020; Song, √
Meng, and Ermon 2020; Dhariwal and Nichol 2021; Liu q(xt |x0 , c) = N ( ᾱt x0 + st , (1 − ᾱt )Σ), (6)
where st = kt · E(c) and Σ = Σ(c). We will discuss how Algorithm 1: Training
to choose kt and E(·) in later sections. 1: repeat
With Eq.(6), we can derive corresponding forward diffu- 2: x0 , c ∼ q(x0 )
sion kernels (See proof in Appendix A): 3: t ∼ Uniform({1, . . . , T })
√ √ 4: st = kt · E(c) , Σ = Σ(c)
q(xt |xt−1 , c) = N ( αt xt−1 + st − αt st−1 , βt Σ), 5:  ∼ N√(0, Σ) √
(7) 6: xt = ᾱt x0 +√st + 1 − ᾱt 
where s0 = 0 (i.e. k0 = 0). Intuitively, our forward diffu- 7: xt − ᾱt x0
Optimize k √1−ᾱt − gθ (xt , t)k2Σ−1
sion kernels introduce a small perturbation conditioned on c 8: until converged
to original ones shown in Eq.(1).
With Eq.(6) and Eq.(7), the posterior distributions of for-
ward steps for t > 1 can be derived from Bayes’ rule (See Algorithm 2: Sampling
proof in Appendix A): 1: sT = kT · E(c) , Σ = Σ(c)
√ √ 2: xT ∼ N (sT , Σ)
ᾱt−1 βt αt (1 − ᾱt−1 ) 3: for t = T, . . . , 1 do
q(xt−1 |xt , x0 , c) = N ( x0 + xt
1 − ᾱt 1 − ᾱt 4: z ∼ N (0, Σ) if t > 1, else z = 0 
√ 5: xt−1 = √1αt xt − √1− βt
g (xt , t)
αt (1 − ᾱt−1 ) 1 − ᾱt−1 ᾱt θ
− st + st−1 , βt Σ) . √
αt (1−ᾱt−1 )
q
1−ᾱt−1
1 − ᾱt 1 − ᾱt − 1−ᾱt
st + st−1 + 1−ᾱt
βt z
(8) 6: end for
7: return x0
Parameterized Reverse Process
The reverse process starts at p(xT ) = N (sT , Σ), which is
an approximation of q(xT |x0 , c), and employs parameter- Training Objective
ized kernels pθ (xt−1 |xt , c) to fit q(xt−1 |xt , x0 , c). With our conditional forward process and corresponding re-
According to Eq.(6), x0 can be represented as: verse process, our training objective can be represented as
1 √  (See proof in Appendix A):
x0 = √ xt − st − 1 − ᾱt  , (9) T √
ᾱt X 
xt − ᾱt x0

2
L=c+ γt Ex0 , k √ − gθ (xt , t)kΣ−1 ,
where  ∼ N (0, Σ). Then we take it into Eq.(8) and derive t=1
1 − ᾱt
the posterior mean of forward steps: (13)
where c√is some constant, √ x 0 ∼ q(x 0 ),  ∼ N (0, Σ),
  1 βt xt = ᾱt x0 + st + 1 − ᾱt , kxk2Σ−1 = xT Σ−1 x,
E q(xt−1 |xt , x0 , c) = √ (xt − √ )
αt 1 − ᾱt γ1 = 2α1 1 and γt = 2αt (1− βt
(10) ᾱt−1 ) for t ≥ 2. During train-
1 ing, we follow DDPMs (Ho, Jain, and Abbeel 2020) to adopt
− √ st + st−1 ,
αt the simplified training objective by uniformly sampling t be-
tween 1 and T and ignoring loss weight γt . Algorithm 1 and
where all things are available except . We can employ a Algorithm 2 describe our training and sampling procedure.
model θ (xt , t) to predict . Note that there is no need to Note that E(·) and Σ(·) will be optimized along with θ if
feed c into θ because we have encoded it into condition- they are trainable.
dependent trajectories (i.e., in xt ) so that the model does not
need its guidance. Intuitive Interpretation
Further improvements come from another parameteriza- √ √
Assume that Σ = I and x0t = ᾱt x0 + 1 − ᾱt , DDPMs
tion because  in Eq.(9) is given by: 0 √
xt − ᾱt x0
0
√ employ θ (xt , c, t) to predict  = √1−ᾱ , while ShiftD-
t √
xt − ᾱt x0 st x0 +s − ᾱ x
= √ −√ , (11) DPMs employ gθ (x0t + st , t) to predict t √t1−ᾱ t 0 . They
1 − ᾱt 1 − ᾱt t
are trained to predict the same

pattern of objective but with
ᾱt x0
where the second term is available. Therefore we can em- different input (i.e. input−

1−ᾱt
). Compared with DDPMs,
ploy a model gθ (xt , t) to predict the first term for training. ShiftDDPMs transfer input condition c onto diffusion tra-
We find this parameterization achieves better performance jectories by shifting x0t to x0t + st , which allows conditional
than predicting  directly. Then we can get the predicted pos- training and sampling without feeding c into the network.
terior distributions parameterized by θ: For DDPMs, only the training and sampling during limited
1

βt
 critical-stage plays a key role for condition modeling and
pθ (xt−1 |xt , c) = N ( √ xt − √ gθ (xt , t) generation, while ShiftDDPMs disperse it to all timesteps
αt 1 − ᾱt and improve the utilization rate of latent space, which may

αt (1 − ᾱt−1 ) 1 − ᾱt−1 lead to a better performance.
− st + st−1 , βt Σ).
1 − ᾱt 1 − ᾱt Furthermore, if E(·) is trainable, it will be optimized to
(12) find an optimal shift in latent space to specialize the diffu-
sion trajectories of different conditions and make them dis-
entangle as much as possible. The term dt = − √1αt st +st−1
in Eq.(10) will amend the sampling trajectories in every step
to ensure they can finally fall on the data manifold.
Next, we will show that the forward process of Grad-
TTS (Popov et al. 2021) and PriorGrad (Lee et al. 2021)
correspond to a special choice of kt , respectively.

Prior-Shift
Grad-TTS (Popov et al. 2021) proposes a score-based text-
to-speech generative model with the prior mean predicted
by text encoder and aligner. Specifically, it defines a forward
process satisfying the following SDE:
1 p
dXt = (µ − Xt )βt dt + βt dWt , (14)
2
where µ corresponds to E(c) of our system (E(·) represents
the parameterized text encoder and √ aligner, c represents the
input text). We show that kt = 1 − ᾱt match a discretiza-
tion of Eq.(14) (See proof in Appendix A). For forward pro-
cess, kt increases from 0 to 1 and leads xt to shift to µ as t
Prior-Shift Data-Normalization Quadratic-Shift
increases. For reverse process, we have:
1 Figure 2: 32 × 32 conditional MNIST samples for differ-
dt = (1 − √ )µ , (15) ent shift modes with different shift predictors. The last row
αt
visualize the learned Eψ (·).
where 1 − √1α < 0 because the reverse process starts from
t
N (µ, I) and it needs to eliminate the cumulative shift µ
of forward process. From the view of diffusion trajectories, designing Σ, it can achieve the same precision with a sim-
Grad-TTS changes the ending point of trajectories, so we pler network and have a faster convergence rate under some
name the shift mode as Prior-Shift. constraints (Lee et al. 2021). Data-Normalization is more
Note that Grad-TTS still takes µ as an additional input to suitable for variance-sensitive data such as audio.
the score estimator, but we have stated that it is unnecessary.
However, doing this will get at least not worse results, but Quadratic-Shift
also introduces additional parameter and computation. Except for Prior-Shift, we propose a shift mode to disentan-
gle the diffusion trajectories of different conditions by mak-
Data-Normalization ing the concave trajectories shown in Figure 1 convex. In
PriorGrad (Lee et al. 2021) employs a forward process as this case, we don’t change their starting or ending point, and
follows: E(c) becomes a middle point, where they first progress to
√ √ it and then go away from it. Therefore kt should be similar
xt = ᾱt (x0 − µ) + 1 − ᾱt  , (16) 1 ≈0
to some quadratic function opening downwards √ with k√
√ and kT ≈ 0. Empirically, we choose kt = ᾱt (1 − ᾱt ).
where  ∼ N (0, Σ). Obviously, kt = − ᾱt satisfies We name the shift mode as Quadratic-Shift.
Eq.(16). For forward process, it first normalizes x0 by sub-
tracting its corresponding prior mean µ and then trains a
diffusion model on normalized x0 with prior N (0, Σ). For Experiments
reverse process, we have: In this section, we conduct several conditional image synthe-
sis experiments with ShiftDDPMs. Note that we always set
d1 = µ , dt>1 = 0. (17) Σ(c) = I. Full implementation details of all experiments
Intuitively, the reverse process starts from N (0, Σ) and has can be found in Appendix B.
no amendments all the time except the last step, where it
adds prior mean µ to the output (denormalization). From the Effectiveness of Conditional Sampling
view of diffusion trajectories, PriorGrad resets the starting We first verify the effectiveness of ShiftDDPMs with three
point of trajectories on the data manifold, so we name the shift modes on toy dataset MNIST (LeCun et al. 1998). We
shift mode as Data-Normalization. employ two fixed shift predictors (E1 (·) and E2 (·)) and a
Unlike Prior-Shift that disperses the cumulative shift to all trainable one (Eψ (·) with parameters ψ), mapping a one-hot
points on the diffusion trajectories, Data-Normalization does vector c to a 32 × 32 matrix. Specifically, E1 (·) takes 10
not disentangle the diffusion trajectories so that it must feed evenly spaced numbers over [−1, 1] and expands each num-
c into the network to guide sampling. However, by carefully ber into a 32×32 matrix. E2 (·) takes the mean of all training
Model IS↑ FID↓ NLL↓
Unconditional
DDPM 9.46 3.17 ≤ 3.75
our DDPM 9.52 3.13 ≤ 3.72
Conditional
cond. DDPM 9.59 3.12 ≤ 3.74
cls. DDPM 9.17 5.85 −
Prior-Shift 9.54 3.06 ≤ 3.71
cond. Prior-Shift 9.65 3.06 ≤ 3.70
Data-Normalization 9.14 5.51 −
Quadratic-Shift 9.67 3.05 ≤ 3.69
cond. Quadratic-Shift 9.74 3.02 ≤ 3.70

Table 1: Quantitative results of conditional sample quality


on CIFAR-10. NLL measured in bits/dim.

Quadratic-Shift, with the same method with (cond. DDPM).


Figure 3 presents some conditional CIFAR-10 samples gen-
erated by Quadratic-Shift. Table 1 shows Inception Score,
Figure 3: 32 × 32 conditional CIFAR-10 samples for FID, negative log likelihood for these models.
Quadratic-Shift. As we can see, our retrained unconditional DDPM is
slightly better than the original one with the help of im-
proved settings. With the help of conditional knowledge,
data belonging to the specified class. Eψ (·) employs stacked conditional DDPM outperforms unconditional DDPM.
transposed convolution layers to compute the matrix. Classifer-guided DDPM has poor results because it is sen-
Figure 2 presents the conditional MNIST samples for dif- sitive to the classifier. Data-Normalization has an unstable
ferent shift modes with different shift predictors. As we can training process and poor results, which means that it is not
see, all models work for conditional generation, and the vi- suitable for image synthesis. Both Prior-Shift and Quadratic-
sualization of learned Eψ (c) for Prior-Shift and Quadratic- Shift outperform conditional DDPM, which proves that con-
Shift contain the general shape of corresponding class, ditional forward process can improve the learning capac-
which means that they learn specialized trajectories for dif- ity of ShiftDDPMs. Although incorporating class labels can
ferent conditions. Data-Normalization must feed c into the slightly improve their performance, it also introduces addi-
model so it may ignore the shift. tional computational and parameter complexity.
Despite the success of the fixed shift predictor on MNIST,
we get poor sample results when modeling complex data dis-
Adaption to DDIM for Fast Sampling
tribution such as CIFAR-10. Therefore we will always em- DDIMs (Song, Meng, and Ermon 2020) generalize the for-
ploy trainable shift predictor Eψ with parameter ψ in the ward process of DDPMs to non-Markovian process with an
following experiments. equivalent objective for training, which enables us to employ
an accelerated reverse process with pre-trained DDPMs.
Sample Quality Fortunately, ShiftDDPMs can be adapted to ShiftDDIMs.
Specifically, we can generate xt−1 from xt via:
We further evaluate ShiftDDPMs on CIFAR-10 (Krizhevsky
and Hinton 2009). For a fair comparison, we retrain a DDPM 1  √ 
as baseline (our DDPM) and then use the same experimental xt−1 = √ xt − 1 − ᾱt gθ (xt , t) + st−1
αt
settings and resources to train other models. We train a tra-  
st
q
ditional conditional DDPM (cond. DDPM) by incorporating + 1 − ᾱt−1 − σt2 · gθ (xt , t) − √ + σ t t ,
class labels into the function approximator of reverse pro- 1 − ᾱt
cess. Moreover, we train a time-dependent classifier (Sohl- (18)
Dickstein et al. 2015; Song et al. 2020; Dhariwal and Nichol where t ∼ N (0, Σ) (See proof in Appendix A).
2021) on noisy images and use its gradients to guide (our Then we employ τ = {τ1 , · · · , τS }, which is an in-
DDPM) to sample towards specified class (cls. DDPM). For creasing sub-sequence of [1, · · · , T ] of length S, for ac-
ShiftDDPMs, we train three models, including Prior-Shift, celerated sampling.
r The corresponding variance become
Data-Normalization, and Quadratic-Shift, all with trainable 1−ᾱτi−1
q ᾱ
στi (η) = η 1−ᾱτ 1 − ᾱτ τi , where η is a hyperpa-
shift predictors. Furthermore, we employ another two mod- i i−1

els (cond. Prior-Shift and cond. Quadratic-Shift) by incorpo- rameter that we can directly control. Figure 4 and Table 2
rating class labels into the reverse process of Prior-Shift and presents the conditional CIFAR-10 samples generated by
𝑆𝑆 = 10 𝑆𝑆 = 100

𝜂𝜂 = 0.0

𝜂𝜂 = 0.2

𝜂𝜂 = 0.5

𝜂𝜂 = 1.0

Figure 4: 32 × 32 conditional CIFAR-10 samples for Figure 5: 64 × 64 conditional LFW samples for Quadratic-
Quadratic-Shift. We use fixed input and noise during sam- Shift. From left to right are ground truth image (from test
pling. set), generated image and learned Eψ (c).

gt1 1.0 0.8 0.6 0.4 0.2 0.0 gt2


S 10 20 50 100
0.0 14.25 7.95 5.22 3.93
0.2 14.16 7.88 5.30 4.06
η
0.5 16.96 9.12 6.18 4.43
1.0 25.33 11.67 9.81 5.70
σ̂ 264.32 118.61 36.24 10.95

Table 2: FID of conditional sample quality on CIFAR-10 for


Quadratic-Shift.

Quadratic-Shift mode and its FID with different sampling Figure 6: 64 × 64 conditional LFW interpolations for
steps and η. ShiftDDIMs can still keep competitive FID even Quadratic-Shift. We use fixed input and noise during sam-
though it only samples for 100 steps. pling.

Interpolation of Diffusion Trajectories


DDPMs (Ho, Jain, and Abbeel 2020) show that one can in- Image Inpainting
terpolate the latents of two source data, decode the interpo- Except for class-conditional image synthesis, we conduct
lated latent by the reverse process and get a sample similar some image-to-image synthesis experiments. Compared
to the interpolation of two source data. Inspired by this phe- with enumerable class label, image space is almost infi-
nomenon, we can try to interpolate the diffusion trajectories nite and it is a challenge to assign a unique trajectory for
of different conditions, which is equivalent to interpolating each instance. To prove the capacity of ShiftDDPMs, we
between different st , such as ŝt = λ · kt · Eψ (c1 ) + (1 − λ) · conduct image inpainting experiments using Irregular Mask
kt · Eψ (c2 ) for two different conditions c1 and c2 . In theory, Dataset (Liu et al. 2018) with three image datasets: CelebA-
st decide the direction of diffusion trajectories and the in- HQ (Liu et al. 2015), LSUN-church (Yu et al. 2015) and
terpolated ŝt will take the median direction, which can lead Places2 (Zhou et al. 2017). We employ Quadratic-Shift
the reverse process to generate the samples with the mixed mode and a UNet based architecture as a shift predictor,
features of c1 and c2 . which takes as input the masked image and predicts the shift.
We verify this idea by conducting the experiments of Figure 7 presents some inpainting samples. As we can see,
attribute-to-image (Yan et al. 2016) on LFW dataset (Huang ShiftDDPMs predict a template of complete image based on
et al. 2008). Specifically, it requires us to generate facial the masked one, which guides the trajectory to generate con-
images according to the input attributes. Each image (x0 ) sistent and diverse completions. To further evaluate ShiftD-
in LFW corresponds to a 73-dim real-valued vector (c), DPMs on image inpainting, we follow prior works (Yu et al.
where the value of each dimension represents the degree 2019; Liu et al. 2018; Zhang et al. 2020) by reporting FID
of some attribute such as male, beard and so on. We em- on Places2 dataset. We choose several GAN-based models:
ploy Quadratic-Shift with a trainable shift predictor to train Contextual Attention (Yu et al. 2018), EdgeConnect (Nazeri
on the training set and evaluate it on the test set. Figure 5 et al. 2019) and StructureFlow (Ren et al. 2019) as baselines.
presents some samples, which shows that ShiftDDPMs can Besides, we take score-based inpainting method proposed
learn a meaningful shift (like a heatmap of the face), and in (Song et al. 2020) as another baseline. Table 3 presents
the generated images are consistent with the ground truth in the quantitative results, and ShiftDDPMs achieve compet-
labeled face attributes. Figure 6 presents the interpolations itive results comparable to prior GAN-based methods. In
generated by Quadratic-Shift. The interpolations smoothly addition, ShiftDDPMs also outperform the score-based in-
transition from one side to the other, which verifies our as- painting method, showing that the extra utilization of the la-
sumptions about the disentangled diffusion trajectories. tent space to some extent improves the learning capacity of
This bird has a brown crown, a short
brown bill, and a rounded yellow belly.

This little bird has a speckled appearance of


gray and black with a white belly and breast.

This bird has a very long wing span and


crocked beak.

Figure 8: 256 × 256 text2img samples from CUB test set


for Quadratic-Shift. From left to right are text, learned shift,
generated sample and ground truth, respectively.

Figure 7: 256 × 256 inpainting samples from CelebA-HQ Methods IS FID


and LSUN-church test set for Quadratic-Shift. GAN-INT-CLS 2.88 68.79
StackGAN 3.70 51.89
Mask Percentage 0-20% 20-40% 40-60% StackGAN++ 3.82 15.30
AttnGAN 4.36 -
Contextual Attention 4.8586 18.4190 37.9432 cond. DDPM 4.18 14.79
EdgeConnect 3.0097 7.2635 19.0030 Quadratic-Shift 4.42 14.26
StructureFlow 2.9420 7.0354 22.3803
DDPM (score) 2.0665 6.6129 17.3601
Quadratic-Shift 1.8314 6.2915 14.9667 Table 4: IS and FID of text2img results on CUB dataset.

Table 3: FID of inpainting results on Places2 dataset.


Conclusion
In this work, we propose a novel and flexible conditional dif-
diffusion models. fusion model called ShiftDDPMs by introducing conditional
forward process with controllable condition-dependent dif-
Text-to-Image fusion trajectories. We analyze the differences of existing
We conduct text-to-image (text2img) experiments on CUB related methods under the point of view of ShiftDDPMs and
dataset (Wah et al. 2011). We employ Quadratic-Shift mode first apply them on image synthesis. With ShiftDDPMs, we
and a network as shift predictor to generate shift from the can achieve a better performance and learn some interesting
pre-trained sentence embeddings. Figure 8 presents some features in latent space. Extensive qualitative and quantita-
generated samples. We can see that the shift predictor can tive experiments on image synthesis demonstrate the feasi-
predict a meaningful template according to text and guide bility and effectiveness of ShiftDDPMs.
the trajectory to generate text-consistent images. We choose
several GAN-based models GAN-INT-CLS (Reed et al. Acknowledgments
2016), StackGAN (Zhang et al. 2017), StackGAN++ (Zhang This work was supported in part by the National Natu-
et al. 2018) and AttnGAN (Xu et al. 2018) as baselines. ral Science Foundation of China (Grant No.62020106007,
Besides, we take traditional conditional diffusion method No.U21B2040, No.62222211 and No.202100023), Zhe-
as another baseline, which only incorporates sentence em- jiang Natural Science Foundation (LR19F020006), Zhejiang
beddings into the function approximator of parameterized Electric Power Co., Ltd. Science and Technology Project
reverse process. Table 4 presents some quantitative results, No.5211YF220006 and Yiwise.
and ShiftDDPMs achieve competitive results comparable to
prior GAN-based methods and traditional conditional diffu-
sion model.

More Choice of kt
The choice of kt is flexible. For Prior-Shift, any schedules of
kt monotonically increasing from 0 to 1 can be applied on
Prior-Shift. We have tried with following three types kt : Tt ,
( Tt )2 and sin( 2T

− π2 ) and they all work well. Furthermore,
kt can also be piecewise:

0 t < 0.4T
kt = t−0.4T . (19)
0.6T otherwise
One can also design other reasonable kt . We leave empirical
investigations of kt as future work.
References Reed, S.; Akata, Z.; Yan, X.; Logeswaran, L.; Schiele, B.;
Bishop, C. M. 2006. Pattern recognition. Machine learning, and Lee, H. 2016. Generative adversarial text to image syn-
128(9). thesis. In International Conference on Machine Learning,
Chen, N.; Zhang, Y.; Zen, H.; Weiss, R. J.; Norouzi, M.; 1060–1069. PMLR.
and Chan, W. 2020. WaveGrad: Estimating gradients for Ren, Y.; Yu, X.; Zhang, R.; Li, T. H.; Liu, S.; and Li, G.
waveform generation. arXiv preprint arXiv:2009.00713. 2019. Structureflow: Image inpainting via structure-aware
Dhariwal, P.; and Nichol, A. 2021. Diffusion models beat appearance flow. In Proceedings of the IEEE/CVF Interna-
gans on image synthesis. arXiv preprint arXiv:2105.05233. tional Conference on Computer Vision, 181–190.
Goodfellow, I.; Pouget-Abadie, J.; Mirza, M.; Xu, B.; Rezende, D.; and Mohamed, S. 2015. Variational inference
Warde-Farley, D.; Ozair, S.; Courville, A.; and Bengio, Y. with normalizing flows. In International conference on ma-
2014. Generative adversarial nets. Advances in neural in- chine learning, 1530–1538. PMLR.
formation processing systems, 27. Saharia, C.; Chan, W.; Saxena, S.; Li, L.; Whang, J.; Den-
Ho, J.; Jain, A.; and Abbeel, P. 2020. Denoising diffusion ton, E.; Ghasemipour, S. K. S.; Ayan, B. K.; Mahdavi, S. S.;
probabilistic models. arXiv preprint arXiv:2006.11239. Lopes, R. G.; et al. 2022. Photorealistic Text-to-Image Dif-
fusion Models with Deep Language Understanding. arXiv
Huang, G. B.; Mattar, M.; Berg, T.; and Learned-Miller, E.
preprint arXiv:2205.11487.
2008. Labeled faces in the wild: A database forstudying face
recognition in unconstrained environments. In Workshop on Sohl-Dickstein, J.; Weiss, E.; Maheswaranathan, N.; and
faces in’Real-Life’Images: detection, alignment, and recog- Ganguli, S. 2015. Deep unsupervised learning using
nition. nonequilibrium thermodynamics. In International Confer-
Huang, R.; Lam, M. W.; Wang, J.; Su, D.; Yu, D.; Ren, Y.; ence on Machine Learning, 2256–2265. PMLR.
and Zhao, Z. 2022a. FastDiff: A Fast Conditional Diffusion Song, J.; Meng, C.; and Ermon, S. 2020. Denoising diffusion
Model for High-Quality Speech Synthesis. arXiv preprint implicit models. arXiv preprint arXiv:2010.02502.
arXiv:2204.09934. Song, Y.; Sohl-Dickstein, J.; Kingma, D. P.; Kumar, A.; Er-
Huang, R.; Zhao, Z.; Liu, H.; Liu, J.; Cui, C.; and Ren, Y. mon, S.; and Poole, B. 2020. Score-based generative model-
2022b. Prodiff: Progressive fast diffusion model for high- ing through stochastic differential equations. arXiv preprint
quality text-to-speech. arXiv preprint arXiv:2207.06389. arXiv:2011.13456.
Kingma, D. P.; and Welling, M. 2013. Auto-encoding varia- Van Oord, A.; Kalchbrenner, N.; and Kavukcuoglu, K. 2016.
tional bayes. arXiv preprint arXiv:1312.6114. Pixel recurrent neural networks. In International Conference
Krizhevsky, A.; and Hinton, G. 2009. Learning multiple lay- on Machine Learning, 1747–1756. PMLR.
ers of features from tiny images. Technical Report 0, Uni- Wah, C.; Branson, S.; Welinder, P.; Perona, P.; and Belongie,
versity of Toronto, Toronto, Ontario. S. 2011. The Caltech-UCSD Birds-200-2011 Dataset.
LeCun, Y.; Bottou, L.; Bengio, Y.; and Haffner, P. 1998. Technical Report CNS-TR-2011-001, California Institute of
Gradient-based learning applied to document recognition. Technology.
Proceedings of the IEEE, 86(11): 2278–2324. Xu, T.; Zhang, P.; Huang, Q.; Zhang, H.; Gan, Z.; Huang,
Lee, S.-g.; Kim, H.; Shin, C.; Tan, X.; Liu, C.; Meng, Q.; X.; and He, X. 2018. Attngan: Fine-grained text to image
Qin, T.; Chen, W.; Yoon, S.; and Liu, T.-Y. 2021. Pri- generation with attentional generative adversarial networks.
orGrad: Improving Conditional Denoising Diffusion Mod- In Proceedings of the IEEE conference on computer vision
els with Data-Driven Adaptive Prior. arXiv preprint and pattern recognition, 1316–1324.
arXiv:2106.06406. Yan, X.; Yang, J.; Sohn, K.; and Lee, H. 2016. At-
Liu, G.; Reda, F. A.; Shih, K. J.; Wang, T.-C.; Tao, A.; and tribute2image: Conditional image generation from visual at-
Catanzaro, B. 2018. Image inpainting for irregular holes tributes. In European Conference on Computer Vision, 776–
using partial convolutions. In Proceedings of the European 791. Springer.
Conference on Computer Vision (ECCV), 85–100. Ye, Z.; Jiang, Z.; Ren, Y.; Liu, J.; He, J.; and Zhao, Z. 2023.
Liu, L.; Ren, Y.; Lin, Z.; and Zhao, Z. 2022. Pseudo Numer- GeneFace: Generalized and High-Fidelity Audio-Driven 3D
ical Methods for Diffusion Models on Manifolds. In Inter- Talking Face Synthesis. arXiv preprint arXiv:2301.13430.
national Conference on Learning Representations. Ye, Z.; Zhao, Z.; Ren, Y.; and Wu, F. 2022. SyntaSpeech:
Liu, Z.; Luo, P.; Wang, X.; and Tang, X. 2015. Deep learn- Syntax-aware Generative Adversarial Text-to-Speech. arXiv
ing face attributes in the wild. In Proceedings of the IEEE preprint arXiv:2204.11792.
international conference on computer vision, 3730–3738. Yu, F.; Seff, A.; Zhang, Y.; Song, S.; Funkhouser, T.; and
Nazeri, K.; Ng, E.; Joseph, T.; Qureshi, F.; and Ebrahimi, M. Xiao, J. 2015. Lsun: Construction of a large-scale image
2019. Edgeconnect: Structure guided image inpainting using dataset using deep learning with humans in the loop. arXiv
edge prediction. In Proceedings of the IEEE/CVF Interna- preprint arXiv:1506.03365.
tional Conference on Computer Vision Workshops, 0–0. Yu, J.; Lin, Z.; Yang, J.; Shen, X.; Lu, X.; and Huang, T. S.
Popov, V.; Vovk, I.; Gogoryan, V.; Sadekova, T.; and Kudi- 2018. Generative image inpainting with contextual atten-
nov, M. 2021. Grad-tts: A diffusion probabilistic model for tion. In Proceedings of the IEEE conference on computer
text-to-speech. arXiv preprint arXiv:2105.06337. vision and pattern recognition, 5505–5514.
Yu, J.; Lin, Z.; Yang, J.; Shen, X.; Lu, X.; and Huang, T. S.
2019. Free-form image inpainting with gated convolution.
In Proceedings of the IEEE/CVF International Conference
on Computer Vision, 4471–4480.
Zhang, H.; Xu, T.; Li, H.; Zhang, S.; Wang, X.; Huang,
X.; and Metaxas, D. N. 2017. Stackgan: Text to photo-
realistic image synthesis with stacked generative adversarial
networks. In Proceedings of the IEEE international confer-
ence on computer vision, 5907–5915.
Zhang, H.; Xu, T.; Li, H.; Zhang, S.; Wang, X.; Huang, X.;
and Metaxas, D. N. 2018. Stackgan++: Realistic image syn-
thesis with stacked generative adversarial networks. IEEE
transactions on pattern analysis and machine intelligence,
41(8): 1947–1962.
Zhang, Z.; Zhao, Z.; and Lin, Z. 2022. Unsupervised Repre-
sentation Learning from Pre-trained Diffusion Probabilistic
Models. In Advances in Neural Information Processing Sys-
tems.
Zhang, Z.; Zhao, Z.; Zhang, Z.; Huai, B.; and Yuan, J. 2020.
Text-guided image inpainting. In Proceedings of the 28th
ACM International Conference on Multimedia, 4079–4087.
Zhou, B.; Lapedriza, A.; Khosla, A.; Oliva, A.; and Torralba,
A. 2017. Places: A 10 million image database for scene
recognition. IEEE transactions on pattern analysis and ma-
chine intelligence, 40(6): 1452–1464.
Appendix A From (Bishop 2006) (2.116 and 2.117), we have that
q(xt−1 |xt , x0 , c) is Gaussian and
−1
Derivation of our conditional forward

1 αt
Σ−1 + Σ−1
 
Cov q(xt−1 |xt , x0 , c) =
diffusion kernels 1 − ᾱt−1 1 − αt
According to Markovian property, q(xt |xt−1 , x0 , c) = 1
= 1 αt Σ
q(xt |xt−1 , c) for all t > 1. Therefore, we can assume that: 1−ᾱt−1 + 1−αt

q(xt |xt−1 , c) = N (Axt−1 + b, L−1 ) . (20) (1 − ᾱt−1 )(1 − αt )


= Σ
1 − αt + αt − ᾱt
As we have known the marginal Gaussian for xt−1 : 1 − ᾱt−1
√ = βt Σ ,
q(xt−1 |x0 , c) = N ( ᾱt−1 x0 + st−1 , (1 − ᾱt−1 )Σ) , 1 − ᾱt
(21) (29)
from (Bishop 2006) (2.115), we can derive that the marginal and
Gaussian for xt , i.e., q(xt |x0 , c) is given by: 
E q(xt−1 |xt , x0 , c)

√ √
1 − ᾱt−1 √
 
E q(xt |x0 , c) = A( ᾱt−1 x0 + st−1 ) + b αt −1
(22) = βt Σ Σ (xt − st − αt st−1 )+
Cov q(xt |x0 , c) = L−1 + (1 − ᾱt−1 )AΣAT .
  1 − ᾱt βt


1
Then we need to ensure that: Σ−1 ( ᾱt−1 x0 + st−1 )
1 − ᾱt−1
√ √ √
q(xt |x0 , c) = N ( ᾱt x0 + st , (1 − ᾱt )Σ) , (23) ᾱt−1 βt αt (1 − ᾱt−1 ) √ 1 − ᾱt−1
= x0 + xt − αt st + st−1 .
1 − ᾱt 1 − ᾱt 1 − ᾱt
from which we can derive that: (30)

A = αt I
√ Derivation of the training objective
b = st − αt st−1 (24)
−1
 
L = 1 − ᾱt − αt (1 − ᾱt−1 ) Σ = (1 − αt )Σ . The training objective can be represented as:
Finally, for all t > 1 we have:
 h i
√ √ L = Eq − log pθ (x0 |x1 , c) + DKL q(xT |x0 , c) k p(xT )
q(xt |xt−1 , c) = N ( αt xt−1 + st − αt st−1 , βt Σ) .
T
(25) X h i
We further consider the case for t = 1 from following + DKL q(xt−1 |xt , x0 , c) k pθ (xt−1 |xt , c) .
facts: t=2
√ √ (31)
q(x2 |x1 , c) = N ( α2 x1 + s2 − α2 s1 , β2 Σ) For the first term, we have:
√ (26)
q(x2 |x0 , c) = N ( ᾱ2 x0 + s2 , (1 − ᾱ2 )Σ) .
1 β1
pθ (x0 |x1 , c) = N ( √ (x1 − √ gθ (x1 , 1)) , β1 Σ) .
With similar derivation based on (Bishop 2006) (2.113), we α1 1 − ᾱ1
can get: (32)
√ Then we can derive the first term by Gaussian probability
q(x1 |x0 , c) = N ( α1 x0 + s1 , (1 − α1 )Σ) , (27) density function:
which matches Eq.(23). Therefore we set s0 = 0, i.e., k0 =  d 1
− log pθ (x0 |x1 , c) = log (2π) 2 |β1 Σ| 2
0 to make Eq.(25) true for t = 1.
One can also verify this conclusion with the recurrence 1 1 β1
+ kx0 − √ (x1 − √ gθ (x1 , 1))k2(β1 Σ)−1
relation in Eq.(25) by the rule of the sum of normally dis- 2 α1 1 − ᾱ1
tributed random variables. √
1  d
 1 x1 − ᾱ1 x0 2
= log (2πβ1 ) |Σ| + kgθ (x1 , 1) − √ kΣ−1 ,
2 2α1 1 − ᾱ1
Derivation of the posterior distributions of our (33)
conditional forward frocess where d is the dimension of x.
For the second term, we have:
For all t > 1, we can derive q(xt−1 |xt , x0 , c) by Bayes’
rule: √
q(xT |x0 , c) = N ( ᾱT x0 + sT , (1 − ᾱT )Σ) , (34)
q(xt |xt−1 , x0 , c) q(xt−1 |x0 , c)
q(xt−1 |xt , x0 , c) = . and
q(xt |x0 , c)
(28) p(xT ) = N (sT , Σ) . (35)
Then we can derive the second term by Gaussian Kullback- where zt ∼ N (0, I) because for Wiener process Wt −Ws ∼
Leibler divergence: N (0, t − s) when 0 ≤ s ≤ t. With this recurrence relation,
h i we can derive that:
DKL q(xT |x0 , c) k p(xT ) = √ √ √
Xt = ᾱt X0 + (1 − ᾱt )µ + 1 − ᾱt  , (43)
1 1 √
log + d(1 − ᾱT ) − d + k ᾱT x0 k2Σ−1 . √
2 (1 − ᾱT ) d where we can get kt = 1 − ᾱt for Grad-TTS.
(36)
For the third term, we have:
√ √
ᾱt−1 βt αt (1 − ᾱt−1 ) Appendix B
q(xt−1 |xt , x0 , c) = N ( x0 + xt
1 − ᾱt 1 − ᾱt

αt (1 − ᾱt−1 ) 1 − ᾱt−1 Implementation Details
− st + st−1 , βt Σ) ,
1 − ᾱt 1 − ᾱt We use the same settings with ADM (Dhariwal and Nichol
(37) 2021), including network architecture, timesteps, variance
and schedule, dropout, learning rate and EMA. We set batch size
1 βt to 128 for CIFAR-10, 64 for LFW and 32 for the others. We
pθ (xt−1 |xt , c) = N ( √ (xt − √ gθ (xt , t))
αt 1 − ᾱt use 4 feature map resolutions for 32 × 32 models and 6 for
√ (38) the others.
αt (1 − ᾱt−1 ) 1 − ᾱt−1
− st + st−1 , βt Σ) . To compute Eψ (c), we employ a linear layer and stacked
1 − ᾱt 1 − ᾱt transposed convolution layers to map conditions (one-hot
Then we can derive the third term by Gaussian Kullback- vector or attribute vector) to three-channel feature maps
Leibler divergence: for CIFAR-10 and LFW dataset. For image inpainting on
h i CelebA-HQ, LSUN-church, Place2 datasets, we employ a
DKL q(xt−1 |xt , x0 , c) k pθ (xt−1 |xt , c) U-Net architecture for pixel-to-pixel prediction. For text-to-
√ image synthesis on CUB bird dataset, we employ a linear
1 βt xt − ᾱt x0 2
= k√ √ (gθ (xt , t) − √ )k 1−ᾱt−1 layer and stacked transposed convolution layers with atten-
2 αt 1 − ᾱt 1 − ᾱt ( 1−ᾱ βt Σ))−1
t tion mechanism to map the pre-trained word embeddings to

βt xt − ᾱt x0 2 three-channel feature maps.
= kgθ (xt , t) − √ )kΣ−1 . For image inpainting, we use Irregular Mask Dataset col-
2αt (1 − ᾱt−1 ) 1 − ᾱt
(39) lected by (Liu et al. 2018), which contains 55,116 irregular
Combining the above derivations, we can get final training raw masks for training and 24,866 for testing. During train-
objective: ing, for each image in the batch, we first randomly sample
T  √  a mask from 55,116 training masks, then perform some ran-
X xt − ᾱt x0 2 dom augmentations on the mask, finally we use it to mask
L=c+ γt Ex0 , k √ − gθ (xt , t)kΣ−1 , the image and get our class center for training. So the train-
t=1
1 − ᾱt
(40) ing masks are different all the time. The mask is irregular
c x ∼ q(x ),  ∼ N (0, Σ), x and may be 100% hole due to augmentations. During test-
where
√ is some √ constant, 0 0 t =
ᾱt x0 + st + 1 − ᾱt , kxk2Σ−1 = xT Σ−1 x, γ1 = 2α1 1 ing, we use 12,000 test masks sampled and augmented from
βt
24,866 raw testing masks. These 12,000 masks are catego-
and γt = 2αt (1− ᾱt−1 ) for t ≥ 2. rized by hole size according to hole-to-image area ratios (0-
20%, 20-40%, 40-60%).
A discretization of Grad-TTS The classifier for (cls. DDPM) employs the encoder half
Grad-TTS defines a forward process with following SDE: UNet to classify the noisy images. For the class-conditional
1 p function approximator, we use AdaGN same with that in
dXt = (µ − Xt )βt dt + βt dWt , (41) ADM (Dhariwal and Nichol 2021).
2 We train all our models on eight Nvidia RTX 2080Ti
where µ corresponds to E(c) of our notations. Consider a GPUs.
discretization of it:
1 p
Xt+1 − Xt = (µ − Xt )βt ∆t + βt zt
2
1 1 p
Xt+1 = (1 − βt ∆t)Xt + βt ∆tµ + βt zt
2 2
1 1 p
= (1 − βt ∆t)Xt + (1 − 1 + βt ∆t)µ + βt zt
p 2 p 2 p
≈ 1 − βt ∆tXt + (1 − 1 − βt ∆t)µ + βt zt
√ √ √
= αt Xt + (1 − αt )µ + 1 − αt zt ,
(42)

You might also like