iterate_axes

iterate_axes(arr: ndarray, axes: tuple[int, ...])[source][source]

Iterate over all possible indices for a set of axes

Parameters:
  • arr (np.ndarray) – The array to iterate over

  • axes (tuple[int]) – The axes to iterate over

  • index (tuple[int]) – The current index

  • axis (int) – The current axis

Yields:

tuple[slice] – The indices for the current iteration

Examples

>>> arr = np.arange(24).reshape(2, 3, 4)
>>> for sl in iterate_axes(arr, (0, 1)):
...     print(arr[sl], sl)
[0 1 2 3] (0, 0, slice(None, None, None))
[4 5 6 7] (0, 1, slice(None, None, None))
[ 8  9 10 11] (0, 2, slice(None, None, None))
[12 13 14 15] (1, 0, slice(None, None, None))
[16 17 18 19] (1, 1, slice(None, None, None))
[20 21 22 23] (1, 2, slice(None, None, None))
>>> for sl in iterate_axes(arr, (0, 2)):
...    print(arr[sl], sl)
[0 4 8] (0, slice(None, None, None), 0)
[1 5 9] (0, slice(None, None, None), 1)
[ 2  6 10] (0, slice(None, None, None), 2)
[ 3  7 11] (0, slice(None, None, None), 3)
[12 16 20] (1, slice(None, None, None), 0)
[13 17 21] (1, slice(None, None, None), 1)
[14 18 22] (1, slice(None, None, None), 2)
[15 19 23] (1, slice(None, None, None), 3)