diff --git a/CHANGELOG.md b/CHANGELOG.md index ac82f11..e69c020 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,23 @@ All notable changes to `atomicmemory` will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.1.2] - 2026-06-15 + +### Security +- `api_url` is now validated against SSRF across all six SDK configs (the three + provider configs, the storage/client configs, and `EntitiesClientConfig`) via + one shared validator. It always rejects link-local / cloud-metadata addresses + (AWS IMDS `169.254.169.254`, IPv6 `fe80::/10`) — including their decimal + (`http://2852039166/`), hex, octal, short-form, and IPv4-mapped-IPv6 + (`::ffff:169.254.169.254`) encodings, which are canonicalized so they cannot + bypass the guard. Loopback / private / reserved IP literals remain allowed by + default — the SDK routinely connects to local and self-hosted cores — and are + rejected only when you opt into strict mode with `allowPrivateNetworks=False`. + Hostnames (incl. the `localhost` default) are intentionally not DNS-resolved + at config time. This matches the Node SDK's posture for cross-SDK parity, and + a reflective enumeration test fails if a new `api_url` config omits the guard. + (FailSafe AGNT-PY-001.) + ## [1.1.1] - 2026-06-11 ### Added diff --git a/CLAUDE.md b/CLAUDE.md index f640cbf..01cf939 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -43,6 +43,7 @@ Before changing code, read the relevant local files first: - Snake_case for Python attributes; Pydantic `Field(alias="apiUrl")` aliases preserve TS camelCase wire format. - Keep public API behavior aligned with `atomicmemory-sdk` where both SDKs expose the same concept. - Prefer integration tests with a real HTTP path for client behavior; use mocks only for narrow transport errors. +- **Cross-cutting controls live at one chokepoint, enumerated and bypass-tested.** When a security/correctness rule must hold for *all* of a category (every config with an `api_url`, every input reaching a sink), apply it through one shared helper, not per-surface — and back it with a **reflective enumeration test** that fails when a new surface lacks it (e.g. `test_every_api_url_config_blocks_imds` discovers every `BaseModel` with an `api_url` field). Tests must exercise the **adversarial bypass** (the encoding, the key, the header), not just the canonical example, and validate against the **downstream consumer's interpretation** (the resolver, Postgres, the server), not your own parser. This is the gap that caused AGNT-PY-001's missed `EntitiesClientConfig` and numeric-IP bypass. ## Pre-commit verification diff --git a/atomicmemory/_version.py b/atomicmemory/_version.py index 76c70d7..d997da5 100644 --- a/atomicmemory/_version.py +++ b/atomicmemory/_version.py @@ -4,4 +4,4 @@ __version__: The current package version string (PEP 440). """ -__version__ = "1.1.1" +__version__ = "1.1.2" diff --git a/atomicmemory/client/atomic_memory_client.py b/atomicmemory/client/atomic_memory_client.py index 2b2eeab..c2a3824 100644 --- a/atomicmemory/client/atomic_memory_client.py +++ b/atomicmemory/client/atomic_memory_client.py @@ -10,7 +10,6 @@ from types import TracebackType from typing import Any -from urllib.parse import urlparse from pydantic import BaseModel, ConfigDict, Field, SecretStr, field_validator, model_validator from pydantic import ValidationError as PydanticValidationError @@ -18,6 +17,7 @@ from atomicmemory.client.async_memory_client import AsyncMemoryClient from atomicmemory.client.memory_client import MemoryClient, MemoryProviderConfigs from atomicmemory.core.errors import ConfigError +from atomicmemory.core.url import validate_api_url from atomicmemory.core.validation import sanitized_pydantic_errors from atomicmemory.entities import AsyncEntitiesClient, EntitiesClient from atomicmemory.entities.client import EntitiesClientConfig @@ -48,17 +48,11 @@ class AtomicMemoryClientConfig(BaseModel): api_key: SecretStr = Field(alias="apiKey") user_id: str = Field(alias="userId") timeout_seconds: float = Field(default=30.0, alias="timeoutSeconds") + allow_private_networks: bool = Field(default=True, alias="allowPrivateNetworks") + """Permit loopback/private/reserved IP literals in ``api_url`` (default True; + set False to harden). Link-local / cloud-metadata stay blocked regardless.""" memory: MemoryNamespaceConfig | None = None - @field_validator("api_url") - @classmethod - def _validate_api_url(cls, value: str) -> str: - stripped = value.strip() - parsed = urlparse(stripped) - if parsed.scheme not in {"http", "https"} or not parsed.netloc: - raise ValueError("api_url must be an http(s) URL") - return stripped - @field_validator("api_key", mode="before") @classmethod def _validate_api_key(cls, value: object) -> object: @@ -88,6 +82,7 @@ def _validate_timeout(cls, value: float) -> float: def _require_non_empty(self) -> AtomicMemoryClientConfig: if not self.api_url: raise ValueError("api_url is required") + self.api_url = validate_api_url(self.api_url, allow_private_networks=self.allow_private_networks) # api_key is always truthy as SecretStr; empty string rejected by _validate_api_key above. if not self.user_id: raise ValueError("user_id is required") diff --git a/atomicmemory/core/url.py b/atomicmemory/core/url.py new file mode 100644 index 0000000..09de256 --- /dev/null +++ b/atomicmemory/core/url.py @@ -0,0 +1,114 @@ +"""Shared ``api_url`` validation used by every SDK config boundary. + +Centralizes the rule that an ``api_url`` must be an http(s) URL with a +host, and adds SSRF defense: link-local / cloud-metadata addresses +(notably the ``169.254.169.254`` IMDS endpoint) are always rejected. +Loopback / private / reserved IP literals are *allowed by default* — the +SDK routinely connects to local and self-hosted cores — and only rejected +when the caller opts into strict mode via ``allow_private_networks=False``. +This mirrors the Node SDK's posture for cross-SDK parity. + +Hostnames are intentionally NOT resolved here. Config-time DNS resolution +would be slow, racy, and still bypassable via DNS rebinding, so a literal +hostname (including ``localhost`` and ``metadata.google.internal``) passes +the scheme/host checks. Deployments that must defend against +hostname-based metadata access should pin ``api_url`` to a vetted host. +""" + +from __future__ import annotations + +import ipaddress +import socket +from urllib.parse import urlparse + +_ALLOWED_SCHEMES = frozenset({"http", "https"}) + + +def _parse_ip(host: str) -> ipaddress.IPv4Address | ipaddress.IPv6Address | None: + """Return the parsed IP when ``host`` is an IP literal, else ``None``. + + Covers canonical literals AND the legacy IPv4 encodings the C resolver + (``inet_aton``/``getaddrinfo``) still accepts — decimal (``2852039166``), + hex (``0xA9FEA9FE``), octal (``0251.0376.0251.0376``) and short forms + (``127.1``). Without this they slip through as un-resolved "hostnames" and + defeat the SSRF checks, since the HTTP client resolves them to the real + address (e.g. ``http://2852039166/`` → ``169.254.169.254``). + + Args: + host: The URL host component. + + Returns: + The parsed/canonicalized IP address, or ``None`` when ``host`` is a + genuine (non-numeric) hostname. + """ + try: + return _collapse_mapped(ipaddress.ip_address(host)) + except ValueError: + pass + try: + return ipaddress.IPv4Address(socket.inet_aton(host)) + except (OSError, ValueError): + return None + + +def _collapse_mapped( + ip: ipaddress.IPv4Address | ipaddress.IPv6Address, +) -> ipaddress.IPv4Address | ipaddress.IPv6Address: + """Reclassify an IPv4-mapped IPv6 address (``::ffff:a.b.c.d``) as its IPv4. + + ``IPv6Address.is_link_local`` only delegates to the embedded IPv4 on + newer CPython, so on Python 3.10/3.11 ``::ffff:169.254.169.254`` would + otherwise read as a benign global IPv6 and bypass the metadata block. + Collapsing to the embedded IPv4 makes classification deterministic + across all supported interpreters and matches the Node SDK. + + Args: + ip: A parsed IP literal. + + Returns: + The embedded IPv4 when ``ip`` is IPv4-mapped, otherwise ``ip``. + """ + mapped = getattr(ip, "ipv4_mapped", None) + return mapped if mapped is not None else ip + + +def validate_api_url(value: str, *, allow_private_networks: bool = True) -> str: + """Validate and normalize an ``api_url``, guarding against SSRF. + + Args: + value: The candidate URL. + allow_private_networks: Defaults to ``True`` — loopback / private / + reserved IP literals are permitted because the SDK routinely + connects to local and self-hosted cores. Pass ``False`` to reject + those too (hardened multi-tenant deployments). Link-local / + cloud-metadata addresses are rejected regardless of this flag. + + Returns: + The whitespace-stripped URL. + + Raises: + ValueError: If the scheme is not http(s), the host is missing, or + the host is a disallowed IP literal. + """ + stripped = value.strip() + parsed = urlparse(stripped) + if parsed.scheme not in _ALLOWED_SCHEMES or not parsed.netloc: + raise ValueError("api_url must be an http(s) URL") + host = parsed.hostname + if not host: + raise ValueError("api_url must include a host") + + ip = _parse_ip(host) + if ip is None: + return stripped + + if ip.is_link_local: + raise ValueError("api_url must not target a link-local or cloud-metadata address") + if not allow_private_networks and ( + ip.is_loopback or ip.is_private or ip.is_reserved or ip.is_multicast or ip.is_unspecified + ): + raise ValueError( + "api_url must not target a loopback, private, or reserved address; " + "set allow_private_networks=True to permit it" + ) + return stripped diff --git a/atomicmemory/entities/client.py b/atomicmemory/entities/client.py index b18a278..f89fb26 100644 --- a/atomicmemory/entities/client.py +++ b/atomicmemory/entities/client.py @@ -25,12 +25,13 @@ import json from types import TracebackType from typing import Any, TypeVar, cast -from urllib.parse import quote, urlencode, urlparse +from urllib.parse import quote, urlencode import httpx -from pydantic import BaseModel, ConfigDict, Field, SecretStr, field_validator +from pydantic import BaseModel, ConfigDict, Field, SecretStr, field_validator, model_validator from pydantic import ValidationError as PydanticValidationError +from atomicmemory.core.url import validate_api_url from atomicmemory.entities.errors import EntitiesClientError from atomicmemory.entities.types import ( DeleteEntityResult, @@ -63,15 +64,14 @@ class EntitiesClientConfig(BaseModel): api_url: str = Field(alias="apiUrl") api_key: SecretStr = Field(alias="apiKey") timeout_seconds: float = Field(default=30.0, alias="timeoutSeconds") + allow_private_networks: bool = Field(default=True, alias="allowPrivateNetworks") + """Permit loopback/private/reserved IP literals in ``api_url`` (default True; + set False to harden). Link-local / cloud-metadata stay blocked regardless.""" - @field_validator("api_url") - @classmethod - def _validate_api_url(cls, value: str) -> str: - stripped = value.strip() - parsed = urlparse(stripped) - if parsed.scheme not in {"http", "https"} or not parsed.netloc: - raise ValueError("api_url must be an http(s) URL") - return stripped + @model_validator(mode="after") + def _validate_api_url(self) -> EntitiesClientConfig: + self.api_url = validate_api_url(self.api_url, allow_private_networks=self.allow_private_networks) + return self @field_validator("api_key", mode="before") @classmethod diff --git a/atomicmemory/providers/atomicmemory/config.py b/atomicmemory/providers/atomicmemory/config.py index 7d87669..f1b4215 100644 --- a/atomicmemory/providers/atomicmemory/config.py +++ b/atomicmemory/providers/atomicmemory/config.py @@ -5,8 +5,9 @@ from __future__ import annotations -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, model_validator +from atomicmemory.core.url import validate_api_url from atomicmemory.memory.meta_fact_filter import MetaFactFilterConfig ATOMICMEMORY_DEFAULT_TIMEOUT_SECONDS: float = 30.0 @@ -41,3 +42,12 @@ class AtomicMemoryProviderConfig(BaseModel): meta_fact_filter: MetaFactFilterConfig | None = Field(default=None, alias="metaFactFilter") """Optional opt-in post-retrieval meta-fact filter. Off when unset.""" + + allow_private_networks: bool = Field(default=True, alias="allowPrivateNetworks") + """Permit loopback/private/reserved IP literals in ``api_url`` (default True; + set False to harden). Link-local / cloud-metadata stay blocked regardless.""" + + @model_validator(mode="after") + def _validate_api_url(self) -> AtomicMemoryProviderConfig: + self.api_url = validate_api_url(self.api_url, allow_private_networks=self.allow_private_networks) + return self diff --git a/atomicmemory/providers/hindsight/config.py b/atomicmemory/providers/hindsight/config.py index fba8562..783387b 100644 --- a/atomicmemory/providers/hindsight/config.py +++ b/atomicmemory/providers/hindsight/config.py @@ -9,8 +9,9 @@ from collections.abc import Awaitable, Callable from typing import Any, Literal -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, model_validator +from atomicmemory.core.url import validate_api_url from atomicmemory.memory.types import IngestInput, Scope HindsightRecallBudget = Literal["low", "mid", "high"] @@ -38,6 +39,14 @@ class HindsightProviderConfig(BaseModel): project_id: str = Field(default=HINDSIGHT_DEFAULT_PROJECT_ID, alias="projectId") default_budget: HindsightRecallBudget | None = Field(default=None, alias="defaultBudget") default_max_tokens: int | None = Field(default=None, alias="defaultMaxTokens") + allow_private_networks: bool = Field(default=True, alias="allowPrivateNetworks") + """Permit loopback/private/reserved IP literals in ``api_url`` (default True; + set False to harden). Link-local / cloud-metadata stay blocked regardless.""" + + @model_validator(mode="after") + def _validate_api_url(self) -> HindsightProviderConfig: + self.api_url = validate_api_url(self.api_url, allow_private_networks=self.allow_private_networks) + return self class HindsightRetainResponse(BaseModel): diff --git a/atomicmemory/providers/mem0/config.py b/atomicmemory/providers/mem0/config.py index 84c81b8..296a8f0 100644 --- a/atomicmemory/providers/mem0/config.py +++ b/atomicmemory/providers/mem0/config.py @@ -5,7 +5,9 @@ from __future__ import annotations -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from atomicmemory.core.url import validate_api_url MEM0_DEFAULT_TIMEOUT_SECONDS: float = 30.0 MEM0_DEFAULT_PATH_PREFIX: str = "/v1" @@ -49,3 +51,12 @@ class Mem0ProviderConfig(BaseModel): org_id: str | None = Field(default=None, alias="orgId") project_id: str | None = Field(default=None, alias="projectId") + + allow_private_networks: bool = Field(default=True, alias="allowPrivateNetworks") + """Permit loopback/private/reserved IP literals in ``api_url`` (default True; + set False to harden). Link-local / cloud-metadata stay blocked regardless.""" + + @model_validator(mode="after") + def _validate_api_url(self) -> Mem0ProviderConfig: + self.api_url = validate_api_url(self.api_url, allow_private_networks=self.allow_private_networks) + return self diff --git a/atomicmemory/storage/types.py b/atomicmemory/storage/types.py index dfff900..e53c65f 100644 --- a/atomicmemory/storage/types.py +++ b/atomicmemory/storage/types.py @@ -8,10 +8,11 @@ from __future__ import annotations from typing import Any, Literal -from urllib.parse import urlparse from pydantic import BaseModel, ConfigDict, Field, SecretStr, field_validator, model_validator +from atomicmemory.core.url import validate_api_url + StorageArtifactStatus = Literal[ "stored", "pending", @@ -42,15 +43,9 @@ class StorageClientConfig(BaseModel): api_key: SecretStr = Field(alias="apiKey") user_id: str = Field(alias="userId") timeout_seconds: float = Field(default=30.0, alias="timeoutSeconds") - - @field_validator("api_url") - @classmethod - def _validate_api_url(cls, value: str) -> str: - stripped = value.strip() - parsed = urlparse(stripped) - if parsed.scheme not in {"http", "https"} or not parsed.netloc: - raise ValueError("api_url must be an http(s) URL") - return stripped + allow_private_networks: bool = Field(default=True, alias="allowPrivateNetworks") + """Permit loopback/private/reserved IP literals in ``api_url`` (default True; + set False to harden). Link-local / cloud-metadata stay blocked regardless.""" @field_validator("api_key", mode="before") @classmethod @@ -81,6 +76,7 @@ def _validate_timeout(cls, value: float) -> float: def _require_non_empty(self) -> StorageClientConfig: if not self.api_url: raise ValueError("api_url is required") + self.api_url = validate_api_url(self.api_url, allow_private_networks=self.allow_private_networks) # api_key is always truthy as SecretStr; empty string rejected by _validate_api_key above. if not self.user_id: raise ValueError("user_id is required") diff --git a/pyproject.toml b/pyproject.toml index d6ca4d6..66a507f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "atomicmemory" -version = "1.1.1" +version = "1.1.2" description = "Python client SDK for AtomicMemory memory and artifact storage." readme = "README.md" requires-python = ">=3.10" diff --git a/tests/core/test_url.py b/tests/core/test_url.py new file mode 100644 index 0000000..ab8860e --- /dev/null +++ b/tests/core/test_url.py @@ -0,0 +1,119 @@ +"""Tests for the shared ``api_url`` SSRF guard. + +Verifies that :func:`atomicmemory.core.url.validate_api_url` enforces an +http(s) scheme, always rejects link-local / cloud-metadata addresses +(e.g. the AWS IMDS endpoint) even when private networks are permitted, +and gates loopback / private / reserved IPs behind the explicit +``allow_private_networks`` opt-in. Hostnames are intentionally not +DNS-resolved at config time. +""" + +from __future__ import annotations + +import pytest + +from atomicmemory.core.url import validate_api_url + + +def test_allows_public_http_and_https() -> None: + assert validate_api_url("https://api.example.com") == "https://api.example.com" + assert validate_api_url("http://core.test:17350") == "http://core.test:17350" + + +def test_allows_hostnames_without_dns_resolution() -> None: + # Literal hostnames (incl. localhost) are not resolved at config time. + assert validate_api_url("http://localhost:17350") == "http://localhost:17350" + + +@pytest.mark.parametrize("bad", ["not-a-url", "ftp://host/x", "file:///etc/passwd", "://no-scheme"]) +def test_rejects_non_http_scheme_or_missing_host(bad: str) -> None: + with pytest.raises(ValueError): + validate_api_url(bad) + + +@pytest.mark.parametrize( + "metadata_url", + [ + "http://169.254.169.254/latest/meta-data/", + "http://[fe80::1]/x", + ], +) +def test_always_rejects_link_local_even_when_private_allowed(metadata_url: str) -> None: + with pytest.raises(ValueError): + validate_api_url(metadata_url, allow_private_networks=True) + + +@pytest.mark.parametrize( + "private_url", + [ + "http://127.0.0.1:17350", + "http://10.0.0.5/api", + "http://192.168.1.10:8080", + "http://172.16.0.1/x", + "http://[::1]:17350", + ], +) +def test_allows_private_and_loopback_ips_by_default(private_url: str) -> None: + # Posture B: the SDK connects to local/self-hosted cores, so these pass. + assert validate_api_url(private_url) == private_url + + +@pytest.mark.parametrize( + "private_url", + [ + "http://127.0.0.1:17350", + "http://10.0.0.5/api", + "http://172.16.0.1/x", + "http://[::1]:17350", + ], +) +def test_rejects_private_and_loopback_ips_when_strict(private_url: str) -> None: + with pytest.raises(ValueError): + validate_api_url(private_url, allow_private_networks=False) + + +@pytest.mark.parametrize( + "mapped_imds", + ["http://[::ffff:169.254.169.254]/", "http://[::ffff:a9fe:a9fe]/"], +) +def test_rejects_ipv4_mapped_ipv6_metadata(mapped_imds: str) -> None: + # ::ffff:169.254.169.254 collapses to the embedded IPv4 and stays blocked + # even with private networks allowed (cross-Python determinism). + with pytest.raises(ValueError): + validate_api_url(mapped_imds, allow_private_networks=True) + + +def test_strips_surrounding_whitespace() -> None: + assert validate_api_url(" https://api.example.com ") == "https://api.example.com" + + +@pytest.mark.parametrize( + "encoded_imds", + [ + "http://2852039166/latest/meta-data/", # decimal 169.254.169.254 + "http://0xA9FEA9FE/", # hex 169.254.169.254 + "http://0251.0376.0251.0376/", # dotted-octal 169.254.169.254 + ], +) +def test_rejects_numeric_encoded_metadata_address(encoded_imds: str) -> None: + # Legacy IPv4 encodings the resolver still accepts must not bypass the + # always-on link-local/metadata block (even with private networks allowed). + with pytest.raises(ValueError): + validate_api_url(encoded_imds, allow_private_networks=True) + + +@pytest.mark.parametrize( + "encoded_private", + [ + "http://2130706433:17350/", # decimal 127.0.0.1 + "http://0x7f000001/", # hex loopback + "http://127.1/", # short-form loopback + "http://0/", # 0.0.0.0 (unspecified) + ], +) +def test_rejects_numeric_encoded_private_when_strict(encoded_private: str) -> None: + # Posture B allows these by default (loopback/unspecified are local cores); + # strict mode must still reject them, encoding notwithstanding. + assert validate_api_url(encoded_private) == encoded_private + with pytest.raises(ValueError): + validate_api_url(encoded_private, allow_private_networks=False) diff --git a/tests/providers/test_config_ssrf.py b/tests/providers/test_config_ssrf.py new file mode 100644 index 0000000..76875fa --- /dev/null +++ b/tests/providers/test_config_ssrf.py @@ -0,0 +1,139 @@ +"""SSRF-guard coverage for every SDK config that accepts ``api_url``. + +Each provider/client config must always reject the AWS IMDS link-local +endpoint (and its encodings), while allowing private/loopback IP literals +by default (posture B — local/self-hosted cores) and rejecting them only +when ``allow_private_networks=False``. This pins the consistency the +FailSafe report (AGNT-PY-001) found missing: previously only the +storage/client configs validated the URL while the provider configs +accepted any string. +""" + +from __future__ import annotations + +import importlib +import pkgutil + +import pytest +from pydantic import BaseModel, ValidationError + +import atomicmemory +from atomicmemory.client.atomic_memory_client import AtomicMemoryClientConfig +from atomicmemory.providers.atomicmemory.config import AtomicMemoryProviderConfig +from atomicmemory.providers.hindsight.config import HindsightProviderConfig +from atomicmemory.providers.mem0.config import Mem0ProviderConfig +from atomicmemory.storage.types import StorageClientConfig + +_IMDS = "http://169.254.169.254/latest/meta-data/" +_LOOPBACK = "http://127.0.0.1:17350" + + +def _client_kwargs(api_url: str, **extra: object) -> dict[str, object]: + return {"apiUrl": api_url, "apiKey": "secret", "userId": "u1", **extra} + + +def test_provider_configs_reject_imds_endpoint() -> None: + for factory in ( + lambda u: AtomicMemoryProviderConfig(apiUrl=u), + lambda u: HindsightProviderConfig(apiUrl=u), + lambda u: Mem0ProviderConfig(apiUrl=u), + ): + with pytest.raises(ValidationError): + factory(_IMDS) + + +def test_client_configs_reject_imds_endpoint() -> None: + with pytest.raises(ValidationError): + StorageClientConfig(**_client_kwargs(_IMDS)) + with pytest.raises(ValidationError): + AtomicMemoryClientConfig(**_client_kwargs(_IMDS)) + + +def test_provider_config_allows_loopback_ip_by_default() -> None: + # Posture B: providers connect to local/self-hosted cores by default. + cfg = AtomicMemoryProviderConfig(apiUrl=_LOOPBACK) + assert cfg.api_url == _LOOPBACK + + +def test_provider_config_rejects_loopback_ip_when_strict() -> None: + with pytest.raises(ValidationError): + AtomicMemoryProviderConfig(apiUrl=_LOOPBACK, allowPrivateNetworks=False) + + +def test_imds_rejected_even_with_private_networks_allowed() -> None: + with pytest.raises(ValidationError): + AtomicMemoryProviderConfig(apiUrl=_IMDS, allowPrivateNetworks=True) + + +def test_localhost_hostname_still_allowed_when_strict() -> None: + # Hostnames are not DNS-resolved, so localhost passes even in strict mode. + cfg = Mem0ProviderConfig(apiUrl="http://localhost:8888", allowPrivateNetworks=False) + assert cfg.api_url == "http://localhost:8888" + + +def test_config_rejects_decimal_encoded_imds() -> None: + # The numeric-encoding bypass must be closed end-to-end through a config, + # not just in the standalone validator: http://2852039166/ == IMDS. + with pytest.raises(ValidationError): + AtomicMemoryProviderConfig(apiUrl="http://2852039166/latest/meta-data/", allowPrivateNetworks=True) + + +def test_entities_config_rejects_imds_literal_and_encoded() -> None: + # EntitiesClientConfig (shared by sync + async entities clients) is the 6th + # api_url config and must enforce the same SSRF guard. + from atomicmemory.entities.client import EntitiesClientConfig + + for url in (_IMDS, "http://2852039166/latest/meta-data/"): + with pytest.raises(ValidationError): + EntitiesClientConfig(apiUrl=url, apiKey="secret") + + +def test_entities_config_allows_loopback_by_default_blocks_when_strict() -> None: + from atomicmemory.entities.client import EntitiesClientConfig + + ok = EntitiesClientConfig(apiUrl=_LOOPBACK, apiKey="secret") + assert ok.api_url == _LOOPBACK + with pytest.raises(ValidationError): + EntitiesClientConfig(apiUrl=_LOOPBACK, apiKey="secret", allowPrivateNetworks=False) + host = EntitiesClientConfig(apiUrl="http://localhost:8888", apiKey="secret") + assert host.api_url == "http://localhost:8888" + + +def _discover_api_url_configs() -> list[type[BaseModel]]: + """Every Pydantic config in the package that exposes an ``api_url`` field. + + Imports the whole ``atomicmemory`` package so a newly added config is + discovered automatically — this is the guard that fails when a future + config forgets the shared SSRF validator (the exact gap AGNT-PY-001 and + its EntitiesClientConfig follow-up were). + """ + for mod in pkgutil.walk_packages(atomicmemory.__path__, "atomicmemory."): + importlib.import_module(mod.name) + found: dict[type[BaseModel], None] = {} + + def walk(cls: type[BaseModel]) -> None: + for sub in cls.__subclasses__(): + if "api_url" in sub.model_fields: + found[sub] = None + walk(sub) + + walk(BaseModel) + return list(found) + + +def _dummy_required_kwargs(model: type[BaseModel]) -> dict[str, object]: + """Minimal valid kwargs (by alias) for every required field except api_url.""" + kwargs: dict[str, object] = {} + for name, field in model.model_fields.items(): + if name == "api_url" or not field.is_required(): + continue + kwargs[field.alias or name] = "x" + return kwargs + + +def test_every_api_url_config_blocks_imds() -> None: + configs = _discover_api_url_configs() + assert len(configs) >= 6, f"expected to discover >= 6 api_url configs, found {len(configs)}: {configs}" + for model in configs: + with pytest.raises(ValidationError): + model(apiUrl=_IMDS, **_dummy_required_kwargs(model)) diff --git a/uv.lock b/uv.lock index aecd770..6fd0559 100644 --- a/uv.lock +++ b/uv.lock @@ -92,7 +92,7 @@ wheels = [ [[package]] name = "atomicmemory" -version = "1.1.1" +version = "1.1.2" source = { editable = "." } dependencies = [ { name = "httpx" },