diff --git a/nest/core/database/__init__.py b/nest/core/database/__init__.py index e69de29..c76ee3d 100644 --- a/nest/core/database/__init__.py +++ b/nest/core/database/__init__.py @@ -0,0 +1,19 @@ +from nest.core.database.database_module import ( + DATABASE_ENGINE, + DATABASE_OPTIONS, + DATABASE_SESSION_FACTORY, + DatabaseModule, + DatabaseOptions, + DatabaseService, +) +from nest.core.database.orm_provider import Base + +__all__ = [ + "Base", + "DATABASE_ENGINE", + "DATABASE_OPTIONS", + "DATABASE_SESSION_FACTORY", + "DatabaseModule", + "DatabaseOptions", + "DatabaseService", +] diff --git a/nest/core/database/database_module.py b/nest/core/database/database_module.py new file mode 100644 index 0000000..d1bf1aa --- /dev/null +++ b/nest/core/database/database_module.py @@ -0,0 +1,266 @@ +from __future__ import annotations + +from contextlib import asynccontextmanager, contextmanager +from dataclasses import dataclass, field +from typing import Any, Dict, Generator, Optional, Type + +from sqlalchemy import create_engine +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm import Session, sessionmaker + +from nest.common.provider import InjectionToken +from nest.core.decorators.module import Module +from nest.core.database.orm_config import AsyncConfigFactory, ConfigFactory +from nest.core.database.orm_provider import Base + +DATABASE_OPTIONS = InjectionToken( + "DATABASE_OPTIONS", "Normalized DatabaseModule.for_root options" +) +DATABASE_ENGINE = InjectionToken("DATABASE_ENGINE", "SQLAlchemy engine") +DATABASE_SESSION_FACTORY = InjectionToken( + "DATABASE_SESSION_FACTORY", "SQLAlchemy session factory" +) + + +@dataclass(frozen=True) +class DatabaseOptions: + driver: str + config_params: Dict[str, Any] + async_mode: bool = False + engine_params: Dict[str, Any] = field(default_factory=dict) + session_params: Dict[str, Any] = field(default_factory=dict) + create_all: bool = False + base: Type[Any] = Base + + +class DatabaseService: + """Lifecycle-aware SQLAlchemy service registered by DatabaseModule.""" + + def __init__( + self, + options: DatabaseOptions, + engine: Any, + session_factory: Any, + ) -> None: + self.options = options + self.engine = engine + self.session_factory = session_factory + self.Base = options.base + + def on_module_init(self): + if not self.options.create_all: + return None + return self.create_all() + + def on_module_destroy(self): + result = self.engine.dispose() + return result + + def create_all(self): + if self.options.async_mode: + return self._create_all_async() + self.Base.metadata.create_all(bind=self.engine) + return None + + async def _create_all_async(self) -> None: + async with self.engine.begin() as conn: + await conn.run_sync(self.Base.metadata.create_all) + + def drop_all(self): + if self.options.async_mode: + return self._drop_all_async() + self.Base.metadata.drop_all(bind=self.engine) + return None + + async def _drop_all_async(self) -> None: + async with self.engine.begin() as conn: + await conn.run_sync(self.Base.metadata.drop_all) + + def session(self): + if self.options.async_mode: + return self._async_session() + return self._sync_session() + + def get_session(self): + return self.session() + + def get_db(self): + if self.options.async_mode: + return self._async_db() + return self._sync_db() + + @contextmanager + def _sync_session(self) -> Generator[Session, None, None]: + db = self.session_factory() + try: + yield db + except Exception: + db.rollback() + raise + finally: + db.close() + + def _sync_db(self) -> Session: + return self.session_factory() + + @asynccontextmanager + async def _async_session(self) -> AsyncSession: + db = self.session_factory() + try: + yield db + except Exception: + await db.rollback() + raise + finally: + await db.close() + + async def _async_db(self): + db = self.session_factory() + try: + yield db + finally: + await db.close() + + +def create_database_engine(options: DatabaseOptions): + config_factory = AsyncConfigFactory if options.async_mode else ConfigFactory + engine_factory = create_async_engine if options.async_mode else create_engine + config_class = config_factory(db_type=options.driver).get_config() + config_url = config_class(**options.config_params).get_engine_url() + return engine_factory(config_url, **options.engine_params) + + +def create_database_session_factory(options: DatabaseOptions, engine: Any): + if options.async_mode: + session_params = {"expire_on_commit": False, "class_": AsyncSession} + session_params.update(options.session_params) + return async_sessionmaker(engine, **session_params) + return sessionmaker(engine, **options.session_params) + + +def create_database_service( + options: DatabaseOptions, + engine: Any, + session_factory: Any, +) -> DatabaseService: + return DatabaseService(options, engine, session_factory) + + +@Module(imports=[], providers=[], exports=[]) +class DatabaseModule: + @classmethod + def for_root( + cls, + driver: str = "postgresql", + *, + database: Optional[str] = None, + db_name: Optional[str] = None, + config_params: Optional[Dict[str, Any]] = None, + host: Optional[str] = None, + user: Optional[str] = None, + password: Optional[str] = None, + port: Optional[int] = None, + async_mode: bool = False, + engine_params: Optional[Dict[str, Any]] = None, + session_params: Optional[Dict[str, Any]] = None, + create_all: bool = False, + base: Type[Any] = Base, + is_global: bool = True, + **extra_config: Any, + ): + normalized_config = _normalize_config_params( + config_params=config_params, + database=database, + db_name=db_name, + host=host, + user=user, + password=password, + port=port, + extra_config=extra_config, + ) + options = DatabaseOptions( + driver=driver, + config_params=normalized_config, + async_mode=async_mode, + engine_params=engine_params or {}, + session_params=session_params or {}, + create_all=create_all, + base=base, + ) + + providers = [ + {"provide": DATABASE_OPTIONS, "useValue": options}, + { + "provide": DATABASE_ENGINE, + "useFactory": create_database_engine, + "inject": [DATABASE_OPTIONS], + }, + { + "provide": DATABASE_SESSION_FACTORY, + "useFactory": create_database_session_factory, + "inject": [DATABASE_OPTIONS, DATABASE_ENGINE], + }, + { + "provide": DatabaseService, + "useFactory": create_database_service, + "inject": [ + DATABASE_OPTIONS, + DATABASE_ENGINE, + DATABASE_SESSION_FACTORY, + ], + }, + ] + + module_name = _configured_module_name(driver=driver, async_mode=async_mode) + configured_module = type(module_name, (cls,), {}) + setattr(configured_module, "__pynest_database_root__", True) + return Module( + imports=[], + providers=providers, + exports=[ + DATABASE_OPTIONS, + DATABASE_ENGINE, + DATABASE_SESSION_FACTORY, + DatabaseService, + ], + is_global=is_global, + )(configured_module) + + +def _normalize_config_params( + *, + config_params: Optional[Dict[str, Any]], + database: Optional[str], + db_name: Optional[str], + host: Optional[str], + user: Optional[str], + password: Optional[str], + port: Optional[int], + extra_config: Dict[str, Any], +) -> Dict[str, Any]: + normalized = dict(config_params or {}) + + database_name = db_name if db_name is not None else database + if database_name is not None and "db_name" not in normalized: + normalized["db_name"] = database_name + + for key, value in { + "host": host, + "user": user, + "password": password, + "port": port, + }.items(): + if value is not None and key not in normalized: + normalized[key] = value + + for key, value in extra_config.items(): + if value is not None and key not in normalized: + normalized[key] = value + + return normalized + + +def _configured_module_name(driver: str, async_mode: bool) -> str: + prefix = "Async" if async_mode else "" + normalized_driver = "".join(part.capitalize() for part in driver.split("_")) + return f"{prefix}{normalized_driver}DatabaseModule" diff --git a/nest/core/decorators/database.py b/nest/core/decorators/database.py index 5f7ea2c..1583dd2 100644 --- a/nest/core/decorators/database.py +++ b/nest/core/decorators/database.py @@ -10,8 +10,8 @@ def db_request_handler(func): """ Decorator that wraps ORM service methods with timing, logging, and HTTP error - conversion. Session lifecycle (open / commit / rollback / close) is the - responsibility of each service method — use config.get_session() there. + conversion. Session lifecycle (open / commit / rollback / close) is the + responsibility of each service method; use DatabaseService.session() there. """ def wrapper(self, *args, **kwargs): @@ -32,7 +32,7 @@ def wrapper(self, *args, **kwargs): def async_db_request_handler(func): """ Async version of db_request_handler. Session lifecycle is the caller's - responsibility (pass session via Depends or use config.get_session()). + responsibility; use DatabaseService.session() in the service method. """ async def wrapper(*args, **kwargs): diff --git a/nest/core/injector_module.py b/nest/core/injector_module.py index bf32b20..34b96c8 100644 --- a/nest/core/injector_module.py +++ b/nest/core/injector_module.py @@ -32,8 +32,7 @@ class PyNestInjectorModule(InjectorModule): def __init__(self, descriptors: List[ProviderDescriptor]) -> None: self._descriptors = [ - d for d in descriptors - if d.use_factory is None and d.use_existing is None + d for d in descriptors if d.use_factory is None and d.use_existing is None ] def configure(self, binder) -> None: @@ -59,8 +58,14 @@ def build_injector(descriptors: List[ProviderDescriptor]) -> Injector: from injector import InstanceProvider injector = Injector([PyNestInjectorModule(descriptors)]) + provider_counts = {} + last_provider_index = {} + for index, desc in enumerate(descriptors): + key = _to_key(desc.provide) + provider_counts[key] = provider_counts.get(key, 0) + 1 + last_provider_index[key] = index - for desc in descriptors: + for index, desc in enumerate(descriptors): key = _to_key(desc.provide) if desc.use_factory is not None: @@ -73,4 +78,22 @@ def build_injector(descriptors: List[ProviderDescriptor]) -> Injector: existing_instance = injector.get(_to_key(desc.use_existing)) injector.binder.bind(key, to=InstanceProvider(existing_instance)) + elif ( + desc.use_value is not None + and provider_counts[key] > 1 + and last_provider_index[key] == index + ): + injector.binder.bind(key, to=InstanceProvider(desc.use_value)) + + elif ( + desc.use_class is not None + and provider_counts[key] > 1 + and last_provider_index[key] == index + ): + injector.binder.bind( + key, + to=desc.use_class, + scope=_injector_scope(desc.scope), + ) + return injector diff --git a/nest/core/pynest_container.py b/nest/core/pynest_container.py index 120dc2c..69a57f8 100644 --- a/nest/core/pynest_container.py +++ b/nest/core/pynest_container.py @@ -57,6 +57,7 @@ def __init__(self) -> None: self._lifecycle_shutdown = False self._module_token_factory = ModuleTokenFactory() self._module_compiler = ModuleCompiler(self._module_token_factory) + self._database_root_registered = False # ── Public API ───────────────────────────────────────────────────────────── @@ -74,6 +75,14 @@ def module_compiler(self): def add_module(self, module_class: Type) -> dict: """Compile and register a module and all its imports recursively.""" + if getattr(module_class, "__pynest_database_root__", False): + if self._database_root_registered: + raise RuntimeError( + "Only one DatabaseModule.for_root() can be registered per " + "application. Named database connections are not supported yet." + ) + self._database_root_registered = True + compiled = self._module_compiler.compile(module_class) token = compiled.token @@ -126,6 +135,7 @@ def clear(self) -> None: self._module_instances.clear() self._lifecycle_initialized = False self._lifecycle_shutdown = False + self._database_root_registered = False async def initialize_lifecycle(self) -> None: """Run module init and application bootstrap hooks once.""" diff --git a/pyproject.toml b/pyproject.toml index 4b8da1f..1035b8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ test = [ "beanie>=1.27.0,<2.0.0", "python-dotenv>=1.0.1,<2.0.0", "aiosqlite>=0.19.0,<1.0.0", + "greenlet>=3.1.1,<4.0.0", "websockets>=13.0,<16.0", ] docs = [ diff --git a/tests/test_cli/test_orm_templates.py b/tests/test_cli/test_orm_templates.py new file mode 100644 index 0000000..f8529ea --- /dev/null +++ b/tests/test_cli/test_orm_templates.py @@ -0,0 +1,60 @@ +from nest.cli.templates.postgres_template import AsyncPostgresqlTemplate +from nest.cli.templates.mysql_template import AsyncMySQLTemplate, MySQLTemplate +from nest.cli.templates.sqlite_template import SQLiteTemplate + + +def test_sync_orm_app_template_uses_database_module_for_root(): + template = SQLiteTemplate("book") + + app_file = template.app_file() + config_file = template.config_file() + service_file = template.service_file() + entity_file = template.entity_file() + + assert "from nest.core.database import DatabaseModule" in app_file + assert "from .config import DATABASE_CONFIG" in app_file + assert "DatabaseModule.for_root(**DATABASE_CONFIG)" in app_file + assert "create_all=True" in config_file + assert "config.create_all" not in app_file + assert "OrmProvider" not in config_file + assert "DATABASE_CONFIG = dict(" in config_file + assert "from nest.core.database import DatabaseService" in service_file + assert "def __init__(self, db: DatabaseService):" in service_file + assert "with self.db.session() as session:" in service_file + assert "from src.config import config" not in service_file + assert "from nest.core.database import Base" in entity_file + + +def test_async_orm_template_uses_injected_database_service(): + template = AsyncPostgresqlTemplate("book") + + app_file = template.app_file() + config_file = template.config_file() + service_file = template.service_file() + controller_file = template.controller_file() + + assert "DatabaseModule.for_root(**DATABASE_CONFIG)" in app_file + assert '"async_mode": True' in config_file + assert '"create_all": True' in config_file + assert "AsyncOrmProvider" not in config_file + assert "from nest.core.database import DatabaseService" in service_file + assert "def __init__(self, db: DatabaseService):" in service_file + assert "async with self.db.session() as session:" in service_file + assert "Depends(config.get_db)" not in controller_file + assert "AsyncSession" not in controller_file + + +def test_orm_template_requirements_include_sqlalchemy_runtime(): + sync_requirements = SQLiteTemplate("book").requirements_file() + async_requirements = AsyncPostgresqlTemplate("book").requirements_file() + + assert "sqlalchemy" in sync_requirements.lower() + assert "sqlalchemy" in async_requirements.lower() + + +def test_mysql_orm_templates_default_missing_port_environment_variables(): + sync_config = MySQLTemplate("book").config_file() + async_config = AsyncMySQLTemplate("book").config_file() + + assert 'os.getenv("MYSQL_PORT", 3306)' in sync_config + assert 'os.getenv("MYSQL_PORT", 3306)' in async_config diff --git a/tests/test_core/test_database/test_database_module.py b/tests/test_core/test_database/test_database_module.py new file mode 100644 index 0000000..5424d01 --- /dev/null +++ b/tests/test_core/test_database/test_database_module.py @@ -0,0 +1,301 @@ +import asyncio + +import pytest +from fastapi.testclient import TestClient +from sqlalchemy import Column, Integer, String, inspect, select +from sqlalchemy.orm import DeclarativeBase + +from nest.core import Controller, Get, Injectable, Module, PyNestFactory +from nest.core.database import ( + DATABASE_ENGINE, + DATABASE_OPTIONS, + DATABASE_SESSION_FACTORY, + DatabaseModule, + DatabaseOptions, + DatabaseService, +) + + +def test_database_module_for_root_registers_core_providers(tmp_path): + class LocalBase(DeclarativeBase): + pass + + database_name = str(tmp_path / "providers") + configured_database_module = DatabaseModule.for_root( + driver="sqlite", + database=database_name, + base=LocalBase, + create_all=False, + ) + + @Module(imports=[configured_database_module]) + class AppModule: + pass + + app = PyNestFactory.create(AppModule) + options = app.container.get(DATABASE_OPTIONS) + engine = app.container.get(DATABASE_ENGINE) + session_factory = app.container.get(DATABASE_SESSION_FACTORY) + database = app.container.get(DatabaseService) + + assert isinstance(options, DatabaseOptions) + assert options.driver == "sqlite" + assert options.config_params == {"db_name": database_name} + assert database.options is options + assert database.engine is engine + assert database.session_factory is session_factory + + asyncio.run(app.close()) + + +def test_database_module_rejects_duplicate_root_registration(tmp_path): + class LocalBase(DeclarativeBase): + pass + + @Module( + imports=[ + DatabaseModule.for_root( + driver="sqlite", + database=str(tmp_path / "primary"), + base=LocalBase, + ), + DatabaseModule.for_root( + driver="sqlite", + database=str(tmp_path / "secondary"), + base=LocalBase, + ), + ] + ) + class AppModule: + pass + + with pytest.raises(RuntimeError, match="DatabaseModule.for_root"): + PyNestFactory.create(AppModule) + + +def test_database_module_does_not_create_tables_by_default(tmp_path): + class LocalBase(DeclarativeBase): + pass + + class Author(LocalBase): + __tablename__ = "default_authors" + + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String, nullable=False) + + @Module( + imports=[ + DatabaseModule.for_root( + driver="sqlite", + database=str(tmp_path / "default-create-all"), + base=LocalBase, + ) + ] + ) + class AppModule: + pass + + app = PyNestFactory.create(AppModule) + database = app.container.get(DatabaseService) + + assert "default_authors" not in inspect(database.engine).get_table_names() + + asyncio.run(app.close()) + + +def test_database_service_runs_sync_lifecycle_hooks(): + events = [] + + class Metadata: + def create_all(self, bind): + events.append(("create_all", bind)) + + class LocalBase: + metadata = Metadata() + + class Engine: + def dispose(self): + events.append("dispose") + + engine = Engine() + options = DatabaseOptions( + driver="sqlite", + config_params={"db_name": "lifecycle"}, + base=LocalBase, + create_all=True, + ) + service = DatabaseService(options, engine, session_factory=lambda: object()) + + service.on_module_init() + service.on_module_destroy() + + assert events == [("create_all", engine), "dispose"] + + +def test_database_service_session_rolls_back_and_closes_on_error(): + events = [] + + class Session: + def rollback(self): + events.append("rollback") + + def close(self): + events.append("close") + + options = DatabaseOptions( + driver="sqlite", + config_params={"db_name": "sessions"}, + create_all=False, + ) + service = DatabaseService(options, engine=object(), session_factory=Session) + + with pytest.raises(ValueError, match="boom"): + with service.session() as session: + assert isinstance(session, Session) + raise ValueError("boom") + + assert events == ["rollback", "close"] + + +def test_database_service_can_be_replaced_by_app_provider(tmp_path): + class LocalBase(DeclarativeBase): + pass + + class FakeDatabaseService: + def session(self): + return "fake-session" + + fake_database = FakeDatabaseService() + + @Injectable + class UsesDatabase: + def __init__(self, db: DatabaseService): + self.db = db + + @Module( + imports=[ + DatabaseModule.for_root( + driver="sqlite", + database=str(tmp_path / "replace"), + base=LocalBase, + create_all=False, + ) + ], + providers=[ + {"provide": DatabaseService, "useValue": fake_database}, + UsesDatabase, + ], + ) + class AppModule: + pass + + app = PyNestFactory.create(AppModule) + + assert app.container.get(DatabaseService) is fake_database + assert app.container.get(UsesDatabase).db is fake_database + + asyncio.run(app.close()) + + +def test_database_module_powers_feature_module_through_http_e2e(tmp_path): + class LocalBase(DeclarativeBase): + pass + + class Author(LocalBase): + __tablename__ = "authors" + + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String, nullable=False) + + @Injectable + class AuthorService: + def __init__(self, db: DatabaseService): + self.db = db + + def create_and_list(self): + with self.db.session() as session: + session.add(Author(name="Le Guin")) + session.commit() + + with self.db.session() as session: + authors = session.query(Author).order_by(Author.name).all() + return [author.name for author in authors] + + @Controller("/authors", tag="authors") + class AuthorController: + def __init__(self, service: AuthorService): + self.service = service + + @Get("/") + def list_authors(self): + return {"authors": self.service.create_and_list()} + + @Module(controllers=[AuthorController], providers=[AuthorService]) + class AuthorModule: + pass + + @Module( + imports=[ + DatabaseModule.for_root( + driver="sqlite", + database=str(tmp_path / "authors"), + base=LocalBase, + create_all=True, + ), + AuthorModule, + ] + ) + class AppModule: + pass + + app = PyNestFactory.create(AppModule) + database = app.container.get(DatabaseService) + + assert "authors" in inspect(database.engine).get_table_names() + + with TestClient(app.get_server()) as client: + response = client.get("/authors") + + assert response.status_code == 200 + assert response.json() == {"authors": ["Le Guin"]} + + +def test_async_database_module_creates_tables_and_runs_queries(tmp_path): + class LocalBase(DeclarativeBase): + pass + + class Author(LocalBase): + __tablename__ = "async_authors" + + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String, nullable=False) + + @Module( + imports=[ + DatabaseModule.for_root( + driver="sqlite", + database=str(tmp_path / "async-authors"), + base=LocalBase, + async_mode=True, + create_all=True, + ) + ] + ) + class AppModule: + pass + + app = PyNestFactory.create(AppModule) + database = app.container.get(DatabaseService) + + async def scenario(): + async with database.session() as session: + session.add(Author(name="Butler")) + await session.commit() + + async with database.session() as session: + result = await session.execute(select(Author.name)) + return result.scalars().all() + + assert asyncio.run(scenario()) == ["Butler"] + + asyncio.run(app.close())