Source code for earthkit.utils.array.convert

# (C) Copyright 2025 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from earthkit.utils.array.array_namespace import _get_array_name
from earthkit.utils.array.array_namespace import array_namespace as array_namespace_func
from earthkit.utils.array.converter import _CONVERTERS, FromUnknownConverter
from earthkit.utils.array.namespace import _CUPY_NAMESPACE, _NUMPY_NAMESPACE, UnknownPatchedNamespace


def _get_converter(source_array_namespace):
    if isinstance(source_array_namespace, UnknownPatchedNamespace):
        return _CONVERTERS.get(_get_array_name(source_array_namespace), FromUnknownConverter)
    elif isinstance(source_array_namespace, str):
        return _CONVERTERS[source_array_namespace]
    else:
        raise ValueError(f"Unknown array backend: {source_array_namespace._earthkit_array_namespace_name}")


[docs] def convert(array, *, device=None, array_namespace=None, **kwargs): """Return a copy/view of a converted array. Parameters ---------- array : array The array to convert. device : array namespace-specific device spec or str The device to which the array should be moved. For example, "cpu", "cuda:0", etc. array_namespace : str or array namespace The array namespace to use for the conversion. If None, the following logic is applied: - if the device is "cpu", it will use numpy - otherwise it will use the namespace of the array ``v``, but if that backend is numpy, it will use the cupy backend. **kwargs : forwarded to the underlying call """ # TODO: dtype conversion support also? if array_namespace is None and device is None: return array source_xp = array_namespace_func(array) source_name = _get_array_name(source_xp) if array_namespace is None: if device == "cpu" and source_name == "cupy": array_namespace = _NUMPY_NAMESPACE elif device != "cpu" and source_name == "numpy": array_namespace = _CUPY_NAMESPACE else: array_namespace = source_xp if array_namespace is not None: target_xp = array_namespace_func(array_namespace) converter = _get_converter(source_xp) converter_instance = converter(target_xp) target_name = _get_array_name(target_xp) # TODO: decide if we want to pass device here, or later. # Currently, do it later array = converter_instance.to(array, target_name) if device is not None: xp = array_namespace_func(array) array = xp.to_device(array, device=device, **kwargs) return array
def convert_dtype(dtype, array_namespace): target_xp = array_namespace_func(array_namespace) if type(dtype) is str: return target_xp.__array_namespace_info__().dtypes()[dtype] else: import numpy as np source_array_namespace_name = type(dtype).__module__.split(".")[0] # NB: this is very hacky and should be changed if source_array_namespace_name == "builtins": source_xp = _NUMPY_NAMESPACE dtype = np.dtype(dtype) else: source_xp = array_namespace_func(source_array_namespace_name) target_dtypes = target_xp.__array_namespace_info__().dtypes() source_dtypes = source_xp.__array_namespace_info__().dtypes() overlapping_dtypes = np.intersect1d(list(target_dtypes.keys()), list(source_dtypes.keys())) source_dtypes_overlapping_subset = np.vectorize(source_dtypes.get)(overlapping_dtypes) target_dtypes_overlapping_subset = np.vectorize(target_dtypes.get)(overlapping_dtypes) mapping = dict(zip(source_dtypes_overlapping_subset, target_dtypes_overlapping_subset)) return mapping[dtype]