Description
Is your feature request related to a problem? Please describe.
During the SDXL training process, it may be necessary to pass in a zero embedding as Micro-Conditioning
embeddings:
# those line will randomly set embedding as zero if `ucg_rate` > 0
if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
emb = (
expand_dims_like(
torch.bernoulli(
(1.0 - embedder.ucg_rate)
* torch.ones(emb.shape[0], device=emb.device)
),
emb,
)
* emb
)
# SDXL set the `ucg_rate` of `original_size_as_tuple` embedder as 0.1.
# so during traning, we need to pass zero embedding as added embedding for time embedding of Unet
ucg_rate: 0.1
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
Current SDXL-UNet2DConditionModel
accepts encoder_hidden_states
, time_ids
and add_text_embeds
as condition.
diffusers/src/diffusers/models/unet_2d_condition.py
Lines 843 to 854 in 2e53936
To correctly finetune the SDXL model, we need to randomly set the condition embeddings to 0 with a suitable probability.
While it is easy to set encoder_hidden_states
and add_text_embeds
as zero embedding, It is impossible to zero time_embeds
at line 849.
original SDXL uses different embedders to convert different micro-conditions into Fourier features. during training, different Fourier features are independently randomly set to 0. Therefore, UNet2DConditionModel
need to be able to independently zero time_embeds
part.
Describe the solution you'd like
Added the ability to set SDXL Micro-Conditioning
embeddings as 0.
Describe alternatives you've considered
Perhaps it is possible to allow diffusers users to pass in a time_embeds
, and if time_embeds
exists, time_ids
are no longer used?
if "time_embeds" in added_cond_kwargs:
time_embeds = added_cond_kwargs.get("time_embeds")
else:
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))