diff --git a/.env.example b/.env.example index 9e0195f99..562132a89 100644 --- a/.env.example +++ b/.env.example @@ -18,3 +18,9 @@ OPENAI_API_KEY=policyengine_openai_api_key # Token for Hugging Face models HUGGING_FACE_TOKEN=policyengine_huggingface_token + +# Redis is required for budget-window economy requests and other API cache paths. +# Local development and App Engine use an in-container/local Redis by default. +CACHE_REDIS_HOST=127.0.0.1 +CACHE_REDIS_PORT=6379 +CACHE_REDIS_DB=0 diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 1ff895b02..cb5289d74 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -226,7 +226,7 @@ jobs: - name: Install staging test dependencies run: pip install pytest httpx - name: Run staging smoke test - run: python -m pytest tests/integration/test_live_calculate.py tests/integration/test_live_economy.py -v + run: python -m pytest tests/integration/test_live_calculate.py tests/integration/test_live_economy.py tests/integration/test_live_budget_window_cache.py -v env: API_BASE_URL: ${{ needs.deploy-staging.outputs.url }} STAGING_API_TEST_PROBE_ID: ${{ needs.deploy-staging.outputs.version }} diff --git a/Makefile b/Makefile index 7a3588d95..4fe1df11c 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ test-env-vars: pytest tests/env_variables test: - MAX_HOUSEHOLDS=1000 coverage run -a --branch -m pytest tests/to_refactor tests/unit --disable-pytest-warnings + MAX_HOUSEHOLDS=1000 coverage run -a --branch -m pytest tests/to_refactor tests/unit tests/integration/test_budget_window_in_flight_dedupe.py --disable-pytest-warnings coverage xml -i debug-test: diff --git a/README.md b/README.md index bb584acee..5aa1d2477 100644 --- a/README.md +++ b/README.md @@ -120,6 +120,8 @@ NOTE: Any output that needs to be calculated will not work. Therefore, only hous ### 6. Testing calculations +Redis is required for API cache paths, including budget-window economy requests. The budget-window endpoint uses Redis for completed-result caching and in-flight batch deduplication; if Redis is unavailable, those requests fail instead of falling back to the database or an in-process cache. + To test anything that utilizes Redis or the API's service workers (e.g. anything that requires society-wide calculations with the policy calculator), you'll also need to complete the following steps: 1. Start Redis @@ -136,6 +138,8 @@ brew install redis redis-server ``` +By default the API connects to Redis at `127.0.0.1:6379`, database `0`. Override this with `CACHE_REDIS_HOST`, `CACHE_REDIS_PORT`, and `CACHE_REDIS_DB` if your local Redis uses different connection settings. + 2. Start the API Run the below @@ -144,6 +148,8 @@ Run the below FLASK_DEBUG=1 python -m flask --app policyengine_api.api run ``` +App Engine staging and production deployments install and start Redis in the API container before Gunicorn starts. + NOTE: Calculations are not possible in the uk app without access to a specific dataset. Expect an error: "ValueError: Invalid response code 404 for url https://api.github.com/repos/policyengine/non-public-microdata/releases/tags/uk-2024-march-efo." ## Testing, Formatting, Changelogging diff --git a/changelog.d/budget-window-batch.fixed.md b/changelog.d/budget-window-batch.fixed.md new file mode 100644 index 000000000..9bd03006f --- /dev/null +++ b/changelog.d/budget-window-batch.fixed.md @@ -0,0 +1 @@ +Added a budget-window economy endpoint that batches yearly impact calculations with bounded server-side concurrency and returns aggregated progress plus totals. diff --git a/gcp/README.md b/gcp/README.md index 477ab5fc7..a63f65ebd 100644 --- a/gcp/README.md +++ b/gcp/README.md @@ -2,6 +2,8 @@ The deployment actions build Docker images and deploy them to Google App Engine. The docker images themselves are based off a starter image (to save each API docker image having to spend 5 minutes installing the same dependencies). The starter image is the `Dockerfile` in this directory. +The App Engine API image installs `redis-server` and starts it through `gcp/policyengine_api/start.sh`. Redis is required at runtime for budget-window economy request caching and in-flight batch deduplication. The API reads `CACHE_REDIS_HOST`, `CACHE_REDIS_PORT`, and `CACHE_REDIS_DB`, defaulting to `127.0.0.1`, `6379`, and `0`. + To update the starter image: * `python setup.py sdist` to build the python package * `twine upload dist/*` to upload the package to pypi as `policyengine-api` diff --git a/gcp/policyengine_api/start.sh b/gcp/policyengine_api/start.sh index 37abbad46..92818ba81 100644 --- a/gcp/policyengine_api/start.sh +++ b/gcp/policyengine_api/start.sh @@ -1,18 +1,25 @@ #!/bin/sh # Environment variables PORT="${PORT:-8080}" -REDIS_PORT="${REDIS_PORT:-6379}" +CACHE_REDIS_HOST="${CACHE_REDIS_HOST:-127.0.0.1}" +CACHE_REDIS_PORT="${CACHE_REDIS_PORT:-6379}" +CACHE_REDIS_DB="${CACHE_REDIS_DB:-0}" +export CACHE_REDIS_HOST CACHE_REDIS_PORT CACHE_REDIS_DB -# Start the API -gunicorn -b :"$PORT" policyengine_api.api --timeout 300 --workers 5 --preload & - -# Start Redis with configuration for multiple clients -redis-server --protected-mode no \ +# Start Redis with configuration for multiple clients. +redis-server --bind "$CACHE_REDIS_HOST" \ + --port "$CACHE_REDIS_PORT" \ + --protected-mode yes \ --maxclients 10000 \ --timeout 0 & # Wait for Redis to be ready -sleep 2 +until redis-cli -h "$CACHE_REDIS_HOST" -p "$CACHE_REDIS_PORT" ping >/dev/null 2>&1; do + sleep 1 +done + +# Start the API +gunicorn -b :"$PORT" policyengine_api.api --timeout 300 --workers 5 --preload & # Keep the script running and handle shutdown gracefully trap "pkill -P $$; exit 1" INT TERM diff --git a/policyengine_api/api.py b/policyengine_api/api.py index 112cce9ac..7d6a67b01 100644 --- a/policyengine_api/api.py +++ b/policyengine_api/api.py @@ -4,6 +4,7 @@ import time import sys +import os start_time = time.time() @@ -89,8 +90,9 @@ def log_timing(message): { "CACHE_TYPE": "RedisCache", "CACHE_KEY_PREFIX": "policyengine", - "CACHE_REDIS_HOST": "127.0.0.1", - "CACHE_REDIS_PORT": 6379, + "CACHE_REDIS_HOST": os.environ.get("CACHE_REDIS_HOST", "127.0.0.1"), + "CACHE_REDIS_PORT": int(os.environ.get("CACHE_REDIS_PORT", "6379")), + "CACHE_REDIS_DB": int(os.environ.get("CACHE_REDIS_DB", "0")), "CACHE_DEFAULT_TIMEOUT": 300, } ) diff --git a/policyengine_api/libs/simulation_api_modal.py b/policyengine_api/libs/simulation_api_modal.py index 3d7660791..0ea8900f9 100644 --- a/policyengine_api/libs/simulation_api_modal.py +++ b/policyengine_api/libs/simulation_api_modal.py @@ -7,7 +7,7 @@ import os import sys -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional import httpx @@ -42,6 +42,28 @@ def name(self) -> str: return self.job_id +@dataclass +class ModalBudgetWindowBatchExecution: + """ + Represents a budget-window batch execution in the Modal simulation API. + """ + + batch_job_id: str + status: str + progress: Optional[int] = None + completed_years: list[str] = field(default_factory=list) + running_years: list[str] = field(default_factory=list) + queued_years: list[str] = field(default_factory=list) + failed_years: list[str] = field(default_factory=list) + result: Optional[dict] = None + error: Optional[str] = None + + @property + def name(self) -> str: + """Alias for batch_job_id.""" + return self.batch_job_id + + class SimulationAPIModal: """ HTTP client for the Modal Simulation API. @@ -154,6 +176,57 @@ def run(self, payload: dict) -> ModalSimulationExecution: ) raise + def run_budget_window_batch(self, payload: dict) -> ModalBudgetWindowBatchExecution: + """ + Submit a budget-window batch job to the Modal API. + """ + try: + modal_payload = dict(payload) + if "model_version" in modal_payload: + modal_payload["version"] = modal_payload.pop("model_version") + modal_payload.pop("data_version", None) + + response = self.client.post( + f"{self.base_url}/simulate/economy/budget-window", + json=modal_payload, + ) + response.raise_for_status() + data = response.json() + + logger.log_struct( + { + "message": "Modal budget-window batch submitted", + "batch_job_id": data.get("batch_job_id"), + "status": data.get("status"), + }, + severity="INFO", + ) + + return ModalBudgetWindowBatchExecution( + batch_job_id=data["batch_job_id"], + status=data["status"], + ) + + except httpx.HTTPStatusError as e: + logger.log_struct( + { + "message": f"Modal batch API HTTP error: {e.response.status_code}", + "response_text": e.response.text[:500], + }, + severity="ERROR", + ) + raise + + except httpx.RequestError as e: + logger.log_struct( + { + "message": f"Modal batch API request error: {str(e)}", + "run_id": (payload.get("_telemetry") or {}).get("run_id"), + }, + severity="ERROR", + ) + raise + def resolve_app_name( self, country: str, version: Optional[str] = None ) -> tuple[str, str]: @@ -235,6 +308,51 @@ def get_execution_by_id(self, job_id: str) -> ModalSimulationExecution: ) raise + def get_budget_window_batch_by_id( + self, batch_job_id: str + ) -> ModalBudgetWindowBatchExecution: + """ + Poll the Modal API for the current status of a budget-window batch. + """ + try: + response = self.client.get( + f"{self.base_url}/budget-window-jobs/{batch_job_id}" + ) + if response.status_code not in (200, 202, 500): + response.raise_for_status() + data = response.json() + + return ModalBudgetWindowBatchExecution( + batch_job_id=batch_job_id, + status=data["status"], + progress=data.get("progress"), + completed_years=data.get("completed_years", []), + running_years=data.get("running_years", []), + queued_years=data.get("queued_years", []), + failed_years=data.get("failed_years", []), + result=data.get("result"), + error=data.get("error"), + ) + + except httpx.HTTPStatusError as e: + logger.log_struct( + { + "message": f"Modal batch API HTTP error polling job {batch_job_id}: {e.response.status_code}", + "response_text": e.response.text[:500], + }, + severity="ERROR", + ) + raise + + except httpx.RequestError as e: + logger.log_struct( + { + "message": f"Modal batch API request error polling job {batch_job_id}: {str(e)}", + }, + severity="ERROR", + ) + raise + def get_execution_status(self, execution: ModalSimulationExecution) -> str: """ Get the status string from an execution. diff --git a/policyengine_api/openapi_spec.yaml b/policyengine_api/openapi_spec.yaml index a49268c8c..77daadc9e 100644 --- a/policyengine_api/openapi_spec.yaml +++ b/policyengine_api/openapi_spec.yaml @@ -660,6 +660,138 @@ paths: type: string message: type: string + /{country_id}/economy/{policy_id}/over/{baseline_policy_id}/budget-window: + get: + summary: Calculate budget-window economic impacts + operationId: get_budget_window_economic_impact + description: Calculate annual and total budget impacts for a policy over a multi-year budget window. + parameters: + - name: country_id + in: path + description: The country ID. + required: true + schema: + type: string + - name: policy_id + in: path + description: The reform policy ID. + required: true + schema: + type: string + - name: baseline_policy_id + in: path + description: The baseline policy ID. + required: true + schema: + type: string + - name: region + in: query + description: The sub-national region. + required: true + schema: + type: string + - name: start_year + in: query + description: First year in the budget window. + required: true + schema: + type: string + - name: window_size + in: query + description: Number of years to include in the budget window. + required: true + schema: + type: integer + - name: dataset + in: query + description: Dataset selection. + required: false + schema: + type: string + default: default + - name: version + in: query + description: API version number. + required: false + schema: + type: string + - name: include_district_breakdowns + in: query + description: Whether to include congressional district breakdowns for US national simulations. + required: false + schema: + type: boolean + default: false + - name: target + in: query + description: Impact target. Budget-window calculations only support general impacts. + required: false + schema: + type: string + default: general + responses: + 200: + description: Budget-window economic impact, progress, or error state. + content: + application/json: + schema: + type: object + properties: + status: + type: string + enum: + - ok + - computing + - error + message: + type: string + nullable: true + result: + type: object + nullable: true + progress: + type: integer + nullable: true + completed_years: + type: array + items: + type: string + computing_years: + type: array + items: + type: string + queued_years: + type: array + items: + type: string + error: + type: string + nullable: true + 400: + description: Invalid budget-window request. + content: + application/json: + schema: + type: object + properties: + status: + type: string + message: + type: string + result: + type: object + nullable: true + 404: + description: Invalid country ID. + content: + text/html: + schema: + type: object + properties: + status: + type: string + message: + type: string /{country_id}/analysis: post: summary: Get or trigger policy analysis diff --git a/policyengine_api/routes/economy_routes.py b/policyengine_api/routes/economy_routes.py index 1807416f2..cbecc16cd 100644 --- a/policyengine_api/routes/economy_routes.py +++ b/policyengine_api/routes/economy_routes.py @@ -2,6 +2,7 @@ from policyengine_api.services.economy_service import ( EconomyService, EconomicImpactResult, + BudgetWindowEconomicImpactResult, ) from policyengine_api.utils import get_current_law_policy_id from policyengine_api.utils.payload_validators import validate_country @@ -11,6 +12,26 @@ economy_bp = Blueprint("economy", __name__) economy_service = EconomyService() +BUDGET_WINDOW_CACHE_HEADER = "X-PolicyEngine-Budget-Window-Cache" + + +def _json_response(payload: dict, status: int = 200) -> Response: + return Response( + json.dumps(payload), + status=status, + mimetype="application/json", + ) + + +def _bad_request_response(message: str) -> Response: + return _json_response( + { + "status": "error", + "message": message, + "result": None, + }, + status=400, + ) @economy_bp.route( @@ -56,14 +77,97 @@ def get_economic_impact(country_id: str, policy_id: int, baseline_policy_id: int result_dict: dict[str, str | dict | None] = economic_impact_result.to_dict() - return Response( - json.dumps( - { - "status": result_dict["status"], - "message": None, - "result": result_dict["data"], - } - ), - status=200, - mimetype="application/json", + return _json_response( + { + "status": result_dict["status"], + "message": None, + "result": result_dict["data"], + } + ) + + +@economy_bp.route( + "//economy//over//budget-window", + methods=["GET"], +) +@validate_country +def get_budget_window_economic_impact( + country_id: str, policy_id: int, baseline_policy_id: int +): + policy_id = int(policy_id or get_current_law_policy_id(country_id)) + baseline_policy_id = int( + baseline_policy_id or get_current_law_policy_id(country_id) + ) + + query_parameters = request.args + options = dict(query_parameters) + options = json.loads(json.dumps(options)) + region = options.pop("region", None) + if not region: + return _bad_request_response("Missing required query parameter: region") + + dataset = options.pop("dataset", "default") + start_year = options.pop("start_year", None) + if not start_year: + return _bad_request_response("Missing required query parameter: start_year") + + window_size_raw = options.pop("window_size", None) + if window_size_raw is None: + return _bad_request_response("Missing required query parameter: window_size") + + try: + window_size = int(window_size_raw) + except (TypeError, ValueError): + return _bad_request_response("window_size must be an integer") + + include_district_breakdowns_raw = options.pop( + "include_district_breakdowns", "false" + ) + include_district_breakdowns = include_district_breakdowns_raw.lower() == "true" + if include_district_breakdowns and country_id == "us" and region == "us": + dataset = "national-with-breakdowns" + + target: Literal["general", "cliff"] = options.pop("target", "general") + if target != "general": + return _bad_request_response( + "Budget-window calculations only support target=general" + ) + + api_version = options.pop("version", COUNTRY_PACKAGE_VERSIONS.get(country_id)) + + try: + economic_impact_result: BudgetWindowEconomicImpactResult = ( + economy_service.get_budget_window_economic_impact( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + start_year=start_year, + window_size=window_size, + options=options, + api_version=api_version, + target=target, + ) + ) + except ValueError as error: + return _bad_request_response(str(error)) + + result_dict = economic_impact_result.to_dict() + + response = _json_response( + { + "status": result_dict["status"], + "message": result_dict["message"], + "result": result_dict["data"], + "progress": result_dict["progress"], + "completed_years": result_dict["completed_years"], + "computing_years": result_dict["computing_years"], + "queued_years": result_dict["queued_years"], + "error": result_dict["error"], + } ) + cache_status = getattr(economic_impact_result, "cache_status", None) + if isinstance(cache_status, str) and cache_status: + response.headers[BUDGET_WINDOW_CACHE_HEADER] = cache_status + return response diff --git a/policyengine_api/services/budget_window_cache.py b/policyengine_api/services/budget_window_cache.py new file mode 100644 index 000000000..7c19f5921 --- /dev/null +++ b/policyengine_api/services/budget_window_cache.py @@ -0,0 +1,168 @@ +import hashlib +import json +import os +from typing import Any + +import redis + +from policyengine_api.gcp_logging import logger + +BUDGET_WINDOW_CACHE_PREFIX = "budget_window:v1" +BUDGET_WINDOW_STARTING_PREFIX = "starting:" +BUDGET_WINDOW_STARTING_TTL_SECONDS = int( + os.environ.get("BUDGET_WINDOW_STARTING_TTL_SECONDS", "300") +) +BUDGET_WINDOW_BATCH_TTL_SECONDS = int( + os.environ.get("BUDGET_WINDOW_BATCH_TTL_SECONDS", "86400") +) +BUDGET_WINDOW_RESULT_TTL_SECONDS = int( + os.environ.get("BUDGET_WINDOW_RESULT_TTL_SECONDS", "2592000") +) + + +class BudgetWindowCache: + """Redis-backed cache and in-flight mapping for budget-window requests.""" + + def __init__(self, client: redis.Redis | None = None): + self._client = client + + @property + def client(self) -> redis.Redis: + if self._client is None: + self._client = redis.Redis( + host=os.environ.get("CACHE_REDIS_HOST", "127.0.0.1"), + port=int(os.environ.get("CACHE_REDIS_PORT", "6379")), + db=int(os.environ.get("CACHE_REDIS_DB", "0")), + decode_responses=True, + socket_connect_timeout=1, + socket_timeout=1, + ) + return self._client + + def build_key( + self, + *, + country_id: str, + reform_policy_id: int, + baseline_policy_id: int, + region: str, + dataset: str, + time_period: str, + options_hash: str | None, + api_version: str, + ) -> str: + key_payload = { + "country_id": country_id, + "reform_policy_id": reform_policy_id, + "baseline_policy_id": baseline_policy_id, + "region": region, + "dataset": dataset, + "time_period": time_period, + "options_hash": options_hash, + "api_version": api_version, + } + encoded = json.dumps( + key_payload, + sort_keys=True, + separators=(",", ":"), + default=str, + ) + digest = hashlib.sha256(encoded.encode("utf-8")).hexdigest() + return f"{BUDGET_WINDOW_CACHE_PREFIX}:{country_id}:{digest}" + + def _result_key(self, cache_key: str) -> str: + return f"{cache_key}:result" + + def _batch_key(self, cache_key: str) -> str: + return f"{cache_key}:batch_job_id" + + def _handle_cache_error(self, operation: str, error: Exception) -> None: + logger.log_struct( + { + "message": f"Budget-window Redis cache {operation} failed", + "error": str(error), + }, + severity="WARNING", + ) + + def get_completed_result(self, cache_key: str) -> dict[str, Any] | None: + try: + payload = self.client.get(self._result_key(cache_key)) + except Exception as error: + self._handle_cache_error("read", error) + raise + + if not payload: + return None + + try: + result = json.loads(payload) + except (TypeError, ValueError) as error: + self._handle_cache_error("decode", error) + return None + + return result if isinstance(result, dict) else None + + def set_completed_result(self, cache_key: str, result: dict[str, Any]) -> None: + try: + self.client.set( + self._result_key(cache_key), + json.dumps(result), + ex=BUDGET_WINDOW_RESULT_TTL_SECONDS, + ) + except Exception as error: + self._handle_cache_error("write result", error) + raise + + def get_batch_job_id(self, cache_key: str) -> str | None: + try: + value = self.client.get(self._batch_key(cache_key)) + except Exception as error: + self._handle_cache_error("read batch id", error) + raise + + if not isinstance(value, str) or not value: + return None + if value.startswith(BUDGET_WINDOW_STARTING_PREFIX): + return None + return value + + def claim_batch_start(self, cache_key: str, claim_token: str) -> bool: + try: + claimed = self.client.set( + self._batch_key(cache_key), + f"{BUDGET_WINDOW_STARTING_PREFIX}{claim_token}", + nx=True, + ex=BUDGET_WINDOW_STARTING_TTL_SECONDS, + ) + except Exception as error: + self._handle_cache_error("claim", error) + raise + + return bool(claimed) + + def store_batch_job_id(self, cache_key: str, batch_job_id: str) -> None: + try: + self.client.set( + self._batch_key(cache_key), + batch_job_id, + ex=BUDGET_WINDOW_BATCH_TTL_SECONDS, + ) + except Exception as error: + self._handle_cache_error("write batch id", error) + raise + + def clear_starting_claim(self, cache_key: str, claim_token: str) -> None: + try: + batch_key = self._batch_key(cache_key) + value = self.client.get(batch_key) + if value == f"{BUDGET_WINDOW_STARTING_PREFIX}{claim_token}": + self.client.delete(batch_key) + except Exception as error: + self._handle_cache_error("clear claim", error) + + def clear_batch_job_id(self, cache_key: str) -> None: + try: + self.client.delete(self._batch_key(cache_key)) + except Exception as error: + self._handle_cache_error("clear batch id", error) diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 871c896cc..3532244d9 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -2,6 +2,7 @@ from policyengine_api.services.reform_impacts_service import ( ReformImpactsService, ) +from policyengine_api.services.budget_window_cache import BudgetWindowCache from policyengine_api.constants import ( COUNTRY_PACKAGE_VERSIONS, EXECUTION_STATUSES_SUCCESS, @@ -20,6 +21,7 @@ normalize_us_region, ) from policyengine_api.data.places import validate_place_code +from policyengine_api.utils import budget_window as budget_window_utils from policyengine.simulation import SimulationOptions from policyengine.utils.data.datasets import get_default_dataset import json @@ -28,7 +30,7 @@ import uuid from typing import Literal, Any, Optional, Annotated from dotenv import load_dotenv -from pydantic import BaseModel +from pydantic import BaseModel, Field import numpy as np from enum import Enum @@ -37,6 +39,7 @@ policy_service = PolicyService() reform_impacts_service = ReformImpactsService() simulation_api = simulation_api_modal +budget_window_cache = BudgetWindowCache() def get_policyengine_version() -> str | None: @@ -71,6 +74,9 @@ class ImpactStatus(Enum): COMPLETE_STATUSES = [ImpactStatus.OK.value, ImpactStatus.ERROR.value] COMPUTING_STATUS = ImpactStatus.COMPUTING.value +BUDGET_WINDOW_MAX_ACTIVE_YEARS = budget_window_utils.BUDGET_WINDOW_MAX_ACTIVE_YEARS +BUDGET_WINDOW_MAX_YEARS = budget_window_utils.BUDGET_WINDOW_MAX_YEARS +BUDGET_WINDOW_MAX_END_YEAR = budget_window_utils.BUDGET_WINDOW_MAX_END_YEAR class EconomicImpactSetupOptions(BaseModel): @@ -134,6 +140,91 @@ def error(cls, message: str) -> "EconomicImpactResult": return cls(status=ImpactStatus.ERROR, data=None) +class BudgetWindowEconomicImpactResult(BaseModel): + """ + Model for a batch budget-window economic impact response. + """ + + status: ImpactStatus + data: Optional[dict] = None + progress: Optional[int] = None + completed_years: list[str] = Field(default_factory=list) + computing_years: list[str] = Field(default_factory=list) + queued_years: list[str] = Field(default_factory=list) + message: Optional[str] = None + error: Optional[str] = None + cache_status: Optional[str] = None + + model_config = {"frozen": True} + + def to_dict(self) -> dict[str, Any]: + return { + "status": self.status.value, + "data": self.data, + "progress": self.progress, + "completed_years": self.completed_years, + "computing_years": self.computing_years, + "queued_years": self.queued_years, + "message": self.message, + "error": self.error, + } + + @classmethod + def completed( + cls, data: dict, *, cache_status: Optional[str] = None + ) -> "BudgetWindowEconomicImpactResult": + return cls( + status=ImpactStatus.OK, + data=data, + progress=100, + cache_status=cache_status, + ) + + @classmethod + def computing( + cls, + *, + progress: int, + completed_years: list[str], + computing_years: list[str], + queued_years: list[str], + message: str, + cache_status: Optional[str] = None, + ) -> "BudgetWindowEconomicImpactResult": + return cls( + status=ImpactStatus.COMPUTING, + data=None, + progress=progress, + completed_years=completed_years, + computing_years=computing_years, + queued_years=queued_years, + message=message, + cache_status=cache_status, + ) + + @classmethod + def failed( + cls, + message: str, + *, + completed_years: Optional[list[str]] = None, + computing_years: Optional[list[str]] = None, + queued_years: Optional[list[str]] = None, + cache_status: Optional[str] = None, + ) -> "BudgetWindowEconomicImpactResult": + logger.log_struct({"message": message}, severity="ERROR") + return cls( + status=ImpactStatus.ERROR, + data=None, + completed_years=completed_years or [], + computing_years=computing_years or [], + queued_years=queued_years or [], + message=message, + error=message, + cache_status=cache_status, + ) + + class EconomyService: """ Service for calculating economic impact of policy reforms; this is connected @@ -168,133 +259,432 @@ def get_economic_impact( if country_id == "us": region = normalize_us_region(region) - # Set up logging - process_id: str = self._create_process_id() + economic_impact_setup_options = self._build_economic_impact_setup_options( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + time_period=time_period, + options=options, + api_version=api_version, + target=target, + ) + + return self._get_or_create_economic_impact( + setup_options=economic_impact_setup_options, + ) - country_package_version = COUNTRY_PACKAGE_VERSIONS.get(country_id) - cache_version = get_economy_impact_cache_version(country_id, api_version) - resolved_dataset = self._setup_data( + except Exception as e: + print(f"Error getting economic impact: {str(e)}") + raise e + + def get_budget_window_economic_impact( + self, + country_id: str, + policy_id: int, + baseline_policy_id: int, + region: str, + dataset: str, + start_year: str, + window_size: int, + options: dict, + api_version: str, + target: Literal["general", "cliff"] = "general", + max_active_years: int = BUDGET_WINDOW_MAX_ACTIVE_YEARS, + ) -> BudgetWindowEconomicImpactResult: + try: + if country_id == "us": + region = normalize_us_region(region) + + budget_window_setup = budget_window_utils.build_budget_window_request_setup( + start_year=start_year, + window_size=window_size, + target=target, + ) + start_year = budget_window_setup.start_year + years = budget_window_setup.years + setup_options = self._build_economic_impact_setup_options( country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, region=region, dataset=dataset, + time_period=budget_window_setup.time_period, + options=dict(options), + api_version=api_version, + target=target, ) - resolved_model_version = country_package_version - resolved_data_version = self._extract_dataset_version(resolved_dataset) - options_hash = self._build_options_hash( - options=options, - model_version=resolved_model_version, - dataset=resolved_dataset, + cache_key = self._build_budget_window_cache_key(setup_options) + + cached_result = budget_window_cache.get_completed_result(cache_key) + if cached_result is not None: + return BudgetWindowEconomicImpactResult.completed( + cached_result, + cache_status="result-hit", + ) + + batch_job_id = budget_window_cache.get_batch_job_id(cache_key) + if batch_job_id: + return self._get_budget_window_result_from_batch_job_id( + batch_job_id=batch_job_id, + cache_key=cache_key, + total_years=len(years), + queued_years_on_submit=years, + cache_status="batch-id-hit", + ) + + claim_token = setup_options.process_id + cache_status = "starting-claim-hit" + if budget_window_cache.claim_batch_start(cache_key, claim_token): + cache_status = "miss" + try: + batch_execution = self._start_budget_window_batch( + setup_options=setup_options, + start_year=start_year, + window_size=window_size, + max_parallel=max_active_years, + ) + budget_window_cache.store_batch_job_id( + cache_key, batch_execution.batch_job_id + ) + except Exception: + budget_window_cache.clear_starting_claim(cache_key, claim_token) + raise + + return self._build_budget_window_computing_result( + total_years=len(years), + completed_years=[], + computing_years=[], + queued_years=years, + progress=0, + cache_status=cache_status, ) + except Exception as e: + print(f"Error getting budget-window economic impact: {str(e)}") + raise e - economic_impact_setup_options = EconomicImpactSetupOptions.model_validate( - { - "process_id": process_id, - "country_id": country_id, - "reform_policy_id": policy_id, - "baseline_policy_id": baseline_policy_id, - "region": region, - "dataset": resolved_dataset, - "time_period": time_period, - "options": options, - "api_version": cache_version, - "target": target, - "model_version": resolved_model_version, - "policyengine_version": None, - "data_version": resolved_data_version, - "runtime_app_name": None, - "options_hash": options_hash, - } + def _build_budget_window_cache_key( + self, + setup_options: EconomicImpactSetupOptions, + ) -> str: + return budget_window_cache.build_key( + country_id=setup_options.country_id, + reform_policy_id=setup_options.reform_policy_id, + baseline_policy_id=setup_options.baseline_policy_id, + region=setup_options.region, + dataset=setup_options.dataset, + time_period=setup_options.time_period, + options_hash=setup_options.options_hash, + api_version=setup_options.api_version, + ) + + def _build_budget_window_batch_payload( + self, + *, + setup_options: EconomicImpactSetupOptions, + start_year: str, + window_size: int, + max_parallel: int, + ) -> dict[str, Any]: + baseline_policy = policy_service.get_policy_json( + setup_options.country_id, + setup_options.baseline_policy_id, + ) + reform_policy = policy_service.get_policy_json( + setup_options.country_id, + setup_options.reform_policy_id, + ) + sim_config: SimulationOptions = self._setup_sim_options( + country_id=setup_options.country_id, + reform_policy=reform_policy, + baseline_policy=baseline_policy, + region=setup_options.region, + time_period=start_year, + dataset=setup_options.dataset, + scope="macro", + include_cliffs=False, + model_version=setup_options.model_version, + data_version=setup_options.data_version, + ) + sim_params = sim_config.model_dump() + sim_params.pop("time_period", None) + sim_params["start_year"] = start_year + sim_params["window_size"] = window_size + sim_params["max_parallel"] = max_parallel + sim_params["target"] = setup_options.target + return sim_params + + def _start_budget_window_batch( + self, + *, + setup_options: EconomicImpactSetupOptions, + start_year: str, + window_size: int, + max_parallel: int, + ): + sim_params = self._build_budget_window_batch_payload( + setup_options=setup_options, + start_year=start_year, + window_size=window_size, + max_parallel=max_parallel, + ) + + logger.log_struct( + { + "message": "Submitting budget-window batch job", + **setup_options.model_dump(), + "start_year": start_year, + "window_size": window_size, + "max_parallel": max_parallel, + }, + severity="INFO", + ) + + return simulation_api.run_budget_window_batch(sim_params) + + def _get_budget_window_result_from_batch_job_id( + self, + *, + batch_job_id: str, + cache_key: str, + total_years: int, + queued_years_on_submit: list[str], + cache_status: Optional[str] = None, + ) -> BudgetWindowEconomicImpactResult: + batch_execution = simulation_api.get_budget_window_batch_by_id(batch_job_id) + + if batch_execution.status in EXECUTION_STATUSES_SUCCESS: + result = batch_execution.result + if not isinstance(result, dict) or not result: + budget_window_cache.clear_batch_job_id(cache_key) + return BudgetWindowEconomicImpactResult.failed( + "Budget-window batch completed without a result", + completed_years=batch_execution.completed_years, + computing_years=batch_execution.running_years, + queued_years=batch_execution.queued_years or queued_years_on_submit, + cache_status=cache_status, + ) + budget_window_cache.set_completed_result(cache_key, result) + budget_window_cache.clear_batch_job_id(cache_key) + return BudgetWindowEconomicImpactResult.completed( + result, + cache_status=cache_status, ) - # Logging that we've received a request + if batch_execution.status in EXECUTION_STATUSES_FAILURE: + error_message = batch_execution.error or "Budget-window batch failed" + budget_window_cache.clear_batch_job_id(cache_key) + return BudgetWindowEconomicImpactResult.failed( + error_message, + completed_years=batch_execution.completed_years, + computing_years=batch_execution.running_years, + queued_years=batch_execution.queued_years or queued_years_on_submit, + cache_status=cache_status, + ) + + if batch_execution.status in EXECUTION_STATUSES_PENDING: + return self._build_budget_window_computing_result( + total_years=total_years, + completed_years=batch_execution.completed_years, + computing_years=batch_execution.running_years, + queued_years=batch_execution.queued_years, + progress=batch_execution.progress, + cache_status=cache_status, + ) + + raise ValueError( + f"Unexpected budget-window batch execution state: {batch_execution.status}" + ) + + def _build_budget_window_computing_result( + self, + *, + total_years: int, + completed_years: list[str], + computing_years: list[str], + queued_years: list[str], + progress: Optional[int] = None, + cache_status: Optional[str] = None, + ) -> BudgetWindowEconomicImpactResult: + resolved_progress = progress + if resolved_progress is None: + resolved_progress = round((len(completed_years) / total_years) * 100) + + return BudgetWindowEconomicImpactResult.computing( + progress=resolved_progress, + completed_years=completed_years, + computing_years=computing_years, + queued_years=queued_years, + message=self._build_budget_window_progress_message( + completed_years=completed_years, + total_years=total_years, + computing_years=computing_years, + queued_years=queued_years, + ), + cache_status=cache_status, + ) + + def _build_economic_impact_setup_options( + self, + *, + country_id: str, + policy_id: int, + baseline_policy_id: int, + region: str, + dataset: str, + time_period: str, + options: dict, + api_version: str, + target: Literal["general", "cliff"] = "general", + ) -> EconomicImpactSetupOptions: + process_id: str = self._create_process_id() + cache_version = get_economy_impact_cache_version(country_id, api_version) + country_package_version = COUNTRY_PACKAGE_VERSIONS.get(country_id) + resolved_dataset = self._setup_data( + country_id=country_id, + region=region, + dataset=dataset, + ) + resolved_data_version = self._extract_dataset_version(resolved_dataset) + options_hash = self._build_options_hash( + options=options, + model_version=country_package_version, + dataset=resolved_dataset, + ) + + return EconomicImpactSetupOptions.model_validate( + { + "process_id": process_id, + "country_id": country_id, + "reform_policy_id": policy_id, + "baseline_policy_id": baseline_policy_id, + "region": region, + "dataset": resolved_dataset, + "time_period": time_period, + "options": options, + "api_version": cache_version, + "target": target, + "model_version": country_package_version, + "policyengine_version": None, + "data_version": resolved_data_version, + "runtime_app_name": None, + "options_hash": options_hash, + } + ) + + def _get_or_create_economic_impact( + self, setup_options: EconomicImpactSetupOptions + ) -> EconomicImpactResult: + logger.log_struct( + { + "message": "Received request for economic impact; checking if already in reform_impacts table", + **setup_options.model_dump(), + }, + severity="INFO", + ) + + most_recent_impact: dict | None = self._get_most_recent_impact( + setup_options=setup_options + ) + + if most_recent_impact and self._should_refresh_cached_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ): + most_recent_impact = self._get_most_recent_impact(setup_options) + if ( + not most_recent_impact + or most_recent_impact.get("options_hash") != setup_options.options_hash + ): + most_recent_impact = None + + impact_action: ImpactAction = self._determine_impact_action( + most_recent_impact=most_recent_impact + ) + + if impact_action == ImpactAction.COMPLETED: logger.log_struct( { - "message": "Received request for economic impact; checking if already in reform_impacts table", - **economic_impact_setup_options.model_dump(), + "message": "Found completed economic impact in db; returning result", + **setup_options.model_dump(), }, severity="INFO", ) - - most_recent_impact: dict | None = self._get_most_recent_impact( - setup_options=economic_impact_setup_options, + return self._handle_completed_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, ) - if most_recent_impact and self._should_refresh_cached_impact( - setup_options=economic_impact_setup_options, - most_recent_impact=most_recent_impact, - ): - most_recent_impact = self._get_most_recent_impact( - economic_impact_setup_options - ) - if ( - not most_recent_impact - or most_recent_impact.get("options_hash") - != economic_impact_setup_options.options_hash - ): - most_recent_impact = None - - impact_action: ImpactAction = self._determine_impact_action( + if impact_action == ImpactAction.COMPUTING: + logger.log_struct( + { + "message": "Found computing economic impact record in db; confirming this is still computing", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_computing_impact( + setup_options=setup_options, most_recent_impact=most_recent_impact, ) - if impact_action == ImpactAction.COMPLETED: - logger.log_struct( - { - "message": "Found completed economic impact in db; returning result", - **economic_impact_setup_options.model_dump(), - }, - severity="INFO", - ) - return self._handle_completed_impact( - setup_options=economic_impact_setup_options, - most_recent_impact=most_recent_impact, - ) + if impact_action == ImpactAction.CREATE: + self._resolve_runtime_bundle_for_setup_options(setup_options) + logger.log_struct( + { + "message": "No previous economic impact record found in db; creating new simulation run", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_create_impact( + setup_options=setup_options, + ) - if impact_action == ImpactAction.COMPUTING: - logger.log_struct( - { - "message": "Found computing economic impact record in db; confirming this is still computing", - **economic_impact_setup_options.model_dump(), - }, - severity="INFO", - ) - return self._handle_computing_impact( - setup_options=economic_impact_setup_options, - most_recent_impact=most_recent_impact, - ) + raise ValueError(f"Unexpected impact action: {impact_action}") - if impact_action == ImpactAction.CREATE: - if economic_impact_setup_options.runtime_app_name is None: - ( - economic_impact_setup_options.runtime_app_name, - economic_impact_setup_options.model_version, - ) = simulation_api.resolve_app_name( - country_id, - economic_impact_setup_options.model_version, - ) - economic_impact_setup_options.options_hash = self._build_options_hash( - options=options, - model_version=economic_impact_setup_options.model_version, - dataset=resolved_dataset, - data_version=resolved_data_version, - runtime_app_name=economic_impact_setup_options.runtime_app_name, - ) - logger.log_struct( - { - "message": "No previous economic impact record found in db; creating new simulation run", - **economic_impact_setup_options.model_dump(), - }, - severity="INFO", - ) - return self._handle_create_impact( - setup_options=economic_impact_setup_options, - ) + def _resolve_runtime_bundle_for_setup_options( + self, + setup_options: EconomicImpactSetupOptions, + ) -> None: + if setup_options.runtime_app_name is None: + ( + setup_options.runtime_app_name, + setup_options.model_version, + ) = simulation_api.resolve_app_name( + setup_options.country_id, + setup_options.model_version, + ) - raise ValueError(f"Unexpected impact action: {impact_action}") + setup_options.options_hash = self._build_options_hash( + options=setup_options.options, + model_version=setup_options.model_version, + dataset=setup_options.dataset, + data_version=setup_options.data_version, + policyengine_version=setup_options.policyengine_version, + runtime_app_name=setup_options.runtime_app_name, + ) - except Exception as e: - print(f"Error getting economic impact: {str(e)}") - raise e + def _build_budget_window_progress_message( + self, + *, + completed_years: list[str], + total_years: int, + computing_years: list[str], + queued_years: list[str], + ) -> str: + completed_count = len(completed_years) + if computing_years: + active_years = ", ".join(computing_years[:2]) + if len(computing_years) > 2: + active_years = f"{active_years} + {len(computing_years) - 2} more" + return f"Scoring {active_years} ({completed_count} of {total_years} complete)..." + + if queued_years: + return f"Queued {queued_years[0]} ({completed_count} of {total_years} complete)..." + + return f"Scoring budget window ({completed_count} of {total_years} complete)..." def _get_previous_impacts( self, @@ -324,7 +714,6 @@ def _get_previous_impacts( api_version, ) ) - return previous_impacts def _get_most_recent_impact( @@ -522,6 +911,7 @@ def _handle_create_impact( sim_api_execution = simulation_api.run(sim_params) execution_id = simulation_api.get_execution_id(sim_api_execution) + run_id = getattr(sim_api_execution, "run_id", None) or telemetry["run_id"] progress_log = { @@ -582,7 +972,7 @@ def _build_options_hash( data_version: str | None = None, policyengine_version: str | None = None, ) -> str: - option_pairs = "&".join([f"{k}={v}" for k, v in options.items()]) + option_pairs = "&".join(f"{key}={options[key]}" for key in sorted(options)) bundle_parts = [ f"dataset={dataset}", f"model_version={model_version}", diff --git a/policyengine_api/utils/budget_window.py b/policyengine_api/utils/budget_window.py new file mode 100644 index 000000000..6acbd5b3a --- /dev/null +++ b/policyengine_api/utils/budget_window.py @@ -0,0 +1,57 @@ +from dataclasses import dataclass +from typing import Literal + +BUDGET_WINDOW_MAX_ACTIVE_YEARS = 20 +BUDGET_WINDOW_MAX_YEARS = 75 +BUDGET_WINDOW_MAX_END_YEAR = 2099 + + +@dataclass(frozen=True) +class BudgetWindowRequestSetup: + start_year: str + window_size: int + years: list[str] + time_period: str + + +def build_budget_window_years(*, start_year: str, window_size: int) -> list[str]: + start_year_int = int(start_year) + return [str(start_year_int + index) for index in range(window_size)] + + +def build_budget_window_time_period(*, start_year: str, window_size: int) -> str: + return f"budget_window:{start_year}:{window_size}" + + +def build_budget_window_request_setup( + *, + start_year: str, + window_size: int, + target: Literal["general", "cliff"], +) -> BudgetWindowRequestSetup: + if target != "general": + raise ValueError("Budget-window calculations only support target='general'") + + start_year_int = int(start_year) + if not 1 <= window_size <= BUDGET_WINDOW_MAX_YEARS: + raise ValueError(f"window_size must be between 1 and {BUDGET_WINDOW_MAX_YEARS}") + + end_year = start_year_int + window_size - 1 + if end_year > BUDGET_WINDOW_MAX_END_YEAR: + raise ValueError( + f"budget-window end_year must be {BUDGET_WINDOW_MAX_END_YEAR} or earlier" + ) + + normalized_start_year = str(start_year_int) + return BudgetWindowRequestSetup( + start_year=normalized_start_year, + window_size=window_size, + years=build_budget_window_years( + start_year=normalized_start_year, + window_size=window_size, + ), + time_period=build_budget_window_time_period( + start_year=normalized_start_year, + window_size=window_size, + ), + ) diff --git a/tests/fixtures/libs/simulation_api_modal.py b/tests/fixtures/libs/simulation_api_modal.py index 6d514a7e5..a9b4ce45e 100644 --- a/tests/fixtures/libs/simulation_api_modal.py +++ b/tests/fixtures/libs/simulation_api_modal.py @@ -19,6 +19,7 @@ # Mock data constants MOCK_MODAL_JOB_ID = "fc-abc123xyz" MOCK_RUN_ID = "run-abc123xyz" +MOCK_BATCH_JOB_ID = "fc-batch123xyz" MOCK_MODAL_BASE_URL = "https://test-modal-api.modal.run" MOCK_SIMULATION_PAYLOAD = { @@ -87,6 +88,54 @@ MOCK_HEALTH_RESPONSE = {"status": "healthy"} +MOCK_BATCH_SUBMIT_RESPONSE_SUCCESS = { + "batch_job_id": MOCK_BATCH_JOB_ID, + "status": MODAL_EXECUTION_STATUS_SUBMITTED, + "poll_url": f"/budget-window-jobs/{MOCK_BATCH_JOB_ID}", + "country": "us", + "version": "1.459.0", +} + +MOCK_BATCH_POLL_RESPONSE_RUNNING = { + "status": MODAL_EXECUTION_STATUS_RUNNING, + "progress": 33, + "completed_years": ["2026"], + "running_years": ["2027"], + "queued_years": ["2028"], + "failed_years": [], + "result": None, + "error": None, +} + +MOCK_BATCH_POLL_RESPONSE_COMPLETE = { + "status": MODAL_EXECUTION_STATUS_COMPLETE, + "progress": 100, + "completed_years": ["2026", "2027", "2028"], + "running_years": [], + "queued_years": [], + "failed_years": [], + "result": { + "kind": "budgetWindow", + "startYear": "2026", + "endYear": "2028", + "windowSize": 3, + "annualImpacts": [], + "totals": {}, + }, + "error": None, +} + +MOCK_BATCH_POLL_RESPONSE_FAILED = { + "status": MODAL_EXECUTION_STATUS_FAILED, + "progress": 33, + "completed_years": ["2026"], + "running_years": [], + "queued_years": ["2028"], + "failed_years": ["2027"], + "result": None, + "error": "Budget window failed", +} + def create_mock_httpx_response( status_code: int = 200, diff --git a/tests/fixtures/services/economy_service.py b/tests/fixtures/services/economy_service.py index cf41873ed..49202132d 100644 --- a/tests/fixtures/services/economy_service.py +++ b/tests/fixtures/services/economy_service.py @@ -138,6 +138,7 @@ def mock_simulation_api(): """Mock SimulationAPIModal with all required methods.""" mock_api = MagicMock() mock_execution = create_mock_modal_execution() + mock_batch_execution = create_mock_budget_window_batch_execution() mock_api._setup_sim_options.return_value = MOCK_SIM_CONFIG mock_api.run.return_value = mock_execution @@ -149,6 +150,8 @@ def mock_simulation_api(): mock_api.get_execution_by_id.return_value = mock_execution mock_api.get_execution_status.return_value = MODAL_EXECUTION_STATUS_RUNNING mock_api.get_execution_result.return_value = MOCK_REFORM_IMPACT_DATA + mock_api.run_budget_window_batch.return_value = mock_batch_execution + mock_api.get_budget_window_batch_by_id.return_value = mock_batch_execution with patch( "policyengine_api.services.economy_service.simulation_api", mock_api @@ -156,6 +159,26 @@ def mock_simulation_api(): yield mock +@pytest.fixture +def mock_budget_window_cache(): + """Mock Redis-backed budget-window cache.""" + mock_cache = MagicMock() + mock_cache.build_key.return_value = "budget-window-cache-key" + mock_cache.get_completed_result.return_value = None + mock_cache.get_batch_job_id.return_value = None + mock_cache.claim_batch_start.return_value = True + mock_cache.store_batch_job_id.return_value = None + mock_cache.clear_starting_claim.return_value = None + mock_cache.set_completed_result.return_value = None + mock_cache.clear_batch_job_id.return_value = None + + with patch( + "policyengine_api.services.economy_service.budget_window_cache", + mock_cache, + ) as mock: + yield mock + + @pytest.fixture def mock_logger(): """Mock logger.""" @@ -187,6 +210,9 @@ def create_mock_reform_impact( reform_impact_json=None, execution_id=MOCK_MODAL_JOB_ID, options_hash=MOCK_OPTIONS_HASH, + start_time=None, + time_period=MOCK_TIME_PERIOD, + message=None, ): """Helper function to create mock reform impact records.""" default_reform_impact_json = json.dumps( @@ -208,13 +234,14 @@ def create_mock_reform_impact( "baseline_policy_id": MOCK_BASELINE_POLICY_ID, "region": MOCK_REGION, "dataset": MOCK_RESOLVED_DATASET, - "time_period": MOCK_TIME_PERIOD, + "time_period": time_period, "options_hash": options_hash, "status": status, "api_version": MOCK_API_VERSION, "reform_impact_json": reform_impact_json or default_reform_impact_json, "execution_id": execution_id, - "start_time": datetime.datetime(2025, 6, 26, 12, 0, 0), + "message": message, + "start_time": start_time or datetime.datetime(2025, 6, 26, 12, 0, 0), "end_time": ( datetime.datetime(2025, 6, 26, 12, 5, 0) if status == "ok" else None ), @@ -259,6 +286,32 @@ def create_mock_modal_execution( return mock_execution +def create_mock_budget_window_batch_execution( + batch_job_id=MOCK_MODAL_JOB_ID, + status=MODAL_EXECUTION_STATUS_SUBMITTED, + progress=None, + completed_years=None, + running_years=None, + queued_years=None, + failed_years=None, + result=None, + error=None, +): + """Helper function to create mock batch execution objects.""" + mock_execution = MagicMock() + mock_execution.batch_job_id = batch_job_id + mock_execution.name = batch_job_id + mock_execution.status = status + mock_execution.progress = progress + mock_execution.completed_years = completed_years or [] + mock_execution.running_years = running_years or [] + mock_execution.queued_years = queued_years or [] + mock_execution.failed_years = failed_years or [] + mock_execution.result = result + mock_execution.error = error + return mock_execution + + @pytest.fixture def mock_simulation_api_modal(): """Mock SimulationAPIModal with all required methods.""" diff --git a/tests/integration/test_budget_window_in_flight_dedupe.py b/tests/integration/test_budget_window_in_flight_dedupe.py new file mode 100644 index 000000000..7f608ebf1 --- /dev/null +++ b/tests/integration/test_budget_window_in_flight_dedupe.py @@ -0,0 +1,115 @@ +import os +from unittest.mock import MagicMock + +from flask import Flask + +os.environ.setdefault("POLICYENGINE_DB_PASSWORD", "test") + + +class FakeRedis: + def __init__(self): + self.values = {} + + def get(self, key): + return self.values.get(key) + + def set(self, key, value, nx=False, ex=None): + if nx and key in self.values: + return False + self.values[key] = value + return True + + def delete(self, key): + self.values.pop(key, None) + + +def _create_client(economy_bp): + app = Flask(__name__) + app.register_blueprint(economy_bp) + return app.test_client() + + +def test_budget_window_in_flight_dedupe_uses_existing_batch_without_live_db( + monkeypatch, +): + from policyengine_api.libs.simulation_api_modal import ( + ModalBudgetWindowBatchExecution, + ) + from policyengine_api.routes.economy_routes import economy_bp + from policyengine_api.services import economy_service as economy_service_module + from policyengine_api.services.budget_window_cache import BudgetWindowCache + + fake_cache = BudgetWindowCache(client=FakeRedis()) + simulation_api = MagicMock() + reform_impacts_service = MagicMock() + + simulation_api.run_budget_window_batch.return_value = ( + ModalBudgetWindowBatchExecution( + batch_job_id="fc-budget-window-parent", + status="submitted", + ) + ) + simulation_api.get_budget_window_batch_by_id.return_value = ( + ModalBudgetWindowBatchExecution( + batch_job_id="fc-budget-window-parent", + status="running", + progress=50, + completed_years=["2026"], + running_years=["2027"], + queued_years=["2028"], + ) + ) + + monkeypatch.setattr(economy_service_module, "budget_window_cache", fake_cache) + monkeypatch.setattr(economy_service_module, "simulation_api", simulation_api) + monkeypatch.setattr( + economy_service_module, + "reform_impacts_service", + reform_impacts_service, + ) + monkeypatch.setattr( + economy_service_module.EconomyService, + "_build_budget_window_batch_payload", + lambda self, **kwargs: { + "country_id": "us", + "start_year": kwargs["start_year"], + "window_size": kwargs["window_size"], + "max_parallel": kwargs["max_parallel"], + }, + ) + + client = _create_client(economy_bp) + path = "/us/economy/123/over/456/budget-window" + params = { + "region": "us", + "dataset": "hf://policyengine/test.h5@1.0", + "start_year": "2026", + "window_size": "3", + } + + first_response = client.get(path, query_string=params) + first_payload = first_response.get_json() + + assert first_response.status_code == 200 + assert first_response.headers["X-PolicyEngine-Budget-Window-Cache"] == "miss" + assert first_payload["status"] == "computing" + assert first_payload["queued_years"] == ["2026", "2027", "2028"] + + second_response = client.get(path, query_string=params) + second_payload = second_response.get_json() + + assert second_response.status_code == 200 + assert ( + second_response.headers["X-PolicyEngine-Budget-Window-Cache"] == "batch-id-hit" + ) + assert second_payload["status"] == "computing" + assert second_payload["progress"] == 50 + assert second_payload["completed_years"] == ["2026"] + assert second_payload["computing_years"] == ["2027"] + assert second_payload["queued_years"] == ["2028"] + + simulation_api.run_budget_window_batch.assert_called_once() + simulation_api.get_budget_window_batch_by_id.assert_called_once_with( + "fc-budget-window-parent" + ) + reform_impacts_service.assert_not_called() diff --git a/tests/integration/test_live_budget_window_cache.py b/tests/integration/test_live_budget_window_cache.py new file mode 100644 index 000000000..db57d3c39 --- /dev/null +++ b/tests/integration/test_live_budget_window_cache.py @@ -0,0 +1,163 @@ +import json +import os +import time +from pathlib import Path + + +INTEGRATION_TIMEOUT_SECONDS = float( + os.environ.get("STAGING_API_TEST_TIMEOUT_SECONDS", "900") +) +INTEGRATION_POLL_INTERVAL_SECONDS = float( + os.environ.get("STAGING_API_TEST_POLL_INTERVAL_SECONDS", "5") +) + + +def _load_reform_payload(filename: str) -> dict: + return json.loads( + (Path(__file__).resolve().parents[1] / "data" / filename).read_text( + encoding="utf-8" + ) + ) + + +def _poll_budget_window(api_client, path: str, params: dict) -> dict: + deadline = time.monotonic() + INTEGRATION_TIMEOUT_SECONDS + + while True: + response = api_client.get(path, params=params) + response.raise_for_status() + payload = response.json() + + if payload["status"] != "computing": + return payload + + assert time.monotonic() < deadline, ( + f"Timed out polling the budget-window route; last response was {payload}" + ) + time.sleep(INTEGRATION_POLL_INTERVAL_SECONDS) + + +def _get_current_law_id(api_client) -> int: + metadata_response = api_client.get("/us/metadata") + metadata_response.raise_for_status() + return metadata_response.json()["result"]["current_law_id"] + + +def _create_utah_reform_policy(api_client) -> int: + policy_response = api_client.post( + "/us/policy", + json=_load_reform_payload("utah_reform.json"), + ) + assert policy_response.status_code in (200, 201) + return policy_response.json()["result"]["policy_id"] + + +def test_live_budget_window_completed_result_cache(api_client, integration_probe_id): + current_law_id = _get_current_law_id(api_client) + policy_id = _create_utah_reform_policy(api_client) + + path = f"/us/economy/{policy_id}/over/{current_law_id}/budget-window" + params = { + "region": "ut", + "start_year": "2026", + "window_size": 1, + "staging_probe": f"{integration_probe_id}-budget-window-cache", + } + + first_payload = _poll_budget_window(api_client, path, params) + + assert first_payload["status"] == "ok", first_payload + assert first_payload["progress"] == 100, first_payload + assert first_payload["result"] is not None, first_payload + assert first_payload["result"]["kind"] == "budgetWindow", first_payload + assert first_payload["result"]["windowSize"] == 1, first_payload + + second_response = api_client.get(path, params=params) + second_response.raise_for_status() + second_payload = second_response.json() + + assert second_payload["status"] == "ok", second_payload + assert second_payload["progress"] == 100, second_payload + assert second_payload["result"] == first_payload["result"] + assert second_response.headers["X-PolicyEngine-Budget-Window-Cache"] == "result-hit" + + +def test_live_budget_window_multi_year_run(api_client, integration_probe_id): + current_law_id = _get_current_law_id(api_client) + policy_id = _create_utah_reform_policy(api_client) + + path = f"/us/economy/{policy_id}/over/{current_law_id}/budget-window" + params = { + "region": "ut", + "start_year": "2026", + "window_size": 2, + "staging_probe": f"{integration_probe_id}-budget-window-multi-year", + } + + payload = _poll_budget_window(api_client, path, params) + + assert payload["status"] == "ok", payload + assert payload["progress"] == 100, payload + result = payload["result"] + assert result is not None, payload + assert result["kind"] == "budgetWindow", payload + assert result["windowSize"] == 2, payload + assert result["startYear"] == "2026", payload + assert result["endYear"] == "2027", payload + assert [impact["year"] for impact in result["annualImpacts"]] == [ + "2026", + "2027", + ] + assert result["totals"]["year"] == "Total", payload + + +def test_live_budget_window_failed_batch_mapping(api_client, integration_probe_id): + current_law_id = _get_current_law_id(api_client) + policy_id = _create_utah_reform_policy(api_client) + + path = f"/us/economy/{policy_id}/over/{current_law_id}/budget-window" + params = { + "region": "ut", + "dataset": "hf://policyengine/nonexistent-budget-window-test.h5@0.0.0", + "start_year": "2026", + "window_size": 1, + "staging_probe": f"{integration_probe_id}-budget-window-failure", + } + + payload = _poll_budget_window(api_client, path, params) + + assert payload["status"] == "error", payload + assert payload["result"] is None, payload + assert payload["error"], payload + assert isinstance(payload["completed_years"], list), payload + assert isinstance(payload["computing_years"], list), payload + assert isinstance(payload["queued_years"], list), payload + + +def test_live_budget_window_in_flight_dedupe(api_client, integration_probe_id): + current_law_id = _get_current_law_id(api_client) + policy_id = _create_utah_reform_policy(api_client) + + path = f"/us/economy/{policy_id}/over/{current_law_id}/budget-window" + params = { + "region": "ut", + "start_year": "2026", + "window_size": 2, + "staging_probe": f"{integration_probe_id}-budget-window-in-flight", + } + + first_response = api_client.get(path, params=params) + first_response.raise_for_status() + first_payload = first_response.json() + + assert first_payload["status"] == "computing", first_payload + assert first_response.headers["X-PolicyEngine-Budget-Window-Cache"] == "miss" + + second_response = api_client.get(path, params=params) + second_response.raise_for_status() + second_payload = second_response.json() + + assert second_response.headers["X-PolicyEngine-Budget-Window-Cache"] == ( + "batch-id-hit" + ) + assert second_payload["status"] in ("computing", "ok"), second_payload diff --git a/tests/to_refactor/python/test_economy_budget_window_routes.py b/tests/to_refactor/python/test_economy_budget_window_routes.py new file mode 100644 index 000000000..9821502d7 --- /dev/null +++ b/tests/to_refactor/python/test_economy_budget_window_routes.py @@ -0,0 +1,215 @@ +import json +from unittest.mock import Mock, patch + + +def _mock_budget_window_result(cache_status=None): + mock_result = Mock() + mock_result.cache_status = cache_status + mock_result.to_dict.return_value = { + "status": "ok", + "message": None, + "data": { + "kind": "budgetWindow", + "startYear": "2026", + "endYear": "2027", + "windowSize": 2, + "annualImpacts": [], + "totals": {}, + }, + "progress": 100, + "completed_years": ["2026", "2027"], + "computing_years": [], + "queued_years": [], + "error": None, + } + return mock_result + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_rejects_cliff_target( + mock_get_budget_window_economic_impact, rest_client +): + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2026&window_size=10&target=cliff" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "target=general" in data["message"] + mock_get_budget_window_economic_impact.assert_not_called() + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_requires_region( + mock_get_budget_window_economic_impact, rest_client +): + response = rest_client.get( + "/us/economy/123/over/456/budget-window?start_year=2026&window_size=2" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert data["message"] == "Missing required query parameter: region" + mock_get_budget_window_economic_impact.assert_not_called() + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_requires_start_year( + mock_get_budget_window_economic_impact, rest_client +): + response = rest_client.get( + "/us/economy/123/over/456/budget-window?region=us&window_size=2" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert data["message"] == "Missing required query parameter: start_year" + mock_get_budget_window_economic_impact.assert_not_called() + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_requires_window_size( + mock_get_budget_window_economic_impact, rest_client +): + response = rest_client.get( + "/us/economy/123/over/456/budget-window?region=us&start_year=2026" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "window_size" in data["message"] + mock_get_budget_window_economic_impact.assert_not_called() + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_requires_integer_window_size( + mock_get_budget_window_economic_impact, rest_client +): + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2026&window_size=abc" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "window_size must be an integer" == data["message"] + mock_get_budget_window_economic_impact.assert_not_called() + + +def test_budget_window_route_rejects_oversized_window(rest_client): + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2026&window_size=999" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "window_size must be between 1 and" in data["message"] + + +def test_budget_window_route_rejects_end_year_after_2099(rest_client): + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2090&window_size=20" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "budget-window end_year must be 2099 or earlier" == data["message"] + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_passes_version_to_service( + mock_get_budget_window_economic_impact, rest_client +): + mock_result = _mock_budget_window_result() + mock_get_budget_window_economic_impact.return_value = mock_result + + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2026&window_size=2&version=1.2.3" + ) + + data = json.loads(response.data) + + assert response.status_code == 200 + assert data["status"] == "ok" + mock_get_budget_window_economic_impact.assert_called_once_with( + country_id="us", + policy_id=123, + baseline_policy_id=456, + region="us", + dataset="default", + start_year="2026", + window_size=2, + options={}, + api_version="1.2.3", + target="general", + ) + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_uses_breakdown_dataset_for_us_national_request( + mock_get_budget_window_economic_impact, rest_client +): + mock_get_budget_window_economic_impact.return_value = _mock_budget_window_result() + + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2026&window_size=2" + "&include_district_breakdowns=true" + ) + + assert response.status_code == 200 + mock_get_budget_window_economic_impact.assert_called_once() + assert ( + mock_get_budget_window_economic_impact.call_args.kwargs["dataset"] + == "national-with-breakdowns" + ) + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_sets_cache_status_header( + mock_get_budget_window_economic_impact, rest_client +): + mock_get_budget_window_economic_impact.return_value = _mock_budget_window_result( + cache_status="result-hit" + ) + + response = rest_client.get( + "/us/economy/123/over/456/budget-window?region=us&start_year=2026&window_size=2" + ) + + assert response.status_code == 200 + assert response.headers["X-PolicyEngine-Budget-Window-Cache"] == "result-hit" diff --git a/tests/unit/libs/test_simulation_api_modal.py b/tests/unit/libs/test_simulation_api_modal.py index 26b321135..300badee5 100644 --- a/tests/unit/libs/test_simulation_api_modal.py +++ b/tests/unit/libs/test_simulation_api_modal.py @@ -21,6 +21,7 @@ os.environ.setdefault("FLASK_DEBUG", "1") from policyengine_api.libs.simulation_api_modal import ( # noqa: E402 + ModalBudgetWindowBatchExecution, ModalSimulationExecution, SimulationAPIModal, ) @@ -32,6 +33,7 @@ ) from tests.fixtures.libs.simulation_api_modal import ( # noqa: E402 MOCK_MODAL_JOB_ID, + MOCK_BATCH_JOB_ID, MOCK_MODAL_BASE_URL, MOCK_SIMULATION_PAYLOAD, MOCK_SIMULATION_PAYLOAD_WITH_TELEMETRY, @@ -44,6 +46,10 @@ MOCK_POLL_RESPONSE_COMPLETE, MOCK_POLL_RESPONSE_FAILED, MOCK_HEALTH_RESPONSE, + MOCK_BATCH_SUBMIT_RESPONSE_SUCCESS, + MOCK_BATCH_POLL_RESPONSE_RUNNING, + MOCK_BATCH_POLL_RESPONSE_COMPLETE, + MOCK_BATCH_POLL_RESPONSE_FAILED, create_mock_httpx_response, ) @@ -117,6 +123,18 @@ def test__given_failed_execution__then_error_attribute_populated(self): assert execution.result is None +class TestModalBudgetWindowBatchExecution: + """Tests for the ModalBudgetWindowBatchExecution dataclass.""" + + def test__given_batch_job_id__then_name_returns_batch_job_id(self): + execution = ModalBudgetWindowBatchExecution( + batch_job_id=MOCK_BATCH_JOB_ID, + status=MODAL_EXECUTION_STATUS_SUBMITTED, + ) + + assert execution.name == MOCK_BATCH_JOB_ID + + class TestSimulationAPIModal: """Tests for the SimulationAPIModal class.""" @@ -272,6 +290,29 @@ def test__given_telemetry_payload__then_preserves_it_in_post_body( call_args = mock_httpx_client.post.call_args assert call_args[1]["json"]["_telemetry"]["run_id"] == MOCK_RUN_ID + def test__given_model_and_data_versions__then_translates_payload_for_modal( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.post.return_value = create_mock_httpx_response( + status_code=202, + json_data=MOCK_SUBMIT_RESPONSE_SUCCESS, + ) + payload = { + **MOCK_SIMULATION_PAYLOAD, + "model_version": "1.459.0", + "data_version": "1.77.0", + } + api = SimulationAPIModal() + + api.run(payload) + + posted_payload = mock_httpx_client.post.call_args.kwargs["json"] + assert posted_payload["version"] == "1.459.0" + assert "model_version" not in posted_payload + assert "data_version" not in posted_payload + def test__given_http_error__then_raises_exception( self, mock_httpx_client, @@ -300,7 +341,11 @@ def test__given_network_error__then_raises_exception( # When/Then with pytest.raises(httpx.RequestError): - api.run(MOCK_SIMULATION_PAYLOAD) + api.run(MOCK_SIMULATION_PAYLOAD_WITH_TELEMETRY) + + log_payload = mock_modal_logger.log_struct.call_args.args[0] + assert "Modal API request error" in log_payload["message"] + assert log_payload["run_id"] == MOCK_RUN_ID class TestResolveAppName: def test__given_country_and_version__then_returns_registered_app( @@ -322,6 +367,96 @@ def test__given_country_and_version__then_returns_registered_app( assert app_name == MOCK_RESOLVED_APP_NAME assert resolved_version == "1.459.0" + def test__given_unknown_version__then_raises_value_error( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=200, + json_data={ + "latest": "1.459.0", + "1.459.0": MOCK_RESOLVED_APP_NAME, + }, + ) + api = SimulationAPIModal() + + with pytest.raises( + ValueError, match="Unknown version 9.9.9 for country us" + ): + api.resolve_app_name("us", "9.9.9") + + class TestRunBudgetWindowBatch: + def test__given_valid_payload__then_returns_batch_execution( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.post.return_value = create_mock_httpx_response( + status_code=202, + json_data=MOCK_BATCH_SUBMIT_RESPONSE_SUCCESS, + ) + api = SimulationAPIModal() + + execution = api.run_budget_window_batch(MOCK_SIMULATION_PAYLOAD) + + assert execution.batch_job_id == MOCK_BATCH_JOB_ID + assert execution.status == MODAL_EXECUTION_STATUS_SUBMITTED + call_args = mock_httpx_client.post.call_args + assert "/simulate/economy/budget-window" in call_args[0][0] + + def test__given_model_and_data_versions__then_translates_payload_for_modal( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.post.return_value = create_mock_httpx_response( + status_code=202, + json_data=MOCK_BATCH_SUBMIT_RESPONSE_SUCCESS, + ) + payload = { + **MOCK_SIMULATION_PAYLOAD, + "model_version": "1.459.0", + "data_version": "1.77.0", + } + api = SimulationAPIModal() + + api.run_budget_window_batch(payload) + + posted_payload = mock_httpx_client.post.call_args.kwargs["json"] + assert posted_payload["version"] == "1.459.0" + assert "model_version" not in posted_payload + assert "data_version" not in posted_payload + + def test__given_http_error__then_raises_exception( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_response = create_mock_httpx_response( + status_code=400, + json_data={"error": "Invalid request"}, + ) + mock_httpx_client.post.return_value = mock_response + api = SimulationAPIModal() + + with pytest.raises(httpx.HTTPStatusError): + api.run_budget_window_batch(MOCK_SIMULATION_PAYLOAD) + + def test__given_network_error__then_raises_exception( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.post.side_effect = httpx.RequestError("Connection failed") + api = SimulationAPIModal() + + with pytest.raises(httpx.RequestError): + api.run_budget_window_batch(MOCK_SIMULATION_PAYLOAD_WITH_TELEMETRY) + + log_payload = mock_modal_logger.log_struct.call_args.args[0] + assert log_payload["run_id"] == MOCK_RUN_ID + class TestGetExecutionById: def test__given_running_job__then_returns_running_status( self, @@ -416,6 +551,101 @@ def test__given_unexpected_http_error__then_raises_exception( with pytest.raises(httpx.HTTPStatusError): api.get_execution_by_id(MOCK_MODAL_JOB_ID) + def test__given_network_error__then_raises_exception( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.get.side_effect = httpx.RequestError("Connection failed") + api = SimulationAPIModal() + + with pytest.raises(httpx.RequestError): + api.get_execution_by_id(MOCK_MODAL_JOB_ID) + + log_payload = mock_modal_logger.log_struct.call_args.args[0] + assert MOCK_MODAL_JOB_ID in log_payload["message"] + + class TestGetBudgetWindowBatchById: + def test__given_running_batch__then_returns_running_status( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=202, + json_data=MOCK_BATCH_POLL_RESPONSE_RUNNING, + ) + api = SimulationAPIModal() + + execution = api.get_budget_window_batch_by_id(MOCK_BATCH_JOB_ID) + + assert execution.batch_job_id == MOCK_BATCH_JOB_ID + assert execution.status == MODAL_EXECUTION_STATUS_RUNNING + assert execution.completed_years == ["2026"] + assert execution.running_years == ["2027"] + assert execution.queued_years == ["2028"] + + def test__given_complete_batch__then_returns_result( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=200, + json_data=MOCK_BATCH_POLL_RESPONSE_COMPLETE, + ) + api = SimulationAPIModal() + + execution = api.get_budget_window_batch_by_id(MOCK_BATCH_JOB_ID) + + assert execution.status == MODAL_EXECUTION_STATUS_COMPLETE + assert execution.result == MOCK_BATCH_POLL_RESPONSE_COMPLETE["result"] + + def test__given_failed_batch__then_returns_error( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=500, + json_data=MOCK_BATCH_POLL_RESPONSE_FAILED, + ) + api = SimulationAPIModal() + + execution = api.get_budget_window_batch_by_id(MOCK_BATCH_JOB_ID) + + assert execution.status == MODAL_EXECUTION_STATUS_FAILED + assert execution.failed_years == ["2027"] + assert execution.error == "Budget window failed" + + def test__given_unexpected_http_error__then_raises_exception( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=404, + json_data={"detail": "Budget-window job not found"}, + ) + api = SimulationAPIModal() + + with pytest.raises(httpx.HTTPStatusError): + api.get_budget_window_batch_by_id(MOCK_BATCH_JOB_ID) + + def test__given_network_error__then_raises_exception( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.get.side_effect = httpx.RequestError("Connection failed") + api = SimulationAPIModal() + + with pytest.raises(httpx.RequestError): + api.get_budget_window_batch_by_id(MOCK_BATCH_JOB_ID) + + log_payload = mock_modal_logger.log_struct.call_args.args[0] + assert MOCK_BATCH_JOB_ID in log_payload["message"] + class TestGetExecutionId: def test__given_execution__then_returns_job_id(self, mock_httpx_client): # Given diff --git a/tests/unit/services/test_budget_window_cache.py b/tests/unit/services/test_budget_window_cache.py new file mode 100644 index 000000000..f4677fae7 --- /dev/null +++ b/tests/unit/services/test_budget_window_cache.py @@ -0,0 +1,244 @@ +from unittest.mock import MagicMock + +import pytest + +from policyengine_api.services.budget_window_cache import BudgetWindowCache + + +class FakeRedis: + def __init__(self): + self.values = {} + + def get(self, key): + return self.values.get(key) + + def set(self, key, value, nx=False, ex=None): + if nx and key in self.values: + return False + self.values[key] = value + return True + + def delete(self, key): + self.values.pop(key, None) + + +class RaisingRedis: + def __init__(self, *, method): + self.method = method + + def get(self, key): + if self.method == "get": + raise RuntimeError("redis unavailable") + return None + + def set(self, key, value, nx=False, ex=None): + if self.method == "set": + raise RuntimeError("redis unavailable") + return True + + def delete(self, key): + if self.method == "delete": + raise RuntimeError("redis unavailable") + + +def test_build_key_is_stable_for_request_identity(): + cache = BudgetWindowCache(client=FakeRedis()) + + first = cache.build_key( + country_id="us", + reform_policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="budget_window:2026:10", + options_hash="[option=value]", + api_version="e1cache01", + ) + second = cache.build_key( + country_id="us", + reform_policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="budget_window:2026:10", + options_hash="[option=value]", + api_version="e1cache01", + ) + + assert first == second + assert first.startswith("budget_window:v1:us:") + + +def test_claim_batch_start_allows_one_starter(): + cache = BudgetWindowCache(client=FakeRedis()) + + assert cache.claim_batch_start("budget_window:v1:us:key", "process-1") is True + assert cache.claim_batch_start("budget_window:v1:us:key", "process-2") is False + assert cache.get_batch_job_id("budget_window:v1:us:key") is None + + +def test_store_batch_job_id_replaces_starting_claim(): + cache = BudgetWindowCache(client=FakeRedis()) + cache.claim_batch_start("budget_window:v1:us:key", "process-1") + + cache.store_batch_job_id("budget_window:v1:us:key", "fc-parent") + + assert cache.get_batch_job_id("budget_window:v1:us:key") == "fc-parent" + + +def test_completed_result_round_trips(): + cache = BudgetWindowCache(client=FakeRedis()) + result = {"kind": "budgetWindow", "totals": {"budgetaryImpact": 10}} + + cache.set_completed_result("budget_window:v1:us:key", result) + + assert cache.get_completed_result("budget_window:v1:us:key") == result + + +def test_get_completed_result_returns_none_for_empty_payload(): + redis_client = FakeRedis() + redis_client.values["budget_window:v1:us:key:result"] = "" + cache = BudgetWindowCache(client=redis_client) + + assert cache.get_completed_result("budget_window:v1:us:key") is None + + +def test_get_completed_result_returns_none_for_invalid_json(monkeypatch): + mock_logger = MagicMock() + monkeypatch.setattr( + "policyengine_api.services.budget_window_cache.logger", + mock_logger, + ) + redis_client = FakeRedis() + redis_client.values["budget_window:v1:us:key:result"] = "{not-json" + cache = BudgetWindowCache(client=redis_client) + + assert cache.get_completed_result("budget_window:v1:us:key") is None + assert mock_logger.log_struct.call_args.kwargs["severity"] == "WARNING" + + +def test_get_completed_result_reraises_read_errors(monkeypatch): + mock_logger = MagicMock() + monkeypatch.setattr( + "policyengine_api.services.budget_window_cache.logger", + mock_logger, + ) + cache = BudgetWindowCache(client=RaisingRedis(method="get")) + + with pytest.raises(RuntimeError, match="redis unavailable"): + cache.get_completed_result("budget_window:v1:us:key") + + assert mock_logger.log_struct.call_args.kwargs["severity"] == "WARNING" + + +def test_set_completed_result_reraises_write_errors(monkeypatch): + mock_logger = MagicMock() + monkeypatch.setattr( + "policyengine_api.services.budget_window_cache.logger", + mock_logger, + ) + cache = BudgetWindowCache(client=RaisingRedis(method="set")) + + with pytest.raises(RuntimeError, match="redis unavailable"): + cache.set_completed_result("budget_window:v1:us:key", {"ok": True}) + + assert mock_logger.log_struct.call_args.kwargs["severity"] == "WARNING" + + +def test_get_batch_job_id_ignores_empty_non_string_and_starting_values(): + redis_client = FakeRedis() + cache = BudgetWindowCache(client=redis_client) + + redis_client.values["budget_window:v1:us:key:batch_job_id"] = "" + assert cache.get_batch_job_id("budget_window:v1:us:key") is None + + redis_client.values["budget_window:v1:us:key:batch_job_id"] = 123 + assert cache.get_batch_job_id("budget_window:v1:us:key") is None + + redis_client.values["budget_window:v1:us:key:batch_job_id"] = "starting:process-1" + assert cache.get_batch_job_id("budget_window:v1:us:key") is None + + +def test_get_batch_job_id_reraises_read_errors(monkeypatch): + mock_logger = MagicMock() + monkeypatch.setattr( + "policyengine_api.services.budget_window_cache.logger", + mock_logger, + ) + cache = BudgetWindowCache(client=RaisingRedis(method="get")) + + with pytest.raises(RuntimeError, match="redis unavailable"): + cache.get_batch_job_id("budget_window:v1:us:key") + + assert mock_logger.log_struct.call_args.kwargs["severity"] == "WARNING" + + +def test_claim_batch_start_reraises_claim_errors(monkeypatch): + mock_logger = MagicMock() + monkeypatch.setattr( + "policyengine_api.services.budget_window_cache.logger", + mock_logger, + ) + cache = BudgetWindowCache(client=RaisingRedis(method="set")) + + with pytest.raises(RuntimeError, match="redis unavailable"): + cache.claim_batch_start("budget_window:v1:us:key", "process-1") + + assert mock_logger.log_struct.call_args.kwargs["severity"] == "WARNING" + + +def test_store_batch_job_id_reraises_write_errors(monkeypatch): + mock_logger = MagicMock() + monkeypatch.setattr( + "policyengine_api.services.budget_window_cache.logger", + mock_logger, + ) + cache = BudgetWindowCache(client=RaisingRedis(method="set")) + + with pytest.raises(RuntimeError, match="redis unavailable"): + cache.store_batch_job_id("budget_window:v1:us:key", "fc-parent") + + assert mock_logger.log_struct.call_args.kwargs["severity"] == "WARNING" + + +def test_clear_starting_claim_deletes_only_matching_token(): + redis_client = FakeRedis() + cache = BudgetWindowCache(client=redis_client) + cache.claim_batch_start("budget_window:v1:us:key", "process-1") + + cache.clear_starting_claim("budget_window:v1:us:key", "process-2") + + assert ( + redis_client.values["budget_window:v1:us:key:batch_job_id"] + == "starting:process-1" + ) + + cache.clear_starting_claim("budget_window:v1:us:key", "process-1") + + assert "budget_window:v1:us:key:batch_job_id" not in redis_client.values + + +def test_clear_starting_claim_logs_and_swallows_errors(monkeypatch): + mock_logger = MagicMock() + monkeypatch.setattr( + "policyengine_api.services.budget_window_cache.logger", + mock_logger, + ) + cache = BudgetWindowCache(client=RaisingRedis(method="get")) + + cache.clear_starting_claim("budget_window:v1:us:key", "process-1") + + assert mock_logger.log_struct.call_args.kwargs["severity"] == "WARNING" + + +def test_clear_batch_job_id_logs_and_swallows_errors(monkeypatch): + mock_logger = MagicMock() + monkeypatch.setattr( + "policyengine_api.services.budget_window_cache.logger", + mock_logger, + ) + cache = BudgetWindowCache(client=RaisingRedis(method="delete")) + + cache.clear_batch_job_id("budget_window:v1:us:key") + + assert mock_logger.log_struct.call_args.kwargs["severity"] == "WARNING" diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index d036ab296..5b1da4405 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -1,9 +1,65 @@ import json +import sys import pytest from unittest.mock import patch, MagicMock from typing import Literal +from types import ModuleType + +try: + from policyengine.simulation import SimulationOptions # noqa: F401 +except ModuleNotFoundError: + policyengine_module = sys.modules.setdefault( + "policyengine", ModuleType("policyengine") + ) + simulation_module = ModuleType("policyengine.simulation") + utils_module = ModuleType("policyengine.utils") + data_module = ModuleType("policyengine.utils.data") + datasets_module = ModuleType("policyengine.utils.data.datasets") + + class _StubSimulationOptions: + def __init__(self, payload): + self._payload = payload + + @classmethod + def model_validate(cls, payload): + return cls(payload) + + def model_dump(self): + return dict(self._payload) + + simulation_module.SimulationOptions = _StubSimulationOptions + policyengine_module.simulation = simulation_module + + def _stub_get_default_dataset(country, region): + if country == "us": + if region == "us": + return "gs://policyengine-us-data/enhanced_cps_2024.h5" + if region == "state/ca": + return "gs://policyengine-us-data/states/CA.h5" + if region == "state/ut": + return "gs://policyengine-us-data/states/UT.h5" + if region == "place/NJ-57000": + return "gs://policyengine-us-data/states/NJ.h5" + if region == "congressional_district/CA-37": + return "gs://policyengine-us-data/districts/CA-37.h5" + if country == "uk" and region == "uk": + return "gs://policyengine-uk-data-private/enhanced_frs_2023_24.h5" + raise ValueError( + f"Error getting default dataset for country={country}, region={region}: unsupported in test stub" + ) + + datasets_module.get_default_dataset = _stub_get_default_dataset + data_module.datasets = datasets_module + utils_module.data = data_module + policyengine_module.utils = utils_module + sys.modules["policyengine.simulation"] = simulation_module + sys.modules["policyengine.utils"] = utils_module + sys.modules["policyengine.utils.data"] = data_module + sys.modules["policyengine.utils.data.datasets"] = datasets_module from policyengine_api.services.economy_service import ( + BUDGET_WINDOW_MAX_END_YEAR, + BUDGET_WINDOW_MAX_YEARS, EconomyService, EconomicImpactResult, EconomicImpactSetupOptions, @@ -30,12 +86,30 @@ MOCK_REFORM_IMPACT_DATA, MOCK_RESOLVED_DATASET, MOCK_RESOLVED_APP_NAME, + create_mock_budget_window_batch_execution, create_mock_reform_impact, ) pytest_plugins = ("tests.fixtures.services.economy_service",) +def make_mock_budget_impact_data( + *, + tax_revenue_impact: int, + state_tax_revenue_impact: int, + benefit_spending_impact: int, + budgetary_impact: int, +): + return { + "budget": { + "tax_revenue_impact": tax_revenue_impact, + "state_tax_revenue_impact": state_tax_revenue_impact, + "benefit_spending_impact": benefit_spending_impact, + "budgetary_impact": budgetary_impact, + } + } + + class TestEconomyService: class TestGetEconomicImpact: @pytest.fixture @@ -617,6 +691,489 @@ def test__given_uk_request__preserves_model_version_in_bundle( sim_params = mock_simulation_api.run.call_args[0][0] assert sim_params["_metadata"]["model_version"] == "2.7.8" + class TestGetBudgetWindowEconomicImpact: + @pytest.fixture + def economy_service( + self, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + return EconomyService() + + @pytest.fixture + def base_params(self): + return { + "country_id": MOCK_COUNTRY_ID, + "policy_id": MOCK_POLICY_ID, + "baseline_policy_id": MOCK_BASELINE_POLICY_ID, + "region": MOCK_REGION, + "dataset": MOCK_DATASET, + "start_year": "2026", + "window_size": 3, + "options": MOCK_OPTIONS, + "api_version": MOCK_API_VERSION, + "target": "general", + } + + def test__given_no_cached_batch__submits_parent_batch_and_returns_queued_result( + self, + economy_service, + base_params, + mock_reform_impacts_service, + mock_simulation_api, + mock_budget_window_cache, + ): + batch_execution = create_mock_budget_window_batch_execution( + batch_job_id="fc-budget-123", + status="submitted", + ) + mock_simulation_api.run_budget_window_batch.return_value = batch_execution + + result = economy_service.get_budget_window_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + assert result.progress == 0 + assert result.completed_years == [] + assert result.computing_years == [] + assert result.queued_years == ["2026", "2027", "2028"] + assert result.cache_status == "miss" + assert "Queued 2026" in result.message + mock_simulation_api.run_budget_window_batch.assert_called_once() + submitted_payload = ( + mock_simulation_api.run_budget_window_batch.call_args.args[0] + ) + assert submitted_payload["start_year"] == "2026" + assert submitted_payload["window_size"] == 3 + assert submitted_payload["max_parallel"] == 20 + assert submitted_payload["target"] == "general" + assert "time_period" not in submitted_payload + mock_budget_window_cache.claim_batch_start.assert_called_once_with( + "budget-window-cache-key", MOCK_PROCESS_ID + ) + mock_budget_window_cache.store_batch_job_id.assert_called_once_with( + "budget-window-cache-key", "fc-budget-123" + ) + mock_reform_impacts_service.set_reform_impact.assert_not_called() + + def test__given_completed_cached_result__returns_completed_batch_result( + self, + economy_service, + base_params, + mock_simulation_api, + mock_budget_window_cache, + ): + completed_result = { + "kind": "budgetWindow", + "startYear": "2026", + "endYear": "2028", + "windowSize": 3, + "annualImpacts": [ + { + "year": "2026", + "taxRevenueImpact": 100, + "federalTaxRevenueImpact": 80, + "stateTaxRevenueImpact": 20, + "benefitSpendingImpact": -10, + "budgetaryImpact": 90, + } + ], + "totals": { + "year": "Total", + "taxRevenueImpact": 100, + "federalTaxRevenueImpact": 80, + "stateTaxRevenueImpact": 20, + "benefitSpendingImpact": -10, + "budgetaryImpact": 90, + }, + } + mock_budget_window_cache.get_completed_result.return_value = ( + completed_result + ) + + result = economy_service.get_budget_window_economic_impact(**base_params) + + assert result.status == ImpactStatus.OK + assert result.progress == 100 + assert result.data == completed_result + assert result.cache_status == "result-hit" + mock_simulation_api.get_budget_window_batch_by_id.assert_not_called() + mock_simulation_api.run_budget_window_batch.assert_not_called() + + def test__given_cached_batch_id__returns_running_batch_progress( + self, + economy_service, + base_params, + mock_simulation_api, + mock_budget_window_cache, + ): + mock_budget_window_cache.get_batch_job_id.return_value = "fc-budget-123" + mock_simulation_api.get_budget_window_batch_by_id.return_value = ( + create_mock_budget_window_batch_execution( + batch_job_id="fc-budget-123", + status="running", + progress=33, + completed_years=["2026"], + running_years=["2027"], + queued_years=["2028"], + ) + ) + + result = economy_service.get_budget_window_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + assert result.progress == 33 + assert result.completed_years == ["2026"] + assert result.computing_years == ["2027"] + assert result.queued_years == ["2028"] + assert result.cache_status == "batch-id-hit" + assert "1 of 3 complete" in result.message + mock_simulation_api.get_budget_window_batch_by_id.assert_called_once_with( + "fc-budget-123" + ) + + def test__given_completed_batch_poll__caches_result_and_returns_completed( + self, + economy_service, + base_params, + mock_simulation_api, + mock_budget_window_cache, + ): + completed_result = { + "kind": "budgetWindow", + "startYear": "2026", + "endYear": "2028", + "windowSize": 3, + "annualImpacts": [], + "totals": {}, + } + mock_budget_window_cache.get_batch_job_id.return_value = "fc-budget-123" + mock_simulation_api.get_budget_window_batch_by_id.return_value = ( + create_mock_budget_window_batch_execution( + batch_job_id="fc-budget-123", + status="complete", + progress=100, + completed_years=["2026", "2027", "2028"], + result=completed_result, + ) + ) + + result = economy_service.get_budget_window_economic_impact(**base_params) + + assert result.status == ImpactStatus.OK + assert result.data == completed_result + assert result.cache_status == "batch-id-hit" + mock_budget_window_cache.set_completed_result.assert_called_once_with( + "budget-window-cache-key", completed_result + ) + mock_budget_window_cache.clear_batch_job_id.assert_called_once_with( + "budget-window-cache-key" + ) + + @pytest.mark.parametrize("malformed_result", [None, {}, []]) + def test__given_completed_batch_without_result__returns_error_without_caching( + self, + economy_service, + base_params, + mock_simulation_api, + mock_budget_window_cache, + malformed_result, + ): + mock_budget_window_cache.get_batch_job_id.return_value = "fc-budget-123" + mock_simulation_api.get_budget_window_batch_by_id.return_value = ( + create_mock_budget_window_batch_execution( + batch_job_id="fc-budget-123", + status="complete", + progress=100, + completed_years=["2026", "2027"], + running_years=[], + queued_years=["2028"], + result=malformed_result, + ) + ) + + result = economy_service.get_budget_window_economic_impact(**base_params) + + assert result.status == ImpactStatus.ERROR + assert result.error == "Budget-window batch completed without a result" + assert result.completed_years == ["2026", "2027"] + assert result.computing_years == [] + assert result.queued_years == ["2028"] + assert result.cache_status == "batch-id-hit" + mock_budget_window_cache.set_completed_result.assert_not_called() + mock_budget_window_cache.clear_batch_job_id.assert_called_once_with( + "budget-window-cache-key" + ) + + def test__given_completed_batch_cache_write_fails__does_not_clear_batch_id( + self, + economy_service, + base_params, + mock_simulation_api, + mock_budget_window_cache, + ): + completed_result = { + "kind": "budgetWindow", + "startYear": "2026", + "endYear": "2028", + "windowSize": 3, + "annualImpacts": [], + "totals": {}, + } + mock_budget_window_cache.get_batch_job_id.return_value = "fc-budget-123" + mock_budget_window_cache.set_completed_result.side_effect = RuntimeError( + "redis unavailable" + ) + mock_simulation_api.get_budget_window_batch_by_id.return_value = ( + create_mock_budget_window_batch_execution( + batch_job_id="fc-budget-123", + status="complete", + progress=100, + completed_years=["2026", "2027", "2028"], + result=completed_result, + ) + ) + + with pytest.raises(RuntimeError, match="redis unavailable"): + economy_service.get_budget_window_economic_impact(**base_params) + + mock_budget_window_cache.clear_batch_job_id.assert_not_called() + + def test__given_failed_batch_poll__returns_failed( + self, + economy_service, + base_params, + mock_simulation_api, + mock_budget_window_cache, + ): + mock_budget_window_cache.get_batch_job_id.return_value = "fc-budget-123" + mock_simulation_api.get_budget_window_batch_by_id.return_value = ( + create_mock_budget_window_batch_execution( + batch_job_id="fc-budget-123", + status="failed", + progress=33, + completed_years=["2026"], + queued_years=["2028"], + failed_years=["2027"], + error="Budget window failed for 2027", + ) + ) + + result = economy_service.get_budget_window_economic_impact(**base_params) + + assert result.status == ImpactStatus.ERROR + assert result.error == "Budget window failed for 2027" + assert result.completed_years == ["2026"] + assert result.computing_years == [] + assert result.queued_years == ["2028"] + assert result.cache_status == "batch-id-hit" + mock_budget_window_cache.set_completed_result.assert_not_called() + mock_budget_window_cache.clear_batch_job_id.assert_called_once_with( + "budget-window-cache-key" + ) + + def test__given_existing_start_claim__does_not_submit_duplicate_batch( + self, + economy_service, + base_params, + mock_simulation_api, + mock_budget_window_cache, + ): + mock_budget_window_cache.claim_batch_start.return_value = False + + result = economy_service.get_budget_window_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + assert result.progress == 0 + assert result.queued_years == ["2026", "2027", "2028"] + assert result.cache_status == "starting-claim-hit" + mock_simulation_api.run_budget_window_batch.assert_not_called() + + def test__given_batch_submission_fails__clears_start_claim( + self, + economy_service, + base_params, + mock_simulation_api, + mock_budget_window_cache, + ): + mock_simulation_api.run_budget_window_batch.side_effect = RuntimeError( + "submit failed" + ) + + with pytest.raises(RuntimeError, match="submit failed"): + economy_service.get_budget_window_economic_impact(**base_params) + + mock_budget_window_cache.clear_starting_claim.assert_called_once_with( + "budget-window-cache-key", MOCK_PROCESS_ID + ) + + def test__given_cliff_target__raises_value_error( + self, economy_service, base_params + ): + base_params["target"] = "cliff" + + with pytest.raises( + ValueError, + match="Budget-window calculations only support target='general'", + ): + economy_service.get_budget_window_economic_impact(**base_params) + + def test__given_oversized_window__raises_value_error( + self, economy_service, base_params + ): + base_params["window_size"] = BUDGET_WINDOW_MAX_YEARS + 1 + + with pytest.raises( + ValueError, + match=(f"window_size must be between 1 and {BUDGET_WINDOW_MAX_YEARS}"), + ): + economy_service.get_budget_window_economic_impact(**base_params) + + def test__given_end_year_after_2099__raises_value_error( + self, economy_service, base_params + ): + base_params["start_year"] = "2090" + base_params["window_size"] = 20 + + with pytest.raises( + ValueError, + match=( + f"budget-window end_year must be {BUDGET_WINDOW_MAX_END_YEAR} or earlier" + ), + ): + economy_service.get_budget_window_economic_impact(**base_params) + + def test__given_runtime_cache_version__uses_versioned_cache_key_for_budget_window( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_simulation_api, + mock_budget_window_cache, + mock_logger, + mock_datetime, + mock_numpy_random, + monkeypatch, + ): + cache_version = "e1cache01" + + monkeypatch.setattr( + "policyengine_api.services.economy_service.get_economy_impact_cache_version", + lambda country_id, api_version=None: cache_version, + ) + result = economy_service.get_budget_window_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + cache_key_kwargs = mock_budget_window_cache.build_key.call_args.kwargs + assert cache_key_kwargs["time_period"] == "budget_window:2026:3" + assert cache_key_kwargs["api_version"] == cache_version + + def test__given_reordered_options__uses_same_budget_window_cache_identity( + self, + economy_service, + base_params, + mock_simulation_api, + mock_budget_window_cache, + ): + mock_budget_window_cache.get_completed_result.return_value = { + "kind": "budgetWindow" + } + + economy_service.get_budget_window_economic_impact( + **{ + **base_params, + "options": {"staging_probe": "abc", "analysis_mode": "x"}, + } + ) + first_cache_key_kwargs = dict( + mock_budget_window_cache.build_key.call_args.kwargs + ) + mock_budget_window_cache.build_key.reset_mock() + + economy_service.get_budget_window_economic_impact( + **{ + **base_params, + "options": {"analysis_mode": "x", "staging_probe": "abc"}, + } + ) + second_cache_key_kwargs = dict( + mock_budget_window_cache.build_key.call_args.kwargs + ) + + assert first_cache_key_kwargs == second_cache_key_kwargs + mock_simulation_api.run_budget_window_batch.assert_not_called() + + def test__given_legacy_us_region__normalizes_before_building_cache_key( + self, + economy_service, + base_params, + mock_reform_impacts_service, + mock_simulation_api, + mock_budget_window_cache, + ): + base_params["region"] = "ca" + base_params["dataset"] = "default" + + result = economy_service.get_budget_window_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + assert mock_budget_window_cache.build_key.call_args.kwargs["region"] == ( + "state/ca" + ) + + def test__given_unexpected_batch_status__raises_value_error( + self, + economy_service, + base_params, + mock_simulation_api, + mock_budget_window_cache, + ): + mock_budget_window_cache.get_batch_job_id.return_value = "fc-budget-123" + mock_simulation_api.get_budget_window_batch_by_id.return_value = ( + create_mock_budget_window_batch_execution( + batch_job_id="fc-budget-123", + status="paused", + ) + ) + + with pytest.raises( + ValueError, + match="Unexpected budget-window batch execution state: paused", + ): + economy_service.get_budget_window_economic_impact(**base_params) + + def test__given_missing_progress__computes_progress_from_completed_years( + self, economy_service + ): + result = economy_service._build_budget_window_computing_result( + total_years=4, + completed_years=["2026"], + computing_years=["2027", "2028", "2029"], + queued_years=[], + progress=None, + ) + + assert result.progress == 25 + assert result.message == "Scoring 2027, 2028 + 1 more (1 of 4 complete)..." + + def test__given_no_active_or_queued_years__uses_generic_progress_message( + self, economy_service + ): + result = economy_service._build_budget_window_computing_result( + total_years=3, + completed_years=["2026"], + computing_years=[], + queued_years=[], + progress=None, + ) + + assert result.progress == 33 + assert result.message == "Scoring budget window (1 of 3 complete)..." + class TestGetPreviousImpacts: @pytest.fixture def economy_service(self): @@ -1478,3 +2035,21 @@ def test__given_nonexistent_district__raises_value_error(self): with pytest.raises(ValueError) as exc_info: service._validate_us_region("congressional_district/cruft") assert "Invalid congressional district: 'cruft'" in str(exc_info.value) + + +class TestBuildOptionsHash: + def test__given_reordered_options__returns_same_hash(self): + service = EconomyService() + + first = service._build_options_hash( + options={"staging_probe": "abc", "analysis_mode": "x"}, + model_version="1.2.3", + dataset="hf://policyengine/test.h5@1.0", + ) + second = service._build_options_hash( + options={"analysis_mode": "x", "staging_probe": "abc"}, + model_version="1.2.3", + dataset="hf://policyengine/test.h5@1.0", + ) + + assert first == second diff --git a/tests/unit/utils/test_budget_window.py b/tests/unit/utils/test_budget_window.py new file mode 100644 index 000000000..b74d44092 --- /dev/null +++ b/tests/unit/utils/test_budget_window.py @@ -0,0 +1,73 @@ +import pytest + +from policyengine_api.utils.budget_window import ( + BUDGET_WINDOW_MAX_END_YEAR, + BUDGET_WINDOW_MAX_YEARS, + build_budget_window_request_setup, + build_budget_window_time_period, + build_budget_window_years, +) + + +def test_build_budget_window_years(): + assert build_budget_window_years(start_year="2026", window_size=3) == [ + "2026", + "2027", + "2028", + ] + + +def test_build_budget_window_time_period(): + assert ( + build_budget_window_time_period(start_year="2026", window_size=3) + == "budget_window:2026:3" + ) + + +def test_build_budget_window_request_setup_normalizes_start_year(): + setup = build_budget_window_request_setup( + start_year="02026", + window_size=2, + target="general", + ) + + assert setup.start_year == "2026" + assert setup.window_size == 2 + assert setup.years == ["2026", "2027"] + assert setup.time_period == "budget_window:2026:2" + + +def test_build_budget_window_request_setup_rejects_cliff_target(): + with pytest.raises( + ValueError, + match="Budget-window calculations only support target='general'", + ): + build_budget_window_request_setup( + start_year="2026", + window_size=2, + target="cliff", + ) + + +def test_build_budget_window_request_setup_rejects_oversized_window(): + with pytest.raises( + ValueError, + match=f"window_size must be between 1 and {BUDGET_WINDOW_MAX_YEARS}", + ): + build_budget_window_request_setup( + start_year="2026", + window_size=BUDGET_WINDOW_MAX_YEARS + 1, + target="general", + ) + + +def test_build_budget_window_request_setup_rejects_end_year_after_max(): + with pytest.raises( + ValueError, + match=f"budget-window end_year must be {BUDGET_WINDOW_MAX_END_YEAR} or earlier", + ): + build_budget_window_request_setup( + start_year="2090", + window_size=20, + target="general", + )