From c0a6690cc1258005703e341b1349712f7a10df29 Mon Sep 17 00:00:00 2001 From: TiM Date: Wed, 6 May 2026 15:20:49 +1200 Subject: [PATCH] perf: defer boto3 + submodule loads to cut cold-start init - boto3/botocore no longer imported at module load. Deferred to first use in secrets_helper, utils, and cloudwatch_metrics (MetricsPublisher lazy-constructs its boto3 client on first publish). - Package __init__ uses PEP 562 __getattr__: submodules (slack_client, es_client, db_client, jwt_auth, powertools_helpers, etc.) only load when their attributes are first accessed. `import nui_shared_utils` drops from ~700ms to ~10ms locally. - powertools_helpers no longer eagerly imports SlackClient. slack_sdk is only pulled when slack_alert_channel is passed to powertools_handler. - Legacy nui_lambda_shared_utils shim now forwards via PEP 562 instead of `from nui_shared_utils import *`, so consumers still on the old import name also benefit from the cold-start savings. Backwards compatibility: - Public API unchanged (verified by 517-test suite + 10 new lazy tests) - Star imports still work via __all__ - Legacy nui_lambda_shared_utils import name still works - Optional integrations (SlackClient etc.) still return None when their dep is not installed --- nui_lambda_shared_utils/__init__.py | 25 +- nui_shared_utils/__init__.py | 461 +++++++++++++------------ nui_shared_utils/cloudwatch_metrics.py | 16 +- nui_shared_utils/powertools_helpers.py | 52 ++- nui_shared_utils/secrets_helper.py | 5 +- nui_shared_utils/utils.py | 162 ++++----- scripts/bench_imports.py | 107 ++++++ tests/conftest.py | 2 +- tests/test_cloudwatch_metrics.py | 26 +- tests/test_lazy_imports.py | 160 +++++++++ tests/test_utils.py | 149 ++++---- 11 files changed, 745 insertions(+), 420 deletions(-) create mode 100644 scripts/bench_imports.py create mode 100644 tests/test_lazy_imports.py diff --git a/nui_lambda_shared_utils/__init__.py b/nui_lambda_shared_utils/__init__.py index c368af7..4d85a19 100644 --- a/nui_lambda_shared_utils/__init__.py +++ b/nui_lambda_shared_utils/__init__.py @@ -4,15 +4,21 @@ This package has been renamed to nui-python-shared-utils. The import name has changed from nui_lambda_shared_utils to nui_shared_utils. -This shim re-exports everything from nui_shared_utils so existing consumers +This shim forwards attribute access to nui_shared_utils so existing consumers continue to work without changes. New code should use: from nui_shared_utils import ... This shim will be removed in the next major version (2.0.0). + +Forwarding is lazy (PEP 562 ``__getattr__``) to preserve the cold-start +optimisation in the underlying package: ``from nui_lambda_shared_utils.jwt_auth +import check_auth`` only imports ``jwt_auth`` and its dependencies, not the +full slack/es/db client surface. """ import warnings +from typing import Any, List warnings.warn( "nui_lambda_shared_utils is deprecated. Use nui_shared_utils instead. " @@ -21,5 +27,18 @@ stacklevel=2, ) -from nui_shared_utils import * # noqa: F401,F403 -from nui_shared_utils import __all__ # noqa: F401 +import nui_shared_utils as _target + +__all__ = list(_target.__all__) + + +def __getattr__(name: str) -> Any: + # Delegate to the new package's lazy resolver. Cache on this module so + # subsequent accesses avoid the round-trip. + value = getattr(_target, name) + globals()[name] = value + return value + + +def __dir__() -> List[str]: + return sorted(set(globals()) | set(__all__)) diff --git a/nui_shared_utils/__init__.py b/nui_shared_utils/__init__.py index 85e3aca..6e952c6 100644 --- a/nui_shared_utils/__init__.py +++ b/nui_shared_utils/__init__.py @@ -1,252 +1,263 @@ """ Enterprise-grade utilities for AWS Lambda functions with Slack, Elasticsearch, and monitoring integrations. + +Public API is resolved lazily via PEP 562 ``__getattr__`` to keep package import +cheap. Submodules and their dependencies (boto3, slack-sdk, elasticsearch, etc.) +are only imported when an attribute is first accessed. """ -# Configuration system -from .config import ( - Config, - get_config, - set_config, - configure, - get_es_host, - get_es_credentials_secret, - get_db_credentials_secret, - get_slack_credentials_secret, -) +import importlib +from typing import TYPE_CHECKING, Any, List -# Core utilities -from .secrets_helper import ( - get_secret, - get_database_credentials, - get_elasticsearch_credentials, - get_slack_credentials, - get_api_key, - clear_cache, -) +# Map public name -> (submodule, attr_name). +# Optional integrations: keep their entries here; on ImportError during lazy +# resolution we return None to preserve historical behaviour where consumers +# could check ``if nui.SlackClient is not None``. +_LAZY_EXPORTS = { + # Configuration system + "Config": ("config", "Config"), + "get_config": ("config", "get_config"), + "set_config": ("config", "set_config"), + "configure": ("config", "configure"), + "get_es_host": ("config", "get_es_host"), + "get_es_credentials_secret": ("config", "get_es_credentials_secret"), + "get_db_credentials_secret": ("config", "get_db_credentials_secret"), + "get_slack_credentials_secret": ("config", "get_slack_credentials_secret"), + # Secrets + "get_secret": ("secrets_helper", "get_secret"), + "get_database_credentials": ("secrets_helper", "get_database_credentials"), + "get_elasticsearch_credentials": ("secrets_helper", "get_elasticsearch_credentials"), + "get_slack_credentials": ("secrets_helper", "get_slack_credentials"), + "get_api_key": ("secrets_helper", "get_api_key"), + "clear_cache": ("secrets_helper", "clear_cache"), + # Common utilities + "resolve_config_value": ("utils", "resolve_config_value"), + "resolve_aws_region": ("utils", "resolve_aws_region"), + "create_aws_client": ("utils", "create_aws_client"), + "handle_client_errors": ("utils", "handle_client_errors"), + "merge_dimensions": ("utils", "merge_dimensions"), + "validate_required_param": ("utils", "validate_required_param"), + "safe_close_connection": ("utils", "safe_close_connection"), + "format_log_context": ("utils", "format_log_context"), + "DEFAULT_AWS_REGION": ("utils", "DEFAULT_AWS_REGION"), + # Base client architecture + "BaseClient": ("base_client", "BaseClient"), + "ServiceHealthMixin": ("base_client", "ServiceHealthMixin"), + "RetryableOperationMixin": ("base_client", "RetryableOperationMixin"), + # Timezone helpers + "nz_time": ("timezone", "nz_time"), + "format_nz_time": ("timezone", "format_nz_time"), + # Slack formatting (no external deps) + "SlackBlockBuilder": ("slack_formatter", "SlackBlockBuilder"), + "format_currency": ("slack_formatter", "format_currency"), + "format_percentage": ("slack_formatter", "format_percentage"), + "format_number": ("slack_formatter", "format_number"), + "format_nz_time_slack": ("slack_formatter", "format_nz_time"), + "format_date_range": ("slack_formatter", "format_date_range"), + "format_daily_header": ("slack_formatter", "format_daily_header"), + "format_weekly_header": ("slack_formatter", "format_weekly_header"), + "format_error_alert": ("slack_formatter", "format_error_alert"), + "SEVERITY_EMOJI": ("slack_formatter", "SEVERITY_EMOJI"), + "STATUS_EMOJI": ("slack_formatter", "STATUS_EMOJI"), + # Error handling + "RetryableError": ("error_handler", "RetryableError"), + "NonRetryableError": ("error_handler", "NonRetryableError"), + "ErrorPatternMatcher": ("error_handler", "ErrorPatternMatcher"), + "ErrorAggregator": ("error_handler", "ErrorAggregator"), + "with_retry": ("error_handler", "with_retry"), + "retry_on_network_error": ("error_handler", "retry_on_network_error"), + "retry_on_db_error": ("error_handler", "retry_on_db_error"), + "retry_on_es_error": ("error_handler", "retry_on_es_error"), + "handle_lambda_error": ("error_handler", "handle_lambda_error"), + "categorize_retryable_error": ("error_handler", "categorize_retryable_error"), + # CloudWatch metrics + "MetricsPublisher": ("cloudwatch_metrics", "MetricsPublisher"), + "MetricAggregator": ("cloudwatch_metrics", "MetricAggregator"), + "StandardMetrics": ("cloudwatch_metrics", "StandardMetrics"), + "TimedMetric": ("cloudwatch_metrics", "TimedMetric"), + "track_lambda_performance": ("cloudwatch_metrics", "track_lambda_performance"), + "create_service_dimensions": ("cloudwatch_metrics", "create_service_dimensions"), + "publish_health_metric": ("cloudwatch_metrics", "publish_health_metric"), + # Log processing (no external deps) + "extract_cloudwatch_logs_from_kinesis": ("log_processors", "extract_cloudwatch_logs_from_kinesis"), + "derive_index_name": ("log_processors", "derive_index_name"), + "CloudWatchLogEvent": ("log_processors", "CloudWatchLogEvent"), + "CloudWatchLogsData": ("log_processors", "CloudWatchLogsData"), + # Lambda context helpers + "get_lambda_environment_info": ("lambda_helpers", "get_lambda_environment_info"), + # Optional: Slack client (slack-sdk) + "SlackClient": ("slack_client", "SlackClient"), + # Optional: Elasticsearch client + query builder + "ElasticsearchClient": ("es_client", "ElasticsearchClient"), + "ESQueryBuilder": ("es_query_builder", "ESQueryBuilder"), + "build_error_rate_query": ("es_query_builder", "build_error_rate_query"), + "build_top_errors_query": ("es_query_builder", "build_top_errors_query"), + "build_response_time_query": ("es_query_builder", "build_response_time_query"), + "build_service_volume_query": ("es_query_builder", "build_service_volume_query"), + "build_user_activity_query": ("es_query_builder", "build_user_activity_query"), + "build_pattern_detection_query": ("es_query_builder", "build_pattern_detection_query"), + "build_tender_participant_query": ("es_query_builder", "build_tender_participant_query"), + # Optional: Database client (pymysql / psycopg2) + "DatabaseClient": ("db_client", "DatabaseClient"), + "PostgreSQLClient": ("db_client", "PostgreSQLClient"), + "get_pool_stats": ("db_client", "get_pool_stats"), + # Optional: AWS Powertools + "get_powertools_logger": ("powertools_helpers", "get_powertools_logger"), + "powertools_handler": ("powertools_helpers", "powertools_handler"), + # Optional: JWT validation (rsa) + "validate_jwt": ("jwt_auth", "validate_jwt"), + "require_auth": ("jwt_auth", "require_auth"), + "check_auth": ("jwt_auth", "check_auth"), + "get_jwt_public_key": ("jwt_auth", "get_jwt_public_key"), + "JWTValidationError": ("jwt_auth", "JWTValidationError"), + "AuthenticationError": ("jwt_auth", "AuthenticationError"), +} -# Common utilities -from .utils import ( - resolve_config_value, - create_aws_client, - handle_client_errors, - merge_dimensions, - validate_required_param, -) +# Submodules that are optional integrations; ImportError during lazy load +# resolves to None instead of propagating, matching pre-1.4 behaviour. +# Includes ``slack_setup`` which is also handled by a special-case branch in +# ``__getattr__`` (it is exposed as a submodule object, not an attribute). +_OPTIONAL_SUBMODULES = { + "slack_client", + "es_client", + "es_query_builder", + "db_client", + "powertools_helpers", + "jwt_auth", + "slack_setup", +} -# Base client architecture -from .base_client import BaseClient, ServiceHealthMixin, RetryableOperationMixin -# Client implementations - only fail if actually used -try: - from .slack_client import SlackClient -except ImportError: - SlackClient = None # type: ignore +def __getattr__(name: str) -> Any: + # ``slack_setup`` is exposed as a submodule attribute (``nui.slack_setup``). + if name == "slack_setup": + try: + mod = importlib.import_module(".slack_setup", __name__) + except ImportError: + if name in _OPTIONAL_SUBMODULES: + mod = None + else: + raise + globals()["slack_setup"] = mod + return mod -try: - from .es_client import ElasticsearchClient -except ImportError: - ElasticsearchClient = None # type: ignore + if name in _LAZY_EXPORTS: + submod_name, attr = _LAZY_EXPORTS[name] + try: + submod = importlib.import_module(f".{submod_name}", __name__) + value = getattr(submod, attr) + except ImportError: + if submod_name in _OPTIONAL_SUBMODULES: + value = None + else: + raise + globals()[name] = value + return value + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -try: - from .db_client import DatabaseClient, PostgreSQLClient, get_pool_stats -except ImportError: - DatabaseClient = None # type: ignore - PostgreSQLClient = None # type: ignore - get_pool_stats = None # type: ignore -from .timezone import nz_time, format_nz_time +def __dir__() -> List[str]: + return sorted(set(globals()) | set(_LAZY_EXPORTS) | {"slack_setup"}) -# Slack formatting utilities (no external dependencies) -from .slack_formatter import ( - SlackBlockBuilder, - format_currency, - format_percentage, - format_number, - format_nz_time as format_nz_time_slack, - format_date_range, - format_daily_header, - format_weekly_header, - format_error_alert, - SEVERITY_EMOJI, - STATUS_EMOJI, -) -# ES query builder - optional import -try: +if TYPE_CHECKING: + # Imported for type checkers / IDE completion only; not executed at runtime. + from .config import ( + Config, + configure, + get_config, + get_db_credentials_secret, + get_es_credentials_secret, + get_es_host, + get_slack_credentials_secret, + set_config, + ) + from .secrets_helper import ( + clear_cache, + get_api_key, + get_database_credentials, + get_elasticsearch_credentials, + get_secret, + get_slack_credentials, + ) + from .utils import ( + DEFAULT_AWS_REGION, + create_aws_client, + format_log_context, + handle_client_errors, + merge_dimensions, + resolve_aws_region, + resolve_config_value, + safe_close_connection, + validate_required_param, + ) + from .base_client import BaseClient, RetryableOperationMixin, ServiceHealthMixin + from .timezone import format_nz_time, nz_time + from .slack_formatter import ( + SEVERITY_EMOJI, + STATUS_EMOJI, + SlackBlockBuilder, + format_currency, + format_daily_header, + format_date_range, + format_error_alert, + format_number, + format_nz_time as format_nz_time_slack, + format_percentage, + format_weekly_header, + ) + from .error_handler import ( + ErrorAggregator, + ErrorPatternMatcher, + NonRetryableError, + RetryableError, + categorize_retryable_error, + handle_lambda_error, + retry_on_db_error, + retry_on_es_error, + retry_on_network_error, + with_retry, + ) + from .cloudwatch_metrics import ( + MetricAggregator, + MetricsPublisher, + StandardMetrics, + TimedMetric, + create_service_dimensions, + publish_health_metric, + track_lambda_performance, + ) + from .log_processors import ( + CloudWatchLogEvent, + CloudWatchLogsData, + derive_index_name, + extract_cloudwatch_logs_from_kinesis, + ) + from .lambda_helpers import get_lambda_environment_info + from .slack_client import SlackClient + from .es_client import ElasticsearchClient from .es_query_builder import ( ESQueryBuilder, build_error_rate_query, - build_top_errors_query, + build_pattern_detection_query, build_response_time_query, build_service_volume_query, - build_user_activity_query, - build_pattern_detection_query, build_tender_participant_query, + build_top_errors_query, + build_user_activity_query, ) -except ImportError: - ESQueryBuilder = None # type: ignore - build_error_rate_query = None # type: ignore - build_top_errors_query = None # type: ignore - build_response_time_query = None # type: ignore - build_service_volume_query = None # type: ignore - build_user_activity_query = None # type: ignore - build_pattern_detection_query = None # type: ignore - build_tender_participant_query = None # type: ignore -from .error_handler import ( - RetryableError, - NonRetryableError, - ErrorPatternMatcher, - ErrorAggregator, - with_retry, - retry_on_network_error, - retry_on_db_error, - retry_on_es_error, - handle_lambda_error, - categorize_retryable_error, -) -from .cloudwatch_metrics import ( - MetricsPublisher, - MetricAggregator, - StandardMetrics, - TimedMetric, - track_lambda_performance, - create_service_dimensions, - publish_health_metric, -) - -# AWS Powertools integration - optional import -try: + from .db_client import DatabaseClient, PostgreSQLClient, get_pool_stats from .powertools_helpers import get_powertools_logger, powertools_handler -except ImportError: - get_powertools_logger = None # type: ignore - powertools_handler = None # type: ignore - -# Log processing utilities (no external dependencies) -from .log_processors import ( - CloudWatchLogEvent, - CloudWatchLogsData, - derive_index_name, - extract_cloudwatch_logs_from_kinesis, -) - -# Lambda context helpers (no external dependencies) -from .lambda_helpers import get_lambda_environment_info - -# JWT authentication - optional import -try: from .jwt_auth import ( - validate_jwt, - require_auth, + AuthenticationError, + JWTValidationError, check_auth, get_jwt_public_key, - JWTValidationError, - AuthenticationError, + require_auth, + validate_jwt, ) -except ImportError: - validate_jwt = None # type: ignore - require_auth = None # type: ignore - check_auth = None # type: ignore - get_jwt_public_key = None # type: ignore - JWTValidationError = None # type: ignore - AuthenticationError = None # type: ignore - -# Slack setup utilities (for CLI usage) - optional import -try: from . import slack_setup -except ImportError: - slack_setup = None # type: ignore -__all__ = [ - # Configuration system - "Config", - "get_config", - "set_config", - "configure", - "get_es_host", - "get_es_credentials_secret", - "get_db_credentials_secret", - "get_slack_credentials_secret", - # Core utilities - "get_secret", - "get_database_credentials", - "get_elasticsearch_credentials", - "get_slack_credentials", - "get_api_key", - "clear_cache", - # Common utilities - "resolve_config_value", - "create_aws_client", - "handle_client_errors", - "merge_dimensions", - "validate_required_param", - # Base client architecture - "BaseClient", - "ServiceHealthMixin", - "RetryableOperationMixin", - # Client implementations - "SlackClient", - "ElasticsearchClient", - "DatabaseClient", - "PostgreSQLClient", - "get_pool_stats", # Legacy compatibility (None) - "nz_time", - "format_nz_time", - "slack_setup", - # Slack formatting - "SlackBlockBuilder", - "format_currency", - "format_percentage", - "format_number", - "format_nz_time_slack", - "format_date_range", - "format_daily_header", - "format_weekly_header", - "format_error_alert", - "SEVERITY_EMOJI", - "STATUS_EMOJI", - # ES query building - "ESQueryBuilder", - "build_error_rate_query", - "build_top_errors_query", - "build_response_time_query", - "build_service_volume_query", - "build_user_activity_query", - "build_pattern_detection_query", - "build_tender_participant_query", - # Error handling - "RetryableError", - "NonRetryableError", - "ErrorPatternMatcher", - "ErrorAggregator", - "with_retry", - "retry_on_network_error", - "retry_on_db_error", - "retry_on_es_error", - "handle_lambda_error", - "categorize_retryable_error", - # CloudWatch metrics - "MetricsPublisher", - "MetricAggregator", - "StandardMetrics", - "TimedMetric", - "track_lambda_performance", - "create_service_dimensions", - "publish_health_metric", - # AWS Powertools integration - "get_powertools_logger", - "powertools_handler", - # Log processing - "extract_cloudwatch_logs_from_kinesis", - "derive_index_name", - "CloudWatchLogEvent", - "CloudWatchLogsData", - # Lambda context helpers - "get_lambda_environment_info", - # JWT authentication - "validate_jwt", - "require_auth", - "check_auth", - "get_jwt_public_key", - "JWTValidationError", - "AuthenticationError", -] + +__all__ = list(sorted(set(_LAZY_EXPORTS) | {"slack_setup"})) diff --git a/nui_shared_utils/cloudwatch_metrics.py b/nui_shared_utils/cloudwatch_metrics.py index 82c9c9e..320b741 100644 --- a/nui_shared_utils/cloudwatch_metrics.py +++ b/nui_shared_utils/cloudwatch_metrics.py @@ -9,8 +9,6 @@ from typing import Dict, List, Optional, Union from datetime import datetime from collections import defaultdict -import boto3 -from botocore.exceptions import ClientError log = logging.getLogger(__name__) @@ -45,9 +43,19 @@ def __init__( self.namespace = namespace self.default_dimensions = dimensions or {} self.auto_flush_size = auto_flush_size - self.client = boto3.client("cloudwatch", region_name=region) + self._region = region + self._client = None self.metric_buffer: List[Dict] = [] + @property + def client(self): + """Lazily construct the boto3 CloudWatch client on first use.""" + if self._client is None: + import boto3 + + self._client = boto3.client("cloudwatch", region_name=self._region) + return self._client + def put_metric( self, metric_name: str, @@ -148,6 +156,8 @@ def flush(self) -> bool: if not self.metric_buffer: return True + from botocore.exceptions import ClientError + try: # CloudWatch allows max 20 metrics per request for i in range(0, len(self.metric_buffer), 20): diff --git a/nui_shared_utils/powertools_helpers.py b/nui_shared_utils/powertools_helpers.py index 9bd733e..0dd5f47 100644 --- a/nui_shared_utils/powertools_helpers.py +++ b/nui_shared_utils/powertools_helpers.py @@ -26,15 +26,34 @@ except ImportError: COLOREDLOGS_AVAILABLE = False -try: - from .slack_client import SlackClient +from .lambda_helpers import get_lambda_environment_info - SLACK_CLIENT_AVAILABLE = True -except ImportError: - SLACK_CLIENT_AVAILABLE = False - SlackClient = None # type: ignore +# SlackClient is loaded lazily on first use to keep this module's import +# cost low for callers that don't enable Slack alerting (it transitively +# pulls in slack_sdk, which is the dominant cost). +SLACK_CLIENT_AVAILABLE = False +SlackClient = None # type: ignore[assignment] -from .lambda_helpers import get_lambda_environment_info + +def _ensure_slack_client_loaded() -> None: + """Lazy-import :class:`SlackClient` and update module-level flags. + + Idempotent: returns immediately if ``SlackClient`` has already been + populated (real import or test mock). + """ + global SLACK_CLIENT_AVAILABLE, SlackClient + if SlackClient is not None: + # Already populated (real import or test mock) — keep the availability + # flag in sync so callers don't see a stale False. + SLACK_CLIENT_AVAILABLE = True + return + try: + from .slack_client import SlackClient as _SC + + SlackClient = _SC + SLACK_CLIENT_AVAILABLE = True + except ImportError: + SLACK_CLIENT_AVAILABLE = False __all__ = ["get_powertools_logger", "powertools_handler"] @@ -99,6 +118,7 @@ def _mock_inject_lambda_context(func=None, **_kwargs): if func is not None: return func return lambda f: f + logger.inject_lambda_context = _mock_inject_lambda_context # type: ignore return logger @@ -194,14 +214,16 @@ def decorator(func: Callable) -> Callable: # Create Slack client if channel provided slack_client = None - if slack_alert_channel and SLACK_CLIENT_AVAILABLE: - try: - slack_client = SlackClient( - account_names=slack_account_names, - account_names_config=slack_account_names_config, - ) - except Exception as e: - logger.warning("Failed to initialize Slack client: %s", e) + if slack_alert_channel: + _ensure_slack_client_loaded() + if SLACK_CLIENT_AVAILABLE and SlackClient is not None: + try: + slack_client = SlackClient( + account_names=slack_account_names, + account_names_config=slack_account_names_config, + ) + except Exception as e: + logger.warning("Failed to initialize Slack client: %s", e) @functools.wraps(func) def wrapper(event: dict, context: Any) -> dict: diff --git a/nui_shared_utils/secrets_helper.py b/nui_shared_utils/secrets_helper.py index 5391781..8b904d5 100644 --- a/nui_shared_utils/secrets_helper.py +++ b/nui_shared_utils/secrets_helper.py @@ -7,8 +7,6 @@ import json import logging from typing import Dict, Optional -import boto3 -from botocore.exceptions import ClientError from .config import get_config @@ -35,6 +33,9 @@ def get_secret(secret_name: str) -> Dict: if secret_name in _secrets_cache: return _secrets_cache[secret_name] + import boto3 + from botocore.exceptions import ClientError + # Create a Secrets Manager client session = boto3.session.Session() client = session.client(service_name="secretsmanager", region_name=session.region_name or "ap-southeast-2") diff --git a/nui_shared_utils/utils.py b/nui_shared_utils/utils.py index b99a7ca..73be4eb 100644 --- a/nui_shared_utils/utils.py +++ b/nui_shared_utils/utils.py @@ -7,53 +7,49 @@ import logging import functools from typing import Union, List, Optional, Any, Dict -import boto3 -from botocore.exceptions import ClientError, NoCredentialsError from .config import get_config log = logging.getLogger(__name__) -# AWS region resolution constants +# AWS region fallback. Used only when no explicit region, env var, config, +# or boto3 session region is available. Override at deploy time via the +# ``AWS_REGION_FALLBACK`` environment variable. DEFAULT_AWS_REGION = "ap-southeast-2" -def resolve_config_value( - param_value: Optional[Any], - env_var_names: Union[str, List[str]], - config_default: Any -) -> Any: +def resolve_config_value(param_value: Optional[Any], env_var_names: Union[str, List[str]], config_default: Any) -> Any: """ Resolve configuration value with priority: param > env vars > config default. - + Args: param_value: Explicitly provided parameter value env_var_names: Environment variable name(s) to check (string or list) config_default: Default value from configuration - + Returns: Resolved configuration value - + Example: host = resolve_config_value( - host_param, - ["ES_HOST", "ELASTICSEARCH_HOST"], + host_param, + ["ES_HOST", "ELASTICSEARCH_HOST"], "localhost:9200" ) """ # Parameter takes highest precedence if param_value is not None: return param_value - + # Check environment variables if isinstance(env_var_names, str): env_var_names = [env_var_names] - + for env_var in env_var_names: value = os.environ.get(env_var) if value is not None: return value - + # Fall back to config default return config_default @@ -61,70 +57,72 @@ def resolve_config_value( def resolve_aws_region(explicit_region: Optional[str] = None) -> str: """ Resolve AWS region with priority: param > env > config > session > default. - + Args: explicit_region: Explicitly provided region - + Returns: AWS region string """ # Explicit parameter wins if explicit_region: return explicit_region - + # Check environment variables - env_region = resolve_config_value( - None, - ["AWS_REGION", "AWS_DEFAULT_REGION"], - None - ) + env_region = resolve_config_value(None, ["AWS_REGION", "AWS_DEFAULT_REGION"], None) if env_region: return env_region - + # Check config config = get_config() - if hasattr(config, 'aws_region') and config.aws_region: + if hasattr(config, "aws_region") and config.aws_region: return config.aws_region - + # Check boto3 session default try: + import boto3 + from botocore.exceptions import NoCredentialsError + session = boto3.session.Session() if session.region_name: return session.region_name - except Exception as e: - log.debug(f"Failed to get session region: {e}") - - # Final fallback - return DEFAULT_AWS_REGION + except ImportError as e: + log.warning(f"boto3 not available for session-based region resolution: {e}") + except NoCredentialsError as e: + log.warning(f"No AWS credentials configured for session-based region resolution: {e}") + + # Final fallback. Operators can override the package default via + # AWS_REGION_FALLBACK without forking or monkey-patching the constant. + return os.environ.get("AWS_REGION_FALLBACK") or DEFAULT_AWS_REGION def create_aws_client(service_name: str, region: Optional[str] = None): """ Create AWS client with consistent region resolution and error handling. - + Args: service_name: AWS service name (e.g., 'secretsmanager', 'cloudwatch') region: Optional explicit region - + Returns: AWS service client - + Raises: NoCredentialsError: When AWS credentials are not configured ClientError: When client creation fails """ resolved_region = resolve_aws_region(region) - + + import boto3 + from botocore.exceptions import ClientError, NoCredentialsError + try: session = boto3.session.Session() - client = session.client( - service_name=service_name, - region_name=resolved_region - ) - + client = session.client(service_name=service_name, region_name=resolved_region) + log.debug(f"Created {service_name} client for region {resolved_region}") return client - + except NoCredentialsError: log.error(f"AWS credentials not configured for {service_name} client") raise @@ -137,65 +135,59 @@ def create_aws_client(service_name: str, region: Optional[str] = None): def handle_client_errors( - default_return: Any = None, - log_context: Optional[Dict[str, Any]] = None, - reraise: bool = False + default_return: Any = None, log_context: Optional[Dict[str, Any]] = None, reraise: bool = False ): """ Decorator for standardized client error handling. - + Args: default_return: Value to return on error (if not reraising) log_context: Additional context for error logging reraise: Whether to re-raise exceptions after logging - + Example: @handle_client_errors(default_return=[]) def search_documents(self, query): # Implementation that might fail return results """ + def decorator(func): @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args, **kwargs) -> Any: try: return func(*args, **kwargs) except Exception as e: # Build log context - context = { - "function": func.__name__, - "error_type": type(e).__name__, - "error_message": str(e) - } + context = {"function": func.__name__, "error_type": type(e).__name__, "error_message": str(e)} if log_context: context.update(log_context) - - log.error( - f"{func.__name__} failed: {e}", - exc_info=True, - extra=context - ) - + + log.error(f"{func.__name__} failed: {e}", exc_info=True, extra=context) + if reraise: raise - + return default_return - + return wrapper + return decorator -def merge_dimensions(base_dimensions: Dict[str, str], additional_dimensions: Optional[Dict[str, str]] = None) -> List[Dict[str, str]]: +def merge_dimensions( + base_dimensions: Dict[str, str], additional_dimensions: Optional[Dict[str, str]] = None +) -> List[Dict[str, str]]: """ Merge CloudWatch metric dimensions and format for API. - + Args: base_dimensions: Base dimensions dictionary additional_dimensions: Additional dimensions to merge - + Returns: List of dimension dictionaries formatted for CloudWatch API - + Example: dimensions = merge_dimensions( {"Service": "auth", "Environment": "prod"}, @@ -206,40 +198,37 @@ def merge_dimensions(base_dimensions: Dict[str, str], additional_dimensions: Opt all_dimensions = {**base_dimensions} if additional_dimensions: all_dimensions.update(additional_dimensions) - - return [ - {"Name": str(key), "Value": str(value)} - for key, value in all_dimensions.items() - ] + + return [{"Name": str(key), "Value": str(value)} for key, value in all_dimensions.items()] def validate_required_param(param_value: Any, param_name: str) -> Any: """ Validate that a required parameter is provided. - + Args: param_value: Parameter value to validate param_name: Parameter name for error messages - + Returns: The parameter value if valid - + Raises: ValueError: If parameter is None or empty string """ if param_value is None: raise ValueError(f"{param_name} is required") - + if isinstance(param_value, str) and not param_value.strip(): raise ValueError(f"{param_name} cannot be empty") - + return param_value def safe_close_connection(connection) -> None: """ Safely close a database connection with proper error handling. - + Args: connection: Database connection to close """ @@ -251,29 +240,26 @@ def safe_close_connection(connection) -> None: return if hasattr(connection, "open") and not connection.open: return - + # Generic close connection.close() log.debug("Database connection closed successfully") - + except Exception as e: log.debug(f"Error closing connection (non-fatal): {e}") -def format_log_context( - operation: str, - **context_data -) -> Dict[str, Any]: +def format_log_context(operation: str, **context_data) -> Dict[str, Any]: """ Format consistent logging context for operations. - + Args: operation: Operation name **context_data: Additional context key-value pairs - + Returns: Formatted context dictionary - + Example: context = format_log_context( "database_query", @@ -287,5 +273,5 @@ def format_log_context( "timestamp": time.time(), } context.update(context_data) - - return context \ No newline at end of file + + return context diff --git a/scripts/bench_imports.py b/scripts/bench_imports.py new file mode 100644 index 0000000..c28aff6 --- /dev/null +++ b/scripts/bench_imports.py @@ -0,0 +1,107 @@ +""" +Measure import-time cost of common entry points into ``nui_shared_utils``. + +Runs ``python -X importtime -c "import "`` in a fresh subprocess for each +entry point, parses the cumulative timings emitted on stderr, and prints a +summary table. Use before/after performance changes (lazy imports, dependency +shuffles) to confirm the impact on Lambda cold-start init. + +Usage: + python scripts/bench_imports.py + python scripts/bench_imports.py --runs 3 # average over 3 runs + python scripts/bench_imports.py --out out.json # machine-readable +""" + +from __future__ import annotations + +import argparse +import json +import re +import statistics +import subprocess +import sys +from pathlib import Path + +ENTRY_POINTS = [ + "nui_shared_utils", + "nui_shared_utils.config", + "nui_shared_utils.secrets_helper", + "nui_shared_utils.jwt_auth", + "nui_shared_utils.powertools_helpers", + "nui_shared_utils.cloudwatch_metrics", +] + +# Modules whose cumulative time we want to highlight per run, when present. +WATCH_MODULES = ["boto3", "botocore", "nui_shared_utils", "nui_shared_utils.secrets_helper"] + +LINE_RE = re.compile(r"^import time:\s+(\d+)\s+\|\s+(\d+)\s+\|\s+(.+?)\s*$") + + +def measure(entry: str) -> dict: + """Run a single ``python -X importtime`` invocation and parse the output.""" + proc = subprocess.run( + [sys.executable, "-X", "importtime", "-c", f"import {entry}"], + capture_output=True, + text=True, + check=True, + ) + cumulative_us: dict[str, int] = {} + for line in proc.stderr.splitlines(): + m = LINE_RE.match(line) + if not m: + continue + # Strip leading nesting whitespace from the module name. + module = m.group(3).strip() + cum = int(m.group(2)) + # Last occurrence wins; importtime emits one final line per module load. + cumulative_us[module] = cum + return cumulative_us + + +def average_runs(entry: str, runs: int) -> dict: + samples: list[dict[str, int]] = [measure(entry) for _ in range(runs)] + averaged: dict[str, float] = {} + keys = set().union(*(s.keys() for s in samples)) + for k in keys: + vals = [s.get(k, 0) for s in samples] + averaged[k] = statistics.mean(vals) + return averaged + + +def fmt_ms(microseconds: float) -> str: + return f"{microseconds / 1000:7.1f} ms" + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--runs", type=int, default=1, help="Number of runs per entry point (averaged)") + parser.add_argument("--out", type=Path, default=None, help="Write raw results as JSON to this path") + args = parser.parse_args() + if args.runs < 1: + parser.error(f"--runs must be a positive integer, got {args.runs}") + + print(f"Python: {sys.version.split()[0]} Runs per entry: {args.runs}") + print() + + results: dict[str, dict[str, float]] = {} + for entry in ENTRY_POINTS: + cumulative = average_runs(entry, args.runs) + results[entry] = cumulative + + total = cumulative.get(entry, 0) + print(f"=== import {entry}") + print(f" cumulative: {fmt_ms(total)}") + for watched in WATCH_MODULES: + if watched in cumulative and watched != entry: + print(f" {watched:40s}{fmt_ms(cumulative[watched])}") + print() + + if args.out: + args.out.write_text(json.dumps(results, indent=2, sort_keys=True)) + print(f"Wrote raw results to {args.out}") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/conftest.py b/tests/conftest.py index b99e3d7..0ad3134 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,7 @@ @pytest.fixture def mock_boto3_session(): """Mock boto3 session.""" - with patch("nui_shared_utils.secrets_helper.boto3.session.Session") as mock_session_class: + with patch("boto3.session.Session") as mock_session_class: mock_session = Mock() mock_session_class.return_value = mock_session yield mock_session diff --git a/tests/test_cloudwatch_metrics.py b/tests/test_cloudwatch_metrics.py index ce30cee..94acc16 100644 --- a/tests/test_cloudwatch_metrics.py +++ b/tests/test_cloudwatch_metrics.py @@ -24,7 +24,7 @@ class TestMetricsPublisher: """Tests for MetricsPublisher class.""" - @patch("nui_shared_utils.cloudwatch_metrics.boto3.client") + @patch("boto3.client") def test_init_default_values(self, mock_boto3_client): """Test initialization with default values.""" mock_client = Mock() @@ -40,7 +40,7 @@ def test_init_default_values(self, mock_boto3_client): mock_boto3_client.assert_called_once_with("cloudwatch", region_name=None) - @patch("nui_shared_utils.cloudwatch_metrics.boto3.client") + @patch("boto3.client") def test_init_custom_values(self, mock_boto3_client): """Test initialization with custom values.""" mock_client = Mock() @@ -56,7 +56,7 @@ def test_init_custom_values(self, mock_boto3_client): mock_boto3_client.assert_called_once_with("cloudwatch", region_name="us-east-1") - @patch("nui_shared_utils.cloudwatch_metrics.boto3.client") + @patch("boto3.client") def test_put_metric_basic(self, mock_boto3_client): """Test basic metric publishing.""" mock_client = Mock() @@ -81,7 +81,7 @@ def test_put_metric_basic(self, mock_boto3_client): assert metric["StorageResolution"] == 60 assert "Dimensions" not in metric - @patch("nui_shared_utils.cloudwatch_metrics.boto3.client") + @patch("boto3.client") def test_put_metric_with_dimensions(self, mock_boto3_client): """Test metric publishing with dimensions.""" mock_client = Mock() @@ -102,7 +102,7 @@ def test_put_metric_with_dimensions(self, mock_boto3_client): assert metric["Dimensions"] == expected_dimensions - @patch("nui_shared_utils.cloudwatch_metrics.boto3.client") + @patch("boto3.client") def test_put_metric_with_custom_timestamp(self, mock_boto3_client): """Test metric with custom timestamp.""" mock_client = Mock() @@ -117,7 +117,7 @@ def test_put_metric_with_custom_timestamp(self, mock_boto3_client): assert metric["Timestamp"] == custom_timestamp assert metric["StorageResolution"] == 1 - @patch("nui_shared_utils.cloudwatch_metrics.boto3.client") + @patch("boto3.client") def test_put_metric_auto_flush(self, mock_boto3_client): """Test auto-flush when buffer reaches size limit.""" mock_client = Mock() @@ -135,7 +135,7 @@ def test_put_metric_auto_flush(self, mock_boto3_client): assert len(publisher.metric_buffer) == 0 # Buffer cleared after flush mock_client.put_metric_data.assert_called_once() - @patch("nui_shared_utils.cloudwatch_metrics.boto3.client") + @patch("boto3.client") def test_put_metric_with_statistics(self, mock_boto3_client): """Test metric publishing with statistical values.""" mock_client = Mock() @@ -151,7 +151,7 @@ def test_put_metric_with_statistics(self, mock_boto3_client): assert metric["Unit"] == "Milliseconds" assert metric["StatisticValues"] == {"SampleCount": 5, "Sum": 150, "Minimum": 10, "Maximum": 50} - @patch("nui_shared_utils.cloudwatch_metrics.boto3.client") + @patch("boto3.client") def test_put_metric_with_statistics_empty_values(self, mock_boto3_client): """Test metric with statistics with empty values list.""" mock_client = Mock() @@ -163,7 +163,7 @@ def test_put_metric_with_statistics_empty_values(self, mock_boto3_client): assert len(publisher.metric_buffer) == 0 - @patch("nui_shared_utils.cloudwatch_metrics.boto3.client") + @patch("boto3.client") def test_flush_success(self, mock_boto3_client): """Test successful metrics flush.""" mock_client = Mock() @@ -186,7 +186,7 @@ def test_flush_success(self, mock_boto3_client): assert call_args[1]["Namespace"] == "TestNamespace" assert len(call_args[1]["MetricData"]) == 2 - @patch("nui_shared_utils.cloudwatch_metrics.boto3.client") + @patch("boto3.client") def test_flush_large_batch(self, mock_boto3_client): """Test flush with more than 20 metrics (batch splitting).""" mock_client = Mock() @@ -204,7 +204,7 @@ def test_flush_large_batch(self, mock_boto3_client): assert mock_client.put_metric_data.call_count == 2 assert len(publisher.metric_buffer) == 0 - @patch("nui_shared_utils.cloudwatch_metrics.boto3.client") + @patch("boto3.client") def test_flush_empty_buffer(self, mock_boto3_client): """Test flush with empty buffer.""" mock_client = Mock() @@ -217,7 +217,7 @@ def test_flush_empty_buffer(self, mock_boto3_client): assert result is True mock_client.put_metric_data.assert_not_called() - @patch("nui_shared_utils.cloudwatch_metrics.boto3.client") + @patch("boto3.client") def test_flush_client_error(self, mock_boto3_client): """Test flush with CloudWatch client error.""" mock_client = Mock() @@ -235,7 +235,7 @@ def test_flush_client_error(self, mock_boto3_client): # Buffer should not be cleared on error assert len(publisher.metric_buffer) == 1 - @patch("nui_shared_utils.cloudwatch_metrics.boto3.client") + @patch("boto3.client") def test_context_manager(self, mock_boto3_client): """Test context manager functionality.""" mock_client = Mock() diff --git a/tests/test_lazy_imports.py b/tests/test_lazy_imports.py new file mode 100644 index 0000000..2bce4b7 --- /dev/null +++ b/tests/test_lazy_imports.py @@ -0,0 +1,160 @@ +""" +Regression tests for cold-start import laziness. + +We run each scenario in a fresh subprocess so module-import side effects from +the rest of the test suite cannot mask a regression. +""" + +import subprocess +import sys +import textwrap + + +def _run(script: str) -> str: + """Execute ``script`` in a fresh Python interpreter and return stdout.""" + result = subprocess.run( + [sys.executable, "-c", textwrap.dedent(script)], + capture_output=True, + text=True, + check=True, + timeout=30, + ) + return result.stdout.strip() + + +def test_bare_package_import_does_not_load_boto3(): + """``import nui_shared_utils`` must not transitively import boto3.""" + output = _run(""" + import sys + import nui_shared_utils # noqa: F401 + for mod in ("boto3", "botocore"): + print(f"{mod}={mod in sys.modules}") + """) + assert "boto3=False" in output + assert "botocore=False" in output + + +def test_jwt_auth_import_does_not_load_boto3(): + """``from nui_shared_utils.jwt_auth import check_auth`` must stay lazy. + + Lambdas using only JWT validation (with cached public key) should never + pay the boto3 cold-start cost. + """ + output = _run(""" + import sys + from nui_shared_utils.jwt_auth import check_auth # noqa: F401 + print(f"boto3={'boto3' in sys.modules}") + """) + assert "boto3=False" in output + + +def test_powertools_helpers_import_does_not_load_boto3(): + output = _run(""" + import sys + from nui_shared_utils.powertools_helpers import get_powertools_logger # noqa: F401 + print(f"boto3={'boto3' in sys.modules}") + """) + assert "boto3=False" in output + + +def test_powertools_helpers_does_not_pull_slack_sdk(): + """``powertools_helpers`` must not transitively load slack_sdk. + + Slack alerting is opt-in via ``slack_alert_channel`` on + ``powertools_handler``; consumers using only ``get_powertools_logger`` + should not pay the slack_sdk import cost (~150ms). + """ + output = _run(""" + import sys + from nui_shared_utils.powertools_helpers import get_powertools_logger # noqa: F401 + print(f"slack_sdk={'slack_sdk' in sys.modules}") + print(f"slack_client={'nui_shared_utils.slack_client' in sys.modules}") + """) + assert "slack_sdk=False" in output + assert "slack_client=False" in output + + +def test_legacy_shim_attribute_access_is_lazy(): + """``from nui_lambda_shared_utils import X`` must remain lazy. + + The shim used to do ``from nui_shared_utils import *`` which eagerly + loaded every submodule. After the 1.4 lazy refactor it forwards via + PEP 562 ``__getattr__`` so legacy consumers also benefit from + cold-start savings. + """ + output = _run(""" + import sys + import warnings + warnings.filterwarnings("ignore", category=DeprecationWarning) + from nui_lambda_shared_utils import get_powertools_logger # noqa: F401 + print(f"boto3={'boto3' in sys.modules}") + print(f"slack_sdk={'slack_sdk' in sys.modules}") + print(f"db_client={'nui_shared_utils.db_client' in sys.modules}") + print(f"es_client={'nui_shared_utils.es_client' in sys.modules}") + """) + assert "boto3=False" in output + assert "slack_sdk=False" in output + assert "db_client=False" in output + assert "es_client=False" in output + + +def test_secrets_helper_loads_boto3_only_on_first_call(): + """Importing ``secrets_helper`` is cheap; the first ``get_secret`` call pays. + + Uses ``moto`` to mock AWS so the test stays network-/credential-independent. + ``moto`` is imported after the lazy-import assertion because it pulls + ``boto3`` transitively, which would skew ``after_import``. + """ + output = _run(""" + import sys + from nui_shared_utils import secrets_helper + print(f"after_import={'boto3' in sys.modules}") + # moto imports boto3 transitively; defer until after the lazy-import check. + from moto import mock_aws + with mock_aws(): + try: + secrets_helper.get_secret("nonexistent-secret-for-test") + except Exception: + pass + print(f"after_call={'boto3' in sys.modules}") + """) + assert "after_import=False" in output + assert "after_call=True" in output + + +def test_lazy_attribute_access_returns_real_object(): + """PEP 562 ``__getattr__`` must materialise real objects, not stubs.""" + import nui_shared_utils as nui + + assert "1.50" in nui.format_currency(1.5, "NZD") + assert nui.SlackBlockBuilder().__class__.__name__ == "SlackBlockBuilder" + + +def test_unknown_attribute_raises_attribute_error(): + import pytest + + import nui_shared_utils as nui + + with pytest.raises(AttributeError, match="this_does_not_exist"): + nui.this_does_not_exist # noqa: B018 + + +def test_lazy_attribute_is_cached_after_first_access(): + """Once resolved, the attribute should live in ``globals()`` for fast reuse.""" + output = _run(""" + import nui_shared_utils as nui + first = nui.format_currency + second = nui.format_currency + print(f"same={first is second}") + print(f"in_globals={'format_currency' in vars(nui)}") + """) + assert "same=True" in output + assert "in_globals=True" in output + + +def test_dir_includes_lazy_exports(): + import nui_shared_utils as nui + + names = dir(nui) + for expected in ("SlackClient", "format_currency", "with_retry", "get_secret", "slack_setup"): + assert expected in names, f"{expected!r} missing from dir(nui_shared_utils)" diff --git a/tests/test_utils.py b/tests/test_utils.py index da5911e..9105a72 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -19,7 +19,7 @@ validate_required_param, safe_close_connection, format_log_context, - DEFAULT_AWS_REGION + DEFAULT_AWS_REGION, ) @@ -81,7 +81,7 @@ def test_explicit_region_takes_precedence(self): result = resolve_aws_region("eu-west-1") assert result == "eu-west-1" - @patch('nui_shared_utils.utils.get_config') + @patch("nui_shared_utils.utils.get_config") def test_env_var_precedence(self, mock_get_config): """Test that AWS_REGION environment variable is used.""" mock_config = Mock() @@ -92,7 +92,7 @@ def test_env_var_precedence(self, mock_get_config): result = resolve_aws_region() assert result == "env-region" - @patch('nui_shared_utils.utils.get_config') + @patch("nui_shared_utils.utils.get_config") def test_aws_default_region_env_var(self, mock_get_config): """Test that AWS_DEFAULT_REGION environment variable is used.""" mock_config = Mock() @@ -103,7 +103,7 @@ def test_aws_default_region_env_var(self, mock_get_config): result = resolve_aws_region() assert result == "default-env-region" - @patch('nui_shared_utils.utils.get_config') + @patch("nui_shared_utils.utils.get_config") def test_config_region_fallback(self, mock_get_config): """Test that config aws_region is used when env vars not set.""" mock_config = Mock() @@ -114,8 +114,8 @@ def test_config_region_fallback(self, mock_get_config): result = resolve_aws_region() assert result == "config-region" - @patch('nui_shared_utils.utils.get_config') - @patch('boto3.session.Session') + @patch("nui_shared_utils.utils.get_config") + @patch("boto3.session.Session") def test_boto3_session_fallback(self, mock_session_class, mock_get_config): """Test that boto3 session region is used as fallback.""" mock_config = Mock() @@ -130,8 +130,8 @@ def test_boto3_session_fallback(self, mock_session_class, mock_get_config): result = resolve_aws_region() assert result == "session-region" - @patch('nui_shared_utils.utils.get_config') - @patch('boto3.session.Session') + @patch("nui_shared_utils.utils.get_config") + @patch("boto3.session.Session") def test_default_region_final_fallback(self, mock_session_class, mock_get_config): """Test that DEFAULT_AWS_REGION is used as final fallback.""" mock_config = Mock() @@ -146,26 +146,61 @@ def test_default_region_final_fallback(self, mock_session_class, mock_get_config result = resolve_aws_region() assert result == DEFAULT_AWS_REGION - @patch('nui_shared_utils.utils.get_config') - @patch('boto3.session.Session') - def test_boto3_session_exception_handling(self, mock_session_class, mock_get_config): - """Test that boto3 session exceptions are handled gracefully.""" + @patch("nui_shared_utils.utils.get_config") + @patch("boto3.session.Session") + def test_boto3_session_no_credentials_handled(self, mock_session_class, mock_get_config): + """NoCredentialsError during session-based region resolution falls back.""" + from botocore.exceptions import NoCredentialsError + mock_config = Mock() mock_config.aws_region = None mock_get_config.return_value = mock_config - mock_session_class.side_effect = Exception("Session error") + mock_session_class.side_effect = NoCredentialsError() with patch.dict(os.environ, {}, clear=True): result = resolve_aws_region() assert result == DEFAULT_AWS_REGION + @patch("nui_shared_utils.utils.get_config") + @patch("boto3.session.Session") + def test_unexpected_session_exception_propagates(self, mock_session_class, mock_get_config): + """Unknown exceptions are surfaced rather than masking real config issues.""" + mock_config = Mock() + mock_config.aws_region = None + mock_get_config.return_value = mock_config + + mock_session_class.side_effect = RuntimeError("unexpected") + + with patch.dict(os.environ, {}, clear=True): + try: + resolve_aws_region() + except RuntimeError as e: + assert "unexpected" in str(e) + else: + raise AssertionError("Expected RuntimeError to propagate") + + @patch("nui_shared_utils.utils.get_config") + @patch("boto3.session.Session") + def test_aws_region_fallback_env_var_overrides_default(self, mock_session_class, mock_get_config): + """AWS_REGION_FALLBACK overrides the package default when no other source applies.""" + mock_config = Mock() + mock_config.aws_region = None + mock_get_config.return_value = mock_config + + mock_session = Mock() + mock_session.region_name = None + mock_session_class.return_value = mock_session + + with patch.dict(os.environ, {"AWS_REGION_FALLBACK": "eu-west-1"}, clear=True): + assert resolve_aws_region() == "eu-west-1" + class TestCreateAwsClient: """Test create_aws_client function.""" - @patch('nui_shared_utils.utils.resolve_aws_region') - @patch('boto3.session.Session') + @patch("nui_shared_utils.utils.resolve_aws_region") + @patch("boto3.session.Session") def test_successful_client_creation(self, mock_session_class, mock_resolve_region): """Test successful AWS client creation.""" mock_resolve_region.return_value = "us-east-1" @@ -178,13 +213,10 @@ def test_successful_client_creation(self, mock_session_class, mock_resolve_regio assert result == mock_client mock_resolve_region.assert_called_once_with("us-west-2") - mock_session.client.assert_called_once_with( - service_name="secretsmanager", - region_name="us-east-1" - ) + mock_session.client.assert_called_once_with(service_name="secretsmanager", region_name="us-east-1") - @patch('nui_shared_utils.utils.resolve_aws_region') - @patch('boto3.session.Session') + @patch("nui_shared_utils.utils.resolve_aws_region") + @patch("boto3.session.Session") def test_no_credentials_error(self, mock_session_class, mock_resolve_region): """Test handling of NoCredentialsError.""" mock_resolve_region.return_value = "us-east-1" @@ -195,23 +227,22 @@ def test_no_credentials_error(self, mock_session_class, mock_resolve_region): with pytest.raises(NoCredentialsError): create_aws_client("secretsmanager") - @patch('nui_shared_utils.utils.resolve_aws_region') - @patch('boto3.session.Session') + @patch("nui_shared_utils.utils.resolve_aws_region") + @patch("boto3.session.Session") def test_client_error(self, mock_session_class, mock_resolve_region): """Test handling of ClientError.""" mock_resolve_region.return_value = "us-east-1" mock_session = Mock() mock_session.client.side_effect = ClientError( - {"Error": {"Code": "AccessDenied", "Message": "Access denied"}}, - "CreateClient" + {"Error": {"Code": "AccessDenied", "Message": "Access denied"}}, "CreateClient" ) mock_session_class.return_value = mock_session with pytest.raises(ClientError): create_aws_client("secretsmanager") - @patch('nui_shared_utils.utils.resolve_aws_region') - @patch('boto3.session.Session') + @patch("nui_shared_utils.utils.resolve_aws_region") + @patch("boto3.session.Session") def test_unexpected_error(self, mock_session_class, mock_resolve_region): """Test handling of unexpected errors.""" mock_resolve_region.return_value = "us-east-1" @@ -228,6 +259,7 @@ class TestHandleClientErrors: def test_successful_execution(self): """Test that decorator doesn't interfere with successful execution.""" + @handle_client_errors() def test_func(): return "success" @@ -237,6 +269,7 @@ def test_func(): def test_error_with_default_return(self, caplog): """Test that error returns default value and logs error.""" + @handle_client_errors(default_return="default_value") def test_func(): raise ValueError("Test error") @@ -250,6 +283,7 @@ def test_func(): def test_error_with_reraise(self): """Test that error is reraised when reraise=True.""" + @handle_client_errors(reraise=True) def test_func(): raise ValueError("Test error") @@ -259,10 +293,8 @@ def test_func(): def test_error_with_log_context(self, caplog): """Test that additional log context is included.""" - @handle_client_errors( - default_return=None, - log_context={"service": "test", "operation": "query"} - ) + + @handle_client_errors(default_return=None, log_context={"service": "test", "operation": "query"}) def test_func(): raise ConnectionError("Connection failed") @@ -278,6 +310,7 @@ def test_func(): def test_function_metadata_preserved(self): """Test that original function metadata is preserved.""" + @handle_client_errors() def test_func_with_metadata(): """Test function docstring.""" @@ -288,6 +321,7 @@ def test_func_with_metadata(): def test_function_arguments_passed_through(self): """Test that function arguments are passed through correctly.""" + @handle_client_errors() def test_func(arg1, arg2, kwarg1=None): return f"{arg1}-{arg2}-{kwarg1}" @@ -304,10 +338,7 @@ def test_base_dimensions_only(self): base = {"Service": "auth", "Environment": "prod"} result = merge_dimensions(base) - expected = [ - {"Name": "Service", "Value": "auth"}, - {"Name": "Environment", "Value": "prod"} - ] + expected = [{"Name": "Service", "Value": "auth"}, {"Name": "Environment", "Value": "prod"}] assert len(result) == 2 assert all(dim in result for dim in expected) @@ -320,7 +351,7 @@ def test_base_and_additional_dimensions(self): expected = [ {"Name": "Service", "Value": "auth"}, {"Name": "Version", "Value": "1.2.3"}, - {"Name": "Region", "Value": "us-east-1"} + {"Name": "Region", "Value": "us-east-1"}, ] assert len(result) == 3 assert all(dim in result for dim in expected) @@ -341,10 +372,7 @@ def test_numeric_values_converted_to_strings(self): base = {"Port": 8080, "Count": 42} result = merge_dimensions(base) - expected = [ - {"Name": "Port", "Value": "8080"}, - {"Name": "Count", "Value": "42"} - ] + expected = [{"Name": "Port", "Value": "8080"}, {"Name": "Count", "Value": "42"}] assert len(result) == 2 assert all(dim in result for dim in expected) @@ -474,68 +502,49 @@ def test_connection_without_close_method(self): class TestFormatLogContext: """Test format_log_context function.""" - @patch('time.time') + @patch("time.time") def test_basic_log_context(self, mock_time): """Test basic log context formatting.""" mock_time.return_value = 1234567890.0 result = format_log_context("database_query") - expected = { - "operation": "database_query", - "timestamp": 1234567890.0 - } + expected = {"operation": "database_query", "timestamp": 1234567890.0} assert result == expected - @patch('time.time') + @patch("time.time") def test_log_context_with_additional_data(self, mock_time): """Test log context with additional context data.""" mock_time.return_value = 1234567890.0 - result = format_log_context( - "database_query", - table="users", - query_type="SELECT", - duration_ms=150 - ) + result = format_log_context("database_query", table="users", query_type="SELECT", duration_ms=150) expected = { "operation": "database_query", "timestamp": 1234567890.0, "table": "users", "query_type": "SELECT", - "duration_ms": 150 + "duration_ms": 150, } assert result == expected - @patch('time.time') + @patch("time.time") def test_context_data_overwrites_defaults(self, mock_time): """Test that context data can overwrite default keys.""" mock_time.return_value = 1234567890.0 - result = format_log_context( - "test_operation", - timestamp=9999999999.0, - custom_field="custom_value" - ) + result = format_log_context("test_operation", timestamp=9999999999.0, custom_field="custom_value") # Additional context should supplement and override timestamp - expected = { - "operation": "test_operation", - "timestamp": 9999999999.0, - "custom_field": "custom_value" - } + expected = {"operation": "test_operation", "timestamp": 9999999999.0, "custom_field": "custom_value"} assert result == expected - @patch('time.time') + @patch("time.time") def test_empty_additional_context(self, mock_time): """Test with no additional context data.""" mock_time.return_value = 1234567890.0 result = format_log_context("simple_operation") - expected = { - "operation": "simple_operation", - "timestamp": 1234567890.0 - } - assert result == expected \ No newline at end of file + expected = {"operation": "simple_operation", "timestamp": 1234567890.0} + assert result == expected