Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class JobTerminationReason(str, Enum):
FAILED_TO_START_DUE_TO_NO_CAPACITY = "failed_to_start_due_to_no_capacity"
INTERRUPTED_BY_NO_CAPACITY = "interrupted_by_no_capacity"
INSTANCE_UNREACHABLE = "instance_unreachable"
INSTANCE_ACCESS_REVOKED = "instance_access_revoked"
WAITING_INSTANCE_LIMIT_EXCEEDED = "waiting_instance_limit_exceeded"
WAITING_RUNNER_LIMIT_EXCEEDED = "waiting_runner_limit_exceeded"
TERMINATED_BY_USER = "terminated_by_user"
Expand All @@ -158,6 +159,7 @@ def to_status(self) -> JobStatus:
self.FAILED_TO_START_DUE_TO_NO_CAPACITY: JobStatus.FAILED,
self.INTERRUPTED_BY_NO_CAPACITY: JobStatus.FAILED,
self.INSTANCE_UNREACHABLE: JobStatus.FAILED,
self.INSTANCE_ACCESS_REVOKED: JobStatus.FAILED,
self.WAITING_INSTANCE_LIMIT_EXCEEDED: JobStatus.FAILED,
self.WAITING_RUNNER_LIMIT_EXCEEDED: JobStatus.FAILED,
self.TERMINATED_BY_USER: JobStatus.TERMINATED,
Expand Down Expand Up @@ -196,6 +198,7 @@ def to_error(self) -> Optional[str]:
# handled and shown in status_message.
error_mapping = {
JobTerminationReason.INSTANCE_UNREACHABLE: "instance unreachable",
JobTerminationReason.INSTANCE_ACCESS_REVOKED: "instance access revoked",
JobTerminationReason.WAITING_INSTANCE_LIMIT_EXCEEDED: "waiting instance limit exceeded",
JobTerminationReason.WAITING_RUNNER_LIMIT_EXCEEDED: "waiting runner limit exceeded",
JobTerminationReason.VOLUME_ERROR: "volume error",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Dict, Iterable, Literal, Optional, Sequence, Union

import httpx
from sqlalchemy import and_, func, or_, select, update
from sqlalchemy import and_, exists, func, or_, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import aliased, contains_eager, joinedload, load_only

Expand Down Expand Up @@ -51,7 +51,9 @@
from dstack._internal.server.background.pipeline_tasks.common import get_provisioning_timeout
from dstack._internal.server.db import get_db, get_session_ctx
from dstack._internal.server.models import (
ExportedFleetModel,
FleetModel,
ImportModel,
InstanceModel,
JobModel,
ProbeModel,
Expand Down Expand Up @@ -309,6 +311,7 @@ class _ProcessContext:
job: Job
job_submission: JobSubmission
job_provisioning_data: Optional[JobProvisioningData]
instance_access_revoked: bool
server_ssh_private_keys: Optional[tuple[str, Optional[str]]] = None

@property
Expand Down Expand Up @@ -374,6 +377,7 @@ async def _load_process_context(item: JobRunningPipelineItem) -> Optional[_Proce
)
run = run_model_to_run(run_model, include_sensitive=True)
job = find_job(run.jobs, job_model.replica_num, job_model.job_num)
instance_access_revoked = await _is_instance_access_revoked(session, job_model)
job_submission = job_model_to_job_submission(job_model)
server_ssh_private_keys = get_instance_ssh_private_keys(get_or_error(job_model.instance))
return _ProcessContext(
Expand All @@ -383,12 +387,24 @@ async def _load_process_context(item: JobRunningPipelineItem) -> Optional[_Proce
job=job,
job_submission=job_submission,
job_provisioning_data=job_submission.job_provisioning_data,
instance_access_revoked=instance_access_revoked,
server_ssh_private_keys=server_ssh_private_keys,
)


async def _process_running_job(context: _ProcessContext) -> _ProcessResult:
result = _ProcessResult()
if context.instance_access_revoked:
_terminate_job(
job_model=context.job_model,
job_update_map=result.job_update_map,
termination_reason=JobTerminationReason.INSTANCE_ACCESS_REVOKED,
termination_reason_message=(
"The instance is no longer imported into the job's project"
),
)
return result

if context.job_provisioning_data is None:
logger.error("%s: job_provisioning_data of an active job is None", fmt(context.job_model))
_terminate_job(
Expand Down Expand Up @@ -559,6 +575,22 @@ async def _fetch_run_model(
return res.unique().scalar_one()


async def _is_instance_access_revoked(session: AsyncSession, job_model: JobModel) -> bool:
if job_model.instance is None or job_model.instance.project_id == job_model.project_id:
return False
return not (
await session.execute(
select(
exists().where(
ImportModel.project_id == job_model.project_id,
ImportModel.export_id == ExportedFleetModel.export_id,
ExportedFleetModel.fleet_id == job_model.instance.fleet_id,
)
)
)
).scalar()


async def _process_provisioning_status(
context: _ProcessContext,
startup_context: _StartupContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1995,6 +1995,105 @@ async def test_registers_service_replica_in_gateway_when_running_on_imported_ins
ssh_head_proxy_private_key=None,
)

@pytest.mark.parametrize("job_status", [JobStatus.RUNNING, JobStatus.PULLING])
async def test_terminates_job_when_instance_access_revoked(
self,
test_db,
session: AsyncSession,
worker: JobRunningWorker,
job_status: JobStatus,
):
user = await create_user(session=session)
exporter_project = await create_project(session=session, name="exporter", owner=user)
importer_project = await create_project(session=session, name="importer", owner=user)
fleet = await create_fleet(session=session, project=exporter_project)
instance = await create_instance(
session=session,
project=exporter_project,
status=InstanceStatus.BUSY,
fleet=fleet,
)
repo = await create_repo(session=session, project_id=importer_project.id)
run = await create_run(
session=session,
project=importer_project,
repo=repo,
user=user,
)
job = await create_job(
session=session,
run=run,
status=job_status,
job_provisioning_data=get_job_provisioning_data(dockerized=True),
instance=instance,
instance_assigned=True,
)
# No export created -> the import link no longer exists -> access revoked

await _process_job(session, worker, job)

await session.refresh(job)
assert job.status == JobStatus.TERMINATING
assert job.termination_reason == JobTerminationReason.INSTANCE_ACCESS_REVOKED
events = await list_events(session)
assert len(events) == 1
assert events[0].message == (
f"Job status changed {job_status.upper()} -> TERMINATING."
" Termination reason: INSTANCE_ACCESS_REVOKED"
" (The instance is no longer imported into the job's project)"
)

@pytest.mark.parametrize("job_status", [JobStatus.RUNNING, JobStatus.PULLING])
async def test_does_not_terminate_job_when_instance_access_is_valid(
self,
test_db,
session: AsyncSession,
worker: JobRunningWorker,
ssh_tunnel_mock: Mock,
runner_client_mock: Mock,
job_status: JobStatus,
):
user = await create_user(session=session)
exporter_project = await create_project(session=session, name="exporter", owner=user)
importer_project = await create_project(session=session, name="importer", owner=user)
fleet = await create_fleet(session=session, project=exporter_project)
instance = await create_instance(
session=session,
project=exporter_project,
status=InstanceStatus.BUSY,
fleet=fleet,
)
await create_export(
session=session,
exporter_project=exporter_project,
importer_projects=[importer_project],
exported_fleets=[fleet],
)
repo = await create_repo(session=session, project_id=importer_project.id)
run = await create_run(
session=session,
project=importer_project,
repo=repo,
user=user,
)
job = await create_job(
session=session,
run=run,
status=job_status,
job_provisioning_data=get_job_provisioning_data(dockerized=False),
instance=instance,
instance_assigned=True,
)
runner_client_mock.pull.return_value = PullResponse(
job_states=[], job_logs=[], runner_logs=[], last_updated=0
)

await _process_job(session, worker, job)

await session.refresh(job)
assert job.status == job_status
assert job.termination_reason is None

async def test_apply_skips_probe_insert_when_lock_token_changes_after_processing(
self,
test_db,
Expand Down
Loading