Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ All notable changes to `openarmature-python` are documented in this file.

The format follows [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). The package follows [Semantic Versioning](https://semver.org/); pre-1.0 minor bumps may carry behavioral changes per [spec governance](https://github.com/LunarCommand/openarmature-spec/blob/main/GOVERNANCE.md).

## [Unreleased]

### Fixed

- **`current_fan_out_index()` inside fan-out instance middleware** now returns the executing instance's index (and `current_fan_out_index_chain()` its lineage) instead of `None`. The engine set the fan-out lineage ContextVars per-node, inside the inner subgraph, which left them unset in `instance_middleware` that wraps the subgraph from outside; they are now set around the instance-middleware chain. The documented `instance_middleware` use (`RetryMiddleware`) does not read the index, so no shipped behavior changes. This corrects the value seen by custom instance middleware that reads the index or calls `set_invocation_metadata`.

## [0.15.0] — 2026-06-22

### Added
Expand Down
42 changes: 40 additions & 2 deletions src/openarmature/graph/fan_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,20 @@

import asyncio
import time
from collections.abc import Callable, Mapping, Sequence
from collections.abc import Callable, Iterator, Mapping, Sequence
from contextlib import AbstractContextManager, contextmanager, nullcontext
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal, cast

from openarmature.observability.correlation import (
_reset_branch_name_chain,
_reset_fan_out_index,
_reset_fan_out_index_chain,
_set_branch_name_chain,
_set_fan_out_index,
_set_fan_out_index_chain,
)

from .errors import (
FanOutEmpty,
FanOutInvalidConcurrency,
Expand All @@ -57,6 +67,25 @@
ConcurrencyResolver = Callable[[Any], int | None]


@contextmanager
def _bind_instance_lineage(child_context: _InvocationContext) -> Iterator[None]:
"""Bind the fan-out lineage ContextVars (the instance index and the
per-depth index / branch chains) to ``child_context`` for the duration of
the ``with`` block, resetting them on exit."""
# compiled.py binds these per-node, inside the inner subgraph; the
# instance_middleware chain runs outside that, so current_fan_out_index()
# and set_invocation_metadata's lineage view would otherwise be unset there.
fan_out_token = _set_fan_out_index(child_context.fan_out_index)
fan_out_chain_token = _set_fan_out_index_chain(child_context.fan_out_index_chain)
branch_chain_token = _set_branch_name_chain(child_context.branch_name_chain)
try:
yield
finally:
_reset_branch_name_chain(branch_chain_token)
_reset_fan_out_index_chain(fan_out_chain_token)
_reset_fan_out_index(fan_out_token)


@dataclass(frozen=True)
class FanOutConfig:
"""Frozen configuration for a :class:`FanOutNode`.
Expand Down Expand Up @@ -291,8 +320,17 @@ async def innermost(s: ChildT) -> Mapping[str, Any]:
return _extract_instance_partial(cfg, final_inst_state)

chain: ChainCall = compose_chain(cfg.instance_middleware, innermost)
# Bind the lineage ContextVars around the chain only when there is
# instance middleware to read them; with none, the inner subgraph's
# nodes bind them and this level would be a redundant no-op. The
# context manager resets before the error-handling below so that
# path's saves keep their existing context.
lineage: AbstractContextManager[None] = (
_bind_instance_lineage(child_context) if cfg.instance_middleware else nullcontext()
)
try:
partial = await chain(instance_state)
with lineage:
partial = await chain(instance_state)
except Exception as exc:
if cfg.error_policy == "collect":
# Per §10.11.2 collect mode: the failure becomes a
Expand Down
101 changes: 101 additions & 0 deletions tests/unit/test_fan_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,107 @@ async def maybe_fail(state: WorkerState) -> Mapping[str, Any]:
assert instance_attempts == {7: 2, 9: 2}


async def test_instance_middleware_sees_fan_out_index() -> None:
# An instance_middleware that reads current_fan_out_index() / its chain
# observes the instance's own index: the engine sets the lineage ContextVars
# around the middleware chain, not only inside node bodies. (Regression --
# the index was None here when only compiled.py set it, deeper in node
# execution, so the middleware wrapping the inner subgraph saw nothing.)
from openarmature.observability.correlation import (
current_fan_out_index,
current_fan_out_index_chain,
)

seen_index: dict[int, int | None] = {}
seen_chain: dict[int, tuple[int | None, ...]] = {}

class _RecordIndexMW:
async def __call__(self, state: WorkerState, next_: Any, /) -> Any:
# Key by the item so each instance is identifiable without relying
# on the index under test.
seen_index[state.item] = current_fan_out_index()
seen_chain[state.item] = current_fan_out_index_chain()
return await next_(state)

async def compute(state: WorkerState) -> Mapping[str, Any]:
return {"result": state.item}

inner_builder: GraphBuilder[WorkerState] = GraphBuilder(WorkerState)
inner_builder.set_entry("compute")
inner_builder.add_node("compute", compute)
inner_builder.add_edge("compute", END)
inner = inner_builder.compile()

parent_builder: GraphBuilder[InstanceMwParentState] = GraphBuilder(InstanceMwParentState)
parent_builder.set_entry("process")
parent_builder.add_fan_out_node(
"process",
subgraph=inner,
items_field="items",
item_field="item",
collect_field="result",
target_field="results",
instance_middleware=[_RecordIndexMW()],
)
parent_builder.add_edge("process", END)
parent = parent_builder.compile()

await parent.invoke(InstanceMwParentState(items=[10, 20, 30]))
await parent.drain()

# items 10/20/30 are fan-out indices 0/1/2 in order; the chain carries the
# instance index at the leaf.
assert seen_index == {10: 0, 20: 1, 30: 2}
assert seen_chain == {10: (0,), 20: (1,), 30: (2,)}


async def test_instance_middleware_lineage_reset_on_failure() -> None:
# The lineage ContextVars reset even when an instance fails: the binding's
# finally runs on the exception path, so a failed instance leaks nothing
# into the parent scope.
from openarmature.observability.correlation import current_fan_out_index

seen: list[int | None] = []

class _RecordMW:
async def __call__(self, state: WorkerState, next_: Any, /) -> Any:
seen.append(current_fan_out_index())
return await next_(state)

async def boom(_state: WorkerState) -> Mapping[str, Any]:
raise RuntimeError("boom")

inner_builder: GraphBuilder[WorkerState] = GraphBuilder(WorkerState)
inner_builder.set_entry("boom")
inner_builder.add_node("boom", boom)
inner_builder.add_edge("boom", END)
inner = inner_builder.compile()

parent_builder: GraphBuilder[InstanceMwParentState] = GraphBuilder(InstanceMwParentState)
parent_builder.set_entry("process")
parent_builder.add_fan_out_node(
"process",
subgraph=inner,
items_field="items",
item_field="item",
collect_field="result",
target_field="results",
instance_middleware=[_RecordMW()],
concurrency=1,
)
parent_builder.add_edge("process", END)
parent = parent_builder.compile()

with pytest.raises(NodeException):
await parent.invoke(InstanceMwParentState(items=[1, 2]))
await parent.drain()

# The middleware saw the instance index (the bind happened) ...
assert seen and all(idx is not None for idx in seen)
# ... and the bind's finally reset it despite the failure.
assert current_fan_out_index() is None


# ---------------------------------------------------------------------------
# Fan-in determinism under nondeterministic completion order (§9.4)
# ---------------------------------------------------------------------------
Expand Down