Description
Describe the bug
In order to work with infiniteyou (https://huggingface.co/ByteDance/InfiniteYou) in 24GB vram gpus I'm distributing the model in several gpus.
After trying different device_map configurations and moving input tensors on different devices, I always get the same error in FluxTransformer2DModel. The problem comes from controlnet_block_samples and controlnet_single_block_samples. Once calling forward of FluxTransformer2DModel, inside the function they are always in "cuda:0" regardless of the device they are on in the call.
I've been able to run the inference by modifying some lines in:
Basically I move the tensors to the hidden_states.device where the error happens.
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)].to(hidden_states.device)
But this is not a solution.
Reproduction
import torch
from diffusers import FluxTransformer2DModel
transformer_path="[path to]/FLUX.1-dev/transformer"
dtype = torch.bfloat16
transformer = FluxTransformer2DModel.from_pretrained(
transformer_path,
subfolder=None,
torch_dtype=dtype,
device_map="auto",
max_memory={0: "16GB", 1: "16GB", 2: "16GB", 3: "16GB"}
)
device = transformer.device
controlnet_block_samples= [torch.rand((1,3888,3072), device=device,dtype=dtype) for i in range(4)]
controlnet_single_block_samples= [torch.rand((1,3888,3072), device=device,dtype=dtype) for i in range(10)]
latents = torch.rand((1, 3888, 64), device=device, dtype=dtype)
timestep = torch.rand(1, device=device, dtype=dtype)
guidance = torch.rand(1, device=device, dtype=dtype)
pooled_projections = torch.rand((1, 768), device=device,dtype=dtype)
encoder_hidden_states = torch.rand((1, 512, 4096), device=device,dtype=dtype)
txt_ids = torch.rand((512, 3), device=device,dtype=dtype)
img_ids = torch.rand((3888, 3), device=device,dtype=dtype)
result = transformer(
hidden_states=latents,
pooled_projections=pooled_projections,
timestep=timestep / 1000,
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples,
return_dict=False,
controlnet_blocks_repeat=True,
guidance=guidance,
encoder_hidden_states=encoder_hidden_states,
txt_ids=txt_ids,
img_ids=img_ids
)[0]
Logs
Traceback (most recent call last):
File "/secondary/projects/InfiniteYou/test.py", line 91, in <module>
result = transformer(
File "/secondary/.virtualenvs/infiniteyou/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/secondary/.virtualenvs/infiniteyou/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/secondary/.virtualenvs/infiniteyou/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
output = module._old_forward(*args, **kwargs)
File "/secondary/.virtualenvs/infiniteyou/lib/python3.10/site-packages/diffusers/models/transformers/transformer_flux.py", line 540, in forward
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!
System Info
🤗 Diffusers version: 0.32.2
- Platform: Linux-5.15.0-58-generic-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.10.6
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.28.1
- Transformers version: 4.48.0
- Accelerate version: 1.0.1
- PEFT version: 0.14.0
- Bitsandbytes version: not installed
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: NVIDIA L4, 23034 MiB
NVIDIA L4, 23034 MiB
NVIDIA L4, 23034 MiB
NVIDIA L4, 23034 MiB
NVIDIA L4, 23034 MiB
Who can help?
No response