Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions datashare-python/datashare_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
METADATA_JSON = "metadata.json"

TIKA_METADATA_RESOURCENAME = "tika_metadata_resourcename"

DEFAULT_SHARED_RESOURCES_SIZE = 1
12 changes: 10 additions & 2 deletions datashare-python/datashare_python/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from asyncio import AbstractEventLoop, iscoroutine
from collections.abc import AsyncGenerator, Callable
from contextlib import AsyncExitStack, asynccontextmanager
from contextlib import AbstractContextManager, AsyncExitStack, asynccontextmanager
from contextvars import ContextVar
from copy import deepcopy
from typing import Any
Expand Down Expand Up @@ -108,7 +108,7 @@ async def set_shared_resources(shared: Shared) -> Shared:


# Return shared resources
def shared_resources() -> Shared:
def lifespan_shared_resources() -> Shared:
try:
return SHARED.get()
except LookupError as e:
Expand Down Expand Up @@ -146,3 +146,11 @@ def add_missing_args(fn: Callable, args: dict[str, Any], **kwargs) -> dict[str,
args = deepcopy(args)
args.update(from_kwargs)
return args


# component lifecycle
def component_teardown(_cache_key: str, component: AbstractContextManager) -> None:
if not isinstance(component, AbstractContextManager):
return

component.__exit__(None, None, None)
32 changes: 27 additions & 5 deletions datashare-python/datashare_python/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,21 @@
import os
from asyncio import Lock
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from dataclasses import InitVar, dataclass, field
from datetime import UTC, datetime
from enum import StrEnum, unique
from io import BytesIO
from pathlib import Path
from typing import Annotated, Any, ClassVar, Literal, Self, TypeVar, cast

import langcodes
from lru import LRU
from pydantic_core import PydanticCustomError, ValidationError, core_schema
from pydantic_core.core_schema import PlainValidatorFunctionSchema
from pydantic_extra_types.language_code import LanguageName
from temporalio import workflow

from .constants import TIKA_METADATA_RESOURCENAME
from .constants import DEFAULT_SHARED_RESOURCES_SIZE, TIKA_METADATA_RESOURCENAME

with workflow.unsafe.imports_passed_through():
from icij_common.es import (
Expand Down Expand Up @@ -372,10 +373,23 @@ def python(cls) -> Self:

@dataclass(frozen=True)
class Shared:
_resources: dict[str, Any] = field(default_factory=dict)
_lock: Lock = field(default_factory=Lock)
cache_size: InitVar[int] = DEFAULT_SHARED_RESOURCES_SIZE
eviction_callback: InitVar[Callable] = None
_resources: LRU = field(init=False, repr=False)
_lock: Lock = field(init=False, repr=False)

def __post_init__(self, cache_size: int, eviction_callback: Callable) -> None:
object.__setattr__(
self, "_resources", LRU(cache_size, callback=eviction_callback)
)
object.__setattr__(self, "_lock", Lock())

def get_resource(
self, key: str, default: Any = None, *, set_if_unavailable: bool = True
) -> Any:
if key not in self._resources and set_if_unavailable:
self.set_resource(key, default)

def get_resource(self, key: str, default: Any = None) -> Any:
return self._resources.get(key, default)

def set_resource(self, key: str, value: Any) -> None:
Expand All @@ -384,6 +398,14 @@ def set_resource(self, key: str, value: Any) -> None:
def pop_resource(self, key: str, default: Any = None) -> Any:
return self._resources.pop(key, default)

async def async_get_resource(
self, key: str, default: Any = None, *, set_if_unavailable: bool = True
) -> Any:
if key not in self._resources and set_if_unavailable:
await self.async_set_resource(key, default)

return self._resources.get(key, default)

async def async_set_resource(self, key: str, value: Any) -> None:
async with self._lock:
self._resources[key] = value
Expand Down
3 changes: 2 additions & 1 deletion datashare-python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "datashare-python"
version = "0.8.24"
version = "0.8.23"
description = "Manage Python tasks and local resources in Datashare"
authors = [
{ name = "Clément Doumouro", email = "cdoumouro@icij.org" },
Expand All @@ -23,6 +23,7 @@ dependencies = [
"temporalio~=1.23",
"typer>=0.15.4,<0.25.1",
"tomlkit~=0.14.0",
"lru-dict~=1.4",
]

[project.urls]
Expand Down
32 changes: 28 additions & 4 deletions datashare-python/tests/test_dependencies.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Any
from unittest.mock import MagicMock

import pytest
from datashare_python.dependencies import (
add_missing_args,
component_teardown,
lifespan_shared_resources,
set_shared_resources,
shared_resources,
)
from datashare_python.exceptions import DependencyInjectionError
from datashare_python.objects import Shared
Expand All @@ -15,7 +17,7 @@ async def test_set_shared_resources_and_get() -> None:
shared.set_resource("k", "v")
returned = await set_shared_resources(shared)
assert returned is shared
assert shared_resources() is shared
assert lifespan_shared_resources() is shared


async def test_set_shared_resources_overwrites() -> None:
Expand All @@ -25,12 +27,12 @@ async def test_set_shared_resources_overwrites() -> None:
second.set_resource("k", "v2")
await set_shared_resources(first)
await set_shared_resources(second)
assert shared_resources() is second
assert lifespan_shared_resources() is second


def test_shared_resources_raises_before_set() -> None:
with pytest.raises(DependencyInjectionError):
shared_resources()
lifespan_shared_resources()


def test_shared_resources_field_is_frozen() -> None:
Expand Down Expand Up @@ -69,6 +71,28 @@ async def test_async_pop_resource_missing_key_with_default() -> None:
assert await shared.async_pop_resource("missing", None) is None


def test_call_component_teardown_on_key_eviction() -> None:
shared = Shared(1, eviction_callback=component_teardown)

first_resource = MagicMock()
second_resource = MagicMock()

shared.set_resource("first", first_resource)
shared.set_resource("second", second_resource)

assert shared._resources.keys() == ["second"]

first_resource.__exit__.assert_called_once()


def test_dont_set_default_when_set_is_unavailable_is_false() -> None:
shared = Shared()

shared.get_resource("first", default="isn't there", set_if_unavailable=False)

assert "first" not in shared._resources


@pytest.mark.parametrize(
("provided_args", "kwargs", "maybe_output"),
[
Expand Down
Loading
Loading