-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Add Muon Optimizer Support to ZeRO3 #7797
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| for idx, dest_offset in params_to_subgroup_maps[i]: | ||
| momentum_buffer[idx] = self.optimizer.state[ | ||
| self.fp32_partitioned_groups_flat[i]]["momentum_buffer"].narrow( | ||
| 0, dest_offset, param.partition_numel()).clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pengdurice Here is a bug. The variable param refers to the last parameter from the previous loop (for param in self.ipg_buckets[...]), not the parameter corresponding to idx. We should change it touse_muon_params[idx].partition_numel().
PKUWZP
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pengdurice Also two more comments:
-
It seems that we have excessive tensor allocations: Multiple torch.empty, torch.zeros, and .clone() calls create memory footprint pressure. Consider reusing buffers where possible.
-
Synchronous
all_gather: The distributed operations could potentially be overlapped with computation.
I think we need to re-work on the PR and let's take some times to refine the code.
| self.dp_process_group = self.parameter_offload.dp_process_group | ||
| self.sequence_parallel_size = groups._get_sequence_parallel_world_size() | ||
|
|
||
| self.all2all_process_group = all2all_process_group |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pengdurice Question: where did we set up the all2all_process_group? It seems that it's never set.
| ] + [torch.empty_like(params[-1].grad)] * (world_sz - len(params) % world_sz) | ||
| gathered_momentums_pad = gathered_momentums + [torch.empty_like(gathered_momentums[-1]) | ||
| ] * (world_sz - len(gathered_momentums) % world_sz) | ||
| for base_i in range(len(params))[::world_sz]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pengdurice There's a padding error here. When len(params) % world_sz == 0, this adds world_sz empty tensors instead of 0. Should we change it to: (world_sz - len(params) % world_sz) % world_sz ?
|
|
||
| self.reduce_scatter = reduce_scatter | ||
|
|
||
| self.use_muon = 'muon' in self.optimizer.__class__.__name__.lower() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pengdurice This is very fragile and purely depends on the naming conventions. Can we leverage isinstance() instead?
| self.reduce_scatter = reduce_scatter | ||
|
|
||
| self.use_muon = 'muon' in self.optimizer.__class__.__name__.lower() | ||
| self.save_muon_momentum_buffer_in_memory = ds_config.get('save_muon_momentum_buffer_in_memory', False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pengdurice Can we add save_muon_momentum_buffer_in_memory to the config schema and documented?
| params_to_subgroup_maps[i].append((idx, dest_offset)) | ||
| idx += 1 | ||
| params_size_offset += param.grad.numel() | ||
| # if optimizer is swappable, swap in the momentum buffer of the parameters that need to be updated using muon and then swap them out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pengdurice This doubles NVMe I/O overhead. Can we consider consolidating into a single swap in/out cycle?
| communication_data_type) | ||
| for i in params_to_subgroup_maps: | ||
| if self._swappable_optimizer_subgroup(i) and not self.save_muon_momentum_buffer_in_memory: | ||
| self.optimizer_swapper.swap_in_optimizer_state(parameter=self.fp32_partitioned_groups_flat[i]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pengdurice Again same thing here, can we consolidate the two swaps into one swap?
Kudos to @pengdurice for the great work and discussions!
We aim on adding Muon Optimizer to zero stage 3 in this PR:
self.fp32_partitioned_groups_flat; whendevice == NVME, we make sure that the momentum buffers can be swapped in and out along with other components in the optimizer states.self.fp32_partitioned_groups_flatto save memory footprint. So, before the muon update, we need to performall_gatheron top of each data-parallel group rank. The Muon updates of the parameters are also divided across the data-parallel ranks, and the results are all-gathered once all updates are complete. After theall_gather, the momentum buffers are partitioned and flatted again.Next steps: