Coverage for install/scipp/compat/xarray_compat.py: 74%

68 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-04-28 01:28 +0000

1# SPDX-License-Identifier: BSD-3-Clause 

2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) 

3 

4from __future__ import annotations 

5 

6from typing import TYPE_CHECKING, Union 

7from warnings import warn 

8 

9from ..core import DataArray, Dataset, Unit, Variable 

10from ..typing import VariableLike 

11from ..units import default_unit 

12 

13if TYPE_CHECKING: 

14 import xarray as xr 

15 

16 

17def from_xarray(obj: Union[xr.Variable, xr.DataArray, xr.Dataset]) -> VariableLike: 

18 """Convert an xarray object to the corresponding scipp object. 

19 Attributes named `"units"` are used to set the units of the Variables. 

20 All other DataArray attributes are kept, but attributes of Variables, Coordinates 

21 and Datasets are dropped. 

22 

23 Parameters 

24 ---------- 

25 obj: 

26 The xarray object to convert. 

27 

28 Returns 

29 ------- 

30 : 

31 The converted scipp object. 

32 

33 See Also 

34 -------- 

35 scipp.compat.to_xarray 

36 """ 

37 import xarray as xr 

38 

39 if isinstance(obj, xr.Variable): 

40 return _from_xarray_variable(obj) 

41 elif isinstance(obj, xr.DataArray): 

42 return _from_xarray_dataarray(obj) 

43 elif isinstance(obj, xr.Dataset): 

44 return _from_xarray_dataset(obj) 

45 else: 

46 raise ValueError(f"from_xarray: cannot convert type '{type(obj)}'") 

47 

48 

49def to_xarray(obj: VariableLike) -> Union[xr.Variable, xr.DataArray, xr.Dataset]: 

50 """Convert a scipp object to the corresponding xarray object. 

51 

52 Warning 

53 ------- 

54 Any masks and variances in the input will be stripped during the conversion. 

55 Binned data is not supported. 

56 

57 Parameters 

58 ---------- 

59 obj: 

60 The scipp object to convert. 

61 

62 Returns 

63 ------- 

64 : 

65 The converted xarray object. 

66 

67 See Also 

68 -------- 

69 scipp.compat.from_xarray 

70 """ 

71 

72 if isinstance(obj, Variable): 

73 return _to_xarray_variable(obj) 

74 elif isinstance(obj, DataArray): 

75 return _to_xarray_dataarray(obj) 

76 elif isinstance(obj, Dataset): 

77 return _to_xarray_dataset(obj) 

78 else: 

79 raise ValueError(f"to_xarray: cannot convert type '{type(obj)}'") 

80 

81 

82def _from_xarray_variable(xr_obj: Union[xr.Coordinate, xr.DataArray]) -> Variable: 

83 """Converts an xarray Coordinate or the data in a DataArray to a scipp.Variable.""" 

84 unit = xr_obj.attrs.get('units', None) 

85 return Variable( 

86 dims=xr_obj.dims, 

87 values=xr_obj.values, 

88 unit=Unit(unit) if unit is not None else default_unit, 

89 ) 

90 

91 

92def _to_xarray_variable(var: Variable) -> xr.Variable: 

93 """Converts a scipp.Variable to a dict containing dims, values and unit for storing 

94 in either an xarray Coordinate or DataArray. 

95 """ 

96 import xarray as xr 

97 

98 if var.bins is not None: 

99 raise ValueError("Xarray does not support binned data.") 

100 if var.variances is not None: 

101 warn( 

102 "Variances of variable were stripped when converting to Xarray.", 

103 stacklevel=3, 

104 ) 

105 attrs = {'units': str(var.unit)} if var.unit is not None else None 

106 return xr.Variable(dims=var.dims, data=var.values, attrs=attrs) 

107 

108 

109def _from_xarray_dataarray(da: xr.DataArray) -> DataArray: 

110 """Converts an xarray.DataArray object to a scipp.DataArray object.""" 

111 if da.attrs and set(da.attrs) != {"units"}: 

112 warn( 

113 "Input data contains some attributes which have been dropped during the " 

114 "conversion.", 

115 stacklevel=3, 

116 ) 

117 coords = { 

118 f"{name}": _from_xarray_variable(coord) for name, coord in da.coords.items() 

119 } 

120 scipp_da = DataArray( 

121 data=_from_xarray_variable(da), 

122 coords=coords, 

123 name=getattr(da, "name", None) or "", 

124 ) 

125 for name in da.coords: 

126 if name not in da.indexes: 

127 scipp_da.coords.set_aligned(f'{name}', False) 

128 return scipp_da 

129 

130 

131def _to_xarray_dataarray(da: DataArray) -> xr.DataArray: 

132 """Converts a scipp.DataArray object to an xarray.DataArray object.""" 

133 import xarray as xr 

134 

135 if da.masks: 

136 warn( 

137 "Some masks were found in the DataArray. " 

138 "These have been removed when converting to Xarray.", 

139 stacklevel=3, 

140 ) 

141 out = xr.DataArray(_to_xarray_variable(da.data)) 

142 for key, coord in da.coords.items(): 

143 for dim in coord.dims: 

144 if da.meta.is_edges(key, dim=dim): 

145 raise ValueError("Xarray does not support coordinates with bin edges.") 

146 out.coords[key] = _to_xarray_variable(coord) 

147 out = out.drop_indexes(key for key, coord in da.coords.items() if not coord.aligned) 

148 return out 

149 

150 

151def _from_xarray_dataset(ds: xr.Dataset) -> Dataset: 

152 """Converts an xarray.Dataset object to a scipp.Dataset object.""" 

153 if ds.attrs: 

154 warn( 

155 "Input data contains some attributes which have been dropped during the " 

156 "conversion.", 

157 stacklevel=3, 

158 ) 

159 sc_data = {k: _from_xarray_dataarray(v) for k, v in ds.items()} 

160 # The non-indexed coordinates of items also show up as global coordinates in an 

161 # Xarray dataset, so we make sure we exclude those when we add the remaining coords, 

162 # after creating the dataset from the individual data arrays. 

163 coords_in_data_arrays = [] 

164 for item in ds.values(): 

165 coords_in_data_arrays += list(item.coords.keys()) 

166 return Dataset( 

167 data=sc_data, 

168 coords={ 

169 key: _from_xarray_variable(ds.coords[key]) 

170 for key in (set(ds.coords.keys()) - set(coords_in_data_arrays)) 

171 }, 

172 ) 

173 

174 

175def _to_xarray_dataset(ds: Dataset) -> xr.Dataset: 

176 """Converts a scipp.Dataset object to an xarray.Dataset object.""" 

177 import xarray as xr 

178 

179 return xr.Dataset( 

180 data_vars={k: _to_xarray_variable(v.data) for k, v in ds.items()}, 

181 coords={c: _to_xarray_variable(coord) for c, coord in ds.coords.items()}, 

182 )