-
Notifications
You must be signed in to change notification settings - Fork 511
pin update #8908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pin update #8908
Conversation
I was able to compile by adding the following patch to OpenXLA:
|
Thank you @ysiraichi! I added this patch for now. |
Persistent cache test is failing on GPU, due to deserialization issue. Skipping the test for now and will file a Github Issue for this.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the amazing work, its a really huge change adopt PR, LGTM
Thanks @lsy323 for updating the pin. Regarding the paged_attention hang, could you update this line xla/torch_xla/experimental/custom_kernel.py Line 1212 in c044c69
step = torch.ones((1,), dtype=torch.int32).to("xla") ? It should make the test pass. I tested locally.
|
Thanks @vanbasten23! Updated the PR. Also do you mind elaborating a bit on this? |
#8908 accidentally enabled some pallas tests on CPU, which is not supported
Yeah, jax-ml/jax@8c73799 made a change (it's not a bug but a valid change). As a result, the torch_xla wrapper needs to change accordingly. |
Accommodate the following changes:
xla::Shape::rank()
is renamed toxla::Shape::dimensions_size
xla::Shape
ctor