-
Notifications
You must be signed in to change notification settings - Fork 24.9k
ZeroRedundancyOptimizer: an implementation of a standalone sharded optimizer wrapper #46750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
💊 CI failures summary and remediationsAs of commit 766ba48 (more details on the Dr. CI page):
🕵️ 6 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this!
I saw there are some changes in android and fbgemm. If those are unintentional, shall we remove those? Thanks!
I just merged master before submitting, I'm not sure why this is considered as a change from this branch, I'll update |
torch/optim/zeroptimizer.py
Outdated
""" | ||
return self.optim.state_dict() | ||
|
||
def consolidate_state_dict(self, recipient_rank: int = 0) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to start discussions, I don't like that part, I see a couple of issues:
- the API is not the same as normal pytorch optimizers
- this is very slow (broadcast state shards everywhere)
- this adds a fairly big code footprint for a single usage (checkpoints)
I think that relying on rpc/rref would be nicer, but it would require both c10 and rpc to be initialized, very open to other ideas
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrshenli, during model checkpointing, we save individual states of each data parallel process involved in ZeRO into a separate file. During the checkpoint loading, we do the consolidation and do the partitioning based on the number of data parallel process.
In DeepSpeed ZeRO, since we handle the model checkpoint and restore internally, we do not expose any state consolidation APIs to the user. Regardless of whether the user is using ZeRO or not, we have a common, 'save_checkpoint' and 'load_checkpoint' APIs that handles the consolidation in the background if needed.
I feel like what you guys have is a good solution. Alternatively, you could remove the consolidate API, and just use state_dict() to always return the consolidated states. I feel that since the model parameters themselves are not partitioned for ZeRO Stage 1 and Stage 2, it also makes sense to return a consolidated state dictionary for the optimizer states.
Another alternative could be to use state_dict(self, consolidation=False), which would make state_dict compatible with all other optimizers, but for ZeRO, you have the option of consolidation.
Codecov Report
@@ Coverage Diff @@
## master #46750 +/- ##
==========================================
- Coverage 80.65% 80.44% -0.21%
==========================================
Files 1913 1914 +1
Lines 208121 208333 +212
==========================================
- Hits 167859 167600 -259
- Misses 40262 40733 +471 |
torch/optim/zeroptimizer.py
Outdated
|
||
|
||
class ZeROptimizer(Optimizer): | ||
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we link this to distributed.rst
and build it locally to make sure docs correctly formatted? Please post a screenshot of the built doc, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
on a second thought, this deserves it's own page instead of being part of c10d I think.
Thanks for the review @mrshenli, updating later today |
torch/optim/zeroptimizer.py
Outdated
|
||
|
||
# Credits: classy_vision/generic/distributed_util.py | ||
def _recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.device) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be done in a follow-up PR. There is a recursive_to
in DDP. I wonder if we should consolidate it with this one.
pytorch/torch/nn/parallel/distributed.py
Line 698 in 5a2b537
def _recursive_to(self, inputs, target_gpu): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah that would be nice indeed, could this be made a public interface ? Honestly I would be happy to remove as many helpers as possible
torch/optim/zeroptimizer.py
Outdated
class ZeROptimizer(Optimizer): | ||
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>` | ||
optimizer and shards its state as described by ZeRO_. | ||
:: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we move this to the bottom of the doc, and expand it to be a full example?
torch/optim/zeroptimizer.py
Outdated
self._all_states: List[Dict[str, Any]] = [] | ||
|
||
# Current default device is set by the parameters allocated to this rank | ||
self._device = self.partition_parameters()[self.rank][0]["params"][0].device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the second dimension of the self.partition_parameters()
? I mean the one after self.rank
. Is it device? If yes, we can avoid that, as new features in the distributed package can assume that each rank exclusively works on one device. And we are going to retire the single-process multi-device mode in DDP as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a sync with @msbaines would be nice, a key assumption in this code is that ranks can have parameters spread across multiple devices, cannot be removed lightly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a little surprised that the legacy single-process multi-device mode is indeed used here. @msbaines could you please elaborate on why do you need this mode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just realized that the above only replied to part of your question. The dimensions look like self.partition_parameters[rank][param_group]
, and each param could be on a different device, which is a hard requirement to keep the sharded optimizer compatible with some techniques.
This self._device is kind of a bandaid though, for some methods we need to have a handle to 'a' device (for instance for state broadcasting), so this saves a handle to a device which works
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a little surprised that the legacy single-process multi-device mode is indeed used here. @msbaines could you please elaborate on why do you need this mode?
I can reply to that, combining OSS + model parallel for instance means that OSS ranks could own several devices, that's a use case actively being used in fairseq
…ributed seems not to be compatible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
When are you ready, please import this as a diff and land. |
else: | ||
# Fetch the optim state from the other replicas | ||
global_rank = _get_global_rank(self.group, rank) | ||
replica_state = [0] | ||
dist.broadcast_object_list(replica_state, src=global_rank, group=self.group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rohan-varma I was getting deadlocks in unit tests with this configuration, with NCCL and current pytorch, I checked that all ranks were properly calling this but one of the ranks went through the call (the call returned) while the others timed out.
I checked
- that the tensors were on the gpu
- that all ranks were calling
- that the group was correct
- that the src rank was correct
Replacing this by the "homemade" (from Classy) _broadcast_object and keeping the same call parameters yields no issue. I was wondering whether you had an idea about that, something I'm doing wrong ? That worked fine with Gloo if I remember correctly
cc @mrshenli
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@blefaudeux has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@blefaudeux has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@blefaudeux has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@blefaudeux merged this pull request in 87fb370. |
return value | ||
|
||
|
||
def _broadcast_object( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late comment, but can we modify this to use the existing broadcast_object_list
in pytorch native now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes we can, there are a couple of follow up PRs planned, I removed it at some point because there were some deadlocks on CI but I guess that I was not using it right
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@blefaudeux Can you tell me more about the details about the deadlocks on CI
? My training procedure also got stuck intermittently, and the error is not reproducible so getting a hard time solving the problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it safer to use broadcast_object_list
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's just that it's part of pytorch now, and less code is better code I think. I'm not sure for the reason for deadlocks, a guess was that I was not using broadcast_object_list
properly since it's guessing the receiving device by context instead of being explicit, but I'm not 100% sure of that
Implement the first stage of ZeRO, sharding of the optimizer state, as described in this blog post and this paper. This implementation is completely independent from the DeepSpeed framework, and aims at providing ZeRO-compliant building blocks within the PyTorch scheme of things.
This works by:
This can be used with DDP, although some communications are wasted in that case (gradients are all-reduced to all ranks). This implementation was initially developed in Fairscale, and can also be used with an optimized DDP which only reduces to the relevant ranks. More context on ZeRO and PyTorch can be found in this RFC
The API with respect to loading and saving the state is a known pain point and should probably be discussed an updated. Other possible follow ups include integrating more closely to a modularized DDP, making the checkpoints partition-agnostic, exposing a gradient clipping option and making sure that mixed precision states are properly handled.
original authors include @msbaines, @min-xu-ai and myself