-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[ONNX] Support custom axis name through dynamic_shapes #146321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Support custom axis name through dynamic_shapes #146321
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146321
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 3 PendingAs of commit d59f8b8 with merge base e3839bd ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
return False | ||
|
||
|
||
def convert_str_to_export_dim( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possible to have unit tests for this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
) | ||
|
||
|
||
def create_rename_mapping( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possible to have unit tests for this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This involves to create ir.Inputs and it's only about mapping. Unless you have some corner cases in mind that we should keep an eye on? I can make that a test.
# When the axis is static, or it connects to _DimHint in dynamic shapes, we skip renaming | ||
for idx, axes in enumerate(flat_dynamic_shapes): | ||
input = inputs[idx] | ||
if isinstance(axes, dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about nested list or dictionary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The list and dict here are leaves that are kept from flatten in line 227. Please check the function _flatten_dynamic_shapes_to_axes
. That is, at this point, there will not be any nested list and dict.
isinstance(x, dict) | ||
and all( | ||
isinstance(k, int) | ||
and (v is None or isinstance(v, (_Dim, _DimHint, str))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DimDerived?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's covered:
pytorch/torch/export/dynamic_shapes.py
Line 142 in 206ad9f
class _DerivedDim(_Dim): |
Or there is another instance you are referring to?
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
befd357
to
d59f8b8
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
@pytorchbot merge -f "ROCM tests queued but never runs and they are unrelated to the PR." |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes #143443 This PR aims to support custom dynamic axis naming through dynamic_shapes. Currently, _Dim and _DimHint do not support dynamic axis naming (#144273). 1. **the original dynamic shapes guarantee** The axis renaming is only applied when dynamic shapes include string instead of all _Dim and _DimHint. Thus, there will not be any inconsistent behavior to dynamic_shapes with torch.export.export if the given dynamic shapes follow torch.export.export format. 2. _DimHint.AUTO is applied to the axes that are specified with custom names to avoid exporter crash. (_DimHint.DYNAMIC crashes when the export fails.) 3. There's no need to handle cases where kwargs are out of order with the model signature, as torch.export.export supports dynamism only when kwargs and dynamic_shapes are provided in order. https://github.com/pytorch/pytorch/blob/49082f9dba3b79a344cb03652972ddbe7c3729cc/torch/export/_trace.py#L2034 4. If `torch.onnx.ExportedProgram` finds the axes share the same constraints, they will have the same name (e.g. s0, s1, ...). Therefore, even if the ONNX users specify them with different custom names, they won't be respected. Example model: ```python class NestedModel(torch.nn.Module): def forward( self, x: torch.Tensor, ys: list[torch.Tensor], zs: dict[str, torch.Tensor], c: torch.Tensor, ): y = ys[0] + ys[1] + zs["a"] + zs["b"] w = 5 if x.shape[0] < 3 and c.shape[0] != 4: return x + w, x + y, c else: return x - w, x - y, c input = ( torch.ones(5), [torch.zeros(5), torch.ones(5)], {"a": torch.zeros(5), "b": torch.ones(5)}, torch.ones(6), ) dynamic_shapes = ( {0: torch.export.Dim("dim_x", min=3)}, # _Dim [("custom_name_axis_ys_0",), (torch.export.Dim.AUTO,)], # custom name { "a": {0: torch.export.Dim.AUTO}, "b": ("custom_name_axis_zs_b_0",), }, # _DimHint {0: "custom_name_axis_c_0"}, # custom name ) ``` Pull Request resolved: #146321 Approved by: https://github.com/justinchuby
## Describe your changes With pytorch/pytorch#146321, `torch.onnx.export(.., dynamo=True)` now can support string in dynamic_shapes, which fits better with Olive driven with configuration. Major changes: - Add support for string in dynamic_shapes - Move dynamic_shapes pre-process to io_config.py (like dynamic_axes) - Get rid of the post-process of making [str, int, int] -> torch.export.Dim(str, max=int, min=int). `torch.onn.export(..., dynamo=True)` can now take string. - Leverage Optimum to auto-generate dynamic_shapes when Optimum models is requested. - KV cache support Pitfall: - When dynamic_shapes targets kwargs, both of them need to follow the order of model.forward signature. onnx/conversion.py provides naive approach to sort them, but users should be aware of this. Minor changes: - Move onnxscript (released) to the official requirement.txt - dynamic_shapes with string is supported since torch 2.7 ## Checklist before requesting a review - [x] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [x] Lint and apply fixes to your code by running `lintrunner -a` - [x] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. - [ ] Is this PR including examples changes? If yes, please remember to update [example documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md) in a follow-up PR. ## (Optional) Issue link
Fixes pytorch#143443 This PR aims to support custom dynamic axis naming through dynamic_shapes. Currently, _Dim and _DimHint do not support dynamic axis naming (pytorch#144273). 1. **the original dynamic shapes guarantee** The axis renaming is only applied when dynamic shapes include string instead of all _Dim and _DimHint. Thus, there will not be any inconsistent behavior to dynamic_shapes with torch.export.export if the given dynamic shapes follow torch.export.export format. 2. _DimHint.AUTO is applied to the axes that are specified with custom names to avoid exporter crash. (_DimHint.DYNAMIC crashes when the export fails.) 3. There's no need to handle cases where kwargs are out of order with the model signature, as torch.export.export supports dynamism only when kwargs and dynamic_shapes are provided in order. https://github.com/pytorch/pytorch/blob/49082f9dba3b79a344cb03652972ddbe7c3729cc/torch/export/_trace.py#L2034 4. If `torch.onnx.ExportedProgram` finds the axes share the same constraints, they will have the same name (e.g. s0, s1, ...). Therefore, even if the ONNX users specify them with different custom names, they won't be respected. Example model: ```python class NestedModel(torch.nn.Module): def forward( self, x: torch.Tensor, ys: list[torch.Tensor], zs: dict[str, torch.Tensor], c: torch.Tensor, ): y = ys[0] + ys[1] + zs["a"] + zs["b"] w = 5 if x.shape[0] < 3 and c.shape[0] != 4: return x + w, x + y, c else: return x - w, x - y, c input = ( torch.ones(5), [torch.zeros(5), torch.ones(5)], {"a": torch.zeros(5), "b": torch.ones(5)}, torch.ones(6), ) dynamic_shapes = ( {0: torch.export.Dim("dim_x", min=3)}, # _Dim [("custom_name_axis_ys_0",), (torch.export.Dim.AUTO,)], # custom name { "a": {0: torch.export.Dim.AUTO}, "b": ("custom_name_axis_zs_b_0",), }, # _DimHint {0: "custom_name_axis_c_0"}, # custom name ) ``` Pull Request resolved: pytorch#146321 Approved by: https://github.com/justinchuby
Fixes pytorch#143443 This PR aims to support custom dynamic axis naming through dynamic_shapes. Currently, _Dim and _DimHint do not support dynamic axis naming (pytorch#144273). 1. **the original dynamic shapes guarantee** The axis renaming is only applied when dynamic shapes include string instead of all _Dim and _DimHint. Thus, there will not be any inconsistent behavior to dynamic_shapes with torch.export.export if the given dynamic shapes follow torch.export.export format. 2. _DimHint.AUTO is applied to the axes that are specified with custom names to avoid exporter crash. (_DimHint.DYNAMIC crashes when the export fails.) 3. There's no need to handle cases where kwargs are out of order with the model signature, as torch.export.export supports dynamism only when kwargs and dynamic_shapes are provided in order. https://github.com/pytorch/pytorch/blob/49082f9dba3b79a344cb03652972ddbe7c3729cc/torch/export/_trace.py#L2034 4. If `torch.onnx.ExportedProgram` finds the axes share the same constraints, they will have the same name (e.g. s0, s1, ...). Therefore, even if the ONNX users specify them with different custom names, they won't be respected. Example model: ```python class NestedModel(torch.nn.Module): def forward( self, x: torch.Tensor, ys: list[torch.Tensor], zs: dict[str, torch.Tensor], c: torch.Tensor, ): y = ys[0] + ys[1] + zs["a"] + zs["b"] w = 5 if x.shape[0] < 3 and c.shape[0] != 4: return x + w, x + y, c else: return x - w, x - y, c input = ( torch.ones(5), [torch.zeros(5), torch.ones(5)], {"a": torch.zeros(5), "b": torch.ones(5)}, torch.ones(6), ) dynamic_shapes = ( {0: torch.export.Dim("dim_x", min=3)}, # _Dim [("custom_name_axis_ys_0",), (torch.export.Dim.AUTO,)], # custom name { "a": {0: torch.export.Dim.AUTO}, "b": ("custom_name_axis_zs_b_0",), }, # _DimHint {0: "custom_name_axis_c_0"}, # custom name ) ``` Pull Request resolved: pytorch#146321 Approved by: https://github.com/justinchuby
Fixes #143443
This PR aims to support custom dynamic axis naming through dynamic_shapes. Currently, _Dim and _DimHint do not support dynamic axis naming (#144273).
The axis renaming is only applied when dynamic shapes include string instead of all _Dim and _DimHint. Thus, there will not be any inconsistent behavior to dynamic_shapes with torch.export.export if the given dynamic shapes follow torch.export.export format.
as torch.export.export supports dynamism only when kwargs and dynamic_shapes are provided in order.
pytorch/torch/export/_trace.py
Line 2034 in 49082f9
torch.onnx.ExportedProgram
finds the axes share the same constraints, they will have the same name (e.g. s0, s1, ...). Therefore, even if the ONNX users specify them with different custom names, they won't be respected.Example model: