Source code for earthkit.utils.array.array_namespace
# (C) Copyright 2025 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
import typing as T
import array_api_compat
from earthkit.utils.array.namespace import _DEFAULT_NAMESPACE, _NAMESPACES, UnknownPatchedNamespace
def _get_array_name(xp):
name = xp.__name__
if "jax" in name:
return "jax"
elif "numpy" in name:
return "numpy"
elif "cupy" in name:
return "cupy"
elif "torch" in name:
return "torch"
else:
return name
def _get_namespace_from_array(*arrays):
xp = array_api_compat.array_namespace(*arrays)
namespace = _NAMESPACES.get(_get_array_name(xp))
if namespace is None:
namespace = UnknownPatchedNamespace(xp)
return namespace
[docs]
def array_namespace(*args: T.Any) -> T.Any:
"""Return the array namespace of the arguments.
Parameters
----------
*args: tuple
Scalar, string or array-like arguments.
Returns
-------
xp: module
The patched array namespace of the arguments. The namespace
returned from array_api_compat.array_namespace(*args) is patched with
extra/modified methods. When only a scalar is passed, the numpy namespace
is returned.
Notes
-----
The array namespace is extended with the following methods when necessary:
- polyval: evaluate a polynomial (available in numpy)
- percentile: compute the n-th percentile of the data along the
specified axis (available in numpy)
- histogram2d: compute a 2D histogram (available in numpy)
Some other methods may be reimplemented for a given namespace to ensure correct
behaviour. E.g. sign() for torch.
"""
arrays = [a for a in args if array_api_compat.is_array_api_obj(a)]
if not arrays:
# TODO: decide if we want to support this or not
# i.e. array_namespace("numpy")
# array_namespace(np)
if len(args) == 1:
arg = args[0]
if isinstance(arg, str):
xp = _NAMESPACES[arg]
else:
if hasattr(arg, "asarray"):
xp = _get_namespace_from_array(arg.asarray(0))
elif hasattr(arg, "array"):
xp = _get_namespace_from_array(arg.array(0))
else:
xp = _DEFAULT_NAMESPACE
else:
xp = _DEFAULT_NAMESPACE
else:
xp = _get_namespace_from_array(*arrays)
return xp