xp_take_along_axis

xp_take_along_axis(arr: Any, indices: Any, /, *, axis: int = -1, xp: ModuleType | None = None) Any[source][source]

Take values from an array along an axis at the indices specified.

Parameters:
  • arr (Array) – The source array.

  • indices (Array) –

    The indices of the values to extract. This array must have the same

    shape as arr, excluding the axis dimension.

  • axis (int, optional) – The axis over which to select values. Default is -1 (the last axis).

  • xp (module, optional) –

    The array API namespace to use. If not provided, the namespace is

    inferred from the arrays.

Returns:

The indexed array. The shape of the output is the same as indices.

Return type:

Array

Notes

This is a dispatcher for np.take_along_axis for backends that support it; see data-apis/array-api/pull#816.

Examples

>>> import numpy as np
>>> a = np.array([[10, 30, 20], [60, 40, 50]])
>>> # Sort along the last axis
>>> ai = np.argsort(a)
>>> ai
array([[0, 2, 1],
       [1, 2, 0]])
>>> xp_take_along_axis(a, ai, xp=np)
array([[10, 20, 30],
       [40, 50, 60]])
>>> # Sort along the first axis
>>> ai = np.argsort(a, axis=0)
>>> ai
array([[0, 0, 0],
       [1, 1, 1]])
>>> xp_take_along_axis(a, ai, axis=0, xp=np)
array([[10, 30, 20],
       [60, 40, 50]])