# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
"""This module provides tools for running workflows in a streaming fashion."""
from abc import ABC, abstractmethod
from collections.abc import Callable
from copy import deepcopy
from typing import Any, Generic, TypeVar
import networkx as nx
import sciline
import scipp as sc
T = TypeVar('T')
[docs]
def maybe_hist(value: T) -> T:
"""
Convert value to a histogram if it is not already a histogram.
This is the default pre-processing used by accumulators.
Parameters
----------
value:
Value to be converted to a histogram.
Returns
-------
:
Histogram.
"""
if not isinstance(value, sc.Variable | sc.DataArray):
return value
return value if value.bins is None else value.hist()
[docs]
class Accumulator(ABC, Generic[T]):
"""
Abstract base class for accumulators.
Accumulators are used to accumulate values over multiple chunks.
"""
[docs]
def __init__(self, preprocess: Callable[[T], T] | None = maybe_hist) -> None:
"""
Parameters
----------
preprocess:
Preprocessing function to be applied to pushed values prior to accumulation.
"""
self._preprocess = preprocess
[docs]
def push(self, value: T) -> None:
"""
Push a value to the accumulator.
Parameters
----------
value:
Value to be pushed to the accumulator.
"""
if self._preprocess is not None:
value = self._preprocess(value)
self._do_push(value)
@abstractmethod
def _do_push(self, value: T) -> None: ...
@property
def is_empty(self) -> bool:
"""
Check if the accumulator is empty.
Returns
-------
:
True if the accumulator is empty, False otherwise.
"""
return False
@property
def value(self) -> T:
"""
Get the accumulated value.
Returns
-------
:
Accumulated value.
Raises
------
ValueError
If the accumulator is empty.
"""
if self.is_empty:
raise ValueError("Cannot get value from empty accumulator")
return self._get_value()
@abstractmethod
def _get_value(self) -> T:
"""Return the accumulated value, assuming it exists."""
[docs]
@abstractmethod
def clear(self) -> None:
"""
Clear the accumulator, resetting it to its initial state.
"""
[docs]
class EternalAccumulator(Accumulator[T]):
"""
Simple accumulator that adds pushed values immediately.
Does not support event data.
"""
[docs]
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._value: T | None = None
@property
def is_empty(self) -> bool:
return self._value is None
def _get_value(self) -> T:
return deepcopy(self._value)
def _do_push(self, value: T) -> None:
if self._value is None:
self._value = deepcopy(value)
else:
self._value += value
[docs]
def clear(self) -> None:
"""Clear the accumulated value."""
self._value = None
[docs]
class RollingAccumulator(Accumulator[T]):
"""
Accumulator that adds pushed values to a rolling window.
Does not support event data.
"""
[docs]
def __init__(self, window: int = 10, **kwargs: Any) -> None:
"""
Parameters
----------
window:
Size of the rolling window.
"""
super().__init__(**kwargs)
self._window = window
self._values: list[T] = []
@property
def is_empty(self) -> bool:
return len(self._values) == 0
def _get_value(self) -> T:
# Naive and potentially slow implementation if values and/or window are large!
return sc.reduce(self._values).sum()
def _do_push(self, value: T) -> None:
self._values.append(value)
if len(self._values) > self._window:
self._values.pop(0)
[docs]
def clear(self) -> None:
"""Clear the accumulated values."""
self._values = []
[docs]
class MinAccumulator(Accumulator):
"""Keeps the minimum value seen so far.
Only supports scalar values.
"""
[docs]
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._cur_min: sc.Variable | None = None
def _do_push(self, value: sc.Variable) -> None:
if self._cur_min is None:
self._cur_min = value
else:
self._cur_min = min(self._cur_min, value)
@property
def is_empty(self) -> bool:
"""Check if the accumulator has collected a minimum value."""
return self._cur_min is None
def _get_value(self) -> Any:
return self._cur_min
[docs]
def clear(self) -> None:
"""Clear the accumulated minimum value."""
self._cur_min = None
[docs]
class MaxAccumulator(Accumulator):
"""Keeps the maximum value seen so far.
Only supports scalar values.
"""
[docs]
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._cur_max: sc.Variable | None = None
@property
def is_empty(self) -> bool:
"""Check if the accumulator has collected a maximum value."""
return self._cur_max is None
def _do_push(self, value: sc.Variable) -> None:
if self._cur_max is None:
self._cur_max = value
else:
self._cur_max = max(self._cur_max, value)
def _get_value(self) -> sc.Variable | None:
return self._cur_max
[docs]
def clear(self) -> None:
"""Clear the accumulated maximum value."""
self._cur_max = None
[docs]
class StreamProcessor:
"""
Wrap a base workflow for streaming processing of chunks.
Note that this class can not determine if the workflow is valid for streamed
processing based on the input keys. In particular, it is the responsibility of the
user to ensure that the workflow is "linear" with respect to the dynamic keys up to
the accumulation keys.
Similarly, the stream processor cannot determine from the workflow structure whether
context updates are compatible with the accumulated data. Accumulators are not
cleared automatically. This is best illustrated with an example:
- If the context is the detector rotation angle, and we accumulate I(Q) (or a
prerequisite of I(Q)), then updating the detector angle context is compatible with
previous data, assuming Q for each new chunk is computed based on the angle.
- If the context is the sample temperature, and we accumulate I(Q), then updating
the temperature context is not compatible with previous data. Accumulating I(Q, T)
could be compatible in this case.
Since the correctness cannot be determined from the workflow structure, we recommend
implementing processing steps in a way to catch such problems. For example, adding
the temperature as a coordinate to the I(Q) data array should allow for
automatically raising in the accumulator if the temperature changes.
"""
[docs]
def __init__(
self,
base_workflow: sciline.Pipeline,
*,
dynamic_keys: tuple[sciline.typing.Key, ...],
context_keys: tuple[sciline.typing.Key, ...] = (),
target_keys: tuple[sciline.typing.Key, ...],
accumulators: dict[sciline.typing.Key, Accumulator | Callable[..., Accumulator]]
| tuple[sciline.typing.Key, ...],
allow_bypass: bool = False,
) -> None:
"""
Create a stream processor.
Parameters
----------
base_workflow:
Workflow to be used for processing chunks.
dynamic_keys:
Keys that are expected to be updated with each chunk. These keys cannot
depend on each other or on context_keys.
context_keys:
Keys that define context for processing chunks and may change occasionally.
These keys cannot overlap with dynamic_keys or depend on each other or on
dynamic_keys.
target_keys:
Keys to be computed and returned.
accumulators:
Keys at which to accumulate values and their accumulators. If a tuple is
passed, :py:class:`EternalAccumulator` is used for all keys. Otherwise, a
dict mapping keys to accumulator instances can be passed. If a dict value is
a callable, base_workflow.bind_and_call(value) is used to make an instance.
allow_bypass:
If True, allow bypassing accumulators for keys that are not in the
accumulators dict. This is useful for dynamic keys that are not "terminated"
in any accumulator. USE WITH CARE! This will lead to incorrect results
unless the values for these keys are valid for all chunks comprised in the
final accumulators at the point where :py:meth:`finalize` is called.
"""
self._dynamic_keys = set(dynamic_keys)
self._context_keys = set(context_keys)
# Validate that dynamic and context keys do not overlap
overlap = self._dynamic_keys & self._context_keys
if overlap:
raise ValueError(f"Keys cannot be both dynamic and context: {overlap}")
# Check dynamic/context keys don't depend on other dynamic/context keys
graph = base_workflow.underlying_graph
special_keys = self._dynamic_keys | self._context_keys
for key in special_keys:
if key not in graph:
continue
ancestors = nx.ancestors(graph, key)
special_ancestors = ancestors & special_keys
downstream = 'Dynamic' if key in self._dynamic_keys else 'Context'
if special_ancestors:
raise ValueError(
f"{downstream} key '{key}' depends on other dynamic/context keys: "
f"{special_ancestors}. This is not supported."
)
workflow = sciline.Pipeline()
for key in target_keys:
workflow[key] = base_workflow[key]
for key in dynamic_keys:
workflow[key] = None # hack to prune branches
for key in context_keys:
workflow[key] = None
# Find and pre-compute static nodes as far down the graph as possible
nodes = _find_descendants(workflow, dynamic_keys + context_keys)
last_static = _find_parents(workflow, nodes) - nodes
for key, value in base_workflow.compute(last_static).items():
workflow[key] = value
# Nodes that may need updating on context change but should be cached otherwise.
dynamic_nodes = _find_descendants(workflow, dynamic_keys)
# Nodes as far "down" in the graph as possible, right before the dynamic nodes.
# This also includes target keys that are not dynamic but context-dependent.
context_to_cache = (
(_find_parents(workflow, dynamic_nodes) | set(target_keys)) - dynamic_nodes
) & _find_descendants(workflow, context_keys)
graph = workflow.underlying_graph
self._context_key_to_cached_context_nodes_map = {
context_key: ({context_key} | nx.descendants(graph, context_key))
& context_to_cache
for context_key in self._context_keys
if context_key in graph
}
self._context_workflow = workflow.copy()
self._process_chunk_workflow = workflow.copy()
self._finalize_workflow = workflow.copy()
self._accumulators = (
accumulators
if isinstance(accumulators, dict)
else {key: EternalAccumulator() for key in accumulators}
)
# Map each accumulator to its dependent dynamic keys
self._accumulator_dependencies = {
acc_key: nx.ancestors(graph, acc_key) & self._dynamic_keys
for acc_key in self._accumulators
if acc_key in graph
}
# Depending on the target_keys, some accumulators can be unused and should not
# be computed when adding a chunk.
self._accumulators = {
key: value for key, value in self._accumulators.items() if key in graph
}
# Create accumulators unless instances were passed. This allows for initializing
# accumulators with arguments that depend on the workflow such as bin edges,
# which would otherwise be hard to obtain.
self._accumulators = {
key: value
if isinstance(value, Accumulator)
else base_workflow.bind_and_call(value)
for key, value in self._accumulators.items()
}
self._target_keys = target_keys
self._allow_bypass = allow_bypass
[docs]
def set_context(self, context: dict[sciline.typing.Key, Any]) -> None:
"""
Set the context for processing chunks.
Parameters
----------
context:
Context to be set.
"""
needs_recompute = set()
for key in context:
if key not in self._context_keys:
raise ValueError(f"Key '{key}' is not a context key")
needs_recompute |= self._context_key_to_cached_context_nodes_map[key]
for key, value in context.items():
self._context_workflow[key] = value
results = self._context_workflow.compute(needs_recompute)
for key, value in results.items():
if key in self._target_keys:
# Context-dependent key is direct target, independent of dynamic nodes.
self._finalize_workflow[key] = value
else:
self._process_chunk_workflow[key] = value
[docs]
def add_chunk(
self, chunks: dict[sciline.typing.Key, Any]
) -> dict[sciline.typing.Key, Any]:
"""
Legacy interface for accumulating values from chunks and finalizing the result.
It is recommended to use :py:meth:`accumulate` and :py:meth:`finalize` instead.
Parameters
----------
chunks:
Chunks to be processed.
Returns
-------
:
Finalized result.
"""
self.accumulate(chunks)
return self.finalize()
[docs]
def accumulate(self, chunks: dict[sciline.typing.Key, Any]) -> None:
"""
Accumulate values from chunks without finalizing the result.
Parameters
----------
chunks:
Chunks to be processed.
Raises
------
ValueError
If non-dynamic keys are provided in chunks.
If accumulator computation requires dynamic keys not provided in chunks.
"""
non_dynamic = set(chunks) - self._dynamic_keys
if non_dynamic:
raise ValueError(
f"Can only update dynamic keys. Got non-dynamic keys: {non_dynamic}"
)
accumulators_to_update = []
for acc_key, deps in self._accumulator_dependencies.items():
if deps.isdisjoint(chunks.keys()):
continue
if not deps.issubset(chunks.keys()):
raise ValueError(
f"Accumulator '{acc_key}' requires dynamic keys "
f"{deps - chunks.keys()} not provided in the current chunk."
)
accumulators_to_update.append(acc_key)
for key, value in chunks.items():
self._process_chunk_workflow[key] = value
# There can be dynamic keys that do not "terminate" in any accumulator. In
# that case, we need to make sure they can be and are used when computing
# the target keys.
if self._allow_bypass:
self._finalize_workflow[key] = value
to_accumulate = self._process_chunk_workflow.compute(accumulators_to_update)
for key, processed in to_accumulate.items():
self._accumulators[key].push(processed)
[docs]
def finalize(self) -> dict[sciline.typing.Key, Any]:
"""
Get the final result by computing the target keys based on accumulated values.
Returns
-------
:
Finalized result.
"""
for key in self._accumulators:
self._finalize_workflow[key] = self._accumulators[key].value
return self._finalize_workflow.compute(self._target_keys)
[docs]
def clear(self) -> None:
"""
Clear all accumulators, resetting them to their initial state.
This is useful for restarting a streaming computation without
creating a new StreamProcessor instance.
"""
for accumulator in self._accumulators.values():
accumulator.clear()
def _find_descendants(
workflow: sciline.Pipeline, keys: tuple[sciline.typing.Key, ...]
) -> set[sciline.typing.Key]:
graph = workflow.underlying_graph
descendants = set()
for key in keys:
descendants |= nx.descendants(graph, key)
return descendants | set(keys)
def _find_parents(
workflow: sciline.Pipeline, keys: tuple[sciline.typing.Key, ...]
) -> set[sciline.typing.Key]:
graph = workflow.underlying_graph
parents = set()
for key in keys:
parents |= set(graph.predecessors(key))
return parents