Skip to content

Latest commit

 

History

History
99 lines (87 loc) · 9.44 KB

tensor_indexing.rst.txt

File metadata and controls

99 lines (87 loc) · 9.44 KB

Tensor Indexing API

Indexing a tensor in the PyTorch C++ API works very similar to the Python API. All index types such as None / ... / integer / boolean / slice / tensor are available in the C++ API, making translation from Python indexing code to C++ very simple. The main difference is that, instead of using the []-operator similar to the Python API syntax, in the C++ API the indexing methods are:

  • torch::Tensor::index (link)
  • torch::Tensor::index_put_ (link)

It's also important to note that index types such as None / Ellipsis / Slice live in the torch::indexing namespace, and it's recommended to put using namespace torch::indexing before any indexing code for convenient use of those index types.

Here are some examples of translating Python indexing code to C++:

Getter

Python C++ (assuming using namespace torch::indexing)
tensor[None] tensor.index({None})
tensor[Ellipsis, ...] tensor.index({Ellipsis, "..."})
tensor[1, 2] tensor.index({1, 2})
tensor[True, False] tensor.index({true, false})
tensor[1::2] tensor.index({Slice(1, None, 2)})
tensor[torch.tensor([1, 2])] tensor.index({torch::tensor({1, 2})})
tensor[..., 0, True, 1::2, torch.tensor([1, 2])] tensor.index({"...", 0, true, Slice(1, None, 2), torch::tensor({1, 2})})

Setter

Python C++ (assuming using namespace torch::indexing)
tensor[None] = 1 tensor.index_put_({None}, 1)
tensor[Ellipsis, ...] = 1 tensor.index_put_({Ellipsis, "..."}, 1)
tensor[1, 2] = 1 tensor.index_put_({1, 2}, 1)
tensor[True, False] = 1 tensor.index_put_({true, false}, 1)
tensor[1::2] = 1 tensor.index_put_({Slice(1, None, 2)}, 1)
tensor[torch.tensor([1, 2])] = 1 tensor.index_put_({torch::tensor({1, 2})}, 1)
tensor[..., 0, True, 1::2, torch.tensor([1, 2])] = 1 tensor.index_put_({"...", 0, true, Slice(1, None, 2), torch::tensor({1, 2})}, 1)

Translating between Python/C++ index types

The one-to-one translation between Python and C++ index types is as follows:

Python C++ (assuming using namespace torch::indexing)
None None
Ellipsis Ellipsis
... "..."
123 123
True true
False false
: or :: Slice() or Slice(None, None) or Slice(None, None, None)
1: or 1:: Slice(1, None) or Slice(1, None, None)
:3 or :3: Slice(None, 3) or Slice(None, 3, None)
::2 Slice(None, None, 2)
1:3 Slice(1, 3)
1::2 Slice(1, None, 2)
:3:2 Slice(None, 3, 2)
1:3:2 Slice(1, 3, 2)
torch.tensor([1, 2]) torch::tensor({1, 2})