Skip to content

Conversation

@PKUWZP
Copy link
Collaborator

@PKUWZP PKUWZP commented Jan 20, 2026

Kudos to @pengdurice for the great work and discussions!

We aim on adding Muon Optimizer to zero stage 3 in this PR:

  • Created a dedicated momentum buffer in zero stage 3 optimizer to save the momentum buffers specifically for Muon Optimizer.
  • The optimizer states can be dispatched into 3 devices: GPU, CPU and NVME. For GPU and CPU, we just make the new buffers the same device of self.fp32_partitioned_groups_flat; when device == NVME, we make sure that the momentum buffers can be swapped in and out along with other components in the optimizer states.
  • The new momentum buffers are also partitioned like self.fp32_partitioned_groups_flat to save memory footprint. So, before the muon update, we need to perform all_gather on top of each data-parallel group rank. The Muon updates of the parameters are also divided across the data-parallel ranks, and the results are all-gathered once all updates are complete. After the all_gather, the momentum buffers are partitioned and flatted again.

Next steps:

  • Explore quantization of momentum buffers for saving memory
  • Explore using highly optimized Adam / AdamW Optimizers

for idx, dest_offset in params_to_subgroup_maps[i]:
momentum_buffer[idx] = self.optimizer.state[
self.fp32_partitioned_groups_flat[i]]["momentum_buffer"].narrow(
0, dest_offset, param.partition_numel()).clone()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pengdurice Here is a bug. The variable param refers to the last parameter from the previous loop (for param in self.ipg_buckets[...]), not the parameter corresponding to idx. We should change it touse_muon_params[idx].partition_numel().

Copy link
Collaborator Author

@PKUWZP PKUWZP left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pengdurice Also two more comments:

  • It seems that we have excessive tensor allocations: Multiple torch.empty, torch.zeros, and .clone() calls create memory footprint pressure. Consider reusing buffers where possible.

  • Synchronous all_gather: The distributed operations could potentially be overlapped with computation.

I think we need to re-work on the PR and let's take some times to refine the code.

self.dp_process_group = self.parameter_offload.dp_process_group
self.sequence_parallel_size = groups._get_sequence_parallel_world_size()

self.all2all_process_group = all2all_process_group
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pengdurice Question: where did we set up the all2all_process_group? It seems that it's never set.

] + [torch.empty_like(params[-1].grad)] * (world_sz - len(params) % world_sz)
gathered_momentums_pad = gathered_momentums + [torch.empty_like(gathered_momentums[-1])
] * (world_sz - len(gathered_momentums) % world_sz)
for base_i in range(len(params))[::world_sz]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pengdurice There's a padding error here. When len(params) % world_sz == 0, this adds world_sz empty tensors instead of 0. Should we change it to: (world_sz - len(params) % world_sz) % world_sz ?


self.reduce_scatter = reduce_scatter

self.use_muon = 'muon' in self.optimizer.__class__.__name__.lower()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pengdurice This is very fragile and purely depends on the naming conventions. Can we leverage isinstance() instead?

self.reduce_scatter = reduce_scatter

self.use_muon = 'muon' in self.optimizer.__class__.__name__.lower()
self.save_muon_momentum_buffer_in_memory = ds_config.get('save_muon_momentum_buffer_in_memory', False)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pengdurice Can we add save_muon_momentum_buffer_in_memory to the config schema and documented?

params_to_subgroup_maps[i].append((idx, dest_offset))
idx += 1
params_size_offset += param.grad.numel()
# if optimizer is swappable, swap in the momentum buffer of the parameters that need to be updated using muon and then swap them out
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pengdurice This doubles NVMe I/O overhead. Can we consider consolidating into a single swap in/out cycle?

communication_data_type)
for i in params_to_subgroup_maps:
if self._swappable_optimizer_subgroup(i) and not self.save_muon_momentum_buffer_in_memory:
self.optimizer_swapper.swap_in_optimizer_state(parameter=self.fp32_partitioned_groups_flat[i])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pengdurice Again same thing here, can we consolidate the two swaps into one swap?

@PKUWZP PKUWZP closed this Jan 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants