Question
Apply permutation array on multiple axes in numpy
Let's say I have an array of permutations perm
which could look like:
perm = np.array([[0, 1, 2], [1, 2, 0], [0, 2, 1], [2, 1, 0]])
If I want to apply it to one axis, I can write something like:
v = np.arange(9).reshape(3, 3)
print(v[perm])
Output:
array([[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]],
[[3, 4, 5],
[6, 7, 8],
[0, 1, 2]],
[[0, 1, 2],
[6, 7, 8],
[3, 4, 5]],
[[6, 7, 8],
[3, 4, 5],
[0, 1, 2]]])
Now I would like to apply it to two axes at the same time. I figured out that I can do it via:
np.array([v[tuple(np.meshgrid(p, p, indexing="ij"))] for p in perm])
But I find it quite inefficient, because it has to create a mesh grid, and it also requires a for loop. I made a small array in this example but in reality I have a lot larger arrays with a lot of permutations, so I would really love to have something that's as quick and simple as the one-axis version.