diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index e81dbf704..8db39a5fd 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -396,6 +396,10 @@ async def set_gateway_wildcard_domain( old_domain = gateway.wildcard_domain if old_domain != wildcard_domain: gateway.wildcard_domain = wildcard_domain + if gateway.configuration is not None: + conf = get_gateway_configuration(gateway) + conf.domain = wildcard_domain + gateway.configuration = conf.json() events.emit( session, f"Gateway wildcard domain changed {old_domain!r} -> {gateway.wildcard_domain!r}", diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 2c0a66be5..f2e633532 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -44,7 +44,11 @@ SSHHostParams, SSHParams, ) -from dstack._internal.core.models.gateways import GatewayComputeConfiguration, GatewayStatus +from dstack._internal.core.models.gateways import ( + GatewayComputeConfiguration, + GatewayConfiguration, + GatewayStatus, +) from dstack._internal.core.models.health import HealthStatus from dstack._internal.core.models.instances import ( Disk, @@ -641,13 +645,31 @@ async def create_gateway( status: Optional[GatewayStatus] = GatewayStatus.SUBMITTED, last_processed_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), forbid_new_services: bool = False, + populate_configuration: bool = True, ) -> GatewayModel: + """ + Args: + populate_configuration: whether to populate GatewayModel.configuration. + True - 0.18.2+ gateways, False - legacy pre-0.18.2 gateways. Prefer + testing against both in major test cases. + """ + configuration = None + if populate_configuration: + backend = await session.get(BackendModel, backend_id) + assert backend is not None + configuration = GatewayConfiguration( + name=name, + backend=backend.type, + region=region, + domain=wildcard_domain, + ).json() gateway = GatewayModel( project_id=project_id, backend_id=backend_id, name=name, region=region, wildcard_domain=wildcard_domain, + configuration=configuration, status=status, last_processed_at=last_processed_at, forbid_new_services=forbid_new_services, @@ -666,7 +688,30 @@ async def create_gateway_compute( instance_id: Optional[str] = "i-1234567890", ssh_private_key: str = "", ssh_public_key: str = "", + populate_configuration: bool = True, ) -> GatewayComputeModel: + """ + Args: + populate_configuration: whether to populate GatewayComputeModel.configuration. + True - 0.18.2+ gateways, False - legacy pre-0.18.2 gateways. Prefer + testing against both in major test cases. + """ + configuration = None + if populate_configuration: + backend_type = BackendType.AWS + if backend_id is not None: + backend = await session.get(BackendModel, backend_id) + assert backend is not None + backend_type = backend.type + configuration = GatewayComputeConfiguration( + project_name="test-project", + instance_name=instance_id or "test-instance", + backend=backend_type, + region=region, + public_ip=True, + ssh_key_pub=ssh_public_key, + certificate=None, + ).json() gateway_compute = GatewayComputeModel( gateway_id=gateway_id, backend_id=backend_id, @@ -675,6 +720,7 @@ async def create_gateway_compute( instance_id=instance_id, ssh_private_key=ssh_private_key, ssh_public_key=ssh_public_key, + configuration=configuration, ) session.add(gateway_compute) await session.commit() diff --git a/src/tests/_internal/server/routers/test_gateways.py b/src/tests/_internal/server/routers/test_gateways.py index 075d1f6d4..13d1966c6 100644 --- a/src/tests/_internal/server/routers/test_gateways.py +++ b/src/tests/_internal/server/routers/test_gateways.py @@ -31,8 +31,14 @@ async def test_returns_40x_if_not_authenticated(self, client: AsyncClient): @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @pytest.mark.parametrize("legacy_compute", [False, True]) + @pytest.mark.parametrize("populate_configuration", [True, False]) async def test_list( - self, test_db, session: AsyncSession, client: AsyncClient, legacy_compute: bool + self, + test_db, + session: AsyncSession, + client: AsyncClient, + legacy_compute: bool, + populate_configuration: bool, ): user = await create_user(session, global_role=GlobalRole.USER) project = await create_project(session) @@ -44,13 +50,21 @@ async def test_list( session=session, project_id=project.id, backend_id=backend.id, + populate_configuration=populate_configuration, ) if legacy_compute: - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway_compute = await create_gateway_compute( + session=session, + backend_id=backend.id, + populate_configuration=populate_configuration, + ) gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style else: gateway_compute = await create_gateway_compute( - session=session, backend_id=backend.id, gateway_id=gateway.id + session=session, + backend_id=backend.id, + gateway_id=gateway.id, + populate_configuration=populate_configuration, ) await session.commit() response = await client.post( @@ -102,8 +116,14 @@ async def test_list( @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @pytest.mark.parametrize("legacy_compute", [False, True]) + @pytest.mark.parametrize("populate_configuration", [True, False]) async def test_get( - self, test_db, session: AsyncSession, client: AsyncClient, legacy_compute: bool + self, + test_db, + session: AsyncSession, + client: AsyncClient, + legacy_compute: bool, + populate_configuration: bool, ): user = await create_user(session, global_role=GlobalRole.USER) project = await create_project(session) @@ -115,13 +135,21 @@ async def test_get( session=session, project_id=project.id, backend_id=backend.id, + populate_configuration=populate_configuration, ) if legacy_compute: - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway_compute = await create_gateway_compute( + session=session, + backend_id=backend.id, + populate_configuration=populate_configuration, + ) gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style else: gateway_compute = await create_gateway_compute( - session=session, backend_id=backend.id, gateway_id=gateway.id + session=session, + backend_id=backend.id, + gateway_id=gateway.id, + populate_configuration=populate_configuration, ) await session.commit() response = await client.post( @@ -797,7 +825,10 @@ async def test_only_admin_can_set_default_gateway( @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_set_default_gateway(self, test_db, session: AsyncSession, client: AsyncClient): + @pytest.mark.parametrize("populate_configuration", [True, False]) + async def test_set_default_gateway( + self, test_db, session: AsyncSession, client: AsyncClient, populate_configuration: bool + ): user = await create_user(session, global_role=GlobalRole.USER) project = await create_project(session) await add_project_member( @@ -809,11 +840,13 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client: project_id=project.id, backend_id=backend.id, name="first_gateway", + populate_configuration=populate_configuration, ) gateway_compute = await create_gateway_compute( session=session, backend_id=backend.id, gateway_id=gateway.id, + populate_configuration=populate_configuration, ) response = await client.post( f"/api/project/{project.name}/gateways/set_default", @@ -875,11 +908,13 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client: project_id=project.id, backend_id=backend.id, name="second_gateway", + populate_configuration=populate_configuration, ) await create_gateway_compute( session=session, backend_id=backend.id, gateway_id=second_gateway.id, + populate_configuration=populate_configuration, ) await clear_events(session) response = await client.post( @@ -1061,8 +1096,13 @@ async def test_only_admin_can_delete( @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.parametrize("populate_configuration", [True, False]) async def test_marks_gateways_to_be_deleted( - self, test_db, session: AsyncSession, client: AsyncClient + self, + test_db, + session: AsyncSession, + client: AsyncClient, + populate_configuration: bool, ): user = await create_user(session, global_role=GlobalRole.USER) project = await create_project(session) @@ -1076,22 +1116,26 @@ async def test_marks_gateways_to_be_deleted( project_id=project.id, backend_id=backend_aws.id, name="gateway-aws", + populate_configuration=populate_configuration, ) gateway_compute_aws = await create_gateway_compute( session=session, backend_id=backend_aws.id, gateway_id=gateway_aws.id, + populate_configuration=populate_configuration, ) gateway_gcp = await create_gateway( session=session, project_id=project.id, backend_id=backend_gcp.id, name="gateway-gcp", + populate_configuration=populate_configuration, ) gateway_compute_gcp = await create_gateway_compute( session=session, backend_id=backend_gcp.id, gateway_id=gateway_gcp.id, + populate_configuration=populate_configuration, ) response = await client.post( f"/api/project/{project.name}/gateways/delete", @@ -1183,7 +1227,10 @@ async def test_only_admin_can_set_wildcard_domain( @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client: AsyncClient): + @pytest.mark.parametrize("populate_configuration", [True, False]) + async def test_set_wildcard_domain( + self, test_db, session: AsyncSession, client: AsyncClient, populate_configuration: bool + ): user = await create_user(session, global_role=GlobalRole.USER) project = await create_project(session) await add_project_member( @@ -1195,11 +1242,13 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client: project_id=project.id, backend_id=backend.id, wildcard_domain="old.example", + populate_configuration=populate_configuration, ) gateway_compute = await create_gateway_compute( session=session, backend_id=backend.id, gateway_id=gateway.id, + populate_configuration=populate_configuration, ) response = await client.post( f"/api/project/{project.name}/gateways/set_wildcard_domain", diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 82b5c53f8..7fc21ac2f 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -3831,6 +3831,52 @@ async def test_submit_to_correct_proxy( events = await list_events(session) assert ("Service registered in gateway" in {e.message for e in events}) == is_gateway + @pytest.mark.asyncio + @pytest.mark.parametrize("populate_configuration", [True, False]) + async def test_submit_to_gateway_by_name( + self, + test_db, + session: AsyncSession, + client: AsyncClient, + populate_configuration: bool, + ) -> None: + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user, name="test-project") + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + repo = await create_repo(session=session, project_id=project.id) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.RUNNING, + name="my-gateway", + wildcard_domain="my-gateway.example", + populate_configuration=populate_configuration, + ) + await create_gateway_compute( + session=session, + backend_id=backend.id, + gateway_id=gateway.id, + populate_configuration=populate_configuration, + ) + run_spec = get_service_run_spec( + repo_id=repo.name, + run_name="test-service", + gateway="my-gateway", + ) + response = await client.post( + f"/api/project/{project.name}/runs/submit", + headers=get_auth_headers(user.token), + json={"run_spec": run_spec}, + ) + assert response.status_code == 200 + assert response.json()["service"]["url"] == "https://test-service.my-gateway.example" + events = await list_events(session) + assert "Service registered in gateway" in {e.message for e in events} + @pytest.mark.asyncio async def test_return_error_if_specified_gateway_not_exists( self, test_db, session: AsyncSession, client: AsyncClient