Source code for earthkit.utils.array.converter.torch
# (C) Copyright 2025 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
from earthkit.utils.array.converter.unknown import FromUnknownConverter
[docs]
class FromTorchConverter(FromUnknownConverter):
[docs]
def __init__(self, xp_target):
super().__init__(xp_target)
[docs]
def to_numpy(self, array, **kwargs):
return array.cpu().numpy()
[docs]
def to_cupy(self, array, **kwargs):
# TODO: add device handling
# (below only works if tensor device is cuda)
import cupy as cp
return cp.from_dlpack(array)
[docs]
def to_torch(self, array, **kwargs):
# TODO: add device handling?
return array
# def to_jax(self, array, **kwargs):
# return