Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 1 addition & 27 deletions datashare-python/datashare_python/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,14 @@
from icij_common.pydantic_utils import ICIJSettings
from pydantic import PrivateAttr
from pydantic_settings import SettingsConfigDict
from temporalio.contrib.pydantic import PydanticJSONPlainPayloadConverter, ToJsonOptions
from temporalio.converter import (
CompositePayloadConverter,
DataConverter,
DefaultPayloadConverter,
JSONPlainPayloadConverter,
)
from temporalio.runtime import PrometheusConfig, Runtime, TelemetryConfig

import datashare_python

from .objects import BaseModel
from .task_client import DatashareTaskClient
from .types_ import TemporalClient
from .utils import PYDANTIC_DATA_CONVERTER

_ALL_LOGGERS = [datashare_python.__name__]

Expand Down Expand Up @@ -129,23 +123,3 @@ def to_task_client(self) -> DatashareTaskClient:

async def to_temporal_client(self) -> TemporalClient:
return await self.temporal.to_client()


class _PydanticPayloadConverter(CompositePayloadConverter):
def __init__(self) -> None:
json_payload_converter = PydanticJSONPlainPayloadConverter(
ToJsonOptions(exclude_unset=False)
)
super().__init__(
*(
c
if not isinstance(c, JSONPlainPayloadConverter)
else json_payload_converter
for c in DefaultPayloadConverter.default_encoding_payload_converters
)
)


PYDANTIC_DATA_CONVERTER = DataConverter(
payload_converter_class=_PydanticPayloadConverter
)
182 changes: 180 additions & 2 deletions datashare-python/datashare_python/interceptors.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
import asyncio
import contextlib
import dataclasses
import secrets
from collections.abc import Generator, Mapping
from collections.abc import Callable, Generator, Mapping
from contextlib import contextmanager
from contextvars import ContextVar
from copy import deepcopy
from typing import Annotated, Any, NoReturn, Self, TypeVar
from functools import partial
from inspect import signature
from types import UnionType
from typing import (
Annotated,
Any,
NoReturn,
Self,
TypeVar,
get_args,
get_origin,
get_type_hints,
)

from nexusrpc import InputT, OutputT
from pydantic import Field
from temporalio import activity
from temporalio.activity import _Definition
from temporalio.api.common.v1 import Payload
from temporalio.client import WorkflowHandle
from temporalio.converter import DataConverter
from temporalio.worker import (
ActivityInboundInterceptor,
Expand All @@ -34,6 +52,13 @@
)

from .objects import BaseModel
from .types_ import ProgressRateHandler, Weight
from .utils import (
PROGRESS_HANDLER_ARG,
PYDANTIC_DATA_CONVERTER,
ActivityWithProgress,
ProgressSignal,
)

_TRACEPARENT = "traceparent"
_DEFAULT_PAYLOAD_CONVERTER = DataConverter.default.payload_converter
Expand Down Expand Up @@ -196,3 +221,156 @@ def _with_trace_context_header(
next_ctx.traceparent
)
return new_obj


class ProgressInterceptor(Interceptor):
def intercept_activity(
self,
next: ActivityInboundInterceptor, # noqa: A002
) -> ActivityInboundInterceptor:
return _ProgressInboundInterceptor(next)


def _parse_progress_weight(act_fn: Callable) -> float:
hints = get_type_hints(act_fn, include_extras=True)
hint = hints["progress"]
annotated_progress = get_origin(hint)
if annotated_progress is not Annotated:
return 1.0
annotated_args = get_args(hint)
for ann in annotated_args[1:]:
if isinstance(ann, Weight):
return ann.value
return 1.0


async def progress_handler(
progress: float,
handle: WorkflowHandle,
*,
activity_id: str,
run_id: str,
weight: float = 1.0,
) -> None:
signal = ProgressSignal(
activity_id=activity_id, run_id=run_id, progress=progress, weight=weight
)
await handle.signal("update_progress", signal)


def supports_progress(task_fn: Callable) -> bool:
return any(
param.name == PROGRESS_HANDLER_ARG
for param in signature(task_fn).parameters.values()
)


def _get_progress_handler(act_fn: Callable) -> ProgressRateHandler:
act = getattr(act_fn, "__self__", None)
# Weirdly isinstance doesn't work here
if act is None or not isinstance(act, ActivityWithProgress):
msg = (
f"to support progress, activities should inherit from "
f"{ActivityWithProgress.__name__}."
)
raise TypeError(msg)
weight = _parse_progress_weight(act_fn)
info = activity.info()
run_id = info.workflow_run_id
workflow_id = info.workflow_id
activity_id = activity.info().activity_id
client = act._temporal_client
workflow_handle = client.get_workflow_handle(workflow_id, run_id=run_id)
handler = partial(
progress_handler,
handle=workflow_handle,
run_id=run_id,
activity_id=activity_id,
weight=weight,
)
return handler


def _is_progress(t: type) -> bool:
if t is ProgressRateHandler:
return True
return bool(
isinstance(t, UnionType)
and any(sub_t is ProgressRateHandler for sub_t in get_args(t))
)


def _without_progress(arg_types: list[type] | None) -> list[type] | None:
if arg_types is None:
return None
filtered = [t for t in arg_types if not _is_progress(t)]
return filtered


class _ProgressInboundInterceptor(ActivityInboundInterceptor):
async def execute_activity(self, input: ExecuteActivityInput) -> Any: # noqa: A002
if not supports_progress(input.fn):
return await super().execute_activity(input)
# The progress args breaks trigger a bypass of the dataloader:
# https://github.com/temporalio/sdk-python/blob/631ebaf0e20fb214b16589b45627b358048a5d77/temporalio/worker/_activity.py#L600
# we have to force it here again
progress_handler = _get_progress_handler(input.fn)
if input.args:
data_converter = PYDANTIC_DATA_CONVERTER
arg_types = _Definition.must_from_callable(input.fn).arg_types
arg_types = _without_progress(arg_types)
arg_types = arg_types[: len(input.args)]
encoded = await data_converter.encode(input.args)
new_args = await data_converter.decode(encoded, type_hints=arg_types)
new_args.append(progress_handler)
else:
new_args = [progress_handler]
new_input = dataclasses.replace(input, args=new_args)
await progress_handler(0.0)
res = await super().execute_activity(new_input)
await progress_handler(1.0)
return res


class HeartbeatInterceptor(Interceptor):
def __init__(self, n_missed_before_timeout: int = 5):
self._n_missed_before_timeout = n_missed_before_timeout

def intercept_activity(
self,
next: ActivityInboundInterceptor, # noqa: A002
) -> ActivityInboundInterceptor:
return _HeartbeatInboundInterceptor(next, self._n_missed_before_timeout)


async def _heartbeat_every(period: float, *details: Any) -> None:
with contextlib.suppress(RuntimeError, asyncio.TimeoutError):
activity.heartbeat(*details)
while True:
await asyncio.sleep(period)
with contextlib.suppress(RuntimeError, asyncio.TimeoutError):
activity.heartbeat(*details)


class _HeartbeatInboundInterceptor(ActivityInboundInterceptor):
def __init__(
self,
next: ActivityInboundInterceptor, # noqa: A002
n_missed_before_timeout: int = 5,
) -> None:
super().__init__(next)
self._n_missed_before_timeout = n_missed_before_timeout

async def execute_activity(self, input: ExecuteActivityInput) -> Any: # noqa: A002
heartbeat_timeout = activity.info().heartbeat_timeout
heartbeat_task = None
if heartbeat_timeout:
period = heartbeat_timeout.total_seconds() / self._n_missed_before_timeout
heartbeat_task = asyncio.create_task(_heartbeat_every(period))
try:
activity.heartbeat()
return await super().execute_activity(input)
finally:
if heartbeat_task:
heartbeat_task.cancel()
await asyncio.wait([heartbeat_task])
6 changes: 6 additions & 0 deletions datashare-python/datashare_python/types_.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Coroutine
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from dataclasses import dataclass
from typing import Protocol

from temporalio.client import Client
Expand All @@ -12,6 +13,11 @@ async def __call__(self, progress_rate: float) -> None:
pass


@dataclass
class Weight:
value: float


class RawProgressHandler(Protocol):
async def __call__(self, iteration: int) -> None:
pass
Expand Down
Loading
Loading