"""Utility functions to use Python Array API compatible libraries.
Copied from Scipy
For the context about the Array API see:
https://data-apis.org/array-api/latest/purpose_and_scope.html
The SciPy use case of the Array API is described on the following page:
https://data-apis.org/array-api/latest/use_cases.html#use-case-scipy
"""
import os
from types import ModuleType
from typing import Any, Literal, TypeAlias
import numpy as np
import numpy.typing as npt
from functools import reduce
from array_api_compat import (
array_namespace as xp_array_namespace,
is_array_api_obj,
size as xp_size,
numpy as np_compat,
device as xp_device,
is_numpy_namespace as is_numpy,
is_cupy_namespace as is_cupy,
is_torch_namespace as is_torch,
is_jax_namespace as is_jax,
is_array_api_strict_namespace as is_array_api_strict
)
from array_api_extra import (
at, atleast_nd, cov, create_diagonal, expand_dims, kron, nunique,
pad, setdiff1d, sinc
)
__all__ = [
'_asarray', 'array_namespace', 'at', 'atleast_nd', 'cov',
'create_diagonal', 'expand_dims', 'kron', 'nunique', 'pad', 'setdiff1d',
'sinc', 'get_xp_devices',
'is_array_api_strict', 'is_complex', 'is_cupy', 'is_jax', 'is_numpy',
'is_torch',
'scipy_namespace_for',
'xp_assert_equal', 'xp_assert_less',
'xp_copy', 'xp_copysign', 'xp_device',
'xp_moveaxis_to_end', 'xp_ravel', 'xp_real', 'xp_sign', 'xp_size',
'xp_take_along_axis', 'xp_vector_norm',
]
# To enable array API and strict array-like input validation
# set the environment variable SCIPY_ARRAY_API to True.
os.environ.setdefault("SCIPY_ARRAY_API", "1")
SCIPY_ARRAY_API: str | bool = os.environ.get("SCIPY_ARRAY_API", False)
Array: TypeAlias = Any # To be changed to a Protocol later (see array-api#589)
ArrayLike: TypeAlias = Array | npt.ArrayLike
def _check_finite(array: Array, xp: ModuleType) -> None:
"""Check for NaNs or Infs."""
msg = "array must not contain infs or NaNs"
try:
if not xp.all(xp.isfinite(array)):
raise ValueError(msg)
except TypeError:
raise ValueError(msg)
[docs]
def array_namespace(*arrays: Array) -> ModuleType:
"""Get the array API compatible namespace for the arrays xs.
Parameters
----------
*arrays : sequence of array_like
Arrays used to infer the common namespace.
Returns
-------
namespace : module
Common namespace.
Notes
-----
Thin wrapper around `array_api_compat.array_namespace`.
1. Check for the global switch: SCIPY_ARRAY_API. This can also be accessed
dynamically through ``_GLOBAL_CONFIG['SCIPY_ARRAY_API']``.
2. `_compliance_scipy` raise exceptions on known-bad subclasses. See
its definition for more details.
When the global switch is False, it defaults to the `numpy` namespace.
In that case, there is no compliance check. This is a convenience to
ease the adoption. Otherwise, arrays must comply with the new rules.
Examples
--------
>>> import numpy as np
>>> array_namespace(np.array([1, 2, 3])) # doctest: +ELLIPSIS
<module '...numpy' from ...
"""
_arrays = [array for array in arrays if array is not None]
return xp_array_namespace(*_arrays)
[docs]
def intersect1d(*arrays: Array, assume_unique: bool = False,
xp: ModuleType | None = None) -> Array:
"""SciPy-specific replacement for `np.intersect1d` with `assume_unique` and
`xp`.
Parameters
----------
*arrays : array_like
Input arrays. Will be cast to a common type.
assume_unique : bool, optional
If True, the input arrays are assumed to be unique, which can speed up
the calculation.
xp : array_namespace, optional
The array API namespace to use. If not provided, the namespace is
inferred from the arrays.
Returns
-------
intersect1d : array
Sorted 1D array of common elements.
Notes
-----
This function is a thin wrapper around `setdiff1d` from `array_api_extra`.
Examples
--------
>>> import numpy as np
>>> x = np.array([1, 2, 3, 4, 5])
>>> y = np.array([3, 4, 5, 6, 7])
>>> intersect1d(x, y)
array([3, 4, 5])
>>> z = np.array([3, 4, 7, 8])
>>> intersect1d(x, y, z)
array([3, 4])
"""
if xp is None:
xp = array_namespace(*arrays)
if hasattr(xp, 'intersect1d'):
def reduction(x, y):
return xp.intersect1d(x, y, assume_unique=assume_unique)
return reduce(reduction, arrays)
if len(arrays) == 0:
return xp.array([])
result = xp.asarray(arrays[0])
for array in arrays[1:]:
result = setdiff1d(result, setdiff1d(
result, xp.asarray(array), assume_unique=assume_unique),
assume_unique=assume_unique)
return result
[docs]
def split(array: Array, indices_or_sections: int | list[int], axis: int = 0,
xp: ModuleType | None = None) -> list[Array]:
"""SciPy-specific replacement for `np.split` with `axis` and `xp`.
Parameters
----------
array : array_like
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D array
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible, an
error is raised. If `indices_or_sections` is a 1-D array of sorted
integers, the entries indicate where along `axis` the array is
split.
axis : int, optional
The axis along which to split, default is 0.
xp : array_namespace, optional
The array API namespace to use. If not provided, the namespace is
inferred from the arrays.
Returns
-------
subarrays : list of ndarrays
A list of sub-arrays.
Notes
-----
This function is a thin wrapper around `array_api_compat.split`.
Examples
--------
>>> import numpy as np
>>> x = np.arange(9.0)
>>> np.split(x, 3)
[array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7., 8.])]
>>> x = np.arange(8.0).reshape(2, 4)
>>> np.split(x, 2, axis=1)
[array([[0., 1.],
[4., 5.]]), array([[2., 3.],
[6., 7.]])]
"""
if xp is None:
xp = array_namespace(array)
start = 0
if isinstance(indices_or_sections, int):
indices = np.linspace(0, xp_size(array), indices_or_sections + 1,
dtype=int)
else:
indices = xp.asarray(indices_or_sections)
subarrays = []
for end in indices:
subarrays.append(xp.take(array, slice(start, end), axis=axis))
start = end
return subarrays
def _asarray(
array: ArrayLike,
dtype: Any = None,
order: Literal['K', 'A', 'C', 'F'] | None = None,
copy: bool | None = None,
*,
xp: ModuleType | None = None,
check_finite: bool = False,
subok: bool = False
) -> Array:
"""SciPy-specific replacement for `np.asarray` with `order`,
`check_finite`, and `subok`.
Memory layout parameter `order` is not exposed in the Array API standard.
`order` is only enforced if the input array implementation
is NumPy based, otherwise `order` is just silently ignored.
`check_finite` is also not a keyword in the array API standard; included
here for convenience rather than that having to be a separate function
call inside SciPy functions.
`subok` is included to allow this function to preserve the behaviour of
`np.asanyarray` for NumPy based inputs.
"""
if xp is None:
xp = array_namespace(array)
if is_numpy(xp):
# Use NumPy API to support order
if copy is True:
array = np.array(array, order=order, dtype=dtype, subok=subok)
elif subok:
array = np.asanyarray(array, order=order, dtype=dtype)
else:
array = np.asarray(array, order=order, dtype=dtype)
else:
try:
array = xp.asarray(array, dtype=dtype, copy=copy)
except TypeError:
coerced_xp = array_namespace(xp.asarray(3))
array = coerced_xp.asarray(array, dtype=dtype, copy=copy)
if check_finite:
_check_finite(array, xp)
return array
[docs]
def xp_copy(x: Array, *, xp: ModuleType | None = None) -> Array:
"""
Copies an array.
Parameters
----------
x : array
xp : array_namespace
Returns
-------
copy : array
Copied array
Notes
-----
This copy function does not offer all the semantics of `np.copy`, i.e. the
`subok` and `order` keywords are not used.
Examples
--------
>>> import numpy as np
>>> xp_copy([1,2,3], xp=np)
array([1, 2, 3])
"""
# Note: for older NumPy versions, `np.asarray` did not support the `copy`
# kwarg, so this uses our other helper `_asarray`.
if xp is None:
xp = array_namespace(x)
return _asarray(x, copy=True, xp=xp)
def _strict_check(actual, desired, xp, *,
check_namespace=True, check_dtype=True, check_shape=True,
check_0d=True):
__tracebackhide__ = True # Hide traceback for py.test
if check_namespace:
_assert_matching_namespace(actual, desired)
# only NumPy distinguishes between scalars and arrays; we do if
# check_0d=True. Do this first so we can then cast to array (and thus use
# the array API) below.
if is_numpy(xp) and check_0d:
_msg = ("Array-ness does not match:\n Actual: "
f"{type(actual)}\n Desired: {type(desired)}")
assert ((xp.isscalar(actual) and xp.isscalar(desired))
or (not xp.isscalar(actual) and not xp.isscalar(desired))), \
_msg
actual = xp.asarray(actual)
desired = xp.asarray(desired)
if check_dtype:
_msg = (f"dtypes do not match.\nActual: {actual.dtype}\nDesired:"
f" {desired.dtype}")
assert actual.dtype == desired.dtype, _msg
if check_shape:
_msg = (f"Shapes do not match.\nActual: {actual.shape}\nDesired:"
f" {desired.shape}")
assert actual.shape == desired.shape, _msg
desired = xp.broadcast_to(desired, actual.shape)
return actual, desired
def _assert_matching_namespace(actual, desired):
__tracebackhide__ = True # Hide traceback for py.test
actual = actual if isinstance(actual, tuple) else (actual,)
desired_space = array_namespace(desired)
for arr in actual:
arr_space = array_namespace(arr)
_msg = (f"Namespaces do not match.\n"
f"Actual: {arr_space.__name__}\n"
f"Desired: {desired_space.__name__}")
assert arr_space == desired_space, _msg
[docs]
def xp_assert_equal(actual, desired, *, check_namespace=True, check_dtype=True,
check_shape=True, check_0d=True, err_msg='',
xp: ModuleType | None = None):
"""Assert that two arrays are equal.
Parameters
----------
actual : array_like
The array to test.
desired : array_like
The expected array.
check_namespace : bool, optional
If True, check that the arrays have the same namespace.
check_dtype : bool, optional
If True, check that the arrays have the same dtype.
check_shape : bool, optional
If True, check that the arrays have the same shape.
check_0d : bool, optional
If True, check that the arrays have the same dimensionality.
err_msg : str, optional
The error message to be printed in case of failure.
xp : module, optional
The array API namespace to use. If not provided, the namespace is
inferred from the arrays.
Returns
-------
None
If the arrays are equal, otherwise raises an AssertionError.
Notes
-----
This function is a wrapper around the testing functions of different array
libraries, providing a consistent interface for equality testing across
different backends.
Examples
--------
>>> import numpy as np
>>> a = np.array([1, 2, 3])
>>> b = np.array([1, 2, 3])
>>> xp_assert_equal(a, b) # No error raised
>>> c = np.array([1, 2, 4])
>>> xp_assert_equal(a, c) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
AssertionError: Arrays are not equal
"""
__tracebackhide__ = True # Hide traceback for py.test
if xp is None:
xp = array_namespace(actual)
actual, desired = _strict_check(
actual, desired, xp, check_namespace=check_namespace,
check_dtype=check_dtype, check_shape=check_shape,
check_0d=check_0d
)
if is_cupy(xp):
return xp.testing.assert_array_equal(actual, desired, err_msg=err_msg)
elif is_torch(xp):
# PyTorch recommends using `rtol=0, atol=0` like this
# to test for exact equality
err_msg = None if err_msg == '' else err_msg
return xp.testing.assert_close(actual, desired, rtol=0, atol=0,
equal_nan=True, check_dtype=False,
msg=err_msg)
# JAX uses `np.testing`
return np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
[docs]
def xp_assert_less(actual, desired, *, check_namespace=True, check_dtype=True,
check_shape=True, check_0d=True, err_msg='', verbose=True,
xp: ModuleType | None = None):
"""Assert that all elements of an array are strictly less than another
array.
Parameters
----------
actual : array_like
The array to test.
desired : array_like
The array to compare against.
check_namespace : bool, optional
If True, check that the arrays have the same namespace.
check_dtype : bool, optional
If True, check that the arrays have the same dtype.
check_shape : bool, optional
If True, check that the arrays have the same shape.
check_0d : bool, optional
If True, check that the arrays have the same dimensionality.
err_msg : str, optional
The error message to be printed in case of failure.
verbose : bool, optional
If True, print arrays that are not less.
xp : module, optional
The array API namespace to use. If not provided, the namespace is
inferred from the arrays.
Returns
-------
None
If all elements of `actual` are strictly less than all elements of
`desired`, otherwise raises an AssertionError.
Notes
-----
This function is a wrapper around the testing functions of different array
libraries, providing a consistent interface for comparison testing across
different backends.
Examples
--------
>>> import numpy as np
>>> a = np.array([1, 2, 3])
>>> b = np.array([4, 5, 6])
>>> xp_assert_less(a, b) # No error raised
>>> c = np.array([3, 4, 2])
>>> xp_assert_less(a, c) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
AssertionError: Arrays are not less
"""
__tracebackhide__ = True # Hide traceback for py.test
if xp is None:
xp = array_namespace(actual)
actual, desired = _strict_check(
actual, desired, xp, check_namespace=check_namespace,
check_dtype=check_dtype, check_shape=check_shape,
check_0d=check_0d
)
if is_cupy(xp):
return xp.testing.assert_array_less(actual, desired,
err_msg=err_msg, verbose=verbose)
elif is_torch(xp):
if actual.device.type != 'cpu':
actual = actual.cpu()
if desired.device.type != 'cpu':
desired = desired.cpu()
# JAX uses `np.testing`
return np.testing.assert_array_less(actual, desired,
err_msg=err_msg, verbose=verbose)
[docs]
def is_complex(x: Array, xp: ModuleType) -> bool:
"""Check if an array has a complex floating-point data type.
Parameters
----------
x : Array
The array to check.
xp : module
The array API namespace to use.
Returns
-------
bool
True if the array has a complex floating-point data type,
False otherwise.
Examples
--------
>>> import numpy as np
>>> x = np.array([1+2j, 3+4j])
>>> is_complex(x, np)
True
>>> y = np.array([1, 2, 3])
>>> is_complex(y, np)
False
"""
return xp.isdtype(x.dtype, 'complex floating')
[docs]
def get_xp_devices(xp: ModuleType) -> list[str] | list[None]:
"""Returns a list of available devices for the given namespace.
Parameters
----------
xp : module
The array API namespace to check for available devices.
Returns
-------
list of str or list of None
A list of available device strings for the given namespace.
For PyTorch, this might include 'cpu', 'cuda:0', etc.
For CuPy, this might include 'cuda:0', etc.
For JAX, this might include 'cpu:0', 'gpu:0', 'tpu:0', etc.
For other namespaces, returns [None].
Examples
--------
>>> import numpy as np
>>> get_xp_devices(np)
[None]
If PyTorch is available
>>> import torch # doctest: +SKIP
>>> get_xp_devices(torch) # doctest: +SKIP
['cpu', 'cuda:0']
"""
devices: list[str] = []
if is_torch(xp):
devices += ['cpu']
import torch # type: ignore[import]
num_cuda = torch.cuda.device_count()
for i in range(0, num_cuda):
devices += [f'cuda:{i}']
if torch.backends.mps.is_available():
devices += ['mps']
return devices
elif is_cupy(xp):
import cupy # type: ignore[import]
num_cuda = cupy.cuda.runtime.getDeviceCount()
for i in range(0, num_cuda):
devices += [f'cuda:{i}']
return devices
elif is_jax(xp):
import jax # type: ignore[import]
num_cpu = jax.device_count(backend='cpu')
for i in range(0, num_cpu):
devices += [f'cpu:{i}']
num_gpu = jax.device_count(backend='gpu')
for i in range(0, num_gpu):
devices += [f'gpu:{i}']
num_tpu = jax.device_count(backend='tpu')
for i in range(0, num_tpu):
devices += [f'tpu:{i}']
return devices
# given namespace is not known to have a list of available devices;
# return `[None]` so that one can use this in tests for `device=None`.
return [None]
[docs]
def scipy_namespace_for(xp: ModuleType) -> ModuleType | None:
"""Return the `scipy`-like namespace of a non-NumPy backend
That is, return the namespace corresponding with backend `xp` that contains
`scipy` sub-namespaces like `linalg` and `special`. If no such namespace
exists, return ``None``. Useful for dispatching.
Parameters
----------
xp : module
The array API namespace for which to find the corresponding SciPy-like
namespace.
Returns
-------
module or None
The SciPy-like namespace for the given array API namespace, or None if
no such namespace exists.
Examples
--------
>>> import numpy as np
>>> import scipy
>>> scipy_namespace = scipy_namespace_for(np)
>>> scipy_namespace is scipy
True
>>> import cupy as cp # doctest: +SKIP
>>> scipy_namespace = scipy_namespace_for(cp) # doctest: +SKIP
>>> scipy_namespace is cupyx.scipy # doctest: +SKIP
True
"""
if is_cupy(xp):
import cupyx # type: ignore[import-not-found,import-untyped]
return cupyx.scipy
elif is_jax(xp):
import jax # type: ignore[import-not-found]
return jax.scipy
elif is_torch(xp):
return xp
elif is_numpy(xp):
import scipy
return scipy
return None
# temporary substitute for xp.moveaxis, which is not yet in all backends
# or covered by array_api_compat.
[docs]
def xp_moveaxis_to_end(
x: Array,
source: int,
/, *,
xp: ModuleType | None = None) -> Array:
"""Move an axis to the end of the array.
Parameters
----------
x : Array
The input array.
source : int
The source axis to move.
xp : module, optional
The array API namespace to use. If not provided, the namespace is
inferred from the array.
Returns
-------
Array
Array with the source axis moved to the end.
Notes
-----
This is a temporary substitute for xp.moveaxis, which is not yet
available in all backends or covered by array_api_compat.
Examples
--------
>>> import numpy as np
>>> x = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
>>> x.shape
(2, 2, 2)
>>> y = xp_moveaxis_to_end(x, 0, xp=np)
>>> y.shape
(2, 2, 2)
>>> # The first axis (axis 0) is now the last axis
>>> np.array_equal(y, np.moveaxis(x, 0, -1))
True
"""
xp = array_namespace(xp) if xp is None else xp
axes = list(range(x.ndim))
temp = axes.pop(source)
axes = axes + [temp]
return xp.permute_dims(x, axes)
# temporary substitute for xp.copysign, which is not yet in all backends
# or covered by array_api_compat.
[docs]
def xp_copysign(x1: Array, x2: Array, /, *, xp: ModuleType | None = None
) -> Array:
"""Copy the sign of x2 to the magnitude of x1.
Parameters
----------
x1 : Array
The array containing the magnitudes.
x2 : Array
The array containing the signs.
xp : module, optional
The array API namespace to use. If not provided, the namespace is
inferred from the arrays.
Returns
-------
Array
An array with the magnitude of x1 and the sign of x2.
Notes
-----
This is a temporary substitute for xp.copysign, which is not yet available
in all backends or covered by array_api_compat. This implementation does
not attempt to account for special cases.
Examples
--------
>>> import numpy as np
>>> x1 = np.array([-1.3, 1.5, -3.0])
>>> x2 = np.array([1.0, -2.2, 3.0])
>>> xp_copysign(x1, x2, xp=np)
array([ 1.3, -1.5, 3. ])
>>> # Equivalent to NumPy's copysign
>>> np.copysign(x1, x2)
array([ 1.3, -1.5, 3. ])
"""
# no attempt to account for special cases
xp = array_namespace(x1, x2) if xp is None else xp
abs_x1 = xp.abs(x1)
return xp.where(x2 >= 0, abs_x1, -abs_x1)
# partial substitute for xp.sign, which does not cover the NaN special case
# that I need. (https://github.com/data-apis/array-api-compat/issues/136)
[docs]
def xp_sign(x: Array, /, *, xp: ModuleType | None = None) -> Array:
"""Return the sign of each element in the input array.
Parameters
----------
x : Array
The input array.
xp : module, optional
The array API namespace to use. If not provided, the namespace is
inferred from the array.
Returns
-------
Array
An array with the same shape as x, where each element has the sign of
the corresponding element in x. The sign is defined as:
- 1 for positive values
- 0 for zero
- -1 for negative values
- NaN for NaN values
Notes
-----
This is a partial substitute for xp.sign, which does not cover the NaN
special case in some array API implementations. See
https://github.com/data-apis/array-api-compat/issues/136 for more details
Examples
--------
>>> import numpy as np
>>> x = np.array([-5.0, 0.0, 3.0, np.nan])
>>> xp_sign(x, xp=np)
array([-1., 0., 1., nan])
>>> # Equivalent to NumPy's sign
>>> np.sign(x)
array([-1., 0., 1., nan])
"""
xp = array_namespace(x) if xp is None else xp
if is_numpy(xp): # only NumPy implements the special cases correctly
return xp.sign(x)
sign = xp.zeros_like(x)
one = xp.asarray(1, dtype=x.dtype)
sign = xp.where(x > 0, one, sign)
sign = xp.where(x < 0, -one, sign)
sign = xp.where(xp.isnan(x), xp.nan*one, sign)
return sign
# TODO: maybe use `scipy.linalg` if/when array API support is added
[docs]
def xp_vector_norm(x: Array, /, *,
axis: int | tuple[int] | None = None,
keepdims: bool = False,
ord: int | float = 2,
xp: ModuleType | None = None) -> Array:
"""Compute the vector norm of an array.
Parameters
----------
x : Array
The input array.
axis : int or tuple of ints, optional
The axis or axes along which to compute the norm. If None, the norm is
computed over all elements in the array.
keepdims : bool, optional
If True, the axes which are reduced are left in the result as
dimensions with size one. With this option, the result will broadcast
correctly against the original array.
ord : int or float, optional
The order of the norm. Default is 2 (Euclidean norm).
xp : module, optional
The array API namespace to use. If not provided, the namespace is
inferred from the array.
Returns
-------
Array
The vector norm of the input array.
Notes
-----
This function attempts to use the `linalg.vector_norm` function from the
array API if available. If not, it falls back to a simple implementation
for the Euclidean norm (ord=2). For backends not implementing the
`linalg` extension, only the Euclidean norm is supported.
Examples
--------
>>> import numpy as np
>>> x = np.array([3.0, 4.0])
>>> float(xp_vector_norm(x, xp=np)) # Euclidean norm (default)
5.0
>>> float(xp_vector_norm(x, ord=1, xp=np)) # L1 norm
7.0
>>> x = np.array([[1, 2], [3, 4]])
>>> xp_vector_norm(x, axis=0, xp=np) # Norm along first axis
array([3.16227766, 4.47213595])
>>> xp_vector_norm(x, axis=1, xp=np) # Norm along second axis
array([2.23606798, 5. ])
"""
xp = array_namespace(x) if xp is None else xp
if SCIPY_ARRAY_API:
# check for optional `linalg` extension
if hasattr(xp, 'linalg'):
return xp.linalg.vector_norm(x, axis=axis, keepdims=keepdims,
ord=ord)
else:
if ord != 2:
raise ValueError(
"only the Euclidean norm (`ord=2`) is currently supported"
" in `xp_vector_norm` for backends not implementing the"
" `linalg` extension."
)
# return (x @ x)**0.5
# or to get the right behavior with nd, complex arrays
return xp.sum(xp.conj(x) * x, axis=axis, keepdims=keepdims)**0.5
else:
# to maintain backwards compatibility
return np.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)
[docs]
def xp_ravel(x: Array, /, *, xp: ModuleType | None = None) -> Array:
"""Return a flattened array.
Parameters
----------
x : Array
The input array to be flattened.
xp : module, optional
The array API namespace to use. If not provided, the namespace is
inferred from the array.
Returns
-------
Array
A 1-D array containing the elements of the input array.
Notes
-----
This function is equivalent to np.ravel written in terms of the array API.
Even though it's one line, it comes up so often that it's worth having
this function for readability.
Examples
--------
>>> import numpy as np
>>> x = np.array([[1, 2], [3, 4]])
>>> xp_ravel(x, xp=np)
array([1, 2, 3, 4])
>>> # Equivalent to NumPy's ravel
>>> np.ravel(x)
array([1, 2, 3, 4])
"""
xp = array_namespace(x) if xp is None else xp
return xp.reshape(x, (-1,))
[docs]
def xp_real(x: Array, /, *, xp: ModuleType | None = None) -> Array:
"""Return the real part of a complex array or the array itself if it's
not complex.
Parameters
----------
x : Array
The input array.
xp : module, optional
The array API namespace to use. If not provided, the namespace is
inferred from the array.
Returns
-------
Array
The real part of the input array if it has a complex data type,
otherwise the input array.
Notes
-----
This is a convenience wrapper of xp.real that allows non-complex input;
see data-apis/array-api#824.
Examples
--------
>>> import numpy as np
>>> x = np.array([1+2j, 3+4j])
>>> xp_real(x, xp=np)
array([1., 3.])
>>> y = np.array([1, 2, 3])
>>> xp_real(y, xp=np) # Non-complex input is returned as is
array([1, 2, 3])
"""
xp = array_namespace(x) if xp is None else xp
return xp.real(x) if xp.isdtype(x.dtype, 'complex floating') else x
[docs]
def xp_take_along_axis(arr: Array,
indices: Array, /, *,
axis: int = -1,
xp: ModuleType | None = None) -> Array:
"""Take values from an array along an axis at the indices specified.
Parameters
----------
arr : Array
The source array.
indices : Array
The indices of the values to extract. This array must have the same
shape as `arr`, excluding the axis dimension.
axis : int, optional
The axis over which to select values. Default is -1 (the last axis).
xp : module, optional
The array API namespace to use. If not provided, the namespace is
inferred from the arrays.
Returns
-------
Array
The indexed array. The shape of the output is the same as `indices`.
Notes
-----
This is a dispatcher for np.take_along_axis for backends that support it;
see data-apis/array-api/pull#816.
Examples
--------
>>> import numpy as np
>>> a = np.array([[10, 30, 20], [60, 40, 50]])
>>> # Sort along the last axis
>>> ai = np.argsort(a)
>>> ai
array([[0, 2, 1],
[1, 2, 0]])
>>> xp_take_along_axis(a, ai, xp=np)
array([[10, 20, 30],
[40, 50, 60]])
>>> # Sort along the first axis
>>> ai = np.argsort(a, axis=0)
>>> ai
array([[0, 0, 0],
[1, 1, 1]])
>>> xp_take_along_axis(a, ai, axis=0, xp=np)
array([[10, 30, 20],
[60, 40, 50]])
"""
xp = array_namespace(arr) if xp is None else xp
if is_torch(xp):
return xp.take_along_dim(arr, indices, dim=axis)
elif is_array_api_strict(xp):
raise NotImplementedError("Array API standard does not define"
" take_along_axis")
else:
return xp.take_along_axis(arr, indices, axis)
# utility to broadcast arrays and promote to common dtype
[docs]
def xp_float_to_complex(arr: Array, xp: ModuleType | None = None) -> Array:
"""Convert a floating-point array to a complex array.
Parameters
----------
arr : Array
The input array with floating-point data type.
xp : module, optional
The array API namespace to use. If not provided, the namespace is
inferred from the array.
Returns
-------
Array
The input array converted to a complex data type. float32 is converted
to complex64, and float64 (and other real floating types) are
converted to complex128.
Notes
-----
This function only converts arrays with floating-point data types. If the
input array already has a complex data type, it is returned unchanged.
Examples
--------
>>> import numpy as np
>>> x = np.array([1.0, 2.0, 3.0], dtype=np.float32)
>>> y = xp_float_to_complex(x, xp=np)
>>> y.dtype
dtype('complex64')
>>> z = np.array([4.0, 5.0, 6.0], dtype=np.float64)
>>> w = xp_float_to_complex(z, xp=np)
>>> w.dtype
dtype('complex128')
>>> # Complex input is returned unchanged
>>> c = np.array([1+2j, 3+4j])
>>> xp_float_to_complex(c, xp=np) is c
True
"""
xp = array_namespace(arr) if xp is None else xp
arr_dtype = arr.dtype
# The standard float dtypes are float32 and float64.
# Convert float32 to complex64,
# and float64 (and non-standard real dtypes) to complex128
if xp.isdtype(arr_dtype, xp.float32):
arr = xp.astype(arr, xp.complex64)
elif xp.isdtype(arr_dtype, 'real floating'):
arr = xp.astype(arr, xp.complex128)
return arr
if __name__ == '__main__':
pass