
Data Structure
Networking
RDBMS
Operating System
Java
MS Excel
iOS
HTML
CSS
Android
Python
C Programming
C++
C#
MongoDB
MySQL
Javascript
PHP
- Selected Reading
- UPSC IAS Exams Notes
- Developer's Best Practices
- Questions and Answers
- Effective Resume Writing
- HR Interview Questions
- Computer Glossary
- Who is Who
How to perform an expand operation in PyTorch?
Tensor.expand() attribute is used to perform expand operation. It expands the Tensor to new dimensions along the singleton dimension.
Expanding a tensor only creates a new view of the original tensor; it doesn't make a copy of the original tensor.
If you set a particular dimension as -1, the tensor will not be expanded along this dimension.
For example, if we have a tensor of size (3,1), we can expand this tensor along the dimension of size 1.
Steps
To expand a tensor, one could follow the steps given below −
Import the torch library. Make sure you have already installed it.
import torch
Define a tensor having at least one dimension as singleton.
t = torch.tensor([[1],[2],[3]])
Expand the tensor along the singleton dimension. Expanding along a non-singleton dimension will throw a Runtime Error (see Example 3).
t_exp = t.expand(3,2)
Display the expanded tensor.
print("Tensor after expand:
", t_exp)
Example 1
The following Python program shows how to expand a tensor of size (3,1) to a tensor of size (3,2). It expands the tensor along the dimension size of 1. The other dimension of size 3 remains unchanged.
# import required libraries import torch # create a tensor t = torch.tensor([[1],[2],[3]]) # display the tensor print("Tensor:
", t) print("Size of Tensor:
", t.size()) # expand the tensor exp = t.expand(3,2) print("Tensor after expansion:
", exp)
Output
Tensor: tensor([[1], [2], [3]]) Size of Tensor: torch.Size([3, 1]) Tensor after expansion: tensor([[1, 1], [2, 2], [3, 3]])
Example 2
The following Python program expands a tensor of size (1,3) to a tensor of size (3,3). It expands the tensor along the dimension size of 1.
# import required libraries import torch # create a tensor t = torch.tensor([[1,2,3]]) # display the tensor print("Tensor:
", t) # size of tensor is [1,3] print("Size of Tensor:
", t.size()) # expand the tensor expandedTensor = t.expand(3,-1) print("Expanded Tensor:
", expandedTensor) print("Size of expanded tensor:
", expandedTensor.size())
Output
Tensor: tensor([[1, 2, 3]]) Size of Tensor: torch.Size([1, 3]) Expanded Tensor: tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]]) Size of expanded tensor: torch.Size([3, 3])
Example 3
In the following Python program, we tried to expand the tensor along a nonsingleton dimension, hence it throws a Runtime Error.
# import required libraries import torch # create a tensor t = torch.tensor([[1,2,3]]) # display the tensor print("Tensor:
", t) # size of tensor is [1,3] print("Size of Tensor:
", t.size()) t.expand(3,4)
Output
Tensor: tensor([[1, 2, 3]]) Size of Tensor: torch.Size([1, 3]) RuntimeError: The expanded size of the tensor (4) must match the existing size (3) at non-singleton dimension 1. Target sizes: [3, 4]. Tensor sizes: [1, 3]