-
Notifications
You must be signed in to change notification settings - Fork 558
Description
Dynamo backend for torchxla2
Goal
Have a dynamo backend backend by torch_xla2.
The users should be able to do the following:
m = model ...
m_compiled = torch.compile(m, backend='torch_xla2_compile') # backend name TBD
result = m_compiled(*inputs)
The above should run on TPU will low overhead.
Challenge
Usually the challenge of a dynamo backend is the compiler that
transforms a fx graph with torch (or Aten) ops to the compiled executable.
However, in our case, that piece is solved.
For every call_function
node; we lookup the corresponding implementation of
said ATen op in a dictionary for it's corresponding implementation in Jax,
and we just call it.
This is illustrated here: https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/torch_xla2/export.py#L23
Now, the challenge is for dynamo to be able to 1. produce the graph; and 2. n
not incur any data copies in this process.
Consider this following pseudocode:
class XLATensor2:
_data: jax.Array
def __torch_dispatch__(...):
# do stuff with _data, get new data
return XLATensor2(new_data)
def dynamo_backend(fx, sample):
compiled = compile fx into graph that manipulate jax.Array.
def returned_callable(inputs):
datas = [i._data for i in inputs]
res = compiled(*datas)
return TensorSubclass(res)
return returned_callable
model = torch.compile(model, backend = dynamo_backend)
inputs = a list of TensorSubclass or a list of torch.Tensor?
model(*inputs)
What would be the type of inputs?
If inputs are of type TensorSubclass
, then dynamo
will attempt to trace through the __torch_dispatch__
method,
and throws error because it doesn't know what is _data
and the
operations on it.
If inputs
is of type torch.Tensor
, then it works: dynamo
calls the backend, the backend can produce correct result.
But, inputs
need to be converted to TensorSubclass
first inside of
the backend; which usually means a data copy. This happens everytime
the compiled backend is executed, therefore not desirable.
The Desired behavior
When tracing dynamo treats TensorSubclass as if it is a regular tensor
without dispatch override; and when executing the compiled callable,
TensorSubclass is passed in as-is. We know that dynamo can do this with
some tensor subclass, namely FakeTensor
.
Let's list out the possible ways we could accomplish this behavior.
Option 1. Have the jax.Array object hold in C++
Roughly we would have a Tensor
subclass in C++, this is very
similar to the LazyTensor
subclass that is the current XLATensor
.
This tensor can hold it's own states in C++. In our case, that would
be a PyObject*
that happens to point to either jnp.ndarray
or
jax's Traced<ShapedArray>
during jax.jit. We might further result the
XLA
dispatch key to route the operators to the jax implementation,
emulating what __torch_dispatch__
does.
This way, eager mode will continue to work, and dynamo would work
because the Python class is still torch.Tensor
(not a subclass), and
there are no Python logic in dispatching so dynamo cannot trace through.
Pros:
- Very clear that this will work.
Cons:
Now need to deal with C++ builds. In particular, torch
becomes a source
dependency instead of a pip dependency; meaning, again we need to start
building torch first then build torch_xla2. This might be mitigated if
that subclass can be upstreamed.
Option 2. Modify dynamo to do the desired behavior
We have one instance where a torch.Tensor
dispatch subclass
just works with dynamo, without dynamo make a fuss when it traces
__torch_dispatch__
. This is FakeTensor
. (https://github.com/pytorch/pytorch/pull/100017/files)
The idea is to make dynamo trace as-if the inputs are FakeTensor
and
not XLATensor
. and only after the creation of fx graph and backend, dynamo
calls the compiled callable with XLATensor
.
Pros:
- Likely pure python changes.
Cons:
- We also need to design a mechanism to represent tensor subclasses that
is desirable for dynamo to trace through, and those is not. - Likely significant amount of work.
Option 3. Register All the ops as custom_ops
So currently dynamo traces __torch_dispatch__
, and we don't like that
because it will find the operations on Jax arrays, and doesn't understand those.
What if we make dynamo able to understand what is inside?
The Black box python functions doc
points the possibility of registering things that we don't want dynamo
to go into as a custom op. So we could, theoretically do the following:
- Register the jax impl of an Aten op as a custom op.
i.e. registerjaten.add
foraten.add
. - For meta kernels, just call the meta kernel of
aten.add
. - In
__torch_dispatch__
, we forward the call fromaten.add
tojaten.add
.
When dynamo attempts to go inside of __torch_dispatch__
, it will find
jaten.add
. Then it will record that in the fx.Graph
.
Our backend will see the same ops but in a different namespace (jaten
).
That is fine as long as we know how to look up its implementation.
Note: we probably also need to hook up gradients of custom ops via. autograph.Function
.
Pros / Cons:
Haven't tried, don't know if it gonna work or not.
Current standing proposal
Current standing proposal is Option 2.
Meeting notes (so far):
2024-05-29
with @ezyang @williamwen42 @wconstab @JackCaoG @shauheen @Chillee @yanboliang
Went over the 3 options. Opinions split between option 1 and 2. People seems to agree that making 1 working is desired and the work done to make it work is considered "cost of integration".
People also discussed a bit on whether interoperability with Jax should be a valid use case:
i.e.
def f(jax_array_1, jax_array_2):
wraps jax_array_1, jax_array_2 into XLATensor2
call torch
return unwraped
which is a valid Jax function; can it be used in with jax.grad
or jax.jit
.
@Chillee raised a point that if we use jax.grad
to get the gradient and train a model; it might yield different behavior if the user have custom backward hooks in their code.
@williamwen42 Suggested to use this option for torchdynamo:
torch._dynamo.config.traceable_tensor_subclasses.add(
torch_xla2.tensor.XLATensor2)
With this suggestion, dynamo tracing succeeded and called backend
with the correct tensor subclass and graph (desired behavior).
However, it raised an error when the backend attempted to construct XLATensor2
for return value.
Details on the script ran: https://gist.github.com/qihqi/aa4fd50e5ef3cb96598433bd0f62817c?fbclid=IwAR3v5GQwYFUmlGxukEfriucav4f-ybMJ4yVA97I4cslQzCg8b7CF8VKIBac