Open
Description
Describe the bug
If I call pipe.enable_attention_slicing I get NaN's returned when output type is 'latent' and a value error for image output.
The error is....
/Volumes/SSD2TB/AI/Diffusers/lib/python3.11/site-packages/diffusers/image_processor.py:147: RuntimeWarning: invalid value encountered in cast
images = (images * 255).round().astype("uint8")
printing the latent gives
$ python tas.py
Loading pipeline components...: 100%|█████████████████████████████████████████████████████| 6/6 [00:00<00:00, 8.89it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:10<00:00, 5.26s/it]
tensor([[[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
...
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]]]], device='mps:0',
dtype=torch.float16)
commenting out the call to enable_attention_slicing gives non NaN latest and a proper image.
I've tested this on the current release version and the current HEAD version
Reproduction
from diffusers import DiffusionPipeline
import torch
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = DiffusionPipeline.from_pretrained(model_id, variant="fp16",
torch_dtype=torch.float16)
pipe.to(device="mps", torch_dtype=torch.float16)
pipe.enable_attention_slicing()
prompt = "analog film photo Butterflies in a jungle, cold color palette, vivid colors, detailed, 8k, 35mm photo, Kodachrome, Lomography, highly detailed"
negative_prompt = "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured"
images = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=2,
guidance_scale=7,
output_type='latent'
).images
print(images)
comment out or delete pipe.enable_attention_slicing()
and the script works as expected.
Logs
(Diffusers) $ python tas.py
Loading pipeline components...: 100%|█████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.57it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:10<00:00, 5.34s/it]
tensor([[[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]],
[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]],
[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]],
[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]]]], device='mps:0',
dtype=torch.float16)
(Diffusers) $
System Info
- 🤗 Diffusers version: 0.33.0.dev0
- Platform: macOS-15.3.2-arm64-arm-64bit
- Running on Google Colab?: No
- Python version: 3.11.10
- PyTorch version (GPU?): 2.6.0 (False)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.27.1
- Transformers version: 4.50.3
- Accelerate version: 0.34.2
- PEFT version: not installed
- Bitsandbytes version: not installed
- Safetensors version: 0.4.5
- xFormers version: not installed
- Accelerator: Apple M3
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: No
Who can help?
I'm assuming its more MPS related than SDXL.