Transposing and permuting tensors are a common thing to do. However, PyTorch does it slightly differently than what many people are used to from e.g. NumPy.
import numpy as np
import matplotlib.pyplot as plt
import torch
import tensorflow as tf
print("NumPy:",np.__version__)
print("PyTorch",torch.__version__)
print("TensorFlow",tf.__version__)
NumPy: 1.22.2
PyTorch 1.11.0
TensorFlow 2.8.0
Let’s initialize a 5-dimensional tensor (2D would be too simple):
tensor_np = np.random.random((3,3,3,3,3))
tensor_pt = torch.Tensor(tensor_np)
tensor_tf = tf.convert_to_tensor(tensor_np)
We are going to visualize a slice of it:
plt.figure(figsize=(9,5))
plt.subplot(1,3,1)
plt.imshow(tensor_np[:,:,:,0,0])
plt.title("tensor_np[:,:,:,0,0]")
plt.subplot(1,3,2)
plt.imshow(tensor_pt[:,:,:,0,0])
plt.title("tensor_pt[:,:,:,0,0]")
plt.subplot(1,3,3)
plt.imshow(tensor_tf[:,:,:,0,0])
plt.title("tensor_tf[:,:,:,0,0]")
plt.show()

Let’s use the default transpose functions:
tensor_np_T = np.transpose(tensor_np) # tensor_np.T
tensor_pt_T = tensor_pt.T
tensor_tf_T = tf.transpose(tensor_tf)
Why can’t we use torch.transpose(tensor_pt, dim0, dim1)
? Well. torch.transpose
supports only swapping of two axes and not more.

Let’s change arbitrary tensor axes. In this case we have to use the tensor.permute()
attribute with PyTorch.
tensor_np_P = np.transpose(tensor_np, (2,1,3,4,0))
tensor_pt_P = tensor_pt.permute(2,1,3,4,0)
tensor_tf_P = tf.transpose(tensor_tf, (2,1,3,4,0))

This is how we transpose/permute tensors.
NB!: Both NumPy and PyTorch will create non-contiguous outputs whereas TensorFlow does not.
print(tensor_np_P.data.contiguous)
print(tensor_pt_P.is_contiguous())
False
False