Coverage for install/scipp/operations.py: 72%
25 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-01 01:59 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-01 01:59 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3# @author Simon Heybrock
4import functools
5from collections.abc import Callable
6from inspect import signature
8from ._scipp.core import transform as cpp_transform
9from .core import Variable
12def _as_numba_cfunc(function, unit_func=None):
13 import numba
15 dtype = 'double'
16 narg = len(signature(function).parameters)
17 cfunc = numba.cfunc(dtype + '(' + ','.join([dtype] * narg) + ')')(function)
18 cfunc.unit_func = function if unit_func is None else unit_func
19 cfunc.name = function.__name__
20 return cfunc
23def elemwise_func(
24 func: Callable | None = None,
25 *,
26 unit_func: Callable | None = None,
27 dtype: str = 'float64',
28 auto_convert_dtypes: bool = False,
29) -> Callable:
30 """
31 Create a function for transforming input variables based on element-wise operation.
33 This uses ``numba.cfunc`` to compile a kernel that Scipp can use for transforming
34 the variable contents. Only variables with dtype=float64 are supported. Variances
35 are not supported.
37 Custom kernels can reduce intermediate memory consumption and improve performance
38 in multi-step operations with large input variables.
40 Parameters
41 ----------
42 func:
43 Function to compute an output element from input element values.
44 unit_func:
45 Function to compute the output unit. If ``None``, ``func`` will be used.
46 auto_convert_dtypes:
47 Set to ``True`` to automatically convert all inputs to float64.
49 Returns
50 -------
51 :
52 A callable that applies ``func`` to the elements of the variables passed to it.
54 Examples
55 --------
57 We can define a fused multiply-add operation as follows:
59 >>> def fmadd(a, b, c):
60 ... return a * b + c
62 >>> func = sc.elemwise_func(fmadd)
64 >>> x = sc.linspace('x', 0.0, 1.0, num=4, unit='m')
65 >>> y = x - 0.2 * x
66 >>> z = sc.scalar(1.2, unit='m**2')
68 >>> func(x, y, z)
69 <scipp.Variable> (x: 4) float64 [m^2] [1.2, 1.28889, 1.55556, 2]
71 Note that ``fmadd(x, y, z)`` would have the same effect in this case, but requires
72 a potentially large intermediate allocation for the result of "a * b".
73 """
75 def decorator(f):
76 cfunc = _as_numba_cfunc(f, unit_func=unit_func)
78 @functools.wraps(f)
79 def transform_custom(*args: Variable) -> Variable:
80 if auto_convert_dtypes:
81 args = [arg.to(dtype='float64', copy=False) for arg in args]
82 return cpp_transform(cfunc, *args)
84 return transform_custom
86 if func is None:
87 return decorator
88 return decorator(func)