diff --git a/policyengine_api/libs/simulation_api_modal.py b/policyengine_api/libs/simulation_api_modal.py index cc220ef17..8dd06af74 100644 --- a/policyengine_api/libs/simulation_api_modal.py +++ b/policyengine_api/libs/simulation_api_modal.py @@ -6,7 +6,7 @@ """ import os -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional import httpx @@ -31,6 +31,28 @@ def name(self) -> str: return self.job_id +@dataclass +class ModalBudgetWindowBatchExecution: + """ + Represents a budget-window batch execution in the Modal simulation API. + """ + + batch_job_id: str + status: str + progress: Optional[int] = None + completed_years: list[str] = field(default_factory=list) + running_years: list[str] = field(default_factory=list) + queued_years: list[str] = field(default_factory=list) + failed_years: list[str] = field(default_factory=list) + result: Optional[dict] = None + error: Optional[str] = None + + @property + def name(self) -> str: + """Alias for batch_job_id.""" + return self.batch_job_id + + class SimulationAPIModal: """ HTTP client for the Modal Simulation API. @@ -106,10 +128,51 @@ def run(self, payload: dict) -> ModalSimulationExecution: ) raise + def run_budget_window_batch(self, payload: dict) -> ModalBudgetWindowBatchExecution: + """ + Submit a budget-window batch job to the Modal API. + """ + try: + modal_payload = dict(payload) + if "model_version" in modal_payload: + modal_payload["version"] = modal_payload.pop("model_version") + modal_payload.pop("data_version", None) + + response = self.client.post( + f"{self.base_url}/simulate/economy/budget-window", + json=modal_payload, + ) + response.raise_for_status() + data = response.json() + + logger.log_struct( + { + "message": "Modal budget-window batch submitted", + "batch_job_id": data.get("batch_job_id"), + "status": data.get("status"), + }, + severity="INFO", + ) + + return ModalBudgetWindowBatchExecution( + batch_job_id=data["batch_job_id"], + status=data["status"], + ) + + except httpx.HTTPStatusError as e: + logger.log_struct( + { + "message": f"Modal batch API HTTP error: {e.response.status_code}", + "response_text": e.response.text[:500], + }, + severity="ERROR", + ) + raise + except httpx.RequestError as e: logger.log_struct( { - "message": f"Modal API request error: {str(e)}", + "message": f"Modal batch API request error: {str(e)}", }, severity="ERROR", ) @@ -168,10 +231,44 @@ def get_execution_by_id(self, job_id: str) -> ModalSimulationExecution: ) raise + def get_budget_window_batch_by_id( + self, batch_job_id: str + ) -> ModalBudgetWindowBatchExecution: + """ + Poll the Modal API for the current status of a budget-window batch. + """ + try: + response = self.client.get( + f"{self.base_url}/budget-window-jobs/{batch_job_id}" + ) + data = response.json() + + return ModalBudgetWindowBatchExecution( + batch_job_id=batch_job_id, + status=data["status"], + progress=data.get("progress"), + completed_years=data.get("completed_years", []), + running_years=data.get("running_years", []), + queued_years=data.get("queued_years", []), + failed_years=data.get("failed_years", []), + result=data.get("result"), + error=data.get("error"), + ) + + except httpx.HTTPStatusError as e: + logger.log_struct( + { + "message": f"Modal batch API HTTP error polling job {batch_job_id}: {e.response.status_code}", + "response_text": e.response.text[:500], + }, + severity="ERROR", + ) + raise + except httpx.RequestError as e: logger.log_struct( { - "message": f"Modal API request error polling job {job_id}: {str(e)}", + "message": f"Modal batch API request error polling job {batch_job_id}: {str(e)}", }, severity="ERROR", ) diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 165e3f03d..7830a62e8 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -27,7 +27,6 @@ from pydantic import BaseModel, Field import numpy as np from enum import Enum -from concurrent.futures import ThreadPoolExecutor load_dotenv() @@ -59,8 +58,9 @@ class ImpactStatus(Enum): COMPLETE_STATUSES = [ImpactStatus.OK.value, ImpactStatus.ERROR.value] COMPUTING_STATUS = ImpactStatus.COMPUTING.value -BUDGET_WINDOW_MAX_ACTIVE_YEARS = 3 -BUDGET_WINDOW_MAX_YEARS = 20 +BUDGET_WINDOW_MAX_ACTIVE_YEARS = 20 +BUDGET_WINDOW_MAX_YEARS = 75 +BUDGET_WINDOW_MAX_END_YEAR = 2099 PENDING_EXECUTION_ID_PREFIX = "pending:" PROVISIONAL_CLAIM_TTL_SECONDS = 90 STALE_PROVISIONAL_IMPACT_MESSAGE = ( @@ -282,142 +282,265 @@ def get_budget_window_economic_impact( raise ValueError( f"window_size must be between 1 and {BUDGET_WINDOW_MAX_YEARS}" ) - - years = [str(start_year_int + index) for index in range(window_size)] - setup_options_by_year = { - year: self._build_economic_impact_setup_options( - country_id=country_id, - policy_id=policy_id, - baseline_policy_id=baseline_policy_id, - region=region, - dataset=dataset, - time_period=year, - options=dict(options), - api_version=api_version, - target=target, + end_year = start_year_int + window_size - 1 + if end_year > BUDGET_WINDOW_MAX_END_YEAR: + raise ValueError( + f"budget-window end_year must be {BUDGET_WINDOW_MAX_END_YEAR} or earlier" ) - for year in years - } - completed_impacts: dict[str, dict] = {} - computing_years: list[str] = [] - queued_years: list[str] = [] + start_year = str(start_year_int) + years = self._build_budget_window_years( + start_year=start_year, + window_size=window_size, + ) + tracking_setup_options = self._build_budget_window_tracking_setup_options( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + start_year=start_year, + window_size=window_size, + options=options, + api_version=api_version, + target=target, + ) - for year in years: - result = self._get_existing_economic_impact( - setup_options=setup_options_by_year[year] + most_recent_impact = self._get_most_recent_impact(tracking_setup_options) + if most_recent_impact is None: + self._start_budget_window_batch( + setup_options=tracking_setup_options, + start_year=start_year, + window_size=window_size, + max_parallel=max_active_years, + ) + return self._build_budget_window_computing_result( + total_years=len(years), + completed_years=[], + computing_years=[], + queued_years=years, + progress=0, ) - if result is None: - queued_years.append(year) - continue + return self._get_budget_window_result_from_tracking_impact( + setup_options=tracking_setup_options, + most_recent_impact=most_recent_impact, + total_years=len(years), + queued_years_on_submit=years, + ) + except Exception as e: + print(f"Error getting budget-window economic impact: {str(e)}") + raise e - if result.status == ImpactStatus.OK: - completed_impacts[year] = self._extract_budget_window_annual_impact( - year=year, impact_data=result.data or {} - ) - continue + def _build_budget_window_years( + self, + *, + start_year: str, + window_size: int, + ) -> list[str]: + start_year_int = int(start_year) + return [str(start_year_int + index) for index in range(window_size)] + + def _build_budget_window_tracking_time_period( + self, + *, + start_year: str, + window_size: int, + ) -> str: + return f"budget_window:{start_year}:{window_size}" + + def _build_budget_window_tracking_setup_options( + self, + *, + country_id: str, + policy_id: int, + baseline_policy_id: int, + region: str, + dataset: str, + start_year: str, + window_size: int, + options: dict, + api_version: str, + target: Literal["general", "cliff"], + ) -> EconomicImpactSetupOptions: + return self._build_economic_impact_setup_options( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + time_period=self._build_budget_window_tracking_time_period( + start_year=start_year, + window_size=window_size, + ), + options=dict(options), + api_version=api_version, + target=target, + ) + + def _build_budget_window_batch_payload( + self, + *, + setup_options: EconomicImpactSetupOptions, + start_year: str, + window_size: int, + max_parallel: int, + ) -> dict[str, Any]: + baseline_policy = policy_service.get_policy_json( + setup_options.country_id, + setup_options.baseline_policy_id, + ) + reform_policy = policy_service.get_policy_json( + setup_options.country_id, + setup_options.reform_policy_id, + ) + sim_config: SimulationOptions = self._setup_sim_options( + country_id=setup_options.country_id, + reform_policy=reform_policy, + baseline_policy=baseline_policy, + region=setup_options.region, + time_period=start_year, + dataset=setup_options.dataset, + scope="macro", + include_cliffs=False, + model_version=setup_options.model_version, + data_version=setup_options.data_version, + ) + sim_params = sim_config.model_dump() + sim_params.pop("time_period", None) + sim_params["start_year"] = start_year + sim_params["window_size"] = window_size + sim_params["max_parallel"] = max_parallel + sim_params["target"] = setup_options.target + return sim_params + + def _start_budget_window_batch( + self, + *, + setup_options: EconomicImpactSetupOptions, + start_year: str, + window_size: int, + max_parallel: int, + ) -> None: + sim_params = self._build_budget_window_batch_payload( + setup_options=setup_options, + start_year=start_year, + window_size=window_size, + max_parallel=max_parallel, + ) + + logger.log_struct( + { + "message": "Submitting budget-window batch job", + **setup_options.model_dump(), + "start_year": start_year, + "window_size": window_size, + "max_parallel": max_parallel, + }, + severity="INFO", + ) - if result.status == ImpactStatus.COMPUTING: - computing_years.append(year) - continue + batch_execution = simulation_api.run_budget_window_batch(sim_params) + self._set_reform_impact_computing( + setup_options=setup_options, + execution_id=batch_execution.batch_job_id, + ) - completed_years = [ - completed_year - for completed_year in years - if completed_year in completed_impacts - ] + def _get_budget_window_result_from_tracking_impact( + self, + *, + setup_options: EconomicImpactSetupOptions, + most_recent_impact: dict, + total_years: int, + queued_years_on_submit: list[str], + ) -> BudgetWindowEconomicImpactResult: + impact_status = most_recent_impact.get("status") + if impact_status == ImpactStatus.OK.value: + return BudgetWindowEconomicImpactResult.completed( + json.loads(most_recent_impact["reform_impact_json"]) + ) + + execution_id = most_recent_impact.get("execution_id") + if not execution_id: + return BudgetWindowEconomicImpactResult.failed( + most_recent_impact.get("message") + or "Budget-window batch tracking row is missing execution_id", + queued_years=queued_years_on_submit, + ) + + try: + batch_execution = simulation_api.get_budget_window_batch_by_id(execution_id) + except Exception: + if impact_status == ImpactStatus.ERROR.value: return BudgetWindowEconomicImpactResult.failed( - self._get_economic_impact_error_message( - result=result, - year=year, - ), - completed_years=completed_years, - computing_years=computing_years, - queued_years=queued_years, + most_recent_impact.get("message") or "Budget-window batch failed", + queued_years=queued_years_on_submit, ) + raise - available_slots = max(0, max_active_years - len(computing_years)) - years_to_start = queued_years[:available_slots] - remaining_queued_years = queued_years[available_slots:] - - if years_to_start: - max_workers = min(len(years_to_start), max_active_years) - with ThreadPoolExecutor(max_workers=max_workers) as executor: - future_year_pairs = [ - ( - year, - executor.submit( - self._get_or_create_economic_impact, - setup_options_by_year[year], - ), - ) - for year in years_to_start - ] - - for year, future in future_year_pairs: - result = future.result() - - if result.status == ImpactStatus.OK: - completed_impacts[year] = ( - self._extract_budget_window_annual_impact( - year=year, impact_data=result.data or {} - ) - ) - elif result.status == ImpactStatus.COMPUTING: - computing_years.append(year) - else: - completed_years = [ - completed_year - for completed_year in years - if completed_year in completed_impacts - ] - return BudgetWindowEconomicImpactResult.failed( - self._get_economic_impact_error_message( - result=result, - year=year, - ), - completed_years=completed_years, - computing_years=computing_years, - queued_years=remaining_queued_years, - ) - - completed_years = [ - completed_year - for completed_year in years - if completed_year in completed_impacts - ] + if batch_execution.status in EXECUTION_STATUSES_SUCCESS: + result = batch_execution.result or {} + self._set_reform_impact_complete( + setup_options=setup_options, + reform_impact_json=json.dumps(result), + execution_id=execution_id, + ) + return BudgetWindowEconomicImpactResult.completed(result) - if len(completed_years) == len(years): - ordered_annual_impacts = [ - completed_impacts[year] - for year in years - if year in completed_impacts - ] - return BudgetWindowEconomicImpactResult.completed( - self._build_budget_window_output( - start_year=start_year, - window_size=window_size, - annual_impacts=ordered_annual_impacts, - ) - ) + if batch_execution.status in EXECUTION_STATUSES_FAILURE: + error_message = batch_execution.error or ( + most_recent_impact.get("message") or "Budget-window batch failed" + ) + self._set_reform_impact_error( + setup_options=setup_options, + message=error_message, + execution_id=execution_id, + ) + return BudgetWindowEconomicImpactResult.failed( + error_message, + completed_years=batch_execution.completed_years, + computing_years=batch_execution.running_years, + queued_years=batch_execution.queued_years, + ) + + if batch_execution.status in EXECUTION_STATUSES_PENDING: + return self._build_budget_window_computing_result( + total_years=total_years, + completed_years=batch_execution.completed_years, + computing_years=batch_execution.running_years, + queued_years=batch_execution.queued_years, + progress=batch_execution.progress, + ) + + raise ValueError( + f"Unexpected budget-window batch execution state: {batch_execution.status}" + ) - progress = round((len(completed_years) / len(years)) * 100) - return BudgetWindowEconomicImpactResult.computing( - progress=progress, + def _build_budget_window_computing_result( + self, + *, + total_years: int, + completed_years: list[str], + computing_years: list[str], + queued_years: list[str], + progress: Optional[int] = None, + ) -> BudgetWindowEconomicImpactResult: + resolved_progress = progress + if resolved_progress is None: + resolved_progress = round((len(completed_years) / total_years) * 100) + + return BudgetWindowEconomicImpactResult.computing( + progress=resolved_progress, + completed_years=completed_years, + computing_years=computing_years, + queued_years=queued_years, + message=self._build_budget_window_progress_message( completed_years=completed_years, + total_years=total_years, computing_years=computing_years, - queued_years=remaining_queued_years, - message=self._build_budget_window_progress_message( - completed_years=completed_years, - total_years=len(years), - computing_years=computing_years, - queued_years=remaining_queued_years, - ), - ) - except Exception as e: - print(f"Error getting budget-window economic impact: {str(e)}") - raise e + queued_years=queued_years, + ), + ) def _build_economic_impact_setup_options( self, diff --git a/pyproject.toml b/pyproject.toml index ab926a5cd..a3cf5ebcb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,10 +38,10 @@ dependencies = [ "policyengine_canada==0.96.3", "policyengine-ng==0.5.1", "policyengine-il==0.1.0", - "policyengine_uk==2.39.0", - "policyengine_us==1.633.1", - "policyengine_core>=3.16.6", - "policyengine>=0.7.0", + "policyengine_uk==2.88.0", + "policyengine_us==1.653.3", + "policyengine_core>=3.23.5", + "policyengine>0.12.0,<1", "pydantic", "pymysql", "python-dotenv", diff --git a/tests/fixtures/libs/simulation_api_modal.py b/tests/fixtures/libs/simulation_api_modal.py index 64ce139e7..9abfe4050 100644 --- a/tests/fixtures/libs/simulation_api_modal.py +++ b/tests/fixtures/libs/simulation_api_modal.py @@ -18,6 +18,8 @@ # Mock data constants MOCK_MODAL_JOB_ID = "fc-abc123xyz" +MOCK_RUN_ID = "run-abc123xyz" +MOCK_BATCH_JOB_ID = "fc-batch123xyz" MOCK_MODAL_BASE_URL = "https://test-modal-api.modal.run" MOCK_SIMULATION_PAYLOAD = { @@ -65,6 +67,54 @@ MOCK_HEALTH_RESPONSE = {"status": "healthy"} +MOCK_BATCH_SUBMIT_RESPONSE_SUCCESS = { + "batch_job_id": MOCK_BATCH_JOB_ID, + "status": MODAL_EXECUTION_STATUS_SUBMITTED, + "poll_url": f"/budget-window-jobs/{MOCK_BATCH_JOB_ID}", + "country": "us", + "version": "1.459.0", +} + +MOCK_BATCH_POLL_RESPONSE_RUNNING = { + "status": MODAL_EXECUTION_STATUS_RUNNING, + "progress": 33, + "completed_years": ["2026"], + "running_years": ["2027"], + "queued_years": ["2028"], + "failed_years": [], + "result": None, + "error": None, +} + +MOCK_BATCH_POLL_RESPONSE_COMPLETE = { + "status": MODAL_EXECUTION_STATUS_COMPLETE, + "progress": 100, + "completed_years": ["2026", "2027", "2028"], + "running_years": [], + "queued_years": [], + "failed_years": [], + "result": { + "kind": "budgetWindow", + "startYear": "2026", + "endYear": "2028", + "windowSize": 3, + "annualImpacts": [], + "totals": {}, + }, + "error": None, +} + +MOCK_BATCH_POLL_RESPONSE_FAILED = { + "status": MODAL_EXECUTION_STATUS_FAILED, + "progress": 33, + "completed_years": ["2026"], + "running_years": [], + "queued_years": ["2028"], + "failed_years": ["2027"], + "result": None, + "error": "Budget window failed", +} + def create_mock_httpx_response( status_code: int = 200, diff --git a/tests/fixtures/services/economy_service.py b/tests/fixtures/services/economy_service.py index 14c566772..eac45bf27 100644 --- a/tests/fixtures/services/economy_service.py +++ b/tests/fixtures/services/economy_service.py @@ -109,6 +109,7 @@ def mock_simulation_api(): """Mock SimulationAPIModal with all required methods.""" mock_api = MagicMock() mock_execution = create_mock_modal_execution() + mock_batch_execution = create_mock_budget_window_batch_execution() mock_api._setup_sim_options.return_value = MOCK_SIM_CONFIG mock_api.run.return_value = mock_execution @@ -116,6 +117,8 @@ def mock_simulation_api(): mock_api.get_execution_by_id.return_value = mock_execution mock_api.get_execution_status.return_value = MODAL_EXECUTION_STATUS_RUNNING mock_api.get_execution_result.return_value = MOCK_REFORM_IMPACT_DATA + mock_api.run_budget_window_batch.return_value = mock_batch_execution + mock_api.get_budget_window_batch_by_id.return_value = mock_batch_execution with patch( "policyengine_api.services.economy_service.simulation_api", mock_api @@ -154,6 +157,8 @@ def create_mock_reform_impact( reform_impact_json=None, execution_id=MOCK_MODAL_JOB_ID, start_time=None, + time_period=MOCK_TIME_PERIOD, + message=None, ): """Helper function to create mock reform impact records.""" return { @@ -163,12 +168,13 @@ def create_mock_reform_impact( "baseline_policy_id": MOCK_BASELINE_POLICY_ID, "region": MOCK_REGION, "dataset": MOCK_DATASET, - "time_period": MOCK_TIME_PERIOD, + "time_period": time_period, "options_hash": MOCK_OPTIONS_HASH, "status": status, "api_version": MOCK_API_VERSION, "reform_impact_json": reform_impact_json or json.dumps(MOCK_REFORM_IMPACT_DATA), "execution_id": execution_id, + "message": message, "start_time": start_time or datetime.datetime(2025, 6, 26, 12, 0, 0), "end_time": ( datetime.datetime(2025, 6, 26, 12, 5, 0) if status == "ok" else None @@ -210,6 +216,32 @@ def create_mock_modal_execution( return mock_execution +def create_mock_budget_window_batch_execution( + batch_job_id=MOCK_MODAL_JOB_ID, + status=MODAL_EXECUTION_STATUS_SUBMITTED, + progress=None, + completed_years=None, + running_years=None, + queued_years=None, + failed_years=None, + result=None, + error=None, +): + """Helper function to create mock batch execution objects.""" + mock_execution = MagicMock() + mock_execution.batch_job_id = batch_job_id + mock_execution.name = batch_job_id + mock_execution.status = status + mock_execution.progress = progress + mock_execution.completed_years = completed_years or [] + mock_execution.running_years = running_years or [] + mock_execution.queued_years = queued_years or [] + mock_execution.failed_years = failed_years or [] + mock_execution.result = result + mock_execution.error = error + return mock_execution + + @pytest.fixture def mock_simulation_api_modal(): """Mock SimulationAPIModal with all required methods.""" diff --git a/tests/to_refactor/python/test_economy_budget_window_routes.py b/tests/to_refactor/python/test_economy_budget_window_routes.py index fca938948..f185f8c28 100644 --- a/tests/to_refactor/python/test_economy_budget_window_routes.py +++ b/tests/to_refactor/python/test_economy_budget_window_routes.py @@ -71,6 +71,19 @@ def test_budget_window_route_rejects_oversized_window(rest_client): assert "window_size must be between 1 and" in data["message"] +def test_budget_window_route_rejects_end_year_after_2099(rest_client): + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2090&window_size=20" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "budget-window end_year must be 2099 or earlier" == data["message"] + + @patch( "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" ) diff --git a/tests/unit/libs/test_simulation_api_modal.py b/tests/unit/libs/test_simulation_api_modal.py index d44dde8cb..657c0d914 100644 --- a/tests/unit/libs/test_simulation_api_modal.py +++ b/tests/unit/libs/test_simulation_api_modal.py @@ -10,6 +10,7 @@ import httpx from policyengine_api.libs.simulation_api_modal import ( + ModalBudgetWindowBatchExecution, SimulationAPIModal, ModalSimulationExecution, ) @@ -21,6 +22,7 @@ ) from tests.fixtures.libs.simulation_api_modal import ( MOCK_MODAL_JOB_ID, + MOCK_BATCH_JOB_ID, MOCK_MODAL_BASE_URL, MOCK_SIMULATION_PAYLOAD, MOCK_SIMULATION_RESULT, @@ -29,6 +31,10 @@ MOCK_POLL_RESPONSE_COMPLETE, MOCK_POLL_RESPONSE_FAILED, MOCK_HEALTH_RESPONSE, + MOCK_BATCH_SUBMIT_RESPONSE_SUCCESS, + MOCK_BATCH_POLL_RESPONSE_RUNNING, + MOCK_BATCH_POLL_RESPONSE_COMPLETE, + MOCK_BATCH_POLL_RESPONSE_FAILED, create_mock_httpx_response, mock_httpx_client, mock_modal_logger, @@ -86,6 +92,18 @@ def test__given_failed_execution__then_error_attribute_populated(self): assert execution.result is None +class TestModalBudgetWindowBatchExecution: + """Tests for the ModalBudgetWindowBatchExecution dataclass.""" + + def test__given_batch_job_id__then_name_returns_batch_job_id(self): + execution = ModalBudgetWindowBatchExecution( + batch_job_id=MOCK_BATCH_JOB_ID, + status=MODAL_EXECUTION_STATUS_SUBMITTED, + ) + + assert execution.name == MOCK_BATCH_JOB_ID + + class TestSimulationAPIModal: """Tests for the SimulationAPIModal class.""" @@ -187,6 +205,40 @@ def test__given_network_error__then_raises_exception( with pytest.raises(httpx.RequestError): api.run(MOCK_SIMULATION_PAYLOAD) + class TestRunBudgetWindowBatch: + def test__given_valid_payload__then_returns_batch_execution( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.post.return_value = create_mock_httpx_response( + status_code=202, + json_data=MOCK_BATCH_SUBMIT_RESPONSE_SUCCESS, + ) + api = SimulationAPIModal() + + execution = api.run_budget_window_batch(MOCK_SIMULATION_PAYLOAD) + + assert execution.batch_job_id == MOCK_BATCH_JOB_ID + assert execution.status == MODAL_EXECUTION_STATUS_SUBMITTED + call_args = mock_httpx_client.post.call_args + assert "/simulate/economy/budget-window" in call_args[0][0] + + def test__given_http_error__then_raises_exception( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_response = create_mock_httpx_response( + status_code=400, + json_data={"error": "Invalid request"}, + ) + mock_httpx_client.post.return_value = mock_response + api = SimulationAPIModal() + + with pytest.raises(httpx.HTTPStatusError): + api.run_budget_window_batch(MOCK_SIMULATION_PAYLOAD) + class TestGetExecutionById: def test__given_running_job__then_returns_running_status( self, @@ -265,6 +317,59 @@ def test__given_job_id__then_polls_correct_endpoint( call_args = mock_httpx_client.get.call_args assert f"/jobs/{MOCK_MODAL_JOB_ID}" in call_args[0][0] + class TestGetBudgetWindowBatchById: + def test__given_running_batch__then_returns_running_status( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=202, + json_data=MOCK_BATCH_POLL_RESPONSE_RUNNING, + ) + api = SimulationAPIModal() + + execution = api.get_budget_window_batch_by_id(MOCK_BATCH_JOB_ID) + + assert execution.batch_job_id == MOCK_BATCH_JOB_ID + assert execution.status == MODAL_EXECUTION_STATUS_RUNNING + assert execution.completed_years == ["2026"] + assert execution.running_years == ["2027"] + assert execution.queued_years == ["2028"] + + def test__given_complete_batch__then_returns_result( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=200, + json_data=MOCK_BATCH_POLL_RESPONSE_COMPLETE, + ) + api = SimulationAPIModal() + + execution = api.get_budget_window_batch_by_id(MOCK_BATCH_JOB_ID) + + assert execution.status == MODAL_EXECUTION_STATUS_COMPLETE + assert execution.result == MOCK_BATCH_POLL_RESPONSE_COMPLETE["result"] + + def test__given_failed_batch__then_returns_error( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=500, + json_data=MOCK_BATCH_POLL_RESPONSE_FAILED, + ) + api = SimulationAPIModal() + + execution = api.get_budget_window_batch_by_id(MOCK_BATCH_JOB_ID) + + assert execution.status == MODAL_EXECUTION_STATUS_FAILED + assert execution.failed_years == ["2027"] + assert execution.error == "Budget window failed" + class TestGetExecutionId: def test__given_execution__then_returns_job_id(self, mock_httpx_client): # Given diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index dbe25ffdf..63d0e8863 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -1,10 +1,65 @@ import datetime import json +import sys import pytest from unittest.mock import patch, MagicMock from typing import Literal +from types import ModuleType + +try: + from policyengine.simulation import SimulationOptions # noqa: F401 +except ModuleNotFoundError: + policyengine_module = sys.modules.setdefault( + "policyengine", ModuleType("policyengine") + ) + simulation_module = ModuleType("policyengine.simulation") + utils_module = ModuleType("policyengine.utils") + data_module = ModuleType("policyengine.utils.data") + datasets_module = ModuleType("policyengine.utils.data.datasets") + + class _StubSimulationOptions: + def __init__(self, payload): + self._payload = payload + + @classmethod + def model_validate(cls, payload): + return cls(payload) + + def model_dump(self): + return dict(self._payload) + + simulation_module.SimulationOptions = _StubSimulationOptions + policyengine_module.simulation = simulation_module + + def _stub_get_default_dataset(country, region): + if country == "us": + if region == "us": + return "gs://policyengine-us-data/enhanced_cps_2024.h5" + if region == "state/ca": + return "gs://policyengine-us-data/states/CA.h5" + if region == "state/ut": + return "gs://policyengine-us-data/states/UT.h5" + if region == "place/NJ-57000": + return "gs://policyengine-us-data/states/NJ.h5" + if region == "congressional_district/CA-37": + return "gs://policyengine-us-data/districts/CA-37.h5" + if country == "uk" and region == "uk": + return "gs://policyengine-uk-data-private/enhanced_frs_2023_24.h5" + raise ValueError( + f"Error getting default dataset for country={country}, region={region}: unsupported in test stub" + ) + + datasets_module.get_default_dataset = _stub_get_default_dataset + data_module.datasets = datasets_module + utils_module.data = data_module + policyengine_module.utils = utils_module + sys.modules["policyengine.simulation"] = simulation_module + sys.modules["policyengine.utils"] = utils_module + sys.modules["policyengine.utils.data"] = data_module + sys.modules["policyengine.utils.data.datasets"] = datasets_module from policyengine_api.services.economy_service import ( + BUDGET_WINDOW_MAX_END_YEAR, BUDGET_WINDOW_MAX_YEARS, EconomyService, EconomicImpactResult, @@ -28,6 +83,7 @@ MOCK_EXECUTION_ID, MOCK_PROCESS_ID, MOCK_REFORM_IMPACT_DATA, + create_mock_budget_window_batch_execution, create_mock_reform_impact, mock_country_package_versions, mock_datetime, @@ -561,7 +617,15 @@ def test__given_exception__raises_error( class TestGetBudgetWindowEconomicImpact: @pytest.fixture - def economy_service(self): + def economy_service( + self, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_logger, + mock_datetime, + mock_numpy_random, + ): return EconomyService() @pytest.fixture @@ -579,237 +643,208 @@ def base_params(self): "target": "general", } - def test__given_all_years_completed__returns_aggregated_budget_window_result( - self, economy_service, base_params + def test__given_no_tracking_row__submits_parent_batch_and_returns_queued_result( + self, + economy_service, + base_params, + mock_reform_impacts_service, + mock_simulation_api, ): - def make_setup(*, time_period, **_kwargs): - return EconomicImpactSetupOptions( - process_id=MOCK_PROCESS_ID, - country_id=MOCK_COUNTRY_ID, - reform_policy_id=MOCK_POLICY_ID, - baseline_policy_id=MOCK_BASELINE_POLICY_ID, - region=MOCK_REGION, - dataset=MOCK_DATASET, - time_period=time_period, - options=MOCK_OPTIONS, - api_version=MOCK_API_VERSION, - target="general", - options_hash=MOCK_OPTIONS_HASH, - ) + batch_execution = create_mock_budget_window_batch_execution( + batch_job_id="fc-budget-123", + status="submitted", + ) + mock_simulation_api.run_budget_window_batch.return_value = batch_execution - yearly_results = { - "2026": EconomicImpactResult.completed( - make_mock_budget_impact_data( - tax_revenue_impact=100, - state_tax_revenue_impact=20, - benefit_spending_impact=-10, - budgetary_impact=90, - ) - ), - "2027": EconomicImpactResult.completed( - make_mock_budget_impact_data( - tax_revenue_impact=120, - state_tax_revenue_impact=30, - benefit_spending_impact=-20, - budgetary_impact=100, - ) - ), - "2028": EconomicImpactResult.completed( - make_mock_budget_impact_data( - tax_revenue_impact=140, - state_tax_revenue_impact=40, - benefit_spending_impact=-30, - budgetary_impact=110, - ) - ), - } + result = economy_service.get_budget_window_economic_impact(**base_params) - with ( - patch.object( - economy_service, - "_build_economic_impact_setup_options", - side_effect=make_setup, - ), - patch.object( - economy_service, - "_get_existing_economic_impact", - side_effect=lambda setup_options: yearly_results[ - setup_options.time_period - ], - ) as mock_get_existing, - patch.object( - economy_service, "_get_or_create_economic_impact" - ) as mock_get_economic_impact, - ): - result = economy_service.get_budget_window_economic_impact( - **base_params - ) + assert result.status == ImpactStatus.COMPUTING + assert result.progress == 0 + assert result.completed_years == [] + assert result.computing_years == [] + assert result.queued_years == ["2026", "2027", "2028"] + assert "Queued 2026" in result.message + mock_simulation_api.run_budget_window_batch.assert_called_once() + submitted_payload = ( + mock_simulation_api.run_budget_window_batch.call_args.args[0] + ) + assert submitted_payload["start_year"] == "2026" + assert submitted_payload["window_size"] == 3 + assert submitted_payload["max_parallel"] == 20 + assert submitted_payload["target"] == "general" + assert "time_period" not in submitted_payload + mock_reform_impacts_service.set_reform_impact.assert_called_once() + assert ( + mock_reform_impacts_service.set_reform_impact.call_args.kwargs[ + "execution_id" + ] + == "fc-budget-123" + ) - assert result.status == ImpactStatus.OK - assert result.progress == 100 - assert result.data["annualImpacts"] == [ - { - "year": "2026", + def test__given_completed_tracking_row__returns_completed_batch_result( + self, + economy_service, + base_params, + mock_reform_impacts_service, + mock_simulation_api, + ): + completed_result = { + "kind": "budgetWindow", + "startYear": "2026", + "endYear": "2028", + "windowSize": 3, + "annualImpacts": [ + { + "year": "2026", + "taxRevenueImpact": 100, + "federalTaxRevenueImpact": 80, + "stateTaxRevenueImpact": 20, + "benefitSpendingImpact": -10, + "budgetaryImpact": 90, + } + ], + "totals": { + "year": "Total", "taxRevenueImpact": 100, "federalTaxRevenueImpact": 80, "stateTaxRevenueImpact": 20, "benefitSpendingImpact": -10, "budgetaryImpact": 90, }, - { - "year": "2027", - "taxRevenueImpact": 120, - "federalTaxRevenueImpact": 90, - "stateTaxRevenueImpact": 30, - "benefitSpendingImpact": -20, - "budgetaryImpact": 100, - }, - { - "year": "2028", - "taxRevenueImpact": 140, - "federalTaxRevenueImpact": 100, - "stateTaxRevenueImpact": 40, - "benefitSpendingImpact": -30, - "budgetaryImpact": 110, - }, - ] - assert result.data["totals"] == { - "year": "Total", - "taxRevenueImpact": 360, - "federalTaxRevenueImpact": 270, - "stateTaxRevenueImpact": 90, - "benefitSpendingImpact": -60, - "budgetaryImpact": 300, } - assert mock_get_existing.call_count == 3 - mock_get_economic_impact.assert_not_called() - - def test__given_missing_years__starts_only_up_to_remaining_active_slots( - self, economy_service, base_params - ): - def make_setup(*, time_period, **_kwargs): - return EconomicImpactSetupOptions( - process_id=MOCK_PROCESS_ID, - country_id=MOCK_COUNTRY_ID, - reform_policy_id=MOCK_POLICY_ID, - baseline_policy_id=MOCK_BASELINE_POLICY_ID, - region=MOCK_REGION, - dataset=MOCK_DATASET, - time_period=time_period, - options=MOCK_OPTIONS, - api_version=MOCK_API_VERSION, - target="general", - options_hash=MOCK_OPTIONS_HASH, + mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + create_mock_reform_impact( + status="ok", + execution_id="fc-budget-123", + reform_impact_json=json.dumps(completed_result), + time_period="budget_window:2026:3", ) + ] - base_params["window_size"] = 5 + result = economy_service.get_budget_window_economic_impact(**base_params) - existing_results = { - "2026": EconomicImpactResult.completed( - make_mock_budget_impact_data( - tax_revenue_impact=100, - state_tax_revenue_impact=20, - benefit_spending_impact=-10, - budgetary_impact=90, - ) - ), - "2027": EconomicImpactResult.computing(), - "2028": None, - "2029": None, - "2030": None, - } + assert result.status == ImpactStatus.OK + assert result.progress == 100 + assert result.data == completed_result + mock_simulation_api.get_budget_window_batch_by_id.assert_not_called() - with ( - patch.object( - economy_service, - "_build_economic_impact_setup_options", - side_effect=make_setup, - ), - patch.object( - economy_service, - "_get_existing_economic_impact", - side_effect=lambda setup_options: existing_results[ - setup_options.time_period - ], - ), - patch.object( - economy_service, - "_get_or_create_economic_impact", - return_value=EconomicImpactResult.computing(), - ) as mock_get_economic_impact, - ): - result = economy_service.get_budget_window_economic_impact( - **base_params + def test__given_running_tracking_row__returns_running_batch_progress( + self, + economy_service, + base_params, + mock_reform_impacts_service, + mock_simulation_api, + ): + mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + create_mock_reform_impact( + status="computing", + execution_id="fc-budget-123", + time_period="budget_window:2026:3", ) + ] + mock_simulation_api.get_budget_window_batch_by_id.return_value = ( + create_mock_budget_window_batch_execution( + batch_job_id="fc-budget-123", + status="running", + progress=33, + completed_years=["2026"], + running_years=["2027"], + queued_years=["2028"], + ) + ) + + result = economy_service.get_budget_window_economic_impact(**base_params) assert result.status == ImpactStatus.COMPUTING - assert result.progress == 20 + assert result.progress == 33 assert result.completed_years == ["2026"] - assert result.computing_years == ["2027", "2028", "2029"] - assert result.queued_years == ["2030"] - assert "1 of 5 complete" in result.message - assert mock_get_economic_impact.call_count == 2 - started_years = sorted( - call.args[0].time_period - for call in mock_get_economic_impact.call_args_list - ) - assert started_years == ["2028", "2029"] - - def test__given_year_error__returns_budget_window_error( - self, economy_service, base_params, mock_logger + assert result.computing_years == ["2027"] + assert result.queued_years == ["2028"] + assert "1 of 3 complete" in result.message + + def test__given_completed_batch_poll__persists_result_and_returns_completed( + self, + economy_service, + base_params, + mock_reform_impacts_service, + mock_simulation_api, ): - def make_setup(*, time_period, **_kwargs): - return EconomicImpactSetupOptions( - process_id=MOCK_PROCESS_ID, - country_id=MOCK_COUNTRY_ID, - reform_policy_id=MOCK_POLICY_ID, - baseline_policy_id=MOCK_BASELINE_POLICY_ID, - region=MOCK_REGION, - dataset=MOCK_DATASET, - time_period=time_period, - options=MOCK_OPTIONS, - api_version=MOCK_API_VERSION, - target="general", - options_hash=MOCK_OPTIONS_HASH, + completed_result = { + "kind": "budgetWindow", + "startYear": "2026", + "endYear": "2028", + "windowSize": 3, + "annualImpacts": [], + "totals": {}, + } + mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + create_mock_reform_impact( + status="computing", + execution_id="fc-budget-123", + time_period="budget_window:2026:3", ) + ] + mock_simulation_api.get_budget_window_batch_by_id.return_value = ( + create_mock_budget_window_batch_execution( + batch_job_id="fc-budget-123", + status="complete", + progress=100, + completed_years=["2026", "2027", "2028"], + result=completed_result, + ) + ) - with ( - patch.object( - economy_service, - "_build_economic_impact_setup_options", - side_effect=make_setup, - ), - patch.object( - economy_service, - "_get_existing_economic_impact", - side_effect=[ - EconomicImpactResult.completed( - make_mock_budget_impact_data( - tax_revenue_impact=100, - state_tax_revenue_impact=20, - benefit_spending_impact=-10, - budgetary_impact=90, - ) - ), - EconomicImpactResult( - status=ImpactStatus.ERROR, - data={"message": "Calculation failed for 2027"}, - ), - None, - ], - ), - patch.object( - economy_service, "_get_or_create_economic_impact" - ) as mock_get_economic_impact, - ): - result = economy_service.get_budget_window_economic_impact( - **base_params + result = economy_service.get_budget_window_economic_impact(**base_params) + + assert result.status == ImpactStatus.OK + assert result.data == completed_result + mock_reform_impacts_service.set_complete_reform_impact.assert_called_once() + call_kwargs = ( + mock_reform_impacts_service.set_complete_reform_impact.call_args.kwargs + ) + assert call_kwargs["execution_id"] == "fc-budget-123" + assert json.loads(call_kwargs["reform_impact_json"]) == completed_result + + def test__given_failed_batch_poll__persists_error_and_returns_failed( + self, + economy_service, + base_params, + mock_reform_impacts_service, + mock_simulation_api, + ): + mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + create_mock_reform_impact( + status="computing", + execution_id="fc-budget-123", + time_period="budget_window:2026:3", ) + ] + mock_simulation_api.get_budget_window_batch_by_id.return_value = ( + create_mock_budget_window_batch_execution( + batch_job_id="fc-budget-123", + status="failed", + progress=33, + completed_years=["2026"], + queued_years=["2028"], + failed_years=["2027"], + error="Budget window failed for 2027", + ) + ) + + result = economy_service.get_budget_window_economic_impact(**base_params) assert result.status == ImpactStatus.ERROR - assert result.error == "Calculation failed for 2027" + assert result.error == "Budget window failed for 2027" assert result.completed_years == ["2026"] - mock_get_economic_impact.assert_not_called() + assert result.computing_years == [] + assert result.queued_years == ["2028"] + mock_reform_impacts_service.set_error_reform_impact.assert_called_once() + assert ( + mock_reform_impacts_service.set_error_reform_impact.call_args.kwargs[ + "execution_id" + ] + == "fc-budget-123" + ) def test__given_cliff_target__raises_value_error( self, economy_service, base_params @@ -833,32 +868,44 @@ def test__given_oversized_window__raises_value_error( ): economy_service.get_budget_window_economic_impact(**base_params) - def test__given_started_year_error__returns_specific_budget_window_error( - self, economy_service, base_params, mock_logger + def test__given_end_year_after_2099__raises_value_error( + self, economy_service, base_params ): - with ( - patch.object( - economy_service, - "_get_existing_economic_impact", - side_effect=[None, None, None], - ), - patch.object( - economy_service, - "_get_or_create_economic_impact", - side_effect=[ - EconomicImpactResult.error("Calculation failed for 2026"), - EconomicImpactResult.computing(), - EconomicImpactResult.computing(), - ], + base_params["start_year"] = "2090" + base_params["window_size"] = 20 + + with pytest.raises( + ValueError, + match=( + f"budget-window end_year must be {BUDGET_WINDOW_MAX_END_YEAR} or earlier" ), ): - result = economy_service.get_budget_window_economic_impact( - **base_params + economy_service.get_budget_window_economic_impact(**base_params) + + def test__given_failed_tracking_row_and_unavailable_batch__returns_stored_error( + self, + economy_service, + base_params, + mock_reform_impacts_service, + mock_simulation_api, + ): + mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + create_mock_reform_impact( + status="error", + execution_id="fc-budget-123", + time_period="budget_window:2026:3", + message="Stored batch failure", ) + ] + mock_simulation_api.get_budget_window_batch_by_id.side_effect = Exception( + "batch lookup failed" + ) + + result = economy_service.get_budget_window_economic_impact(**base_params) assert result.status == ImpactStatus.ERROR - assert result.error == "Calculation failed for 2026" - assert result.completed_years == [] + assert result.error == "Stored batch failure" + assert result.queued_years == ["2026", "2027", "2028"] def test__given_runtime_cache_version__uses_versioned_cache_key_for_budget_window( self, @@ -866,59 +913,37 @@ def test__given_runtime_cache_version__uses_versioned_cache_key_for_budget_windo base_params, mock_country_package_versions, mock_get_dataset_version, + mock_reform_impacts_service, + mock_simulation_api, mock_logger, mock_datetime, mock_numpy_random, monkeypatch, ): cache_version = "e1cache01" - seen_existing_calls = [] - seen_create_calls = [] monkeypatch.setattr( "policyengine_api.services.economy_service.get_economy_impact_cache_version", lambda country_id, api_version=None: cache_version, ) - - def fake_get_existing(setup_options): - seen_existing_calls.append( - (setup_options.time_period, setup_options.api_version) - ) - return None - - def fake_get_or_create(setup_options): - seen_create_calls.append( - (setup_options.time_period, setup_options.api_version) - ) - return EconomicImpactResult.computing() - - with ( - patch.object( - economy_service, - "_get_existing_economic_impact", - side_effect=fake_get_existing, - ), - patch.object( - economy_service, - "_get_or_create_economic_impact", - side_effect=fake_get_or_create, - ), - ): - result = economy_service.get_budget_window_economic_impact( - **base_params - ) + result = economy_service.get_budget_window_economic_impact(**base_params) assert result.status == ImpactStatus.COMPUTING - assert seen_existing_calls == [ - ("2026", cache_version), - ("2027", cache_version), - ("2028", cache_version), - ] - assert seen_create_calls == [ - ("2026", cache_version), - ("2027", cache_version), - ("2028", cache_version), - ] + mock_reform_impacts_service.get_all_reform_impacts.assert_called_once() + assert ( + mock_reform_impacts_service.get_all_reform_impacts.call_args.args[5] + == "budget_window:2026:3" + ) + assert ( + mock_reform_impacts_service.get_all_reform_impacts.call_args.args[7] + == cache_version + ) + assert ( + mock_reform_impacts_service.set_reform_impact.call_args.kwargs[ + "api_version" + ] + == cache_version + ) class TestGetPreviousImpacts: @pytest.fixture