Source code for earthkit.utils.decorators._dispatch

# (C) Copyright 2021 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#
from __future__ import annotations

import logging
import sys
from abc import ABCMeta, abstractmethod
from collections.abc import Callable
from functools import wraps
from importlib import import_module
from inspect import signature
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    import xarray as xr  # noqa: F401
    from earthkit.data import FieldList  # noqa: F401

LOG = logging.getLogger(__name__)


def is_module_loaded(module_name):
    return module_name in sys.modules


def _is_xarray(obj: Any) -> bool:
    if not is_module_loaded("xarray"):
        return False

    try:
        import xarray as xr

    except ImportError:
        return False

    return isinstance(obj, (xr.DataArray, xr.Dataset))


def _is_fieldlist(obj: Any) -> bool:
    if not is_module_loaded("earthkit.data"):
        return False

    try:
        from earthkit.data import FieldList

    except ImportError:
        return False

    return isinstance(obj, FieldList)


def _is_array(obj: Any) -> bool:
    import array_api_compat

    return array_api_compat.is_array_api_obj(obj)


def is_array_like(obj: Any) -> bool:
    """Check if the object is array-like.

    True if it belongs to a known array namespace
    or is a scalar or list that can be converted to an array.
    """
    import numpy as np

    try:
        np.asarray(obj)
        return True
    # TODO: Improve this exception handling
    except Exception:
        return False


class DataDispatcher(metaclass=ABCMeta):
    """A dispatcher class to route function calls based on input data types."""

    @staticmethod
    @abstractmethod
    def match(obj: Any) -> bool:
        pass

    @abstractmethod
    def dispatch(self, func: str, module: str, *args: Any, **kwargs: Any) -> Any:
        pass


class XArrayDispatcher(DataDispatcher):
    @staticmethod
    def match(obj: Any) -> bool:
        return _is_xarray(obj)

    def dispatch(self, func, module, *args, **kwargs):
        module = import_module(module + ".xarray")
        return getattr(module, func)(*args, **kwargs)


class FieldListDispatcher(DataDispatcher):
    @staticmethod
    def match(obj: Any) -> bool:
        return _is_fieldlist(obj)

    def dispatch(self, func, module, *args, **kwargs):
        module = import_module(module + ".fieldlist")
        return getattr(module, func)(*args, **kwargs)


class ArrayDispatcher(DataDispatcher):
    @staticmethod
    def match(obj: Any) -> bool:
        return _is_array(obj)

    def dispatch(self, func, module, *args, **kwargs):
        module = import_module(module + ".array")
        return getattr(module, func)(*args, **kwargs)


class ArrayLikeDispatcher(ArrayDispatcher):
    @staticmethod
    def match(obj: Any) -> bool:
        return is_array_like(obj)


[docs] def dispatch( func: Callable, match: int | str = 0, xarray: bool = True, fieldlist: bool = True, array: bool = True, array_like: bool = False, ): """Decorator to dispatch function calls based on input data types. The dispatch will attempt to route the call to the appropriate implementation based on the type of the specified argument. The implementations are assumed to live in submodules named after the data type (e.g., .xarray, .fieldlist, .array) with the same function name as the toplevel function. This wrapper should be applied inline as: def func(...): return dispatch(func, match=..., xarray=..., fieldlist=..., array=..., array_like=...)(...) Parameters ---------- func: function The toplevel function to be wrapped. match: int or str The index or name of the argument to check for dispatching. Default is 0 (the first argument). xarray: bool Whether to include the xarray dispatcher. Default is True. fieldlist: bool Whether to include the FieldList dispatcher. Default is True. array: bool Whether to include the array dispatcher. Default is True. array_like: bool Whether to include the array-like dispatcher. Default is False. Returns ------- function The decorated function with dispatching capability. """ def _make_wrapper(_func): DISPATCHERS = [] if xarray: DISPATCHERS.append(XArrayDispatcher()) if fieldlist: DISPATCHERS.append(FieldListDispatcher()) if array: DISPATCHERS.append(ArrayDispatcher()) if array_like: DISPATCHERS.append(ArrayLikeDispatcher()) sig = signature(_func) params = list(sig.parameters) if isinstance(match, int): try: param_name = params[match] except IndexError as e: raise ValueError( f"'match' index {match} is invalid for function {_func.__name__} with {len(params)} arguments" ) from e elif isinstance(match, str): if match in params: param_name = match else: raise ValueError(f"'match' parameter name {match} is not in the function signature of {_func.__name__}") else: raise TypeError(f"'match' must be an integer index or a string parameter name, got {type(match)}") @wraps(_func) def wrapper(*args, **kwargs): bound_args = sig.bind(*args, **kwargs) bound_args.apply_defaults() obj_to_check = bound_args.arguments[param_name] module_name = _func.__module__ parent_module, _sep, _ = module_name.rpartition(".") _module = parent_module if parent_module else module_name for dispatcher in DISPATCHERS: try: _matched = dispatcher.match(obj_to_check) except Exception as e: LOG.debug(f"Dispatcher {dispatcher.__class__.__name__} failed to match due to error: {e}") continue if _matched: return dispatcher.dispatch(_func.__name__, _module, *args, **kwargs) raise TypeError( f"No dispatcher matched for function {_func.__name__} with argument {param_name} " f"of type {type(obj_to_check)}, and no default dispatcher specified." ) return wrapper # Called as dispatch(func, ...) return _make_wrapper(func)