Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
"""Main API application"""

import logging
from fastapi import FastAPI
from fastapi import FastAPI, Request
from opentelemetry import trace
from starlette.middleware.base import BaseHTTPMiddleware
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import ConsoleSpanExporter, BatchSpanProcessor, SimpleSpanProcessor
Expand All @@ -20,6 +21,7 @@
from app.routers.task import task

from . import config
from .request_context import set_api_url_base, _api_url_base

logging.basicConfig(
level=logging.INFO,
Expand Down Expand Up @@ -48,6 +50,19 @@

APP = FastAPI(servers=[{"url": config.API_URL_ROOT}], **config.API_CONFIG)


class _ExternalRequestContextMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
token = _api_url_base.set(None)
try:
set_api_url_base(request)
return await call_next(request)
finally:
_api_url_base.reset(token)


APP.add_middleware(_ExternalRequestContextMiddleware)

if config.OPENTELEMETRY_ENABLED:
FastAPIInstrumentor.instrument_app(APP)

Expand Down
30 changes: 30 additions & 0 deletions app/request_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Per-request URL context derived from forwarding headers. (e.g. for Kong or other API gateways)"""
from contextvars import ContextVar

from fastapi import Request

from . import config

_api_url_base: ContextVar[str | None] = ContextVar("_api_url_base", default=None)


def set_api_url_base(request: Request) -> None:
"""Set the per-request API URL base from forwarding headers."""
host = (request.headers.get("x-forwarded-host") or
request.headers.get("host", "")).split(",")[0].strip()
proto = (request.headers.get("x-forwarded-proto") or
request.url.scheme).split(",")[0].strip()
prefix = (request.headers.get("x-forwarded-prefix")
or request.headers.get("x-script-name")
or "").rstrip("/")
api_url = config.API_URL.strip("/")
if host:
_api_url_base.set(f"{proto}://{host}{prefix}/{api_url}")


def get_url_prefix() -> str:
"""Return the per-request API URL base, or fall back to static config."""
value = _api_url_base.get()
if value:
return value
return f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}"
10 changes: 5 additions & 5 deletions app/routers/account/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import datetime
from pydantic import Field, computed_field, field_validator

from ... import config
from ...request_context import get_url_prefix
from ...types.base import IRIBaseModel
from ...types.scalars import AllocationUnit

Expand All @@ -26,7 +26,7 @@ def _norm_dt_field(cls, v):
@property
def self_uri(self) -> str:
"""Return the URI for this project resource."""
return f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}/account/projects/{self.id}"
return f"{get_url_prefix()}/account/projects/{self.id}"


class AllocationEntry(IRIBaseModel):
Expand Down Expand Up @@ -54,13 +54,13 @@ class ProjectAllocation(IRIBaseModel):
@property
def project_uri(self) -> str:
"""Return the URI for the associated project resource."""
return f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}/account/projects/{self.project_id}"
return f"{get_url_prefix()}/account/projects/{self.project_id}"

@computed_field(description="URI of the associated capability resource")
@property
def capability_uri(self) -> str:
"""Return the URI for the associated capability."""
return f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}/account/capabilities/{self.capability_id}"
return f"{get_url_prefix()}/account/capabilities/{self.capability_id}"


class UserAllocation(IRIBaseModel):
Expand All @@ -79,4 +79,4 @@ class UserAllocation(IRIBaseModel):
@property
def project_allocation_uri(self) -> str:
"""Return the URI for the associated project allocation."""
return f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}/account/projects/{self.project_id}/project_allocations/{self.project_allocation_id}"
return f"{get_url_prefix()}/account/projects/{self.project_id}/project_allocations/{self.project_allocation_id}"
6 changes: 3 additions & 3 deletions app/routers/facility/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Facility-related models."""
from pydantic import Field, HttpUrl, computed_field

from ... import config
from ...request_context import get_url_prefix
from ...types.base import NamedObject


Expand All @@ -26,7 +26,7 @@ def _self_path(self) -> str:
@property
def resource_uris(self) -> list[str]:
"""Return the list of resource URIs for this site."""
return [f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}/status/resources/{resource_id}" for resource_id in self.resource_ids]
return [f"{get_url_prefix()}/status/resources/{resource_id}" for resource_id in self.resource_ids]

@classmethod
def find(cls, items, name=None, description=None, modified_since=None, short_name=None, country_name=None):
Expand All @@ -53,4 +53,4 @@ def _self_path(self) -> str:
@property
def site_uris(self) -> list[str]:
"""Return the list of site URIs for this facility."""
return [f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}/facility/sites/{site_id}" for site_id in self.site_ids]
return [f"{get_url_prefix()}/facility/sites/{site_id}" for site_id in self.site_ids]
14 changes: 7 additions & 7 deletions app/routers/status/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import Field, computed_field, field_validator

from ... import config
from ...request_context import get_url_prefix
from ...types.base import NamedObject


Expand Down Expand Up @@ -43,13 +43,13 @@ def _self_path(self) -> str:
@property
def site_uri(self) -> str:
"""Return the site URI for this resource."""
return f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}/facility/sites/{self.site_id}"
return f"{get_url_prefix()}/facility/sites/{self.site_id}"

@computed_field(description="The list of capabilities in this resource")
@property
def capability_uris(self) -> list[str]:
"""Return the list of capability URIs for this resource."""
return [f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}/account/capabilities/{e}" for e in self.capability_ids]
return [f"{get_url_prefix()}/account/capabilities/{e}" for e in self.capability_ids]

@classmethod
def find(cls, items, name=None, description=None, modified_since=None, group=None, resource_type=None, current_status=None, capability=None, site_id=None) -> list:
Expand Down Expand Up @@ -89,13 +89,13 @@ def _norm_dt_field(cls, v):
@property
def resource_uri(self) -> str:
"""Return the resource URI for this event."""
return f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}/status/resources/{self.resource_id}"
return f"{get_url_prefix()}/status/resources/{self.resource_id}"

@computed_field(description="The event's incident")
@property
def incident_uri(self) -> str | None:
"""Return the incident URI for this event."""
return f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}/status/incidents/{self.incident_id}" if self.incident_id else None
return f"{get_url_prefix()}/status/incidents/{self.incident_id}" if self.incident_id else None

@classmethod
def find(cls, items, incident_id=None, name=None, description=None, modified_since=None, resource_id=None, status=None, from_=None, to=None, time_=None) -> list:
Expand Down Expand Up @@ -162,13 +162,13 @@ def _norm_dt_field(cls, v):
@property
def event_uris(self) -> list[str]:
"""Return the list of event URIs for this incident."""
return [f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}/status/events/{e}" for e in self.event_ids]
return [f"{get_url_prefix()}/status/events/{e}" for e in self.event_ids]

@computed_field(description="The list of resources that may be impacted by this incident")
@property
def resource_uris(self) -> list[str]:
"""Return the list of resource URIs for this incident."""
return [f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}/status/resources/{r}" for r in self.resource_ids]
return [f"{get_url_prefix()}/status/resources/{r}" for r in self.resource_ids]

@classmethod
def find(cls, items, name=None, description=None, modified_since=None, status=None, type_=None, from_=None, to=None, time_=None, resource_id=None, resolution=None) -> list:
Expand Down
4 changes: 2 additions & 2 deletions app/routers/task/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import enum
from pydantic import BaseModel, Field, computed_field

from ... import config
from ...request_context import get_url_prefix


class TaskSubmitResponse(BaseModel):
Expand All @@ -13,7 +13,7 @@ class TaskSubmitResponse(BaseModel):
@property
def task_uri(self) -> str:
"""Return the URI for this task."""
return f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}/task/{self.task_id}"
return f"{get_url_prefix()}/task/{self.task_id}"


class TaskStatus(str, enum.Enum):
Expand Down
3 changes: 2 additions & 1 deletion app/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic import BaseModel, ConfigDict, Field, computed_field, field_validator, model_serializer

from .. import config
from ..request_context import get_url_prefix
from .scalars import StrictDateTime


Expand Down Expand Up @@ -59,7 +60,7 @@ def _norm_dt_field(cls, v):
@property
def self_uri(self) -> str:
"""Computed self URI property."""
return f"{config.API_URL_ROOT}{config.API_PREFIX}{config.API_URL}{self._self_path()}"
return f"{get_url_prefix()}{self._self_path()}"

name: str|None = Field(default=None, description="The long name of the object.", example="Perlmutter GPU")
description: str|None = Field(default=None, description="Human-readable description of the object.", example="High-performance GPU compute resource")
Expand Down
Loading