Skip to content

Commit d0d23be

Browse files
joshuadengfacebook-github-bot
authored andcommitted
KJT a2a redesign and input dist fusion (#797)
Summary: Pull Request resolved: #797 # tldr: * make it easier for users add new variants of the KeyedJaggedTensor (KJT) * make it easier to automate communication for these new KJT variants * improve QPS when models contain a high number of KJTs. We expanded kjt interface to include these new calls: * `dist_labels` -> names of tensors to transmit * `dist_splits` -> shapes of internal tensors, split by key * `dist_tensors` -> the actual tensor data to transmit * `dist_init` -> builds a new kjt from raw collective output These methods define the data the a2a collective will transmit from the KJT, and so the actual KJT a2a will work for KJT variants that properly implement these methods. Next we changed the KJT a2a such that its first awaitable is to transmit the tensor splits so each rank will know the size of the tensors it will receive (this collective call is blocking). The second awaitable is asynchronous and it transmits the provided tensor data to the correct ranks. Previously it was an explicit awaitable for the lengths which was blocking, followed by an awaitable for the indices (values). Finally we enabled input dist fusion by delaying calling wait() on the first awaitable in the KJT a2a until gathering all KJTs then calling the first wait in succession (becoming the cpu blocking part), and then calling the second awaitable to transmit the actual tensor data completely asynchronously. Reviewed By: dstaay-fb, RenfeiChen-FB Differential Revision: D39520093 LaMa Project: L1138451 fbshipit-source-id: b7892a3af4e26c99ac01a88dd2c696e79e89cf42
1 parent a6b1e00 commit d0d23be

34 files changed

+445
-1479
lines changed

torchrec/distributed/comm_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def alltoall_sequence(
355355
AlltoAll.
356356
input_splits (Tensor): input splits.
357357
output_splits (Tensor): output splits.
358-
variable_batch_size (bool): whether varibale batch size is enabled
358+
variable_batch_size (bool): whether variable batch size is enabled
359359
group (Optional[dist.ProcessGroup]): The process group to work on. If None, the
360360
default process group will be used.
361361
codecs: Optional[QuantizedCommCodecs]: Quantized communication codecs

0 commit comments

Comments
 (0)