Coverage for install/scipp/compat/xarray_compat.py: 74%
68 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-17 01:51 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-17 01:51 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
4from __future__ import annotations
6from typing import TYPE_CHECKING
7from warnings import warn
9from ..core import DataArray, Dataset, Unit, Variable
10from ..typing import VariableLike
11from ..units import default_unit
13if TYPE_CHECKING:
14 import xarray as xr
17def from_xarray(obj: xr.Variable | xr.DataArray | xr.Dataset) -> VariableLike:
18 """Convert an xarray object to the corresponding scipp object.
19 Attributes named `"units"` are used to set the units of the Variables.
20 All other DataArray attributes are kept, but attributes of Variables, Coordinates
21 and Datasets are dropped.
23 Parameters
24 ----------
25 obj:
26 The xarray object to convert.
28 Returns
29 -------
30 :
31 The converted scipp object.
33 See Also
34 --------
35 scipp.compat.to_xarray
36 """
37 import xarray as xr
39 if isinstance(obj, xr.Variable):
40 return _from_xarray_variable(obj)
41 elif isinstance(obj, xr.DataArray):
42 return _from_xarray_dataarray(obj)
43 elif isinstance(obj, xr.Dataset):
44 return _from_xarray_dataset(obj)
45 else:
46 raise ValueError(f"from_xarray: cannot convert type '{type(obj)}'")
49def to_xarray(obj: VariableLike) -> xr.Variable | xr.DataArray | xr.Dataset:
50 """Convert a scipp object to the corresponding xarray object.
52 Warning
53 -------
54 Any masks and variances in the input will be stripped during the conversion.
55 Binned data is not supported.
57 Parameters
58 ----------
59 obj:
60 The scipp object to convert.
62 Returns
63 -------
64 :
65 The converted xarray object.
67 See Also
68 --------
69 scipp.compat.from_xarray
70 """
72 if isinstance(obj, Variable):
73 return _to_xarray_variable(obj)
74 elif isinstance(obj, DataArray):
75 return _to_xarray_dataarray(obj)
76 elif isinstance(obj, Dataset):
77 return _to_xarray_dataset(obj)
78 else:
79 raise ValueError(f"to_xarray: cannot convert type '{type(obj)}'")
82def _from_xarray_variable(xr_obj: xr.Coordinate | xr.DataArray) -> Variable:
83 """Converts an xarray Coordinate or the data in a DataArray to a scipp.Variable."""
84 unit = xr_obj.attrs.get('units', None)
85 return Variable(
86 dims=xr_obj.dims,
87 values=xr_obj.values,
88 unit=Unit(unit) if unit is not None else default_unit,
89 )
92def _to_xarray_variable(var: Variable) -> xr.Variable:
93 """Converts a scipp.Variable to a dict containing dims, values and unit for storing
94 in either an xarray Coordinate or DataArray.
95 """
96 import xarray as xr
98 if var.bins is not None:
99 raise ValueError("Xarray does not support binned data.")
100 if var.variances is not None:
101 warn(
102 "Variances of variable were stripped when converting to Xarray.",
103 stacklevel=3,
104 )
105 attrs = {'units': str(var.unit)} if var.unit is not None else None
106 return xr.Variable(dims=var.dims, data=var.values, attrs=attrs)
109def _from_xarray_dataarray(da: xr.DataArray) -> DataArray:
110 """Converts an xarray.DataArray object to a scipp.DataArray object."""
111 if da.attrs and set(da.attrs) != {"units"}:
112 warn(
113 "Input data contains some attributes which have been dropped during the "
114 "conversion.",
115 stacklevel=3,
116 )
117 coords = {
118 f"{name}": _from_xarray_variable(coord) for name, coord in da.coords.items()
119 }
120 scipp_da = DataArray(
121 data=_from_xarray_variable(da),
122 coords=coords,
123 name=getattr(da, "name", None) or "",
124 )
125 for name in da.coords:
126 if name not in da.indexes:
127 scipp_da.coords.set_aligned(f'{name}', False)
128 return scipp_da
131def _to_xarray_dataarray(da: DataArray) -> xr.DataArray:
132 """Converts a scipp.DataArray object to an xarray.DataArray object."""
133 import xarray as xr
135 if da.masks:
136 warn(
137 "Some masks were found in the DataArray. "
138 "These have been removed when converting to Xarray.",
139 stacklevel=3,
140 )
141 out = xr.DataArray(_to_xarray_variable(da.data))
142 for key, coord in da.coords.items():
143 for dim in coord.dims:
144 if da.meta.is_edges(key, dim=dim):
145 raise ValueError("Xarray does not support coordinates with bin edges.")
146 out.coords[key] = _to_xarray_variable(coord)
147 out = out.drop_indexes(key for key, coord in da.coords.items() if not coord.aligned)
148 return out
151def _from_xarray_dataset(ds: xr.Dataset) -> Dataset:
152 """Converts an xarray.Dataset object to a scipp.Dataset object."""
153 if ds.attrs:
154 warn(
155 "Input data contains some attributes which have been dropped during the "
156 "conversion.",
157 stacklevel=3,
158 )
159 sc_data = {k: _from_xarray_dataarray(v) for k, v in ds.items()}
160 # The non-indexed coordinates of items also show up as global coordinates in an
161 # Xarray dataset, so we make sure we exclude those when we add the remaining coords,
162 # after creating the dataset from the individual data arrays.
163 coords_in_data_arrays = []
164 for item in ds.values():
165 coords_in_data_arrays += list(item.coords.keys())
166 return Dataset(
167 data=sc_data,
168 coords={
169 key: _from_xarray_variable(ds.coords[key])
170 for key in (set(ds.coords.keys()) - set(coords_in_data_arrays))
171 },
172 )
175def _to_xarray_dataset(ds: Dataset) -> xr.Dataset:
176 """Converts a scipp.Dataset object to an xarray.Dataset object."""
177 import xarray as xr
179 return xr.Dataset(
180 data_vars={k: _to_xarray_variable(v.data) for k, v in ds.items()},
181 coords={c: _to_xarray_variable(coord) for c, coord in ds.coords.items()},
182 )