Question

Apply permutation array on multiple axes in numpy

Let's say I have an array of permutations perm which could look like:

perm = np.array([[0, 1, 2], [1, 2, 0], [0, 2, 1], [2, 1, 0]])

If I want to apply it to one axis, I can write something like:

v = np.arange(9).reshape(3, 3)
print(v[perm])

Output:

array([[[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]],

       [[3, 4, 5],
        [6, 7, 8],
        [0, 1, 2]],

       [[0, 1, 2],
        [6, 7, 8],
        [3, 4, 5]],

       [[6, 7, 8],
        [3, 4, 5],
        [0, 1, 2]]])

Now I would like to apply it to two axes at the same time. I figured out that I can do it via:

np.array([v[tuple(np.meshgrid(p, p, indexing="ij"))] for p in perm])

But I find it quite inefficient, because it has to create a mesh grid, and it also requires a for loop. I made a small array in this example but in reality I have a lot larger arrays with a lot of permutations, so I would really love to have something that's as quick and simple as the one-axis version.

 5  82  5
1 Jan 1970

Solution

 4

How about:

p1 = perm[:, :, np.newaxis]
p2 = perm[:, np.newaxis, :]
v[p1, p2]

The zeroth axis of p1 and p2 is just the "batch" dimension of perm, which allows you to do many permutations in one operation.

The other dimension of perm, which corresponds with the indices, is aligned along the first axis in p1 and the second in p2. Because the axes are orthogonal, the arrays get broadcasted, basically like the arrays you got using meshgrid - but these still have the batch dimension.

That's the best I can do from my cell phone : ) I can try to clarify later if needed, but the key idea is broadcasting.

Comparison:

import numpy as np
perm = np.array([[0, 1, 2], [1, 2, 0], [0, 2, 1], [2, 1, 0]])
v = np.arange(9).reshape(3, 3)

ref = np.array([v[tuple(np.meshgrid(p, p, indexing="ij"))] for p in perm])

p1 = perm[:, :, np.newaxis]
p2 = perm[:, np.newaxis, :]
res = v[p1, p2]

np.testing.assert_equal(res, ref)
# passes

%timeit np.array([v[tuple(np.meshgrid(p, p, indexing="ij"))] for p in perm])
# 107 µs ± 20.6 µs per loop

%timeit v[perm[:, :, np.newaxis], perm[:, np.newaxis, :]]
# 3.73 µs ± 1.07 µs per loop

A simpler (without batch dimension) example of broadcasting indices:

import numpy as np
i = np.arange(3)
ref = np.meshgrid(i, i, indexing="ij")
res = np.broadcast_arrays(i[:, np.newaxis], i[np.newaxis, :])
np.testing.assert_equal(res, ref)
# passes

In the solution code at the top, the broadcasting is implicit. We don't need to call broadcast_arrays because it happens automatically during the indexing.

2024-07-15
Matt Haberland

Solution

 2

You can get rid of the meshgrid with

a = np.array([v[p][:,p] for p in perm])
b = np.array([v[tuple(np.meshgrid(p, p, indexing="ij"))] for p in perm])
print(np.all(b == a)) # True

This is 5x faster on your example array:

import timeit
%timeit np.array([v[tuple(np.meshgrid(p, p, indexing="ij"))] for p in perm]) # 42.7 µs
%timeit np.array([v[p][:,p] for p in perm]) # 8.18 µs

I would assume the for loop to be mostly irrelevant, if you are concerned with further optimization, please specify the shapes you are working with...

2024-07-15
Julien