Skip to content

Allow UNet2DModel to use arbitrary class embeddings #2080

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/diffusers/models/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ class UNet2DModel(ModelMixin, ConfigMixin):
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
"""

@register_to_config
Expand All @@ -94,6 +99,8 @@ def __init__(
norm_eps: float = 1e-5,
resnet_time_scale_shift: str = "default",
add_attention: bool = True,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
):
super().__init__()

Expand All @@ -113,6 +120,16 @@ def __init__(

self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)

# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
else:
self.class_embedding = None

self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
Expand Down Expand Up @@ -190,12 +207,15 @@ def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DOutput, Tuple]:
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.

Expand Down Expand Up @@ -225,6 +245,16 @@ def forward(
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)

if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when doing class conditioning")

if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)

class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb

# 2. pre-process
skip_sample = sample
sample = self.conv_in(sample)
Expand Down
3 changes: 3 additions & 0 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
"""

_supports_gradient_checkpointing = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
"""

_supports_gradient_checkpointing = True
Expand Down