Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions nest/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,10 @@
Res,
createParamDecorator,
)
from nest.common.interfaces import (
BeforeApplicationShutdown,
OnApplicationBootstrap,
OnApplicationShutdown,
OnModuleDestroy,
OnModuleInit,
)
28 changes: 28 additions & 0 deletions nest/common/interfaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import annotations

from typing import Any, Optional, Protocol, runtime_checkable


@runtime_checkable
class OnModuleInit(Protocol):
def on_module_init(self) -> Any: ...


@runtime_checkable
class OnApplicationBootstrap(Protocol):
def on_application_bootstrap(self) -> Any: ...


@runtime_checkable
class BeforeApplicationShutdown(Protocol):
def before_application_shutdown(self, signal: Optional[str]) -> Any: ...


@runtime_checkable
class OnModuleDestroy(Protocol):
def on_module_destroy(self) -> Any: ...


@runtime_checkable
class OnApplicationShutdown(Protocol):
def on_application_shutdown(self, signal: Optional[str]) -> Any: ...
68 changes: 67 additions & 1 deletion nest/core/pynest_application.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import asyncio
import inspect
from typing import Any
import signal as signal_module
from contextlib import asynccontextmanager
from typing import Any, Iterable, Optional

from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
Expand All @@ -18,6 +21,9 @@ class PyNestApp:
def __init__(self, container: PyNestContainer, http_server: FastAPI) -> None:
self.container = container
self.http_server = http_server
self._closed = False
self._closing = False
self._install_lifespan_shutdown()
routes_resolver = RoutesResolver(self.container, self.http_server)
routes_resolver.register_routes()

Expand All @@ -33,6 +39,31 @@ def use(self, middleware: type, **options: Any) -> "PyNestApp":
self.http_server.add_middleware(middleware, **options)
return self

def enable_shutdown_hooks(
self, signals: Optional[Iterable[signal_module.Signals]] = None
) -> "PyNestApp":
"""Register process signal handlers that trigger graceful shutdown."""
shutdown_signals = tuple(
signals or (signal_module.SIGTERM, signal_module.SIGINT)
)
for shutdown_signal in shutdown_signals:
signal_module.signal(
shutdown_signal, self._make_signal_handler(shutdown_signal)
)
return self

async def close(self, signal: Optional[str] = None) -> None:
"""Run graceful application shutdown lifecycle hooks once."""
if self._closed or self._closing:
return

self._closing = True
try:
await self.container.shutdown_lifecycle(signal)
self._closed = True
finally:
self._closing = False

def use_global_filters(self, *filters) -> "PyNestApp":
"""Register one or more exception filters that apply to every route.

Expand Down Expand Up @@ -73,3 +104,38 @@ async def handler(request: Request, exc: Exception):
return result

self.http_server.add_exception_handler(exc_type, handler)

def _make_signal_handler(self, shutdown_signal: signal_module.Signals):
def handler(signum, frame):
self._close_from_signal(self._signal_name(signum or shutdown_signal))

return handler

def _close_from_signal(self, signal_name: str) -> None:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
asyncio.run(self.close(signal_name))
return

loop.create_task(self.close(signal_name))

@staticmethod
def _signal_name(signum) -> str:
try:
return signal_module.Signals(signum).name
except ValueError:
return str(signum)

def _install_lifespan_shutdown(self) -> None:
original_lifespan_context = self.http_server.router.lifespan_context

@asynccontextmanager
async def lifespan_context(app: FastAPI):
async with original_lifespan_context(app) as state:
try:
yield state
finally:
await self.close()

self.http_server.router.lifespan_context = lifespan_context
175 changes: 167 additions & 8 deletions nest/core/pynest_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,27 @@
from typing import Any, Dict, List, Optional, Type, Union

from nest.common.exceptions import CircularDependencyException
from nest.common.interfaces import (
BeforeApplicationShutdown,
OnApplicationBootstrap,
OnApplicationShutdown,
OnModuleDestroy,
OnModuleInit,
)
from nest.common.module import CompiledModule, ModuleCompiler, ModuleTokenFactory
from nest.common.provider import InjectionToken, ProviderDescriptor
from nest.core.dependency_graph import DependencyGraph
from nest.core.encapsulation import validate_module_encapsulation
from nest.core.injector_module import build_injector, _to_key

_LIFECYCLE_METHOD_NAMES = (
"on_module_init",
"on_application_bootstrap",
"before_application_shutdown",
"on_module_destroy",
"on_application_shutdown",
)


class ModuleRef:
"""Internal container representation of a registered module."""
Expand All @@ -37,6 +52,9 @@ def __init__(self) -> None:
self._modules: Dict[str, ModuleRef] = {}
self._all_descriptors: List[ProviderDescriptor] = []
self._controller_classes: List[Type] = []
self._module_instances: Dict[str, Any] = {}
self._lifecycle_initialized = False
self._lifecycle_shutdown = False
self._module_token_factory = ModuleTokenFactory()
self._module_compiler = ModuleCompiler(self._module_token_factory)

Expand Down Expand Up @@ -105,20 +123,80 @@ def clear(self) -> None:
self._modules.clear()
self._all_descriptors.clear()
self._controller_classes.clear()
self._module_instances.clear()
self._lifecycle_initialized = False
self._lifecycle_shutdown = False

async def initialize_lifecycle(self) -> None:
"""Run module init and application bootstrap hooks once."""
if self._injector is None:
raise RuntimeError(
"Container not built. Call container.build() before lifecycle hooks."
)
if self._lifecycle_initialized:
return

for module_ref in self._modules.values():
await self._call_hooks(
self._get_module_lifecycle_instances(module_ref),
OnModuleInit,
"on_module_init",
)

await self._call_hooks(
self._get_all_lifecycle_instances(),
OnApplicationBootstrap,
"on_application_bootstrap",
)
self._lifecycle_initialized = True

async def shutdown_lifecycle(self, signal: Optional[str] = None) -> None:
"""Run application shutdown hooks once in graceful shutdown order."""
if self._injector is None:
raise RuntimeError(
"Container not built. Call container.build() before lifecycle hooks."
)
if self._lifecycle_shutdown:
return

modules = list(self._modules.values())
for module_ref in reversed(modules):
await self._call_hooks(
self._get_module_lifecycle_instances(module_ref),
BeforeApplicationShutdown,
"before_application_shutdown",
signal,
)

for module_ref in reversed(modules):
await self._call_hooks(
self._get_module_lifecycle_instances(module_ref),
OnModuleDestroy,
"on_module_destroy",
)

for module_ref in reversed(modules):
await self._call_hooks(
self._get_module_lifecycle_instances(module_ref),
OnApplicationShutdown,
"on_application_shutdown",
signal,
)

self._lifecycle_shutdown = True

# ── Internal ───────────────────────────────────────────────────────────────

def _make_controller_descriptors(self) -> List[ProviderDescriptor]:
from nest.common.provider import Scope

return [
ProviderDescriptor(provide=cls, use_class=cls, scope=Scope.SINGLETON)
for cls in self._controller_classes
]

def _validate_dependency_graph(self) -> None:
"""Build a DAG from all class providers and raise CircularDependencyException on cycles."""
import sys

graph = DependencyGraph()

# Build a name→class lookup from all registered providers so forward refs can be resolved
Expand Down Expand Up @@ -162,9 +240,90 @@ def _validate_dependency_graph(self) -> None:

cycles = graph.detect_cycles()
if cycles:
chain = " → ".join(
getattr(n, "__name__", repr(n)) for n in cycles[0]
)
raise CircularDependencyException(
f"Circular dependency detected: {chain}"
)
chain = " → ".join(getattr(n, "__name__", repr(n)) for n in cycles[0])
raise CircularDependencyException(f"Circular dependency detected: {chain}")

def _get_all_lifecycle_instances(self) -> List[Any]:
instances: List[Any] = []
seen: set[int] = set()
for module_ref in self._modules.values():
for instance in self._get_module_lifecycle_instances(module_ref):
instance_id = id(instance)
if instance_id in seen:
continue
seen.add(instance_id)
instances.append(instance)
return instances

def _get_module_lifecycle_instances(self, module_ref: ModuleRef) -> List[Any]:
instances: List[Any] = []
seen: set[int] = set()

for desc in module_ref.compiled.provider_descriptors:
instance = self.get(desc.provide)
instance_id = id(instance)
if instance_id in seen:
continue
seen.add(instance_id)
instances.append(instance)

module_instance = self._get_module_instance(module_ref)
if module_instance is not None and id(module_instance) not in seen:
instances.append(module_instance)

return instances

def _get_module_instance(self, module_ref: ModuleRef) -> Optional[Any]:
if module_ref.token in self._module_instances:
return self._module_instances[module_ref.token]

if not any(
callable(getattr(module_ref.metatype, name, None))
for name in _LIFECYCLE_METHOD_NAMES
):
return None

instance = self._instantiate_module(module_ref.metatype)
self._module_instances[module_ref.token] = instance
return instance

def _instantiate_module(self, module_class: Type) -> Any:
try:
signature = inspect.signature(module_class.__init__)
except (TypeError, ValueError):
return module_class()

kwargs = {}
for param in list(signature.parameters.values())[1:]:
if param.kind in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
continue
if param.annotation is not inspect.Parameter.empty:
kwargs[param.name] = self.get(param.annotation)
elif param.default is inspect.Parameter.empty:
raise RuntimeError(
f"Cannot instantiate module {module_class.__name__}: "
f"constructor parameter {param.name!r} has no type annotation"
)

return module_class(**kwargs)

async def _call_hooks(
self, instances: List[Any], protocol: Type, method_name: str, *args: Any
) -> None:
calls = [
self._call_hook(instance, method_name, *args)
for instance in instances
if isinstance(instance, protocol)
]
if calls:
import asyncio

await asyncio.gather(*calls)

async def _call_hook(self, instance: Any, method_name: str, *args: Any) -> None:
result = getattr(instance, method_name)(*args)
if inspect.isawaitable(result):
await result
25 changes: 25 additions & 0 deletions nest/core/pynest_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import asyncio
import threading
from abc import ABC, abstractmethod
from typing import Type, TypeVar

Expand Down Expand Up @@ -34,10 +36,33 @@ def create(main_module: Type[ModuleType], **kwargs) -> PyNestApp:
container = PyNestContainer()
container.add_module(main_module)
container.build()
PyNestFactory._run_async(container.initialize_lifecycle())

http_server = FastAPI(**kwargs)
return PyNestApp(container, http_server)

@staticmethod
def _create_server(**kwargs) -> FastAPI:
return FastAPI(**kwargs)

@staticmethod
def _run_async(coro):
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)

result = {}

def runner():
try:
result["value"] = asyncio.run(coro)
except BaseException as exc:
result["error"] = exc

thread = threading.Thread(target=runner)
thread.start()
thread.join()
if "error" in result:
raise result["error"]
return result.get("value")
Loading
Loading