diff --git a/docs/param_decorators.md b/docs/param_decorators.md new file mode 100644 index 0000000..b3a1cf8 --- /dev/null +++ b/docs/param_decorators.md @@ -0,0 +1,171 @@ +# Parameter Decorators + +PyNest supports explicit HTTP parameter decorators for route handlers. Because +Python does not have parameter-level decorator syntax, these decorators are used +as default-value markers in the function signature. + +```python +from nest.common.decorators import Body, Headers, Param, Query +from nest.core import Controller, Get, Post + + +@Controller("/users") +class UsersController: + @Post("/{user_id}") + def update_user( + self, + user_id: int = Param("user_id"), + payload: dict = Body(), + trace_id: str = Headers("x-trace-id", default=None), + ): + return {"user_id": user_id, "payload": payload, "trace_id": trace_id} + + @Get("/") + def list_users(self, page: int = Query("page", default=1)): + return {"page": page} +``` + +## Built-in Decorators + +### Body + +`Body()` injects the full request body. `Body("key")` injects one embedded body +field. + +```python +@Post("/") +def create(self, payload: CreateUserDto = Body()): + return payload + + +@Post("/name") +def create_name(self, name: str = Body("name")): + return {"name": name} +``` + +### Param + +`Param("name")` injects a path parameter. `Param()` injects all path parameters +as a dictionary. + +```python +@Get("/{user_id}/posts/{post_id}") +def get_post( + self, + user_id: int = Param("user_id"), + post_id: int = Param("post_id"), +): + return {"user_id": user_id, "post_id": post_id} + + +@Get("/{user_id}") +def get_user(self, params: dict = Param()): + return params +``` + +### Query + +`Query("name")` injects one query parameter. `Query()` injects all query +parameters as a dictionary. + +```python +@Get("/search") +def search( + self, + q: str = Query("q"), + limit: int = Query("limit", default=20), +): + return {"q": q, "limit": limit} +``` + +### Headers + +`Headers("name")` injects one header. `Headers()` injects all headers as a +dictionary. + +```python +@Get("/me") +def get_me(self, authorization: str = Headers("authorization")): + return {"authorization": authorization} +``` + +### Req, Res, Ip, and HostParam + +`Req()` injects the FastAPI request, `Res()` injects the response object, `Ip()` +injects the client IP address, and `HostParam()` injects the request hostname. + +```python +from fastapi import Request, Response +from nest.common.decorators import HostParam, Ip, Req, Res + + +@Get("/raw") +def raw( + self, + request: Request = Req(), + response: Response = Res(), + ip: str = Ip(), + host: str = HostParam(), +): + response.headers["x-client-ip"] = ip + return {"path": request.url.path, "host": host} +``` + +## Pipes + +Decorators accept one or more pipe callables after the source name. A pipe can be +a callable or a class/instance exposing `transform(value)`. + +```python +def parse_int(value): + return int(value) + + +class TrimPipe: + def transform(self, value): + return value.strip() + + +@Get("/{id}") +def get_item( + self, + item_id=Param("id", parse_int), + q: str = Query("q", TrimPipe), +): + return {"item_id": item_id, "q": q} +``` + +## Custom Decorators + +Use `createParamDecorator(factory)` to build reusable domain-specific parameter +decorators. The factory receives `data` and an `ExecutionContext`. + +```python +from nest.common.decorators import createParamDecorator + + +CurrentUser = createParamDecorator( + lambda data, ctx: ctx.switch_to_http().get_request().state.user +) + +TenantId = createParamDecorator( + lambda data, ctx: ctx.switch_to_http().get_request().headers.get("x-tenant-id") +) + +UserProperty = createParamDecorator( + lambda data, ctx: getattr(ctx.switch_to_http().get_request().state.user, data) +) + + +@Get("/profile") +def profile( + self, + user=CurrentUser(), + email: str = UserProperty("email"), + tenant_id: str = TenantId(), +): + return {"user": user, "email": email, "tenant_id": tenant_id} +``` + +Routes without parameter decorators continue to use FastAPI's normal signature +binding. diff --git a/mkdocs.yml b/mkdocs.yml index cd6a97d..202ee6d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -55,6 +55,7 @@ nav: - CLI Usage: cli.md - Modules: modules.md - Controllers: controllers.md + - Parameter Decorators: param_decorators.md - Providers: providers.md - Guards: guards.md - Exception Filters: exception_filters.md diff --git a/nest/common/__init__.py b/nest/common/__init__.py index 6f1c1a7..6deed60 100644 --- a/nest/common/__init__.py +++ b/nest/common/__init__.py @@ -12,3 +12,15 @@ ArgumentsHost, HttpArgumentsHost, ) +from nest.common.decorators import ( + Body, + ExecutionContext, + Headers, + HostParam, + Ip, + Param, + Query, + Req, + Res, + createParamDecorator, +) diff --git a/nest/common/decorators.py b/nest/common/decorators.py new file mode 100644 index 0000000..b2364cf --- /dev/null +++ b/nest/common/decorators.py @@ -0,0 +1,366 @@ +from __future__ import annotations + +import inspect +import keyword +from dataclasses import dataclass +from typing import Any, Callable, Optional, Tuple + +from fastapi import Body as FastAPIBody +from fastapi import Depends +from fastapi import Header as FastAPIHeader +from fastapi import Path as FastAPIPath +from fastapi import Query as FastAPIQuery +from fastapi import Request, Response +from pydantic import TypeAdapter + + +_MISSING = object() + + +@dataclass(frozen=True) +class ParamMetadata: + source: str + name: Optional[str] = None + data: Any = None + factory: Optional[Callable[[Any, "ExecutionContext"], Any]] = None + pipes: Tuple[Any, ...] = () + default: Any = _MISSING + + +class HttpExecutionContext: + def __init__(self, request: Request, response: Optional[Response] = None): + self._request = request + self._response = response + + def get_request(self) -> Request: + return self._request + + def get_response(self) -> Optional[Response]: + return self._response + + +class ExecutionContext: + def __init__(self, request: Request, response: Optional[Response] = None): + self._request = request + self._response = response + + def switch_to_http(self) -> HttpExecutionContext: + return HttpExecutionContext(self._request, self._response) + + def get_type(self) -> str: + return "http" + + +def Body(key: Optional[str] = None, *pipes: Any, default: Any = _MISSING): + key, pipes = _normalize_name_and_pipes(key, pipes) + return ParamMetadata(source="body", name=key, pipes=pipes, default=default) + + +def Param(name: Optional[str] = None, *pipes: Any, default: Any = _MISSING): + name, pipes = _normalize_name_and_pipes(name, pipes) + return ParamMetadata(source="param", name=name, pipes=pipes, default=default) + + +def Query(name: Optional[str] = None, *pipes: Any, default: Any = _MISSING): + name, pipes = _normalize_name_and_pipes(name, pipes) + return ParamMetadata(source="query", name=name, pipes=pipes, default=default) + + +def Headers(name: Optional[str] = None, *pipes: Any, default: Any = _MISSING): + name, pipes = _normalize_name_and_pipes(name, pipes) + return ParamMetadata(source="headers", name=name, pipes=pipes, default=default) + + +def Req(): + return ParamMetadata(source="request") + + +def Res(): + return ParamMetadata(source="response") + + +def Ip(*pipes: Any, default: Any = _MISSING): + return ParamMetadata(source="ip", pipes=pipes, default=default) + + +def HostParam(name: Optional[str] = None, *pipes: Any, default: Any = _MISSING): + name, pipes = _normalize_name_and_pipes(name, pipes) + return ParamMetadata(source="host", name=name, pipes=pipes, default=default) + + +def createParamDecorator(factory: Callable[[Any, ExecutionContext], Any]) -> Callable: + if not callable(factory): + raise TypeError("createParamDecorator requires a callable factory") + + def decorator(data: Any = None, *pipes: Any, default: Any = _MISSING): + return ParamMetadata( + source="custom", + data=data, + factory=factory, + pipes=pipes, + default=default, + ) + + return decorator + + +def has_param_decorators(endpoint: Callable) -> bool: + signature = inspect.signature(endpoint) + return any( + isinstance(parameter.default, ParamMetadata) + for parameter in signature.parameters.values() + ) + + +def wrap_param_decorators(endpoint: Callable) -> Callable: + signature = inspect.signature(endpoint) + wrapped_parameters = [] + + for parameter in signature.parameters.values(): + if isinstance(parameter.default, ParamMetadata): + dependency = _build_dependency(parameter) + wrapped_parameters.append( + parameter.replace( + annotation=inspect.Parameter.empty, + default=Depends(dependency), + ) + ) + else: + wrapped_parameters.append(parameter) + + wrapper_signature = signature.replace(parameters=wrapped_parameters) + handler_param_names = set(signature.parameters) + + async def wrapper(*args, **kwargs): + call_kwargs = {k: v for k, v in kwargs.items() if k in handler_param_names} + result = endpoint(*args, **call_kwargs) + if inspect.isawaitable(result): + return await result + return result + + wrapper.__name__ = getattr(endpoint, "__name__", "param_decorator_wrapper") + wrapper.__signature__ = wrapper_signature + return wrapper + + +def _build_dependency(parameter: inspect.Parameter) -> Callable: + metadata = parameter.default + annotation = parameter.annotation + + async def dependency(**kwargs): + value = await _resolve_value(metadata, kwargs) + value = await _apply_pipes(value, metadata.pipes) + return _coerce_value(value, annotation) + + dependency.__name__ = f"resolve_{parameter.name}_{metadata.source}" + dependency.__signature__ = _dependency_signature(parameter, metadata) + return dependency + + +async def _resolve_value(metadata: ParamMetadata, kwargs: dict) -> Any: + request = kwargs.get("request") + response = kwargs.get("response") + + if metadata.source == "request": + return request + if metadata.source == "response": + return response + if metadata.source == "param" and metadata.name is None: + return dict(request.path_params) + if metadata.source == "query" and metadata.name is None: + return dict(request.query_params) + if metadata.source == "headers" and metadata.name is None: + return dict(request.headers) + if metadata.source == "ip": + return request.client.host if request.client else None + if metadata.source == "host": + if metadata.name: + return request.path_params.get(metadata.name) + return request.url.hostname + if metadata.source == "custom": + context = ExecutionContext(request, response) + result = metadata.factory(metadata.data, context) + if inspect.isawaitable(result): + return await result + return result + + return _first_source_value(kwargs) + + +def _dependency_signature( + parameter: inspect.Parameter, + metadata: ParamMetadata, +) -> inspect.Signature: + source = metadata.source + annotation = parameter.annotation + + if source == "request": + return inspect.Signature( + parameters=[ + inspect.Parameter( + "request", + inspect.Parameter.KEYWORD_ONLY, + annotation=Request, + ) + ] + ) + + if source == "response": + return inspect.Signature( + parameters=[ + inspect.Parameter( + "response", + inspect.Parameter.KEYWORD_ONLY, + annotation=Response, + ) + ] + ) + + if source in {"param", "query", "headers"} and metadata.name is None: + return _request_only_signature() + + if source in {"ip", "host"}: + return _request_only_signature() + + if source == "custom": + return inspect.Signature( + parameters=[ + inspect.Parameter( + "request", + inspect.Parameter.KEYWORD_ONLY, + annotation=Request, + ), + inspect.Parameter( + "response", + inspect.Parameter.KEYWORD_ONLY, + annotation=Response, + ), + ] + ) + + if source == "body": + name = _source_parameter_name(metadata.name or parameter.name) + default = _default_value(metadata) + fastapi_default = FastAPIBody( + default, + alias=metadata.name, + embed=metadata.name is not None, + ) + return inspect.Signature( + parameters=[ + inspect.Parameter( + name, + inspect.Parameter.KEYWORD_ONLY, + annotation=annotation, + default=fastapi_default, + ) + ] + ) + + if source == "param": + name = _source_parameter_name(metadata.name or parameter.name) + alias = None if name == (metadata.name or parameter.name) else metadata.name + fastapi_default = FastAPIPath(..., alias=alias) + return inspect.Signature( + parameters=[ + inspect.Parameter( + name, + inspect.Parameter.KEYWORD_ONLY, + annotation=annotation, + default=fastapi_default, + ) + ] + ) + + if source == "query": + return _simple_source_signature( + parameter, + metadata, + lambda default, alias: FastAPIQuery(default, alias=alias), + ) + + if source == "headers": + return _simple_source_signature( + parameter, + metadata, + lambda default, alias: FastAPIHeader(default, alias=alias), + ) + + return inspect.Signature() + + +def _simple_source_signature(parameter, metadata, marker_factory) -> inspect.Signature: + source_name = metadata.name or parameter.name + name = _source_parameter_name(source_name) + alias = source_name if name != source_name or metadata.name else None + fastapi_default = marker_factory(_default_value(metadata), alias) + + return inspect.Signature( + parameters=[ + inspect.Parameter( + name, + inspect.Parameter.KEYWORD_ONLY, + annotation=parameter.annotation, + default=fastapi_default, + ) + ] + ) + + +def _request_only_signature() -> inspect.Signature: + return inspect.Signature( + parameters=[ + inspect.Parameter( + "request", + inspect.Parameter.KEYWORD_ONLY, + annotation=Request, + ) + ] + ) + + +def _normalize_name_and_pipes(name: Any, pipes: Tuple[Any, ...]): + if name is not None and not isinstance(name, str): + return None, (name, *pipes) + return name, pipes + + +def _source_parameter_name(name: str) -> str: + if name.isidentifier() and not keyword.iskeyword(name): + return name + return "value" + + +def _default_value(metadata: ParamMetadata) -> Any: + if metadata.default is _MISSING: + return ... + return metadata.default + + +def _first_source_value(kwargs: dict) -> Any: + for key, value in kwargs.items(): + if key not in {"request", "response"}: + return value + return None + + +async def _apply_pipes(value: Any, pipes: Tuple[Any, ...]) -> Any: + for pipe in pipes: + pipe_instance = pipe() if inspect.isclass(pipe) else pipe + if hasattr(pipe_instance, "transform"): + value = pipe_instance.transform(value) + elif callable(pipe_instance): + value = pipe_instance(value) + else: + raise TypeError("Pipe must be callable or expose a transform method") + if inspect.isawaitable(value): + value = await value + return value + + +def _coerce_value(value: Any, annotation: Any) -> Any: + if value is None or annotation in {inspect.Parameter.empty, Any}: + return value + if inspect.isclass(annotation) and isinstance(value, annotation): + return value + return TypeAdapter(annotation).validate_python(value) diff --git a/nest/common/route_resolver.py b/nest/common/route_resolver.py index 6b5fefd..bc5801e 100644 --- a/nest/common/route_resolver.py +++ b/nest/common/route_resolver.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any from fastapi import APIRouter, FastAPI, Request +from nest.common.decorators import has_param_decorators, wrap_param_decorators if TYPE_CHECKING: from nest.core.pynest_container import PyNestContainer @@ -94,6 +95,9 @@ def _add_route( **extra_kwargs, } + if has_param_decorators(bound_method): + route_kwargs["endpoint"] = wrap_param_decorators(bound_method) + if hasattr(original_method, "status_code"): route_kwargs["status_code"] = original_method.status_code @@ -105,7 +109,7 @@ def _add_route( controller_filters = list(getattr(cls, "__filters__", [])) if route_filters or controller_filters: route_kwargs["endpoint"] = _wrap_with_filters( - bound_method, route_filters + controller_filters + route_kwargs["endpoint"], route_filters + controller_filters ) router.add_api_route(**route_kwargs) diff --git a/nest/core/__init__.py b/nest/core/__init__.py index 5690b23..a5cf3a5 100644 --- a/nest/core/__init__.py +++ b/nest/core/__init__.py @@ -1,5 +1,17 @@ from fastapi import Depends +from nest.common.decorators import ( + Body, + ExecutionContext, + Headers, + HostParam, + Ip, + Param, + Query, + Req, + Res, + createParamDecorator, +) from nest.common.provider import InjectionToken, Scope from nest.core.decorators import ( Catch, diff --git a/nest/core/decorators/__init__.py b/nest/core/decorators/__init__.py index c152132..212d9e6 100644 --- a/nest/core/decorators/__init__.py +++ b/nest/core/decorators/__init__.py @@ -1,3 +1,15 @@ +from nest.common.decorators import ( + Body, + ExecutionContext, + Headers, + HostParam, + Ip, + Param, + Query, + Req, + Res, + createParamDecorator, +) from nest.core.decorators.controller import Controller from nest.core.decorators.filters import Catch, UseFilters from nest.core.decorators.http_code import HttpCode diff --git a/tests/test_common/test_param_decorators.py b/tests/test_common/test_param_decorators.py new file mode 100644 index 0000000..c63f1e5 --- /dev/null +++ b/tests/test_common/test_param_decorators.py @@ -0,0 +1,170 @@ +from fastapi import FastAPI, Request, Response +from fastapi.testclient import TestClient + +from nest.common.decorators import ( + Body, + ExecutionContext, + Headers, + HostParam, + Ip, + Param, + Query, + Req, + Res, + createParamDecorator, +) +from nest.common.route_resolver import RoutesResolver +from nest.core.decorators.controller import Controller +from nest.core.decorators.http_method import Get, Post +from nest.core.decorators.module import Module +from nest.core.pynest_container import PyNestContainer + + +def build_client(controller): + @Module(controllers=[controller]) + class TestModule: + pass + + container = PyNestContainer() + container.add_module(TestModule) + container.build() + app = FastAPI() + RoutesResolver(container, app).register_routes() + return TestClient(app) + + +def increment(value): + return int(value) + 1 + + +class UpperPipe: + def transform(self, value): + return value.upper() + + +TenantHeader = createParamDecorator( + lambda data, ctx: ctx.switch_to_http().get_request().headers.get(data) +) + + +ContextAware = createParamDecorator( + lambda data, ctx: { + "is_context": isinstance(ctx, ExecutionContext), + "path": ctx.switch_to_http().get_request().url.path, + } +) + + +@Controller("/decorated") +class DecoratedController: + @Post("/{item_id}") + def create( + self, + item_id=Param("item_id", increment), + q: str = Query("q", UpperPipe), + agent: str = Headers("user-agent"), + body: dict = Body(), + ): + return {"item_id": item_id, "q": q, "agent": agent, "body": body} + + @Post("/body-key") + def body_key(self, name: str = Body("name", UpperPipe)): + return {"name": name} + + @Get("/{item_id}/request-data") + def request_data( + self, + params: dict = Param(), + queries: dict = Query(), + headers: dict = Headers(), + request: Request = Req(), + response: Response = Res(), + ip: str = Ip(), + host: str = HostParam(), + ): + response.headers["x-item-id"] = params["item_id"] + return { + "params": params, + "query": queries["q"], + "has_host_header": "host" in headers, + "path": request.url.path, + "ip": ip, + "host": host, + } + + @Get("/custom") + def custom( + self, + tenant: str = TenantHeader("x-tenant-id"), + context: dict = ContextAware(), + ): + return {"tenant": tenant, "context": context} + + @Get("/implicit/{item_id}") + def implicit(self, item_id: int, q: str): + return {"item_id": item_id, "q": q} + + +def test_param_decorators_extract_builtin_sources_and_apply_pipes(): + client = build_client(DecoratedController) + + response = client.post( + "/decorated/41?q=hello", + json={"title": "PyNest"}, + headers={"user-agent": "pytest"}, + ) + + assert response.status_code == 200 + assert response.json() == { + "item_id": 42, + "q": "HELLO", + "agent": "pytest", + "body": {"title": "PyNest"}, + } + + +def test_body_decorator_can_extract_one_body_key(): + client = build_client(DecoratedController) + + response = client.post("/decorated/body-key", json={"name": "ada"}) + + assert response.status_code == 200 + assert response.json() == {"name": "ADA"} + + +def test_request_response_and_whole_collection_decorators(): + client = build_client(DecoratedController) + + response = client.get("/decorated/abc/request-data?q=search") + + assert response.status_code == 200 + assert response.headers["x-item-id"] == "abc" + assert response.json() == { + "params": {"item_id": "abc"}, + "query": "search", + "has_host_header": True, + "path": "/decorated/abc/request-data", + "ip": "testclient", + "host": "testserver", + } + + +def test_custom_param_decorator_receives_data_and_execution_context(): + client = build_client(DecoratedController) + + response = client.get("/decorated/custom", headers={"x-tenant-id": "acme"}) + + assert response.status_code == 200 + assert response.json() == { + "tenant": "acme", + "context": {"is_context": True, "path": "/decorated/custom"}, + } + + +def test_routes_without_param_decorators_keep_fastapi_binding(): + client = build_client(DecoratedController) + + response = client.get("/decorated/implicit/7?q=plain") + + assert response.status_code == 200 + assert response.json() == {"item_id": 7, "q": "plain"}