Skip to content

[torch-xla 2.1 - 2.4] when functionalization is on, there are no aliasing for gradients when using gradient accumulation #7174

@jeffhataws

Description

@jeffhataws

🐛 Bug

When functionalization is on (XLA_DISABLE_FUNCTIONALIZATION=0), I see that there are fewer aliased tensors. Jack has a patch to increase the number of aliased tensors e3fc033 . However, even though this change helped increase the number of aliased tensor, it seems to still missing aliasing for gradients when gradient accumulation is used.

Using test_train_mp_mnist.py, make the modifications below. I added a mark_step to isolate the gradient accumulation loops.

@ -158,16 +163,19 @@ def train_mnist(flags, **kwargs):
       output = model(data)
       loss = loss_fn(output, target)
       loss.backward()
-      if flags.ddp:
-        optimizer.step()
-      else:
-        xm.optimizer_step(optimizer)
-      tracker.add(flags.batch_size)
-      if step % flags.log_steps == 0:
-        xm.add_step_closure(
-            _train_update,
-            args=(device, step, loss, tracker, epoch, writer),
-            run_async=flags.async_closures)
+
+      if step % 4 == 0:
+          xm.mark_step()
+          if flags.ddp:
+            optimizer.step()
+          else:
+            xm.optimizer_step(optimizer)
+          tracker.add(flags.batch_size)
+          if step % flags.log_steps == 0:
+            xm.add_step_closure(
+                _train_update,
+                args=(device, step, loss, tracker, epoch, writer),
+                run_async=flags.async_closures)

I only see 2 alias even though we expect all the gradient tensors to be aliased:

2024-06-03 21:15:37.676472: I torch_xla/csrc/xla_graph_executor.cpp:1462] Parameter sequence graph hash b8e15ed0391b82171706a34d84ca8ea0
2024-06-03 21:15:37.678822: I torch_xla/csrc/xla_graph_executor.cpp:1299] Aliased paramter 13 with output 4: s64[]
2024-06-03 21:15:37.678862: I torch_xla/csrc/xla_graph_executor.cpp:1299] Aliased paramter 14 with output 5: s64[]
2024-06-03 21:15:37.679222: I torch_xla/csrc/xla_graph_executor.cpp:1397] Compiling IR graph hash b8e15ed0391b82171706a34d84ca8ea0 on device CPU:0 ...

To Reproduce

Steps to reproduce the behavior:

  1. Check out r2.1_aws_neuron branch
  2. Apply a patch from Jack e3fc033
  3. Build/install as in CONTRIBUTION doc
  4. Go into xla/test
  5. Edit test_train_mp_mnist.py and add gradient accumulation loop as above.
  6. Run with TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE="xla_graph_executor=6,pjrt_computation_client=5" to see aliasing debugging logs:
XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 XLA_SAVE_TENSORS_FMT="hlo" XLA_SAVE_TENSORS_FILE="/tmp/save1.hlo"   TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE="xla_graph_executor=6,pjrt_computation_client=5" python test_train_mp_mnist.py |& tee log

Expected behavior

Expect gradients to be aliased

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CPU
  • torch_xla version: 2.1

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions