diff --git a/AGENTS.md b/AGENTS.md index 7f63472..20c8626 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -49,6 +49,7 @@ The package uses optional extras to minimize Lambda bundle size: - `database` - MySQL/PostgreSQL drivers - `slack` - Slack SDK - `jwt` - RS256 JWT validation (`rsa` package) +- `snowflake` - Pure-Python Snowflake SQL API client (`snowflake-sql-api`) - `all` - All integrations - `dev` - Development and testing tools diff --git a/README.md b/README.md index 1e25f23..6a55c8e 100644 --- a/README.md +++ b/README.md @@ -65,12 +65,13 @@ Production-ready shared Python utilities for AWS Lambda functions, CLI tools, an pip install nui-python-shared-utils # With specific extras for optional dependencies -pip install nui-python-shared-utils[all] # All integrations +pip install nui-python-shared-utils[all] # Core optional integrations (excludes Snowflake) pip install nui-python-shared-utils[powertools] # AWS Powertools only pip install nui-python-shared-utils[slack] # Slack only pip install nui-python-shared-utils[elasticsearch] # Elasticsearch only pip install nui-python-shared-utils[database] # Database only pip install nui-python-shared-utils[jwt] # JWT authentication only +pip install nui-python-shared-utils[snowflake] # Snowflake SQL API client ``` ### Basic Configuration @@ -198,6 +199,31 @@ async with db.get_connection() as conn: **[→ See full database guide](docs/getting-started/quickstart.md#database-connections)** +### Snowflake (SQL API) + +Pure-Python Snowflake client (no `snowflake-connector-python`), keypair auth via +Secrets Manager, with NUI session defaults (`TIMEZONE=Pacific/Auckland`, role +`NUI_LAMBDA`) that you can override via `timezone=` and `role=`, plus a redacting +query-logging hook. + +```python +from nui_shared_utils import create_snowflake_client + +# Loads account/user/private_key from Secrets Manager ("snowflake-credentials" +# by default; override with SNOWFLAKE_CREDENTIALS_SECRET or secret_name=). +# The NUI defaults are overridable, so the client stays generic for any account. +client = create_snowflake_client( + warehouse="COMPUTE_WH", + database="ANALYTICS", + timezone="UTC", # override the Pacific/Auckland default + role="MY_APP_ROLE", # override the NUI_LAMBDA default +) +rows = client.query("SELECT id, name FROM orders WHERE status = ?", ["confirmed"]) +``` + +Sync is the default; use `create_async_snowflake_client(...)` inside a Lambda +already on an event loop. Requires the `[snowflake]` extra. + ### CloudWatch Metrics ```python diff --git a/nui_lambda_shared_utils/snowflake_client.py b/nui_lambda_shared_utils/snowflake_client.py new file mode 100644 index 0000000..02f0952 --- /dev/null +++ b/nui_lambda_shared_utils/snowflake_client.py @@ -0,0 +1,3 @@ +"""Backwards-compatibility shim. Use nui_shared_utils.snowflake_client instead.""" + +from nui_shared_utils.snowflake_client import * # noqa: F401,F403 diff --git a/nui_shared_utils/__init__.py b/nui_shared_utils/__init__.py index 6e952c6..122c556 100644 --- a/nui_shared_utils/__init__.py +++ b/nui_shared_utils/__init__.py @@ -104,6 +104,11 @@ # Optional: AWS Powertools "get_powertools_logger": ("powertools_helpers", "get_powertools_logger"), "powertools_handler": ("powertools_helpers", "powertools_handler"), + # Optional: Snowflake client adapter (snowflake-sql-api) + "create_snowflake_client": ("snowflake_client", "create_snowflake_client"), + "create_async_snowflake_client": ("snowflake_client", "create_async_snowflake_client"), + "get_snowflake_credentials": ("snowflake_client", "get_snowflake_credentials"), + "redacting_query_logger": ("snowflake_client", "redacting_query_logger"), # Optional: JWT validation (rsa) "validate_jwt": ("jwt_auth", "validate_jwt"), "require_auth": ("jwt_auth", "require_auth"), @@ -124,6 +129,7 @@ "db_client", "powertools_helpers", "jwt_auth", + "snowflake_client", "slack_setup", } @@ -248,6 +254,12 @@ def __dir__() -> List[str]: build_user_activity_query, ) from .db_client import DatabaseClient, PostgreSQLClient, get_pool_stats + from .snowflake_client import ( + create_async_snowflake_client, + create_snowflake_client, + get_snowflake_credentials, + redacting_query_logger, + ) from .powertools_helpers import get_powertools_logger, powertools_handler from .jwt_auth import ( AuthenticationError, diff --git a/nui_shared_utils/snowflake_client.py b/nui_shared_utils/snowflake_client.py new file mode 100644 index 0000000..6ba6555 --- /dev/null +++ b/nui_shared_utils/snowflake_client.py @@ -0,0 +1,422 @@ +"""NUI adapter for the ``snowflake-sql-api`` pure-Python Snowflake client. + +This is the NUI-opinionated layer over the generic +`snowflake-sql-api `_ package. +The generic package stays vendor-neutral (no account/role/timezone defaults); +all NUI specifics live here: + +- **Keypair from Secrets Manager.** Credentials (account, user, PEM private key, + optional passphrase) are loaded from AWS Secrets Manager by default, with + environment-variable and explicit-argument overrides. +- **NUI session defaults.** ``TIMEZONE='Pacific/Auckland'`` and role + ``NUI_LAMBDA`` are applied unless overridden. (``WEEK_START`` is *not* set: the + SQL API only accepts an allow-list of session parameters in a request and + rejects ``WEEK_START`` with HTTP 400. The API is stateless per statement, so + ``ALTER SESSION`` cannot substitute. Set it on the Snowflake user/role default + if a consumer needs Monday-first weeks.) +- **Redacted query logging.** A logging hook records the statement text + (truncated) and the *number* of bind parameters, never the bind values, JWTs, + key material, or secret names. +- **Sync by default.** :func:`create_snowflake_client` returns the synchronous + client; use :func:`create_async_snowflake_client` only for a Lambda already + running on an event loop. + +Install the extra to pull the client in:: + + pip install 'nui-python-shared-utils[snowflake]' + +Using any factory here without that extra raises a clear :class:`ImportError` +rather than failing at package import time. +""" + +import logging +import os +from typing import Any, Callable, Dict, Optional, Sequence, Union + +log = logging.getLogger(__name__) + +# Type alias for the on_query callback: receives SQL text and optional bind params. +QueryHook = Callable[[str, Optional[Sequence[Any]]], None] + +# Optional dependency: the generic snowflake-sql-api client. Guarded so the +# package imports without the [snowflake] extra; factories raise a clear error. +try: + from snowflake_sql_api import AsyncSnowflakeClient, SnowflakeClient + + SNOWFLAKE_SQL_API_AVAILABLE = True +except ModuleNotFoundError as exc: # pragma: no cover - flag is patched in tests + # Only the missing extra itself flips the availability flag. If the package + # is installed but a *dependency of it* is missing, re-raise so the user sees + # the real traceback instead of a misleading "not installed" message. + if exc.name and exc.name.split(".")[0] != "snowflake_sql_api": + raise + AsyncSnowflakeClient = None # type: ignore[assignment,misc] + SnowflakeClient = None # type: ignore[assignment,misc] + SNOWFLAKE_SQL_API_AVAILABLE = False + +__all__ = [ + "DEFAULT_TIMEZONE", + "DEFAULT_ROLE", + "DEFAULT_SECRET_NAME", + "get_snowflake_credentials", + "create_snowflake_client", + "create_async_snowflake_client", + "redacting_query_logger", +] + +#: NUI session defaults. Every one of these is overridable per call. +DEFAULT_TIMEZONE = "Pacific/Auckland" +DEFAULT_ROLE = "NUI_LAMBDA" + +#: Secrets Manager secret name used when none is supplied via argument or the +#: ``SNOWFLAKE_CREDENTIALS_SECRET`` environment variable. +DEFAULT_SECRET_NAME = "snowflake-credentials" + +#: Default ``User-Agent`` sent by NUI clients (overridable via ``user_agent``). +_USER_AGENT = "nui-python-shared-utils-snowflake" + +# Secret/credential field aliases. The secret JSON may use any of these. +_ACCOUNT_FIELDS = ("account", "snowflake_account") +_USER_FIELDS = ("user", "username", "snowflake_user") +_PRIVATE_KEY_FIELDS = ("private_key", "privateKey", "snowflake_private_key") +_PASSPHRASE_FIELDS = ( + "private_key_passphrase", + "passphrase", + "private_key_password", +) + + +def _require_snowflake_sql_api() -> None: + """Raise a clear, actionable error when the optional extra is missing.""" + if not SNOWFLAKE_SQL_API_AVAILABLE: + raise ImportError( + "snowflake-sql-api is not installed. Install it with: pip install 'nui-python-shared-utils[snowflake]'" + ) + + +def _first(mapping: Dict[str, Any], keys: Sequence[str]) -> Optional[Any]: + """Return the first present, non-empty value among ``keys`` in ``mapping``.""" + for key in keys: + value = mapping.get(key) + if value: + return value + return None + + +def _as_key_bytes(value: Any) -> bytes: + """Coerce a PEM private key (``str`` or ``bytes``) to ``bytes``.""" + if isinstance(value, bytes): + return value + if isinstance(value, str): + return value.encode("utf-8") + raise ValueError("private_key must be PEM text (str) or bytes") + + +def _resolve_secret_name(secret_name: Optional[str]) -> str: + """Resolve the secret name: explicit arg > env var > default.""" + return secret_name or os.environ.get("SNOWFLAKE_CREDENTIALS_SECRET") or DEFAULT_SECRET_NAME + + +def _env_private_key() -> Optional[bytes]: + """Read key bytes from ``SNOWFLAKE_PRIVATE_KEY`` (inline PEM) or ``SNOWFLAKE_PRIVATE_KEY_PATH`` (file).""" + key_pem = os.environ.get("SNOWFLAKE_PRIVATE_KEY") + if key_pem: + return _as_key_bytes(key_pem) + key_path = os.environ.get("SNOWFLAKE_PRIVATE_KEY_PATH") + if key_path: + with open(key_path, "rb") as handle: + return handle.read() + return None + + +def get_snowflake_credentials( + secret_name: Optional[str] = None, + *, + account: Optional[str] = None, + user: Optional[str] = None, + private_key: Optional[Union[bytes, str]] = None, + private_key_passphrase: Optional[str] = None, +) -> Dict[str, Any]: + """Resolve Snowflake keypair credentials. + + Each field is resolved independently with strict **explicit arg > env var > + Secrets Manager** precedence, so an explicit value is never shadowed by a + lower tier: + + - ``account`` / ``user``: explicit arg, else ``SNOWFLAKE_ACCOUNT`` / + ``SNOWFLAKE_USER``, else the secret's ``account`` / ``user`` field. + - ``private_key``: explicit arg (bytes/PEM str), else ``SNOWFLAKE_PRIVATE_KEY`` + (inline PEM) / ``SNOWFLAKE_PRIVATE_KEY_PATH`` (file), else the secret's + ``private_key`` field (PEM text). + - ``private_key_passphrase`` is paired with the key's source: an explicit + passphrase arg always wins, otherwise the passphrase comes from the **same + tier the key came from** (env key -> ``SNOWFLAKE_PRIVATE_KEY_PASSPHRASE``; + secret key -> the secret's passphrase field). An explicit key is therefore + never silently paired with a stale env/secret passphrase. + + Secrets Manager is only contacted when ``account``, ``user`` or the key is + still unresolved after the explicit/env tiers. + + Returns a dict with ``account``, ``user``, ``private_key`` (bytes) and + ``private_key_passphrase`` keys. + """ + resolved_account = account or os.environ.get("SNOWFLAKE_ACCOUNT") + resolved_user = user or os.environ.get("SNOWFLAKE_USER") + + # Resolve key + passphrase together so they always come from one source. + if private_key is not None: + resolved_key: Optional[bytes] = _as_key_bytes(private_key) + resolved_passphrase = private_key_passphrase # explicit arg only + else: + env_key = _env_private_key() + if env_key is not None: + resolved_key = env_key + resolved_passphrase = ( + private_key_passphrase + if private_key_passphrase is not None + else os.environ.get("SNOWFLAKE_PRIVATE_KEY_PASSPHRASE") + ) + else: + resolved_key = None + resolved_passphrase = private_key_passphrase # may be filled from secret + + # Contact Secrets Manager only for what is still missing. + if not (resolved_account and resolved_user and resolved_key is not None): + from .secrets_helper import get_secret + + resolved_name = _resolve_secret_name(secret_name) + secret = get_secret(resolved_name) + + resolved_account = resolved_account or _first(secret, _ACCOUNT_FIELDS) + resolved_user = resolved_user or _first(secret, _USER_FIELDS) + if resolved_key is None: + secret_key = _first(secret, _PRIVATE_KEY_FIELDS) + resolved_key = _as_key_bytes(secret_key) if secret_key else None + if resolved_passphrase is None: + resolved_passphrase = _first(secret, _PASSPHRASE_FIELDS) + + missing = [ + name + for name, value in ( + ("account", resolved_account), + ("user", resolved_user), + ("private_key", resolved_key), + ) + if not value + ] + if missing: + # Name the secret but never its contents. + raise ValueError(f"Snowflake secret '{resolved_name}' is missing required field(s): " + ", ".join(missing)) + + if not resolved_key: + # An explicit empty key (``b""``) or an empty key file would otherwise + # slip past the missing-field check above (which only runs in the secret + # branch). Fail clearly here regardless of the key's source. + raise ValueError("a non-empty private_key is required") + + return { + "account": resolved_account, + "user": resolved_user, + "private_key": resolved_key, + "private_key_passphrase": resolved_passphrase, + } + + +def redacting_query_logger( + logger: Optional[logging.Logger] = None, + *, + level: int = logging.DEBUG, + max_sql_chars: int = 1000, +) -> QueryHook: + """Build an ``on_query`` hook that logs statements without leaking values. + + The returned callback logs the (truncated) SQL text and the *count* of bind + parameters. It never logs the bind values themselves, so user data, secrets, + and PII passed as parameters stay out of the logs. + + Args: + logger: Destination logger. Defaults to this module's logger. + level: Log level for the query record (default ``DEBUG``). + max_sql_chars: Truncate the logged SQL beyond this many characters. + """ + target = logger or log + + def _hook(sql: str, params: Optional[Sequence[Any]]) -> None: + truncated = sql if len(sql) <= max_sql_chars else sql[:max_sql_chars] + "...(truncated)" + param_count = len(params) if params else 0 + target.log( + level, + "snowflake query", + extra={"sql": truncated, "bind_param_count": param_count}, + ) + + return _hook + + +def _resolve_on_query( + on_query: Optional[QueryHook], + logger: Optional[logging.Logger], + log_queries: bool, + log_sql_max_chars: int, +) -> Optional[QueryHook]: + """Pick the query hook: explicit override, redacting logger, or none.""" + if on_query is not None: + return on_query + if log_queries: + return redacting_query_logger(logger, max_sql_chars=log_sql_max_chars) + return None + + +def _client_kwargs( + creds: Dict[str, Any], + *, + role: Optional[str], + warehouse: Optional[str], + database: Optional[str], + schema: Optional[str], + timezone: Optional[str], + parameters: Optional[Dict[str, Any]], + on_query: Optional[QueryHook], + user_agent: Optional[str], + extra: Dict[str, Any], +) -> Dict[str, Any]: + """Shared keyword assembly for both sync and async client construction. + + ``TIMEZONE`` rides the client's ``timezone`` argument; any caller-supplied + ``parameters`` pass straight through to the SQL API ``parameters`` map. Note + the SQL API only accepts an allow-list of session parameters there (TIMEZONE + and output-format params are fine; WEEK_START is rejected, see module docs). + """ + # Start from passthrough kwargs, then let adapter-managed keys win, so an + # unexpected ``extra`` key can never silently override a managed one (the + # factories' named params already capture managed keys, so a caller cannot + # reach these via ``**client_kwargs`` today; this guards a future key add). + kwargs: Dict[str, Any] = dict(extra) + kwargs.update( + { + "account": creds["account"], + "user": creds["user"], + "private_key": creds["private_key"], + "private_key_passphrase": creds.get("private_key_passphrase"), + "role": role, + "warehouse": warehouse, + "database": database, + "schema": schema, + "timezone": timezone, + "parameters": dict(parameters) if parameters else None, + "on_query": on_query, + "user_agent": user_agent or _USER_AGENT, + } + ) + return kwargs + + +def create_snowflake_client( + *, + secret_name: Optional[str] = None, + account: Optional[str] = None, + user: Optional[str] = None, + private_key: Optional[Union[bytes, str]] = None, + private_key_passphrase: Optional[str] = None, + role: Optional[str] = DEFAULT_ROLE, + warehouse: Optional[str] = None, + database: Optional[str] = None, + schema: Optional[str] = None, + timezone: Optional[str] = DEFAULT_TIMEZONE, + parameters: Optional[Dict[str, Any]] = None, + logger: Optional[logging.Logger] = None, + log_queries: bool = True, + log_sql_max_chars: int = 1000, + on_query: Optional[QueryHook] = None, + user_agent: Optional[str] = None, + **client_kwargs: Any, +) -> "SnowflakeClient": + """Construct a synchronous Snowflake client with NUI defaults. + + This is the adapter's default entry point. Credentials are resolved via + :func:`get_snowflake_credentials`; NUI session defaults (timezone, role) are + applied unless overridden; a redacting query-logging hook is wired in unless + ``log_queries=False`` or a custom ``on_query`` is given. + + Extra keyword arguments (``timeout``, ``statement_timeout``, + ``poll_interval``, ``retry_policy``, ``host`` ...) pass straight through to + :class:`snowflake_sql_api.SnowflakeClient`. + + Returns: + A ready-to-use :class:`snowflake_sql_api.SnowflakeClient`. + """ + _require_snowflake_sql_api() + creds = get_snowflake_credentials( + secret_name, + account=account, + user=user, + private_key=private_key, + private_key_passphrase=private_key_passphrase, + ) + hook = _resolve_on_query(on_query, logger, log_queries, log_sql_max_chars) + kwargs = _client_kwargs( + creds, + role=role, + warehouse=warehouse, + database=database, + schema=schema, + timezone=timezone, + parameters=parameters, + on_query=hook, + user_agent=user_agent, + extra=client_kwargs, + ) + return SnowflakeClient(**kwargs) + + +def create_async_snowflake_client( + *, + secret_name: Optional[str] = None, + account: Optional[str] = None, + user: Optional[str] = None, + private_key: Optional[Union[bytes, str]] = None, + private_key_passphrase: Optional[str] = None, + role: Optional[str] = DEFAULT_ROLE, + warehouse: Optional[str] = None, + database: Optional[str] = None, + schema: Optional[str] = None, + timezone: Optional[str] = DEFAULT_TIMEZONE, + parameters: Optional[Dict[str, Any]] = None, + logger: Optional[logging.Logger] = None, + log_queries: bool = True, + log_sql_max_chars: int = 1000, + on_query: Optional[QueryHook] = None, + user_agent: Optional[str] = None, + **client_kwargs: Any, +) -> "AsyncSnowflakeClient": + """Construct an asynchronous Snowflake client with NUI defaults. + + Async parity with :func:`create_snowflake_client`. Use only inside a Lambda + already running on an event loop; the sync client is the NUI default + otherwise. + + Returns: + A ready-to-use :class:`snowflake_sql_api.AsyncSnowflakeClient`. + """ + _require_snowflake_sql_api() + creds = get_snowflake_credentials( + secret_name, + account=account, + user=user, + private_key=private_key, + private_key_passphrase=private_key_passphrase, + ) + hook = _resolve_on_query(on_query, logger, log_queries, log_sql_max_chars) + kwargs = _client_kwargs( + creds, + role=role, + warehouse=warehouse, + database=database, + schema=schema, + timezone=timezone, + parameters=parameters, + on_query=hook, + user_agent=user_agent, + extra=client_kwargs, + ) + return AsyncSnowflakeClient(**kwargs) diff --git a/pyproject.toml b/pyproject.toml index 00eb2a0..74ce855 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,12 @@ powertools = [ "coloredlogs>=15.0", ] jwt = ["rsa>=4.9"] +# Pre-release pin to the public snowflake-sql-api repo (PR #1 merge commit on +# master). Repinned to a PyPI release (>=0.1.0) once that package ships. Kept +# out of `all` so nothing pulls the git dependency implicitly. +snowflake = [ + "snowflake-sql-api @ git+https://github.com/hampsterx/snowflake-sql-api.git@50205b0cb63df008c9c453ee05f0f130dcfe4805", +] all = [ "elasticsearch>=7.17.0,<8.0.0", "pymysql>=1.0.0", diff --git a/requirements-test.txt b/requirements-test.txt index b531f7c..826dde6 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -24,4 +24,7 @@ pymysql>=1.0.0 psycopg2-binary>=2.9.0 slack-sdk>=3.19.0 rsa>=4.9 -cryptography>=41.0.0 \ No newline at end of file +cryptography>=41.0.0 + +# snowflake adapter (pre-release pin; see pyproject [snowflake] extra) +snowflake-sql-api @ git+https://github.com/hampsterx/snowflake-sql-api.git@50205b0cb63df008c9c453ee05f0f130dcfe4805 diff --git a/setup.py b/setup.py index ad9587e..4709b02 100644 --- a/setup.py +++ b/setup.py @@ -57,6 +57,13 @@ "coloredlogs>=15.0", ], "jwt": ["rsa>=4.9"], + # Pre-release pin to the public snowflake-sql-api repo (PR #1 merge + # commit on master). Repinned to a PyPI release (>=0.1.0) once that + # package ships. Kept out of "all" so nothing pulls the git dep + # implicitly. + "snowflake": [ + "snowflake-sql-api @ git+https://github.com/hampsterx/snowflake-sql-api.git@50205b0cb63df008c9c453ee05f0f130dcfe4805", + ], "all": [ "elasticsearch>=7.17.0,<8.0.0", "pymysql>=1.0.0", diff --git a/tests/test_snowflake_client.py b/tests/test_snowflake_client.py new file mode 100644 index 0000000..07d5372 --- /dev/null +++ b/tests/test_snowflake_client.py @@ -0,0 +1,564 @@ +"""Tests for the NUI Snowflake adapter (nui_shared_utils.snowflake_client). + +Covers credential resolution (explicit / env / Secrets Manager via moto), +NUI session defaults, the redacting query-logging hook, sync + async client +construction, and the clear-error path when the optional extra is missing. +""" + +import json +import logging +import os + +import boto3 +import pytest +from moto import mock_aws + +from nui_shared_utils import snowflake_client as sc + +# All tests here are fast and offline (moto mocks Secrets Manager in-process; +# "real" client construction only parses a local key, no network). Mark the +# module ``unit`` to match the rest of the suite. +pytestmark = pytest.mark.unit + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="session") +def rsa_private_key_pem() -> bytes: + """A real unencrypted RSA private key (PEM) for end-to-end construction.""" + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import rsa + + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + return key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + +@pytest.fixture +def aws_env(monkeypatch): + """Dummy AWS env so boto3/moto have a region and credentials.""" + monkeypatch.setenv("AWS_DEFAULT_REGION", "ap-southeast-2") + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing") + monkeypatch.setenv("AWS_SECURITY_TOKEN", "testing") + monkeypatch.setenv("AWS_SESSION_TOKEN", "testing") + + +@pytest.fixture +def clear_snowflake_env(monkeypatch): + """Ensure SNOWFLAKE_* env does not leak between tests.""" + for var in ( + "SNOWFLAKE_ACCOUNT", + "SNOWFLAKE_USER", + "SNOWFLAKE_PRIVATE_KEY", + "SNOWFLAKE_PRIVATE_KEY_PATH", + "SNOWFLAKE_PRIVATE_KEY_PASSPHRASE", + "SNOWFLAKE_CREDENTIALS_SECRET", + ): + monkeypatch.delenv(var, raising=False) + + +@pytest.fixture +def secret_in_sm(aws_env, rsa_private_key_pem): + """Create a Snowflake credentials secret in a mocked Secrets Manager.""" + with mock_aws(): + client = boto3.client("secretsmanager", region_name="ap-southeast-2") + client.create_secret( + Name="snowflake-credentials", + SecretString=json.dumps( + { + "account": "bj72353.ap-southeast-2", + "user": "NUI_SVC", + "private_key": rsa_private_key_pem.decode("utf-8"), + } + ), + ) + yield "snowflake-credentials" + + +# --------------------------------------------------------------------------- +# get_snowflake_credentials +# --------------------------------------------------------------------------- + + +class TestGetCredentials: + def test_explicit_args_skip_secrets_manager(self, clear_snowflake_env, mocker, rsa_private_key_pem): + spy = mocker.patch("nui_shared_utils.secrets_helper.get_secret") + creds = sc.get_snowflake_credentials( + account="ab12345", + user="tim", + private_key=rsa_private_key_pem, + private_key_passphrase="hunter2", + ) + assert creds["account"] == "ab12345" + assert creds["user"] == "tim" + assert creds["private_key"] == rsa_private_key_pem + assert creds["private_key_passphrase"] == "hunter2" + spy.assert_not_called() + + def test_explicit_string_key_coerced_to_bytes(self, clear_snowflake_env, rsa_private_key_pem): + creds = sc.get_snowflake_credentials( + account="ab12345", + user="tim", + private_key=rsa_private_key_pem.decode("utf-8"), + ) + assert creds["private_key"] == rsa_private_key_pem + + def test_env_inline_key(self, clear_snowflake_env, monkeypatch, mocker, rsa_private_key_pem): + spy = mocker.patch("nui_shared_utils.secrets_helper.get_secret") + monkeypatch.setenv("SNOWFLAKE_ACCOUNT", "ab12345.ap-southeast-2") + monkeypatch.setenv("SNOWFLAKE_USER", "env_user") + monkeypatch.setenv("SNOWFLAKE_PRIVATE_KEY", rsa_private_key_pem.decode("utf-8")) + monkeypatch.setenv("SNOWFLAKE_PRIVATE_KEY_PASSPHRASE", "pp") + creds = sc.get_snowflake_credentials() + assert creds["account"] == "ab12345.ap-southeast-2" + assert creds["user"] == "env_user" + assert creds["private_key"] == rsa_private_key_pem + assert creds["private_key_passphrase"] == "pp" + spy.assert_not_called() + + def test_explicit_key_wins_over_stale_env_key(self, clear_snowflake_env, monkeypatch, mocker, rsa_private_key_pem): + """Regression (codex #1): explicit key must beat env, never silently swapped.""" + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import rsa + + stale = rsa.generate_private_key(public_exponent=65537, key_size=2048).private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + assert stale != rsa_private_key_pem + monkeypatch.setenv("SNOWFLAKE_ACCOUNT", "env_acct") + monkeypatch.setenv("SNOWFLAKE_USER", "env_user") + monkeypatch.setenv("SNOWFLAKE_PRIVATE_KEY", stale.decode("utf-8")) + spy = mocker.patch("nui_shared_utils.secrets_helper.get_secret") + creds = sc.get_snowflake_credentials(private_key=rsa_private_key_pem) + # Explicit key wins; account/user fall through to env. + assert creds["private_key"] == rsa_private_key_pem + assert creds["account"] == "env_acct" + assert creds["user"] == "env_user" + spy.assert_not_called() + + def test_explicit_key_not_paired_with_stale_env_passphrase( + self, clear_snowflake_env, monkeypatch, rsa_private_key_pem + ): + """An explicit (unencrypted) key must not pick up a stale env passphrase.""" + monkeypatch.setenv("SNOWFLAKE_ACCOUNT", "env_acct") + monkeypatch.setenv("SNOWFLAKE_USER", "env_user") + monkeypatch.setenv("SNOWFLAKE_PRIVATE_KEY_PASSPHRASE", "stale-env-pp") + creds = sc.get_snowflake_credentials(account="a", user="u", private_key=rsa_private_key_pem) + assert creds["private_key_passphrase"] is None + + def test_explicit_passphrase_arg_wins_over_env(self, clear_snowflake_env, monkeypatch, rsa_private_key_pem): + monkeypatch.setenv("SNOWFLAKE_ACCOUNT", "ab12345") + monkeypatch.setenv("SNOWFLAKE_USER", "env_user") + monkeypatch.setenv("SNOWFLAKE_PRIVATE_KEY", rsa_private_key_pem.decode("utf-8")) + monkeypatch.setenv("SNOWFLAKE_PRIVATE_KEY_PASSPHRASE", "env-pp") + creds = sc.get_snowflake_credentials(private_key_passphrase="arg-pp") + assert creds["private_key_passphrase"] == "arg-pp" + + def test_invalid_key_type_raises(self, clear_snowflake_env): + with pytest.raises(ValueError, match="PEM text"): + sc.get_snowflake_credentials(account="a", user="u", private_key=12345) + + def test_empty_explicit_key_raises(self, clear_snowflake_env, mocker): + """An empty explicit key must not slip past validation (codex finding).""" + spy = mocker.patch("nui_shared_utils.secrets_helper.get_secret") + with pytest.raises(ValueError, match="non-empty private_key"): + sc.get_snowflake_credentials(account="a", user="u", private_key=b"") + # account+user present, so the empty key must be caught without an SM call. + spy.assert_not_called() + + def test_env_passphrase_not_used_when_key_from_secret( + self, clear_snowflake_env, monkeypatch, mocker, rsa_private_key_pem + ): + """A secret-sourced key pairs with the secret passphrase, not a stale env one.""" + monkeypatch.setenv("SNOWFLAKE_PRIVATE_KEY_PASSPHRASE", "stale-env-pp") + mocker.patch( + "nui_shared_utils.secrets_helper.get_secret", + return_value={ + "account": "ab12345", + "user": "svc", + "private_key": rsa_private_key_pem.decode("utf-8"), + "private_key_passphrase": "secret-pp", + }, + ) + creds = sc.get_snowflake_credentials("some-secret") + assert creds["private_key_passphrase"] == "secret-pp" + + def test_env_key_path(self, clear_snowflake_env, monkeypatch, tmp_path, rsa_private_key_pem): + key_file = tmp_path / "snowflake.p8" + key_file.write_bytes(rsa_private_key_pem) + monkeypatch.setenv("SNOWFLAKE_ACCOUNT", "ab12345") + monkeypatch.setenv("SNOWFLAKE_USER", "env_user") + monkeypatch.setenv("SNOWFLAKE_PRIVATE_KEY_PATH", str(key_file)) + monkeypatch.setenv("SNOWFLAKE_PRIVATE_KEY_PASSPHRASE", "filepp") + creds = sc.get_snowflake_credentials() + assert creds["private_key"] == rsa_private_key_pem + # Env key-path pairs with the env passphrase from the same tier. + assert creds["private_key_passphrase"] == "filepp" + + def test_env_requires_account_user_and_key(self, clear_snowflake_env, monkeypatch, mocker): + # Only account+user but no key -> falls through to Secrets Manager. + monkeypatch.setenv("SNOWFLAKE_ACCOUNT", "ab12345") + monkeypatch.setenv("SNOWFLAKE_USER", "env_user") + spy = mocker.patch( + "nui_shared_utils.secrets_helper.get_secret", + return_value={"account": "x", "user": "y", "private_key": "PEMDATA"}, + ) + sc.get_snowflake_credentials() + spy.assert_called_once() + + def test_secrets_manager(self, clear_snowflake_env, secret_in_sm, rsa_private_key_pem): + # secret_in_sm fixture creates the secret in a live moto context. + creds = sc.get_snowflake_credentials(secret_in_sm) + assert creds["account"] == "bj72353.ap-southeast-2" + assert creds["user"] == "NUI_SVC" + assert creds["private_key"] == rsa_private_key_pem + assert creds["private_key_passphrase"] is None + + def test_secrets_manager_field_aliases(self, clear_snowflake_env, aws_env, mocker, rsa_private_key_pem): + mocker.patch( + "nui_shared_utils.secrets_helper.get_secret", + return_value={ + "snowflake_account": "ab12345", + "username": "aliased_user", + "privateKey": rsa_private_key_pem.decode("utf-8"), + "passphrase": "secretpp", + }, + ) + creds = sc.get_snowflake_credentials("some-secret") + assert creds["account"] == "ab12345" + assert creds["user"] == "aliased_user" + assert creds["private_key"] == rsa_private_key_pem + assert creds["private_key_passphrase"] == "secretpp" + + def test_explicit_account_user_override_secret(self, clear_snowflake_env, mocker, rsa_private_key_pem): + mocker.patch( + "nui_shared_utils.secrets_helper.get_secret", + return_value={ + "account": "secret_account", + "user": "secret_user", + "private_key": rsa_private_key_pem.decode("utf-8"), + }, + ) + creds = sc.get_snowflake_credentials("some-secret", account="override_acct", user="override_user") + assert creds["account"] == "override_acct" + assert creds["user"] == "override_user" + + def test_missing_field_raises_naming_secret_not_contents(self, clear_snowflake_env, mocker, rsa_private_key_pem): + mocker.patch( + "nui_shared_utils.secrets_helper.get_secret", + return_value={"account": "ab12345"}, # no user, no private_key + ) + with pytest.raises(ValueError) as exc: + sc.get_snowflake_credentials("my-snowflake-secret") + msg = str(exc.value) + assert "my-snowflake-secret" in msg + assert "user" in msg and "private_key" in msg + # Must not leak any key material in the error. + assert "BEGIN" not in msg + + def test_secret_name_precedence(self, clear_snowflake_env, monkeypatch, mocker, rsa_private_key_pem): + captured = {} + + def fake_get_secret(name): + captured["name"] = name + return {"account": "a", "user": "u", "private_key": rsa_private_key_pem.decode("utf-8")} + + mocker.patch("nui_shared_utils.secrets_helper.get_secret", side_effect=fake_get_secret) + + # Default + sc.get_snowflake_credentials() + assert captured["name"] == "snowflake-credentials" + # Env override + monkeypatch.setenv("SNOWFLAKE_CREDENTIALS_SECRET", "env-secret") + sc.get_snowflake_credentials() + assert captured["name"] == "env-secret" + # Explicit arg wins + sc.get_snowflake_credentials("explicit-secret") + assert captured["name"] == "explicit-secret" + + +# --------------------------------------------------------------------------- +# redacting_query_logger +# --------------------------------------------------------------------------- + + +class TestRedactingLogger: + def test_logs_sql_and_param_count_not_values(self, caplog): + logger = logging.getLogger("nui_shared_utils.snowflake_client") + hook = sc.redacting_query_logger(logger, level=logging.INFO) + with caplog.at_level(logging.INFO, logger="nui_shared_utils.snowflake_client"): + hook("SELECT * FROM orders WHERE id = ?", ["super-secret-pii-value"]) + records = [r for r in caplog.records if r.message == "snowflake query"] + assert len(records) == 1 + rec = records[0] + assert rec.sql == "SELECT * FROM orders WHERE id = ?" + assert rec.bind_param_count == 1 + # The actual bind value must never appear anywhere in the log output. + assert "super-secret-pii-value" not in caplog.text + + def test_param_count_zero_for_none(self, caplog): + hook = sc.redacting_query_logger(level=logging.INFO) + with caplog.at_level(logging.INFO, logger="nui_shared_utils.snowflake_client"): + hook("SELECT 1", None) + rec = [r for r in caplog.records if r.message == "snowflake query"][0] + assert rec.bind_param_count == 0 + + def test_sql_truncated(self, caplog): + hook = sc.redacting_query_logger(level=logging.INFO, max_sql_chars=10) + with caplog.at_level(logging.INFO, logger="nui_shared_utils.snowflake_client"): + hook("SELECT " + "x" * 100, None) + rec = [r for r in caplog.records if r.message == "snowflake query"][0] + assert rec.sql.endswith("...(truncated)") + assert len(rec.sql) == len("...(truncated)") + 10 + + +# --------------------------------------------------------------------------- +# create_snowflake_client / create_async_snowflake_client +# --------------------------------------------------------------------------- + + +class TestCreateClient: + @pytest.fixture + def spy_client(self, mocker): + return mocker.patch("nui_shared_utils.snowflake_client.SnowflakeClient") + + @pytest.fixture + def spy_async_client(self, mocker): + return mocker.patch("nui_shared_utils.snowflake_client.AsyncSnowflakeClient") + + def _base_kwargs(self, rsa_private_key_pem): + return {"account": "ab12345", "user": "tim", "private_key": rsa_private_key_pem} + + def test_nui_defaults_applied(self, clear_snowflake_env, spy_client, rsa_private_key_pem): + sc.create_snowflake_client(**self._base_kwargs(rsa_private_key_pem)) + kwargs = spy_client.call_args.kwargs + assert kwargs["timezone"] == "Pacific/Auckland" + assert kwargs["role"] == "NUI_LAMBDA" + # No session parameters are injected by default (WEEK_START is rejected + # by the SQL API; TIMEZONE rides the dedicated `timezone` arg). + assert kwargs["parameters"] is None + assert kwargs["user_agent"] == "nui-python-shared-utils-snowflake" + # Default on_query is the redacting hook. + assert callable(kwargs["on_query"]) + + def test_defaults_overridable(self, clear_snowflake_env, spy_client, rsa_private_key_pem): + sc.create_snowflake_client( + role="ANALYST", + timezone="UTC", + warehouse="WH", + database="DB", + schema="SCH", + **self._base_kwargs(rsa_private_key_pem), + ) + kwargs = spy_client.call_args.kwargs + assert kwargs["role"] == "ANALYST" + assert kwargs["timezone"] == "UTC" + assert kwargs["warehouse"] == "WH" + assert kwargs["database"] == "DB" + assert kwargs["schema"] == "SCH" + + def test_caller_parameters_passthrough(self, clear_snowflake_env, spy_client, rsa_private_key_pem): + sc.create_snowflake_client( + parameters={"QUERY_TAG": "nui", "DATE_OUTPUT_FORMAT": "YYYY-MM-DD"}, + **self._base_kwargs(rsa_private_key_pem), + ) + params = spy_client.call_args.kwargs["parameters"] + assert params == {"QUERY_TAG": "nui", "DATE_OUTPUT_FORMAT": "YYYY-MM-DD"} + + def test_log_queries_false_disables_hook(self, clear_snowflake_env, spy_client, rsa_private_key_pem): + sc.create_snowflake_client(log_queries=False, **self._base_kwargs(rsa_private_key_pem)) + assert spy_client.call_args.kwargs["on_query"] is None + + def test_custom_on_query_overrides(self, clear_snowflake_env, spy_client, rsa_private_key_pem): + sentinel = lambda sql, params: None # noqa: E731 + sc.create_snowflake_client(on_query=sentinel, **self._base_kwargs(rsa_private_key_pem)) + assert spy_client.call_args.kwargs["on_query"] is sentinel + + def test_client_kwargs_passthrough(self, clear_snowflake_env, spy_client, rsa_private_key_pem): + sc.create_snowflake_client( + statement_timeout=120, + timeout=30.0, + host="custom.example.com", + **self._base_kwargs(rsa_private_key_pem), + ) + kwargs = spy_client.call_args.kwargs + assert kwargs["statement_timeout"] == 120 + assert kwargs["timeout"] == 30.0 + assert kwargs["host"] == "custom.example.com" + + def test_custom_user_agent(self, clear_snowflake_env, spy_client, rsa_private_key_pem): + sc.create_snowflake_client(user_agent="my-lambda/1.0", **self._base_kwargs(rsa_private_key_pem)) + assert spy_client.call_args.kwargs["user_agent"] == "my-lambda/1.0" + + def test_async_factory_applies_defaults(self, clear_snowflake_env, spy_async_client, rsa_private_key_pem): + sc.create_async_snowflake_client(**self._base_kwargs(rsa_private_key_pem)) + kwargs = spy_async_client.call_args.kwargs + assert kwargs["timezone"] == "Pacific/Auckland" + assert kwargs["role"] == "NUI_LAMBDA" + assert kwargs["parameters"] is None + assert callable(kwargs["on_query"]) + + def test_async_factory_full_parity(self, clear_snowflake_env, spy_async_client, rsa_private_key_pem): + """Async factory must assemble the same kwargs as sync for a custom config.""" + sc.create_async_snowflake_client( + role="ANALYST", + timezone="UTC", + warehouse="WH", + database="DB", + schema="SCH", + parameters={"QUERY_TAG": "nui"}, + user_agent="my-lambda/1.0", + log_queries=False, + statement_timeout=120, + host="custom.example.com", + **self._base_kwargs(rsa_private_key_pem), + ) + kwargs = spy_async_client.call_args.kwargs + assert kwargs["role"] == "ANALYST" + assert kwargs["timezone"] == "UTC" + assert kwargs["warehouse"] == "WH" + assert kwargs["database"] == "DB" + assert kwargs["schema"] == "SCH" + assert kwargs["parameters"] == {"QUERY_TAG": "nui"} + assert kwargs["user_agent"] == "my-lambda/1.0" + assert kwargs["on_query"] is None + assert kwargs["statement_timeout"] == 120 + assert kwargs["host"] == "custom.example.com" + + def test_parameters_dict_is_copied(self, clear_snowflake_env, rsa_private_key_pem): + """Mutating the caller's parameters dict must not affect the client.""" + params = {"QUERY_TAG": "original"} + client = sc.create_snowflake_client(parameters=params, **self._base_kwargs(rsa_private_key_pem)) + try: + params["QUERY_TAG"] = "mutated" + params["EXTRA"] = "added" + assert client._parameters == {"QUERY_TAG": "original"} + finally: + client.close() + + def test_real_sync_client_constructs(self, clear_snowflake_env, rsa_private_key_pem): + """End-to-end: a real SnowflakeClient is built (key parsed, no network).""" + from snowflake_sql_api import SnowflakeClient + + client = sc.create_snowflake_client( + account="ab12345.ap-southeast-2", + user="tim", + private_key=rsa_private_key_pem, + ) + try: + assert isinstance(client, SnowflakeClient) + assert client.role == "NUI_LAMBDA" + assert client.timezone == "Pacific/Auckland" + assert client._parameters == {} + assert client.on_query is not None + finally: + client.close() + + def test_real_async_client_constructs(self, clear_snowflake_env, rsa_private_key_pem): + import asyncio + + from snowflake_sql_api import AsyncSnowflakeClient + + client = sc.create_async_snowflake_client( + account="ab12345.ap-southeast-2", + user="tim", + private_key=rsa_private_key_pem, + ) + try: + assert isinstance(client, AsyncSnowflakeClient) + assert client.role == "NUI_LAMBDA" + assert client.timezone == "Pacific/Auckland" + finally: + asyncio.run(client.aclose()) + + +# --------------------------------------------------------------------------- +# Redaction: no key material in logs end-to-end +# --------------------------------------------------------------------------- + + +class TestRedactionEndToEnd: + def test_no_key_material_in_logs_during_construction( + self, clear_snowflake_env, caplog, mocker, rsa_private_key_pem + ): + passphrase = "TOP-SECRET-PASSPHRASE" + # Encrypted key so both key bytes and passphrase are in play. + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import rsa + + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + enc_pem = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.BestAvailableEncryption(passphrase.encode()), + ) + mocker.patch( + "nui_shared_utils.secrets_helper.get_secret", + return_value={ + "account": "ab12345", + "user": "svc", + "private_key": enc_pem.decode("utf-8"), + "private_key_passphrase": passphrase, + }, + ) + with caplog.at_level(logging.DEBUG): + client = sc.create_snowflake_client(secret_name="some-secret") + client._notify("SELECT * FROM t WHERE x = ?", ["sensitive-bind"]) + client.close() + assert passphrase not in caplog.text + assert "BEGIN" not in caplog.text # no PEM body + assert "sensitive-bind" not in caplog.text # no bind value + + +# --------------------------------------------------------------------------- +# Missing extra +# --------------------------------------------------------------------------- + + +class TestMissingExtra: + def test_clear_error_when_unavailable(self, clear_snowflake_env, monkeypatch, rsa_private_key_pem): + monkeypatch.setattr(sc, "SNOWFLAKE_SQL_API_AVAILABLE", False) + with pytest.raises(ImportError) as exc: + sc.create_snowflake_client(account="a", user="u", private_key=rsa_private_key_pem) + assert "nui-python-shared-utils[snowflake]" in str(exc.value) + + def test_async_clear_error_when_unavailable(self, clear_snowflake_env, monkeypatch, rsa_private_key_pem): + monkeypatch.setattr(sc, "SNOWFLAKE_SQL_API_AVAILABLE", False) + with pytest.raises(ImportError) as exc: + sc.create_async_snowflake_client(account="a", user="u", private_key=rsa_private_key_pem) + assert "nui-python-shared-utils[snowflake]" in str(exc.value) + + +# --------------------------------------------------------------------------- +# Package wiring (top-level lazy export + back-compat shim) +# --------------------------------------------------------------------------- + + +class TestPackageWiring: + def test_top_level_lazy_exports_resolve(self): + """The PEP-562 lazy exports in __init__ resolve to the real callables.""" + import nui_shared_utils as nui + + assert nui.create_snowflake_client is sc.create_snowflake_client + assert nui.create_async_snowflake_client is sc.create_async_snowflake_client + assert nui.get_snowflake_credentials is sc.get_snowflake_credentials + assert nui.redacting_query_logger is sc.redacting_query_logger + + def test_backwards_compat_shim_reexports(self): + """The nui_lambda_shared_utils shim re-exports the adapter surface.""" + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + from nui_lambda_shared_utils import snowflake_client as shim + + assert shim.create_snowflake_client is sc.create_snowflake_client + assert shim.get_snowflake_credentials is sc.get_snowflake_credentials