Skip to content

Error distributing FluxTransformer2DModel to multiple GPUs using controlnet #11247

Open
@maflx

Description

@maflx

Describe the bug

In order to work with infiniteyou (https://huggingface.co/ByteDance/InfiniteYou) in 24GB vram gpus I'm distributing the model in several gpus.
After trying different device_map configurations and moving input tensors on different devices, I always get the same error in FluxTransformer2DModel. The problem comes from controlnet_block_samples and controlnet_single_block_samples. Once calling forward of FluxTransformer2DModel, inside the function they are always in "cuda:0" regardless of the device they are on in the call.
I've been able to run the inference by modifying some lines in:

Basically I move the tensors to the hidden_states.device where the error happens.
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)].to(hidden_states.device)

But this is not a solution.

Reproduction

import torch
from diffusers import FluxTransformer2DModel
transformer_path="[path to]/FLUX.1-dev/transformer"
dtype = torch.bfloat16
transformer = FluxTransformer2DModel.from_pretrained(
    transformer_path,
    subfolder=None,
    torch_dtype=dtype,
    device_map="auto",
    max_memory={0: "16GB", 1: "16GB", 2: "16GB", 3: "16GB"}
)
device = transformer.device

controlnet_block_samples= [torch.rand((1,3888,3072), device=device,dtype=dtype) for i in range(4)]
controlnet_single_block_samples= [torch.rand((1,3888,3072), device=device,dtype=dtype) for i in range(10)]

latents = torch.rand((1, 3888, 64), device=device, dtype=dtype)

timestep = torch.rand(1, device=device, dtype=dtype)
guidance = torch.rand(1, device=device, dtype=dtype)
pooled_projections = torch.rand((1, 768), device=device,dtype=dtype)
encoder_hidden_states = torch.rand((1, 512, 4096), device=device,dtype=dtype)
txt_ids = torch.rand((512, 3), device=device,dtype=dtype)
img_ids = torch.rand((3888, 3), device=device,dtype=dtype)

result = transformer(
                hidden_states=latents,
                pooled_projections=pooled_projections,
                timestep=timestep / 1000,
                controlnet_block_samples=controlnet_block_samples,
                controlnet_single_block_samples=controlnet_single_block_samples,
                return_dict=False,
                controlnet_blocks_repeat=True,
                guidance=guidance,
                encoder_hidden_states=encoder_hidden_states,
                txt_ids=txt_ids,
                img_ids=img_ids
                )[0]

Logs

Traceback (most recent call last):
  File "/secondary/projects/InfiniteYou/test.py", line 91, in <module>
    result = transformer(
  File "/secondary/.virtualenvs/infiniteyou/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/secondary/.virtualenvs/infiniteyou/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/secondary/.virtualenvs/infiniteyou/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/secondary/.virtualenvs/infiniteyou/lib/python3.10/site-packages/diffusers/models/transformers/transformer_flux.py", line 540, in forward
    hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

System Info

🤗 Diffusers version: 0.32.2

  • Platform: Linux-5.15.0-58-generic-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.10.6
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.28.1
  • Transformers version: 4.48.0
  • Accelerate version: 1.0.1
  • PEFT version: 0.14.0
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: NVIDIA L4, 23034 MiB
    NVIDIA L4, 23034 MiB
    NVIDIA L4, 23034 MiB
    NVIDIA L4, 23034 MiB
    NVIDIA L4, 23034 MiB

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions