Coverage for install/scipp/operations.py: 72%

25 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-04-28 01:28 +0000

1# SPDX-License-Identifier: BSD-3-Clause 

2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) 

3# @author Simon Heybrock 

4import functools 

5from inspect import signature 

6from typing import Callable, Optional 

7 

8from ._scipp.core import transform as cpp_transform 

9from .core import Variable 

10 

11 

12def _as_numba_cfunc(function, unit_func=None): 

13 import numba 

14 

15 dtype = 'double' 

16 narg = len(signature(function).parameters) 

17 cfunc = numba.cfunc(dtype + '(' + ','.join([dtype] * narg) + ')')(function) 

18 cfunc.unit_func = function if unit_func is None else unit_func 

19 cfunc.name = function.__name__ 

20 return cfunc 

21 

22 

23def elemwise_func( 

24 func: Optional[Callable] = None, 

25 *, 

26 unit_func: Optional[Callable] = None, 

27 dtype: str = 'float64', 

28 auto_convert_dtypes: bool = False, 

29) -> Callable: 

30 """ 

31 Create a function for transforming input variables based on element-wise operation. 

32 

33 This uses ``numba.cfunc`` to compile a kernel that Scipp can use for transforming 

34 the variable contents. Only variables with dtype=float64 are supported. Variances 

35 are not supported. 

36 

37 Custom kernels can reduce intermediate memory consumption and improve performance 

38 in multi-step operations with large input variables. 

39 

40 Parameters 

41 ---------- 

42 func: 

43 Function to compute an output element from input element values. 

44 unit_func: 

45 Function to compute the output unit. If ``None``, ``func`` will be used. 

46 auto_convert_dtypes: 

47 Set to ``True`` to automatically convert all inputs to float64. 

48 

49 Returns 

50 ------- 

51 : 

52 A callable that applies ``func`` to the elements of the variables passed to it. 

53 

54 Examples 

55 -------- 

56 

57 We can define a fused multiply-add operation as follows: 

58 

59 >>> def fmadd(a, b, c): 

60 ... return a * b + c 

61 

62 >>> func = sc.elemwise_func(fmadd) 

63 

64 >>> x = sc.linspace('x', 0.0, 1.0, num=4, unit='m') 

65 >>> y = x - 0.2 * x 

66 >>> z = sc.scalar(1.2, unit='m**2') 

67 

68 >>> func(x, y, z) 

69 <scipp.Variable> (x: 4) float64 [m^2] [1.2, 1.28889, 1.55556, 2] 

70 

71 Note that ``fmadd(x, y, z)`` would have the same effect in this case, but requires 

72 a potentially large intermediate allocation for the result of "a * b". 

73 """ 

74 

75 def decorator(f): 

76 cfunc = _as_numba_cfunc(f, unit_func=unit_func) 

77 

78 @functools.wraps(f) 

79 def transform_custom(*args: Variable) -> Variable: 

80 if auto_convert_dtypes: 

81 args = [arg.to(dtype='float64', copy=False) for arg in args] 

82 return cpp_transform(cfunc, *args) 

83 

84 return transform_custom 

85 

86 if func is None: 

87 return decorator 

88 return decorator(func)