Skip to content

Added the ability to set SDXL Micro-Conditioning embeddings as 0 #4208

Closed
@budui

Description

@budui

Is your feature request related to a problem? Please describe.

During the SDXL training process, it may be necessary to pass in a zero embedding as Micro-Conditioning embeddings:

https://github.com/Stability-AI/generative-models/blob/e25e4c0df1d01fb9720f62c73b4feab2e4003e3f/sgm/modules/encoders/modules.py#L151-L161

# those line will randomly set embedding as zero if `ucg_rate` > 0
                if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
                    emb = (
                        expand_dims_like(
                            torch.bernoulli(
                                (1.0 - embedder.ucg_rate)
                                * torch.ones(emb.shape[0], device=emb.device)
                            ),
                            emb,
                        )
                        * emb
                    )

https://github.com/Stability-AI/generative-models/blob/e25e4c0df1d01fb9720f62c73b4feab2e4003e3f/configs/example_training/txt2img-clipl-legacy-ucg-training.yaml#L65

# SDXL set  the `ucg_rate` of `original_size_as_tuple` embedder as 0.1. 
# so during traning, we need to pass zero embedding as added embedding for time embedding of Unet
            ucg_rate: 0.1
            input_key: original_size_as_tuple
            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: 256  # multiplied by two

Current SDXL-UNet2DConditionModel accepts encoder_hidden_states, time_ids and add_text_embeds as condition.

text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)

To correctly finetune the SDXL model, we need to randomly set the condition embeddings to 0 with a suitable probability.
While it is easy to set encoder_hidden_states and add_text_embeds as zero embedding, It is impossible to zero time_embeds at line 849.

original SDXL uses different embedders to convert different micro-conditions into Fourier features. during training, different Fourier features are independently randomly set to 0. Therefore, UNet2DConditionModel need to be able to independently zero time_embeds part.

Describe the solution you'd like

Added the ability to set SDXL Micro-Conditioning embeddings as 0.

Describe alternatives you've considered

Perhaps it is possible to allow diffusers users to pass in a time_embeds, and if time_embeds exists, time_ids are no longer used?

if "time_embeds" in added_cond_kwargs:
    time_embeds = added_cond_kwargs.get("time_embeds") 
else:
    time_ids = added_cond_kwargs.get("time_ids") 
     time_embeds = self.add_time_proj(time_ids.flatten()) 
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) 

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions