-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Enable fixed fast_mode for complex #55699
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 04831f0 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Todo: - error message should be updated to say whether the failure is for fn's real or imaginary component [ghstack-poisoned]
Todo: - error message should be updated to say whether the failure is for fn's real or imaginary component [ghstack-poisoned]
Todo: - error message should be updated to say whether the failure is for fn's real or imaginary component [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic looks ok for me.
I wonder if there would be a way to have u
as an object so that you don't need to add a new argument to every function.
That object would contain one or two real valued vectors.
And you unpack it properly depending if you're considering complex or not when you need it. What do you think?
torch/autograd/gradcheck.py
Outdated
@@ -201,7 +201,7 @@ def compute_numerical_jacobian_cols(jvp_fn, delta, input_is_complex) -> List[tor | |||
ds_dx_tup = jvp_fn(delta) | |||
|
|||
if input_is_complex: # C -> R | |||
ds_dy_tup = jvp_fn(delta * 1j) | |||
ds_dy_tup = jvp_fn(delta * 1j) if delta_i is None else jvp_fn(delta_i * 1j) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the input is complex, delta_i should never be None right?
Or the slow version uses same delta for both?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slow version does not use delta_i. But now I've changed it to use a single parameter delta object which may be either tuple or tensor or python number. It will be always python number in the slow case, and tuple or tensor in the fast case depending on whether that input is complex.
Having a single object |
Todo: - error message should be updated to say whether the failure is for fn's real or imaginary component [ghstack-poisoned]
Todo: - error message should be updated to say whether the failure is for fn's real or imaginary component [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Just a couple comments to update and it is good!
else: | ||
if u.layout != torch.sparse_coo: | ||
return u.reshape(shape) | ||
return u |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: add a comment here saying that we don't need to reshape for sparse Tensors.
@@ -815,6 +843,7 @@ def adjusted_atol(atol, u, v): | |||
# | |||
# We see that atol needs to be scaled by v^T M u (where M is an all-ones M x N matrix): | |||
# v^T M u = \sum_{i} \sum_{j} u_i * v_j = (\sum_{i} u_i)(\sum_{i} v_i) | |||
u = u[0] if isinstance(u, tuple) else u |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a comment explaining why we ignore the second part of u? Or a todo if that needs to be done later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah still need to think about how to handle the complex case here. I've moved the other TODO you mentioned to here.
Todo: - error message should be updated to say whether the failure is for fn's real or imaginary component [ghstack-poisoned]
@soulitzer merged this pull request in 201ad93. |
Summary: Pull Request resolved: pytorch#55699 Todo: - error message should be updated to say whether the failure is for fn's real or imaginary component Test Plan: Imported from OSS Reviewed By: H-Huang Differential Revision: D28007887 Pulled By: soulitzer fbshipit-source-id: 1819201f59c8586a1d9631db05983969438bde66
Summary: Pull Request resolved: pytorch#55699 Todo: - error message should be updated to say whether the failure is for fn's real or imaginary component Test Plan: Imported from OSS Reviewed By: H-Huang Differential Revision: D28007887 Pulled By: soulitzer fbshipit-source-id: 1819201f59c8586a1d9631db05983969438bde66
Stack from ghstack:
Todo:
Differential Revision: D28007887