Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/dstack/_internal/server/services/gateways/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,10 @@ async def set_gateway_wildcard_domain(
old_domain = gateway.wildcard_domain
if old_domain != wildcard_domain:
gateway.wildcard_domain = wildcard_domain
if gateway.configuration is not None:
conf = get_gateway_configuration(gateway)
conf.domain = wildcard_domain
gateway.configuration = conf.json()
events.emit(
session,
f"Gateway wildcard domain changed {old_domain!r} -> {gateway.wildcard_domain!r}",
Expand Down
48 changes: 47 additions & 1 deletion src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@
SSHHostParams,
SSHParams,
)
from dstack._internal.core.models.gateways import GatewayComputeConfiguration, GatewayStatus
from dstack._internal.core.models.gateways import (
GatewayComputeConfiguration,
GatewayConfiguration,
GatewayStatus,
)
from dstack._internal.core.models.health import HealthStatus
from dstack._internal.core.models.instances import (
Disk,
Expand Down Expand Up @@ -641,13 +645,31 @@ async def create_gateway(
status: Optional[GatewayStatus] = GatewayStatus.SUBMITTED,
last_processed_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
forbid_new_services: bool = False,
populate_configuration: bool = True,
) -> GatewayModel:
"""
Args:
populate_configuration: whether to populate GatewayModel.configuration.
True - 0.18.2+ gateways, False - legacy pre-0.18.2 gateways. Prefer
testing against both in major test cases.
"""
configuration = None
if populate_configuration:
backend = await session.get(BackendModel, backend_id)
assert backend is not None
configuration = GatewayConfiguration(
name=name,
backend=backend.type,
region=region,
domain=wildcard_domain,
).json()
gateway = GatewayModel(
project_id=project_id,
backend_id=backend_id,
name=name,
region=region,
wildcard_domain=wildcard_domain,
configuration=configuration,
status=status,
last_processed_at=last_processed_at,
forbid_new_services=forbid_new_services,
Expand All @@ -666,7 +688,30 @@ async def create_gateway_compute(
instance_id: Optional[str] = "i-1234567890",
ssh_private_key: str = "",
ssh_public_key: str = "",
populate_configuration: bool = True,
) -> GatewayComputeModel:
"""
Args:
populate_configuration: whether to populate GatewayComputeModel.configuration.
True - 0.18.2+ gateways, False - legacy pre-0.18.2 gateways. Prefer
testing against both in major test cases.
"""
configuration = None
if populate_configuration:
backend_type = BackendType.AWS
if backend_id is not None:
backend = await session.get(BackendModel, backend_id)
assert backend is not None
backend_type = backend.type
configuration = GatewayComputeConfiguration(
project_name="test-project",
instance_name=instance_id or "test-instance",
backend=backend_type,
region=region,
public_ip=True,
ssh_key_pub=ssh_public_key,
certificate=None,
).json()
gateway_compute = GatewayComputeModel(
gateway_id=gateway_id,
backend_id=backend_id,
Expand All @@ -675,6 +720,7 @@ async def create_gateway_compute(
instance_id=instance_id,
ssh_private_key=ssh_private_key,
ssh_public_key=ssh_public_key,
configuration=configuration,
)
session.add(gateway_compute)
await session.commit()
Expand Down
67 changes: 58 additions & 9 deletions src/tests/_internal/server/routers/test_gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,14 @@ async def test_returns_40x_if_not_authenticated(self, client: AsyncClient):
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@pytest.mark.parametrize("legacy_compute", [False, True])
@pytest.mark.parametrize("populate_configuration", [True, False])
async def test_list(
self, test_db, session: AsyncSession, client: AsyncClient, legacy_compute: bool
self,
test_db,
session: AsyncSession,
client: AsyncClient,
legacy_compute: bool,
populate_configuration: bool,
):
user = await create_user(session, global_role=GlobalRole.USER)
project = await create_project(session)
Expand All @@ -44,13 +50,21 @@ async def test_list(
session=session,
project_id=project.id,
backend_id=backend.id,
populate_configuration=populate_configuration,
)
if legacy_compute:
gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
gateway_compute = await create_gateway_compute(
session=session,
backend_id=backend.id,
populate_configuration=populate_configuration,
)
gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style
else:
gateway_compute = await create_gateway_compute(
session=session, backend_id=backend.id, gateway_id=gateway.id
session=session,
backend_id=backend.id,
gateway_id=gateway.id,
populate_configuration=populate_configuration,
)
await session.commit()
response = await client.post(
Expand Down Expand Up @@ -102,8 +116,14 @@ async def test_list(
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@pytest.mark.parametrize("legacy_compute", [False, True])
@pytest.mark.parametrize("populate_configuration", [True, False])
async def test_get(
self, test_db, session: AsyncSession, client: AsyncClient, legacy_compute: bool
self,
test_db,
session: AsyncSession,
client: AsyncClient,
legacy_compute: bool,
populate_configuration: bool,
):
user = await create_user(session, global_role=GlobalRole.USER)
project = await create_project(session)
Expand All @@ -115,13 +135,21 @@ async def test_get(
session=session,
project_id=project.id,
backend_id=backend.id,
populate_configuration=populate_configuration,
)
if legacy_compute:
gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
gateway_compute = await create_gateway_compute(
session=session,
backend_id=backend.id,
populate_configuration=populate_configuration,
)
gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style
else:
gateway_compute = await create_gateway_compute(
session=session, backend_id=backend.id, gateway_id=gateway.id
session=session,
backend_id=backend.id,
gateway_id=gateway.id,
populate_configuration=populate_configuration,
)
await session.commit()
response = await client.post(
Expand Down Expand Up @@ -797,7 +825,10 @@ async def test_only_admin_can_set_default_gateway(

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_set_default_gateway(self, test_db, session: AsyncSession, client: AsyncClient):
@pytest.mark.parametrize("populate_configuration", [True, False])
async def test_set_default_gateway(
self, test_db, session: AsyncSession, client: AsyncClient, populate_configuration: bool
):
user = await create_user(session, global_role=GlobalRole.USER)
project = await create_project(session)
await add_project_member(
Expand All @@ -809,11 +840,13 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client:
project_id=project.id,
backend_id=backend.id,
name="first_gateway",
populate_configuration=populate_configuration,
)
gateway_compute = await create_gateway_compute(
session=session,
backend_id=backend.id,
gateway_id=gateway.id,
populate_configuration=populate_configuration,
)
response = await client.post(
f"/api/project/{project.name}/gateways/set_default",
Expand Down Expand Up @@ -875,11 +908,13 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client:
project_id=project.id,
backend_id=backend.id,
name="second_gateway",
populate_configuration=populate_configuration,
)
await create_gateway_compute(
session=session,
backend_id=backend.id,
gateway_id=second_gateway.id,
populate_configuration=populate_configuration,
)
await clear_events(session)
response = await client.post(
Expand Down Expand Up @@ -1061,8 +1096,13 @@ async def test_only_admin_can_delete(

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@pytest.mark.parametrize("populate_configuration", [True, False])
async def test_marks_gateways_to_be_deleted(
self, test_db, session: AsyncSession, client: AsyncClient
self,
test_db,
session: AsyncSession,
client: AsyncClient,
populate_configuration: bool,
):
user = await create_user(session, global_role=GlobalRole.USER)
project = await create_project(session)
Expand All @@ -1076,22 +1116,26 @@ async def test_marks_gateways_to_be_deleted(
project_id=project.id,
backend_id=backend_aws.id,
name="gateway-aws",
populate_configuration=populate_configuration,
)
gateway_compute_aws = await create_gateway_compute(
session=session,
backend_id=backend_aws.id,
gateway_id=gateway_aws.id,
populate_configuration=populate_configuration,
)
gateway_gcp = await create_gateway(
session=session,
project_id=project.id,
backend_id=backend_gcp.id,
name="gateway-gcp",
populate_configuration=populate_configuration,
)
gateway_compute_gcp = await create_gateway_compute(
session=session,
backend_id=backend_gcp.id,
gateway_id=gateway_gcp.id,
populate_configuration=populate_configuration,
)
response = await client.post(
f"/api/project/{project.name}/gateways/delete",
Expand Down Expand Up @@ -1183,7 +1227,10 @@ async def test_only_admin_can_set_wildcard_domain(

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client: AsyncClient):
@pytest.mark.parametrize("populate_configuration", [True, False])
async def test_set_wildcard_domain(
self, test_db, session: AsyncSession, client: AsyncClient, populate_configuration: bool
):
user = await create_user(session, global_role=GlobalRole.USER)
project = await create_project(session)
await add_project_member(
Expand All @@ -1195,11 +1242,13 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client:
project_id=project.id,
backend_id=backend.id,
wildcard_domain="old.example",
populate_configuration=populate_configuration,
)
gateway_compute = await create_gateway_compute(
session=session,
backend_id=backend.id,
gateway_id=gateway.id,
populate_configuration=populate_configuration,
)
response = await client.post(
f"/api/project/{project.name}/gateways/set_wildcard_domain",
Expand Down
46 changes: 46 additions & 0 deletions src/tests/_internal/server/routers/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3831,6 +3831,52 @@ async def test_submit_to_correct_proxy(
events = await list_events(session)
assert ("Service registered in gateway" in {e.message for e in events}) == is_gateway

@pytest.mark.asyncio
@pytest.mark.parametrize("populate_configuration", [True, False])
async def test_submit_to_gateway_by_name(
self,
test_db,
session: AsyncSession,
client: AsyncClient,
populate_configuration: bool,
) -> None:
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user, name="test-project")
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
repo = await create_repo(session=session, project_id=project.id)
backend = await create_backend(session=session, project_id=project.id)
gateway = await create_gateway(
session=session,
project_id=project.id,
backend_id=backend.id,
status=GatewayStatus.RUNNING,
name="my-gateway",
wildcard_domain="my-gateway.example",
populate_configuration=populate_configuration,
)
await create_gateway_compute(
session=session,
backend_id=backend.id,
gateway_id=gateway.id,
populate_configuration=populate_configuration,
)
run_spec = get_service_run_spec(
repo_id=repo.name,
run_name="test-service",
gateway="my-gateway",
)
response = await client.post(
f"/api/project/{project.name}/runs/submit",
headers=get_auth_headers(user.token),
json={"run_spec": run_spec},
)
assert response.status_code == 200
assert response.json()["service"]["url"] == "https://test-service.my-gateway.example"
events = await list_events(session)
assert "Service registered in gateway" in {e.message for e in events}

@pytest.mark.asyncio
async def test_return_error_if_specified_gateway_not_exists(
self, test_db, session: AsyncSession, client: AsyncClient
Expand Down
Loading