Source code for earthkit.utils.decorators._xarray_ufunc
# (C) Copyright 2021 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#
from __future__ import annotations
def _infer_output_count(func) -> int:
import inspect
from typing import get_args, get_origin
try:
annotation = inspect.signature(func).return_annotation
except (ValueError, TypeError):
return 1
if annotation is inspect.Signature.empty:
return 1
origin = get_origin(annotation)
if origin is tuple:
args = get_args(annotation)
if args and args[-1] is not Ellipsis:
return len(args)
return 1
[docs]
def xarray_ufunc(func, *args, **kwargs):
import xarray as xr
xarray_ufunc_kwargs = kwargs.pop("xarray_ufunc_kwargs", {})
merged = {
"dask": "parallelized",
"keep_attrs": True,
}
merged.update(xarray_ufunc_kwargs)
if "output_dtypes" not in merged:
output_count = _infer_output_count(func)
merged["output_dtypes"] = [float] * output_count
if "output_core_dims" not in merged and len(merged["output_dtypes"]) > 1:
output_core_dims = [args[0].dims for _ in merged["output_dtypes"]]
merged["output_core_dims"] = output_core_dims
if "input_core_dims" not in merged and len(merged["output_dtypes"]) > 1:
input_core_dims = [x.dims for x in args]
merged["input_core_dims"] = input_core_dims
return xr.apply_ufunc(
func,
*args,
kwargs=kwargs,
**merged,
)