diff --git a/DESIGN_ISSUES.md b/DESIGN_ISSUES.md index 8f035042..340e9276 100644 --- a/DESIGN_ISSUES.md +++ b/DESIGN_ISSUES.md @@ -416,10 +416,10 @@ message text. A shared parameterized validation helper would eliminate the dupli ## `src/orcapod/core/operators/` — Async execution ### O1 — Operators use barrier-mode `async_execute` only; streaming/incremental overrides needed -**Status:** open +**Status:** in progress **Severity:** medium -All operators currently use the default barrier-mode `async_execute` inherited from +All operators originally used the default barrier-mode `async_execute` inherited from `StaticOutputPod`: collect all input rows into memory, materialize to `ArrowTableStream`(s), run the existing sync `static_process`, then emit results. This works correctly but negates the latency and memory benefits of the push-based channel model. @@ -428,20 +428,27 @@ Three categories of improvement are planned: 1. **Streaming overrides (row-by-row, zero buffering)** — for operators that process rows independently: - - `PolarsFilter` — evaluate predicate per row, emit or drop immediately - - `MapTags` / `MapPackets` — rename columns per row, emit immediately - - `SelectTagColumns` / `SelectPacketColumns` — project columns per row, emit immediately - - `DropTagColumns` / `DropPacketColumns` — drop columns per row, emit immediately + - ~~`PolarsFilter` — evaluate predicate per row, emit or drop immediately~~ (kept barrier: + Polars expressions require DataFrame context for evaluation) + - `MapTags` / `MapPackets` — rename columns per row, emit immediately ✅ + - `SelectTagColumns` / `SelectPacketColumns` — project columns per row, emit immediately ✅ + - `DropTagColumns` / `DropPacketColumns` — drop columns per row, emit immediately ✅ 2. **Incremental overrides (stateful, eager emit)** — for multi-input operators that can produce partial results before all inputs are consumed: - - `Join` — symmetric hash join: index each input by tag keys, emit matches as they arrive - - `MergeJoin` — same approach, with list-merge on colliding packet columns - - `SemiJoin` — buffer the right (filter) input fully, then stream the left input and emit - matches (right must be fully consumed first, but left can stream) + - `Join` — symmetric hash join for 2 inputs (streaming, with correct + system-tag name-extending via `input_pipeline_hashes` passed directly + to `async_execute`); barrier fallback for N>2 inputs via `static_process`. ✅ + - `MergeJoin` — kept barrier: complex column-merging logic + - `SemiJoin` — build right, stream left through hash lookup ✅ + +3. **Streaming accumulation:** + - `Batch` — emit full batches as they accumulate (`batch_size > 0`); barrier fallback + when `batch_size == 0` (batch everything) ✅ -3. **Barrier-only (no change needed):** - - `Batch` — inherently requires all rows before grouping; barrier mode is correct +**Remaining:** `PolarsFilter` (barrier), `MergeJoin` (barrier) could receive incremental +overrides in the future but require careful handling of Polars expression evaluation and +system-tag evolution respectively. --- @@ -527,6 +534,29 @@ await AddResult(grade_pf).async_execute([input_ch], output_ch) --- +## `src/orcapod/hashing/semantic_hashing/` + +### H1 — Semantic hasher does not support PEP 604 union types (`int | None`) +**Status:** open +**Severity:** medium + +The `BaseSemanticHasher` raises `BeartypeDoorNonpepException` when hashing a +`PythonPacketFunction` whose return type uses PEP 604 syntax (`int | None`). +The hasher's `_handle_unknown` path receives `types.UnionType` (the Python 3.10+ type for +`X | Y` expressions) and has no registered handler for it. + +`typing.Optional[int]` also fails (different error path through beartype). + +This means packet functions cannot use union return types — a common pattern for functions +that may filter packets by returning `None`. + +**Workaround:** Use non-union return types and raise/return sentinel values instead. + +**Fix needed:** Register a `TypeHandlerProtocol` for `types.UnionType` (and +`typing.Union`/`typing.Optional`) in the semantic hasher's type handler registry. + +--- + ### G2 — Pod Group abstraction for other composite pod patterns **Status:** open **Severity:** low diff --git a/orcapod-design.md b/orcapod-design.md index 169ab142..7684a2cc 100644 --- a/orcapod-design.md +++ b/orcapod-design.md @@ -468,15 +468,57 @@ async def async_execute( Nodes consume `(Tag, Packet)` pairs from input channels and produce them to an output channel. This enables push-based, streaming execution where data flows through the pipeline as soon as it's available, with backpressure propagated via bounded channel buffers. -**Operator async strategies:** +**FunctionPod async strategy:** Streaming mode — each input `(tag, packet)` is processed independently with semaphore-controlled concurrency. Uses `asyncio.TaskGroup` for structured concurrency. + +#### Operator Async Strategies + +Each operator overrides `async_execute` with the most efficient streaming pattern its semantics permit. The default fallback (inherited from `StaticOutputPod`) is barrier mode: collect all inputs via `asyncio.gather`, materialize to `ArrowTableStream`, call `static_process`, and emit results. Operators override this default when a more incremental strategy is possible. | Strategy | Description | Operators | |---|---|---| -| **Barrier mode** (default) | Collect all inputs, run `static_process`, emit results | Batch (inherently barrier) | -| **Streaming overrides** | Process rows individually, zero buffering | PolarsFilter, MapTags, MapPackets, Select/Drop columns | -| **Incremental overrides** | Stateful, emit partial results as inputs arrive | Join (symmetric hash join), MergeJoin, SemiJoin (buffer right, stream left) | +| **Per-row streaming** | Transform each `(Tag, Packet)` independently as it arrives; zero buffering beyond the current row | SelectTagColumns, SelectPacketColumns, DropTagColumns, DropPacketColumns, MapTags, MapPackets | +| **Accumulate-and-emit** | Buffer rows up to `batch_size`, emit full batches immediately, flush partial at end | Batch (`batch_size > 0`) | +| **Build-probe** | Collect one side fully (build), then stream the other through a hash lookup (probe) | SemiJoin | +| **Symmetric hash join** | Read both sides concurrently, buffer + index both, emit matches as they're found | Join (2 inputs) | +| **Barrier mode** | Collect all inputs, run `static_process`, emit results | PolarsFilter, MergeJoin, Batch (`batch_size = 0`), Join (N > 2 inputs) | -**FunctionPod async strategy:** Streaming mode — each input `(tag, packet)` is processed independently with semaphore-controlled concurrency. Uses `asyncio.TaskGroup` for structured concurrency. +#### Per-Row Streaming (Unary Column/Map Operators) + +For operators that transform each row independently (column selection, column dropping, column renaming), the async path iterates `async for tag, packet in inputs[0]` and applies the transformation per row. Column metadata (which columns to drop, the rename map, etc.) is computed lazily on the first row and cached for subsequent rows. This avoids materializing the entire input into an Arrow table, enabling true pipeline-level streaming where upstream producers and downstream consumers run concurrently. + +#### Accumulate-and-Emit (Batch) + +When `batch_size > 0`, Batch accumulates rows into a buffer and emits a batched result stream each time the buffer reaches `batch_size`. Any partial batch at the end is emitted unless `drop_partial_batch` is set. When `batch_size = 0` (meaning "batch everything into one group"), the operator must see all input before producing output, so it falls back to barrier mode. + +#### Build-Probe (SemiJoin) + +SemiJoin is non-commutative: the left side is filtered by the right side. The async implementation collects the right (build) side fully, constructs a hash set of its key tuples, then streams the left (probe) side through the lookup — emitting each left row whose keys appear in the right set. This is the same pattern as Kafka's KStream-KTable join: the table side is materialized, the stream side drives output. + +#### Symmetric Hash Join + +The 2-input Join uses a symmetric hash join — the same algorithm used by Apache Kafka for KStream-KStream joins and by Apache Flink for regular streaming joins. Both input channels are drained concurrently into a shared `asyncio.Queue`. For each arriving row: + +1. Buffer the row on its side and index it by the shared key columns. +2. Probe the opposite side's index for matching keys. +3. Emit all matches immediately. + +When the first rows from both sides have arrived, the shared key columns are determined (intersection of tag column names). Any rows that arrived before shared keys were known are re-indexed and cross-matched in a one-time reconciliation step. + +**Comparison with industry stream processors:** + +| Aspect | Kafka Streams (KStream-KStream) | Apache Flink (Regular Join) | OrcaPod | +|---|---|---|---| +| Algorithm | Symmetric windowed hash join | Symmetric hash join with state TTL | Symmetric hash join | +| Windowing | Required (sliding window bounds state) | Optional (TTL evicts old state) | Not needed (finite streams) | +| State backend | RocksDB state stores for fault tolerance | RocksDB / heap state with checkpointing | In-memory buffers | +| State cleanup | Window expiry evicts old records | TTL or watermark eviction | Natural termination — inputs are finite | +| N-way joins | Chained pairwise joins | Chained pairwise joins | 2-way: symmetric hash; N > 2: barrier + Arrow join | + +The symmetric hash join is optimal for our use case: it emits results with minimum latency (as soon as a match exists on both sides) and requires no windowing complexity since OrcaPod streams are finite. For N > 2 inputs, the operator falls back to barrier mode with Arrow-level join execution, which is efficient for bounded data and avoids the complexity of chaining pairwise streaming joins. + +**Why not build-probe for Join?** Since Join is commutative and input sizes are unknown upfront, there is no principled way to choose which side to build vs. probe. Symmetric hash join avoids this asymmetry. SemiJoin, being non-commutative, has a natural build (right) and probe (left) side. + +**Why barrier for PolarsFilter and MergeJoin?** PolarsFilter requires a Polars DataFrame context for predicate evaluation, which needs full materialization. MergeJoin's column-merging semantics (colliding columns become sorted `list[T]`) require seeing all rows to produce correctly typed output columns. ### Sync / Async Equivalence diff --git a/plan.md b/plan.md new file mode 100644 index 00000000..5ac73a63 --- /dev/null +++ b/plan.md @@ -0,0 +1,830 @@ +# Plan: Unified `process_packet` / `async_process_packet` + Node `async_execute` + +## Goal + +Establish `process_packet` and `async_process_packet` as **the** universal per-packet +interface across FunctionPod, FunctionPodStream, FunctionNode, and PersistentFunctionNode. +All iteration paths — sequential, concurrent, and async — route through these methods. +Add `async_execute` to all four Node classes. Add cache-aware `async_call` to +`CachedPacketFunction`. Remove `_execute_concurrent` module-level helper. + +--- + +## What exists today + +### Class hierarchy + +``` +_FunctionPodBase (TraceableBase) + ├── process_packet(tag, packet) → calls packet_function.call(packet) + ├── FunctionPod + │ ├── process() → FunctionPodStream + │ └── async_execute() → calls packet_function.async_call(packet) DIRECTLY + │ + FunctionPodStream (StreamBase) + │ ├── _iter_packets_sequential() → calls _function_pod.process_packet(tag, packet) ✓ + │ └── _iter_packets_concurrent() → calls _execute_concurrent(packet_function, ...) DIRECTLY + │ + FunctionNode (StreamBase) + │ ├── _iter_packets_sequential() → calls _packet_function.call(packet) DIRECTLY + │ ├── _iter_packets_concurrent() → calls _execute_concurrent(_packet_function, ...) DIRECTLY + │ └── (no async_execute) + │ + PersistentFunctionNode (FunctionNode) + ├── process_packet(tag, packet) → calls _packet_function.call(packet, skip_cache_*=...) + │ then add_pipeline_record(...) + ├── iter_packets() → Phase 1: replay from DB + │ Phase 2: calls self.process_packet(tag, packet) ✓ + └── (no async_execute) + +OperatorNode (StreamBase) + ├── run() → calls _operator.process(*streams) + └── (no async_execute) + +PersistentOperatorNode (OperatorNode) + ├── _compute_and_store() → calls _operator.process() + bulk DB write + ├── _replay_from_cache() → loads from DB + └── (no async_execute) +``` + +### Module-level helpers + +```python +def _executor_supports_concurrent(packet_function) -> bool: + """True if the pf's executor supports concurrent execution.""" + +def _execute_concurrent(packet_function, packets) -> list[PacketProtocol | None]: + """Submit all packets concurrently via asyncio.gather(pf.async_call(...)). + Falls back to sequential pf.call() if already inside a running event loop.""" +``` + +### Problems + +1. **FunctionPod.async_execute** bypasses `process_packet` — calls `packet_function.async_call` + directly (line 317). +2. **FunctionPodStream._iter_packets_concurrent** bypasses `process_packet` — calls + `_execute_concurrent(packet_function, ...)` directly (line 472). +3. **FunctionNode._iter_packets_sequential** bypasses any process_packet — calls + `_packet_function.call(packet)` directly (line 831). +4. **FunctionNode._iter_packets_concurrent** same — calls `_execute_concurrent` directly + (line 852). +5. **CachedPacketFunction.async_call** inherits from `PacketFunctionWrapper` — completely + **bypasses the cache** (no lookup, no recording). +6. **No `async_process_packet`** exists anywhere. +7. **No `async_execute`** on any Node class. +8. **`_execute_concurrent`** is a module-level function that takes a raw `packet_function` + and list of bare `packets` — no way to route through `process_packet`. + +--- + +## Design principles + +### A. `process_packet` / `async_process_packet` is the single per-packet entry point + +Every class in the function pod hierarchy defines these two methods. **All** iteration and +execution paths go through them — sequential, concurrent, and async. No direct +`packet_function.call()` or `packet_function.async_call()` calls outside of these methods. + +``` +_FunctionPodBase.process_packet(tag, pkt) → packet_function.call(pkt) +_FunctionPodBase.async_process_packet(tag, pkt) → await packet_function.async_call(pkt) + +FunctionNode.process_packet(tag, pkt) → self._function_pod.process_packet(tag, pkt) +FunctionNode.async_process_packet(tag, pkt) → await self._function_pod.async_process_packet(tag, pkt) + +PersistentFunctionNode.process_packet(tag, pkt) → cache check → self._function_pod.process_packet → pipeline record +PersistentFunctionNode.async_process_packet(tag, pkt) → cache check → await self._function_pod.async_process_packet → pipeline record +``` + +Wait — there's a subtlety with PersistentFunctionNode. Today its `process_packet` calls +`self._packet_function.call(packet, skip_cache_lookup=..., skip_cache_insert=...)` directly, +where `self._packet_function` is a `CachedPacketFunction` (which wraps the original pf). +It does NOT delegate to the pod's `process_packet`. That's because PersistentFunctionNode +needs to pass `skip_cache_*` kwargs that the base `process_packet` doesn't accept. + +The cleanest structure: + +``` +PersistentFunctionNode.process_packet(tag, pkt) + → self._packet_function.call(pkt, skip_cache_*=...) # CachedPacketFunction (sync) + → self.add_pipeline_record(...) # pipeline DB (sync) + +PersistentFunctionNode.async_process_packet(tag, pkt) + → await self._packet_function.async_call(pkt, skip_cache_*=...) # CachedPacketFunction (async) + → self.add_pipeline_record(...) # pipeline DB (sync) +``` + +This is the same as today for the sync path. The `CachedPacketFunction` handles the result +cache internally. The `PersistentFunctionNode` handles pipeline records. Neither delegates +to the pod's `process_packet` — the pod is bypassed because the `CachedPacketFunction` +replaced the raw packet function in `__init__`. + +### B. Concurrent iteration routes through `async_process_packet` + +The concurrent path is inherently async — it uses `asyncio.gather`. So it naturally routes +through `async_process_packet`. The fallback path (when already inside an event loop) routes +through `process_packet` (sync). + +For **FunctionPodStream**, the target is the pod: +```python +# concurrent +await self._function_pod.async_process_packet(tag, pkt) +# fallback +self._function_pod.process_packet(tag, pkt) +``` + +For **FunctionNode**, the target is `self` — so overrides (PersistentFunctionNode) kick in: +```python +# concurrent +await self.async_process_packet(tag, pkt) +# fallback +self.process_packet(tag, pkt) +``` + +This means PersistentFunctionNode's concurrent path **automatically** gets cache checks + +pipeline records via polymorphism. No special handling needed. + +### C. `_execute_concurrent` is removed + +The module-level `_execute_concurrent(packet_function, packets)` helper is removed. Its +logic (asyncio.gather with event-loop fallback) is inlined into `_iter_packets_concurrent` +methods, but now routes through `process_packet` / `async_process_packet` instead of raw +`packet_function.call` / `packet_function.async_call`. + +The `_executor_supports_concurrent` helper stays — it's just a predicate check. + +### D. Sync and async are cleanly separated execution modes + +- Sync: `iter_packets()` / `as_table()` / `run()` +- Async: `async_execute(inputs, output)` + +They don't populate each other's caches. DB persistence (for Persistent variants) provides +durability that works across both modes. + +### E. OperatorNode delegates to operator, PersistentOperatorNode intercepts for storage + +Operators are opaque stream transformers — no per-packet hook. `OperatorNode` passes through +directly. `PersistentOperatorNode` uses an intermediate channel + `TaskGroup` to forward +results downstream immediately while collecting them for post-hoc DB storage. + +### F. DB operations stay synchronous + +The `ArrowDatabaseProtocol` is sync. All DB reads/writes within async methods are sync calls. +Acceptable because DB is typically in-process and fast. Async DB protocol is deferred. + +--- + +## Implementation steps + +### Step 1: Add `async_process_packet` to `_FunctionPodBase` + +**File:** `src/orcapod/core/function_pod.py` + +Add alongside existing `process_packet` (after line 180): + +```python +async def async_process_packet( + self, tag: TagProtocol, packet: PacketProtocol +) -> tuple[TagProtocol, PacketProtocol | None]: + """Async counterpart of ``process_packet``.""" + return tag, await self.packet_function.async_call(packet) +``` + +### Step 2: Fix `FunctionPod.async_execute` to use `async_process_packet` + +**File:** `src/orcapod/core/function_pod.py` + +Change the `process_one` inner function (lines 315-322): + +```python +async def process_one(tag: TagProtocol, packet: PacketProtocol) -> None: + try: + tag, result_packet = await self.async_process_packet(tag, packet) + if result_packet is not None: + await output.send((tag, result_packet)) + finally: + if sem is not None: + sem.release() +``` + +### Step 3: Fix `FunctionPodStream._iter_packets_concurrent` to use `async_process_packet` + +**File:** `src/orcapod/core/function_pod.py` + +Replace the `_execute_concurrent` call (lines 454-482) with direct `async_process_packet` +routing: + +```python +def _iter_packets_concurrent( + self, +) -> Iterator[tuple[TagProtocol, PacketProtocol]]: + """Collect remaining inputs, execute concurrently, and yield results in order.""" + input_iter = self._cached_input_iterator + + all_inputs: list[tuple[int, TagProtocol, PacketProtocol]] = [] + to_compute: list[tuple[int, TagProtocol, PacketProtocol]] = [] + for i, (tag, packet) in enumerate(input_iter): + all_inputs.append((i, tag, packet)) + if i not in self._cached_output_packets: + to_compute.append((i, tag, packet)) + self._cached_input_iterator = None + + if to_compute: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None: + # Already in event loop — fall back to sequential sync + results = [ + self._function_pod.process_packet(tag, pkt) + for _, tag, pkt in to_compute + ] + else: + # No event loop — run concurrently via asyncio.run + async def _gather() -> list[tuple[TagProtocol, PacketProtocol | None]]: + return list( + await asyncio.gather( + *[ + self._function_pod.async_process_packet(tag, pkt) + for _, tag, pkt in to_compute + ] + ) + ) + + results = asyncio.run(_gather()) + + for (i, _, _), (tag, output_packet) in zip(to_compute, results): + self._cached_output_packets[i] = (tag, output_packet) + + for i, *_ in all_inputs: + tag, packet = self._cached_output_packets[i] + if packet is not None: + yield tag, packet +``` + +**Note:** The method signature drops the `packet_function` parameter — it no longer needs +it since it routes through `self._function_pod`. + +The `iter_packets` method that calls this also needs updating — remove the `pf` argument: + +```python +def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: + if self.is_stale: + self.clear_cache() + if self._cached_input_iterator is not None: + if _executor_supports_concurrent(self._function_pod.packet_function): + yield from self._iter_packets_concurrent() + else: + yield from self._iter_packets_sequential() + else: + for i in range(len(self._cached_output_packets)): + tag, packet = self._cached_output_packets[i] + if packet is not None: + yield tag, packet +``` + +### Step 4: Fix `FunctionNode._iter_packets_sequential` to use `process_packet` + +**File:** `src/orcapod/core/function_pod.py` + +Change line 831 from: +```python +output_packet = self._packet_function.call(packet) +self._cached_output_packets[i] = (tag, output_packet) +``` +to: +```python +tag, output_packet = self.process_packet(tag, packet) +self._cached_output_packets[i] = (tag, output_packet) +``` + +### Step 5: Fix `FunctionNode._iter_packets_concurrent` to use `async_process_packet` + +**File:** `src/orcapod/core/function_pod.py` + +Same transformation as Step 3, but routing through `self` instead of `self._function_pod`: + +```python +def _iter_packets_concurrent( + self, +) -> Iterator[tuple[TagProtocol, PacketProtocol]]: + """Collect remaining inputs, execute concurrently, and yield results in order.""" + input_iter = self._cached_input_iterator + + all_inputs: list[tuple[int, TagProtocol, PacketProtocol]] = [] + to_compute: list[tuple[int, TagProtocol, PacketProtocol]] = [] + for i, (tag, packet) in enumerate(input_iter): + all_inputs.append((i, tag, packet)) + if i not in self._cached_output_packets: + to_compute.append((i, tag, packet)) + self._cached_input_iterator = None + + if to_compute: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None: + # Already in event loop — fall back to sequential sync + results = [ + self.process_packet(tag, pkt) + for _, tag, pkt in to_compute + ] + else: + # No event loop — run concurrently via asyncio.run + async def _gather() -> list[tuple[TagProtocol, PacketProtocol | None]]: + return list( + await asyncio.gather( + *[ + self.async_process_packet(tag, pkt) + for _, tag, pkt in to_compute + ] + ) + ) + + results = asyncio.run(_gather()) + + for (i, _, _), (tag, output_packet) in zip(to_compute, results): + self._cached_output_packets[i] = (tag, output_packet) + + for i, *_ in all_inputs: + tag, packet = self._cached_output_packets[i] + if packet is not None: + yield tag, packet +``` + +**Critical difference from Step 3:** Uses `self.process_packet` / `self.async_process_packet` +instead of `self._function_pod.*`. This means when `PersistentFunctionNode` inherits this +method, it automatically routes through its overridden `process_packet` / +`async_process_packet` which include cache checks + pipeline record storage. + +### Step 6: Remove `_execute_concurrent` + +**File:** `src/orcapod/core/function_pod.py` + +Delete the `_execute_concurrent` function (lines 52-82). Its logic is now inlined into the +`_iter_packets_concurrent` methods. + +### Step 7: Add `process_packet` and `async_process_packet` to `FunctionNode` + +**File:** `src/orcapod/core/function_pod.py` + +FunctionNode currently has no `process_packet`. Add delegation to the function pod: + +```python +def process_packet( + self, tag: TagProtocol, packet: PacketProtocol +) -> tuple[TagProtocol, PacketProtocol | None]: + """Process a single packet by delegating to the function pod.""" + return self._function_pod.process_packet(tag, packet) + +async def async_process_packet( + self, tag: TagProtocol, packet: PacketProtocol +) -> tuple[TagProtocol, PacketProtocol | None]: + """Async counterpart of ``process_packet``.""" + return await self._function_pod.async_process_packet(tag, packet) +``` + +### Step 8: Add `FunctionNode.async_execute` + +**File:** `src/orcapod/core/function_pod.py` + +Sequential streaming through `async_process_packet`: + +```python +async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], +) -> None: + """Streaming async execution — process each packet via async_process_packet.""" + try: + async for tag, packet in inputs[0]: + tag, result_packet = await self.async_process_packet(tag, packet) + if result_packet is not None: + await output.send((tag, result_packet)) + finally: + await output.close() +``` + +### Step 9: Add async cache-aware `async_call` to `CachedPacketFunction` + +**File:** `src/orcapod/core/packet_function.py` + +Override `async_call` to mirror the sync `call()` logic (lines 508-533): + +```python +async def async_call( + self, + packet: PacketProtocol, + *, + skip_cache_lookup: bool = False, + skip_cache_insert: bool = False, +) -> PacketProtocol | None: + """Async counterpart of ``call`` with cache check and recording.""" + output_packet = None + if not skip_cache_lookup: + logger.info("Checking for cache...") + output_packet = self.get_cached_output_for_packet(packet) + if output_packet is not None: + logger.info(f"Cache hit for {packet}!") + if output_packet is None: + output_packet = await self._packet_function.async_call(packet) + if output_packet is not None: + if not skip_cache_insert: + self.record_packet(packet, output_packet) + output_packet = output_packet.with_meta_columns( + **{self.RESULT_COMPUTED_FLAG: True} + ) + return output_packet +``` + +### Step 10: Add `async_process_packet` to `PersistentFunctionNode` + +**File:** `src/orcapod/core/function_pod.py` + +PersistentFunctionNode already has `process_packet` (line 1027-1066) which calls +`self._packet_function.call(packet, skip_cache_*=...)` (where `_packet_function` is a +`CachedPacketFunction`) then `self.add_pipeline_record(...)`. Add the async counterpart: + +```python +async def async_process_packet( + self, + tag: TagProtocol, + packet: PacketProtocol, + skip_cache_lookup: bool = False, + skip_cache_insert: bool = False, +) -> tuple[TagProtocol, PacketProtocol | None]: + """Async counterpart of ``process_packet``. + + Uses the CachedPacketFunction's async_call for computation + result caching. + Pipeline record storage is synchronous (DB protocol is sync). + """ + output_packet = await self._packet_function.async_call( + packet, + skip_cache_lookup=skip_cache_lookup, + skip_cache_insert=skip_cache_insert, + ) + + if output_packet is not None: + result_computed = bool( + output_packet.get_meta_value( + self._packet_function.RESULT_COMPUTED_FLAG, False + ) + ) + self.add_pipeline_record( + tag, + packet, + packet_record_id=output_packet.datagram_id, + computed=result_computed, + ) + + return tag, output_packet +``` + +### Step 11: Add `PersistentFunctionNode.async_execute` (two-phase) + +**File:** `src/orcapod/core/function_pod.py` + +Overrides `FunctionNode.async_execute`: + +```python +async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], +) -> None: + """Two-phase async execution: replay cached, then compute missing.""" + try: + # Phase 1: emit existing results from DB + existing = self.get_all_records(columns={"meta": True}) + computed_hashes: set[str] = set() + if existing is not None and existing.num_rows > 0: + tag_keys = self._input_stream.keys()[0] + hash_col = constants.INPUT_PACKET_HASH_COL + computed_hashes = set( + cast(list[str], existing.column(hash_col).to_pylist()) + ) + data_table = existing.drop([hash_col]) + existing_stream = ArrowTableStream(data_table, tag_columns=tag_keys) + for tag, packet in existing_stream.iter_packets(): + await output.send((tag, packet)) + + # Phase 2: process packets not already in the DB + async for tag, packet in inputs[0]: + input_hash = packet.content_hash().to_string() + if input_hash in computed_hashes: + continue + tag, output_packet = await self.async_process_packet(tag, packet) + if output_packet is not None: + await output.send((tag, output_packet)) + finally: + await output.close() +``` + +### Step 12: Add `OperatorNode.async_execute` + +**File:** `src/orcapod/core/operator_node.py` + +Direct pass-through: + +```python +async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], +) -> None: + """Delegate to operator's async_execute.""" + await self._operator.async_execute(inputs, output) +``` + +### Step 13: Extract `_store_output_stream` from `PersistentOperatorNode._compute_and_store` + +**File:** `src/orcapod/core/operator_node.py` + +```python +def _store_output_stream(self, stream: StreamProtocol) -> None: + """Materialize stream and store in the pipeline database with per-row dedup.""" + output_table = stream.as_table( + columns={"source": True, "system_tags": True}, + ) + + arrow_hasher = self.data_context.arrow_hasher + record_hashes = [] + for batch in output_table.to_batches(): + for i in range(len(batch)): + record_hashes.append( + arrow_hasher.hash_table(batch.slice(i, 1)).to_hex() + ) + + output_table = output_table.add_column( + 0, + self.HASH_COLUMN_NAME, + pa.array(record_hashes, type=pa.large_string()), + ) + + self._pipeline_database.add_records( + self.pipeline_path, + output_table, + record_id_column=self.HASH_COLUMN_NAME, + skip_duplicates=True, + ) + + self._cached_output_table = output_table.drop(self.HASH_COLUMN_NAME) +``` + +Refactor `_compute_and_store`: + +```python +def _compute_and_store(self) -> None: + self._cached_output_stream = self._operator.process(*self._input_streams) + if self._cache_mode == CacheMode.OFF: + self._update_modified_time() + return + self._store_output_stream(self._cached_output_stream) + self._update_modified_time() +``` + +### Step 14: Add `PersistentOperatorNode.async_execute` + +**File:** `src/orcapod/core/operator_node.py` + +```python +async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], +) -> None: + """Async execution with cache mode handling. + + REPLAY: emit from DB, close output. + OFF: delegate to operator, forward results. + LOG: delegate to operator, forward + collect results, then store in DB. + """ + try: + if self._cache_mode == CacheMode.REPLAY: + self._replay_from_cache() + assert self._cached_output_stream is not None + for tag, packet in self._cached_output_stream.iter_packets(): + await output.send((tag, packet)) + return # finally block closes output + + # OFF or LOG: delegate to operator, forward results downstream + intermediate = Channel[tuple[TagProtocol, PacketProtocol]]() + collected: list[tuple[TagProtocol, PacketProtocol]] = [] + + async def forward() -> None: + async for item in intermediate.reader: + collected.append(item) + await output.send(item) + + async with asyncio.TaskGroup() as tg: + tg.create_task( + self._operator.async_execute(inputs, intermediate.writer) + ) + tg.create_task(forward()) + + # TaskGroup has completed — all results are in `collected` + # Store if LOG mode (sync DB write, post-hoc) + if self._cache_mode == CacheMode.LOG and collected: + stream = StaticOutputPod._materialize_to_stream(collected) + self._cached_output_stream = stream + self._store_output_stream(stream) + + self._update_modified_time() + finally: + await output.close() +``` + +### Step 15: Add imports + +**`src/orcapod/core/operator_node.py`** — add: +```python +import asyncio +from collections.abc import Sequence + +from orcapod.channels import Channel, ReadableChannel, WritableChannel +from orcapod.core.static_output_pod import StaticOutputPod +``` + +**`src/orcapod/core/function_pod.py`** — already has all needed imports. + +### Step 16: Update regression test for `_execute_concurrent` removal + +**File:** `tests/test_core/test_regression_fixes.py` + +`TestExecuteConcurrentInRunningLoop` imports and tests `_execute_concurrent` directly. +Since we're removing that function, this test class needs to be rewritten to test the +behavior through the actual classes: + +- Test that `FunctionPodStream._iter_packets_concurrent` falls back to sequential + `process_packet` when called inside a running event loop. +- Test that `FunctionNode._iter_packets_concurrent` does the same. + +The tested behavior (event-loop fallback) is preserved — it's just now method-internal +rather than in a standalone helper. + +### Step 17: Tests for new functionality + +**File:** `tests/test_channels/test_node_async_execute.py` (new) + +``` +TestProtocolConformance + - test_function_node_satisfies_async_executable_protocol + - test_persistent_function_node_satisfies_async_executable_protocol + - test_operator_node_satisfies_async_executable_protocol + - test_persistent_operator_node_satisfies_async_executable_protocol + +TestCachedPacketFunctionAsync + - test_async_call_cache_miss_computes_and_records + - test_async_call_cache_hit_returns_cached + - test_async_call_skip_cache_lookup + - test_async_call_skip_cache_insert + +TestProcessPacketRouting + - test_function_pod_stream_sequential_uses_process_packet + - test_function_pod_stream_concurrent_uses_async_process_packet + - test_function_node_sequential_uses_process_packet + - test_function_node_concurrent_uses_async_process_packet + - test_persistent_function_node_concurrent_uses_overridden_async_process_packet + - test_concurrent_fallback_in_event_loop_uses_sync_process_packet + +TestFunctionNodeAsyncExecute + - test_basic_streaming_matches_sync + - test_empty_input_closes_cleanly + - test_none_packets_filtered_out + +TestPersistentFunctionNodeAsyncExecute + - test_no_cache_processes_all_inputs + - test_phase1_emits_cached_results + - test_phase2_skips_cached_computes_new + - test_pipeline_records_created_for_new_packets + - test_result_cache_populated_for_new_packets + +TestOperatorNodeAsyncExecute + - test_unary_op_delegation (SelectPacketColumns) + - test_binary_op_delegation (SemiJoin) + - test_nary_op_delegation (Join) + - test_results_match_sync_run + +TestPersistentOperatorNodeAsyncExecute + - test_off_mode_computes_no_db_write + - test_log_mode_computes_and_stores + - test_log_mode_results_match_sync + - test_replay_mode_emits_from_db + - test_replay_empty_db_returns_empty + +TestEndToEnd + - test_source_to_persistent_function_node_pipeline + - test_source_to_persistent_operator_node_pipeline +``` + +### Step 18: Run full test suite + +```bash +uv run pytest tests/ -x +``` + +--- + +## Summary of all changes + +### Call chains after changes + +**Sync sequential path:** +``` +FunctionPodStream._iter_packets_sequential + → self._function_pod.process_packet(tag, pkt) # already correct + → packet_function.call(pkt) + +FunctionNode._iter_packets_sequential + → self.process_packet(tag, pkt) # CHANGED: was _packet_function.call(pkt) + → self._function_pod.process_packet(tag, pkt) + → packet_function.call(pkt) + +PersistentFunctionNode._iter_packets_sequential (inherited from FunctionNode) + → self.process_packet(tag, pkt) # polymorphism kicks in + → CachedPacketFunction.call(pkt, skip_cache_*=...) # cache check + compute + record + → self.add_pipeline_record(...) # pipeline DB +``` + +**Sync concurrent path:** +``` +FunctionPodStream._iter_packets_concurrent + → asyncio.run(gather( + self._function_pod.async_process_packet(tag, pkt) ... # CHANGED: was _execute_concurrent + )) + OR (if event loop running): + self._function_pod.process_packet(tag, pkt) ... # fallback + +FunctionNode._iter_packets_concurrent + → asyncio.run(gather( + self.async_process_packet(tag, pkt) ... # CHANGED: was _execute_concurrent + )) + OR (if event loop running): + self.process_packet(tag, pkt) ... # fallback + +PersistentFunctionNode._iter_packets_concurrent (inherited from FunctionNode) + → asyncio.run(gather( + self.async_process_packet(tag, pkt) ... # polymorphism kicks in + → await CachedPacketFunction.async_call(pkt) # cache + compute + → self.add_pipeline_record(...) # pipeline DB + )) +``` + +**Async execution path:** +``` +FunctionPod.async_execute + → await self.async_process_packet(tag, pkt) # CHANGED: was packet_function.async_call + → await packet_function.async_call(pkt) + +FunctionNode.async_execute # NEW + → await self.async_process_packet(tag, pkt) + → await self._function_pod.async_process_packet(tag, pkt) + → await packet_function.async_call(pkt) + +PersistentFunctionNode.async_execute # NEW (two-phase) + Phase 1: emit from DB + Phase 2: + → await self.async_process_packet(tag, pkt) # polymorphic override + → await CachedPacketFunction.async_call(pkt) # cache + compute + → self.add_pipeline_record(...) # pipeline DB (sync) + +OperatorNode.async_execute # NEW + → await operator.async_execute(inputs, output) + +PersistentOperatorNode.async_execute # NEW + REPLAY: emit from DB + OFF/LOG: + TaskGroup: + operator.async_execute(inputs, intermediate.writer) + forward(intermediate.reader → output + collect) + if LOG: _store_output_stream(materialize(collected)) # sync DB write +``` + +### Files modified + +| File | Changes | +|------|---------| +| `src/orcapod/core/packet_function.py` | Add `CachedPacketFunction.async_call` override with cache logic | +| `src/orcapod/core/function_pod.py` | (1) Add `_FunctionPodBase.async_process_packet` | +| | (2) Fix `FunctionPod.async_execute` to use `async_process_packet` | +| | (3) Rewrite `FunctionPodStream._iter_packets_concurrent` — route through `_function_pod.async_process_packet` / `process_packet`, drop `packet_function` param | +| | (4) Update `FunctionPodStream.iter_packets` — remove `pf` arg to `_iter_packets_concurrent` | +| | (5) Fix `FunctionNode._iter_packets_sequential` to use `self.process_packet` | +| | (6) Rewrite `FunctionNode._iter_packets_concurrent` — route through `self.async_process_packet` / `self.process_packet` | +| | (7) Add `FunctionNode.process_packet` + `async_process_packet` (delegate to pod) | +| | (8) Add `FunctionNode.async_execute` | +| | (9) Add `PersistentFunctionNode.async_process_packet` (cache + pipeline records) | +| | (10) Add `PersistentFunctionNode.async_execute` (two-phase) | +| | (11) Remove `_execute_concurrent` module-level helper | +| `src/orcapod/core/operator_node.py` | (1) Add imports | +| | (2) Add `OperatorNode.async_execute` (pass-through) | +| | (3) Extract `PersistentOperatorNode._store_output_stream` | +| | (4) Refactor `PersistentOperatorNode._compute_and_store` | +| | (5) Add `PersistentOperatorNode.async_execute` (TaskGroup + post-hoc storage) | +| `tests/test_core/test_regression_fixes.py` | Rewrite `TestExecuteConcurrentInRunningLoop` — test through classes instead of removed helper | +| `tests/test_channels/test_node_async_execute.py` | New test file | diff --git a/src/orcapod/core/function_pod.py b/src/orcapod/core/function_pod.py index 7fa5ca51..a7618d7a 100644 --- a/src/orcapod/core/function_pod.py +++ b/src/orcapod/core/function_pod.py @@ -49,38 +49,6 @@ def _executor_supports_concurrent( return executor is not None and executor.supports_concurrent_execution -def _execute_concurrent( - packet_function: PacketFunctionProtocol, - packets: list[PacketProtocol], -) -> list[PacketProtocol | None]: - """Submit all *packets* to the executor concurrently and return results in order. - - Uses ``asyncio.gather`` to run all tasks concurrently, then blocks - until all complete. If an event loop is already running (e.g. inside - ``async def`` code, notebooks, or pytest-asyncio), falls back to - sequential execution to avoid ``RuntimeError``. - """ - import asyncio - - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = None - - if loop is not None: - # Already inside an event loop -- cannot call asyncio.run(). - # Fall back to sequential synchronous execution. - return [packet_function.call(pkt) for pkt in packets] - - async def _gather() -> list[PacketProtocol | None]: - return list( - await asyncio.gather( - *[packet_function.async_call(pkt) for pkt in packets] - ) - ) - - return asyncio.run(_gather()) - class _FunctionPodBase(TraceableBase): """Base pod that applies a packet function to each input packet.""" @@ -179,6 +147,12 @@ def process_packet( """ return tag, self.packet_function.call(packet) + async def async_process_packet( + self, tag: TagProtocol, packet: PacketProtocol + ) -> tuple[TagProtocol, PacketProtocol | None]: + """Async counterpart of ``process_packet``.""" + return tag, await self.packet_function.async_call(packet) + def handle_input_streams(self, *streams: StreamProtocol) -> StreamProtocol: """Handle multiple input streams by joining them if necessary. @@ -314,7 +288,7 @@ async def async_execute( async def process_one(tag: TagProtocol, packet: PacketProtocol) -> None: try: - result_packet = await self.packet_function.async_call(packet) + tag, result_packet = await self.async_process_packet(tag, packet) if result_packet is not None: await output.send((tag, result_packet)) finally: @@ -419,9 +393,8 @@ def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: if self.is_stale: self.clear_cache() if self._cached_input_iterator is not None: - pf = self._function_pod.packet_function - if _executor_supports_concurrent(pf): - yield from self._iter_packets_concurrent(pf) + if _executor_supports_concurrent(self._function_pod.packet_function): + yield from self._iter_packets_concurrent() else: yield from self._iter_packets_sequential() else: @@ -453,7 +426,6 @@ def _iter_packets_sequential( def _iter_packets_concurrent( self, - packet_function: PacketFunctionProtocol, ) -> Iterator[tuple[TagProtocol, PacketProtocol]]: """Collect remaining inputs, execute concurrently, and yield results in order.""" input_iter = self._cached_input_iterator @@ -467,12 +439,33 @@ def _iter_packets_concurrent( to_compute.append((i, tag, packet)) self._cached_input_iterator = None - # Submit uncached packets concurrently and cache results. + # Submit uncached packets concurrently via async_process_packet. if to_compute: - results = _execute_concurrent( - packet_function, [pkt for _, _, pkt in to_compute] - ) - for (i, tag, _), output_packet in zip(to_compute, results): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None: + # Already in event loop — fall back to sequential sync + results = [ + self._function_pod.process_packet(tag, pkt) + for _, tag, pkt in to_compute + ] + else: + async def _gather() -> list[tuple[TagProtocol, PacketProtocol | None]]: + return list( + await asyncio.gather( + *[ + self._function_pod.async_process_packet(tag, pkt) + for _, tag, pkt in to_compute + ] + ) + ) + + results = asyncio.run(_gather()) + + for (i, _, _), (tag, output_packet) in zip(to_compute, results): self._cached_output_packets[i] = (tag, output_packet) # Yield everything in original order. @@ -818,6 +811,18 @@ def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: if packet is not None: yield tag, packet + def process_packet( + self, tag: TagProtocol, packet: PacketProtocol + ) -> tuple[TagProtocol, PacketProtocol | None]: + """Process a single packet by delegating to the function pod.""" + return self._function_pod.process_packet(tag, packet) + + async def async_process_packet( + self, tag: TagProtocol, packet: PacketProtocol + ) -> tuple[TagProtocol, PacketProtocol | None]: + """Async counterpart of ``process_packet``.""" + return await self._function_pod.async_process_packet(tag, packet) + def _iter_packets_sequential( self, ) -> Iterator[tuple[TagProtocol, PacketProtocol]]: @@ -828,7 +833,7 @@ def _iter_packets_sequential( if packet is not None: yield tag, packet else: - output_packet = self._packet_function.call(packet) + tag, output_packet = self.process_packet(tag, packet) self._cached_output_packets[i] = (tag, output_packet) if output_packet is not None: yield tag, output_packet @@ -849,10 +854,31 @@ def _iter_packets_concurrent( self._cached_input_iterator = None if to_compute: - results = _execute_concurrent( - self._packet_function, [pkt for _, _, pkt in to_compute] - ) - for (i, tag, _), output_packet in zip(to_compute, results): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None: + # Already in event loop — fall back to sequential sync + results = [ + self.process_packet(tag, pkt) + for _, tag, pkt in to_compute + ] + else: + async def _gather() -> list[tuple[TagProtocol, PacketProtocol | None]]: + return list( + await asyncio.gather( + *[ + self.async_process_packet(tag, pkt) + for _, tag, pkt in to_compute + ] + ) + ) + + results = asyncio.run(_gather()) + + for (i, _, _), (tag, output_packet) in zip(to_compute, results): self._cached_output_packets[i] = (tag, output_packet) for i, *_ in all_inputs: @@ -945,13 +971,22 @@ def as_table( ) return output_table + # ------------------------------------------------------------------ + # Async channel execution + # ------------------------------------------------------------------ + async def async_execute( self, inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], output: WritableChannel[tuple[TagProtocol, PacketProtocol]], pipeline_config: PipelineConfig | None = None, ) -> None: - """Streaming async execution for FunctionNode.""" + """Streaming async execution for FunctionNode. + + Routes each packet through ``async_process_packet`` so that + subclasses (e.g. ``PersistentFunctionNode``) can override the + per-packet logic without re-implementing the concurrency scaffold. + """ try: pipeline_config = pipeline_config or PipelineConfig() node_config = ( @@ -964,9 +999,11 @@ async def async_execute( async def process_one(tag: TagProtocol, packet: PacketProtocol) -> None: try: - result_packet = self._packet_function.call(packet) + tag_out, result_packet = await self.async_process_packet( + tag, packet + ) if result_packet is not None: - await output.send((tag, result_packet)) + await output.send((tag_out, result_packet)) finally: if sem is not None: sem.release() @@ -1099,6 +1136,39 @@ def process_packet( return tag, output_packet + async def async_process_packet( + self, + tag: TagProtocol, + packet: PacketProtocol, + skip_cache_lookup: bool = False, + skip_cache_insert: bool = False, + ) -> tuple[TagProtocol, PacketProtocol | None]: + """Async counterpart of ``process_packet``. + + Uses the CachedPacketFunction's async_call for computation + result + caching. Pipeline record storage is synchronous (DB protocol is sync). + """ + output_packet = await self._packet_function.async_call( + packet, + skip_cache_lookup=skip_cache_lookup, + skip_cache_insert=skip_cache_insert, + ) + + if output_packet is not None: + result_computed = bool( + output_packet.get_meta_value( + self._packet_function.RESULT_COMPUTED_FLAG, False + ) + ) + self.add_pipeline_record( + tag, + packet, + packet_record_id=output_packet.datagram_id, + computed=result_computed, + ) + + return tag, output_packet + def add_pipeline_record( self, tag: TagProtocol, @@ -1262,6 +1332,42 @@ def run(self) -> None: for _ in self.iter_packets(): pass + # ------------------------------------------------------------------ + # Async channel execution (two-phase) + # ------------------------------------------------------------------ + + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + ) -> None: + """Two-phase async execution: replay cached, then compute missing.""" + try: + # Phase 1: emit existing results from DB + existing = self.get_all_records(columns={"meta": True}) + computed_hashes: set[str] = set() + if existing is not None and existing.num_rows > 0: + tag_keys = self._input_stream.keys()[0] + hash_col = constants.INPUT_PACKET_HASH_COL + computed_hashes = set( + cast(list[str], existing.column(hash_col).to_pylist()) + ) + data_table = existing.drop([hash_col]) + existing_stream = ArrowTableStream(data_table, tag_columns=tag_keys) + for tag, packet in existing_stream.iter_packets(): + await output.send((tag, packet)) + + # Phase 2: process packets not already in the DB + async for tag, packet in inputs[0]: + input_hash = packet.content_hash().to_string() + if input_hash in computed_hashes: + continue + tag, output_packet = await self.async_process_packet(tag, packet) + if output_packet is not None: + await output.send((tag, output_packet)) + finally: + await output.close() + def as_source(self): """Return a DerivedSource backed by the DB records of this node.""" from orcapod.core.sources.derived_source import DerivedSource diff --git a/src/orcapod/core/operator_node.py b/src/orcapod/core/operator_node.py index 3bf87485..6de06e70 100644 --- a/src/orcapod/core/operator_node.py +++ b/src/orcapod/core/operator_node.py @@ -1,13 +1,14 @@ from __future__ import annotations +import asyncio import logging from collections.abc import Iterator, Sequence from typing import TYPE_CHECKING, Any -from orcapod.channels import ReadableChannel, WritableChannel - from orcapod import contexts +from orcapod.channels import Channel, ReadableChannel, WritableChannel from orcapod.config import Config +from orcapod.core.static_output_pod import StaticOutputPod from orcapod.core.streams.base import StreamBase from orcapod.core.tracker import DEFAULT_TRACKER_MANAGER from orcapod.protocols.core_protocols import ( @@ -158,13 +159,25 @@ def as_table( assert self._cached_output_stream is not None return self._cached_output_stream.as_table(columns=columns, all_info=all_info) + # ------------------------------------------------------------------ + # Async channel execution + # ------------------------------------------------------------------ + async def async_execute( self, inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], output: WritableChannel[tuple[TagProtocol, PacketProtocol]], ) -> None: - """Delegate to the wrapped operator's async_execute.""" - await self._operator.async_execute(inputs, output) + """Delegate to the wrapped operator's async_execute. + + Passes pipeline hashes from the input streams so that + multi-input operators can compute canonical system-tag + column names without storing state during validation. + """ + hashes = [s.pipeline_hash() for s in self._input_streams] + await self._operator.async_execute( + inputs, output, input_pipeline_hashes=hashes + ) def __repr__(self) -> str: return ( @@ -242,18 +255,9 @@ def pipeline_path(self) -> tuple[str, ...]: + (f"node:{self._pipeline_node_hash}",) ) - def _compute_and_store(self) -> None: - """Compute operator output, optionally store in DB.""" - self._cached_output_stream = self._operator.process( - *self._input_streams, - ) - - if self._cache_mode == CacheMode.OFF: - self._update_modified_time() - return - - # Materialize for DB storage (LOG and REPLAY modes) - output_table = self._cached_output_stream.as_table( + def _store_output_stream(self, stream: StreamProtocol) -> None: + """Materialize stream and store in the pipeline database with per-row dedup.""" + output_table = stream.as_table( columns={"source": True, "system_tags": True}, ) @@ -281,6 +285,18 @@ def _compute_and_store(self) -> None: ) self._cached_output_table = output_table.drop(self.HASH_COLUMN_NAME) + + def _compute_and_store(self) -> None: + """Compute operator output, optionally store in DB.""" + self._cached_output_stream = self._operator.process( + *self._input_streams, + ) + + if self._cache_mode == CacheMode.OFF: + self._update_modified_time() + return + + self._store_output_stream(self._cached_output_stream) self._update_modified_time() def _replay_from_cache(self) -> None: @@ -368,6 +384,61 @@ def get_all_records( return results if results.num_rows > 0 else None + # ------------------------------------------------------------------ + # Async channel execution + # ------------------------------------------------------------------ + + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + ) -> None: + """Async execution with cache mode handling. + + REPLAY: emit from DB, close output. + OFF: delegate to operator, forward results. + LOG: delegate to operator, forward + collect results, then store in DB. + """ + try: + if self._cache_mode == CacheMode.REPLAY: + self._replay_from_cache() + assert self._cached_output_stream is not None + for tag, packet in self._cached_output_stream.iter_packets(): + await output.send((tag, packet)) + return # finally block closes output + + # OFF or LOG: delegate to operator, forward results downstream + intermediate: Channel[tuple[TagProtocol, PacketProtocol]] = Channel() + should_collect = self._cache_mode == CacheMode.LOG + collected: list[tuple[TagProtocol, PacketProtocol]] = [] + + async def forward() -> None: + async for item in intermediate.reader: + if should_collect: + collected.append(item) + await output.send(item) + + hashes = [s.pipeline_hash() for s in self._input_streams] + async with asyncio.TaskGroup() as tg: + tg.create_task( + self._operator.async_execute( + inputs, + intermediate.writer, + input_pipeline_hashes=hashes, + ) + ) + tg.create_task(forward()) + + # TaskGroup has completed — store if LOG mode (sync DB write, post-hoc) + if should_collect and collected: + stream = StaticOutputPod._materialize_to_stream(collected) + self._cached_output_stream = stream + self._store_output_stream(stream) + + self._update_modified_time() + finally: + await output.close() + # ------------------------------------------------------------------ # DerivedSource # ------------------------------------------------------------------ diff --git a/src/orcapod/core/operators/base.py b/src/orcapod/core/operators/base.py index ab7b5fc2..fbc20fbf 100644 --- a/src/orcapod/core/operators/base.py +++ b/src/orcapod/core/operators/base.py @@ -13,7 +13,7 @@ StreamProtocol, TagProtocol, ) -from orcapod.types import ColumnConfig, Schema +from orcapod.types import ColumnConfig, ContentHash, Schema class UnaryOperator(StaticOutputPod): @@ -72,6 +72,8 @@ async def async_execute( self, inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + *, + input_pipeline_hashes: Sequence[ContentHash] | None = None, ) -> None: """Barrier-mode: collect single input, run unary_static_process, emit.""" try: @@ -154,6 +156,8 @@ async def async_execute( self, inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + *, + input_pipeline_hashes: Sequence[ContentHash] | None = None, ) -> None: """Barrier-mode: collect both inputs concurrently, run binary_static_process, emit.""" try: diff --git a/src/orcapod/core/operators/batch.py b/src/orcapod/core/operators/batch.py index d49eeaa6..28e5cc4c 100644 --- a/src/orcapod/core/operators/batch.py +++ b/src/orcapod/core/operators/batch.py @@ -1,8 +1,10 @@ +from collections.abc import Sequence from typing import TYPE_CHECKING, Any +from orcapod.channels import ReadableChannel, WritableChannel from orcapod.core.operators.base import UnaryOperator from orcapod.core.streams import ArrowTableStream -from orcapod.protocols.core_protocols import StreamProtocol +from orcapod.protocols.core_protocols import PacketProtocol, StreamProtocol, TagProtocol from orcapod.types import ColumnConfig from orcapod.utils.lazy_module import LazyModule @@ -91,5 +93,48 @@ def unary_output_schema( # TODO: check if this is really necessary return Schema(batched_tag_types), Schema(batched_packet_types) + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + **kwargs: Any, + ) -> None: + """Streaming batch: emit full batches as they accumulate. + + When ``batch_size > 0``, each group of ``batch_size`` rows is + materialized and emitted immediately, allowing downstream consumers + to start processing before all input is consumed. When + ``batch_size == 0`` (batch everything), falls back to barrier mode. + """ + try: + if self.batch_size == 0: + # Must collect all rows — barrier fallback + rows = await inputs[0].collect() + if rows: + stream = self._materialize_to_stream(rows) + result = self.unary_static_process(stream) + for tag, packet in result.iter_packets(): + await output.send((tag, packet)) + return + + batch: list[tuple[TagProtocol, PacketProtocol]] = [] + async for tag, packet in inputs[0]: + batch.append((tag, packet)) + if len(batch) >= self.batch_size: + stream = self._materialize_to_stream(batch) + result = self.unary_static_process(stream) + for out_tag, out_packet in result.iter_packets(): + await output.send((out_tag, out_packet)) + batch = [] + + # Flush partial batch + if batch and not self.drop_partial_batch: + stream = self._materialize_to_stream(batch) + result = self.unary_static_process(stream) + for out_tag, out_packet in result.iter_packets(): + await output.send((out_tag, out_packet)) + finally: + await output.close() + def identity_structure(self) -> Any: return (self.__class__.__name__, self.batch_size, self.drop_partial_batch) diff --git a/src/orcapod/core/operators/column_selection.py b/src/orcapod/core/operators/column_selection.py index ee09cd11..96099ee1 100644 --- a/src/orcapod/core/operators/column_selection.py +++ b/src/orcapod/core/operators/column_selection.py @@ -1,11 +1,12 @@ import logging -from collections.abc import Collection, Mapping +from collections.abc import Collection, Mapping, Sequence from typing import TYPE_CHECKING, Any +from orcapod.channels import ReadableChannel, WritableChannel from orcapod.core.operators.base import UnaryOperator from orcapod.core.streams import ArrowTableStream from orcapod.errors import InputValidationError -from orcapod.protocols.core_protocols import StreamProtocol +from orcapod.protocols.core_protocols import PacketProtocol, StreamProtocol, TagProtocol from orcapod.system_constants import constants from orcapod.types import ColumnConfig, Schema from orcapod.utils.lazy_module import LazyModule @@ -82,6 +83,34 @@ def unary_output_schema( return Schema(new_tag_schema), packet_schema + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + **kwargs: Any, + ) -> None: + """Streaming: select tag columns per row without materializing.""" + try: + tags_to_drop: list[str] | None = None + async for tag, packet in inputs[0]: + if tags_to_drop is None: + tag_keys = tag.keys() + if self.strict: + missing = set(self.columns) - set(tag_keys) + if missing: + raise InputValidationError( + f"Missing tag columns: {missing}. Make sure all " + f"specified columns to select are present or use " + f"strict=False to ignore missing columns" + ) + tags_to_drop = [c for c in tag_keys if c not in self.columns] + if not tags_to_drop: + await output.send((tag, packet)) + else: + await output.send((tag.drop(*tags_to_drop), packet)) + finally: + await output.close() + def identity_structure(self) -> Any: return ( self.__class__.__name__, @@ -163,6 +192,34 @@ def unary_output_schema( return tag_schema, Schema(new_packet_schema) + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + **kwargs: Any, + ) -> None: + """Streaming: select packet columns per row without materializing.""" + try: + pkts_to_drop: list[str] | None = None + async for tag, packet in inputs[0]: + if pkts_to_drop is None: + pkt_keys = packet.keys() + if self.strict: + missing = set(self.columns) - set(pkt_keys) + if missing: + raise InputValidationError( + f"Missing packet columns: {missing}. Make sure all " + f"specified columns to select are present or use " + f"strict=False to ignore missing columns" + ) + pkts_to_drop = [c for c in pkt_keys if c not in self.columns] + if not pkts_to_drop: + await output.send((tag, packet)) + else: + await output.send((tag, packet.drop(*pkts_to_drop))) + finally: + await output.close() + def identity_structure(self) -> Any: return ( self.__class__.__name__, @@ -237,6 +294,38 @@ def unary_output_schema( return Schema(new_tag_schema), packet_schema + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + **kwargs: Any, + ) -> None: + """Streaming: drop tag columns per row without materializing.""" + try: + effective_drops: list[str] | None = None + async for tag, packet in inputs[0]: + if effective_drops is None: + tag_keys = tag.keys() + if self.strict: + missing = set(self.columns) - set(tag_keys) + if missing: + raise InputValidationError( + f"Missing tag columns: {missing}. Make sure all " + f"specified columns to drop are present or use " + f"strict=False to ignore missing columns" + ) + effective_drops = ( + list(self.columns) + if self.strict + else [c for c in self.columns if c in tag_keys] + ) + if not effective_drops: + await output.send((tag, packet)) + else: + await output.send((tag.drop(*effective_drops), packet)) + finally: + await output.close() + def identity_structure(self) -> Any: return ( self.__class__.__name__, @@ -314,6 +403,38 @@ def unary_output_schema( return tag_schema, Schema(new_packet_schema) + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + **kwargs: Any, + ) -> None: + """Streaming: drop packet columns per row without materializing.""" + try: + effective_drops: list[str] | None = None + async for tag, packet in inputs[0]: + if effective_drops is None: + pkt_keys = packet.keys() + if self.strict: + missing = set(self.columns) - set(pkt_keys) + if missing: + raise InputValidationError( + f"Missing packet columns: {missing}. Make sure all " + f"specified columns to drop are present or use " + f"strict=False to ignore missing columns" + ) + effective_drops = ( + list(self.columns) + if self.strict + else [c for c in self.columns if c in pkt_keys] + ) + if not effective_drops: + await output.send((tag, packet)) + else: + await output.send((tag, packet.drop(*effective_drops))) + finally: + await output.close() + def identity_structure(self) -> Any: return ( self.__class__.__name__, diff --git a/src/orcapod/core/operators/join.py b/src/orcapod/core/operators/join.py index 9a1f793d..7fd6d3fc 100644 --- a/src/orcapod/core/operators/join.py +++ b/src/orcapod/core/operators/join.py @@ -1,12 +1,19 @@ -from collections.abc import Collection +import asyncio +from collections.abc import Collection, Sequence from typing import TYPE_CHECKING, Any +from orcapod.channels import ReadableChannel, WritableChannel from orcapod.core.operators.base import NonZeroInputOperator from orcapod.core.streams import ArrowTableStream from orcapod.errors import InputValidationError -from orcapod.protocols.core_protocols import ArgumentGroup, StreamProtocol +from orcapod.protocols.core_protocols import ( + ArgumentGroup, + PacketProtocol, + StreamProtocol, + TagProtocol, +) from orcapod.system_constants import constants -from orcapod.types import ColumnConfig, Schema +from orcapod.types import ColumnConfig, ContentHash, Schema from orcapod.utils import arrow_data_utils, schema_utils from orcapod.utils.lazy_module import LazyModule @@ -28,6 +35,7 @@ def kernel_id(self) -> tuple[str, ...]: return (f"{self.__class__.__name__}",) def validate_nonzero_inputs(self, *streams: StreamProtocol) -> None: + """Validate that input streams are compatible for joining.""" try: self.output_schema(*streams) except Exception as e: @@ -168,6 +176,282 @@ def static_process(self, *streams: StreamProtocol) -> StreamProtocol: tag_columns=tuple(tag_keys), ) + # ------------------------------------------------------------------ + # Async execution + # ------------------------------------------------------------------ + + def _compute_system_tag_suffixes( + self, + input_pipeline_hashes: Sequence[ContentHash], + ) -> list[str]: + """Compute per-input system-tag suffixes from pipeline hashes. + + Each suffix is ``{truncated_hash}:{canonical_position}`` where + canonical position is determined by sorting the hashes (matching + the deterministic ordering used by ``static_process``). + + Args: + input_pipeline_hashes: Pipeline hash per input, positionally + matching the input channels. + + Returns: + List of suffix strings, one per input position. + """ + n_char = self.orcapod_config.system_tag_hash_n_char + hex_strings = [h.to_hex() for h in input_pipeline_hashes] + + # Canonical order: sorted by full hex (same as order_input_streams) + sorted_hexes = sorted(hex_strings) + + suffixes: list[str] = [] + for orig_idx, hex_str in enumerate(hex_strings): + canon_idx = sorted_hexes.index(hex_str) + truncated = input_pipeline_hashes[orig_idx].to_hex(n_char) + suffixes.append(f"{truncated}:{canon_idx}") + return suffixes + + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + *, + input_pipeline_hashes: Sequence[ContentHash] | None = None, + ) -> None: + """Async join with streaming symmetric hash join for two inputs. + + Single input: streams through directly without any buffering. + + Two inputs: symmetric hash join — each arriving row is + immediately probed against the opposite side's buffer, emitting + matches as soon as found. System-tag columns are correctly + renamed using the ``input_pipeline_hashes``. + + Three or more inputs: collects all inputs concurrently, then + delegates to ``static_process`` for the Polars N-way join. + + Args: + inputs: Readable channels, one per upstream. + output: Writable channel for downstream. + input_pipeline_hashes: Pipeline hash for each input, + positionally matching ``inputs``. Required for + correct system-tag renaming with 2+ inputs. + """ + try: + if len(inputs) == 1: + async for tag, packet in inputs[0]: + await output.send((tag, packet)) + return + + if len(inputs) == 2: + suffixes = ( + self._compute_system_tag_suffixes(input_pipeline_hashes) + if input_pipeline_hashes is not None + else ["0", "1"] + ) + await self._symmetric_hash_join( + inputs[0], inputs[1], output, suffixes + ) + return + + # N > 2: concurrent collection + static_process + all_rows = await asyncio.gather(*(ch.collect() for ch in inputs)) + + # Guard against empty inputs — join with an empty side is empty + if any(len(rows) == 0 for rows in all_rows): + return + + streams = [self._materialize_to_stream(rows) for rows in all_rows] + result = self.static_process(*streams) + for tag, packet in result.iter_packets(): + await output.send((tag, packet)) + finally: + await output.close() + + async def _symmetric_hash_join( + self, + left_ch: ReadableChannel[tuple[TagProtocol, PacketProtocol]], + right_ch: ReadableChannel[tuple[TagProtocol, PacketProtocol]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + suffixes: list[str], + ) -> None: + """Symmetric hash join for two inputs. + + Both sides are read concurrently via a merged bounded queue. + Each arriving row is added to its side's index and immediately + probed against the opposite side. Matched rows are emitted to + ``output`` as soon as found, so downstream consumers can begin + work before either input is fully consumed. + + Args: + left_ch: Left input channel. + right_ch: Right input channel. + output: Output channel for matched rows. + suffixes: Per-input system-tag suffixes (positional), + computed from pipeline hashes and canonical ordering. + """ + # Bounded queue preserves backpressure — producers block when full. + _SENTINEL = object() + queue: asyncio.Queue = asyncio.Queue(maxsize=64) + + async def _drain( + ch: ReadableChannel[tuple[TagProtocol, PacketProtocol]], + side: int, + ) -> None: + async for item in ch: + await queue.put((side, item)) + await queue.put((side, _SENTINEL)) + + block_sep = constants.BLOCK_SEPARATOR + + async with asyncio.TaskGroup() as tg: + tg.create_task(_drain(left_ch, 0)) + tg.create_task(_drain(right_ch, 1)) + + # buffers[i] holds all rows seen so far from input i + buffers: list[list[tuple[TagProtocol, PacketProtocol]]] = [[], []] + # indexes[i] maps shared-key tuple → list of indices into buffers[i] + indexes: list[dict[tuple, list[int]]] = [{}, {}] + + shared_keys: tuple[str, ...] | None = None + needs_reindex = False + closed_count = 0 + + while closed_count < 2: + side, item = await queue.get() + + if item is _SENTINEL: + closed_count += 1 + continue + + tag, pkt = item + other = 1 - side + + # Determine shared tag keys once we have rows from both sides + if shared_keys is None: + if not buffers[other]: + # Other side empty — just buffer this row for later + buffers[side].append((tag, pkt)) + continue + + # We have data from both sides; compute shared keys + this_keys = set(tag.keys()) + other_keys = set(buffers[other][0][0].keys()) + shared_keys = tuple(sorted(this_keys & other_keys)) + needs_reindex = True + + # One-time re-index of all rows buffered before shared_keys + if needs_reindex: + needs_reindex = False + for buf_side in (0, 1): + for j, (bt, _bp) in enumerate(buffers[buf_side]): + btd = bt.as_dict() + k = ( + tuple(btd[sk] for sk in shared_keys) + if shared_keys + else (0,) + ) + indexes[buf_side].setdefault(k, []).append(j) + + # Emit matches for all already-buffered rows across sides + for li, (lt, lp) in enumerate(buffers[0]): + ltd = lt.as_dict() + lk = ( + tuple(ltd[sk] for sk in shared_keys) + if shared_keys + else (0,) + ) + for ri in indexes[1].get(lk, []): + rt, rp = buffers[1][ri] + await output.send( + self._merge_row_pair( + lt, lp, rt, rp, suffixes, block_sep + ) + ) + + # Index the new row + td = tag.as_dict() + key = ( + tuple(td[sk] for sk in shared_keys) if shared_keys else (0,) + ) + row_idx = len(buffers[side]) + buffers[side].append((tag, pkt)) + indexes[side].setdefault(key, []).append(row_idx) + + # Probe the opposite buffer for matches + matching_indices = indexes[other].get(key, []) + for mi in matching_indices: + other_tag, other_pkt = buffers[other][mi] + if side == 0: + merged = self._merge_row_pair( + tag, pkt, other_tag, other_pkt, + suffixes, block_sep, + ) + else: + merged = self._merge_row_pair( + other_tag, other_pkt, tag, pkt, + suffixes, block_sep, + ) + await output.send(merged) + + @staticmethod + def _merge_row_pair( + left_tag: TagProtocol, + left_pkt: PacketProtocol, + right_tag: TagProtocol, + right_pkt: PacketProtocol, + suffixes: list[str], + block_sep: str, + ) -> tuple[TagProtocol, PacketProtocol]: + """Merge a matched pair of rows into one joined (Tag, Packet). + + System-tag keys are renamed by appending + ``{block_sep}{suffix}`` to match the canonical name-extending + scheme used by ``static_process``. System-tag values sharing + the same provenance path are sorted for commutativity. + """ + from orcapod.core.datagrams import Packet, Tag + + sys_prefix = constants.SYSTEM_TAG_PREFIX + + # Merge tag dicts (shared keys come from left) + merged_tag_d: dict = {} + merged_tag_d.update(left_tag.as_dict()) + for k, v in right_tag.as_dict().items(): + if k not in merged_tag_d: + merged_tag_d[k] = v + + # Rename and merge system tags with canonical suffixes + merged_sys: dict = {} + for k, v in left_tag.system_tags().items(): + new_key = ( + f"{k}{block_sep}{suffixes[0]}" + if k.startswith(sys_prefix) + else k + ) + merged_sys[new_key] = v + for k, v in right_tag.system_tags().items(): + new_key = ( + f"{k}{block_sep}{suffixes[1]}" + if k.startswith(sys_prefix) + else k + ) + merged_sys[new_key] = v + + merged_tag = Tag(merged_tag_d, system_tags=merged_sys) + + # Merge packet dicts (non-overlapping by Join's validation) + merged_pkt_d: dict = {} + merged_pkt_d.update(left_pkt.as_dict()) + merged_pkt_d.update(right_pkt.as_dict()) + + merged_si: dict = {} + merged_si.update(left_pkt.source_info()) + merged_si.update(right_pkt.source_info()) + + merged_pkt = Packet(merged_pkt_d, source_info=merged_si) + + return merged_tag, merged_pkt + def identity_structure(self) -> Any: return self.__class__.__name__ diff --git a/src/orcapod/core/operators/mappers.py b/src/orcapod/core/operators/mappers.py index d28b2dec..e9c51510 100644 --- a/src/orcapod/core/operators/mappers.py +++ b/src/orcapod/core/operators/mappers.py @@ -1,10 +1,11 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any +from orcapod.channels import ReadableChannel, WritableChannel from orcapod.core.operators.base import UnaryOperator from orcapod.core.streams import ArrowTableStream from orcapod.errors import InputValidationError -from orcapod.protocols.core_protocols import StreamProtocol +from orcapod.protocols.core_protocols import PacketProtocol, StreamProtocol, TagProtocol from orcapod.system_constants import constants from orcapod.types import ColumnConfig, Schema from orcapod.utils.lazy_module import LazyModule @@ -110,6 +111,34 @@ def unary_output_schema( return tag_schema, Schema(new_packet_schema) + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + **kwargs: Any, + ) -> None: + """Streaming: rename packet columns per row without materializing.""" + try: + rename_map: dict[str, str] | None = None + unmapped: list[str] | None = None + async for tag, packet in inputs[0]: + if rename_map is None: + pkt_keys = packet.keys() + rename_map = { + k: self.name_map[k] for k in pkt_keys if k in self.name_map + } + if self.drop_unmapped: + unmapped = [k for k in pkt_keys if k not in self.name_map] + if not rename_map: + await output.send((tag, packet)) + else: + new_pkt = packet.rename(rename_map) + if unmapped: + new_pkt = new_pkt.drop(*unmapped) + await output.send((tag, new_pkt)) + finally: + await output.close() + def identity_structure(self) -> Any: return ( self.__class__.__name__, @@ -208,6 +237,34 @@ def unary_output_schema( return Schema(new_tag_schema), packet_schema + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + **kwargs: Any, + ) -> None: + """Streaming: rename tag columns per row without materializing.""" + try: + rename_map: dict[str, str] | None = None + unmapped: list[str] | None = None + async for tag, packet in inputs[0]: + if rename_map is None: + tag_keys = tag.keys() + rename_map = { + k: self.name_map[k] for k in tag_keys if k in self.name_map + } + if self.drop_unmapped: + unmapped = [k for k in tag_keys if k not in self.name_map] + if not rename_map: + await output.send((tag, packet)) + else: + new_tag = tag.rename(rename_map) + if unmapped: + new_tag = new_tag.drop(*unmapped) + await output.send((new_tag, packet)) + finally: + await output.close() + def identity_structure(self) -> Any: return ( self.__class__.__name__, diff --git a/src/orcapod/core/operators/semijoin.py b/src/orcapod/core/operators/semijoin.py index 0a36f342..59dde1fc 100644 --- a/src/orcapod/core/operators/semijoin.py +++ b/src/orcapod/core/operators/semijoin.py @@ -1,9 +1,11 @@ +from collections.abc import Sequence from typing import TYPE_CHECKING, Any +from orcapod.channels import ReadableChannel, WritableChannel from orcapod.core.operators.base import BinaryOperator from orcapod.core.streams import ArrowTableStream from orcapod.errors import InputValidationError -from orcapod.protocols.core_protocols import StreamProtocol +from orcapod.protocols.core_protocols import PacketProtocol, StreamProtocol, TagProtocol from orcapod.types import ColumnConfig, Schema from orcapod.utils import schema_utils from orcapod.utils.lazy_module import LazyModule @@ -93,6 +95,9 @@ def validate_binary_inputs( """ Validates that the input streams are compatible for semi-join. Checks that overlapping columns have compatible types. + + Stores the common keys so that ``async_execute`` can use them + to determine the correct empty-right behavior without data. """ try: left_tag_schema, left_packet_schema = left_stream.output_schema() @@ -107,7 +112,8 @@ def validate_binary_inputs( ) # intersection_schemas will raise an error if types are incompatible - schema_utils.intersection_schemas(left_all_schema, right_all_schema) + common = schema_utils.intersection_schemas(left_all_schema, right_all_schema) + self._validated_common_keys: tuple[str, ...] = tuple(common.keys()) except Exception as e: raise InputValidationError( @@ -117,5 +123,88 @@ def validate_binary_inputs( def is_commutative(self) -> bool: return False + def _common_keys_from_schema(self) -> tuple[str, ...]: + """Return the common keys computed during input validation. + + Falls back to an empty tuple if validation hasn't been called + (shouldn't happen in normal pipeline execution). + """ + return getattr(self, "_validated_common_keys", ()) + + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + **kwargs: Any, + ) -> None: + """Build-probe: collect right input, then stream left through a hash lookup. + + Phase 1 — Build: collect all rows from the right (filter) channel and + index them by the common-key values. + Phase 2 — Probe: stream left rows one at a time; for each row whose + common-key values appear in the right-side index, emit immediately. + + Falls back to barrier mode when the right input is empty (schema + cannot be inferred from data) or when there are no common keys. + """ + try: + left_ch, right_ch = inputs[0], inputs[1] + + # Phase 1: Build right-side lookup + right_rows = await right_ch.collect() + + if not right_rows: + # Empty right: determine common keys from the validated + # input schemas (set during __init__) to match sync semantics. + # Common keys exist → empty result; no common keys → pass left through. + common = self._common_keys_from_schema() + if common: + # Drain left channel (discard) — result is empty + await left_ch.collect() + return + # No common keys — pass all left rows through unchanged + async for tag, packet in left_ch: + await output.send((tag, packet)) + return + + # Determine right-side keys from first row + right_tag_keys = set(right_rows[0][0].keys()) + right_pkt_keys = set(right_rows[0][1].keys()) + right_all_keys = right_tag_keys | right_pkt_keys + + # Phase 2: Probe — stream left rows + common_keys: tuple[str, ...] | None = None + right_lookup: set[tuple] | None = None + + async for tag, packet in left_ch: + if common_keys is None: + # First left row — determine common keys and build index + left_tag_keys = set(tag.keys()) + left_pkt_keys = set(packet.keys()) + left_all_keys = left_tag_keys | left_pkt_keys + common_keys = tuple(sorted(left_all_keys & right_all_keys)) + + if not common_keys: + # No common keys — pass all left rows through + await output.send((tag, packet)) + async for t, p in left_ch: + await output.send((t, p)) + return + + # Build right-side lookup + right_lookup = set() + for rt, rp in right_rows: + rd = rt.as_dict() + rd.update(rp.as_dict()) + right_lookup.add(tuple(rd[k] for k in common_keys)) + + # Probe + ld = tag.as_dict() + ld.update(packet.as_dict()) + if tuple(ld[k] for k in common_keys) in right_lookup: # type: ignore[arg-type] + await output.send((tag, packet)) + finally: + await output.close() + def identity_structure(self) -> Any: return self.__class__.__name__ diff --git a/src/orcapod/core/packet_function.py b/src/orcapod/core/packet_function.py index ed3d8234..7b27fc9f 100644 --- a/src/orcapod/core/packet_function.py +++ b/src/orcapod/core/packet_function.py @@ -532,6 +532,30 @@ def call( return output_packet + async def async_call( + self, + packet: PacketProtocol, + *, + skip_cache_lookup: bool = False, + skip_cache_insert: bool = False, + ) -> PacketProtocol | None: + """Async counterpart of ``call`` with cache check and recording.""" + output_packet = None + if not skip_cache_lookup: + logger.info("Checking for cache...") + output_packet = self.get_cached_output_for_packet(packet) + if output_packet is not None: + logger.info(f"Cache hit for {packet}!") + if output_packet is None: + output_packet = await self._packet_function.async_call(packet) + if output_packet is not None: + if not skip_cache_insert: + self.record_packet(packet, output_packet) + output_packet = output_packet.with_meta_columns( + **{self.RESULT_COMPUTED_FLAG: True} + ) + return output_packet + def get_cached_output_for_packet( self, input_packet: PacketProtocol ) -> PacketProtocol | None: diff --git a/src/orcapod/core/static_output_pod.py b/src/orcapod/core/static_output_pod.py index c52fb3cd..e5366d04 100644 --- a/src/orcapod/core/static_output_pod.py +++ b/src/orcapod/core/static_output_pod.py @@ -21,7 +21,7 @@ TagProtocol, TrackerManagerProtocol, ) -from orcapod.types import ColumnConfig, Schema +from orcapod.types import ColumnConfig, ContentHash, Schema from orcapod.utils.lazy_module import LazyModule logger = logging.getLogger(__name__) @@ -207,11 +207,21 @@ async def async_execute( self, inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + *, + input_pipeline_hashes: Sequence[ContentHash] | None = None, ) -> None: """Default barrier-mode async execution. Collects all inputs, runs ``static_process``, emits results. Subclasses override for streaming or incremental strategies. + + Args: + inputs: Readable channels, one per upstream node. + output: Writable channel for downstream consumption. + input_pipeline_hashes: Pipeline hash for each input stream, + positionally matching ``inputs``. Multi-input operators + (e.g. Join) use these to compute canonical system-tag + column names. Ignored by single-input operators. """ try: all_rows = await asyncio.gather(*(ch.collect() for ch in inputs)) diff --git a/src/orcapod/pipeline/nodes.py b/src/orcapod/pipeline/nodes.py index 2245b760..ae436067 100644 --- a/src/orcapod/pipeline/nodes.py +++ b/src/orcapod/pipeline/nodes.py @@ -1,10 +1,11 @@ from __future__ import annotations import logging -from collections.abc import Iterator +from collections.abc import Iterator, Sequence from typing import TYPE_CHECKING, Any from orcapod import contexts +from orcapod.channels import ReadableChannel, WritableChannel from orcapod.config import Config from orcapod.core.streams.arrow_table_stream import ArrowTableStream from orcapod.core.tracker import SourceNode @@ -144,6 +145,20 @@ def as_table( assert self._cached_stream is not None return self._cached_stream.as_table(columns=columns, all_info=all_info) + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + ) -> None: + """Materialize to cache DB, then push cached rows to the output channel.""" + try: + self._ensure_stream() + assert self._cached_stream is not None + for tag, packet in self._cached_stream.iter_packets(): + await output.send((tag, packet)) + finally: + await output.close() + def get_all_records(self) -> "pa.Table | None": """Retrieve all stored records from the cache database.""" return self._cache_database.get_all_records(self.cache_path) diff --git a/src/orcapod/pipeline/orchestrator.py b/src/orcapod/pipeline/orchestrator.py index 17de6743..86ca9ee3 100644 --- a/src/orcapod/pipeline/orchestrator.py +++ b/src/orcapod/pipeline/orchestrator.py @@ -1,7 +1,10 @@ """Async pipeline orchestrator for push-based channel execution. -Compiles a ``GraphTracker``'s DAG into channels and launches all nodes -concurrently via ``asyncio.TaskGroup``. +Walks a compiled ``Pipeline``'s persistent node graph and launches all +nodes concurrently via ``asyncio.TaskGroup``, wiring them together with +bounded channels. After execution, results are available in the +pipeline databases via the usual ``get_all_records()`` / ``as_source()`` +accessors on each persistent node. """ from __future__ import annotations @@ -12,163 +15,129 @@ from typing import TYPE_CHECKING, Any from orcapod.channels import BroadcastChannel, Channel -from orcapod.core.static_output_pod import StaticOutputPod -from orcapod.core.tracker import GraphTracker, SourceNode from orcapod.types import PipelineConfig if TYPE_CHECKING: import networkx as nx - from orcapod.core.streams.arrow_table_stream import ArrowTableStream - from orcapod.protocols.core_protocols import PacketProtocol, StreamProtocol, TagProtocol + from orcapod.pipeline.graph import Pipeline logger = logging.getLogger(__name__) class AsyncPipelineOrchestrator: - """Executes a compiled DAG asynchronously using channels and TaskGroup. + """Execute a compiled ``Pipeline`` asynchronously using channels. - After ``GraphTracker.compile()``, the orchestrator: + After ``Pipeline.compile()``, the orchestrator: - 1. Identifies source, intermediate, and terminal nodes. - 2. Creates bounded channels (or broadcast channels for fan-out) between - connected nodes. - 3. Launches every node's ``async_execute`` concurrently. - 4. Collects the terminal node's output and materializes it as a stream. + 1. Walks ``Pipeline._node_graph`` (persistent nodes) in topological + order. + 2. Creates bounded channels (or broadcast channels for fan-out) + between connected nodes. + 3. Launches every node's ``async_execute`` concurrently via + ``asyncio.TaskGroup``. + + Results are written to the pipeline databases by the persistent + nodes themselves (``PersistentFunctionNode``, ``PersistentOperatorNode`` + in LOG mode, etc.). After ``run()`` returns, callers retrieve data + via ``pipeline.