-
Notifications
You must be signed in to change notification settings - Fork 558
add auto_wrap_policy
into XLA FSDP for automatic wrapping
#4318
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This is great! @ronghanghu can you also update the usage of this arg in https://github.com/pytorch/xla/blob/master/docs/fsdp.md ? I think this is something many people would want to use! |
b6dd40f
to
bd01aea
Compare
@JackCaoG I added the usages to this doc. We should probably test it on more cases like GPT-2 before merging. |
Hi @ronghanghu @JackCaoG, thanks so much for your great contribution! Can I know that are auto_warp_policy also suitable for general HuggingFace models (T5, OPT, etc.) especially the ones without wrap structure like GPT2Block? Thanks |
@jianguoz, yes, it should be compatible with general Hugging Face models such as BERT, T5, and OPT. |
@ronghanghu Thanks for your quick reply! That is really awesome! Since you are testing this new feature on more cases. Before this is merged, do I also need to do a test on new models such as T5/OPT? |
@jianguoz Yes, although it isn't finalized yet, you're welcome to try it out on more models or cases! (And since this PR is entirely in Python, it could be added to an existing torch_xla installation by directly copying over the files in torch_xla/distributed/fsdp/) |
@ronghanghu That is great! I will copy the files try it on the T5 and OPT). Hope you could finalize it soon and open easy door for HuggingFace very large model fine-tuning on TPU! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly lgtm, can you rebase to resolve conflicts?
58f632c
to
8da27b5
Compare
Hey @jianguoz if you tried this pr and confirmed it worked with other HF models, could you give an update here. FYI we have an pr huggingface/transformers#20774 to add FSDP support to HF and we will update that pr to use |
8da27b5
to
0002f45
Compare
I just rebased it to the latest master. Let me also test it on more cases before merging. |
Hi @JackCaoG, thanks for the efforts:) I am trying it on other HF models, will give an update soon! |
Thanks @ronghanghu I am going to merge this one and add a test to CI for |
Hi @ronghanghu, Good afternoon:) Thanks for the new testing cases! Can I know if |
It should work, take a look at https://github.com/pytorch/xla#how-to-run-on-tpu-vm-pods-distributed-training. |
@JackCaoG @ronghanghu I tested
Files in /tmp/mnist for
While for
I guess there may be errors or wrong assert conditions (e.g., 8) when sharding model checkpoints together into a full model state dict on >8 pods. Can you check whether this is caused by the error of checkpoint saving? Thanks:) |
Hi @jianguoz, I think it's because on v3-32 or v4-64, the filesystems are separate on each host VM in the TPU pod. The sharded checkpoints are saved in a distributed manner by each host, while the consolidation part requires them to be in the same filesystem. On v3-32 or v4-64, one can either skip this checkpoint consolidation part by adding |
Hi @ronghanghu, sorry for the late reply. Thanks for your help and super valuable contributions:) It works with the filestore system! In addition, we get similar results to your testing results, i.e., I will modify your vit_10b_fsdp_example to add autowrap and do a test on the 10B model. |
Hi @ronghanghu, good afternoon! I am trying the |
Hi @JackCaoG, I start to test this and will give updates soon. Meanwhile, I believe HuggingFace xla_spawn only supports training on a single TPU node (<=8 cores) and it does not support TPU pods like v3-32. Hence, it is better if they can resolve this issue first and then auto FSDP could be scaled to more TPUs. |
|
@JackCaoG That is awesome! Before I did not know that we can use |
Yes, this |
@ronghanghu Thanks so much for your suggestions! I will set them accordingly:) |
Hi @ronghanghu, Thanks very much for your auto_wrap FSDP contributions! I have a question regarding consolidating models during modifying your code. I checked that there is a process to Consolidate the sharded model checkpoints for MNIST in test_train_mp_mnist_fsdp_with_ckpt.py, and there is no such code in run_vit_training.py. I have two questions here:
Thanks so much for your help again! |
Hi @jianguoz, thanks for your test! Here checkpoint consolidation is only needed if one wants to stitch the sharded checkpoints together into a single checkpoint file for a non-FSDP-wrapped model (the original model without
This test was to verify that the consolidated checkpoint could work for the original MNIST model, so it does not have FSDP wrap before loading the model. If it is needed to resume FSDP training, one can simply load the sharded checkpoint files: # the FSDP-wrapped model and its optimizer
model = fsdp_wrap(MNIST().to(device))
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=flags.momentum)
# load the sharded checkpoint file
rank = xm.get_ordinal()
world_size = xm.xrt_world_size()
ckpt_path = f'{flags.ckpt_prefix}_rank-{rank:08d}-of-{world_size:08d}.pth'
ckpt_sharded = torch.load(ckpt_path)
model.load_state_dict(ckpt_sharded['model'])
optimizer.load_state_dict(ckpt_sharded['optimizer'])
I think for 10B model size (which should be around 40 GB parameter size for float32), it should still be able to fit into the host memory of a TPU VM (which typically has 300+ GB memory). Do you experience host-side OOM when consolidating the checkpoint from command line as follows (here the checkpoint files are # consolidate the checkpoint files matching `ckpt_prefix` + `ckpt_suffix`
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
--ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
--ckpt_suffix "_rank-*-of-*.pth" |
Hi @ronghanghu, thank you so much for your super quick reply and suggestions! I have tested above consolidating commands on mnist-fsdp (model is quite small) and it does not have any issues. Since I haven't saving the shared checkpoint files for >=10B model, I will give an update this week:). One more question is that regarding to inference (i.e., test), do you usually consolidate the shared checkpoint for very large models or only keep the original shared files? or we consolidate files based on size, say 20G maximum for a file (like OPT 30B, BLOOM 175B), to make it easier for users download them and load them on limited GPU device (with sharding the models). |
Hi @jianguoz, for very large models (e.g. those with 20B+ parameters), I usually just keep the original sharded checkpoint files, since these models are hard to run without FSDP anyway :) For smaller models I sometimes consolidate them into a single checkpoint to use them in other tasks. |
Hi @ronghanghu, That really makes sense:) Thanks for your shared experience and have a nice night:) |
Hi @JackCaoG @ronghanghu, good morning! I am testing HuggingFace models following the above auto_wrap_policy instructions. I start with a
I still have the same errors even though i set a very short input length (128 tokens just works with bfloat16) and label length (128 tokens). The issue is also not solved when i add
My inputs are 2-D. I know that for encoder-decoder models, they may share the embeddings. Does torch_xla Really appreciate your time and help! |
@JackCaoG A further update is that I wrote code to test Huggingface PR using HuggingFace T5-3B model with more TPUs, i.e., v3-128. I use type_based method (size_based has 2-D issue) to wrap the T5block. Each model has 3B/128=22m parameters, which is relatively small on each device. However, It still easily gets the OOM issue. |
Hi @jianguoz, regarding Hugging Face transformers, I earlier set up a small example in https://github.com/huggingface/transformers/compare/main...ronghanghu:transformers:huggingface_fsdp_example?expand=1. There is an ongoing PR to add it to the Hugging Face transformer repo (here is a draft). Regarding the issue of
This layer can be used for both |
Hi @ronghanghu, thanks for your reply and detailed information. Regarding the ongoing PR, I think my code is similar to the PR except that I changed the nested FSDP to auto-wrapping functionality into FSDP. So far, the OOM issues of 3B model is still unsolved with bfloat16 and batch size 1 on V3-128, and I am checking the potential errors. |
This PR adds the auto-wrapping feature in XLA FSDP, similar to the native PyTorch FSDP's
auto_wrap_policy
argument.Auto-wrapping submodules based on policies
We now allow to automatically wrap the submodules in an
nn.Module
based on the policy specified in theauto_wrap_policy
argument to theXlaFullyShardedDataParallel
class.For example, one can set
to automatically wrap all
GPT2Block
submodules (which is probably the most common scenario in transformer-style models).Or one can also apply it based on the parameter size of a submodule
to automatically wrap all submodules with more than e.g. 1e7 (10M) parameters.
There are also more policies such as
lambda_auto_wrap_policy
to determine whether to wrap a module by a custom callable. The wrapping policies are directly borrowed from native PyTorch FSDP policies in https://github.com/pytorch/pytorch/blob/v1.13.0/torch/distributed/fsdp/wrap.py.Gradient checkpointing (i.e. activation checkpointing/rematerialization)
Additionally, now one can also specify an
auto_wrapper_callable
argument to theXlaFullyShardedDataParallel
class to use a custom callable wrapper for the submodules (default wrapper is justXlaFullyShardedDataParallel
). For example, one can use the following to apply gradient checkpointing (i.e. activation checkpointing/rematerialization) to each auto-wrapped submodule.The MNIST and ImageNet examples are updated accordingly to show examples of auto-wrapping usage based on size or classes. Also, this PR changes the MNIST and ImageNet FSDP tests to
pin_layout=True
by default to be consistent with #4359.cc: @AlexWertheim @JackCaoG
New tests added:
[OK] Test MNIST size-based auto-wrap FSDP (and command line checkpoint consolidation) on v3-8
python3 -u ~/xla_fsdp_dev/test/test_train_mp_mnist_fsdp_with_ckpt.py \ --batch_size 16 --drop_last --num_epochs 2 \ --auto_wrap_policy size_based
Results: matching expected accuracy for 2 training epochs
[OK] Test MNIST type-based auto-wrap FSDP (and command line checkpoint consolidation) on v3-8
python3 -u ~/xla_fsdp_dev/test/test_train_mp_mnist_fsdp_with_ckpt.py \ --batch_size 16 --drop_last --num_epochs 2 \ --auto_wrap_policy type_based
Results: matching expected accuracy for 2 training epochs
[OK] Test MNIST type-based auto-wrap FSDP + gradient checkpointing (and command line checkpoint consolidation) on v3-8
python3 -u ~/xla_fsdp_dev/test/test_train_mp_mnist_fsdp_with_ckpt.py \ --batch_size 16 --drop_last --num_epochs 2 \ --auto_wrap_policy type_based --use_gradient_checkpointing
Results: matching expected accuracy for 2 training epochs
[OK] Test ImageNet ResNet-50 size-based auto-wrap FSDP on v3-8
python3 -u ~/xla_fsdp_dev/test/test_train_mp_imagenet_fsdp.py \ --datadir /datasets02/imagenet-1k --drop_last \ --model resnet50 --test_set_batch_size 64 --eval_interval 10 \ --lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \ --auto_wrap_policy size_based
Results: matching expected accuracy for batch size 128
[OK] Test ImageNet ResNet-50 type-based auto-wrap FSDP on v3-8
python3 -u ~/xla_fsdp_dev/test/test_train_mp_imagenet_fsdp.py \ --datadir /datasets02/imagenet-1k --drop_last \ --model resnet50 --test_set_batch_size 64 --eval_interval 10 \ --lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \ --auto_wrap_policy type_based
Results: matching expected accuracy for batch size 128
[OK] Test ImageNet ResNet-50 type-based auto-wrap + gradient checkpointing FSDP on v3-8
python3 -u ~/xla_fsdp_dev/test/test_train_mp_imagenet_fsdp.py \ --datadir /datasets02/imagenet-1k --drop_last \ --model resnet50 --test_set_batch_size 64 --eval_interval 10 \ --lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \ --auto_wrap_policy type_based --use_gradient_checkpointing
Results: matching expected accuracy for batch size 128