From 8fe388538b9bae054bf2a8344c87c1555589fba7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cle=CC=81ment=20Doumouro?= Date: Fri, 19 Jun 2026 09:50:03 +0200 Subject: [PATCH 1/3] fix(datashare-python): set retry policy max attempt to 3 --- datashare-python/datashare_python/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datashare-python/datashare_python/utils.py b/datashare-python/datashare_python/utils.py index a4746b5..47298bb 100644 --- a/datashare-python/datashare_python/utils.py +++ b/datashare-python/datashare_python/utils.py @@ -98,7 +98,7 @@ async def update_progress(self, signal: ProgressSignal) -> None: def _retry_policy_with_default(retry_policy: RetryPolicy | None) -> RetryPolicy: if retry_policy is None: - retry_policy = RetryPolicy(non_retryable_error_types=[]) + retry_policy = RetryPolicy(non_retryable_error_types=[], maximum_attempts=3) retry_policy = deepcopy(retry_policy) non_retryable_error_types = set(retry_policy.non_retryable_error_types) non_retryable_error_types.update(_NEVER_RETRIABLES) From 01ad3eb2c096dff1416ff085119eeb2983292b3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cle=CC=81ment=20Doumouro?= Date: Mon, 22 Jun 2026 12:20:01 +0200 Subject: [PATCH 2/3] fix(datashare-python): pydantic deserialization + refactor activity definition using interceptors --- .../datashare_python/interceptors.py | 185 ++++++++++- datashare-python/datashare_python/types_.py | 6 + datashare-python/datashare_python/utils.py | 207 ++---------- datashare-python/datashare_python/worker.py | 12 +- datashare-python/tests/test_interceptors.py | 310 +++++++++++++++++- worker-template/worker_template/activities.py | 8 - workers/asr-worker/asr_worker/activities.py | 40 ++- .../extract_worker/activities.py | 4 - .../translation_worker/activities.py | 2 - 9 files changed, 546 insertions(+), 228 deletions(-) diff --git a/datashare-python/datashare_python/interceptors.py b/datashare-python/datashare_python/interceptors.py index 8b42f42..384e56d 100644 --- a/datashare-python/datashare_python/interceptors.py +++ b/datashare-python/datashare_python/interceptors.py @@ -1,14 +1,33 @@ +import asyncio +import contextlib +import dataclasses import secrets -from collections.abc import Generator, Mapping +from collections.abc import Callable, Generator, Mapping from contextlib import contextmanager from contextvars import ContextVar from copy import deepcopy -from typing import Annotated, Any, NoReturn, Self, TypeVar +from functools import partial +from inspect import signature +from types import UnionType +from typing import ( + Annotated, + Any, + NoReturn, + Self, + TypeVar, + get_args, + get_origin, + get_type_hints, +) from nexusrpc import InputT, OutputT from pydantic import Field +from temporalio import activity +from temporalio.activity import _Definition from temporalio.api.common.v1 import Payload +from temporalio.client import WorkflowHandle from temporalio.converter import DataConverter +from temporalio.exceptions import ApplicationError from temporalio.worker import ( ActivityInboundInterceptor, ContinueAsNewInput, @@ -33,7 +52,10 @@ NexusOperationHandle, ) +from .config import PYDANTIC_DATA_CONVERTER from .objects import BaseModel +from .types_ import ProgressRateHandler, Weight +from .utils import PROGRESS_HANDLER_ARG, ActivityWithProgress, ProgressSignal _TRACEPARENT = "traceparent" _DEFAULT_PAYLOAD_CONVERTER = DataConverter.default.payload_converter @@ -196,3 +218,162 @@ def _with_trace_context_header( next_ctx.traceparent ) return new_obj + + +class ProgressInterceptor(Interceptor): + def intercept_activity( + self, + next: ActivityInboundInterceptor, # noqa: A002 + ) -> ActivityInboundInterceptor: + return _ProgressInboundInterceptor(next) + + +def _parse_progress_weight(act_fn: Callable) -> float: + hints = get_type_hints(act_fn, include_extras=True) + hint = hints["progress"] + annotated_progress = get_origin(hint) + if annotated_progress is not Annotated: + return 1.0 + annotated_args = get_args(hint) + for ann in annotated_args[1:]: + if isinstance(ann, Weight): + return ann.value + return 1.0 + + +async def progress_handler( + progress: float, + handle: WorkflowHandle, + *, + activity_id: str, + run_id: str, + weight: float = 1.0, +) -> None: + signal = ProgressSignal( + activity_id=activity_id, run_id=run_id, progress=progress, weight=weight + ) + await handle.signal("update_progress", signal) + + +def supports_progress(task_fn: Callable) -> bool: + return any( + param.name == PROGRESS_HANDLER_ARG + for param in signature(task_fn).parameters.values() + ) + + +def _get_progress_handler(act_fn: Callable) -> ProgressRateHandler: + act = getattr(act_fn, "__self__", None) + # Weirdly isinstance doesn't work here + if act is None or not isinstance(act, ActivityWithProgress): + msg = ( + f"to support progress, activities should inherit from " + f"{ActivityWithProgress.__name__}." + ) + raise TypeError(msg) + weight = _parse_progress_weight(act_fn) + info = activity.info() + run_id = info.workflow_run_id + workflow_id = info.workflow_id + activity_id = activity.info().activity_id + client = act._temporal_client + workflow_handle = client.get_workflow_handle(workflow_id, run_id=run_id) + handler = partial( + progress_handler, + handle=workflow_handle, + run_id=run_id, + activity_id=activity_id, + weight=weight, + ) + return handler + + +def _is_progress(t: type) -> bool: + if t is ProgressRateHandler: + return True + return bool( + isinstance(t, UnionType) + and any(sub_t is ProgressRateHandler for sub_t in get_args(t)) + ) + + +def _without_progress(arg_types: list[type] | None) -> list[type] | None: + if arg_types is None: + return None + filtered = [t for t in arg_types if not _is_progress(t)] + return filtered + + +class _ProgressInboundInterceptor(ActivityInboundInterceptor): + async def execute_activity(self, input: ExecuteActivityInput) -> Any: # noqa: A002 + if not supports_progress(input.fn): + return await super().execute_activity(input) + # The progress args breaks trigger a bypass of the dataloader: + # https://github.com/temporalio/sdk-python/blob/631ebaf0e20fb214b16589b45627b358048a5d77/temporalio/worker/_activity.py#L600 + # we have to force it here again + progress_handler = _get_progress_handler(input.fn) + if input.args: + data_converter = PYDANTIC_DATA_CONVERTER + arg_types = _Definition.must_from_callable(input.fn).arg_types + arg_types = _without_progress(arg_types) + arg_types = arg_types[: len(input.args)] + try: + encoded = await data_converter.encode(input.args) + except Exception as e: + raise ApplicationError("Failed encoding arguments") from e + try: + new_args = await data_converter.decode(encoded, type_hints=arg_types) + except Exception as e: + raise ApplicationError("Failed decoding arguments") from e + new_args.append(progress_handler) + else: + new_args = [progress_handler] + new_input = dataclasses.replace(input, args=new_args) + await progress_handler(0.0) + res = await super().execute_activity(new_input) + await progress_handler(1.0) + return res + + +class HeartbeatInterceptor(Interceptor): + def __init__(self, n_missed_before_timeout: int = 5): + self._n_missed_before_timeout = n_missed_before_timeout + + def intercept_activity( + self, + next: ActivityInboundInterceptor, # noqa: A002 + ) -> ActivityInboundInterceptor: + return _HeartbeatInboundInterceptor(next, self._n_missed_before_timeout) + + +async def _heartbeat_every(period: float, *details: Any) -> None: + with contextlib.suppress(RuntimeError, asyncio.TimeoutError): + activity.heartbeat(*details) + while True: + await asyncio.sleep(period) + with contextlib.suppress(RuntimeError, asyncio.TimeoutError): + activity.heartbeat(*details) + + +class _HeartbeatInboundInterceptor(ActivityInboundInterceptor): + def __init__( + self, + next: ActivityInboundInterceptor, # noqa: A002 + n_missed_before_timeout: int = 5, + ) -> None: + super().__init__(next) + self._n_missed_before_timeout = n_missed_before_timeout + + async def execute_activity(self, input: ExecuteActivityInput) -> Any: # noqa: A002 + heartbeat_timeout = activity.info().heartbeat_timeout + heartbeat_task = None + if heartbeat_timeout: + period = heartbeat_timeout.total_seconds() / self._n_missed_before_timeout + heartbeat_task = asyncio.create_task(_heartbeat_every(period)) + try: + activity.heartbeat() + return await super().execute_activity(input) + finally: + if heartbeat_task: + heartbeat_task.cancel() + await asyncio.wait([heartbeat_task]) diff --git a/datashare-python/datashare_python/types_.py b/datashare-python/datashare_python/types_.py index fad7ccc..7bc9fb7 100644 --- a/datashare-python/datashare_python/types_.py +++ b/datashare-python/datashare_python/types_.py @@ -1,5 +1,6 @@ from collections.abc import Coroutine from contextlib import AbstractAsyncContextManager, AbstractContextManager +from dataclasses import dataclass from typing import Protocol from temporalio.client import Client @@ -12,6 +13,11 @@ async def __call__(self, progress_rate: float) -> None: pass +@dataclass +class Weight: + value: float + + class RawProgressHandler(Protocol): async def __call__(self, iteration: int) -> None: pass diff --git a/datashare-python/datashare_python/utils.py b/datashare-python/datashare_python/utils.py index 47298bb..a9006e1 100644 --- a/datashare-python/datashare_python/utils.py +++ b/datashare-python/datashare_python/utils.py @@ -1,18 +1,15 @@ import asyncio import contextlib -import contextvars import inspect import json import os import shutil -import threading -from collections.abc import Awaitable, Callable, Coroutine, Iterable +from collections.abc import Callable, Coroutine, Iterable from copy import deepcopy from dataclasses import dataclass from datetime import timedelta -from functools import partial, wraps +from functools import wraps from hashlib import sha256 -from inspect import signature from io import BytesIO from pathlib import Path from typing import Any, ParamSpec, TypeVar @@ -21,7 +18,7 @@ import nest_asyncio import temporalio from temporalio import activity, workflow -from temporalio.client import Client, WorkflowHandle +from temporalio.client import Client from temporalio.common import RetryPolicy, SearchAttributeKey from temporalio.exceptions import ApplicationError @@ -100,7 +97,12 @@ def _retry_policy_with_default(retry_policy: RetryPolicy | None) -> RetryPolicy: if retry_policy is None: retry_policy = RetryPolicy(non_retryable_error_types=[], maximum_attempts=3) retry_policy = deepcopy(retry_policy) - non_retryable_error_types = set(retry_policy.non_retryable_error_types) + non_retryable_error_types = ( + retry_policy.non_retryable_error_types + if retry_policy.non_retryable_error_types is not None + else [] + ) + non_retryable_error_types = set(non_retryable_error_types) non_retryable_error_types.update(_NEVER_RETRIABLES) retry_policy.non_retryable_error_types = list(non_retryable_error_types) return retry_policy @@ -130,167 +132,6 @@ async def execute_activity( ) -async def progress_handler( - progress: float, - handle: WorkflowHandle, - *, - activity_id: str, - run_id: str, - weight: float = 1.0, -) -> None: - signal = ProgressSignal( - activity_id=activity_id, run_id=run_id, progress=progress, weight=weight - ) - await handle.signal("update_progress", signal) - with contextlib.suppress(RuntimeError, asyncio.TimeoutError): - activity.heartbeat() - - -def get_activity_progress_handler_async( - client: Client, weight: float -) -> ProgressRateHandler: - info = activity.info() - run_id = info.workflow_run_id - workflow_id = info.workflow_id - activity_id = activity.info().activity_id - workflow_handle = client.get_workflow_handle(workflow_id, run_id=run_id) - handler = partial( - progress_handler, - handle=workflow_handle, - run_id=run_id, - activity_id=activity_id, - weight=weight, - ) - return handler - - -def supports_progress(task_fn: Callable) -> bool: - return any( - param.name == PROGRESS_HANDLER_ARG - for param in signature(task_fn).parameters.values() - ) - - -def with_progress(weight: float = 1.0) -> Callable[P, T]: - if isinstance(weight, Callable): - return with_progress(weight=1)(weight) - - def decorator(activity_fn: Callable[P, T]) -> Callable[P, T]: - # TODO: handle the fact activities should have only positional args... - if asyncio.iscoroutinefunction(activity_fn): - - @wraps(activity_fn) - async def wrapper(self: ActivityWithProgress, *args: P.args) -> T: - if not isinstance(self, ActivityWithProgress): - msg = ( - f"{with_progress.__name__} decorator is meant to be used on " - f"activities defined as an {ActivityWithProgress.__name__}" - f" method, expected a {ActivityWithProgress.__name__} as first" - f" argument, found {self}" - ) - raise TypeError(msg) - handler = get_activity_progress_handler_async( - client=self._temporal_client, weight=weight - ) - await handler(0.0) - res = await activity_fn(self, *args, progress=handler) - await handler(1.0) - return res - - else: - - @wraps(activity_fn) - def wrapper(self: ActivityWithProgress, *args: P.args) -> T: - if not isinstance(self, ActivityWithProgress): - msg = ( - f"{with_progress.__name__} decorator is meant to be used on " - f"activities defined as an {ActivityWithProgress.__name__}" - f" method, expected a {ActivityWithProgress.__name__} as first" - f" argument, found {self}" - ) - raise TypeError(msg) - handler = get_activity_progress_handler_async( - client=self._temporal_client, weight=weight - ) - event_loop = self._event_loop - asyncio.run_coroutine_threadsafe(handler(0.0), event_loop).result() - res = activity_fn(self, *args, progress=handler) - asyncio.run_coroutine_threadsafe(handler(1.0), event_loop).result() - return res - - return wrapper - - return decorator - - -def with_async_heartbeat( - activity_fn: Callable[P, Awaitable[T]], n_missed_before_timeout: int -) -> Callable[P, Awaitable[T]]: - # Copied from - # https://github.com/temporalio/samples-python/blob/main/custom_decorator/activity_utils.py - @wraps(activity_fn) - async def wrapper(*args, **kwargs) -> T: - heartbeat_timeout = activity.info().heartbeat_timeout - heartbeat_task = None - if heartbeat_timeout: - period = heartbeat_timeout.total_seconds() / n_missed_before_timeout - heartbeat_task = asyncio.create_task(_async_heartbeat_every(period)) - try: - activity.heartbeat() - return await activity_fn(*args, **kwargs) - finally: - if heartbeat_task: - heartbeat_task.cancel() - await asyncio.wait([heartbeat_task]) - - return wrapper - - -async def _async_heartbeat_every(period: float, *details: Any) -> None: - with contextlib.suppress(RuntimeError, asyncio.TimeoutError): - activity.heartbeat(*details) - while True: - await asyncio.sleep(period) - with contextlib.suppress(RuntimeError, asyncio.TimeoutError): - activity.heartbeat(*details) - - -def with_sync_heartbeat( - activity_fn: Callable[P, T], n_missed_before_timeout: int -) -> Callable[P, T]: - @wraps(activity_fn) - def wrapper(*args, **kwargs) -> T: - heartbeat_timeout = activity.info().heartbeat_timeout - heartbeat_thread, stop_event = None, None - if heartbeat_timeout: - period = heartbeat_timeout.total_seconds() / n_missed_before_timeout - ctx = contextvars.copy_context() - run_args = (_sync_heartbeat_every, period, threading.Event()) - heartbeat_thread, stop_event = ( - threading.Thread(target=ctx.run, args=run_args), - run_args[-1], - ) - heartbeat_thread.start() - try: - return activity_fn(*args, **kwargs) - finally: - if heartbeat_thread: - stop_event.set() - heartbeat_thread.join() - - return wrapper - - -def _sync_heartbeat_every( - period: float, stop_event: threading.Event, *details: Any -) -> None: - with contextlib.suppress(RuntimeError, asyncio.TimeoutError): - activity.heartbeat(*details) - while not stop_event.wait(period): - with contextlib.suppress(RuntimeError, asyncio.TimeoutError): - activity.heartbeat(*details) - - def positional_args_only(activity_fn: Callable[P, T]) -> Callable[P, T]: sig = inspect.signature(activity_fn) @@ -395,29 +236,27 @@ def wrapper(*args, **kwargs) -> T: def activity_defn( - name: str, - progress_weight: float = 1.0, - retriables: set[type[Exception]] = None, - n_missed_heartbeats_before_timeout: int = 5, + name: str, retriables: set[type[Exception]] = None ) -> Callable[[Callable[P, T]], Callable[P, T]]: + def decorator(activity_fn: Callable[P, T]) -> Callable[P, T]: - # TODO: some of these could probably be reimplemented more elegantly using - # temporal interceptors: https://docs.temporal.io/develop/python/workers/interceptors activity_fn = positional_args_only(activity_fn) activity_fn = with_retriables(retriables)(activity_fn) - if supports_progress(activity_fn): - activity_fn = with_progress(progress_weight)(activity_fn) + activity_fn = activity.defn(activity_fn, name=name) + is_async = asyncio.iscoroutinefunction(activity_fn) if is_async: - activity_fn = with_async_heartbeat( - activity_fn, n_missed_heartbeats_before_timeout - ) + + @wraps(activity_fn) + async def wrapper(*args, **kwargs) -> T: + return await activity_fn(*args, **kwargs) else: - activity_fn = with_sync_heartbeat( - activity_fn, n_missed_heartbeats_before_timeout - ) - activity_fn = activity.defn(activity_fn, name=name) - return activity_fn + + @wraps(activity_fn) + def wrapper(*args, **kwargs) -> T: + return activity_fn(*args, **kwargs) + + return wrapper return decorator diff --git a/datashare-python/datashare_python/worker.py b/datashare-python/datashare_python/worker.py index 6931415..d7337cf 100644 --- a/datashare-python/datashare_python/worker.py +++ b/datashare-python/datashare_python/worker.py @@ -21,7 +21,11 @@ from .config import WorkerConfig from .dependencies import with_dependencies from .discovery import Activity -from .interceptors import TraceContextInterceptor +from .interceptors import ( + HeartbeatInterceptor, + ProgressInterceptor, + TraceContextInterceptor, +) from .types_ import ContextManagerFactory, TemporalClient logger = logging.getLogger(__name__) @@ -91,7 +95,11 @@ def datashare_worker( max_concurrent_activities = 1 if workflows: logger.warning(_SEPARATE_IO_AND_CPU_WORKERS) - interceptors = [TraceContextInterceptor()] + interceptors = [ + TraceContextInterceptor(), + ProgressInterceptor(), + HeartbeatInterceptor(), + ] wf_runner = SandboxedWorkflowRunner() if sandboxed else UnsandboxedWorkflowRunner() return DatashareWorker( client, diff --git a/datashare-python/tests/test_interceptors.py b/datashare-python/tests/test_interceptors.py index f948af1..93ea1e2 100644 --- a/datashare-python/tests/test_interceptors.py +++ b/datashare-python/tests/test_interceptors.py @@ -1,30 +1,56 @@ +import asyncio +import logging import uuid from collections.abc import AsyncGenerator +from concurrent.futures import ThreadPoolExecutor from datetime import timedelta -from typing import Any +from typing import Annotated, Any import pytest import temporalio +from datashare_python.objects import DatashareModel +from datashare_python.utils import ( + ActivityWithProgress, + WorkflowWithProgress, + activity_defn, + execute_activity, +) +from temporalio import exceptions as temporalio_exceptions +from temporalio.common import RetryPolicy with temporalio.workflow.unsafe.imports_passed_through(): from datashare_python.config import PYDANTIC_DATA_CONVERTER, WorkerConfig from datashare_python.interceptors import ( + HeartbeatInterceptor, + ProgressInterceptor, TraceContext, TraceContextInterceptor, get_trace_context, ) - from datashare_python.types_ import TemporalClient - from temporalio import activity, workflow + from datashare_python.types_ import ProgressRateHandler, TemporalClient, Weight + from temporalio import workflow from temporalio.client import ( Interceptor, OutboundInterceptor, StartWorkflowInput, + WorkflowFailureError, WorkflowHandle, ) from temporalio.converter import DataConverter from temporalio.worker import Worker -_TEST_CTX_QUEUE = "test.ctx.queue" + +logger = logging.getLogger(__name__) + + +class TestTaskQueue: + HEARTBEAT = "test.heartbeat" + NO_HEARTBEAT = "test.no_heartbeat" + PROGRESS_SYNC = "test.progress.sync" + PROGRESS_ASYNC = "test.progress.async" + TRACE = "test.trace" + WORKFLOWS = "test.workflows" + _DUMMY_TRACE_CTX = TraceContext(version="00", trace_id="trace_id", parent_id="trace_id") _DEFAULT_PAYLOAD_CONVERTER = DataConverter.default.payload_converter @@ -46,7 +72,7 @@ def intercept_client(self, next: OutboundInterceptor) -> OutboundInterceptor: # return super().intercept_client(_MockOutboundInterceptor(next)) -_TIMEOUT = timedelta(seconds=10) +_TIMEOUT = timedelta(seconds=180) @workflow.defn @@ -58,45 +84,230 @@ async def run(self) -> list[TraceContext]: ctx_log = await workflow.execute_activity( ctx_test_act, ctx_log, - task_queue=_TEST_CTX_QUEUE, + task_queue=TestTaskQueue.TRACE, start_to_close_timeout=_TIMEOUT, ) ctx_log = await workflow.execute_activity( ctx_test_act, ctx_log, - task_queue=_TEST_CTX_QUEUE, + task_queue=TestTaskQueue.TRACE, start_to_close_timeout=_TIMEOUT, ) return ctx_log -@activity.defn +@activity_defn(name="ctx-test") async def ctx_test_act(previous: list[TraceContext]) -> list[TraceContext]: previous.append(get_trace_context()) return previous +@activity_defn(name="sleep-for-act") +async def sleep_for_act(duration: float) -> None: + await asyncio.sleep(duration) + + +class ProgressArg(DatashareModel): + name: str + + +class _ProgressAct(ActivityWithProgress): + @activity_defn("hello-async") + async def hello_async_act( + self, + args: ProgressArg, + # We test variable args lengths doesn't break deserialization: + # https://github.com/temporalio/sdk-python/issues/360 + extra: str | None = None, + *, + progress: Annotated[ProgressRateHandler | None, Weight(value=5.0)] = None, + ) -> str: + if progress is not None: + await progress(0.1) + hello = f"hello {args.name}" + if extra: + hello += f"{hello} + {extra}" + if progress is not None: + await progress(0.9) + return hello + + @activity_defn("hello-sync") + def hello_sync_act( + self, + args: ProgressArg, + *, + progress: ProgressRateHandler | None = None, + ) -> str: + if progress is not None: + self._event_loop.run_until_complete(progress(0.1)) + hello = f"hello {args.name}" + if progress is not None: + self._event_loop.run_until_complete(progress(0.9)) + return hello + + +@workflow.defn(name="progress") +class _TestProgressWorkflow(WorkflowWithProgress): + @workflow.run + async def run(self, args: ProgressArg) -> None: + await execute_activity( + _ProgressAct.hello_sync_act, + args=[args], + task_queue=TestTaskQueue.PROGRESS_SYNC, + start_to_close_timeout=_TIMEOUT, + ) + await execute_activity( + _ProgressAct.hello_async_act, + args=[args], + task_queue=TestTaskQueue.PROGRESS_ASYNC, + start_to_close_timeout=_TIMEOUT, + ) + + +@workflow.defn(name="heartbeat") +class _TestHeartbeatWorkflow(WorkflowWithProgress): + @workflow.run + async def run(self) -> None: + await execute_activity( + sleep_for_act, + arg=1, + task_queue=TestTaskQueue.HEARTBEAT, + start_to_close_timeout=_TIMEOUT, + heartbeat_timeout=timedelta(milliseconds=100), + retry_policy=RetryPolicy(maximum_attempts=1), + ) + + +@workflow.defn(name="no-heartbeat") +class _TestNoHeartbeatWorkflow(WorkflowWithProgress): + @workflow.run + async def run(self) -> None: + await execute_activity( + sleep_for_act, + arg=1, + task_queue=TestTaskQueue.NO_HEARTBEAT, + start_to_close_timeout=_TIMEOUT, + heartbeat_timeout=timedelta(milliseconds=100), + retry_policy=RetryPolicy(maximum_attempts=1), + ) + + +@pytest.fixture(scope="session") +async def test_wf_worker( + test_temporal_client_session: TemporalClient, +) -> AsyncGenerator[None, None]: + client = test_temporal_client_session + worker_id = f"test-interceptor-worker-{uuid.uuid4()}" + interceptors = [TraceContextInterceptor()] + wfs = [_TestTraceContentWorkflow, _TestProgressWorkflow] + worker = Worker( + client, + identity=worker_id, + workflows=wfs, + interceptors=interceptors, + task_queue=TestTaskQueue.WORKFLOWS, + ) + async with worker: + yield + + @pytest.fixture(scope="session") -async def test_interceptor_worker( +async def test_trace_worker( test_temporal_client_session: TemporalClient, ) -> AsyncGenerator[None, None]: client = test_temporal_client_session worker_id = f"test-interceptor-worker-{uuid.uuid4()}" + activities = [ctx_test_act] interceptors = [TraceContextInterceptor()] worker = Worker( client, identity=worker_id, - activities=[ctx_test_act], - workflows=[_TestTraceContentWorkflow], + activities=activities, + interceptors=interceptors, + task_queue=TestTaskQueue.TRACE, + ) + async with worker: + yield + + +@pytest.fixture(scope="session") +async def test_progress_interceptor_sync_worker( + test_temporal_client_session: TemporalClient, event_loop: asyncio.AbstractEventLoop +) -> AsyncGenerator[None, None]: + client = test_temporal_client_session + worker_id = f"test-progress-sync-worker-{uuid.uuid4()}" + interceptors = [ProgressInterceptor()] + act = _ProgressAct(test_temporal_client_session, event_loop) + worker = Worker( + client, + identity=worker_id, + activities=[act.hello_sync_act], + activity_executor=ThreadPoolExecutor(), interceptors=interceptors, - task_queue=_TEST_CTX_QUEUE, + task_queue=TestTaskQueue.PROGRESS_SYNC, + ) + async with worker: + yield + + +@pytest.fixture(scope="session") +async def test_progress_interceptor_async_worker( + test_temporal_client_session: TemporalClient, event_loop: asyncio.AbstractEventLoop +) -> AsyncGenerator[None, None]: + client = test_temporal_client_session + worker_id = f"test-progress-async-worker-{uuid.uuid4()}" + interceptors = [ProgressInterceptor()] + act = _ProgressAct(test_temporal_client_session, event_loop) + worker = Worker( + client, + identity=worker_id, + activities=[act.hello_async_act], + interceptors=interceptors, + task_queue=TestTaskQueue.PROGRESS_ASYNC, + ) + async with worker: + yield + + +@pytest.fixture(scope="session") +async def test_heartbeat_interceptor_worker( + test_temporal_client_session: TemporalClient, +) -> AsyncGenerator[None, None]: + client = test_temporal_client_session + worker_id = f"test-heartbeat-worker-{uuid.uuid4()}" + interceptors = [HeartbeatInterceptor()] + worker = Worker( + client, + identity=worker_id, + workflows=[_TestHeartbeatWorkflow], + activities=[sleep_for_act], + interceptors=interceptors, + task_queue=TestTaskQueue.HEARTBEAT, + ) + async with worker: + yield + + +@pytest.fixture(scope="session") +async def test_heartbeat_interceptor_no_heartbeat_worker( + test_temporal_client_session: TemporalClient, +) -> AsyncGenerator[None, None]: + client = test_temporal_client_session + worker_id = f"test-heartbeat-worker-{uuid.uuid4()}" + worker = Worker( + client, + identity=worker_id, + workflows=[_TestNoHeartbeatWorkflow], + activities=[sleep_for_act], + task_queue=TestTaskQueue.NO_HEARTBEAT, ) async with worker: yield async def test_trace_context_interceptor( - test_interceptor_worker, # noqa: ANN001, ARG001 + test_wf_worker, # noqa: ANN001, ARG001 + test_trace_worker, # noqa: ANN001, ARG001 test_worker_config: WorkerConfig, ) -> None: # Given @@ -110,7 +321,7 @@ async def test_trace_context_interceptor( wf_id = f"wf-test-interceptor-{uuid.uuid4()}" # When res = await client.execute_workflow( - _TestTraceContentWorkflow, id=wf_id, task_queue=_TEST_CTX_QUEUE + _TestTraceContentWorkflow, id=wf_id, task_queue=TestTaskQueue.WORKFLOWS ) # Then assert len(res) == 3 @@ -121,3 +332,74 @@ async def test_trace_context_interceptor( assert trace_ctx.trace_id == _DUMMY_TRACE_CTX.trace_id assert len(trace_ctx.parent_id) == 16 assert trace_ctx.sampled + + +async def test_progress_interceptor( + test_wf_worker, # noqa: ANN001, ARG001 + test_progress_interceptor_sync_worker, # noqa: ANN001, ARG001 + test_progress_interceptor_async_worker, # noqa: ANN001, ARG001 + test_worker_config: WorkerConfig, +) -> None: + # Given + temporal_config = test_worker_config.temporal + client = await TemporalClient.connect( + target_host=temporal_config.host, + namespace=temporal_config.namespace, + data_converter=PYDANTIC_DATA_CONVERTER, + ) + wf_id = f"wf-test-progress-{uuid.uuid4()}" + # When + await client.execute_workflow( + _TestProgressWorkflow, + args=[ProgressArg(name="world")], + id=wf_id, + task_queue=TestTaskQueue.WORKFLOWS, + ) + # Then + wf = client.get_workflow_handle(workflow_id=wf_id) + search_attributes = (await wf.describe()).search_attributes + max_progress = search_attributes.get("MaxProgress")[0] + assert max_progress == 6.0 + progress = search_attributes.get("Progress")[0] + assert progress == 6.0 + + +async def test_heartbeat_interceptor( + test_heartbeat_interceptor_worker, # noqa: ANN001, ARG001 + test_worker_config: WorkerConfig, +) -> None: + # Given + temporal_config = test_worker_config.temporal + client = await TemporalClient.connect( + target_host=temporal_config.host, + namespace=temporal_config.namespace, + data_converter=PYDANTIC_DATA_CONVERTER, + ) + wf_id = f"wf-test-heartbeat-{uuid.uuid4()}" + # When + res = await client.execute_workflow( + _TestHeartbeatWorkflow, id=wf_id, task_queue=TestTaskQueue.HEARTBEAT + ) + assert res is None + + +async def test_heartbeat_interceptor_should_fail_when_no_heartbeat( + test_heartbeat_interceptor_no_heartbeat_worker, # noqa: ANN001, ARG001 + test_worker_config: WorkerConfig, +) -> None: + # Given + temporal_config = test_worker_config.temporal + client = await TemporalClient.connect( + target_host=temporal_config.host, + namespace=temporal_config.namespace, + data_converter=PYDANTIC_DATA_CONVERTER, + ) + wf_id = f"wf-test-heartbeat-{uuid.uuid4()}" + # When + with pytest.raises(WorkflowFailureError) as ctx: + await client.execute_workflow( + _TestNoHeartbeatWorkflow, id=wf_id, task_queue=TestTaskQueue.NO_HEARTBEAT + ) + cause = ctx.value.cause.__cause__ + assert isinstance(cause, temporalio_exceptions.TimeoutError) + assert "Heartbeat timeout" in cause.args[0] diff --git a/worker-template/worker_template/activities.py b/worker-template/worker_template/activities.py index d96d94b..4cbe33a 100644 --- a/worker-template/worker_template/activities.py +++ b/worker-template/worker_template/activities.py @@ -156,8 +156,6 @@ async def create_classification_batches( progress: ProgressRateHandler | None = None, logger: logging.Logger, ) -> list[list[str]]: - if not isinstance(config, ClassificationConfig): - config = ClassificationConfig.model_validate(config) # Retrieve unprocessed docs. model = config.model unclassified = _get_unclassified( @@ -207,10 +205,6 @@ async def translate_docs( if config is None: config = TranslationConfig() - # TODO: this should not happen - if not isinstance(config, TranslationConfig): - config = TranslationConfig.model_validate(config) - n_docs = len(docs) if not n_docs: return 0 @@ -284,8 +278,6 @@ async def classify_docs( if config is None: config = ClassificationConfig() # TODO: fix this, we should have a ClassificationConfig hered - if not isinstance(config, ClassificationConfig): - config = ClassificationConfig.model_validate(config) n_docs = len(docs) model = config.model diff --git a/workers/asr-worker/asr_worker/activities.py b/workers/asr-worker/asr_worker/activities.py index 240fd43..9af7fa7 100644 --- a/workers/asr-worker/asr_worker/activities.py +++ b/workers/asr-worker/asr_worker/activities.py @@ -4,7 +4,7 @@ from collections.abc import AsyncGenerator, AsyncIterable, Iterable from itertools import tee from pathlib import Path -from typing import Any, cast +from typing import Annotated, Any, cast from caul.objects import ASRResult, PreprocessedInput from caul.tasks import ( @@ -20,7 +20,7 @@ Document, FilesystemDocument, ) -from datashare_python.types_ import ProgressRateHandler, RawProgressHandler +from datashare_python.types_ import ProgressRateHandler, RawProgressHandler, Weight from datashare_python.utils import ( ActivityWithProgress, activity_defn, @@ -74,9 +74,16 @@ class ASRActivities(ActivityWithProgress): - @activity_defn(name=SEARCH_AUDIOS_ACTIVITY, progress_weight=_SEARCH_AUDIOS_WEIGHT) + @activity_defn(name=SEARCH_AUDIOS_ACTIVITY) async def search_audio_paths( - self, project: str, query: dict[str, Any], batch_size: int + self, + project: str, + query: dict[str, Any], + batch_size: int, + *, + progress: Annotated[ # noqa: ARG002 + ProgressRateHandler | None, Weight(value=_SEARCH_AUDIOS_WEIGHT) + ] = None, ) -> list[Path]: es_client = lifespan_es_client() worker_config = cast(ASRWorkerConfig, lifespan_worker_config()) @@ -96,9 +103,16 @@ async def search_audio_paths( ] return batch_paths - @activity_defn(name=PREPROCESS_ACTIVITY, progress_weight=_PREPROCESS_WEIGHT) + @activity_defn(name=PREPROCESS_ACTIVITY) def preprocess( - self, audio_batch: Path, project: str, config: ParakeetPreprocessorConfig + self, + audio_batch: Path, + project: str, + config: ParakeetPreprocessorConfig, + *, + progress: Annotated[ # noqa: ARG002 + ProgressRateHandler | None, Weight(value=_PREPROCESS_WEIGHT) + ] = None, ) -> list[Path]: # TODO: this shouldn't be necessary, fix this bug worker_config = cast(ASRWorkerConfig, lifespan_worker_config()) @@ -118,14 +132,16 @@ def preprocess( batches = [p.relative_to(workdir) for p in batch_paths] return batches - @activity_defn(name=RUN_INFERENCE_ACTIVITY, progress_weight=_INFERENCE_WEIGHT) + @activity_defn(name=RUN_INFERENCE_ACTIVITY) async def infer( self, preprocessed_inputs: list[Path], project: str, config: InferenceRunnerConfig, *, - progress: ProgressRateHandler | None = None, + progress: Annotated[ # noqa: ARG002 + ProgressRateHandler | None, Weight(value=_INFERENCE_WEIGHT) + ] = None, ) -> list[Path]: # TODO: fix this temporal by, we shouldn't have to reload config = _INFERENCE_CONFIG_TYPE_ADAPTER.validate_python(config) @@ -153,7 +169,7 @@ async def infer( inference_res = [p.relative_to(workdir) async for p in inference_res] return inference_res - @activity_defn(name=POSTPROCESS_ACTIVITY, progress_weight=_BASE_WEIGHT) + @activity_defn(name=POSTPROCESS_ACTIVITY) def postprocess( self, inference_results: list[Path], @@ -161,10 +177,10 @@ def postprocess( config: ParakeetPostprocessorConfig, project: str, *, - progress: ProgressRateHandler | None = None, + progress: Annotated[ # noqa: ARG002 + ProgressRateHandler | None, Weight(value=_BASE_WEIGHT) + ] = None, ) -> int: - # TODO: this shouldn't be necessary, fix this bug - config = ParakeetPostprocessorConfig.model_validate(config) worker_config = cast(ASRWorkerConfig, lifespan_worker_config()) workdir = worker_config.workdir audio_batch = workdir / audio_batch diff --git a/workers/extract-worker/extract_worker/activities.py b/workers/extract-worker/extract_worker/activities.py index 16a1645..87a288a 100644 --- a/workers/extract-worker/extract_worker/activities.py +++ b/workers/extract-worker/extract_worker/activities.py @@ -20,7 +20,6 @@ write_artifact, ) from extract_core import ( - BasePipelineConfig, InputDoc, OutputFormat, Pipeline, @@ -118,9 +117,6 @@ async def extract_markdown_content( MinerUPipeline, ) - if not isinstance(config, BasePipelineConfig): - config = PIPELINE_CONFIG_TA.validate_python(config) - pipeline = Pipeline.from_config(config) worker_config = cast(ExtractWorkerConfig, lifespan_worker_config()) workdir = worker_config.workdir diff --git a/workers/translation-worker/translation_worker/activities.py b/workers/translation-worker/translation_worker/activities.py index e58cad2..7014db2 100644 --- a/workers/translation-worker/translation_worker/activities.py +++ b/workers/translation-worker/translation_worker/activities.py @@ -159,8 +159,6 @@ async def translate_docs_act( progress: ProgressRateHandler | None = None, # noqa: F821 ) -> int: # TODO: this should not happen - if not isinstance(worker_config, TranslationWorkerConfig): - worker_config = TranslationWorkerConfig.model_validate(worker_config) es_queue = asyncio.Queue() publisher = _translate_and_queue( batches, From a1b9c439049ea97e05cd659622c752c27b69e7d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cle=CC=81ment=20Doumouro?= Date: Mon, 22 Jun 2026 12:20:01 +0200 Subject: [PATCH 3/3] fix(datashare-python): fix workflow args deserialization error --- datashare-python/datashare_python/config.py | 28 +--------- .../datashare_python/interceptors.py | 19 +++---- datashare-python/datashare_python/utils.py | 45 +++++++++++++++- datashare-python/tests/test_interceptors.py | 3 +- datashare-python/tests/test_utils.py | 52 ++++++++++++++++++- 5 files changed, 106 insertions(+), 41 deletions(-) diff --git a/datashare-python/datashare_python/config.py b/datashare-python/datashare_python/config.py index 7adf301..14b0b38 100644 --- a/datashare-python/datashare_python/config.py +++ b/datashare-python/datashare_python/config.py @@ -6,13 +6,6 @@ from icij_common.pydantic_utils import ICIJSettings from pydantic import PrivateAttr from pydantic_settings import SettingsConfigDict -from temporalio.contrib.pydantic import PydanticJSONPlainPayloadConverter, ToJsonOptions -from temporalio.converter import ( - CompositePayloadConverter, - DataConverter, - DefaultPayloadConverter, - JSONPlainPayloadConverter, -) from temporalio.runtime import PrometheusConfig, Runtime, TelemetryConfig import datashare_python @@ -20,6 +13,7 @@ from .objects import BaseModel from .task_client import DatashareTaskClient from .types_ import TemporalClient +from .utils import PYDANTIC_DATA_CONVERTER _ALL_LOGGERS = [datashare_python.__name__] @@ -129,23 +123,3 @@ def to_task_client(self) -> DatashareTaskClient: async def to_temporal_client(self) -> TemporalClient: return await self.temporal.to_client() - - -class _PydanticPayloadConverter(CompositePayloadConverter): - def __init__(self) -> None: - json_payload_converter = PydanticJSONPlainPayloadConverter( - ToJsonOptions(exclude_unset=False) - ) - super().__init__( - *( - c - if not isinstance(c, JSONPlainPayloadConverter) - else json_payload_converter - for c in DefaultPayloadConverter.default_encoding_payload_converters - ) - ) - - -PYDANTIC_DATA_CONVERTER = DataConverter( - payload_converter_class=_PydanticPayloadConverter -) diff --git a/datashare-python/datashare_python/interceptors.py b/datashare-python/datashare_python/interceptors.py index 384e56d..36a9ac0 100644 --- a/datashare-python/datashare_python/interceptors.py +++ b/datashare-python/datashare_python/interceptors.py @@ -27,7 +27,6 @@ from temporalio.api.common.v1 import Payload from temporalio.client import WorkflowHandle from temporalio.converter import DataConverter -from temporalio.exceptions import ApplicationError from temporalio.worker import ( ActivityInboundInterceptor, ContinueAsNewInput, @@ -52,10 +51,14 @@ NexusOperationHandle, ) -from .config import PYDANTIC_DATA_CONVERTER from .objects import BaseModel from .types_ import ProgressRateHandler, Weight -from .utils import PROGRESS_HANDLER_ARG, ActivityWithProgress, ProgressSignal +from .utils import ( + PROGRESS_HANDLER_ARG, + PYDANTIC_DATA_CONVERTER, + ActivityWithProgress, + ProgressSignal, +) _TRACEPARENT = "traceparent" _DEFAULT_PAYLOAD_CONVERTER = DataConverter.default.payload_converter @@ -317,14 +320,8 @@ async def execute_activity(self, input: ExecuteActivityInput) -> Any: # noqa: A arg_types = _Definition.must_from_callable(input.fn).arg_types arg_types = _without_progress(arg_types) arg_types = arg_types[: len(input.args)] - try: - encoded = await data_converter.encode(input.args) - except Exception as e: - raise ApplicationError("Failed encoding arguments") from e - try: - new_args = await data_converter.decode(encoded, type_hints=arg_types) - except Exception as e: - raise ApplicationError("Failed decoding arguments") from e + encoded = await data_converter.encode(input.args) + new_args = await data_converter.decode(encoded, type_hints=arg_types) new_args.append(progress_handler) else: new_args = [progress_handler] diff --git a/datashare-python/datashare_python/utils.py b/datashare-python/datashare_python/utils.py index a9006e1..64a33fa 100644 --- a/datashare-python/datashare_python/utils.py +++ b/datashare-python/datashare_python/utils.py @@ -4,7 +4,7 @@ import json import os import shutil -from collections.abc import Callable, Coroutine, Iterable +from collections.abc import Callable, Coroutine, Iterable, Sequence from copy import deepcopy from dataclasses import dataclass from datetime import timedelta @@ -17,9 +17,18 @@ import nest_asyncio import temporalio +from pydantic import ValidationError from temporalio import activity, workflow +from temporalio.api.common.v1 import Payload from temporalio.client import Client from temporalio.common import RetryPolicy, SearchAttributeKey +from temporalio.contrib.pydantic import PydanticJSONPlainPayloadConverter, ToJsonOptions +from temporalio.converter import ( + CompositePayloadConverter, + DataConverter, + DefaultPayloadConverter, + JSONPlainPayloadConverter, +) from temporalio.exceptions import ApplicationError from .constants import METADATA_JSON @@ -441,3 +450,37 @@ def read_jsonl(path: Path) -> Iterable[dict]: line = line.strip() # noqa: PLW2901 if line: yield json.loads(line) + + +class _PydanticPayloadConverter(CompositePayloadConverter): + def __init__(self) -> None: + json_payload_converter = PydanticJSONPlainPayloadConverter( + ToJsonOptions(exclude_unset=False) + ) + super().__init__( + *( + c + if not isinstance(c, JSONPlainPayloadConverter) + else json_payload_converter + for c in DefaultPayloadConverter.default_encoding_payload_converters + ) + ) + + def from_payloads( + self, payloads: Sequence[Payload], type_hints: list[type] | None = None + ) -> list[Any]: + try: + return super().from_payloads(payloads, type_hints) + except (TypeError, ValidationError) as e: + raise fatal_error_from_exception(e) from e + + def to_payloads(self, values: Sequence[Any]) -> list[Payload]: + try: + return super().to_payloads(values) + except (TypeError, ValidationError) as e: + raise fatal_error_from_exception(e) from e + + +PYDANTIC_DATA_CONVERTER = DataConverter( + payload_converter_class=_PydanticPayloadConverter +) diff --git a/datashare-python/tests/test_interceptors.py b/datashare-python/tests/test_interceptors.py index 93ea1e2..4383433 100644 --- a/datashare-python/tests/test_interceptors.py +++ b/datashare-python/tests/test_interceptors.py @@ -19,7 +19,7 @@ from temporalio.common import RetryPolicy with temporalio.workflow.unsafe.imports_passed_through(): - from datashare_python.config import PYDANTIC_DATA_CONVERTER, WorkerConfig + from datashare_python.config import WorkerConfig from datashare_python.interceptors import ( HeartbeatInterceptor, ProgressInterceptor, @@ -28,6 +28,7 @@ get_trace_context, ) from datashare_python.types_ import ProgressRateHandler, TemporalClient, Weight + from datashare_python.utils import PYDANTIC_DATA_CONVERTER from temporalio import workflow from temporalio.client import ( Interceptor, diff --git a/datashare-python/tests/test_utils.py b/datashare-python/tests/test_utils.py index c55ec30..52fea31 100644 --- a/datashare-python/tests/test_utils.py +++ b/datashare-python/tests/test_utils.py @@ -4,7 +4,7 @@ from pathlib import Path import pytest -from datashare_python.objects import DocArtifact +from datashare_python.objects import DatashareModel, DocArtifact from datashare_python.types_ import TemporalClient from datashare_python.utils import activity_defn, positional_args_only, write_artifact from datashare_python.worker import datashare_worker @@ -33,6 +33,15 @@ async def non_retriable() -> None: raise ValueError("non retriable error occurred") +class DeserArg(DatashareModel): + value: str + + +@activity_defn(name="non_retriable") +async def deser_test_act(arg: DeserArg) -> str: + return arg.value + + @workflow.defn(name="non_retriable_workflow") class NonRetriableWorkflow: @workflow.run @@ -44,6 +53,18 @@ async def run(self) -> None: ) +@workflow.defn(name="test-deserialization-error") +class TestDeserializationErrorWorkflow: + @workflow.run + async def run(self, args: DeserArg) -> None: + await workflow.execute_activity( + deser_test_act, + args=[args], + task_queue="deser", + schedule_to_close_timeout=timedelta(seconds=30), + ) + + async def test_retriable(test_temporal_client_session: TemporalClient) -> None: # Given client = test_temporal_client_session @@ -67,6 +88,35 @@ async def test_retriable(test_temporal_client_session: TemporalClient) -> None: assert cause.non_retryable +async def test_deserialization_error( + test_temporal_client_session: TemporalClient, +) -> None: + # Given + wrong_args = {"valueee": "some-value"} + client = test_temporal_client_session + workflow_id = f"workflow_{uuid.uuid4().hex}" + worker_id = f"worker-{uuid.uuid4().hex}" + worker = datashare_worker( + client, + worker_id=worker_id, + task_queue="deser", + workflows=[TestDeserializationErrorWorkflow], + activities=[deser_test_act], + ) + async with worker: + with pytest.raises(WorkflowFailureError) as ctx: + await client.execute_workflow( + TestDeserializationErrorWorkflow.run, + args=[wrong_args], + id=workflow_id, + task_queue="deser", + ) + assert ctx.value.cause.non_retryable + root_cause = ctx.value.cause.__cause__ + assert isinstance(root_cause, ApplicationError) + assert "2 validation errors for DeserArg" in root_cause.message + + def test_write_artifact(tmp_path: Path) -> None: from datashare_python.conftest import TEST_PROJECT # noqa: PLC0415