Source code for earthkit.utils.array.namespace.cupy

# (C) Copyright 2025 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from earthkit.utils.array.namespace.unknown import UnknownPatchedNamespace
from earthkit.utils.decorators import thread_safe_cached_property


[docs] class PatchedCupyNamespace(UnknownPatchedNamespace):
[docs] def __init__(self): super().__init__(None)
@thread_safe_cached_property def xp(self): import array_api_compat.cupy as cp return cp @property def _earthkit_array_namespace_name(self): return "cupy"
[docs] def polyval(self, *args, **kwargs): from cupy.polynomial.polynomial import polyval return polyval(*args, **kwargs)
[docs] def percentile(self, a, q, axis=None): return self.xp.percentile(a, q, axis=axis)
[docs] def quantile(self, a, q, axis=None): return self.xp.quantile(a, q, axis=axis)
[docs] def histogram2d(self, x, y, *, bins=10): return self.xp.histogram2d(x, y, bins=bins)
[docs] def histogramdd(self, x, *, bins=10): return self.xp.histogramdd(x, bins=bins)
[docs] def asarray(self, *args, **kwargs): device = kwargs.pop("device", None) if device is not None: if isinstance(device, str) and device.startswith("cuda"): _, _, idx = device.partition(":") dev_id = int(idx) if idx else 0 else: dev_id = device with self.xp.cuda.Device(dev_id): return self.xp.asarray(*args, **kwargs) else: return self.xp.asarray(*args, **kwargs)
[docs] def to_device(self, x, device, **kwargs): return self.asarray(x, device=device, **kwargs)
[docs] def rad2deg(self, x): return self.xp.rad2deg(x)
[docs] def deg2rad(self, x): return self.xp.deg2rad(x)
[docs] def choice(self, a, size, replace=True, generator=None): rng = self.xp.default_rng() if generator is None else generator return rng.choice(a, size=size, replace=replace)