Skip to content

ZeroRedundancyOptimizer: an implementation of a standalone sharded optimizer wrapper #46750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 52 commits into from

Conversation

blefaudeux
Copy link
Contributor

@blefaudeux blefaudeux commented Oct 23, 2020

Implement the first stage of ZeRO, sharding of the optimizer state, as described in this blog post and this paper. This implementation is completely independent from the DeepSpeed framework, and aims at providing ZeRO-compliant building blocks within the PyTorch scheme of things.

This works by:

  • acting as a wrapper to a pytorch optimizer. ZeROptimizer does not optimize anything by itself, it only shards optimizers for distributed jobs
  • each rank distributes parameters according to a given partitioning scheme (could be updated), and owns the update of a given shard only
  • the .step() is called on each rank as expected, the fact that the optimizer actually works on a shard of the model is not visible from the outside
  • when the update is completed, each rank broadcasts the updated model shard to all the other ranks

This can be used with DDP, although some communications are wasted in that case (gradients are all-reduced to all ranks). This implementation was initially developed in Fairscale, and can also be used with an optimized DDP which only reduces to the relevant ranks. More context on ZeRO and PyTorch can be found in this RFC

The API with respect to loading and saving the state is a known pain point and should probably be discussed an updated. Other possible follow ups include integrating more closely to a modularized DDP, making the checkpoints partition-agnostic, exposing a gradient clipping option and making sure that mixed precision states are properly handled.

original authors include @msbaines, @min-xu-ai and myself

@dr-ci
Copy link

dr-ci bot commented Oct 23, 2020

💊 CI failures summary and remediations

As of commit 766ba48 (more details on the Dr. CI page):


  • 7/7 failures possibly* introduced in this PR
    • 1/7 non-CircleCI failure(s)

🕵️ 6 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test1 (1/6)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Jan 12 23:30:19 AssertionError: mypy failed: torch/distributed/optim/zero_redundancy_optimizer.pyi:2: error: Module 'torch.distributed.optim.optimizer' has no attribute '_params_t' [attr-defined]
Jan 12 23:29:57   test_run_mypy (__main__.TestTypeHints) ... FAIL (62.464s)
Jan 12 23:30:00   test_run_mypy_strict (__main__.TestTypeHints) ... ok (2.998s)
Jan 12 23:30:19   test_type_hint_examples (__main__.TestTypeHints) ... ok (18.492s)
Jan 12 23:30:19 
Jan 12 23:30:19 ======================================================================
Jan 12 23:30:19 FAIL [62.464s]: test_run_mypy (__main__.TestTypeHints)
Jan 12 23:30:19 ----------------------------------------------------------------------
Jan 12 23:30:19 Traceback (most recent call last):
Jan 12 23:30:19   File "test_type_hints.py", line 214, in test_run_mypy
Jan 12 23:30:19     self.fail(f"mypy failed: {stdout} {stderr}")
Jan 12 23:30:19 AssertionError: mypy failed: torch/distributed/optim/zero_redundancy_optimizer.pyi:2: error: Module 'torch.distributed.optim.optimizer' has no attribute '_params_t'  [attr-defined]
Jan 12 23:30:19 torch/distributed/optim/zero_redundancy_optimizer.pyi:2: error: Module 'torch.distributed.optim.optimizer' has no attribute 'Optimizer'  [attr-defined]
Jan 12 23:30:19 torch/distributed/optim/zero_redundancy_optimizer.pyi:36: error: Name 'Union' is not defined  [name-defined]
Jan 12 23:30:19 torch/distributed/optim/zero_redundancy_optimizer.pyi:36: note: Did you forget to import it from "typing"? (Suggestion: "from typing import Union")
Jan 12 23:30:19 Found 3 errors in 1 file (checked 1190 source files)
Jan 12 23:30:19  
Jan 12 23:30:19 
Jan 12 23:30:19 ----------------------------------------------------------------------
Jan 12 23:30:19 Ran 4 tests in 94.687s
Jan 12 23:30:19 
Jan 12 23:30:19 FAILED (failures=1)

See CircleCI build pytorch_linux_bionic_py3_8_gcc9_coverage_test1 (2/6)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Jan 12 23:18:02 AssertionError: mypy failed: torch/distributed/optim/zero_redundancy_optimizer.pyi:2: error: Module 'torch.distributed.optim.optimizer' has no attribute '_params_t' [attr-defined]
Jan 12 23:17:38   test_type_hint_examples (__main__.TestTypeHints)
Jan 12 23:18:02 Runs mypy over all the test examples present in ... ok (23.552s)
Jan 12 23:18:02 
Jan 12 23:18:02 ======================================================================
Jan 12 23:18:02 FAIL [77.634s]: test_run_mypy (__main__.TestTypeHints)
Jan 12 23:18:02 Runs mypy over all files specified in mypy.ini
Jan 12 23:18:02 ----------------------------------------------------------------------
Jan 12 23:18:02 Traceback (most recent call last):
Jan 12 23:18:02   File "test_type_hints.py", line 214, in test_run_mypy
Jan 12 23:18:02     self.fail(f"mypy failed: {stdout} {stderr}")
Jan 12 23:18:02 AssertionError: mypy failed: torch/distributed/optim/zero_redundancy_optimizer.pyi:2: error: Module 'torch.distributed.optim.optimizer' has no attribute '_params_t'  [attr-defined]
Jan 12 23:18:02 torch/distributed/optim/zero_redundancy_optimizer.pyi:2: error: Module 'torch.distributed.optim.optimizer' has no attribute 'Optimizer'  [attr-defined]
Jan 12 23:18:02 torch/distributed/optim/zero_redundancy_optimizer.pyi:36: error: Name 'Union' is not defined  [name-defined]
Jan 12 23:18:02 torch/distributed/optim/zero_redundancy_optimizer.pyi:36: note: Did you forget to import it from "typing"? (Suggestion: "from typing import Union")
Jan 12 23:18:02 Found 3 errors in 1 file (checked 1190 source files)
Jan 12 23:18:02  
Jan 12 23:18:02 
Jan 12 23:18:02 ----------------------------------------------------------------------
Jan 12 23:18:02 Ran 4 tests in 117.802s
Jan 12 23:18:02 
Jan 12 23:18:02 FAILED (failures=1)

See CircleCI build pytorch_windows_vs2019_py36_cuda10.1_test2 (3/6)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

RuntimeError: distributed/optim/test_zero_redundancy_optimizer failed!
  File "C:\Users\circleci\project\build\win_tmp\build\torch\distributed\optim\__init__.py", line 8, in <module>
    from .optimizer import DistributedOptimizer
  File "C:\Users\circleci\project\build\win_tmp\build\torch\distributed\optim\optimizer.py", line 8, in <module>
    from torch.distributed.rpc import RRef
ImportError: cannot import name 'RRef'
Traceback (most recent call last):
  File "run_test.py", line 911, in <module>
    main()
  File "run_test.py", line 890, in main
    raise RuntimeError(err_message)
RuntimeError: distributed/optim/test_zero_redundancy_optimizer failed!

(base) circleci@PACKER-5FD865C5 C:\Users\circleci\project\test>if ERRORLEVEL 1 exit /b 1 
+ cleanup
+ retcode=1
+ set +x


Exited with code exit status 1

See CircleCI build pytorch_linux_xenial_py3_clang5_asan_test1 (4/6)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Jan 12 23:13:32 AssertionError: mypy failed: torch/distributed/optim/zero_redundancy_optimizer.pyi:2: error: Module 'torch.distributed.optim.optimizer' has no attribute '_params_t' [attr-defined]
Jan 12 23:12:58   test_run_mypy (__main__.TestTypeHints) ... FAIL (91.418s)
Jan 12 23:13:03   test_run_mypy_strict (__main__.TestTypeHints) ... ok (5.177s)
Jan 12 23:13:32   test_type_hint_examples (__main__.TestTypeHints) ... ok (28.710s)
Jan 12 23:13:32 
Jan 12 23:13:32 ======================================================================
Jan 12 23:13:32 FAIL [91.418s]: test_run_mypy (__main__.TestTypeHints)
Jan 12 23:13:32 ----------------------------------------------------------------------
Jan 12 23:13:32 Traceback (most recent call last):
Jan 12 23:13:32   File "test_type_hints.py", line 214, in test_run_mypy
Jan 12 23:13:32     self.fail(f"mypy failed: {stdout} {stderr}")
Jan 12 23:13:32 AssertionError: mypy failed: torch/distributed/optim/zero_redundancy_optimizer.pyi:2: error: Module 'torch.distributed.optim.optimizer' has no attribute '_params_t'  [attr-defined]
Jan 12 23:13:32 torch/distributed/optim/zero_redundancy_optimizer.pyi:2: error: Module 'torch.distributed.optim.optimizer' has no attribute 'Optimizer'  [attr-defined]
Jan 12 23:13:32 torch/distributed/optim/zero_redundancy_optimizer.pyi:36: error: Name 'Union' is not defined  [name-defined]
Jan 12 23:13:32 torch/distributed/optim/zero_redundancy_optimizer.pyi:36: note: Did you forget to import it from "typing"? (Suggestion: "from typing import Union")
Jan 12 23:13:32 Found 3 errors in 1 file (checked 1190 source files)
Jan 12 23:13:32  
Jan 12 23:13:32 
Jan 12 23:13:32 ----------------------------------------------------------------------
Jan 12 23:13:32 Ran 4 tests in 139.977s
Jan 12 23:13:32 
Jan 12 23:13:32 FAILED (failures=1)

See CircleCI build pytorch_linux_bionic_py3_6_clang9_test (5/6)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Jan 12 23:11:07 AssertionError: mypy failed: torch/distributed/optim/zero_redundancy_optimizer.pyi:2: error: Module 'torch.distributed.optim.optimizer' has no attribute '_params_t' [attr-defined]
Jan 12 23:10:45   test_run_mypy (__main__.TestTypeHints) ... FAIL (59.728s)
Jan 12 23:10:48   test_run_mypy_strict (__main__.TestTypeHints) ... ok (3.036s)
Jan 12 23:11:07   test_type_hint_examples (__main__.TestTypeHints) ... ok (19.548s)
Jan 12 23:11:07 
Jan 12 23:11:07 ======================================================================
Jan 12 23:11:07 FAIL [59.728s]: test_run_mypy (__main__.TestTypeHints)
Jan 12 23:11:07 ----------------------------------------------------------------------
Jan 12 23:11:07 Traceback (most recent call last):
Jan 12 23:11:07   File "test_type_hints.py", line 214, in test_run_mypy
Jan 12 23:11:07     self.fail(f"mypy failed: {stdout} {stderr}")
Jan 12 23:11:07 AssertionError: mypy failed: torch/distributed/optim/zero_redundancy_optimizer.pyi:2: error: Module 'torch.distributed.optim.optimizer' has no attribute '_params_t'  [attr-defined]
Jan 12 23:11:07 torch/distributed/optim/zero_redundancy_optimizer.pyi:2: error: Module 'torch.distributed.optim.optimizer' has no attribute 'Optimizer'  [attr-defined]
Jan 12 23:11:07 torch/distributed/optim/zero_redundancy_optimizer.pyi:36: error: Name 'Union' is not defined  [name-defined]
Jan 12 23:11:07 torch/distributed/optim/zero_redundancy_optimizer.pyi:36: note: Did you forget to import it from "typing"? (Suggestion: "from typing import Union")
Jan 12 23:11:07 Found 3 errors in 1 file (checked 1190 source files)
Jan 12 23:11:07  
Jan 12 23:11:07 
Jan 12 23:11:07 ----------------------------------------------------------------------
Jan 12 23:11:07 Ran 4 tests in 92.817s
Jan 12 23:11:07 
Jan 12 23:11:07 FAILED (failures=1)

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_test (6/6)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Jan 12 23:13:13 AssertionError: mypy failed: torch/distributed/optim/zero_redundancy_optimizer.pyi:2: error: Module 'torch.distributed.optim.optimizer' has no attribute '_params_t' [attr-defined]
Jan 12 23:12:50   test_run_mypy (__main__.TestTypeHints) ... FAIL (60.146s)
Jan 12 23:12:53   test_run_mypy_strict (__main__.TestTypeHints) ... ok (2.994s)
Jan 12 23:13:13   test_type_hint_examples (__main__.TestTypeHints) ... ok (19.836s)
Jan 12 23:13:13 
Jan 12 23:13:13 ======================================================================
Jan 12 23:13:13 FAIL [60.146s]: test_run_mypy (__main__.TestTypeHints)
Jan 12 23:13:13 ----------------------------------------------------------------------
Jan 12 23:13:13 Traceback (most recent call last):
Jan 12 23:13:13   File "test_type_hints.py", line 214, in test_run_mypy
Jan 12 23:13:13     self.fail(f"mypy failed: {stdout} {stderr}")
Jan 12 23:13:13 AssertionError: mypy failed: torch/distributed/optim/zero_redundancy_optimizer.pyi:2: error: Module 'torch.distributed.optim.optimizer' has no attribute '_params_t'  [attr-defined]
Jan 12 23:13:13 torch/distributed/optim/zero_redundancy_optimizer.pyi:2: error: Module 'torch.distributed.optim.optimizer' has no attribute 'Optimizer'  [attr-defined]
Jan 12 23:13:13 torch/distributed/optim/zero_redundancy_optimizer.pyi:36: error: Name 'Union' is not defined  [name-defined]
Jan 12 23:13:13 torch/distributed/optim/zero_redundancy_optimizer.pyi:36: note: Did you forget to import it from "typing"? (Suggestion: "from typing import Union")
Jan 12 23:13:13 Found 3 errors in 1 file (checked 1190 source files)
Jan 12 23:13:13  
Jan 12 23:13:13 
Jan 12 23:13:13 ----------------------------------------------------------------------
Jan 12 23:13:13 Ran 4 tests in 93.069s
Jan 12 23:13:13 
Jan 12 23:13:13 FAILED (failures=1)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

I saw there are some changes in android and fbgemm. If those are unintentional, shall we remove those? Thanks!

@blefaudeux
Copy link
Contributor Author

Thanks for adding this!

I saw there are some changes in android and fbgemm. If those are unintentional, shall we remove those? Thanks!

I just merged master before submitting, I'm not sure why this is considered as a change from this branch, I'll update

@blefaudeux blefaudeux requested a review from apaszke as a code owner October 23, 2020 17:19
"""
return self.optim.state_dict()

def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to start discussions, I don't like that part, I see a couple of issues:

  • the API is not the same as normal pytorch optimizers
  • this is very slow (broadcast state shards everywhere)
  • this adds a fairly big code footprint for a single usage (checkpoints)

I think that relying on rpc/rref would be nicer, but it would require both c10 and rpc to be initialized, very open to other ideas

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @jeffra @samyam, how does ZeRO solve the state saving problem?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli, during model checkpointing, we save individual states of each data parallel process involved in ZeRO into a separate file. During the checkpoint loading, we do the consolidation and do the partitioning based on the number of data parallel process.

In DeepSpeed ZeRO, since we handle the model checkpoint and restore internally, we do not expose any state consolidation APIs to the user. Regardless of whether the user is using ZeRO or not, we have a common, 'save_checkpoint' and 'load_checkpoint' APIs that handles the consolidation in the background if needed.

I feel like what you guys have is a good solution. Alternatively, you could remove the consolidate API, and just use state_dict() to always return the consolidated states. I feel that since the model parameters themselves are not partitioned for ZeRO Stage 1 and Stage 2, it also makes sense to return a consolidated state dictionary for the optimizer states.

Another alternative could be to use state_dict(self, consolidation=False), which would make state_dict compatible with all other optimizers, but for ZeRO, you have the option of consolidation.

@codecov
Copy link

codecov bot commented Oct 23, 2020

Codecov Report

Merging #46750 (8f5d539) into master (ce30dba) will decrease coverage by 0.20%.
The diff coverage is 16.50%.

@@            Coverage Diff             @@
##           master   #46750      +/-   ##
==========================================
- Coverage   80.65%   80.44%   -0.21%     
==========================================
  Files        1913     1914       +1     
  Lines      208121   208333     +212     
==========================================
- Hits       167859   167600     -259     
- Misses      40262    40733     +471     



class ZeROptimizer(Optimizer):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we link this to distributed.rst and build it locally to make sure docs correctly formatted? Please post a screenshot of the built doc, thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on a second thought, this deserves it's own page instead of being part of c10d I think.

@blefaudeux
Copy link
Contributor Author

Thanks for the review @mrshenli, updating later today



# Credits: classy_vision/generic/distributed_util.py
def _recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.device) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be done in a follow-up PR. There is a recursive_to in DDP. I wonder if we should consolidate it with this one.

def _recursive_to(self, inputs, target_gpu):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah that would be nice indeed, could this be made a public interface ? Honestly I would be happy to remove as many helpers as possible

class ZeROptimizer(Optimizer):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
optimizer and shards its state as described by ZeRO_.
::
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we move this to the bottom of the doc, and expand it to be a full example?

self._all_states: List[Dict[str, Any]] = []

# Current default device is set by the parameters allocated to this rank
self._device = self.partition_parameters()[self.rank][0]["params"][0].device
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the second dimension of the self.partition_parameters()? I mean the one after self.rank. Is it device? If yes, we can avoid that, as new features in the distributed package can assume that each rank exclusively works on one device. And we are going to retire the single-process multi-device mode in DDP as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a sync with @msbaines would be nice, a key assumption in this code is that ranks can have parameters spread across multiple devices, cannot be removed lightly

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little surprised that the legacy single-process multi-device mode is indeed used here. @msbaines could you please elaborate on why do you need this mode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized that the above only replied to part of your question. The dimensions look like self.partition_parameters[rank][param_group], and each param could be on a different device, which is a hard requirement to keep the sharded optimizer compatible with some techniques.
This self._device is kind of a bandaid though, for some methods we need to have a handle to 'a' device (for instance for state broadcasting), so this saves a handle to a device which works

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little surprised that the legacy single-process multi-device mode is indeed used here. @msbaines could you please elaborate on why do you need this mode?

I can reply to that, combining OSS + model parallel for instance means that OSS ranks could own several devices, that's a use case actively being used in fairseq

@blefaudeux blefaudeux requested a review from mrshenli January 17, 2021 17:27
Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@mrshenli
Copy link
Contributor

When are you ready, please import this as a diff and land.

else:
# Fetch the optim state from the other replicas
global_rank = _get_global_rank(self.group, rank)
replica_state = [0]
dist.broadcast_object_list(replica_state, src=global_rank, group=self.group)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rohan-varma I was getting deadlocks in unit tests with this configuration, with NCCL and current pytorch, I checked that all ranks were properly calling this but one of the ranks went through the call (the call returned) while the others timed out.

I checked

  • that the tensors were on the gpu
  • that all ranks were calling
  • that the group was correct
  • that the src rank was correct

Replacing this by the "homemade" (from Classy) _broadcast_object and keeping the same call parameters yields no issue. I was wondering whether you had an idea about that, something I'm doing wrong ? That worked fine with Gloo if I remember correctly

cc @mrshenli

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@blefaudeux has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@blefaudeux has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@blefaudeux has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@blefaudeux merged this pull request in 87fb370.

return value


def _broadcast_object(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late comment, but can we modify this to use the existing broadcast_object_list in pytorch native now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes we can, there are a couple of follow up PRs planned, I removed it at some point because there were some deadlocks on CI but I guess that I was not using it right

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@blefaudeux Can you tell me more about the details about the deadlocks on CI? My training procedure also got stuck intermittently, and the error is not reproducible so getting a hard time solving the problem.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it safer to use broadcast_object_list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just that it's part of pytorch now, and less code is better code I think. I'm not sure for the reason for deadlocks, a guess was that I was not using broadcast_object_list properly since it's guessing the receiving device by context instead of being explicit, but I'm not 100% sure of that

@blefaudeux blefaudeux deleted the sharded_optimizer branch February 28, 2021 03:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants