From f098761e12fc6ad095cfc3362faebd7b721c0ba0 Mon Sep 17 00:00:00 2001 From: chris-colinsky Date: Wed, 24 Jun 2026 17:01:37 -0700 Subject: [PATCH 1/2] Bind fan-out lineage vars in instance middleware current_fan_out_index() (and the lineage chains) returned None inside fan-out instance_middleware: the engine binds those ContextVars per-node inside the inner subgraph (compiled.py), but instance_middleware wraps the subgraph from outside, before any node runs. The documented use (RetryMiddleware) doesn't read the index, so it sat latent; custom instance middleware reading the index or calling set_invocation_metadata saw None. Bind the three fan-out lineage ContextVars (fan_out_index + the per-depth index/branch chains) to the instance's child_context around the instance_middleware chain via a _bind_instance_lineage context manager, resetting on exit. Bind only when there is instance middleware to read them (the inner nodes bind them otherwise), so the no-middleware path is unchanged. Reset before the error-handling below so its saves keep their existing context. --- CHANGELOG.md | 6 ++ src/openarmature/graph/fan_out.py | 42 ++++++++++++- tests/unit/test_fan_out.py | 97 +++++++++++++++++++++++++++++++ 3 files changed, 143 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fd3d08a..99f7cdd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,12 @@ All notable changes to `openarmature-python` are documented in this file. The format follows [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). The package follows [Semantic Versioning](https://semver.org/); pre-1.0 minor bumps may carry behavioral changes per [spec governance](https://github.com/LunarCommand/openarmature-spec/blob/main/GOVERNANCE.md). +## [Unreleased] + +### Fixed + +- **`current_fan_out_index()` inside fan-out instance middleware** now returns the executing instance's index (and `current_fan_out_index_chain()` its lineage) instead of `None`. The engine set the fan-out lineage ContextVars per-node, inside the inner subgraph, which left them unset in `instance_middleware` that wraps the subgraph from outside; they are now set around the instance-middleware chain. The documented `instance_middleware` use (`RetryMiddleware`) does not read the index, so no shipped behavior changes. This corrects the value seen by custom instance middleware that reads the index or calls `set_invocation_metadata`. + ## [0.15.0] — 2026-06-22 ### Added diff --git a/src/openarmature/graph/fan_out.py b/src/openarmature/graph/fan_out.py index f5d3223..3363599 100644 --- a/src/openarmature/graph/fan_out.py +++ b/src/openarmature/graph/fan_out.py @@ -32,10 +32,20 @@ import asyncio import time -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Iterator, Mapping, Sequence +from contextlib import AbstractContextManager, contextmanager, nullcontext from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Literal, cast +from openarmature.observability.correlation import ( + _reset_branch_name_chain, + _reset_fan_out_index, + _reset_fan_out_index_chain, + _set_branch_name_chain, + _set_fan_out_index, + _set_fan_out_index_chain, +) + from .errors import ( FanOutEmpty, FanOutInvalidConcurrency, @@ -57,6 +67,25 @@ ConcurrencyResolver = Callable[[Any], int | None] +@contextmanager +def _bind_instance_lineage(child_context: _InvocationContext) -> Iterator[None]: + """Bind the fan-out lineage ContextVars (the instance index and the + per-depth index / branch chains) to ``child_context`` for the duration of + the ``with`` block, resetting them on exit.""" + # compiled.py binds these per-node, inside the inner subgraph; the + # instance_middleware chain runs outside that, so current_fan_out_index() + # and set_invocation_metadata's lineage view would otherwise be unset there. + fan_out_token = _set_fan_out_index(child_context.fan_out_index) + fan_out_chain_token = _set_fan_out_index_chain(child_context.fan_out_index_chain) + branch_chain_token = _set_branch_name_chain(child_context.branch_name_chain) + try: + yield + finally: + _reset_branch_name_chain(branch_chain_token) + _reset_fan_out_index_chain(fan_out_chain_token) + _reset_fan_out_index(fan_out_token) + + @dataclass(frozen=True) class FanOutConfig: """Frozen configuration for a :class:`FanOutNode`. @@ -291,8 +320,17 @@ async def innermost(s: ChildT) -> Mapping[str, Any]: return _extract_instance_partial(cfg, final_inst_state) chain: ChainCall = compose_chain(cfg.instance_middleware, innermost) + # Bind the lineage ContextVars around the chain only when there is + # instance middleware to read them; with none, the inner subgraph's + # nodes bind them and this level would be a redundant no-op. The + # context manager resets before the error-handling below so that + # path's saves keep their existing context. + lineage: AbstractContextManager[None] = ( + _bind_instance_lineage(child_context) if cfg.instance_middleware else nullcontext() + ) try: - partial = await chain(instance_state) + with lineage: + partial = await chain(instance_state) except Exception as exc: if cfg.error_policy == "collect": # Per §10.11.2 collect mode: the failure becomes a diff --git a/tests/unit/test_fan_out.py b/tests/unit/test_fan_out.py index 7307efe..4e14b48 100644 --- a/tests/unit/test_fan_out.py +++ b/tests/unit/test_fan_out.py @@ -654,6 +654,103 @@ async def maybe_fail(state: WorkerState) -> Mapping[str, Any]: assert instance_attempts == {7: 2, 9: 2} +async def test_instance_middleware_sees_fan_out_index() -> None: + # An instance_middleware that reads current_fan_out_index() / its chain + # observes the instance's own index: the engine sets the lineage ContextVars + # around the middleware chain, not only inside node bodies. (Regression -- + # the index was None here when only compiled.py set it, deeper in node + # execution, so the middleware wrapping the inner subgraph saw nothing.) + from openarmature.observability.correlation import ( + current_fan_out_index, + current_fan_out_index_chain, + ) + + seen_index: dict[int, int | None] = {} + seen_chain: dict[int, tuple[int | None, ...]] = {} + + class _RecordIndexMW: + async def __call__(self, state: WorkerState, next_: Any, /) -> Any: + # Key by the item so each instance is identifiable without relying + # on the index under test. + seen_index[state.item] = current_fan_out_index() + seen_chain[state.item] = current_fan_out_index_chain() + return await next_(state) + + async def compute(state: WorkerState) -> Mapping[str, Any]: + return {"result": state.item} + + inner = ( + GraphBuilder(WorkerState).add_node("compute", compute).add_edge("compute", END).set_entry("compute") + ).compile() + parent = ( + GraphBuilder(InstanceMwParentState) + .add_fan_out_node( + "process", + subgraph=inner, + items_field="items", + item_field="item", + collect_field="result", + target_field="results", + instance_middleware=[_RecordIndexMW()], + ) + .add_edge("process", END) + .set_entry("process") + ).compile() + + await parent.invoke(InstanceMwParentState(items=[10, 20, 30])) + await parent.drain() + + # items 10/20/30 are fan-out indices 0/1/2 in order; the chain carries the + # instance index at the leaf. + assert seen_index == {10: 0, 20: 1, 30: 2} + assert seen_chain == {10: (0,), 20: (1,), 30: (2,)} + + +async def test_instance_middleware_lineage_reset_on_failure() -> None: + # The lineage ContextVars reset even when an instance fails: the binding's + # finally runs on the exception path, so a failed instance leaks nothing + # into the parent scope. + from openarmature.observability.correlation import current_fan_out_index + + seen: list[int | None] = [] + + class _RecordMW: + async def __call__(self, state: WorkerState, next_: Any, /) -> Any: + seen.append(current_fan_out_index()) + return await next_(state) + + async def boom(_state: WorkerState) -> Mapping[str, Any]: + raise RuntimeError("boom") + + inner = ( + GraphBuilder(WorkerState).add_node("boom", boom).add_edge("boom", END).set_entry("boom") + ).compile() + parent = ( + GraphBuilder(InstanceMwParentState) + .add_fan_out_node( + "process", + subgraph=inner, + items_field="items", + item_field="item", + collect_field="result", + target_field="results", + instance_middleware=[_RecordMW()], + concurrency=1, + ) + .add_edge("process", END) + .set_entry("process") + ).compile() + + with pytest.raises(NodeException): + await parent.invoke(InstanceMwParentState(items=[1, 2])) + await parent.drain() + + # The middleware saw the instance index (the bind happened) ... + assert seen and all(idx is not None for idx in seen) + # ... and the bind's finally reset it despite the failure. + assert current_fan_out_index() is None + + # --------------------------------------------------------------------------- # Fan-in determinism under nondeterministic completion order (§9.4) # --------------------------------------------------------------------------- From 1944bce06a35ae0455893dfbb4e36451b728e74a Mon Sep 17 00:00:00 2001 From: chris-colinsky Date: Wed, 24 Jun 2026 17:20:42 -0700 Subject: [PATCH 2/2] Build fan-out test graphs step-by-step From CoPilot review of #189: expand the two new instance-middleware tests' inner and parent graph construction from inline method chains to the step-by-step named-builder pattern used throughout the module. --- tests/unit/test_fan_out.py | 74 ++++++++++++++++++++------------------ 1 file changed, 39 insertions(+), 35 deletions(-) diff --git a/tests/unit/test_fan_out.py b/tests/unit/test_fan_out.py index 4e14b48..665c244 100644 --- a/tests/unit/test_fan_out.py +++ b/tests/unit/test_fan_out.py @@ -679,23 +679,25 @@ async def __call__(self, state: WorkerState, next_: Any, /) -> Any: async def compute(state: WorkerState) -> Mapping[str, Any]: return {"result": state.item} - inner = ( - GraphBuilder(WorkerState).add_node("compute", compute).add_edge("compute", END).set_entry("compute") - ).compile() - parent = ( - GraphBuilder(InstanceMwParentState) - .add_fan_out_node( - "process", - subgraph=inner, - items_field="items", - item_field="item", - collect_field="result", - target_field="results", - instance_middleware=[_RecordIndexMW()], - ) - .add_edge("process", END) - .set_entry("process") - ).compile() + inner_builder: GraphBuilder[WorkerState] = GraphBuilder(WorkerState) + inner_builder.set_entry("compute") + inner_builder.add_node("compute", compute) + inner_builder.add_edge("compute", END) + inner = inner_builder.compile() + + parent_builder: GraphBuilder[InstanceMwParentState] = GraphBuilder(InstanceMwParentState) + parent_builder.set_entry("process") + parent_builder.add_fan_out_node( + "process", + subgraph=inner, + items_field="items", + item_field="item", + collect_field="result", + target_field="results", + instance_middleware=[_RecordIndexMW()], + ) + parent_builder.add_edge("process", END) + parent = parent_builder.compile() await parent.invoke(InstanceMwParentState(items=[10, 20, 30])) await parent.drain() @@ -722,24 +724,26 @@ async def __call__(self, state: WorkerState, next_: Any, /) -> Any: async def boom(_state: WorkerState) -> Mapping[str, Any]: raise RuntimeError("boom") - inner = ( - GraphBuilder(WorkerState).add_node("boom", boom).add_edge("boom", END).set_entry("boom") - ).compile() - parent = ( - GraphBuilder(InstanceMwParentState) - .add_fan_out_node( - "process", - subgraph=inner, - items_field="items", - item_field="item", - collect_field="result", - target_field="results", - instance_middleware=[_RecordMW()], - concurrency=1, - ) - .add_edge("process", END) - .set_entry("process") - ).compile() + inner_builder: GraphBuilder[WorkerState] = GraphBuilder(WorkerState) + inner_builder.set_entry("boom") + inner_builder.add_node("boom", boom) + inner_builder.add_edge("boom", END) + inner = inner_builder.compile() + + parent_builder: GraphBuilder[InstanceMwParentState] = GraphBuilder(InstanceMwParentState) + parent_builder.set_entry("process") + parent_builder.add_fan_out_node( + "process", + subgraph=inner, + items_field="items", + item_field="item", + collect_field="result", + target_field="results", + instance_middleware=[_RecordMW()], + concurrency=1, + ) + parent_builder.add_edge("process", END) + parent = parent_builder.compile() with pytest.raises(NodeException): await parent.invoke(InstanceMwParentState(items=[1, 2]))