From b20f0e42ba9b83910bfaf6bbdbd3d2eefd24e0e4 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Tue, 10 Mar 2026 05:52:09 -0400 Subject: [PATCH 01/27] Log exceptions instead of silently swallowing them in calculate The axes code path silently discarded all exceptions (`pass`), causing variables like NJ gross income to return null with no error trace. Now logs the full traceback via logging.exception(). Fixes #3322 Co-Authored-By: Claude Opus 4.6 --- changelog.d/fix-silent-exception-swallowing.fixed.md | 1 + policyengine_api/country.py | 7 +++---- 2 files changed, 4 insertions(+), 4 deletions(-) create mode 100644 changelog.d/fix-silent-exception-swallowing.fixed.md diff --git a/changelog.d/fix-silent-exception-swallowing.fixed.md b/changelog.d/fix-silent-exception-swallowing.fixed.md new file mode 100644 index 000000000..4b10062e5 --- /dev/null +++ b/changelog.d/fix-silent-exception-swallowing.fixed.md @@ -0,0 +1 @@ +Log exceptions instead of silently swallowing them during household calculations. diff --git a/policyengine_api/country.py b/policyengine_api/country.py index 4278637d8..430df888c 100644 --- a/policyengine_api/country.py +++ b/policyengine_api/country.py @@ -1,5 +1,6 @@ import importlib import inspect +import logging import json from policyengine_core.taxbenefitsystems import TaxBenefitSystem from typing import Union, Optional @@ -429,11 +430,9 @@ def calculate( entity_result ) except Exception as e: - if "axes" in household: - pass - else: + logging.exception(f"Error computing {variable_name} for {entity_id}") + if "axes" not in household: household[entity_plural][entity_id][variable_name][period] = None - print(f"Error computing {variable_name} for {entity_id}: {e}") tracer_output = simulation.tracer.computation_log log_lines = tracer_output.lines(aggregate=False, max_depth=10) From 8e81079389172b1f1505fe13ba7e3f9064aafed9 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Wed, 8 Apr 2026 20:59:21 -0400 Subject: [PATCH 02/27] Add budget window batch economy endpoint --- changelog.d/budget-window-batch.fixed.md | 1 + policyengine_api/routes/economy_routes.py | 67 +++ policyengine_api/services/economy_service.py | 428 ++++++++++++++++++- tests/unit/services/test_economy_service.py | 269 ++++++++++++ 4 files changed, 764 insertions(+), 1 deletion(-) create mode 100644 changelog.d/budget-window-batch.fixed.md diff --git a/changelog.d/budget-window-batch.fixed.md b/changelog.d/budget-window-batch.fixed.md new file mode 100644 index 000000000..9bd03006f --- /dev/null +++ b/changelog.d/budget-window-batch.fixed.md @@ -0,0 +1 @@ +Added a budget-window economy endpoint that batches yearly impact calculations with bounded server-side concurrency and returns aggregated progress plus totals. diff --git a/policyengine_api/routes/economy_routes.py b/policyengine_api/routes/economy_routes.py index 1807416f2..27fe5125f 100644 --- a/policyengine_api/routes/economy_routes.py +++ b/policyengine_api/routes/economy_routes.py @@ -2,6 +2,7 @@ from policyengine_api.services.economy_service import ( EconomyService, EconomicImpactResult, + BudgetWindowEconomicImpactResult, ) from policyengine_api.utils import get_current_law_policy_id from policyengine_api.utils.payload_validators import validate_country @@ -67,3 +68,69 @@ def get_economic_impact(country_id: str, policy_id: int, baseline_policy_id: int status=200, mimetype="application/json", ) + + +@validate_country +@economy_bp.route( + "//economy//over//budget-window", + methods=["GET"], +) +def get_budget_window_economic_impact( + country_id: str, policy_id: int, baseline_policy_id: int +): + policy_id = int(policy_id or get_current_law_policy_id(country_id)) + baseline_policy_id = int( + baseline_policy_id or get_current_law_policy_id(country_id) + ) + + query_parameters = request.args + options = dict(query_parameters) + options = json.loads(json.dumps(options)) + region = options.pop("region") + dataset = options.pop("dataset", "default") + start_year = options.pop("start_year") + window_size = int(options.pop("window_size")) + + include_district_breakdowns_raw = options.pop( + "include_district_breakdowns", "false" + ) + include_district_breakdowns = include_district_breakdowns_raw.lower() == "true" + if include_district_breakdowns and country_id == "us" and region == "us": + dataset = "national-with-breakdowns" + + target: Literal["general", "cliff"] = options.pop("target", "general") + api_version = options.pop("version", COUNTRY_PACKAGE_VERSIONS.get(country_id)) + + economic_impact_result: BudgetWindowEconomicImpactResult = ( + economy_service.get_budget_window_economic_impact( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + start_year=start_year, + window_size=window_size, + options=options, + api_version=api_version, + target=target, + ) + ) + + result_dict = economic_impact_result.to_dict() + + return Response( + json.dumps( + { + "status": result_dict["status"], + "message": result_dict["message"], + "result": result_dict["data"], + "progress": result_dict["progress"], + "completed_years": result_dict["completed_years"], + "computing_years": result_dict["computing_years"], + "queued_years": result_dict["queued_years"], + "error": result_dict["error"], + } + ), + status=200, + mimetype="application/json", + ) diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 871c896cc..1a0f9a281 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -28,9 +28,10 @@ import uuid from typing import Literal, Any, Optional, Annotated from dotenv import load_dotenv -from pydantic import BaseModel +from pydantic import BaseModel, Field import numpy as np from enum import Enum +from concurrent.futures import ThreadPoolExecutor load_dotenv() @@ -71,6 +72,7 @@ class ImpactStatus(Enum): COMPLETE_STATUSES = [ImpactStatus.OK.value, ImpactStatus.ERROR.value] COMPUTING_STATUS = ImpactStatus.COMPUTING.value +BUDGET_WINDOW_MAX_ACTIVE_YEARS = 3 class EconomicImpactSetupOptions(BaseModel): @@ -134,6 +136,79 @@ def error(cls, message: str) -> "EconomicImpactResult": return cls(status=ImpactStatus.ERROR, data=None) +class BudgetWindowEconomicImpactResult(BaseModel): + """ + Model for a batch budget-window economic impact response. + """ + + status: ImpactStatus + data: Optional[dict] = None + progress: Optional[int] = None + completed_years: list[str] = Field(default_factory=list) + computing_years: list[str] = Field(default_factory=list) + queued_years: list[str] = Field(default_factory=list) + message: Optional[str] = None + error: Optional[str] = None + + model_config = {"frozen": True} + + def to_dict(self) -> dict[str, Any]: + return { + "status": self.status.value, + "data": self.data, + "progress": self.progress, + "completed_years": self.completed_years, + "computing_years": self.computing_years, + "queued_years": self.queued_years, + "message": self.message, + "error": self.error, + } + + @classmethod + def completed(cls, data: dict) -> "BudgetWindowEconomicImpactResult": + return cls(status=ImpactStatus.OK, data=data, progress=100) + + @classmethod + def computing( + cls, + *, + progress: int, + completed_years: list[str], + computing_years: list[str], + queued_years: list[str], + message: str, + ) -> "BudgetWindowEconomicImpactResult": + return cls( + status=ImpactStatus.COMPUTING, + data=None, + progress=progress, + completed_years=completed_years, + computing_years=computing_years, + queued_years=queued_years, + message=message, + ) + + @classmethod + def failed( + cls, + message: str, + *, + completed_years: Optional[list[str]] = None, + computing_years: Optional[list[str]] = None, + queued_years: Optional[list[str]] = None, + ) -> "BudgetWindowEconomicImpactResult": + logger.log_struct({"message": message}, severity="ERROR") + return cls( + status=ImpactStatus.ERROR, + data=None, + completed_years=completed_years or [], + computing_years=computing_years or [], + queued_years=queued_years or [], + message=message, + error=message, + ) + + class EconomyService: """ Service for calculating economic impact of policy reforms; this is connected @@ -296,6 +371,357 @@ def get_economic_impact( print(f"Error getting economic impact: {str(e)}") raise e + def get_budget_window_economic_impact( + self, + country_id: str, + policy_id: int, + baseline_policy_id: int, + region: str, + dataset: str, + start_year: str, + window_size: int, + options: dict, + api_version: str, + target: Literal["general", "cliff"] = "general", + max_active_years: int = BUDGET_WINDOW_MAX_ACTIVE_YEARS, + ) -> BudgetWindowEconomicImpactResult: + try: + if country_id == "us": + region = normalize_us_region(region) + + start_year_int = int(start_year) + if window_size < 1: + raise ValueError("window_size must be at least 1") + + years = [str(start_year_int + index) for index in range(window_size)] + setup_options_by_year = { + year: self._build_economic_impact_setup_options( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + time_period=year, + options=dict(options), + api_version=api_version, + target=target, + ) + for year in years + } + + completed_impacts: dict[str, dict] = {} + computing_years: list[str] = [] + queued_years: list[str] = [] + + for year in years: + result = self._get_existing_economic_impact( + setup_options=setup_options_by_year[year] + ) + + if result is None: + queued_years.append(year) + continue + + if result.status == ImpactStatus.OK: + completed_impacts[year] = self._extract_budget_window_annual_impact( + year=year, impact_data=result.data or {} + ) + continue + + if result.status == ImpactStatus.COMPUTING: + computing_years.append(year) + continue + + completed_years = [ + completed_year + for completed_year in years + if completed_year in completed_impacts + ] + return BudgetWindowEconomicImpactResult.failed( + result.data.get("message") + if isinstance(result.data, dict) + else f"Budget-window calculation failed for {year}", + completed_years=completed_years, + computing_years=computing_years, + queued_years=queued_years, + ) + + available_slots = max(0, max_active_years - len(computing_years)) + years_to_start = queued_years[:available_slots] + remaining_queued_years = queued_years[available_slots:] + + if years_to_start: + max_workers = min(len(years_to_start), max_active_years) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_year_pairs = [ + ( + year, + executor.submit( + self.get_economic_impact, + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + time_period=year, + options=dict(options), + api_version=api_version, + target=target, + ), + ) + for year in years_to_start + ] + + for year, future in future_year_pairs: + result = future.result() + + if result.status == ImpactStatus.OK: + completed_impacts[year] = ( + self._extract_budget_window_annual_impact( + year=year, impact_data=result.data or {} + ) + ) + elif result.status == ImpactStatus.COMPUTING: + computing_years.append(year) + else: + completed_years = [ + completed_year + for completed_year in years + if completed_year in completed_impacts + ] + return BudgetWindowEconomicImpactResult.failed( + f"Budget-window calculation failed for {year}", + completed_years=completed_years, + computing_years=computing_years, + queued_years=remaining_queued_years, + ) + + completed_years = [ + completed_year + for completed_year in years + if completed_year in completed_impacts + ] + + if len(completed_years) == len(years): + ordered_annual_impacts = [ + completed_impacts[year] + for year in years + if year in completed_impacts + ] + return BudgetWindowEconomicImpactResult.completed( + self._build_budget_window_output( + start_year=start_year, + window_size=window_size, + annual_impacts=ordered_annual_impacts, + ) + ) + + progress = round((len(completed_years) / len(years)) * 100) + return BudgetWindowEconomicImpactResult.computing( + progress=progress, + completed_years=completed_years, + computing_years=computing_years, + queued_years=remaining_queued_years, + message=self._build_budget_window_progress_message( + completed_years=completed_years, + total_years=len(years), + computing_years=computing_years, + queued_years=remaining_queued_years, + ), + ) + except Exception as e: + print(f"Error getting budget-window economic impact: {str(e)}") + raise e + + def _build_economic_impact_setup_options( + self, + *, + country_id: str, + policy_id: int, + baseline_policy_id: int, + region: str, + dataset: str, + time_period: str, + options: dict, + api_version: str, + target: Literal["general", "cliff"] = "general", + ) -> EconomicImpactSetupOptions: + process_id: str = self._create_process_id() + options_hash = "[" + "&".join([f"{k}={v}" for k, v in options.items()]) + "]" + + country_package_version = COUNTRY_PACKAGE_VERSIONS.get(country_id) + if country_id == "uk": + country_package_version = None + + return EconomicImpactSetupOptions.model_validate( + { + "process_id": process_id, + "country_id": country_id, + "reform_policy_id": policy_id, + "baseline_policy_id": baseline_policy_id, + "region": region, + "dataset": dataset, + "time_period": time_period, + "options": options, + "api_version": api_version, + "target": target, + "model_version": country_package_version, + "data_version": get_dataset_version(country_id), + "options_hash": options_hash, + } + ) + + def _get_or_create_economic_impact( + self, setup_options: EconomicImpactSetupOptions + ) -> EconomicImpactResult: + logger.log_struct( + { + "message": "Received request for economic impact; checking if already in reform_impacts table", + **setup_options.model_dump(), + }, + severity="INFO", + ) + + most_recent_impact: dict | None = self._get_most_recent_impact( + setup_options=setup_options + ) + + impact_action: ImpactAction = self._determine_impact_action( + most_recent_impact=most_recent_impact + ) + + if impact_action == ImpactAction.COMPLETED: + logger.log_struct( + { + "message": "Found completed economic impact in db; returning result", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_completed_impact(most_recent_impact=most_recent_impact) + + if impact_action == ImpactAction.COMPUTING: + logger.log_struct( + { + "message": "Found computing economic impact record in db; confirming this is still computing", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_computing_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) + + if impact_action == ImpactAction.CREATE: + logger.log_struct( + { + "message": "No previous economic impact record found in db; creating new simulation run", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_create_impact(setup_options=setup_options) + + raise ValueError(f"Unexpected impact action: {impact_action}") + + def _get_existing_economic_impact( + self, setup_options: EconomicImpactSetupOptions + ) -> Optional[EconomicImpactResult]: + most_recent_impact = self._get_most_recent_impact(setup_options=setup_options) + if not most_recent_impact: + return None + + status = most_recent_impact.get("status") + if status == ImpactStatus.ERROR.value: + error_message = most_recent_impact.get("message") or ( + f"Economic impact failed for {setup_options.time_period}" + ) + return EconomicImpactResult( + status=ImpactStatus.ERROR, + data={"message": error_message}, + ) + + if status == ImpactStatus.OK.value: + return self._handle_completed_impact(most_recent_impact=most_recent_impact) + + if status == ImpactStatus.COMPUTING.value: + return self._handle_computing_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) + + raise ValueError(f"Unknown impact status: {status}") + + def _extract_budget_window_annual_impact( + self, year: str, impact_data: dict + ) -> dict[str, Union[str, int, float]]: + budget = impact_data.get("budget", {}) + state_tax_revenue_impact = budget.get("state_tax_revenue_impact", 0) + tax_revenue_impact = budget.get("tax_revenue_impact", 0) + + return { + "year": year, + "taxRevenueImpact": tax_revenue_impact, + "federalTaxRevenueImpact": tax_revenue_impact - state_tax_revenue_impact, + "stateTaxRevenueImpact": state_tax_revenue_impact, + "benefitSpendingImpact": budget.get("benefit_spending_impact", 0), + "budgetaryImpact": budget.get("budgetary_impact", 0), + } + + def _sum_budget_window_annual_impacts(self, annual_impacts: list[dict]) -> dict: + totals = { + "year": "Total", + "taxRevenueImpact": 0, + "federalTaxRevenueImpact": 0, + "stateTaxRevenueImpact": 0, + "benefitSpendingImpact": 0, + "budgetaryImpact": 0, + } + + for annual_impact in annual_impacts: + totals["taxRevenueImpact"] += annual_impact["taxRevenueImpact"] + totals["federalTaxRevenueImpact"] += annual_impact[ + "federalTaxRevenueImpact" + ] + totals["stateTaxRevenueImpact"] += annual_impact["stateTaxRevenueImpact"] + totals["benefitSpendingImpact"] += annual_impact["benefitSpendingImpact"] + totals["budgetaryImpact"] += annual_impact["budgetaryImpact"] + + return totals + + def _build_budget_window_output( + self, *, start_year: str, window_size: int, annual_impacts: list[dict] + ) -> dict: + return { + "kind": "budgetWindow", + "startYear": start_year, + "endYear": str(int(start_year) + window_size - 1), + "windowSize": window_size, + "annualImpacts": annual_impacts, + "totals": self._sum_budget_window_annual_impacts(annual_impacts), + } + + def _build_budget_window_progress_message( + self, + *, + completed_years: list[str], + total_years: int, + computing_years: list[str], + queued_years: list[str], + ) -> str: + completed_count = len(completed_years) + if computing_years: + active_years = ", ".join(computing_years[:2]) + if len(computing_years) > 2: + active_years = f"{active_years} + {len(computing_years) - 2} more" + return f"Scoring {active_years} ({completed_count} of {total_years} complete)..." + + if queued_years: + return f"Queued {queued_years[0]} ({completed_count} of {total_years} complete)..." + + return f"Scoring budget window ({completed_count} of {total_years} complete)..." + def _get_previous_impacts( self, country_id: str, diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index d036ab296..aead50710 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -36,6 +36,23 @@ pytest_plugins = ("tests.fixtures.services.economy_service",) +def make_mock_budget_impact_data( + *, + tax_revenue_impact: int, + state_tax_revenue_impact: int, + benefit_spending_impact: int, + budgetary_impact: int, +): + return { + "budget": { + "tax_revenue_impact": tax_revenue_impact, + "state_tax_revenue_impact": state_tax_revenue_impact, + "benefit_spending_impact": benefit_spending_impact, + "budgetary_impact": budgetary_impact, + } + } + + class TestEconomyService: class TestGetEconomicImpact: @pytest.fixture @@ -617,6 +634,258 @@ def test__given_uk_request__preserves_model_version_in_bundle( sim_params = mock_simulation_api.run.call_args[0][0] assert sim_params["_metadata"]["model_version"] == "2.7.8" + class TestGetBudgetWindowEconomicImpact: + @pytest.fixture + def economy_service(self): + return EconomyService() + + @pytest.fixture + def base_params(self): + return { + "country_id": MOCK_COUNTRY_ID, + "policy_id": MOCK_POLICY_ID, + "baseline_policy_id": MOCK_BASELINE_POLICY_ID, + "region": MOCK_REGION, + "dataset": MOCK_DATASET, + "start_year": "2026", + "window_size": 3, + "options": MOCK_OPTIONS, + "api_version": MOCK_API_VERSION, + "target": "general", + } + + def test__given_all_years_completed__returns_aggregated_budget_window_result( + self, economy_service, base_params + ): + def make_setup(*, time_period, **_kwargs): + return EconomicImpactSetupOptions( + process_id=MOCK_PROCESS_ID, + country_id=MOCK_COUNTRY_ID, + reform_policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=time_period, + options=MOCK_OPTIONS, + api_version=MOCK_API_VERSION, + target="general", + options_hash=MOCK_OPTIONS_HASH, + ) + + yearly_results = { + "2026": EconomicImpactResult.completed( + make_mock_budget_impact_data( + tax_revenue_impact=100, + state_tax_revenue_impact=20, + benefit_spending_impact=-10, + budgetary_impact=90, + ) + ), + "2027": EconomicImpactResult.completed( + make_mock_budget_impact_data( + tax_revenue_impact=120, + state_tax_revenue_impact=30, + benefit_spending_impact=-20, + budgetary_impact=100, + ) + ), + "2028": EconomicImpactResult.completed( + make_mock_budget_impact_data( + tax_revenue_impact=140, + state_tax_revenue_impact=40, + benefit_spending_impact=-30, + budgetary_impact=110, + ) + ), + } + + with ( + patch.object( + economy_service, + "_build_economic_impact_setup_options", + side_effect=make_setup, + ), + patch.object( + economy_service, + "_get_existing_economic_impact", + side_effect=lambda setup_options: yearly_results[ + setup_options.time_period + ], + ) as mock_get_existing, + patch.object( + economy_service, "get_economic_impact" + ) as mock_get_economic_impact, + ): + result = economy_service.get_budget_window_economic_impact( + **base_params + ) + + assert result.status == ImpactStatus.OK + assert result.progress == 100 + assert result.data["annualImpacts"] == [ + { + "year": "2026", + "taxRevenueImpact": 100, + "federalTaxRevenueImpact": 80, + "stateTaxRevenueImpact": 20, + "benefitSpendingImpact": -10, + "budgetaryImpact": 90, + }, + { + "year": "2027", + "taxRevenueImpact": 120, + "federalTaxRevenueImpact": 90, + "stateTaxRevenueImpact": 30, + "benefitSpendingImpact": -20, + "budgetaryImpact": 100, + }, + { + "year": "2028", + "taxRevenueImpact": 140, + "federalTaxRevenueImpact": 100, + "stateTaxRevenueImpact": 40, + "benefitSpendingImpact": -30, + "budgetaryImpact": 110, + }, + ] + assert result.data["totals"] == { + "year": "Total", + "taxRevenueImpact": 360, + "federalTaxRevenueImpact": 270, + "stateTaxRevenueImpact": 90, + "benefitSpendingImpact": -60, + "budgetaryImpact": 300, + } + assert mock_get_existing.call_count == 3 + mock_get_economic_impact.assert_not_called() + + def test__given_missing_years__starts_only_up_to_remaining_active_slots( + self, economy_service, base_params + ): + def make_setup(*, time_period, **_kwargs): + return EconomicImpactSetupOptions( + process_id=MOCK_PROCESS_ID, + country_id=MOCK_COUNTRY_ID, + reform_policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=time_period, + options=MOCK_OPTIONS, + api_version=MOCK_API_VERSION, + target="general", + options_hash=MOCK_OPTIONS_HASH, + ) + + base_params["window_size"] = 5 + + existing_results = { + "2026": EconomicImpactResult.completed( + make_mock_budget_impact_data( + tax_revenue_impact=100, + state_tax_revenue_impact=20, + benefit_spending_impact=-10, + budgetary_impact=90, + ) + ), + "2027": EconomicImpactResult.computing(), + "2028": None, + "2029": None, + "2030": None, + } + + with ( + patch.object( + economy_service, + "_build_economic_impact_setup_options", + side_effect=make_setup, + ), + patch.object( + economy_service, + "_get_existing_economic_impact", + side_effect=lambda setup_options: existing_results[ + setup_options.time_period + ], + ), + patch.object( + economy_service, + "get_economic_impact", + return_value=EconomicImpactResult.computing(), + ) as mock_get_economic_impact, + ): + result = economy_service.get_budget_window_economic_impact( + **base_params + ) + + assert result.status == ImpactStatus.COMPUTING + assert result.progress == 20 + assert result.completed_years == ["2026"] + assert result.computing_years == ["2027", "2028", "2029"] + assert result.queued_years == ["2030"] + assert "1 of 5 complete" in result.message + assert mock_get_economic_impact.call_count == 2 + started_years = sorted( + call.kwargs["time_period"] + for call in mock_get_economic_impact.call_args_list + ) + assert started_years == ["2028", "2029"] + + def test__given_year_error__returns_budget_window_error( + self, economy_service, base_params + ): + def make_setup(*, time_period, **_kwargs): + return EconomicImpactSetupOptions( + process_id=MOCK_PROCESS_ID, + country_id=MOCK_COUNTRY_ID, + reform_policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=time_period, + options=MOCK_OPTIONS, + api_version=MOCK_API_VERSION, + target="general", + options_hash=MOCK_OPTIONS_HASH, + ) + + with ( + patch.object( + economy_service, + "_build_economic_impact_setup_options", + side_effect=make_setup, + ), + patch.object( + economy_service, + "_get_existing_economic_impact", + side_effect=[ + EconomicImpactResult.completed( + make_mock_budget_impact_data( + tax_revenue_impact=100, + state_tax_revenue_impact=20, + benefit_spending_impact=-10, + budgetary_impact=90, + ) + ), + EconomicImpactResult( + status=ImpactStatus.ERROR, + data={"message": "Calculation failed for 2027"}, + ), + None, + ], + ), + patch.object( + economy_service, "get_economic_impact" + ) as mock_get_economic_impact, + ): + result = economy_service.get_budget_window_economic_impact( + **base_params + ) + + assert result.status == ImpactStatus.ERROR + assert result.error == "Calculation failed for 2027" + assert result.completed_years == ["2026"] + mock_get_economic_impact.assert_not_called() + class TestGetPreviousImpacts: @pytest.fixture def economy_service(self): From 14018b74353a3586aaca2ea028270132e9cdc31a Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Wed, 8 Apr 2026 21:51:06 -0400 Subject: [PATCH 03/27] Harden budget window batch API --- policyengine_api/routes/economy_routes.py | 116 +++++++++++------- policyengine_api/services/economy_service.py | 68 +++++++--- .../services/reform_impacts_service.py | 3 +- .../test_economy_budget_window_routes.py | 107 ++++++++++++++++ tests/unit/services/test_economy_service.py | 19 ++- 5 files changed, 248 insertions(+), 65 deletions(-) create mode 100644 tests/to_refactor/python/test_economy_budget_window_routes.py diff --git a/policyengine_api/routes/economy_routes.py b/policyengine_api/routes/economy_routes.py index 27fe5125f..d772697c2 100644 --- a/policyengine_api/routes/economy_routes.py +++ b/policyengine_api/routes/economy_routes.py @@ -14,6 +14,25 @@ economy_service = EconomyService() +def _json_response(payload: dict, status: int = 200) -> Response: + return Response( + json.dumps(payload), + status=status, + mimetype="application/json", + ) + + +def _bad_request_response(message: str) -> Response: + return _json_response( + { + "status": "error", + "message": message, + "result": None, + }, + status=400, + ) + + @economy_bp.route( "//economy//over/", methods=["GET"], @@ -57,24 +76,20 @@ def get_economic_impact(country_id: str, policy_id: int, baseline_policy_id: int result_dict: dict[str, str | dict | None] = economic_impact_result.to_dict() - return Response( - json.dumps( - { - "status": result_dict["status"], - "message": None, - "result": result_dict["data"], - } - ), - status=200, - mimetype="application/json", + return _json_response( + { + "status": result_dict["status"], + "message": None, + "result": result_dict["data"], + } ) -@validate_country @economy_bp.route( "//economy//over//budget-window", methods=["GET"], ) +@validate_country def get_budget_window_economic_impact( country_id: str, policy_id: int, baseline_policy_id: int ): @@ -86,10 +101,23 @@ def get_budget_window_economic_impact( query_parameters = request.args options = dict(query_parameters) options = json.loads(json.dumps(options)) - region = options.pop("region") + region = options.pop("region", None) + if not region: + return _bad_request_response("Missing required query parameter: region") + dataset = options.pop("dataset", "default") - start_year = options.pop("start_year") - window_size = int(options.pop("window_size")) + start_year = options.pop("start_year", None) + if not start_year: + return _bad_request_response("Missing required query parameter: start_year") + + window_size_raw = options.pop("window_size", None) + if window_size_raw is None: + return _bad_request_response("Missing required query parameter: window_size") + + try: + window_size = int(window_size_raw) + except (TypeError, ValueError): + return _bad_request_response("window_size must be an integer") include_district_breakdowns_raw = options.pop( "include_district_breakdowns", "false" @@ -99,38 +127,42 @@ def get_budget_window_economic_impact( dataset = "national-with-breakdowns" target: Literal["general", "cliff"] = options.pop("target", "general") + if target != "general": + return _bad_request_response( + "Budget-window calculations only support target=general" + ) + api_version = options.pop("version", COUNTRY_PACKAGE_VERSIONS.get(country_id)) - economic_impact_result: BudgetWindowEconomicImpactResult = ( - economy_service.get_budget_window_economic_impact( - country_id=country_id, - policy_id=policy_id, - baseline_policy_id=baseline_policy_id, - region=region, - dataset=dataset, - start_year=start_year, - window_size=window_size, - options=options, - api_version=api_version, - target=target, + try: + economic_impact_result: BudgetWindowEconomicImpactResult = ( + economy_service.get_budget_window_economic_impact( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + start_year=start_year, + window_size=window_size, + options=options, + api_version=api_version, + target=target, + ) ) - ) + except ValueError as error: + return _bad_request_response(str(error)) result_dict = economic_impact_result.to_dict() - return Response( - json.dumps( - { - "status": result_dict["status"], - "message": result_dict["message"], - "result": result_dict["data"], - "progress": result_dict["progress"], - "completed_years": result_dict["completed_years"], - "computing_years": result_dict["computing_years"], - "queued_years": result_dict["queued_years"], - "error": result_dict["error"], - } - ), - status=200, - mimetype="application/json", + return _json_response( + { + "status": result_dict["status"], + "message": result_dict["message"], + "result": result_dict["data"], + "progress": result_dict["progress"], + "completed_years": result_dict["completed_years"], + "computing_years": result_dict["computing_years"], + "queued_years": result_dict["queued_years"], + "error": result_dict["error"], + } ) diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 1a0f9a281..f6cee3896 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -32,6 +32,7 @@ import numpy as np from enum import Enum from concurrent.futures import ThreadPoolExecutor +from threading import Lock load_dotenv() @@ -73,6 +74,7 @@ class ImpactStatus(Enum): COMPLETE_STATUSES = [ImpactStatus.OK.value, ImpactStatus.ERROR.value] COMPUTING_STATUS = ImpactStatus.COMPUTING.value BUDGET_WINDOW_MAX_ACTIVE_YEARS = 3 +IMPACT_CREATION_LOCK = Lock() class EconomicImpactSetupOptions(BaseModel): @@ -389,6 +391,11 @@ def get_budget_window_economic_impact( if country_id == "us": region = normalize_us_region(region) + if target != "general": + raise ValueError( + "Budget-window calculations only support target='general'" + ) + start_year_int = int(start_year) if window_size < 1: raise ValueError("window_size must be at least 1") @@ -457,16 +464,8 @@ def get_budget_window_economic_impact( ( year, executor.submit( - self.get_economic_impact, - country_id=country_id, - policy_id=policy_id, - baseline_policy_id=baseline_policy_id, - region=region, - dataset=dataset, - time_period=year, - options=dict(options), - api_version=api_version, - target=target, + self._get_or_create_economic_impact, + setup_options_by_year[year], ), ) for year in years_to_start @@ -614,14 +613,47 @@ def _get_or_create_economic_impact( ) if impact_action == ImpactAction.CREATE: - logger.log_struct( - { - "message": "No previous economic impact record found in db; creating new simulation run", - **setup_options.model_dump(), - }, - severity="INFO", - ) - return self._handle_create_impact(setup_options=setup_options) + with IMPACT_CREATION_LOCK: + most_recent_impact = self._get_most_recent_impact( + setup_options=setup_options + ) + impact_action = self._determine_impact_action( + most_recent_impact=most_recent_impact + ) + + if impact_action == ImpactAction.COMPLETED: + logger.log_struct( + { + "message": "Found completed economic impact in db after locking; returning result", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_completed_impact( + most_recent_impact=most_recent_impact + ) + + if impact_action == ImpactAction.COMPUTING: + logger.log_struct( + { + "message": "Found computing economic impact in db after locking; returning progress", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_computing_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) + + logger.log_struct( + { + "message": "No previous economic impact record found in db; creating new simulation run", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_create_impact(setup_options=setup_options) raise ValueError(f"Unexpected impact action: {impact_action}") diff --git a/policyengine_api/services/reform_impacts_service.py b/policyengine_api/services/reform_impacts_service.py index 0f41352f3..cef340c88 100644 --- a/policyengine_api/services/reform_impacts_service.py +++ b/policyengine_api/services/reform_impacts_service.py @@ -25,7 +25,8 @@ def get_all_reform_impacts( "SELECT reform_impact_json, status, message, start_time, execution_id FROM " "reform_impact WHERE country_id = ? AND reform_policy_id = ? AND " "baseline_policy_id = ? AND region = ? AND time_period = ? AND " - "options_hash = ? AND api_version = ? AND dataset = ?" + "options_hash = ? AND api_version = ? AND dataset = ? " + "ORDER BY start_time DESC" ) return local_database.query( query, diff --git a/tests/to_refactor/python/test_economy_budget_window_routes.py b/tests/to_refactor/python/test_economy_budget_window_routes.py new file mode 100644 index 000000000..10148e973 --- /dev/null +++ b/tests/to_refactor/python/test_economy_budget_window_routes.py @@ -0,0 +1,107 @@ +import json +from unittest.mock import Mock, patch + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_rejects_cliff_target( + mock_get_budget_window_economic_impact, rest_client +): + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2026&window_size=10&target=cliff" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "target=general" in data["message"] + mock_get_budget_window_economic_impact.assert_not_called() + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_requires_window_size( + mock_get_budget_window_economic_impact, rest_client +): + response = rest_client.get( + "/us/economy/123/over/456/budget-window?region=us&start_year=2026" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "window_size" in data["message"] + mock_get_budget_window_economic_impact.assert_not_called() + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_requires_integer_window_size( + mock_get_budget_window_economic_impact, rest_client +): + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2026&window_size=abc" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "window_size must be an integer" == data["message"] + mock_get_budget_window_economic_impact.assert_not_called() + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_passes_version_to_service( + mock_get_budget_window_economic_impact, rest_client +): + mock_result = Mock() + mock_result.to_dict.return_value = { + "status": "ok", + "message": None, + "data": { + "kind": "budgetWindow", + "startYear": "2026", + "endYear": "2027", + "windowSize": 2, + "annualImpacts": [], + "totals": {}, + }, + "progress": 100, + "completed_years": ["2026", "2027"], + "computing_years": [], + "queued_years": [], + "error": None, + } + mock_get_budget_window_economic_impact.return_value = mock_result + + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2026&window_size=2&version=1.2.3" + ) + + data = json.loads(response.data) + + assert response.status_code == 200 + assert data["status"] == "ok" + mock_get_budget_window_economic_impact.assert_called_once_with( + country_id="us", + policy_id=123, + baseline_policy_id=456, + region="us", + dataset="default", + start_year="2026", + window_size=2, + options={}, + api_version="1.2.3", + target="general", + ) diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index aead50710..02f1f64b7 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -713,7 +713,7 @@ def make_setup(*, time_period, **_kwargs): ], ) as mock_get_existing, patch.object( - economy_service, "get_economic_impact" + economy_service, "_get_or_create_economic_impact" ) as mock_get_economic_impact, ): result = economy_service.get_budget_window_economic_impact( @@ -809,7 +809,7 @@ def make_setup(*, time_period, **_kwargs): ), patch.object( economy_service, - "get_economic_impact", + "_get_or_create_economic_impact", return_value=EconomicImpactResult.computing(), ) as mock_get_economic_impact, ): @@ -825,7 +825,7 @@ def make_setup(*, time_period, **_kwargs): assert "1 of 5 complete" in result.message assert mock_get_economic_impact.call_count == 2 started_years = sorted( - call.kwargs["time_period"] + call.args[0].time_period for call in mock_get_economic_impact.call_args_list ) assert started_years == ["2028", "2029"] @@ -874,7 +874,7 @@ def make_setup(*, time_period, **_kwargs): ], ), patch.object( - economy_service, "get_economic_impact" + economy_service, "_get_or_create_economic_impact" ) as mock_get_economic_impact, ): result = economy_service.get_budget_window_economic_impact( @@ -886,6 +886,17 @@ def make_setup(*, time_period, **_kwargs): assert result.completed_years == ["2026"] mock_get_economic_impact.assert_not_called() + def test__given_cliff_target__raises_value_error( + self, economy_service, base_params + ): + base_params["target"] = "cliff" + + with pytest.raises( + ValueError, + match="Budget-window calculations only support target='general'", + ): + economy_service.get_budget_window_economic_impact(**base_params) + class TestGetPreviousImpacts: @pytest.fixture def economy_service(self): From 530c3bd66c2fefb3547bd16e8710a7c0b7b06a28 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Thu, 9 Apr 2026 07:42:06 -0400 Subject: [PATCH 04/27] Address budget window review findings --- policyengine_api/services/economy_service.py | 38 +++++-- .../test_economy_budget_window_routes.py | 13 +++ tests/unit/services/test_economy_service.py | 106 ++++++++++++++++++ 3 files changed, 149 insertions(+), 8 deletions(-) diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index f6cee3896..014e1cad7 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -74,6 +74,7 @@ class ImpactStatus(Enum): COMPLETE_STATUSES = [ImpactStatus.OK.value, ImpactStatus.ERROR.value] COMPUTING_STATUS = ImpactStatus.COMPUTING.value BUDGET_WINDOW_MAX_ACTIVE_YEARS = 3 +BUDGET_WINDOW_MAX_YEARS = 20 IMPACT_CREATION_LOCK = Lock() @@ -103,6 +104,7 @@ class EconomicImpactResult(BaseModel): status: ImpactStatus data: Optional[dict] = None + message: Optional[str] = None model_config = {"frozen": True} # Make model immutable @@ -135,7 +137,7 @@ def error(cls, message: str) -> "EconomicImpactResult": Create an EconomicImpactResult for an error in the impact calculation. """ logger.log_struct({"message": message}, severity="ERROR") - return cls(status=ImpactStatus.ERROR, data=None) + return cls(status=ImpactStatus.ERROR, data=None, message=message) class BudgetWindowEconomicImpactResult(BaseModel): @@ -397,8 +399,10 @@ def get_budget_window_economic_impact( ) start_year_int = int(start_year) - if window_size < 1: - raise ValueError("window_size must be at least 1") + if not 1 <= window_size <= BUDGET_WINDOW_MAX_YEARS: + raise ValueError( + f"window_size must be between 1 and {BUDGET_WINDOW_MAX_YEARS}" + ) years = [str(start_year_int + index) for index in range(window_size)] setup_options_by_year = { @@ -445,9 +449,10 @@ def get_budget_window_economic_impact( if completed_year in completed_impacts ] return BudgetWindowEconomicImpactResult.failed( - result.data.get("message") - if isinstance(result.data, dict) - else f"Budget-window calculation failed for {year}", + self._get_economic_impact_error_message( + result=result, + year=year, + ), completed_years=completed_years, computing_years=computing_years, queued_years=queued_years, @@ -489,7 +494,10 @@ def get_budget_window_economic_impact( if completed_year in completed_impacts ] return BudgetWindowEconomicImpactResult.failed( - f"Budget-window calculation failed for {year}", + self._get_economic_impact_error_message( + result=result, + year=year, + ), completed_years=completed_years, computing_years=computing_years, queued_years=remaining_queued_years, @@ -671,7 +679,8 @@ def _get_existing_economic_impact( ) return EconomicImpactResult( status=ImpactStatus.ERROR, - data={"message": error_message}, + data=None, + message=error_message, ) if status == ImpactStatus.OK.value: @@ -685,6 +694,19 @@ def _get_existing_economic_impact( raise ValueError(f"Unknown impact status: {status}") + def _get_economic_impact_error_message( + self, result: EconomicImpactResult, year: str + ) -> str: + if result.message: + return result.message + + if isinstance(result.data, dict): + data_message = result.data.get("message") + if isinstance(data_message, str) and data_message: + return data_message + + return f"Budget-window calculation failed for {year}" + def _extract_budget_window_annual_impact( self, year: str, impact_data: dict ) -> dict[str, Union[str, int, float]]: diff --git a/tests/to_refactor/python/test_economy_budget_window_routes.py b/tests/to_refactor/python/test_economy_budget_window_routes.py index 10148e973..fca938948 100644 --- a/tests/to_refactor/python/test_economy_budget_window_routes.py +++ b/tests/to_refactor/python/test_economy_budget_window_routes.py @@ -58,6 +58,19 @@ def test_budget_window_route_requires_integer_window_size( mock_get_budget_window_economic_impact.assert_not_called() +def test_budget_window_route_rejects_oversized_window(rest_client): + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2026&window_size=999" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "window_size must be between 1 and" in data["message"] + + @patch( "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" ) diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index 02f1f64b7..094bd72d5 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -4,6 +4,7 @@ from typing import Literal from policyengine_api.services.economy_service import ( + BUDGET_WINDOW_MAX_YEARS, EconomyService, EconomicImpactResult, EconomicImpactSetupOptions, @@ -897,6 +898,104 @@ def test__given_cliff_target__raises_value_error( ): economy_service.get_budget_window_economic_impact(**base_params) + def test__given_oversized_window__raises_value_error( + self, economy_service, base_params + ): + base_params["window_size"] = BUDGET_WINDOW_MAX_YEARS + 1 + + with pytest.raises( + ValueError, + match=(f"window_size must be between 1 and {BUDGET_WINDOW_MAX_YEARS}"), + ): + economy_service.get_budget_window_economic_impact(**base_params) + + def test__given_started_year_error__returns_specific_budget_window_error( + self, economy_service, base_params + ): + with ( + patch.object( + economy_service, + "_get_existing_economic_impact", + side_effect=[None, None, None], + ), + patch.object( + economy_service, + "_get_or_create_economic_impact", + side_effect=[ + EconomicImpactResult.error("Calculation failed for 2026"), + EconomicImpactResult.computing(), + EconomicImpactResult.computing(), + ], + ), + ): + result = economy_service.get_budget_window_economic_impact( + **base_params + ) + + assert result.status == ImpactStatus.ERROR + assert result.error == "Calculation failed for 2026" + assert result.completed_years == [] + + def test__given_runtime_cache_version__uses_versioned_cache_key_for_budget_window( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_logger, + mock_datetime, + mock_numpy_random, + monkeypatch, + ): + cache_version = "e1cache01" + seen_existing_calls = [] + seen_create_calls = [] + + monkeypatch.setattr( + "policyengine_api.services.economy_service.get_economy_impact_cache_version", + lambda country_id, api_version=None: cache_version, + ) + + def fake_get_existing(setup_options): + seen_existing_calls.append( + (setup_options.time_period, setup_options.api_version) + ) + return None + + def fake_get_or_create(setup_options): + seen_create_calls.append( + (setup_options.time_period, setup_options.api_version) + ) + return EconomicImpactResult.computing() + + with ( + patch.object( + economy_service, + "_get_existing_economic_impact", + side_effect=fake_get_existing, + ), + patch.object( + economy_service, + "_get_or_create_economic_impact", + side_effect=fake_get_or_create, + ), + ): + result = economy_service.get_budget_window_economic_impact( + **base_params + ) + + assert result.status == ImpactStatus.COMPUTING + assert seen_existing_calls == [ + ("2026", cache_version), + ("2027", cache_version), + ("2028", cache_version), + ] + assert seen_create_calls == [ + ("2026", cache_version), + ("2027", cache_version), + ("2028", cache_version), + ] + class TestGetPreviousImpacts: @pytest.fixture def economy_service(self): @@ -1098,6 +1197,7 @@ def test__given_failed_state__returns_error_result( assert result.status == ImpactStatus.ERROR assert result.data is None + assert result.message == "Simulation API execution failed" mock_reform_impacts_service.set_error_reform_impact.assert_called_once() def test__given_active_state__returns_computing_result( @@ -1174,6 +1274,7 @@ def test__given_modal_failed_state__then_returns_error_result( # Then assert result.status == ImpactStatus.ERROR assert result.data is None + assert result.message == "Simulation API execution failed" mock_reform_impacts_service.set_error_reform_impact.assert_called_once() def test__given_modal_failed_state_with_error_message__then_includes_error_in_message( @@ -1195,6 +1296,10 @@ def test__given_modal_failed_state_with_error_message__then_includes_error_in_me # Then assert result.status == ImpactStatus.ERROR + assert ( + result.message + == "Simulation API execution failed: Simulation timed out" + ) # Verify the error message was passed to the service call_args = mock_reform_impacts_service.set_error_reform_impact.call_args assert "Simulation timed out" in call_args[1]["message"] @@ -1292,6 +1397,7 @@ def test__given_error__creates_correct_instance_and_logs(self): assert result.status == ImpactStatus.ERROR assert result.data is None + assert result.message == "Test error message" mock_logger.log_struct.assert_called_once() From 7f715659636bc96036f04484f9d4df5d29190183 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Thu, 9 Apr 2026 10:03:07 -0400 Subject: [PATCH 05/27] Prevent duplicate budget window jobs across workers --- .../endpoints/economy/reform_impact.py | 4 +- policyengine_api/endpoints/simulation.py | 4 +- policyengine_api/services/economy_service.py | 13 +- .../services/reform_impacts_service.py | 91 ++++++++++++-- tests/unit/services/test_economy_service.py | 4 +- .../services/test_reform_impacts_service.py | 117 ++++++++++++++++++ 6 files changed, 216 insertions(+), 17 deletions(-) create mode 100644 tests/unit/services/test_reform_impacts_service.py diff --git a/policyengine_api/endpoints/economy/reform_impact.py b/policyengine_api/endpoints/economy/reform_impact.py index 42795243d..cc778bcef 100644 --- a/policyengine_api/endpoints/economy/reform_impact.py +++ b/policyengine_api/endpoints/economy/reform_impact.py @@ -1,4 +1,4 @@ -from policyengine_api.data import local_database +from policyengine_api.data import database def set_comment_on_job( @@ -17,7 +17,7 @@ def set_comment_on_job( "time_period = ? AND options_hash = ? AND dataset = ?" ) - local_database.query( + database.query( query, ( comment, diff --git a/policyengine_api/endpoints/simulation.py b/policyengine_api/endpoints/simulation.py index a0d9bd70d..b03941442 100644 --- a/policyengine_api/endpoints/simulation.py +++ b/policyengine_api/endpoints/simulation.py @@ -1,4 +1,4 @@ -from policyengine_api.data import local_database +from policyengine_api.data import database """ @@ -42,7 +42,7 @@ def get_simulations( max_results = _DEFAULT_SIMULATION_RESULTS max_results = max(1, min(max_results, _MAX_SIMULATION_RESULTS)) - result = local_database.query( + result = database.query( "SELECT * FROM reform_impact ORDER BY start_time DESC LIMIT ?", (max_results,), ).fetchall() diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 014e1cad7..fc9e5ec36 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -32,7 +32,6 @@ import numpy as np from enum import Enum from concurrent.futures import ThreadPoolExecutor -from threading import Lock load_dotenv() @@ -75,7 +74,6 @@ class ImpactStatus(Enum): COMPUTING_STATUS = ImpactStatus.COMPUTING.value BUDGET_WINDOW_MAX_ACTIVE_YEARS = 3 BUDGET_WINDOW_MAX_YEARS = 20 -IMPACT_CREATION_LOCK = Lock() class EconomicImpactSetupOptions(BaseModel): @@ -621,7 +619,16 @@ def _get_or_create_economic_impact( ) if impact_action == ImpactAction.CREATE: - with IMPACT_CREATION_LOCK: + with reform_impacts_service.claim_lock( + country_id=setup_options.country_id, + policy_id=setup_options.reform_policy_id, + baseline_policy_id=setup_options.baseline_policy_id, + region=setup_options.region, + dataset=setup_options.dataset, + time_period=setup_options.time_period, + options_hash=setup_options.options_hash, + api_version=setup_options.api_version, + ): most_recent_impact = self._get_most_recent_impact( setup_options=setup_options ) diff --git a/policyengine_api/services/reform_impacts_service.py b/policyengine_api/services/reform_impacts_service.py index cef340c88..4a02c2619 100644 --- a/policyengine_api/services/reform_impacts_service.py +++ b/policyengine_api/services/reform_impacts_service.py @@ -1,14 +1,89 @@ -from policyengine_api.data import local_database +from contextlib import contextmanager +import hashlib +from threading import Lock +from policyengine_api.data import database import datetime +LOCAL_REFORM_IMPACT_LOCK = Lock() +REFORM_IMPACT_LOCK_TIMEOUT_SECONDS = 5 + + class ReformImpactsService: """ Service for storing and retrieving economy-wide reform impacts; - this is connected to the locally-stored reform_impact table - and no existing route + this is connected to the shared reform_impact table. """ + def _build_lock_name( + self, + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options_hash, + api_version, + ) -> str: + raw_key = ( + f"{country_id}:{policy_id}:{baseline_policy_id}:{region}:{dataset}:" + f"{time_period}:{options_hash}:{api_version}" + ) + digest = hashlib.sha256(raw_key.encode("utf-8")).hexdigest() + return f"ri:{digest[:61]}" + + @contextmanager + def claim_lock( + self, + *, + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options_hash, + api_version, + timeout_seconds: int = REFORM_IMPACT_LOCK_TIMEOUT_SECONDS, + ): + if database.local: + with LOCAL_REFORM_IMPACT_LOCK: + yield + return + + lock_name = self._build_lock_name( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + time_period=time_period, + options_hash=options_hash, + api_version=api_version, + ) + with database.pool.connect() as conn: + acquired = ( + conn.exec_driver_sql( + "SELECT GET_LOCK(%s, %s) AS acquired", + (lock_name, timeout_seconds), + ) + .mappings() + .first() + ) + if acquired is None or acquired["acquired"] != 1: + raise TimeoutError( + f"Could not acquire reform impact lock for {country_id}/{policy_id}/{time_period}" + ) + + try: + yield + finally: + conn.exec_driver_sql( + "SELECT RELEASE_LOCK(%s) AS released", (lock_name,) + ) + conn.commit() + def get_all_reform_impacts( self, country_id, @@ -28,7 +103,7 @@ def get_all_reform_impacts( "options_hash = ? AND api_version = ? AND dataset = ? " "ORDER BY start_time DESC" ) - return local_database.query( + return database.query( query, ( country_id, @@ -106,7 +181,7 @@ def set_reform_impact( "region, dataset, time_period, options_json, options_hash, status, api_version, " "reform_impact_json, start_time, execution_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" ) - local_database.query( + database.query( query, ( country_id, @@ -146,7 +221,7 @@ def delete_reform_impact( "dataset = ? AND status = 'computing'" ) - local_database.query( + database.query( query, ( country_id, @@ -181,7 +256,7 @@ def set_error_reform_impact( "region = ? AND time_period = ? AND options_hash = ? AND dataset = ? AND " "execution_id = ?" ) - local_database.query( + database.query( query, ( "error", @@ -225,7 +300,7 @@ def set_complete_reform_impact( "baseline_policy_id = ? AND region = ? AND time_period = ? AND " "options_hash = ? AND dataset = ? AND execution_id = ?" ) - local_database.query( + database.query( query, ( "ok", diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index 094bd72d5..15ef97bf1 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -832,7 +832,7 @@ def make_setup(*, time_period, **_kwargs): assert started_years == ["2028", "2029"] def test__given_year_error__returns_budget_window_error( - self, economy_service, base_params + self, economy_service, base_params, mock_logger ): def make_setup(*, time_period, **_kwargs): return EconomicImpactSetupOptions( @@ -910,7 +910,7 @@ def test__given_oversized_window__raises_value_error( economy_service.get_budget_window_economic_impact(**base_params) def test__given_started_year_error__returns_specific_budget_window_error( - self, economy_service, base_params + self, economy_service, base_params, mock_logger ): with ( patch.object( diff --git a/tests/unit/services/test_reform_impacts_service.py b/tests/unit/services/test_reform_impacts_service.py new file mode 100644 index 000000000..4f44b63a2 --- /dev/null +++ b/tests/unit/services/test_reform_impacts_service.py @@ -0,0 +1,117 @@ +from unittest.mock import MagicMock + +import pytest + +from policyengine_api.services.reform_impacts_service import ReformImpactsService + + +class TestReformImpactsService: + def test__given_remote_database__claim_lock_uses_advisory_lock(self, monkeypatch): + service = ReformImpactsService() + + acquired_result = MagicMock() + acquired_result.mappings.return_value.first.return_value = {"acquired": 1} + release_result = MagicMock() + + mock_connection = MagicMock() + mock_connection.exec_driver_sql.side_effect = [ + acquired_result, + release_result, + ] + + mock_connection_context = MagicMock() + mock_connection_context.__enter__.return_value = mock_connection + mock_connection_context.__exit__.return_value = False + + mock_pool = MagicMock() + mock_pool.connect.return_value = mock_connection_context + + mock_database = MagicMock() + mock_database.local = False + mock_database.pool = mock_pool + + monkeypatch.setattr( + "policyengine_api.services.reform_impacts_service.database", + mock_database, + ) + + with service.claim_lock( + country_id="us", + policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="2026", + options_hash="[option=value]", + api_version="e1cache01", + ): + pass + + assert mock_connection.exec_driver_sql.call_count == 2 + + acquire_call = mock_connection.exec_driver_sql.call_args_list[0] + assert acquire_call.args == ( + "SELECT GET_LOCK(%s, %s) AS acquired", + ( + service._build_lock_name( + country_id="us", + policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="2026", + options_hash="[option=value]", + api_version="e1cache01", + ), + 5, + ), + ) + assert len(acquire_call.args[1][0]) <= 64 + + release_call = mock_connection.exec_driver_sql.call_args_list[1] + assert release_call.args == ( + "SELECT RELEASE_LOCK(%s) AS released", + (acquire_call.args[1][0],), + ) + mock_connection.commit.assert_called_once() + + def test__given_remote_database_lock_timeout__claim_lock_raises(self, monkeypatch): + service = ReformImpactsService() + + acquired_result = MagicMock() + acquired_result.mappings.return_value.first.return_value = {"acquired": 0} + + mock_connection = MagicMock() + mock_connection.exec_driver_sql.return_value = acquired_result + + mock_connection_context = MagicMock() + mock_connection_context.__enter__.return_value = mock_connection + mock_connection_context.__exit__.return_value = False + + mock_pool = MagicMock() + mock_pool.connect.return_value = mock_connection_context + + mock_database = MagicMock() + mock_database.local = False + mock_database.pool = mock_pool + + monkeypatch.setattr( + "policyengine_api.services.reform_impacts_service.database", + mock_database, + ) + + with pytest.raises( + TimeoutError, + match="Could not acquire reform impact lock", + ): + with service.claim_lock( + country_id="us", + policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="2026", + options_hash="[option=value]", + api_version="e1cache01", + ): + pass From f2165787811f21659b0c724ca16e7c21740efae9 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Thu, 9 Apr 2026 10:25:18 -0400 Subject: [PATCH 06/27] Harden reform impact claim deduping --- policyengine_api/data/data.py | 20 +- .../endpoints/economy/reform_impact.py | 4 +- policyengine_api/services/economy_service.py | 229 ++++++++++++++---- .../services/reform_impacts_service.py | 39 ++- tests/fixtures/services/economy_service.py | 6 +- tests/unit/data/test_sqlalchemy_v2.py | 32 +++ tests/unit/services/test_economy_service.py | 193 +++++++++++++++ 7 files changed, 466 insertions(+), 57 deletions(-) diff --git a/policyengine_api/data/data.py b/policyengine_api/data/data.py index 6b16e713e..1ba0ab57a 100644 --- a/policyengine_api/data/data.py +++ b/policyengine_api/data/data.py @@ -105,16 +105,20 @@ def _create_pool(self): with open(".dbpw") as f: db_pass = f.read().strip() db_name = "policyengine" - conn = self.connector.connect( - instance_connection_string=instance_connection_name, - driver="pymysql", - db=db_name, - user=db_user, - password=db_pass, - ) + + def get_connection(): + return self.connector.connect( + instance_connection_string=instance_connection_name, + driver="pymysql", + db=db_name, + user=db_user, + password=db_pass, + ) + self.pool = sqlalchemy.create_engine( "mysql+pymysql://", - creator=lambda: conn, + creator=get_connection, + pool_pre_ping=True, ) def _close_pool(self): diff --git a/policyengine_api/endpoints/economy/reform_impact.py b/policyengine_api/endpoints/economy/reform_impact.py index cc778bcef..42795243d 100644 --- a/policyengine_api/endpoints/economy/reform_impact.py +++ b/policyengine_api/endpoints/economy/reform_impact.py @@ -1,4 +1,4 @@ -from policyengine_api.data import database +from policyengine_api.data import local_database def set_comment_on_job( @@ -17,7 +17,7 @@ def set_comment_on_job( "time_period = ? AND options_hash = ? AND dataset = ?" ) - database.query( + local_database.query( query, ( comment, diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index fc9e5ec36..41ba1fd1f 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -74,6 +74,11 @@ class ImpactStatus(Enum): COMPUTING_STATUS = ImpactStatus.COMPUTING.value BUDGET_WINDOW_MAX_ACTIVE_YEARS = 3 BUDGET_WINDOW_MAX_YEARS = 20 +PENDING_EXECUTION_ID_PREFIX = "pending:" +PROVISIONAL_CLAIM_TTL_SECONDS = 90 +STALE_PROVISIONAL_IMPACT_MESSAGE = ( + "Simulation claim expired before job submission completed" +) class EconomicImpactSetupOptions(BaseModel): @@ -619,56 +624,88 @@ def _get_or_create_economic_impact( ) if impact_action == ImpactAction.CREATE: - with reform_impacts_service.claim_lock( - country_id=setup_options.country_id, - policy_id=setup_options.reform_policy_id, - baseline_policy_id=setup_options.baseline_policy_id, - region=setup_options.region, - dataset=setup_options.dataset, - time_period=setup_options.time_period, - options_hash=setup_options.options_hash, - api_version=setup_options.api_version, - ): - most_recent_impact = self._get_most_recent_impact( - setup_options=setup_options - ) - impact_action = self._determine_impact_action( - most_recent_impact=most_recent_impact - ) - - if impact_action == ImpactAction.COMPLETED: - logger.log_struct( - { - "message": "Found completed economic impact in db after locking; returning result", - **setup_options.model_dump(), - }, - severity="INFO", + try: + with reform_impacts_service.claim_lock( + country_id=setup_options.country_id, + policy_id=setup_options.reform_policy_id, + baseline_policy_id=setup_options.baseline_policy_id, + region=setup_options.region, + dataset=setup_options.dataset, + time_period=setup_options.time_period, + options_hash=setup_options.options_hash, + api_version=setup_options.api_version, + ): + most_recent_impact = self._get_most_recent_impact( + setup_options=setup_options ) - return self._handle_completed_impact( + impact_action = self._determine_impact_action( most_recent_impact=most_recent_impact ) - if impact_action == ImpactAction.COMPUTING: - logger.log_struct( - { - "message": "Found computing economic impact in db after locking; returning progress", - **setup_options.model_dump(), - }, - severity="INFO", + if impact_action == ImpactAction.COMPLETED: + logger.log_struct( + { + "message": "Found completed economic impact in db after locking; returning result", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_completed_impact( + most_recent_impact=most_recent_impact + ) + + if impact_action == ImpactAction.COMPUTING: + logger.log_struct( + { + "message": "Found computing economic impact in db after locking; returning progress", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_computing_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) + + if self._is_stale_provisional_impact(most_recent_impact): + self._expire_stale_provisional_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) + + provisional_execution_id = self._build_provisional_execution_id( + setup_options.process_id ) - return self._handle_computing_impact( + self._set_reform_impact_computing( setup_options=setup_options, - most_recent_impact=most_recent_impact, + execution_id=provisional_execution_id, ) - + except TimeoutError: logger.log_struct( { - "message": "No previous economic impact record found in db; creating new simulation run", + "message": "Timed out waiting for economic impact claim lock; re-checking existing claim", **setup_options.model_dump(), }, - severity="INFO", + severity="WARNING", ) - return self._handle_create_impact(setup_options=setup_options) + existing_impact = self._get_existing_economic_impact( + setup_options=setup_options + ) + if existing_impact is not None: + return existing_impact + return EconomicImpactResult.computing() + + logger.log_struct( + { + "message": "No previous economic impact record found in db; creating new simulation run", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_create_impact( + setup_options=setup_options, + provisional_execution_id=provisional_execution_id, + ) raise ValueError(f"Unexpected impact action: {impact_action}") @@ -694,6 +731,8 @@ def _get_existing_economic_impact( return self._handle_completed_impact(most_recent_impact=most_recent_impact) if status == ImpactStatus.COMPUTING.value: + if self._is_stale_provisional_impact(most_recent_impact): + return None return self._handle_computing_impact( setup_options=setup_options, most_recent_impact=most_recent_impact, @@ -842,6 +881,63 @@ def _get_most_recent_impact( return previous_impacts[0] + def _build_provisional_execution_id(self, process_id: str) -> str: + return f"{PENDING_EXECUTION_ID_PREFIX}{process_id}" + + def _is_provisional_execution_id(self, execution_id: Any) -> bool: + return isinstance(execution_id, str) and execution_id.startswith( + PENDING_EXECUTION_ID_PREFIX + ) + + def _coerce_impact_start_time(self, start_time: Any) -> Optional[datetime.datetime]: + if start_time is None: + return None + + if isinstance(start_time, str): + parsed_start_time = datetime.datetime.fromisoformat(start_time) + elif hasattr(start_time, "tzinfo") and hasattr(start_time, "isoformat"): + parsed_start_time = start_time + else: + return None + + if parsed_start_time.tzinfo is None: + return parsed_start_time.replace(tzinfo=datetime.timezone.utc) + + return parsed_start_time.astimezone(datetime.timezone.utc) + + def _is_stale_provisional_impact(self, impact: dict | None) -> bool: + if not impact: + return False + + if not self._is_provisional_execution_id(impact.get("execution_id")): + return False + + start_time = self._coerce_impact_start_time(impact.get("start_time")) + if start_time is None: + return False + + current_time = datetime.datetime.now(datetime.timezone.utc) + if current_time.tzinfo is None: + current_time = current_time.replace(tzinfo=datetime.timezone.utc) + + claim_age = current_time - start_time + return claim_age.total_seconds() > PROVISIONAL_CLAIM_TTL_SECONDS + + def _expire_stale_provisional_impact( + self, + setup_options: EconomicImpactSetupOptions, + most_recent_impact: dict, + ) -> None: + execution_id = most_recent_impact.get("execution_id") + if not self._is_provisional_execution_id(execution_id): + return + + self._set_reform_impact_error( + setup_options=setup_options, + message=STALE_PROVISIONAL_IMPACT_MESSAGE, + execution_id=execution_id, + ) + def _determine_impact_action( self, most_recent_impact: dict | None, @@ -853,6 +949,8 @@ def _determine_impact_action( if status in [ImpactStatus.OK.value, ImpactStatus.ERROR.value]: return ImpactAction.COMPLETED elif status == ImpactStatus.COMPUTING.value: + if self._is_stale_provisional_impact(most_recent_impact): + return ImpactAction.CREATE return ImpactAction.COMPUTING else: raise ValueError(f"Unknown impact status: {status}") @@ -936,9 +1034,11 @@ def _handle_computing_impact( setup_options: EconomicImpactSetupOptions, most_recent_impact: dict, ) -> EconomicImpactResult: - execution = simulation_api.get_execution_by_id( - most_recent_impact["execution_id"] - ) + execution_id = most_recent_impact["execution_id"] + if self._is_provisional_execution_id(execution_id): + return EconomicImpactResult.computing() + + execution = simulation_api.get_execution_by_id(execution_id) execution_state = simulation_api.get_execution_status(execution) return self._handle_execution_state( execution_state=execution_state, @@ -950,6 +1050,7 @@ def _handle_computing_impact( def _handle_create_impact( self, setup_options: EconomicImpactSetupOptions, + provisional_execution_id: str, ) -> EconomicImpactResult: baseline_policy = policy_service.get_policy_json( setup_options.country_id, setup_options.baseline_policy_id @@ -1007,8 +1108,18 @@ def _handle_create_impact( if sim_params.get("time_period") is not None: sim_params["time_period"] = str(sim_params["time_period"]) - sim_api_execution = simulation_api.run(sim_params) - execution_id = simulation_api.get_execution_id(sim_api_execution) + try: + sim_api_execution = simulation_api.run(sim_params) + execution_id = simulation_api.get_execution_id(sim_api_execution) + except Exception as error: + error_message = f"Failed to start simulation API job: {str(error)}" + self._set_reform_impact_error( + setup_options=setup_options, + message=error_message, + execution_id=provisional_execution_id, + ) + return EconomicImpactResult.error(message=error_message) + run_id = getattr(sim_api_execution, "run_id", None) or telemetry["run_id"] progress_log = { @@ -1019,9 +1130,10 @@ def _handle_create_impact( } logger.log_struct(progress_log, severity="INFO") - self._set_reform_impact_computing( + self._update_reform_impact_execution_id( setup_options=setup_options, - execution_id=execution_id, + current_execution_id=provisional_execution_id, + new_execution_id=execution_id, ) return EconomicImpactResult.computing() @@ -1366,6 +1478,33 @@ def _set_reform_impact_computing( ) raise e + def _update_reform_impact_execution_id( + self, + setup_options: EconomicImpactSetupOptions, + current_execution_id: str, + new_execution_id: str, + ): + try: + reform_impacts_service.update_reform_impact_execution_id( + country_id=setup_options.country_id, + policy_id=setup_options.reform_policy_id, + baseline_policy_id=setup_options.baseline_policy_id, + region=setup_options.region, + dataset=setup_options.dataset, + time_period=setup_options.time_period, + options_hash=setup_options.options_hash, + current_execution_id=current_execution_id, + new_execution_id=new_execution_id, + ) + except Exception as e: + logger.log_struct( + { + "message": f"Error updating reform impact execution id: {str(e)}", + **setup_options.model_dump(), + } + ) + raise e + def _set_reform_impact_complete( self, setup_options: EconomicImpactSetupOptions, diff --git a/policyengine_api/services/reform_impacts_service.py b/policyengine_api/services/reform_impacts_service.py index 4a02c2619..0e07f598e 100644 --- a/policyengine_api/services/reform_impacts_service.py +++ b/policyengine_api/services/reform_impacts_service.py @@ -101,7 +101,7 @@ def get_all_reform_impacts( "reform_impact WHERE country_id = ? AND reform_policy_id = ? AND " "baseline_policy_id = ? AND region = ? AND time_period = ? AND " "options_hash = ? AND api_version = ? AND dataset = ? " - "ORDER BY start_time DESC" + "ORDER BY start_time DESC, reform_impact_id DESC" ) return database.query( query, @@ -203,6 +203,43 @@ def set_reform_impact( print(f"Error setting reform impact: {str(e)}") raise e + def update_reform_impact_execution_id( + self, + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options_hash, + current_execution_id, + new_execution_id, + ): + try: + query = ( + "UPDATE reform_impact SET execution_id = ? WHERE country_id = ? AND " + "reform_policy_id = ? AND baseline_policy_id = ? AND region = ? AND " + "time_period = ? AND options_hash = ? AND dataset = ? AND " + "execution_id = ? AND status = 'computing'" + ) + database.query( + query, + ( + new_execution_id, + country_id, + policy_id, + baseline_policy_id, + region, + time_period, + options_hash, + dataset, + current_execution_id, + ), + ) + except Exception as e: + print(f"Error updating reform impact execution id: {str(e)}") + raise e + def delete_reform_impact( self, country_id, diff --git a/tests/fixtures/services/economy_service.py b/tests/fixtures/services/economy_service.py index cf41873ed..b0fcd1a67 100644 --- a/tests/fixtures/services/economy_service.py +++ b/tests/fixtures/services/economy_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch, MagicMock import json import datetime +from contextlib import nullcontext from policyengine_api.constants import ( MODAL_EXECUTION_STATUS_SUBMITTED, @@ -123,8 +124,10 @@ def mock_reform_impacts_service(): mock_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [] mock_service.get_all_reform_impacts.return_value = [] mock_service.set_reform_impact.return_value = None + mock_service.update_reform_impact_execution_id.return_value = None mock_service.set_complete_reform_impact.return_value = None mock_service.set_error_reform_impact.return_value = None + mock_service.claim_lock.side_effect = lambda **kwargs: nullcontext() with patch( "policyengine_api.services.economy_service.reform_impacts_service", @@ -187,6 +190,7 @@ def create_mock_reform_impact( reform_impact_json=None, execution_id=MOCK_MODAL_JOB_ID, options_hash=MOCK_OPTIONS_HASH, + start_time=None, ): """Helper function to create mock reform impact records.""" default_reform_impact_json = json.dumps( @@ -214,7 +218,7 @@ def create_mock_reform_impact( "api_version": MOCK_API_VERSION, "reform_impact_json": reform_impact_json or default_reform_impact_json, "execution_id": execution_id, - "start_time": datetime.datetime(2025, 6, 26, 12, 0, 0), + "start_time": start_time or datetime.datetime(2025, 6, 26, 12, 0, 0), "end_time": ( datetime.datetime(2025, 6, 26, 12, 5, 0) if status == "ok" else None ), diff --git a/tests/unit/data/test_sqlalchemy_v2.py b/tests/unit/data/test_sqlalchemy_v2.py index 3882bb0f7..2ea63f0f0 100644 --- a/tests/unit/data/test_sqlalchemy_v2.py +++ b/tests/unit/data/test_sqlalchemy_v2.py @@ -12,6 +12,7 @@ import pytest import sqlalchemy +from unittest.mock import MagicMock from policyengine_api.data.data import _ResultProxy, PolicyEngineDatabase @@ -180,3 +181,34 @@ def test_remote_delete(self): db._execute_remote(["DELETE FROM test_table WHERE id = ?", (1,)]) result = db._execute_remote(["SELECT * FROM test_table WHERE id = ?", (1,)]) assert result.fetchone() is None + + +class TestRemotePoolCreation: + def test_create_pool_uses_fresh_connection_creator(self, monkeypatch): + first_connection = MagicMock(name="first_connection") + second_connection = MagicMock(name="second_connection") + mock_connector = MagicMock() + mock_connector.connect.side_effect = [first_connection, second_connection] + + captured_kwargs = {} + + def fake_create_engine(url, **kwargs): + captured_kwargs.update(kwargs) + return MagicMock() + + monkeypatch.setenv("POLICYENGINE_DB_PASSWORD", "test-password") + monkeypatch.setattr( + "policyengine_api.data.data.Connector", lambda: mock_connector + ) + monkeypatch.setattr( + "policyengine_api.data.data.sqlalchemy.create_engine", + fake_create_engine, + ) + + db = PolicyEngineDatabase.__new__(PolicyEngineDatabase) + db._create_pool() + + creator = captured_kwargs["creator"] + assert creator() is first_connection + assert creator() is second_connection + assert captured_kwargs["pool_pre_ping"] is True diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index 15ef97bf1..9c4aafc79 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -1,3 +1,4 @@ +import datetime import json import pytest from unittest.mock import patch, MagicMock @@ -10,6 +11,9 @@ EconomicImpactSetupOptions, ImpactAction, ImpactStatus, + PENDING_EXECUTION_ID_PREFIX, + PROVISIONAL_CLAIM_TTL_SECONDS, + STALE_PROVISIONAL_IMPACT_MESSAGE, ) from tests.fixtures.services.economy_service import ( MOCK_COUNTRY_ID, @@ -256,6 +260,17 @@ def test__given_no_previous_impact__creates_new_simulation( assert result.data is None mock_simulation_api.run.assert_called_once() mock_reform_impacts_service.set_reform_impact.assert_called_once() + mock_reform_impacts_service.update_reform_impact_execution_id.assert_called_once_with( + country_id=MOCK_COUNTRY_ID, + policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=MOCK_TIME_PERIOD, + options_hash=MOCK_OPTIONS_HASH, + current_execution_id=f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}", + new_execution_id=MOCK_EXECUTION_ID, + ) def test__given_no_previous_impact__includes_metadata_in_simulation_params( self, @@ -327,6 +342,114 @@ def test__given_no_previous_impact__includes_telemetry_in_simulation_params( mock_logger.log_struct.call_args_list[-1].kwargs["severity"] == "INFO" ) + def test__given_simulation_api_submission_failure__marks_provisional_claim_error( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + mock_reform_impacts_service.get_all_reform_impacts.return_value = [] + mock_simulation_api.run.side_effect = RuntimeError("gateway unavailable") + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.ERROR + assert ( + result.message + == "Failed to start simulation API job: gateway unavailable" + ) + mock_reform_impacts_service.set_reform_impact.assert_called_once() + mock_reform_impacts_service.set_error_reform_impact.assert_called_once_with( + country_id=MOCK_COUNTRY_ID, + policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=MOCK_TIME_PERIOD, + options_hash=MOCK_OPTIONS_HASH, + message="Failed to start simulation API job: gateway unavailable", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}", + ) + mock_reform_impacts_service.update_reform_impact_execution_id.assert_not_called() + + def test__given_claim_lock_timeout_and_existing_provisional_claim__returns_computing( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_numpy_random, + ): + provisional_impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_other", + start_time=datetime.datetime.now(datetime.timezone.utc), + ) + mock_reform_impacts_service.get_all_reform_impacts.side_effect = [ + [], + [provisional_impact], + ] + mock_reform_impacts_service.claim_lock.side_effect = TimeoutError( + "lock busy" + ) + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + mock_simulation_api.run.assert_not_called() + + def test__given_stale_provisional_claim__expires_and_recreates_simulation( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + ): + stale_start_time = datetime.datetime.now( + datetime.timezone.utc + ) - datetime.timedelta(seconds=PROVISIONAL_CLAIM_TTL_SECONDS + 1) + stale_provisional_impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_stale", + start_time=stale_start_time, + ) + mock_reform_impacts_service.get_all_reform_impacts.side_effect = [ + [stale_provisional_impact], + [stale_provisional_impact], + ] + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + mock_reform_impacts_service.set_error_reform_impact.assert_called_once_with( + country_id=MOCK_COUNTRY_ID, + policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=MOCK_TIME_PERIOD, + options_hash=MOCK_OPTIONS_HASH, + message=STALE_PROVISIONAL_IMPACT_MESSAGE, + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_stale", + ) + mock_reform_impacts_service.set_reform_impact.assert_called_once() + mock_simulation_api.run.assert_called_once() + def test__given_runtime_cache_version__uses_versioned_economy_cache_key( self, economy_service, @@ -1093,6 +1216,47 @@ def test__given_no_impacts__returns_none( # Assert assert result is None + class TestGetExistingEconomicImpact: + @pytest.fixture + def economy_service(self): + return EconomyService() + + @pytest.fixture + def setup_options(self): + return EconomicImpactSetupOptions( + process_id=MOCK_PROCESS_ID, + country_id=MOCK_COUNTRY_ID, + reform_policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=MOCK_TIME_PERIOD, + options=MOCK_OPTIONS, + api_version=MOCK_API_VERSION, + target="general", + options_hash=MOCK_OPTIONS_HASH, + ) + + def test__given_stale_provisional_impact__returns_none( + self, + economy_service, + setup_options, + mock_reform_impacts_service, + ): + stale_impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_stale", + start_time=datetime.datetime.now(datetime.timezone.utc) + - datetime.timedelta(seconds=PROVISIONAL_CLAIM_TTL_SECONDS + 1), + ) + mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + stale_impact + ] + + result = economy_service._get_existing_economic_impact(setup_options) + + assert result is None + class TestDetermineImpactAction: @pytest.fixture def economy_service(self): @@ -1124,6 +1288,20 @@ def test__given_computing_status__returns_computing(self, economy_service): assert result == ImpactAction.COMPUTING + def test__given_stale_provisional_computing_status__returns_create( + self, economy_service + ): + impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_stale", + start_time=datetime.datetime.now(datetime.timezone.utc) + - datetime.timedelta(seconds=PROVISIONAL_CLAIM_TTL_SECONDS + 1), + ) + + result = economy_service._determine_impact_action(impact) + + assert result == ImpactAction.CREATE + def test__given_unknown_status__raises_error(self, economy_service): impact = create_mock_reform_impact(status="unknown") @@ -1212,6 +1390,21 @@ def test__given_active_state__returns_computing_result( assert result.status == ImpactStatus.COMPUTING assert result.data is None + def test__given_provisional_claim__returns_computing_without_polling( + self, economy_service, setup_options, mock_simulation_api, mock_logger + ): + reform_impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_pending", + ) + + result = economy_service._handle_computing_impact( + setup_options, reform_impact + ) + + assert result.status == ImpactStatus.COMPUTING + mock_simulation_api.get_execution_by_id.assert_not_called() + def test__given_unknown_state__raises_error( self, economy_service, setup_options ): From 83dd648a52d306c21f7cfb795f6ce27dafa6cf93 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Thu, 9 Apr 2026 10:41:27 -0400 Subject: [PATCH 07/27] Tighten budget window claim recovery --- policyengine_api/api.py | 8 ++- policyengine_api/data/data.py | 1 + policyengine_api/services/economy_service.py | 62 +++++++++++++++---- .../services/reform_impacts_service.py | 3 +- tests/fixtures/services/economy_service.py | 2 +- tests/unit/services/test_economy_service.py | 36 +++++++++++ 6 files changed, 95 insertions(+), 17 deletions(-) diff --git a/policyengine_api/api.py b/policyengine_api/api.py index 112cce9ac..eb3eba9ee 100644 --- a/policyengine_api/api.py +++ b/policyengine_api/api.py @@ -4,6 +4,7 @@ import time import sys +import os start_time = time.time() @@ -157,8 +158,11 @@ def log_timing(message): app.register_blueprint(user_profile_bp) log_timing("User profile routes registered") -app.route("/simulations", methods=["GET"])(get_simulations) -log_timing("Simulations endpoint registered") +if os.environ.get("FLASK_DEBUG") == "1": + app.route("/simulations", methods=["GET"])(get_simulations) + log_timing("Simulations endpoint registered") +else: + log_timing("Simulations endpoint skipped outside debug mode") app.register_blueprint(tracer_analysis_bp) log_timing("Tracer analysis routes registered") diff --git a/policyengine_api/data/data.py b/policyengine_api/data/data.py index 1ba0ab57a..058f2e714 100644 --- a/policyengine_api/data/data.py +++ b/policyengine_api/data/data.py @@ -19,6 +19,7 @@ class _ResultProxy: Provides fetchone()/fetchall() with dict-like row access.""" def __init__(self, cursor_result): + self.rowcount = getattr(cursor_result, "rowcount", -1) try: # Use .mappings() so rows behave like dicts self._rows = list(cursor_result.mappings()) diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 41ba1fd1f..68d3f9580 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -667,10 +667,10 @@ def _get_or_create_economic_impact( most_recent_impact=most_recent_impact, ) + stale_provisional_execution_id = None if self._is_stale_provisional_impact(most_recent_impact): - self._expire_stale_provisional_impact( - setup_options=setup_options, - most_recent_impact=most_recent_impact, + stale_provisional_execution_id = most_recent_impact.get( + "execution_id" ) provisional_execution_id = self._build_provisional_execution_id( @@ -680,6 +680,11 @@ def _get_or_create_economic_impact( setup_options=setup_options, execution_id=provisional_execution_id, ) + if stale_provisional_execution_id: + self._expire_stale_provisional_impact( + setup_options=setup_options, + execution_id=stale_provisional_execution_id, + ) except TimeoutError: logger.log_struct( { @@ -926,9 +931,8 @@ def _is_stale_provisional_impact(self, impact: dict | None) -> bool: def _expire_stale_provisional_impact( self, setup_options: EconomicImpactSetupOptions, - most_recent_impact: dict, + execution_id: str, ) -> None: - execution_id = most_recent_impact.get("execution_id") if not self._is_provisional_execution_id(execution_id): return @@ -1130,11 +1134,40 @@ def _handle_create_impact( } logger.log_struct(progress_log, severity="INFO") - self._update_reform_impact_execution_id( - setup_options=setup_options, - current_execution_id=provisional_execution_id, - new_execution_id=execution_id, - ) + try: + updated_rows = self._update_reform_impact_execution_id( + setup_options=setup_options, + current_execution_id=provisional_execution_id, + new_execution_id=execution_id, + ) + except Exception as error: + logger.log_struct( + { + "message": "Failed to promote provisional reform impact row; inserting replacement tracking row", + **setup_options.model_dump(), + "execution_id": execution_id, + "provisional_execution_id": provisional_execution_id, + "error": str(error), + }, + severity="WARNING", + ) + updated_rows = 0 + + if updated_rows != 1: + logger.log_struct( + { + "message": "Provisional reform impact row was not updated; inserting replacement tracking row", + **setup_options.model_dump(), + "execution_id": execution_id, + "provisional_execution_id": provisional_execution_id, + "updated_rows": updated_rows, + }, + severity="WARNING", + ) + self._set_reform_impact_computing( + setup_options=setup_options, + execution_id=execution_id, + ) return EconomicImpactResult.computing() @@ -1454,6 +1487,9 @@ def _set_reform_impact_computing( In the reform_impact table, set the status of the impact to "computing". """ try: + start_time = datetime.datetime.now(datetime.timezone.utc).replace( + tzinfo=None + ) reform_impacts_service.set_reform_impact( country_id=setup_options.country_id, policy_id=setup_options.reform_policy_id, @@ -1466,7 +1502,7 @@ def _set_reform_impact_computing( status=ImpactStatus.COMPUTING.value, api_version=setup_options.api_version, reform_impact_json=json.dumps({}), - start_time=datetime.datetime.now(), + start_time=start_time, execution_id=execution_id, ) except Exception as e: @@ -1483,9 +1519,9 @@ def _update_reform_impact_execution_id( setup_options: EconomicImpactSetupOptions, current_execution_id: str, new_execution_id: str, - ): + ) -> int | None: try: - reform_impacts_service.update_reform_impact_execution_id( + return reform_impacts_service.update_reform_impact_execution_id( country_id=setup_options.country_id, policy_id=setup_options.reform_policy_id, baseline_policy_id=setup_options.baseline_policy_id, diff --git a/policyengine_api/services/reform_impacts_service.py b/policyengine_api/services/reform_impacts_service.py index 0e07f598e..58175dadf 100644 --- a/policyengine_api/services/reform_impacts_service.py +++ b/policyengine_api/services/reform_impacts_service.py @@ -222,7 +222,7 @@ def update_reform_impact_execution_id( "time_period = ? AND options_hash = ? AND dataset = ? AND " "execution_id = ? AND status = 'computing'" ) - database.query( + result = database.query( query, ( new_execution_id, @@ -236,6 +236,7 @@ def update_reform_impact_execution_id( current_execution_id, ), ) + return getattr(result, "rowcount", None) except Exception as e: print(f"Error updating reform impact execution id: {str(e)}") raise e diff --git a/tests/fixtures/services/economy_service.py b/tests/fixtures/services/economy_service.py index b0fcd1a67..d5c8da047 100644 --- a/tests/fixtures/services/economy_service.py +++ b/tests/fixtures/services/economy_service.py @@ -124,7 +124,7 @@ def mock_reform_impacts_service(): mock_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [] mock_service.get_all_reform_impacts.return_value = [] mock_service.set_reform_impact.return_value = None - mock_service.update_reform_impact_execution_id.return_value = None + mock_service.update_reform_impact_execution_id.return_value = 1 mock_service.set_complete_reform_impact.return_value = None mock_service.set_error_reform_impact.return_value = None mock_service.claim_lock.side_effect = lambda **kwargs: nullcontext() diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index 9c4aafc79..d0a8b0847 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -260,6 +260,10 @@ def test__given_no_previous_impact__creates_new_simulation( assert result.data is None mock_simulation_api.run.assert_called_once() mock_reform_impacts_service.set_reform_impact.assert_called_once() + assert any( + call.args == (datetime.timezone.utc,) + for call in mock_datetime.now.call_args_list + ) mock_reform_impacts_service.update_reform_impact_execution_id.assert_called_once_with( country_id=MOCK_COUNTRY_ID, policy_id=MOCK_POLICY_ID, @@ -450,6 +454,38 @@ def test__given_stale_provisional_claim__expires_and_recreates_simulation( mock_reform_impacts_service.set_reform_impact.assert_called_once() mock_simulation_api.run.assert_called_once() + def test__given_provisional_promotion_updates_zero_rows__inserts_replacement_tracking_row( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + mock_reform_impacts_service.get_all_reform_impacts.return_value = [] + mock_reform_impacts_service.update_reform_impact_execution_id.return_value = 0 + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + assert mock_reform_impacts_service.set_reform_impact.call_count == 2 + first_insert = mock_reform_impacts_service.set_reform_impact.call_args_list[ + 0 + ] + second_insert = ( + mock_reform_impacts_service.set_reform_impact.call_args_list[1] + ) + assert ( + first_insert.kwargs["execution_id"] + == f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}" + ) + assert second_insert.kwargs["execution_id"] == MOCK_EXECUTION_ID + def test__given_runtime_cache_version__uses_versioned_economy_cache_key( self, economy_service, From 1be87712b006ede5444ed6374acee72042a09bbb Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Thu, 9 Apr 2026 10:55:41 -0400 Subject: [PATCH 08/27] Prevent stale claim takeover --- policyengine_api/services/economy_service.py | 90 ++++++++++++++++++-- tests/unit/endpoints/test_simulation.py | 16 ++++ tests/unit/services/test_economy_service.py | 39 +++++++++ 3 files changed, 137 insertions(+), 8 deletions(-) create mode 100644 tests/unit/endpoints/test_simulation.py diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 68d3f9580..a4aaf1caf 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -1154,22 +1154,96 @@ def _handle_create_impact( updated_rows = 0 if updated_rows != 1: + self._recover_failed_execution_id_promotion( + setup_options=setup_options, + provisional_execution_id=provisional_execution_id, + execution_id=execution_id, + updated_rows=updated_rows, + ) + + return EconomicImpactResult.computing() + + def _recover_failed_execution_id_promotion( + self, + *, + setup_options: EconomicImpactSetupOptions, + provisional_execution_id: str, + execution_id: str, + updated_rows: int | None, + ) -> None: + logger.log_struct( + { + "message": "Provisional reform impact row was not updated; checking whether tracking has already been superseded", + **setup_options.model_dump(), + "execution_id": execution_id, + "provisional_execution_id": provisional_execution_id, + "updated_rows": updated_rows, + }, + severity="WARNING", + ) + + try: + with reform_impacts_service.claim_lock( + country_id=setup_options.country_id, + policy_id=setup_options.reform_policy_id, + baseline_policy_id=setup_options.baseline_policy_id, + region=setup_options.region, + dataset=setup_options.dataset, + time_period=setup_options.time_period, + options_hash=setup_options.options_hash, + api_version=setup_options.api_version, + ): + most_recent_impact = self._get_most_recent_impact( + setup_options=setup_options + ) + if most_recent_impact is not None: + impact_status = most_recent_impact.get("status") + tracked_execution_id = most_recent_impact.get("execution_id") + if tracked_execution_id == execution_id: + return + + if ( + impact_status == ImpactStatus.COMPUTING.value + and tracked_execution_id == provisional_execution_id + ): + retry_updated_rows = self._update_reform_impact_execution_id( + setup_options=setup_options, + current_execution_id=provisional_execution_id, + new_execution_id=execution_id, + ) + if retry_updated_rows == 1: + return + elif impact_status in ( + ImpactStatus.OK.value, + ImpactStatus.COMPUTING.value, + ): + logger.log_struct( + { + "message": "Skipping replacement tracking row because another claim is already authoritative", + **setup_options.model_dump(), + "execution_id": execution_id, + "provisional_execution_id": provisional_execution_id, + "tracked_execution_id": tracked_execution_id, + "tracked_status": impact_status, + }, + severity="WARNING", + ) + return + + self._set_reform_impact_computing( + setup_options=setup_options, + execution_id=execution_id, + ) + except TimeoutError: logger.log_struct( { - "message": "Provisional reform impact row was not updated; inserting replacement tracking row", + "message": "Timed out while recovering failed provisional promotion; leaving the newer claim authoritative", **setup_options.model_dump(), "execution_id": execution_id, "provisional_execution_id": provisional_execution_id, - "updated_rows": updated_rows, }, severity="WARNING", ) - self._set_reform_impact_computing( - setup_options=setup_options, - execution_id=execution_id, - ) - - return EconomicImpactResult.computing() def _setup_sim_options( self, diff --git a/tests/unit/endpoints/test_simulation.py b/tests/unit/endpoints/test_simulation.py new file mode 100644 index 000000000..c29837eec --- /dev/null +++ b/tests/unit/endpoints/test_simulation.py @@ -0,0 +1,16 @@ +from unittest.mock import MagicMock, patch + +from policyengine_api.endpoints.simulation import get_simulations + + +def test_get_simulations_reads_from_shared_database(): + mock_database = MagicMock() + mock_database.query.return_value.fetchall.return_value = [{"id": 1}] + + with patch("policyengine_api.endpoints.simulation.database", mock_database): + result = get_simulations() + + mock_database.query.assert_called_once_with( + "SELECT * FROM reform_impact ORDER BY start_time DESC LIMIT 100", + ) + assert result == {"result": [{"id": 1}]} diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index d0a8b0847..7d431453c 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -486,6 +486,45 @@ def test__given_provisional_promotion_updates_zero_rows__inserts_replacement_tra ) assert second_insert.kwargs["execution_id"] == MOCK_EXECUTION_ID + def test__given_provisional_promotion_updates_zero_rows_but_newer_claim_exists__does_not_insert_fallback( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + replacement_impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_replacement", + start_time=datetime.datetime.now(datetime.timezone.utc), + ) + mock_reform_impacts_service.get_all_reform_impacts.side_effect = [ + [], + [], + [replacement_impact], + ] + mock_reform_impacts_service.update_reform_impact_execution_id.return_value = 0 + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + assert mock_reform_impacts_service.set_reform_impact.call_count == 1 + inserted_execution_id = ( + mock_reform_impacts_service.set_reform_impact.call_args.kwargs[ + "execution_id" + ] + ) + assert ( + inserted_execution_id + == f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}" + ) + def test__given_runtime_cache_version__uses_versioned_economy_cache_key( self, economy_service, From c663e73d3738ec1b6e269f4f4cc8cf9117e7aa24 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Thu, 9 Apr 2026 12:41:29 -0400 Subject: [PATCH 09/27] Backfill reform impact schema lazily --- .../services/reform_impacts_service.py | 43 ++++++++++++ .../services/test_reform_impacts_service.py | 66 +++++++++++++++++++ 2 files changed, 109 insertions(+) diff --git a/policyengine_api/services/reform_impacts_service.py b/policyengine_api/services/reform_impacts_service.py index 58175dadf..3f1d9cfa1 100644 --- a/policyengine_api/services/reform_impacts_service.py +++ b/policyengine_api/services/reform_impacts_service.py @@ -6,6 +6,7 @@ LOCAL_REFORM_IMPACT_LOCK = Lock() +REFORM_IMPACT_SCHEMA_LOCK = Lock() REFORM_IMPACT_LOCK_TIMEOUT_SECONDS = 5 @@ -15,6 +16,42 @@ class ReformImpactsService: this is connected to the shared reform_impact table. """ + def __init__(self): + self._schema_checked = False + + def _ensure_remote_schema(self) -> None: + if database.local or self._schema_checked: + return + + with REFORM_IMPACT_SCHEMA_LOCK: + if self._schema_checked: + return + + existing_columns = { + row["Field"] + for row in database.query("SHOW COLUMNS FROM reform_impact").fetchall() + } + required_columns = { + "execution_id": ( + "ALTER TABLE reform_impact " + "ADD COLUMN execution_id VARCHAR(255) NULL" + ), + "end_time": ( + "ALTER TABLE reform_impact ADD COLUMN end_time DATETIME NULL" + ), + } + + for column_name, alter_query in required_columns.items(): + if column_name in existing_columns: + continue + try: + database.query(alter_query) + except Exception as error: + if "Duplicate column name" not in str(error): + raise + + self._schema_checked = True + def _build_lock_name( self, country_id, @@ -96,6 +133,7 @@ def get_all_reform_impacts( api_version, ): try: + self._ensure_remote_schema() query = ( "SELECT reform_impact_json, status, message, start_time, execution_id FROM " "reform_impact WHERE country_id = ? AND reform_policy_id = ? AND " @@ -176,6 +214,7 @@ def set_reform_impact( execution_id: str, ): try: + self._ensure_remote_schema() query = ( "INSERT INTO reform_impact (country_id, reform_policy_id, baseline_policy_id, " "region, dataset, time_period, options_json, options_hash, status, api_version, " @@ -216,6 +255,7 @@ def update_reform_impact_execution_id( new_execution_id, ): try: + self._ensure_remote_schema() query = ( "UPDATE reform_impact SET execution_id = ? WHERE country_id = ? AND " "reform_policy_id = ? AND baseline_policy_id = ? AND region = ? AND " @@ -252,6 +292,7 @@ def delete_reform_impact( options_hash, ): try: + self._ensure_remote_schema() query = ( "DELETE FROM reform_impact WHERE country_id = ? AND " "reform_policy_id = ? AND baseline_policy_id = ? AND " @@ -288,6 +329,7 @@ def set_error_reform_impact( execution_id: str, ): try: + self._ensure_remote_schema() query = ( "UPDATE reform_impact SET status = ?, message = ?, end_time = ? WHERE " "country_id = ? AND reform_policy_id = ? AND baseline_policy_id = ? AND " @@ -332,6 +374,7 @@ def set_complete_reform_impact( execution_id, ): try: + self._ensure_remote_schema() query = ( "UPDATE reform_impact SET status = ?, message = ?, end_time = ?, " "reform_impact_json = ? WHERE country_id = ? AND reform_policy_id = ? AND " diff --git a/tests/unit/services/test_reform_impacts_service.py b/tests/unit/services/test_reform_impacts_service.py index 4f44b63a2..4456dffad 100644 --- a/tests/unit/services/test_reform_impacts_service.py +++ b/tests/unit/services/test_reform_impacts_service.py @@ -6,6 +6,72 @@ class TestReformImpactsService: + def test__given_remote_database_missing_columns__ensure_remote_schema_adds_them( + self, monkeypatch + ): + service = ReformImpactsService() + + show_columns_result = MagicMock() + show_columns_result.fetchall.return_value = [ + {"Field": "reform_impact_id"}, + {"Field": "status"}, + {"Field": "start_time"}, + ] + alter_execution_result = MagicMock() + alter_end_time_result = MagicMock() + + mock_database = MagicMock() + mock_database.local = False + mock_database.query.side_effect = [ + show_columns_result, + alter_execution_result, + alter_end_time_result, + ] + + monkeypatch.setattr( + "policyengine_api.services.reform_impacts_service.database", + mock_database, + ) + + service._ensure_remote_schema() + + assert mock_database.query.call_args_list[0].args == ( + "SHOW COLUMNS FROM reform_impact", + ) + assert mock_database.query.call_args_list[1].args == ( + "ALTER TABLE reform_impact ADD COLUMN execution_id VARCHAR(255) NULL", + ) + assert mock_database.query.call_args_list[2].args == ( + "ALTER TABLE reform_impact ADD COLUMN end_time DATETIME NULL", + ) + + def test__given_remote_database_existing_columns__ensure_remote_schema_skips_alter( + self, monkeypatch + ): + service = ReformImpactsService() + + show_columns_result = MagicMock() + show_columns_result.fetchall.return_value = [ + {"Field": "reform_impact_id"}, + {"Field": "status"}, + {"Field": "start_time"}, + {"Field": "execution_id"}, + {"Field": "end_time"}, + ] + + mock_database = MagicMock() + mock_database.local = False + mock_database.query.return_value = show_columns_result + + monkeypatch.setattr( + "policyengine_api.services.reform_impacts_service.database", + mock_database, + ) + + service._ensure_remote_schema() + + mock_database.query.assert_called_once_with("SHOW COLUMNS FROM reform_impact") + def test__given_remote_database__claim_lock_uses_advisory_lock(self, monkeypatch): service = ReformImpactsService() From 7b1c3adeeb11cee5c105fcbdeed5b6ce7091486d Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 10 Apr 2026 09:18:36 -0400 Subject: [PATCH 10/27] Backfill reform impact dataset column --- policyengine_api/services/reform_impacts_service.py | 4 ++++ tests/unit/services/test_reform_impacts_service.py | 8 +++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/policyengine_api/services/reform_impacts_service.py b/policyengine_api/services/reform_impacts_service.py index 3f1d9cfa1..a090516fa 100644 --- a/policyengine_api/services/reform_impacts_service.py +++ b/policyengine_api/services/reform_impacts_service.py @@ -32,6 +32,10 @@ def _ensure_remote_schema(self) -> None: for row in database.query("SHOW COLUMNS FROM reform_impact").fetchall() } required_columns = { + "dataset": ( + "ALTER TABLE reform_impact " + "ADD COLUMN dataset VARCHAR(255) NOT NULL DEFAULT 'default'" + ), "execution_id": ( "ALTER TABLE reform_impact " "ADD COLUMN execution_id VARCHAR(255) NULL" diff --git a/tests/unit/services/test_reform_impacts_service.py b/tests/unit/services/test_reform_impacts_service.py index 4456dffad..106cf8757 100644 --- a/tests/unit/services/test_reform_impacts_service.py +++ b/tests/unit/services/test_reform_impacts_service.py @@ -17,6 +17,7 @@ def test__given_remote_database_missing_columns__ensure_remote_schema_adds_them( {"Field": "status"}, {"Field": "start_time"}, ] + alter_dataset_result = MagicMock() alter_execution_result = MagicMock() alter_end_time_result = MagicMock() @@ -24,6 +25,7 @@ def test__given_remote_database_missing_columns__ensure_remote_schema_adds_them( mock_database.local = False mock_database.query.side_effect = [ show_columns_result, + alter_dataset_result, alter_execution_result, alter_end_time_result, ] @@ -39,9 +41,12 @@ def test__given_remote_database_missing_columns__ensure_remote_schema_adds_them( "SHOW COLUMNS FROM reform_impact", ) assert mock_database.query.call_args_list[1].args == ( - "ALTER TABLE reform_impact ADD COLUMN execution_id VARCHAR(255) NULL", + "ALTER TABLE reform_impact ADD COLUMN dataset VARCHAR(255) NOT NULL DEFAULT 'default'", ) assert mock_database.query.call_args_list[2].args == ( + "ALTER TABLE reform_impact ADD COLUMN execution_id VARCHAR(255) NULL", + ) + assert mock_database.query.call_args_list[3].args == ( "ALTER TABLE reform_impact ADD COLUMN end_time DATETIME NULL", ) @@ -55,6 +60,7 @@ def test__given_remote_database_existing_columns__ensure_remote_schema_skips_alt {"Field": "reform_impact_id"}, {"Field": "status"}, {"Field": "start_time"}, + {"Field": "dataset"}, {"Field": "execution_id"}, {"Field": "end_time"}, ] From f7baae4121d7eac7f5ea01c37183bdde415e4ce7 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 10 Apr 2026 09:46:20 -0400 Subject: [PATCH 11/27] Address budget window review findings --- policyengine_api/data/__init__.py | 7 +- policyengine_api/data/data.py | 8 ++ policyengine_api/endpoints/simulation.py | 2 +- policyengine_api/openapi_spec.yaml | 132 +++++++++++++++++++ policyengine_api/services/economy_service.py | 55 ++++++-- tests/unit/endpoints/test_simulation.py | 7 +- tests/unit/services/test_economy_service.py | 37 +++++- 7 files changed, 234 insertions(+), 14 deletions(-) diff --git a/policyengine_api/data/__init__.py b/policyengine_api/data/__init__.py index 15673afdb..94703ee36 100644 --- a/policyengine_api/data/__init__.py +++ b/policyengine_api/data/__init__.py @@ -1 +1,6 @@ -from .data import PolicyEngineDatabase, database, local_database +from .data import ( + PolicyEngineDatabase, + database, + get_remote_database, + local_database, +) diff --git a/policyengine_api/data/data.py b/policyengine_api/data/data.py index 058f2e714..78cdb5459 100644 --- a/policyengine_api/data/data.py +++ b/policyengine_api/data/data.py @@ -264,3 +264,11 @@ def initialize(self): database = PolicyEngineDatabase(local=False, initialize=False) local_database = PolicyEngineDatabase(local=True, initialize=False) +remote_database = None + + +def get_remote_database() -> PolicyEngineDatabase: + global remote_database + if remote_database is None: + remote_database = PolicyEngineDatabase(local=False, initialize=False) + return remote_database diff --git a/policyengine_api/endpoints/simulation.py b/policyengine_api/endpoints/simulation.py index b03941442..1f8fa8662 100644 --- a/policyengine_api/endpoints/simulation.py +++ b/policyengine_api/endpoints/simulation.py @@ -1,4 +1,4 @@ -from policyengine_api.data import database +from policyengine_api.data import get_remote_database """ diff --git a/policyengine_api/openapi_spec.yaml b/policyengine_api/openapi_spec.yaml index a49268c8c..77daadc9e 100644 --- a/policyengine_api/openapi_spec.yaml +++ b/policyengine_api/openapi_spec.yaml @@ -660,6 +660,138 @@ paths: type: string message: type: string + /{country_id}/economy/{policy_id}/over/{baseline_policy_id}/budget-window: + get: + summary: Calculate budget-window economic impacts + operationId: get_budget_window_economic_impact + description: Calculate annual and total budget impacts for a policy over a multi-year budget window. + parameters: + - name: country_id + in: path + description: The country ID. + required: true + schema: + type: string + - name: policy_id + in: path + description: The reform policy ID. + required: true + schema: + type: string + - name: baseline_policy_id + in: path + description: The baseline policy ID. + required: true + schema: + type: string + - name: region + in: query + description: The sub-national region. + required: true + schema: + type: string + - name: start_year + in: query + description: First year in the budget window. + required: true + schema: + type: string + - name: window_size + in: query + description: Number of years to include in the budget window. + required: true + schema: + type: integer + - name: dataset + in: query + description: Dataset selection. + required: false + schema: + type: string + default: default + - name: version + in: query + description: API version number. + required: false + schema: + type: string + - name: include_district_breakdowns + in: query + description: Whether to include congressional district breakdowns for US national simulations. + required: false + schema: + type: boolean + default: false + - name: target + in: query + description: Impact target. Budget-window calculations only support general impacts. + required: false + schema: + type: string + default: general + responses: + 200: + description: Budget-window economic impact, progress, or error state. + content: + application/json: + schema: + type: object + properties: + status: + type: string + enum: + - ok + - computing + - error + message: + type: string + nullable: true + result: + type: object + nullable: true + progress: + type: integer + nullable: true + completed_years: + type: array + items: + type: string + computing_years: + type: array + items: + type: string + queued_years: + type: array + items: + type: string + error: + type: string + nullable: true + 400: + description: Invalid budget-window request. + content: + application/json: + schema: + type: object + properties: + status: + type: string + message: + type: string + result: + type: object + nullable: true + 404: + description: Invalid country ID. + content: + text/html: + schema: + type: object + properties: + status: + type: string + message: + type: string /{country_id}/analysis: post: summary: Get or trigger policy analysis diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index a4aaf1caf..89c3b9767 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -58,6 +58,7 @@ class ImpactAction(Enum): COMPLETED = "completed" COMPUTING = "computing" CREATE = "create" + ERROR = "error" class ImpactStatus(Enum): @@ -623,6 +624,19 @@ def _get_or_create_economic_impact( most_recent_impact=most_recent_impact, ) + if impact_action == ImpactAction.ERROR: + logger.log_struct( + { + "message": "Found failed economic impact in db; returning error", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_error_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) + if impact_action == ImpactAction.CREATE: try: with reform_impacts_service.claim_lock( @@ -667,6 +681,19 @@ def _get_or_create_economic_impact( most_recent_impact=most_recent_impact, ) + if impact_action == ImpactAction.ERROR: + logger.log_struct( + { + "message": "Found failed economic impact in db after locking; returning error", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_error_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) + stale_provisional_execution_id = None if self._is_stale_provisional_impact(most_recent_impact): stale_provisional_execution_id = most_recent_impact.get( @@ -723,13 +750,9 @@ def _get_existing_economic_impact( status = most_recent_impact.get("status") if status == ImpactStatus.ERROR.value: - error_message = most_recent_impact.get("message") or ( - f"Economic impact failed for {setup_options.time_period}" - ) - return EconomicImpactResult( - status=ImpactStatus.ERROR, - data=None, - message=error_message, + return self._handle_error_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, ) if status == ImpactStatus.OK.value: @@ -950,8 +973,10 @@ def _determine_impact_action( return ImpactAction.CREATE status = most_recent_impact.get("status") - if status in [ImpactStatus.OK.value, ImpactStatus.ERROR.value]: + if status == ImpactStatus.OK.value: return ImpactAction.COMPLETED + elif status == ImpactStatus.ERROR.value: + return ImpactAction.ERROR elif status == ImpactStatus.COMPUTING.value: if self._is_stale_provisional_impact(most_recent_impact): return ImpactAction.CREATE @@ -1033,6 +1058,20 @@ def _handle_completed_impact( ) ) + def _handle_error_impact( + self, + setup_options: EconomicImpactSetupOptions, + most_recent_impact: dict, + ) -> EconomicImpactResult: + error_message = most_recent_impact.get("message") or ( + f"Economic impact failed for {setup_options.time_period}" + ) + return EconomicImpactResult( + status=ImpactStatus.ERROR, + data=None, + message=error_message, + ) + def _handle_computing_impact( self, setup_options: EconomicImpactSetupOptions, diff --git a/tests/unit/endpoints/test_simulation.py b/tests/unit/endpoints/test_simulation.py index c29837eec..a9013a056 100644 --- a/tests/unit/endpoints/test_simulation.py +++ b/tests/unit/endpoints/test_simulation.py @@ -3,11 +3,14 @@ from policyengine_api.endpoints.simulation import get_simulations -def test_get_simulations_reads_from_shared_database(): +def test_get_simulations_reads_from_remote_database(): mock_database = MagicMock() mock_database.query.return_value.fetchall.return_value = [{"id": 1}] - with patch("policyengine_api.endpoints.simulation.database", mock_database): + with patch( + "policyengine_api.endpoints.simulation.get_remote_database", + return_value=mock_database, + ): result = get_simulations() mock_database.query.assert_called_once_with( diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index 7d431453c..8fe21efa9 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -147,6 +147,39 @@ def test__given_legacy_completed_impact__refreshes_cache( ) mock_simulation_api.run.assert_called_once() + def test__given_error_impact__returns_error_result( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_get_policyengine_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + error_impact = create_mock_reform_impact( + status="error", + reform_impact_json=json.dumps({}), + ) + error_impact["message"] = "Failed to start simulation API job" + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [ + error_impact + ] + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.ERROR + assert result.data is None + assert result.message == "Failed to start simulation API job" + ( + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.assert_called_once() + ) + mock_simulation_api.run.assert_not_called() + def test__given_computing_impact_with_succeeded_execution__returns_completed_result( self, economy_service, @@ -1349,12 +1382,12 @@ def test__given_ok_status__returns_completed(self, economy_service): assert result == ImpactAction.COMPLETED - def test__given_error_status__returns_completed(self, economy_service): + def test__given_error_status__returns_error(self, economy_service): impact = create_mock_reform_impact(status="error") result = economy_service._determine_impact_action(impact) - assert result == ImpactAction.COMPLETED + assert result == ImpactAction.ERROR def test__given_computing_status__returns_computing(self, economy_service): impact = create_mock_reform_impact(status="computing") From 5edef997ae60c7a6c66251ed71f09f7c22dcf104 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 10 Apr 2026 10:38:28 -0400 Subject: [PATCH 12/27] Mark pre-submission setup failures as errors --- policyengine_api/services/economy_service.py | 102 +++++++++---------- tests/unit/services/test_economy_service.py | 41 ++++++++ 2 files changed, 92 insertions(+), 51 deletions(-) diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 89c3b9767..7e2e3cc12 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -1095,63 +1095,63 @@ def _handle_create_impact( setup_options: EconomicImpactSetupOptions, provisional_execution_id: str, ) -> EconomicImpactResult: - baseline_policy = policy_service.get_policy_json( - setup_options.country_id, setup_options.baseline_policy_id - ) - reform_policy = policy_service.get_policy_json( - setup_options.country_id, setup_options.reform_policy_id - ) + try: + baseline_policy = policy_service.get_policy_json( + setup_options.country_id, setup_options.baseline_policy_id + ) + reform_policy = policy_service.get_policy_json( + setup_options.country_id, setup_options.reform_policy_id + ) - sim_config: SimulationOptions = self._setup_sim_options( - country_id=setup_options.country_id, - reform_policy=reform_policy, - baseline_policy=baseline_policy, - region=setup_options.region, - time_period=setup_options.time_period, - dataset=setup_options.dataset, - scope="macro", - include_cliffs=setup_options.target == "cliff", - model_version=setup_options.model_version, - data_version=setup_options.data_version, - ) + sim_config: SimulationOptions = self._setup_sim_options( + country_id=setup_options.country_id, + reform_policy=reform_policy, + baseline_policy=baseline_policy, + region=setup_options.region, + time_period=setup_options.time_period, + dataset=setup_options.dataset, + scope="macro", + include_cliffs=setup_options.target == "cliff", + model_version=setup_options.model_version, + data_version=setup_options.data_version, + ) - sim_params = sim_config.model_dump(mode="json") - telemetry = self._build_simulation_telemetry( - setup_options=setup_options, - sim_config=sim_params, - ) + sim_params = sim_config.model_dump(mode="json") + telemetry = self._build_simulation_telemetry( + setup_options=setup_options, + sim_config=sim_params, + ) - logger.log_struct( - { - "message": "Setting up sim API job", - "run_id": telemetry["run_id"], - **setup_options.model_dump(), - } - ) + logger.log_struct( + { + "message": "Setting up sim API job", + "run_id": telemetry["run_id"], + **setup_options.model_dump(), + } + ) - # Preserve both legacy metadata and the new telemetry envelope. - sim_params["_metadata"] = { - "reform_policy_id": setup_options.reform_policy_id, - "baseline_policy_id": setup_options.baseline_policy_id, - "process_id": setup_options.process_id, - "model_version": setup_options.model_version, - "policyengine_version": setup_options.policyengine_version, - "data_version": setup_options.data_version, - "dataset": setup_options.dataset, - "resolved_app_name": setup_options.runtime_app_name, - } - sim_params["_telemetry"] = telemetry + # Preserve both legacy metadata and the new telemetry envelope. + sim_params["_metadata"] = { + "reform_policy_id": setup_options.reform_policy_id, + "baseline_policy_id": setup_options.baseline_policy_id, + "process_id": setup_options.process_id, + "model_version": setup_options.model_version, + "policyengine_version": setup_options.policyengine_version, + "data_version": setup_options.data_version, + "dataset": setup_options.dataset, + "resolved_app_name": setup_options.runtime_app_name, + } + sim_params["_telemetry"] = telemetry - # The simulation gateway (policyengine-api-v2 PR #458) requires - # ``time_period`` as a string, but the upstream ``policyengine`` - # package (``TimePeriodType = int``) coerces the value to int during - # ``model_validate`` and ``model_dump`` re-emits it as int. Cast back - # to str at the request boundary so the gateway's strict schema - # validates instead of returning 422. - if sim_params.get("time_period") is not None: - sim_params["time_period"] = str(sim_params["time_period"]) + # The simulation gateway (policyengine-api-v2 PR #458) requires + # ``time_period`` as a string, but the upstream ``policyengine`` + # package (``TimePeriodType = int``) coerces the value to int during + # ``model_validate`` and ``model_dump`` re-emits it as int. Cast back + # to str at the request boundary so the gateway's strict schema + # validates instead of returning 422. + if sim_params.get("time_period") is not None: + sim_params["time_period"] = str(sim_params["time_period"]) - try: sim_api_execution = simulation_api.run(sim_params) execution_id = simulation_api.get_execution_id(sim_api_execution) except Exception as error: diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index 8fe21efa9..b8353013a 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -416,6 +416,47 @@ def test__given_simulation_api_submission_failure__marks_provisional_claim_error ) mock_reform_impacts_service.update_reform_impact_execution_id.assert_not_called() + def test__given_simulation_setup_failure__marks_provisional_claim_error( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + mock_reform_impacts_service.get_all_reform_impacts.return_value = [] + with patch.object( + economy_service, + "_setup_sim_options", + side_effect=ValueError("Invalid US state: 'zz'"), + ): + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.ERROR + assert ( + result.message + == "Failed to start simulation API job: Invalid US state: 'zz'" + ) + mock_reform_impacts_service.set_reform_impact.assert_called_once() + mock_reform_impacts_service.set_error_reform_impact.assert_called_once_with( + country_id=MOCK_COUNTRY_ID, + policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=MOCK_TIME_PERIOD, + options_hash=MOCK_OPTIONS_HASH, + message="Failed to start simulation API job: Invalid US state: 'zz'", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}", + ) + mock_simulation_api.run.assert_not_called() + mock_reform_impacts_service.update_reform_impact_execution_id.assert_not_called() + def test__given_claim_lock_timeout_and_existing_provisional_claim__returns_computing( self, economy_service, From 474513c16db0b453e323dff989ff5ecc171f5d90 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 17 Apr 2026 01:28:29 +0200 Subject: [PATCH 13/27] Adapt budget-window flow to batch simulation API --- policyengine_api/endpoints/simulation.py | 12 +- policyengine_api/libs/simulation_api_modal.py | 103 ++- policyengine_api/services/economy_service.py | 613 ++++++++++-------- .../services/reform_impacts_service.py | 3 +- tests/fixtures/libs/simulation_api_modal.py | 49 ++ tests/fixtures/services/economy_service.py | 34 +- .../test_economy_budget_window_routes.py | 13 + tests/unit/endpoints/test_simulation.py | 3 +- tests/unit/libs/test_simulation_api_modal.py | 105 +++ tests/unit/services/test_economy_service.py | 569 ++++++++-------- uv.lock | 2 +- 11 files changed, 967 insertions(+), 539 deletions(-) diff --git a/policyengine_api/endpoints/simulation.py b/policyengine_api/endpoints/simulation.py index 1f8fa8662..d300ae80b 100644 --- a/policyengine_api/endpoints/simulation.py +++ b/policyengine_api/endpoints/simulation.py @@ -42,10 +42,14 @@ def get_simulations( max_results = _DEFAULT_SIMULATION_RESULTS max_results = max(1, min(max_results, _MAX_SIMULATION_RESULTS)) - result = database.query( - "SELECT * FROM reform_impact ORDER BY start_time DESC LIMIT ?", - (max_results,), - ).fetchall() + result = ( + get_remote_database() + .query( + "SELECT * FROM reform_impact ORDER BY start_time DESC LIMIT ?", + (max_results,), + ) + .fetchall() + ) # Format into [{}] diff --git a/policyengine_api/libs/simulation_api_modal.py b/policyengine_api/libs/simulation_api_modal.py index 3d7660791..c567de06f 100644 --- a/policyengine_api/libs/simulation_api_modal.py +++ b/policyengine_api/libs/simulation_api_modal.py @@ -7,7 +7,7 @@ import os import sys -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional import httpx @@ -42,6 +42,28 @@ def name(self) -> str: return self.job_id +@dataclass +class ModalBudgetWindowBatchExecution: + """ + Represents a budget-window batch execution in the Modal simulation API. + """ + + batch_job_id: str + status: str + progress: Optional[int] = None + completed_years: list[str] = field(default_factory=list) + running_years: list[str] = field(default_factory=list) + queued_years: list[str] = field(default_factory=list) + failed_years: list[str] = field(default_factory=list) + result: Optional[dict] = None + error: Optional[str] = None + + @property + def name(self) -> str: + """Alias for batch_job_id.""" + return self.batch_job_id + + class SimulationAPIModal: """ HTTP client for the Modal Simulation API. @@ -144,10 +166,51 @@ def run(self, payload: dict) -> ModalSimulationExecution: ) raise + def run_budget_window_batch(self, payload: dict) -> ModalBudgetWindowBatchExecution: + """ + Submit a budget-window batch job to the Modal API. + """ + try: + modal_payload = dict(payload) + if "model_version" in modal_payload: + modal_payload["version"] = modal_payload.pop("model_version") + modal_payload.pop("data_version", None) + + response = self.client.post( + f"{self.base_url}/simulate/economy/budget-window", + json=modal_payload, + ) + response.raise_for_status() + data = response.json() + + logger.log_struct( + { + "message": "Modal budget-window batch submitted", + "batch_job_id": data.get("batch_job_id"), + "status": data.get("status"), + }, + severity="INFO", + ) + + return ModalBudgetWindowBatchExecution( + batch_job_id=data["batch_job_id"], + status=data["status"], + ) + + except httpx.HTTPStatusError as e: + logger.log_struct( + { + "message": f"Modal batch API HTTP error: {e.response.status_code}", + "response_text": e.response.text[:500], + }, + severity="ERROR", + ) + raise + except httpx.RequestError as e: logger.log_struct( { - "message": f"Modal API request error: {str(e)}", + "message": f"Modal batch API request error: {str(e)}", "run_id": (payload.get("_telemetry") or {}).get("run_id"), }, severity="ERROR", @@ -226,10 +289,44 @@ def get_execution_by_id(self, job_id: str) -> ModalSimulationExecution: ) raise + def get_budget_window_batch_by_id( + self, batch_job_id: str + ) -> ModalBudgetWindowBatchExecution: + """ + Poll the Modal API for the current status of a budget-window batch. + """ + try: + response = self.client.get( + f"{self.base_url}/budget-window-jobs/{batch_job_id}" + ) + data = response.json() + + return ModalBudgetWindowBatchExecution( + batch_job_id=batch_job_id, + status=data["status"], + progress=data.get("progress"), + completed_years=data.get("completed_years", []), + running_years=data.get("running_years", []), + queued_years=data.get("queued_years", []), + failed_years=data.get("failed_years", []), + result=data.get("result"), + error=data.get("error"), + ) + + except httpx.HTTPStatusError as e: + logger.log_struct( + { + "message": f"Modal batch API HTTP error polling job {batch_job_id}: {e.response.status_code}", + "response_text": e.response.text[:500], + }, + severity="ERROR", + ) + raise + except httpx.RequestError as e: logger.log_struct( { - "message": f"Modal API request error polling job {job_id}: {str(e)}", + "message": f"Modal batch API request error polling job {batch_job_id}: {str(e)}", }, severity="ERROR", ) diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 7e2e3cc12..1ec5f969b 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -26,12 +26,11 @@ import datetime import hashlib import uuid -from typing import Literal, Any, Optional, Annotated +from typing import Literal, Any, Optional, Annotated, Union from dotenv import load_dotenv from pydantic import BaseModel, Field import numpy as np from enum import Enum -from concurrent.futures import ThreadPoolExecutor load_dotenv() @@ -73,8 +72,9 @@ class ImpactStatus(Enum): COMPLETE_STATUSES = [ImpactStatus.OK.value, ImpactStatus.ERROR.value] COMPUTING_STATUS = ImpactStatus.COMPUTING.value -BUDGET_WINDOW_MAX_ACTIVE_YEARS = 3 -BUDGET_WINDOW_MAX_YEARS = 20 +BUDGET_WINDOW_MAX_ACTIVE_YEARS = 20 +BUDGET_WINDOW_MAX_YEARS = 75 +BUDGET_WINDOW_MAX_END_YEAR = 2099 PENDING_EXECUTION_ID_PREFIX = "pending:" PROVISIONAL_CLAIM_TTL_SECONDS = 90 STALE_PROVISIONAL_IMPACT_MESSAGE = ( @@ -251,130 +251,22 @@ def get_economic_impact( if country_id == "us": region = normalize_us_region(region) - # Set up logging - process_id: str = self._create_process_id() - - country_package_version = COUNTRY_PACKAGE_VERSIONS.get(country_id) - cache_version = get_economy_impact_cache_version(country_id, api_version) - resolved_dataset = self._setup_data( + economic_impact_setup_options = self._build_economic_impact_setup_options( country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, region=region, dataset=dataset, - ) - resolved_model_version = country_package_version - resolved_data_version = self._extract_dataset_version(resolved_dataset) - options_hash = self._build_options_hash( + time_period=time_period, options=options, - model_version=resolved_model_version, - dataset=resolved_dataset, - ) - - economic_impact_setup_options = EconomicImpactSetupOptions.model_validate( - { - "process_id": process_id, - "country_id": country_id, - "reform_policy_id": policy_id, - "baseline_policy_id": baseline_policy_id, - "region": region, - "dataset": resolved_dataset, - "time_period": time_period, - "options": options, - "api_version": cache_version, - "target": target, - "model_version": resolved_model_version, - "policyengine_version": None, - "data_version": resolved_data_version, - "runtime_app_name": None, - "options_hash": options_hash, - } - ) - - # Logging that we've received a request - logger.log_struct( - { - "message": "Received request for economic impact; checking if already in reform_impacts table", - **economic_impact_setup_options.model_dump(), - }, - severity="INFO", + api_version=api_version, + target=target, ) - most_recent_impact: dict | None = self._get_most_recent_impact( + return self._get_or_create_economic_impact( setup_options=economic_impact_setup_options, ) - if most_recent_impact and self._should_refresh_cached_impact( - setup_options=economic_impact_setup_options, - most_recent_impact=most_recent_impact, - ): - most_recent_impact = self._get_most_recent_impact( - economic_impact_setup_options - ) - if ( - not most_recent_impact - or most_recent_impact.get("options_hash") - != economic_impact_setup_options.options_hash - ): - most_recent_impact = None - - impact_action: ImpactAction = self._determine_impact_action( - most_recent_impact=most_recent_impact, - ) - - if impact_action == ImpactAction.COMPLETED: - logger.log_struct( - { - "message": "Found completed economic impact in db; returning result", - **economic_impact_setup_options.model_dump(), - }, - severity="INFO", - ) - return self._handle_completed_impact( - setup_options=economic_impact_setup_options, - most_recent_impact=most_recent_impact, - ) - - if impact_action == ImpactAction.COMPUTING: - logger.log_struct( - { - "message": "Found computing economic impact record in db; confirming this is still computing", - **economic_impact_setup_options.model_dump(), - }, - severity="INFO", - ) - return self._handle_computing_impact( - setup_options=economic_impact_setup_options, - most_recent_impact=most_recent_impact, - ) - - if impact_action == ImpactAction.CREATE: - if economic_impact_setup_options.runtime_app_name is None: - ( - economic_impact_setup_options.runtime_app_name, - economic_impact_setup_options.model_version, - ) = simulation_api.resolve_app_name( - country_id, - economic_impact_setup_options.model_version, - ) - economic_impact_setup_options.options_hash = self._build_options_hash( - options=options, - model_version=economic_impact_setup_options.model_version, - dataset=resolved_dataset, - data_version=resolved_data_version, - runtime_app_name=economic_impact_setup_options.runtime_app_name, - ) - logger.log_struct( - { - "message": "No previous economic impact record found in db; creating new simulation run", - **economic_impact_setup_options.model_dump(), - }, - severity="INFO", - ) - return self._handle_create_impact( - setup_options=economic_impact_setup_options, - ) - - raise ValueError(f"Unexpected impact action: {impact_action}") - except Exception as e: print(f"Error getting economic impact: {str(e)}") raise e @@ -407,142 +299,273 @@ def get_budget_window_economic_impact( raise ValueError( f"window_size must be between 1 and {BUDGET_WINDOW_MAX_YEARS}" ) - - years = [str(start_year_int + index) for index in range(window_size)] - setup_options_by_year = { - year: self._build_economic_impact_setup_options( - country_id=country_id, - policy_id=policy_id, - baseline_policy_id=baseline_policy_id, - region=region, - dataset=dataset, - time_period=year, - options=dict(options), - api_version=api_version, - target=target, + end_year = start_year_int + window_size - 1 + if end_year > BUDGET_WINDOW_MAX_END_YEAR: + raise ValueError( + f"budget-window end_year must be {BUDGET_WINDOW_MAX_END_YEAR} or earlier" ) - for year in years - } - completed_impacts: dict[str, dict] = {} - computing_years: list[str] = [] - queued_years: list[str] = [] + start_year = str(start_year_int) + years = self._build_budget_window_years( + start_year=start_year, + window_size=window_size, + ) + tracking_setup_options = self._build_budget_window_tracking_setup_options( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + start_year=start_year, + window_size=window_size, + options=options, + api_version=api_version, + target=target, + ) - for year in years: - result = self._get_existing_economic_impact( - setup_options=setup_options_by_year[year] + most_recent_impact = self._get_budget_window_tracking_impact( + tracking_setup_options + ) + if most_recent_impact is None: + self._start_budget_window_batch( + setup_options=tracking_setup_options, + start_year=start_year, + window_size=window_size, + max_parallel=max_active_years, + ) + return self._build_budget_window_computing_result( + total_years=len(years), + completed_years=[], + computing_years=[], + queued_years=years, + progress=0, ) - if result is None: - queued_years.append(year) - continue + return self._get_budget_window_result_from_tracking_impact( + setup_options=tracking_setup_options, + most_recent_impact=most_recent_impact, + total_years=len(years), + queued_years_on_submit=years, + ) + except Exception as e: + print(f"Error getting budget-window economic impact: {str(e)}") + raise e - if result.status == ImpactStatus.OK: - completed_impacts[year] = self._extract_budget_window_annual_impact( - year=year, impact_data=result.data or {} - ) - continue + def _build_budget_window_years( + self, + *, + start_year: str, + window_size: int, + ) -> list[str]: + start_year_int = int(start_year) + return [str(start_year_int + index) for index in range(window_size)] - if result.status == ImpactStatus.COMPUTING: - computing_years.append(year) - continue + def _build_budget_window_tracking_time_period( + self, + *, + start_year: str, + window_size: int, + ) -> str: + return f"budget_window:{start_year}:{window_size}" - completed_years = [ - completed_year - for completed_year in years - if completed_year in completed_impacts - ] + def _build_budget_window_tracking_setup_options( + self, + *, + country_id: str, + policy_id: int, + baseline_policy_id: int, + region: str, + dataset: str, + start_year: str, + window_size: int, + options: dict, + api_version: str, + target: Literal["general", "cliff"], + ) -> EconomicImpactSetupOptions: + return self._build_economic_impact_setup_options( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + time_period=self._build_budget_window_tracking_time_period( + start_year=start_year, + window_size=window_size, + ), + options=dict(options), + api_version=api_version, + target=target, + ) + + def _build_budget_window_batch_payload( + self, + *, + setup_options: EconomicImpactSetupOptions, + start_year: str, + window_size: int, + max_parallel: int, + ) -> dict[str, Any]: + baseline_policy = policy_service.get_policy_json( + setup_options.country_id, + setup_options.baseline_policy_id, + ) + reform_policy = policy_service.get_policy_json( + setup_options.country_id, + setup_options.reform_policy_id, + ) + sim_config: SimulationOptions = self._setup_sim_options( + country_id=setup_options.country_id, + reform_policy=reform_policy, + baseline_policy=baseline_policy, + region=setup_options.region, + time_period=start_year, + dataset=setup_options.dataset, + scope="macro", + include_cliffs=False, + model_version=setup_options.model_version, + data_version=setup_options.data_version, + ) + sim_params = sim_config.model_dump() + sim_params.pop("time_period", None) + sim_params["start_year"] = start_year + sim_params["window_size"] = window_size + sim_params["max_parallel"] = max_parallel + sim_params["target"] = setup_options.target + return sim_params + + def _get_budget_window_tracking_impact( + self, + setup_options: EconomicImpactSetupOptions, + ) -> dict | None: + return self._get_exact_reform_impact(setup_options) + + def _start_budget_window_batch( + self, + *, + setup_options: EconomicImpactSetupOptions, + start_year: str, + window_size: int, + max_parallel: int, + ) -> None: + sim_params = self._build_budget_window_batch_payload( + setup_options=setup_options, + start_year=start_year, + window_size=window_size, + max_parallel=max_parallel, + ) + + logger.log_struct( + { + "message": "Submitting budget-window batch job", + **setup_options.model_dump(), + "start_year": start_year, + "window_size": window_size, + "max_parallel": max_parallel, + }, + severity="INFO", + ) + + batch_execution = simulation_api.run_budget_window_batch(sim_params) + self._set_reform_impact_computing( + setup_options=setup_options, + execution_id=batch_execution.batch_job_id, + ) + + def _get_budget_window_result_from_tracking_impact( + self, + *, + setup_options: EconomicImpactSetupOptions, + most_recent_impact: dict, + total_years: int, + queued_years_on_submit: list[str], + ) -> BudgetWindowEconomicImpactResult: + impact_status = most_recent_impact.get("status") + if impact_status == ImpactStatus.OK.value: + return BudgetWindowEconomicImpactResult.completed( + json.loads(most_recent_impact["reform_impact_json"]) + ) + + execution_id = most_recent_impact.get("execution_id") + if not execution_id: + return BudgetWindowEconomicImpactResult.failed( + most_recent_impact.get("message") + or "Budget-window batch tracking row is missing execution_id", + queued_years=queued_years_on_submit, + ) + + try: + batch_execution = simulation_api.get_budget_window_batch_by_id(execution_id) + except Exception: + if impact_status == ImpactStatus.ERROR.value: return BudgetWindowEconomicImpactResult.failed( - self._get_economic_impact_error_message( - result=result, - year=year, - ), - completed_years=completed_years, - computing_years=computing_years, - queued_years=queued_years, + most_recent_impact.get("message") or "Budget-window batch failed", + queued_years=queued_years_on_submit, ) + raise - available_slots = max(0, max_active_years - len(computing_years)) - years_to_start = queued_years[:available_slots] - remaining_queued_years = queued_years[available_slots:] - - if years_to_start: - max_workers = min(len(years_to_start), max_active_years) - with ThreadPoolExecutor(max_workers=max_workers) as executor: - future_year_pairs = [ - ( - year, - executor.submit( - self._get_or_create_economic_impact, - setup_options_by_year[year], - ), - ) - for year in years_to_start - ] - - for year, future in future_year_pairs: - result = future.result() - - if result.status == ImpactStatus.OK: - completed_impacts[year] = ( - self._extract_budget_window_annual_impact( - year=year, impact_data=result.data or {} - ) - ) - elif result.status == ImpactStatus.COMPUTING: - computing_years.append(year) - else: - completed_years = [ - completed_year - for completed_year in years - if completed_year in completed_impacts - ] - return BudgetWindowEconomicImpactResult.failed( - self._get_economic_impact_error_message( - result=result, - year=year, - ), - completed_years=completed_years, - computing_years=computing_years, - queued_years=remaining_queued_years, - ) - - completed_years = [ - completed_year - for completed_year in years - if completed_year in completed_impacts - ] + if batch_execution.status in EXECUTION_STATUSES_SUCCESS: + result = batch_execution.result or {} + self._set_reform_impact_complete( + setup_options=setup_options, + reform_impact_json=json.dumps(result), + execution_id=execution_id, + ) + return BudgetWindowEconomicImpactResult.completed(result) - if len(completed_years) == len(years): - ordered_annual_impacts = [ - completed_impacts[year] - for year in years - if year in completed_impacts - ] - return BudgetWindowEconomicImpactResult.completed( - self._build_budget_window_output( - start_year=start_year, - window_size=window_size, - annual_impacts=ordered_annual_impacts, - ) - ) + if batch_execution.status in EXECUTION_STATUSES_FAILURE: + error_message = batch_execution.error or ( + most_recent_impact.get("message") or "Budget-window batch failed" + ) + self._set_reform_impact_error( + setup_options=setup_options, + message=error_message, + execution_id=execution_id, + ) + return BudgetWindowEconomicImpactResult.failed( + error_message, + completed_years=batch_execution.completed_years, + computing_years=batch_execution.running_years, + queued_years=batch_execution.queued_years, + ) + + if batch_execution.status in EXECUTION_STATUSES_PENDING: + return self._build_budget_window_computing_result( + total_years=total_years, + completed_years=batch_execution.completed_years, + computing_years=batch_execution.running_years, + queued_years=batch_execution.queued_years, + progress=batch_execution.progress, + ) + + raise ValueError( + f"Unexpected budget-window batch execution state: {batch_execution.status}" + ) - progress = round((len(completed_years) / len(years)) * 100) - return BudgetWindowEconomicImpactResult.computing( - progress=progress, + def _build_budget_window_computing_result( + self, + *, + total_years: int, + completed_years: list[str], + computing_years: list[str], + queued_years: list[str], + progress: Optional[int] = None, + ) -> BudgetWindowEconomicImpactResult: + resolved_progress = progress + if resolved_progress is None: + resolved_progress = round((len(completed_years) / total_years) * 100) + + return BudgetWindowEconomicImpactResult.computing( + progress=resolved_progress, + completed_years=completed_years, + computing_years=computing_years, + queued_years=queued_years, + message=self._build_budget_window_progress_message( completed_years=completed_years, + total_years=total_years, computing_years=computing_years, - queued_years=remaining_queued_years, - message=self._build_budget_window_progress_message( - completed_years=completed_years, - total_years=len(years), - computing_years=computing_years, - queued_years=remaining_queued_years, - ), - ) - except Exception as e: - print(f"Error getting budget-window economic impact: {str(e)}") - raise e + queued_years=queued_years, + ), + ) def _build_economic_impact_setup_options( self, @@ -558,11 +581,19 @@ def _build_economic_impact_setup_options( target: Literal["general", "cliff"] = "general", ) -> EconomicImpactSetupOptions: process_id: str = self._create_process_id() - options_hash = "[" + "&".join([f"{k}={v}" for k, v in options.items()]) + "]" - + cache_version = get_economy_impact_cache_version(country_id, api_version) country_package_version = COUNTRY_PACKAGE_VERSIONS.get(country_id) - if country_id == "uk": - country_package_version = None + resolved_dataset = self._setup_data( + country_id=country_id, + region=region, + dataset=dataset, + ) + resolved_data_version = self._extract_dataset_version(resolved_dataset) + options_hash = self._build_options_hash( + options=options, + model_version=country_package_version, + dataset=resolved_dataset, + ) return EconomicImpactSetupOptions.model_validate( { @@ -571,13 +602,15 @@ def _build_economic_impact_setup_options( "reform_policy_id": policy_id, "baseline_policy_id": baseline_policy_id, "region": region, - "dataset": dataset, + "dataset": resolved_dataset, "time_period": time_period, "options": options, - "api_version": api_version, + "api_version": cache_version, "target": target, "model_version": country_package_version, - "data_version": get_dataset_version(country_id), + "policyengine_version": None, + "data_version": resolved_data_version, + "runtime_app_name": None, "options_hash": options_hash, } ) @@ -597,6 +630,17 @@ def _get_or_create_economic_impact( setup_options=setup_options ) + if most_recent_impact and self._should_refresh_cached_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ): + most_recent_impact = self._get_most_recent_impact(setup_options) + if ( + not most_recent_impact + or most_recent_impact.get("options_hash") != setup_options.options_hash + ): + most_recent_impact = None + impact_action: ImpactAction = self._determine_impact_action( most_recent_impact=most_recent_impact ) @@ -609,7 +653,10 @@ def _get_or_create_economic_impact( }, severity="INFO", ) - return self._handle_completed_impact(most_recent_impact=most_recent_impact) + return self._handle_completed_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) if impact_action == ImpactAction.COMPUTING: logger.log_struct( @@ -638,6 +685,7 @@ def _get_or_create_economic_impact( ) if impact_action == ImpactAction.CREATE: + self._resolve_runtime_bundle_for_setup_options(setup_options) try: with reform_impacts_service.claim_lock( country_id=setup_options.country_id, @@ -649,7 +697,7 @@ def _get_or_create_economic_impact( options_hash=setup_options.options_hash, api_version=setup_options.api_version, ): - most_recent_impact = self._get_most_recent_impact( + most_recent_impact = self._get_exact_reform_impact( setup_options=setup_options ) impact_action = self._determine_impact_action( @@ -665,7 +713,8 @@ def _get_or_create_economic_impact( severity="INFO", ) return self._handle_completed_impact( - most_recent_impact=most_recent_impact + setup_options=setup_options, + most_recent_impact=most_recent_impact, ) if impact_action == ImpactAction.COMPUTING: @@ -741,10 +790,32 @@ def _get_or_create_economic_impact( raise ValueError(f"Unexpected impact action: {impact_action}") + def _resolve_runtime_bundle_for_setup_options( + self, + setup_options: EconomicImpactSetupOptions, + ) -> None: + if setup_options.runtime_app_name is None: + ( + setup_options.runtime_app_name, + setup_options.model_version, + ) = simulation_api.resolve_app_name( + setup_options.country_id, + setup_options.model_version, + ) + + setup_options.options_hash = self._build_options_hash( + options=setup_options.options, + model_version=setup_options.model_version, + dataset=setup_options.dataset, + data_version=setup_options.data_version, + policyengine_version=setup_options.policyengine_version, + runtime_app_name=setup_options.runtime_app_name, + ) + def _get_existing_economic_impact( self, setup_options: EconomicImpactSetupOptions ) -> Optional[EconomicImpactResult]: - most_recent_impact = self._get_most_recent_impact(setup_options=setup_options) + most_recent_impact = self._get_exact_reform_impact(setup_options=setup_options) if not most_recent_impact: return None @@ -756,7 +827,10 @@ def _get_existing_economic_impact( ) if status == ImpactStatus.OK.value: - return self._handle_completed_impact(most_recent_impact=most_recent_impact) + return self._handle_completed_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) if status == ImpactStatus.COMPUTING.value: if self._is_stale_provisional_impact(most_recent_impact): @@ -878,8 +952,37 @@ def _get_previous_impacts( api_version, ) ) + if previous_impacts: + return previous_impacts + + return reform_impacts_service.get_all_reform_impacts( + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options_hash, + api_version, + ) - return previous_impacts + def _get_exact_reform_impact( + self, + setup_options: EconomicImpactSetupOptions, + ) -> dict | None: + previous_impacts = reform_impacts_service.get_all_reform_impacts( + setup_options.country_id, + setup_options.reform_policy_id, + setup_options.baseline_policy_id, + setup_options.region, + setup_options.dataset, + setup_options.time_period, + setup_options.options_hash, + setup_options.api_version, + ) + if not previous_impacts: + return None + return previous_impacts[0] def _get_most_recent_impact( self, @@ -1232,7 +1335,7 @@ def _recover_failed_execution_id_promotion( options_hash=setup_options.options_hash, api_version=setup_options.api_version, ): - most_recent_impact = self._get_most_recent_impact( + most_recent_impact = self._get_exact_reform_impact( setup_options=setup_options ) if most_recent_impact is not None: @@ -1364,7 +1467,7 @@ def _should_refresh_cached_impact( setup_options: EconomicImpactSetupOptions, most_recent_impact: dict, ) -> bool: - if most_recent_impact.get("status") == ImpactStatus.COMPUTING.value: + if most_recent_impact.get("status") != ImpactStatus.OK.value: return False cached_result = self._extract_cached_result(most_recent_impact) diff --git a/policyengine_api/services/reform_impacts_service.py b/policyengine_api/services/reform_impacts_service.py index a090516fa..91928495e 100644 --- a/policyengine_api/services/reform_impacts_service.py +++ b/policyengine_api/services/reform_impacts_service.py @@ -175,6 +175,7 @@ def get_all_reform_impacts_by_options_hash_prefix( api_version, ): try: + self._ensure_remote_schema() query = ( "SELECT reform_impact_json, status, message, start_time, execution_id, options_hash FROM " "reform_impact WHERE country_id = ? AND reform_policy_id = ? AND " @@ -182,7 +183,7 @@ def get_all_reform_impacts_by_options_hash_prefix( "(options_hash = ? OR options_hash LIKE ? ESCAPE '\\') AND api_version = ? AND dataset = ? " "ORDER BY CASE WHEN options_hash = ? THEN 0 ELSE 1 END, start_time DESC" ) - return local_database.query( + return database.query( query, ( country_id, diff --git a/tests/fixtures/libs/simulation_api_modal.py b/tests/fixtures/libs/simulation_api_modal.py index 6d514a7e5..a9b4ce45e 100644 --- a/tests/fixtures/libs/simulation_api_modal.py +++ b/tests/fixtures/libs/simulation_api_modal.py @@ -19,6 +19,7 @@ # Mock data constants MOCK_MODAL_JOB_ID = "fc-abc123xyz" MOCK_RUN_ID = "run-abc123xyz" +MOCK_BATCH_JOB_ID = "fc-batch123xyz" MOCK_MODAL_BASE_URL = "https://test-modal-api.modal.run" MOCK_SIMULATION_PAYLOAD = { @@ -87,6 +88,54 @@ MOCK_HEALTH_RESPONSE = {"status": "healthy"} +MOCK_BATCH_SUBMIT_RESPONSE_SUCCESS = { + "batch_job_id": MOCK_BATCH_JOB_ID, + "status": MODAL_EXECUTION_STATUS_SUBMITTED, + "poll_url": f"/budget-window-jobs/{MOCK_BATCH_JOB_ID}", + "country": "us", + "version": "1.459.0", +} + +MOCK_BATCH_POLL_RESPONSE_RUNNING = { + "status": MODAL_EXECUTION_STATUS_RUNNING, + "progress": 33, + "completed_years": ["2026"], + "running_years": ["2027"], + "queued_years": ["2028"], + "failed_years": [], + "result": None, + "error": None, +} + +MOCK_BATCH_POLL_RESPONSE_COMPLETE = { + "status": MODAL_EXECUTION_STATUS_COMPLETE, + "progress": 100, + "completed_years": ["2026", "2027", "2028"], + "running_years": [], + "queued_years": [], + "failed_years": [], + "result": { + "kind": "budgetWindow", + "startYear": "2026", + "endYear": "2028", + "windowSize": 3, + "annualImpacts": [], + "totals": {}, + }, + "error": None, +} + +MOCK_BATCH_POLL_RESPONSE_FAILED = { + "status": MODAL_EXECUTION_STATUS_FAILED, + "progress": 33, + "completed_years": ["2026"], + "running_years": [], + "queued_years": ["2028"], + "failed_years": ["2027"], + "result": None, + "error": "Budget window failed", +} + def create_mock_httpx_response( status_code: int = 200, diff --git a/tests/fixtures/services/economy_service.py b/tests/fixtures/services/economy_service.py index d5c8da047..f29eb88b9 100644 --- a/tests/fixtures/services/economy_service.py +++ b/tests/fixtures/services/economy_service.py @@ -141,6 +141,7 @@ def mock_simulation_api(): """Mock SimulationAPIModal with all required methods.""" mock_api = MagicMock() mock_execution = create_mock_modal_execution() + mock_batch_execution = create_mock_budget_window_batch_execution() mock_api._setup_sim_options.return_value = MOCK_SIM_CONFIG mock_api.run.return_value = mock_execution @@ -152,6 +153,8 @@ def mock_simulation_api(): mock_api.get_execution_by_id.return_value = mock_execution mock_api.get_execution_status.return_value = MODAL_EXECUTION_STATUS_RUNNING mock_api.get_execution_result.return_value = MOCK_REFORM_IMPACT_DATA + mock_api.run_budget_window_batch.return_value = mock_batch_execution + mock_api.get_budget_window_batch_by_id.return_value = mock_batch_execution with patch( "policyengine_api.services.economy_service.simulation_api", mock_api @@ -191,6 +194,8 @@ def create_mock_reform_impact( execution_id=MOCK_MODAL_JOB_ID, options_hash=MOCK_OPTIONS_HASH, start_time=None, + time_period=MOCK_TIME_PERIOD, + message=None, ): """Helper function to create mock reform impact records.""" default_reform_impact_json = json.dumps( @@ -212,12 +217,13 @@ def create_mock_reform_impact( "baseline_policy_id": MOCK_BASELINE_POLICY_ID, "region": MOCK_REGION, "dataset": MOCK_RESOLVED_DATASET, - "time_period": MOCK_TIME_PERIOD, + "time_period": time_period, "options_hash": options_hash, "status": status, "api_version": MOCK_API_VERSION, "reform_impact_json": reform_impact_json or default_reform_impact_json, "execution_id": execution_id, + "message": message, "start_time": start_time or datetime.datetime(2025, 6, 26, 12, 0, 0), "end_time": ( datetime.datetime(2025, 6, 26, 12, 5, 0) if status == "ok" else None @@ -263,6 +269,32 @@ def create_mock_modal_execution( return mock_execution +def create_mock_budget_window_batch_execution( + batch_job_id=MOCK_MODAL_JOB_ID, + status=MODAL_EXECUTION_STATUS_SUBMITTED, + progress=None, + completed_years=None, + running_years=None, + queued_years=None, + failed_years=None, + result=None, + error=None, +): + """Helper function to create mock batch execution objects.""" + mock_execution = MagicMock() + mock_execution.batch_job_id = batch_job_id + mock_execution.name = batch_job_id + mock_execution.status = status + mock_execution.progress = progress + mock_execution.completed_years = completed_years or [] + mock_execution.running_years = running_years or [] + mock_execution.queued_years = queued_years or [] + mock_execution.failed_years = failed_years or [] + mock_execution.result = result + mock_execution.error = error + return mock_execution + + @pytest.fixture def mock_simulation_api_modal(): """Mock SimulationAPIModal with all required methods.""" diff --git a/tests/to_refactor/python/test_economy_budget_window_routes.py b/tests/to_refactor/python/test_economy_budget_window_routes.py index fca938948..f185f8c28 100644 --- a/tests/to_refactor/python/test_economy_budget_window_routes.py +++ b/tests/to_refactor/python/test_economy_budget_window_routes.py @@ -71,6 +71,19 @@ def test_budget_window_route_rejects_oversized_window(rest_client): assert "window_size must be between 1 and" in data["message"] +def test_budget_window_route_rejects_end_year_after_2099(rest_client): + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2090&window_size=20" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "budget-window end_year must be 2099 or earlier" == data["message"] + + @patch( "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" ) diff --git a/tests/unit/endpoints/test_simulation.py b/tests/unit/endpoints/test_simulation.py index a9013a056..e2936de11 100644 --- a/tests/unit/endpoints/test_simulation.py +++ b/tests/unit/endpoints/test_simulation.py @@ -14,6 +14,7 @@ def test_get_simulations_reads_from_remote_database(): result = get_simulations() mock_database.query.assert_called_once_with( - "SELECT * FROM reform_impact ORDER BY start_time DESC LIMIT 100", + "SELECT * FROM reform_impact ORDER BY start_time DESC LIMIT ?", + (100,), ) assert result == {"result": [{"id": 1}]} diff --git a/tests/unit/libs/test_simulation_api_modal.py b/tests/unit/libs/test_simulation_api_modal.py index 26b321135..c80ab373e 100644 --- a/tests/unit/libs/test_simulation_api_modal.py +++ b/tests/unit/libs/test_simulation_api_modal.py @@ -21,6 +21,7 @@ os.environ.setdefault("FLASK_DEBUG", "1") from policyengine_api.libs.simulation_api_modal import ( # noqa: E402 + ModalBudgetWindowBatchExecution, ModalSimulationExecution, SimulationAPIModal, ) @@ -32,6 +33,7 @@ ) from tests.fixtures.libs.simulation_api_modal import ( # noqa: E402 MOCK_MODAL_JOB_ID, + MOCK_BATCH_JOB_ID, MOCK_MODAL_BASE_URL, MOCK_SIMULATION_PAYLOAD, MOCK_SIMULATION_PAYLOAD_WITH_TELEMETRY, @@ -44,6 +46,10 @@ MOCK_POLL_RESPONSE_COMPLETE, MOCK_POLL_RESPONSE_FAILED, MOCK_HEALTH_RESPONSE, + MOCK_BATCH_SUBMIT_RESPONSE_SUCCESS, + MOCK_BATCH_POLL_RESPONSE_RUNNING, + MOCK_BATCH_POLL_RESPONSE_COMPLETE, + MOCK_BATCH_POLL_RESPONSE_FAILED, create_mock_httpx_response, ) @@ -117,6 +123,18 @@ def test__given_failed_execution__then_error_attribute_populated(self): assert execution.result is None +class TestModalBudgetWindowBatchExecution: + """Tests for the ModalBudgetWindowBatchExecution dataclass.""" + + def test__given_batch_job_id__then_name_returns_batch_job_id(self): + execution = ModalBudgetWindowBatchExecution( + batch_job_id=MOCK_BATCH_JOB_ID, + status=MODAL_EXECUTION_STATUS_SUBMITTED, + ) + + assert execution.name == MOCK_BATCH_JOB_ID + + class TestSimulationAPIModal: """Tests for the SimulationAPIModal class.""" @@ -322,6 +340,40 @@ def test__given_country_and_version__then_returns_registered_app( assert app_name == MOCK_RESOLVED_APP_NAME assert resolved_version == "1.459.0" + class TestRunBudgetWindowBatch: + def test__given_valid_payload__then_returns_batch_execution( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.post.return_value = create_mock_httpx_response( + status_code=202, + json_data=MOCK_BATCH_SUBMIT_RESPONSE_SUCCESS, + ) + api = SimulationAPIModal() + + execution = api.run_budget_window_batch(MOCK_SIMULATION_PAYLOAD) + + assert execution.batch_job_id == MOCK_BATCH_JOB_ID + assert execution.status == MODAL_EXECUTION_STATUS_SUBMITTED + call_args = mock_httpx_client.post.call_args + assert "/simulate/economy/budget-window" in call_args[0][0] + + def test__given_http_error__then_raises_exception( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_response = create_mock_httpx_response( + status_code=400, + json_data={"error": "Invalid request"}, + ) + mock_httpx_client.post.return_value = mock_response + api = SimulationAPIModal() + + with pytest.raises(httpx.HTTPStatusError): + api.run_budget_window_batch(MOCK_SIMULATION_PAYLOAD) + class TestGetExecutionById: def test__given_running_job__then_returns_running_status( self, @@ -416,6 +468,59 @@ def test__given_unexpected_http_error__then_raises_exception( with pytest.raises(httpx.HTTPStatusError): api.get_execution_by_id(MOCK_MODAL_JOB_ID) + class TestGetBudgetWindowBatchById: + def test__given_running_batch__then_returns_running_status( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=202, + json_data=MOCK_BATCH_POLL_RESPONSE_RUNNING, + ) + api = SimulationAPIModal() + + execution = api.get_budget_window_batch_by_id(MOCK_BATCH_JOB_ID) + + assert execution.batch_job_id == MOCK_BATCH_JOB_ID + assert execution.status == MODAL_EXECUTION_STATUS_RUNNING + assert execution.completed_years == ["2026"] + assert execution.running_years == ["2027"] + assert execution.queued_years == ["2028"] + + def test__given_complete_batch__then_returns_result( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=200, + json_data=MOCK_BATCH_POLL_RESPONSE_COMPLETE, + ) + api = SimulationAPIModal() + + execution = api.get_budget_window_batch_by_id(MOCK_BATCH_JOB_ID) + + assert execution.status == MODAL_EXECUTION_STATUS_COMPLETE + assert execution.result == MOCK_BATCH_POLL_RESPONSE_COMPLETE["result"] + + def test__given_failed_batch__then_returns_error( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=500, + json_data=MOCK_BATCH_POLL_RESPONSE_FAILED, + ) + api = SimulationAPIModal() + + execution = api.get_budget_window_batch_by_id(MOCK_BATCH_JOB_ID) + + assert execution.status == MODAL_EXECUTION_STATUS_FAILED + assert execution.failed_years == ["2027"] + assert execution.error == "Budget window failed" + class TestGetExecutionId: def test__given_execution__then_returns_job_id(self, mock_httpx_client): # Given diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index b8353013a..fd9c64416 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -1,10 +1,65 @@ import datetime import json +import sys import pytest from unittest.mock import patch, MagicMock from typing import Literal +from types import ModuleType + +try: + from policyengine.simulation import SimulationOptions # noqa: F401 +except ModuleNotFoundError: + policyengine_module = sys.modules.setdefault( + "policyengine", ModuleType("policyengine") + ) + simulation_module = ModuleType("policyengine.simulation") + utils_module = ModuleType("policyengine.utils") + data_module = ModuleType("policyengine.utils.data") + datasets_module = ModuleType("policyengine.utils.data.datasets") + + class _StubSimulationOptions: + def __init__(self, payload): + self._payload = payload + + @classmethod + def model_validate(cls, payload): + return cls(payload) + + def model_dump(self): + return dict(self._payload) + + simulation_module.SimulationOptions = _StubSimulationOptions + policyengine_module.simulation = simulation_module + + def _stub_get_default_dataset(country, region): + if country == "us": + if region == "us": + return "gs://policyengine-us-data/enhanced_cps_2024.h5" + if region == "state/ca": + return "gs://policyengine-us-data/states/CA.h5" + if region == "state/ut": + return "gs://policyengine-us-data/states/UT.h5" + if region == "place/NJ-57000": + return "gs://policyengine-us-data/states/NJ.h5" + if region == "congressional_district/CA-37": + return "gs://policyengine-us-data/districts/CA-37.h5" + if country == "uk" and region == "uk": + return "gs://policyengine-uk-data-private/enhanced_frs_2023_24.h5" + raise ValueError( + f"Error getting default dataset for country={country}, region={region}: unsupported in test stub" + ) + + datasets_module.get_default_dataset = _stub_get_default_dataset + data_module.datasets = datasets_module + utils_module.data = data_module + policyengine_module.utils = utils_module + sys.modules["policyengine.simulation"] = simulation_module + sys.modules["policyengine.utils"] = utils_module + sys.modules["policyengine.utils.data"] = data_module + sys.modules["policyengine.utils.data.datasets"] = datasets_module from policyengine_api.services.economy_service import ( + BUDGET_WINDOW_MAX_END_YEAR, BUDGET_WINDOW_MAX_YEARS, EconomyService, EconomicImpactResult, @@ -35,6 +90,7 @@ MOCK_REFORM_IMPACT_DATA, MOCK_RESOLVED_DATASET, MOCK_RESOLVED_APP_NAME, + create_mock_budget_window_batch_execution, create_mock_reform_impact, ) @@ -175,9 +231,7 @@ def test__given_error_impact__returns_error_result( assert result.status == ImpactStatus.ERROR assert result.data is None assert result.message == "Failed to start simulation API job" - ( - mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.assert_called_once() - ) + mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.assert_called_once() mock_simulation_api.run.assert_not_called() def test__given_computing_impact_with_succeeded_execution__returns_completed_result( @@ -302,7 +356,7 @@ def test__given_no_previous_impact__creates_new_simulation( policy_id=MOCK_POLICY_ID, baseline_policy_id=MOCK_BASELINE_POLICY_ID, region=MOCK_REGION, - dataset=MOCK_DATASET, + dataset=MOCK_RESOLVED_DATASET, time_period=MOCK_TIME_PERIOD, options_hash=MOCK_OPTIONS_HASH, current_execution_id=f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}", @@ -408,7 +462,7 @@ def test__given_simulation_api_submission_failure__marks_provisional_claim_error policy_id=MOCK_POLICY_ID, baseline_policy_id=MOCK_BASELINE_POLICY_ID, region=MOCK_REGION, - dataset=MOCK_DATASET, + dataset=MOCK_RESOLVED_DATASET, time_period=MOCK_TIME_PERIOD, options_hash=MOCK_OPTIONS_HASH, message="Failed to start simulation API job: gateway unavailable", @@ -448,7 +502,7 @@ def test__given_simulation_setup_failure__marks_provisional_claim_error( policy_id=MOCK_POLICY_ID, baseline_policy_id=MOCK_BASELINE_POLICY_ID, region=MOCK_REGION, - dataset=MOCK_DATASET, + dataset=MOCK_RESOLVED_DATASET, time_period=MOCK_TIME_PERIOD, options_hash=MOCK_OPTIONS_HASH, message="Failed to start simulation API job: Invalid US state: 'zz'", @@ -519,7 +573,7 @@ def test__given_stale_provisional_claim__expires_and_recreates_simulation( policy_id=MOCK_POLICY_ID, baseline_policy_id=MOCK_BASELINE_POLICY_ID, region=MOCK_REGION, - dataset=MOCK_DATASET, + dataset=MOCK_RESOLVED_DATASET, time_period=MOCK_TIME_PERIOD, options_hash=MOCK_OPTIONS_HASH, message=STALE_PROVISIONAL_IMPACT_MESSAGE, @@ -909,7 +963,15 @@ def test__given_uk_request__preserves_model_version_in_bundle( class TestGetBudgetWindowEconomicImpact: @pytest.fixture - def economy_service(self): + def economy_service( + self, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_logger, + mock_datetime, + mock_numpy_random, + ): return EconomyService() @pytest.fixture @@ -927,237 +989,208 @@ def base_params(self): "target": "general", } - def test__given_all_years_completed__returns_aggregated_budget_window_result( - self, economy_service, base_params + def test__given_no_tracking_row__submits_parent_batch_and_returns_queued_result( + self, + economy_service, + base_params, + mock_reform_impacts_service, + mock_simulation_api, ): - def make_setup(*, time_period, **_kwargs): - return EconomicImpactSetupOptions( - process_id=MOCK_PROCESS_ID, - country_id=MOCK_COUNTRY_ID, - reform_policy_id=MOCK_POLICY_ID, - baseline_policy_id=MOCK_BASELINE_POLICY_ID, - region=MOCK_REGION, - dataset=MOCK_DATASET, - time_period=time_period, - options=MOCK_OPTIONS, - api_version=MOCK_API_VERSION, - target="general", - options_hash=MOCK_OPTIONS_HASH, - ) + batch_execution = create_mock_budget_window_batch_execution( + batch_job_id="fc-budget-123", + status="submitted", + ) + mock_simulation_api.run_budget_window_batch.return_value = batch_execution - yearly_results = { - "2026": EconomicImpactResult.completed( - make_mock_budget_impact_data( - tax_revenue_impact=100, - state_tax_revenue_impact=20, - benefit_spending_impact=-10, - budgetary_impact=90, - ) - ), - "2027": EconomicImpactResult.completed( - make_mock_budget_impact_data( - tax_revenue_impact=120, - state_tax_revenue_impact=30, - benefit_spending_impact=-20, - budgetary_impact=100, - ) - ), - "2028": EconomicImpactResult.completed( - make_mock_budget_impact_data( - tax_revenue_impact=140, - state_tax_revenue_impact=40, - benefit_spending_impact=-30, - budgetary_impact=110, - ) - ), - } + result = economy_service.get_budget_window_economic_impact(**base_params) - with ( - patch.object( - economy_service, - "_build_economic_impact_setup_options", - side_effect=make_setup, - ), - patch.object( - economy_service, - "_get_existing_economic_impact", - side_effect=lambda setup_options: yearly_results[ - setup_options.time_period - ], - ) as mock_get_existing, - patch.object( - economy_service, "_get_or_create_economic_impact" - ) as mock_get_economic_impact, - ): - result = economy_service.get_budget_window_economic_impact( - **base_params - ) + assert result.status == ImpactStatus.COMPUTING + assert result.progress == 0 + assert result.completed_years == [] + assert result.computing_years == [] + assert result.queued_years == ["2026", "2027", "2028"] + assert "Queued 2026" in result.message + mock_simulation_api.run_budget_window_batch.assert_called_once() + submitted_payload = ( + mock_simulation_api.run_budget_window_batch.call_args.args[0] + ) + assert submitted_payload["start_year"] == "2026" + assert submitted_payload["window_size"] == 3 + assert submitted_payload["max_parallel"] == 20 + assert submitted_payload["target"] == "general" + assert "time_period" not in submitted_payload + mock_reform_impacts_service.set_reform_impact.assert_called_once() + assert ( + mock_reform_impacts_service.set_reform_impact.call_args.kwargs[ + "execution_id" + ] + == "fc-budget-123" + ) - assert result.status == ImpactStatus.OK - assert result.progress == 100 - assert result.data["annualImpacts"] == [ - { - "year": "2026", + def test__given_completed_tracking_row__returns_completed_batch_result( + self, + economy_service, + base_params, + mock_reform_impacts_service, + mock_simulation_api, + ): + completed_result = { + "kind": "budgetWindow", + "startYear": "2026", + "endYear": "2028", + "windowSize": 3, + "annualImpacts": [ + { + "year": "2026", + "taxRevenueImpact": 100, + "federalTaxRevenueImpact": 80, + "stateTaxRevenueImpact": 20, + "benefitSpendingImpact": -10, + "budgetaryImpact": 90, + } + ], + "totals": { + "year": "Total", "taxRevenueImpact": 100, "federalTaxRevenueImpact": 80, "stateTaxRevenueImpact": 20, "benefitSpendingImpact": -10, "budgetaryImpact": 90, }, - { - "year": "2027", - "taxRevenueImpact": 120, - "federalTaxRevenueImpact": 90, - "stateTaxRevenueImpact": 30, - "benefitSpendingImpact": -20, - "budgetaryImpact": 100, - }, - { - "year": "2028", - "taxRevenueImpact": 140, - "federalTaxRevenueImpact": 100, - "stateTaxRevenueImpact": 40, - "benefitSpendingImpact": -30, - "budgetaryImpact": 110, - }, - ] - assert result.data["totals"] == { - "year": "Total", - "taxRevenueImpact": 360, - "federalTaxRevenueImpact": 270, - "stateTaxRevenueImpact": 90, - "benefitSpendingImpact": -60, - "budgetaryImpact": 300, } - assert mock_get_existing.call_count == 3 - mock_get_economic_impact.assert_not_called() - - def test__given_missing_years__starts_only_up_to_remaining_active_slots( - self, economy_service, base_params - ): - def make_setup(*, time_period, **_kwargs): - return EconomicImpactSetupOptions( - process_id=MOCK_PROCESS_ID, - country_id=MOCK_COUNTRY_ID, - reform_policy_id=MOCK_POLICY_ID, - baseline_policy_id=MOCK_BASELINE_POLICY_ID, - region=MOCK_REGION, - dataset=MOCK_DATASET, - time_period=time_period, - options=MOCK_OPTIONS, - api_version=MOCK_API_VERSION, - target="general", - options_hash=MOCK_OPTIONS_HASH, + mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + create_mock_reform_impact( + status="ok", + execution_id="fc-budget-123", + reform_impact_json=json.dumps(completed_result), + time_period="budget_window:2026:3", ) + ] - base_params["window_size"] = 5 + result = economy_service.get_budget_window_economic_impact(**base_params) - existing_results = { - "2026": EconomicImpactResult.completed( - make_mock_budget_impact_data( - tax_revenue_impact=100, - state_tax_revenue_impact=20, - benefit_spending_impact=-10, - budgetary_impact=90, - ) - ), - "2027": EconomicImpactResult.computing(), - "2028": None, - "2029": None, - "2030": None, - } + assert result.status == ImpactStatus.OK + assert result.progress == 100 + assert result.data == completed_result + mock_simulation_api.get_budget_window_batch_by_id.assert_not_called() - with ( - patch.object( - economy_service, - "_build_economic_impact_setup_options", - side_effect=make_setup, - ), - patch.object( - economy_service, - "_get_existing_economic_impact", - side_effect=lambda setup_options: existing_results[ - setup_options.time_period - ], - ), - patch.object( - economy_service, - "_get_or_create_economic_impact", - return_value=EconomicImpactResult.computing(), - ) as mock_get_economic_impact, - ): - result = economy_service.get_budget_window_economic_impact( - **base_params + def test__given_running_tracking_row__returns_running_batch_progress( + self, + economy_service, + base_params, + mock_reform_impacts_service, + mock_simulation_api, + ): + mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + create_mock_reform_impact( + status="computing", + execution_id="fc-budget-123", + time_period="budget_window:2026:3", + ) + ] + mock_simulation_api.get_budget_window_batch_by_id.return_value = ( + create_mock_budget_window_batch_execution( + batch_job_id="fc-budget-123", + status="running", + progress=33, + completed_years=["2026"], + running_years=["2027"], + queued_years=["2028"], ) + ) + + result = economy_service.get_budget_window_economic_impact(**base_params) assert result.status == ImpactStatus.COMPUTING - assert result.progress == 20 + assert result.progress == 33 assert result.completed_years == ["2026"] - assert result.computing_years == ["2027", "2028", "2029"] - assert result.queued_years == ["2030"] - assert "1 of 5 complete" in result.message - assert mock_get_economic_impact.call_count == 2 - started_years = sorted( - call.args[0].time_period - for call in mock_get_economic_impact.call_args_list - ) - assert started_years == ["2028", "2029"] - - def test__given_year_error__returns_budget_window_error( - self, economy_service, base_params, mock_logger + assert result.computing_years == ["2027"] + assert result.queued_years == ["2028"] + assert "1 of 3 complete" in result.message + + def test__given_completed_batch_poll__persists_result_and_returns_completed( + self, + economy_service, + base_params, + mock_reform_impacts_service, + mock_simulation_api, ): - def make_setup(*, time_period, **_kwargs): - return EconomicImpactSetupOptions( - process_id=MOCK_PROCESS_ID, - country_id=MOCK_COUNTRY_ID, - reform_policy_id=MOCK_POLICY_ID, - baseline_policy_id=MOCK_BASELINE_POLICY_ID, - region=MOCK_REGION, - dataset=MOCK_DATASET, - time_period=time_period, - options=MOCK_OPTIONS, - api_version=MOCK_API_VERSION, - target="general", - options_hash=MOCK_OPTIONS_HASH, + completed_result = { + "kind": "budgetWindow", + "startYear": "2026", + "endYear": "2028", + "windowSize": 3, + "annualImpacts": [], + "totals": {}, + } + mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + create_mock_reform_impact( + status="computing", + execution_id="fc-budget-123", + time_period="budget_window:2026:3", ) + ] + mock_simulation_api.get_budget_window_batch_by_id.return_value = ( + create_mock_budget_window_batch_execution( + batch_job_id="fc-budget-123", + status="complete", + progress=100, + completed_years=["2026", "2027", "2028"], + result=completed_result, + ) + ) - with ( - patch.object( - economy_service, - "_build_economic_impact_setup_options", - side_effect=make_setup, - ), - patch.object( - economy_service, - "_get_existing_economic_impact", - side_effect=[ - EconomicImpactResult.completed( - make_mock_budget_impact_data( - tax_revenue_impact=100, - state_tax_revenue_impact=20, - benefit_spending_impact=-10, - budgetary_impact=90, - ) - ), - EconomicImpactResult( - status=ImpactStatus.ERROR, - data={"message": "Calculation failed for 2027"}, - ), - None, - ], - ), - patch.object( - economy_service, "_get_or_create_economic_impact" - ) as mock_get_economic_impact, - ): - result = economy_service.get_budget_window_economic_impact( - **base_params + result = economy_service.get_budget_window_economic_impact(**base_params) + + assert result.status == ImpactStatus.OK + assert result.data == completed_result + mock_reform_impacts_service.set_complete_reform_impact.assert_called_once() + call_kwargs = ( + mock_reform_impacts_service.set_complete_reform_impact.call_args.kwargs + ) + assert call_kwargs["execution_id"] == "fc-budget-123" + assert json.loads(call_kwargs["reform_impact_json"]) == completed_result + + def test__given_failed_batch_poll__persists_error_and_returns_failed( + self, + economy_service, + base_params, + mock_reform_impacts_service, + mock_simulation_api, + ): + mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + create_mock_reform_impact( + status="computing", + execution_id="fc-budget-123", + time_period="budget_window:2026:3", ) + ] + mock_simulation_api.get_budget_window_batch_by_id.return_value = ( + create_mock_budget_window_batch_execution( + batch_job_id="fc-budget-123", + status="failed", + progress=33, + completed_years=["2026"], + queued_years=["2028"], + failed_years=["2027"], + error="Budget window failed for 2027", + ) + ) + + result = economy_service.get_budget_window_economic_impact(**base_params) assert result.status == ImpactStatus.ERROR - assert result.error == "Calculation failed for 2027" + assert result.error == "Budget window failed for 2027" assert result.completed_years == ["2026"] - mock_get_economic_impact.assert_not_called() + assert result.computing_years == [] + assert result.queued_years == ["2028"] + mock_reform_impacts_service.set_error_reform_impact.assert_called_once() + assert ( + mock_reform_impacts_service.set_error_reform_impact.call_args.kwargs[ + "execution_id" + ] + == "fc-budget-123" + ) def test__given_cliff_target__raises_value_error( self, economy_service, base_params @@ -1181,32 +1214,44 @@ def test__given_oversized_window__raises_value_error( ): economy_service.get_budget_window_economic_impact(**base_params) - def test__given_started_year_error__returns_specific_budget_window_error( - self, economy_service, base_params, mock_logger + def test__given_end_year_after_2099__raises_value_error( + self, economy_service, base_params ): - with ( - patch.object( - economy_service, - "_get_existing_economic_impact", - side_effect=[None, None, None], - ), - patch.object( - economy_service, - "_get_or_create_economic_impact", - side_effect=[ - EconomicImpactResult.error("Calculation failed for 2026"), - EconomicImpactResult.computing(), - EconomicImpactResult.computing(), - ], + base_params["start_year"] = "2090" + base_params["window_size"] = 20 + + with pytest.raises( + ValueError, + match=( + f"budget-window end_year must be {BUDGET_WINDOW_MAX_END_YEAR} or earlier" ), ): - result = economy_service.get_budget_window_economic_impact( - **base_params + economy_service.get_budget_window_economic_impact(**base_params) + + def test__given_failed_tracking_row_and_unavailable_batch__returns_stored_error( + self, + economy_service, + base_params, + mock_reform_impacts_service, + mock_simulation_api, + ): + mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + create_mock_reform_impact( + status="error", + execution_id="fc-budget-123", + time_period="budget_window:2026:3", + message="Stored batch failure", ) + ] + mock_simulation_api.get_budget_window_batch_by_id.side_effect = Exception( + "batch lookup failed" + ) + + result = economy_service.get_budget_window_economic_impact(**base_params) assert result.status == ImpactStatus.ERROR - assert result.error == "Calculation failed for 2026" - assert result.completed_years == [] + assert result.error == "Stored batch failure" + assert result.queued_years == ["2026", "2027", "2028"] def test__given_runtime_cache_version__uses_versioned_cache_key_for_budget_window( self, @@ -1214,59 +1259,37 @@ def test__given_runtime_cache_version__uses_versioned_cache_key_for_budget_windo base_params, mock_country_package_versions, mock_get_dataset_version, + mock_reform_impacts_service, + mock_simulation_api, mock_logger, mock_datetime, mock_numpy_random, monkeypatch, ): cache_version = "e1cache01" - seen_existing_calls = [] - seen_create_calls = [] monkeypatch.setattr( "policyengine_api.services.economy_service.get_economy_impact_cache_version", lambda country_id, api_version=None: cache_version, ) - - def fake_get_existing(setup_options): - seen_existing_calls.append( - (setup_options.time_period, setup_options.api_version) - ) - return None - - def fake_get_or_create(setup_options): - seen_create_calls.append( - (setup_options.time_period, setup_options.api_version) - ) - return EconomicImpactResult.computing() - - with ( - patch.object( - economy_service, - "_get_existing_economic_impact", - side_effect=fake_get_existing, - ), - patch.object( - economy_service, - "_get_or_create_economic_impact", - side_effect=fake_get_or_create, - ), - ): - result = economy_service.get_budget_window_economic_impact( - **base_params - ) + result = economy_service.get_budget_window_economic_impact(**base_params) assert result.status == ImpactStatus.COMPUTING - assert seen_existing_calls == [ - ("2026", cache_version), - ("2027", cache_version), - ("2028", cache_version), - ] - assert seen_create_calls == [ - ("2026", cache_version), - ("2027", cache_version), - ("2028", cache_version), - ] + mock_reform_impacts_service.get_all_reform_impacts.assert_called_once() + assert ( + mock_reform_impacts_service.get_all_reform_impacts.call_args.args[5] + == "budget_window:2026:3" + ) + assert ( + mock_reform_impacts_service.get_all_reform_impacts.call_args.args[7] + == cache_version + ) + assert ( + mock_reform_impacts_service.set_reform_impact.call_args.kwargs[ + "api_version" + ] + == cache_version + ) class TestGetPreviousImpacts: @pytest.fixture diff --git a/uv.lock b/uv.lock index 8bb11c5e4..778b61515 100644 --- a/uv.lock +++ b/uv.lock @@ -2622,7 +2622,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/a0/f3/eeea7dab690e46cd9 [[package]] name = "policyengine-api" -version = "3.40.7" +version = "3.40.8" source = { editable = "." } dependencies = [ { name = "anthropic" }, From 324b29f1ccfeb4b67f8af9989a118e33937c9228 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 1 May 2026 18:48:03 +0200 Subject: [PATCH 14/27] Fix lint after budget-window rebase --- policyengine_api/country.py | 2 +- tests/unit/data/test_sqlalchemy_v2.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/policyengine_api/country.py b/policyengine_api/country.py index 430df888c..81331fb8f 100644 --- a/policyengine_api/country.py +++ b/policyengine_api/country.py @@ -429,7 +429,7 @@ def calculate( household[entity_plural][entity_id][variable_name][period] = ( entity_result ) - except Exception as e: + except Exception: logging.exception(f"Error computing {variable_name} for {entity_id}") if "axes" not in household: household[entity_plural][entity_id][variable_name][period] = None diff --git a/tests/unit/data/test_sqlalchemy_v2.py b/tests/unit/data/test_sqlalchemy_v2.py index 2ea63f0f0..e873521a2 100644 --- a/tests/unit/data/test_sqlalchemy_v2.py +++ b/tests/unit/data/test_sqlalchemy_v2.py @@ -10,7 +10,6 @@ (dict(row) and row["key"]). """ -import pytest import sqlalchemy from unittest.mock import MagicMock From 9c3d10f4ea79d6ae86f2212d081943bf32f85e58 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 1 May 2026 19:56:48 +0200 Subject: [PATCH 15/27] Address budget-window review feedback --- policyengine_api/libs/simulation_api_modal.py | 2 + .../services/budget_window_cache.py | 167 +++++++++++++++ policyengine_api/services/economy_service.py | 202 ++++++------------ .../services/reform_impacts_service.py | 48 ----- tests/fixtures/services/economy_service.py | 20 ++ tests/unit/libs/test_simulation_api_modal.py | 14 ++ .../unit/services/test_budget_window_cache.py | 72 +++++++ tests/unit/services/test_economy_service.py | 165 ++++++-------- .../services/test_reform_impacts_service.py | 77 ++----- 9 files changed, 426 insertions(+), 341 deletions(-) create mode 100644 policyengine_api/services/budget_window_cache.py create mode 100644 tests/unit/services/test_budget_window_cache.py diff --git a/policyengine_api/libs/simulation_api_modal.py b/policyengine_api/libs/simulation_api_modal.py index c567de06f..833c2684f 100644 --- a/policyengine_api/libs/simulation_api_modal.py +++ b/policyengine_api/libs/simulation_api_modal.py @@ -299,6 +299,8 @@ def get_budget_window_batch_by_id( response = self.client.get( f"{self.base_url}/budget-window-jobs/{batch_job_id}" ) + if response.status_code not in (200, 202, 500): + response.raise_for_status() data = response.json() return ModalBudgetWindowBatchExecution( diff --git a/policyengine_api/services/budget_window_cache.py b/policyengine_api/services/budget_window_cache.py new file mode 100644 index 000000000..dbfc74fb8 --- /dev/null +++ b/policyengine_api/services/budget_window_cache.py @@ -0,0 +1,167 @@ +import hashlib +import json +import os +from typing import Any + +import redis + +from policyengine_api.gcp_logging import logger + +BUDGET_WINDOW_CACHE_PREFIX = "budget_window:v1" +BUDGET_WINDOW_STARTING_PREFIX = "starting:" +BUDGET_WINDOW_STARTING_TTL_SECONDS = int( + os.environ.get("BUDGET_WINDOW_STARTING_TTL_SECONDS", "300") +) +BUDGET_WINDOW_BATCH_TTL_SECONDS = int( + os.environ.get("BUDGET_WINDOW_BATCH_TTL_SECONDS", "86400") +) +BUDGET_WINDOW_RESULT_TTL_SECONDS = int( + os.environ.get("BUDGET_WINDOW_RESULT_TTL_SECONDS", "2592000") +) + + +class BudgetWindowCache: + """Redis-backed cache and in-flight mapping for budget-window requests.""" + + def __init__(self, client: redis.Redis | None = None): + self._client = client + + @property + def client(self) -> redis.Redis: + if self._client is None: + self._client = redis.Redis( + host=os.environ.get("CACHE_REDIS_HOST", "127.0.0.1"), + port=int(os.environ.get("CACHE_REDIS_PORT", "6379")), + db=int(os.environ.get("CACHE_REDIS_DB", "0")), + decode_responses=True, + socket_connect_timeout=1, + socket_timeout=1, + ) + return self._client + + def build_key( + self, + *, + country_id: str, + reform_policy_id: int, + baseline_policy_id: int, + region: str, + dataset: str, + time_period: str, + options_hash: str | None, + api_version: str, + ) -> str: + key_payload = { + "country_id": country_id, + "reform_policy_id": reform_policy_id, + "baseline_policy_id": baseline_policy_id, + "region": region, + "dataset": dataset, + "time_period": time_period, + "options_hash": options_hash, + "api_version": api_version, + } + encoded = json.dumps( + key_payload, + sort_keys=True, + separators=(",", ":"), + default=str, + ) + digest = hashlib.sha256(encoded.encode("utf-8")).hexdigest() + return f"{BUDGET_WINDOW_CACHE_PREFIX}:{country_id}:{digest}" + + def _result_key(self, cache_key: str) -> str: + return f"{cache_key}:result" + + def _batch_key(self, cache_key: str) -> str: + return f"{cache_key}:batch_job_id" + + def _handle_cache_error(self, operation: str, error: Exception) -> None: + logger.log_struct( + { + "message": f"Budget-window Redis cache {operation} failed", + "error": str(error), + }, + severity="WARNING", + ) + + def get_completed_result(self, cache_key: str) -> dict[str, Any] | None: + try: + payload = self.client.get(self._result_key(cache_key)) + except Exception as error: + self._handle_cache_error("read", error) + raise + + if not payload: + return None + + try: + result = json.loads(payload) + except (TypeError, ValueError) as error: + self._handle_cache_error("decode", error) + return None + + return result if isinstance(result, dict) else None + + def set_completed_result(self, cache_key: str, result: dict[str, Any]) -> None: + try: + self.client.set( + self._result_key(cache_key), + json.dumps(result), + ex=BUDGET_WINDOW_RESULT_TTL_SECONDS, + ) + except Exception as error: + self._handle_cache_error("write result", error) + + def get_batch_job_id(self, cache_key: str) -> str | None: + try: + value = self.client.get(self._batch_key(cache_key)) + except Exception as error: + self._handle_cache_error("read batch id", error) + raise + + if not isinstance(value, str) or not value: + return None + if value.startswith(BUDGET_WINDOW_STARTING_PREFIX): + return None + return value + + def claim_batch_start(self, cache_key: str, claim_token: str) -> bool: + try: + claimed = self.client.set( + self._batch_key(cache_key), + f"{BUDGET_WINDOW_STARTING_PREFIX}{claim_token}", + nx=True, + ex=BUDGET_WINDOW_STARTING_TTL_SECONDS, + ) + except Exception as error: + self._handle_cache_error("claim", error) + raise + + return bool(claimed) + + def store_batch_job_id(self, cache_key: str, batch_job_id: str) -> None: + try: + self.client.set( + self._batch_key(cache_key), + batch_job_id, + ex=BUDGET_WINDOW_BATCH_TTL_SECONDS, + ) + except Exception as error: + self._handle_cache_error("write batch id", error) + raise + + def clear_starting_claim(self, cache_key: str, claim_token: str) -> None: + try: + batch_key = self._batch_key(cache_key) + value = self.client.get(batch_key) + if value == f"{BUDGET_WINDOW_STARTING_PREFIX}{claim_token}": + self.client.delete(batch_key) + except Exception as error: + self._handle_cache_error("clear claim", error) + + def clear_batch_job_id(self, cache_key: str) -> None: + try: + self.client.delete(self._batch_key(cache_key)) + except Exception as error: + self._handle_cache_error("clear batch id", error) diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 1ec5f969b..0b16ffe99 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -2,6 +2,7 @@ from policyengine_api.services.reform_impacts_service import ( ReformImpactsService, ) +from policyengine_api.services.budget_window_cache import BudgetWindowCache from policyengine_api.constants import ( COUNTRY_PACKAGE_VERSIONS, EXECUTION_STATUSES_SUCCESS, @@ -26,7 +27,7 @@ import datetime import hashlib import uuid -from typing import Literal, Any, Optional, Annotated, Union +from typing import Literal, Any, Optional, Annotated from dotenv import load_dotenv from pydantic import BaseModel, Field import numpy as np @@ -37,6 +38,7 @@ policy_service = PolicyService() reform_impacts_service = ReformImpactsService() simulation_api = simulation_api_modal +budget_window_cache = BudgetWindowCache() def get_policyengine_version() -> str | None: @@ -310,7 +312,7 @@ def get_budget_window_economic_impact( start_year=start_year, window_size=window_size, ) - tracking_setup_options = self._build_budget_window_tracking_setup_options( + setup_options = self._build_budget_window_setup_options( country_id=country_id, policy_id=policy_id, baseline_policy_id=baseline_policy_id, @@ -322,30 +324,43 @@ def get_budget_window_economic_impact( api_version=api_version, target=target, ) + cache_key = self._build_budget_window_cache_key(setup_options) - most_recent_impact = self._get_budget_window_tracking_impact( - tracking_setup_options - ) - if most_recent_impact is None: - self._start_budget_window_batch( - setup_options=tracking_setup_options, - start_year=start_year, - window_size=window_size, - max_parallel=max_active_years, - ) - return self._build_budget_window_computing_result( + cached_result = budget_window_cache.get_completed_result(cache_key) + if cached_result is not None: + return BudgetWindowEconomicImpactResult.completed(cached_result) + + batch_job_id = budget_window_cache.get_batch_job_id(cache_key) + if batch_job_id: + return self._get_budget_window_result_from_batch_job_id( + batch_job_id=batch_job_id, + cache_key=cache_key, total_years=len(years), - completed_years=[], - computing_years=[], - queued_years=years, - progress=0, + queued_years_on_submit=years, ) - return self._get_budget_window_result_from_tracking_impact( - setup_options=tracking_setup_options, - most_recent_impact=most_recent_impact, + claim_token = setup_options.process_id + if budget_window_cache.claim_batch_start(cache_key, claim_token): + try: + batch_execution = self._start_budget_window_batch( + setup_options=setup_options, + start_year=start_year, + window_size=window_size, + max_parallel=max_active_years, + ) + budget_window_cache.store_batch_job_id( + cache_key, batch_execution.batch_job_id + ) + except Exception: + budget_window_cache.clear_starting_claim(cache_key, claim_token) + raise + + return self._build_budget_window_computing_result( total_years=len(years), - queued_years_on_submit=years, + completed_years=[], + computing_years=[], + queued_years=years, + progress=0, ) except Exception as e: print(f"Error getting budget-window economic impact: {str(e)}") @@ -360,7 +375,7 @@ def _build_budget_window_years( start_year_int = int(start_year) return [str(start_year_int + index) for index in range(window_size)] - def _build_budget_window_tracking_time_period( + def _build_budget_window_time_period( self, *, start_year: str, @@ -368,7 +383,7 @@ def _build_budget_window_tracking_time_period( ) -> str: return f"budget_window:{start_year}:{window_size}" - def _build_budget_window_tracking_setup_options( + def _build_budget_window_setup_options( self, *, country_id: str, @@ -388,7 +403,7 @@ def _build_budget_window_tracking_setup_options( baseline_policy_id=baseline_policy_id, region=region, dataset=dataset, - time_period=self._build_budget_window_tracking_time_period( + time_period=self._build_budget_window_time_period( start_year=start_year, window_size=window_size, ), @@ -397,6 +412,21 @@ def _build_budget_window_tracking_setup_options( target=target, ) + def _build_budget_window_cache_key( + self, + setup_options: EconomicImpactSetupOptions, + ) -> str: + return budget_window_cache.build_key( + country_id=setup_options.country_id, + reform_policy_id=setup_options.reform_policy_id, + baseline_policy_id=setup_options.baseline_policy_id, + region=setup_options.region, + dataset=setup_options.dataset, + time_period=setup_options.time_period, + options_hash=setup_options.options_hash, + api_version=setup_options.api_version, + ) + def _build_budget_window_batch_payload( self, *, @@ -433,12 +463,6 @@ def _build_budget_window_batch_payload( sim_params["target"] = setup_options.target return sim_params - def _get_budget_window_tracking_impact( - self, - setup_options: EconomicImpactSetupOptions, - ) -> dict | None: - return self._get_exact_reform_impact(setup_options) - def _start_budget_window_batch( self, *, @@ -446,7 +470,7 @@ def _start_budget_window_batch( start_year: str, window_size: int, max_parallel: int, - ) -> None: + ): sim_params = self._build_budget_window_batch_payload( setup_options=setup_options, start_year=start_year, @@ -465,67 +489,31 @@ def _start_budget_window_batch( severity="INFO", ) - batch_execution = simulation_api.run_budget_window_batch(sim_params) - self._set_reform_impact_computing( - setup_options=setup_options, - execution_id=batch_execution.batch_job_id, - ) + return simulation_api.run_budget_window_batch(sim_params) - def _get_budget_window_result_from_tracking_impact( + def _get_budget_window_result_from_batch_job_id( self, *, - setup_options: EconomicImpactSetupOptions, - most_recent_impact: dict, + batch_job_id: str, + cache_key: str, total_years: int, queued_years_on_submit: list[str], ) -> BudgetWindowEconomicImpactResult: - impact_status = most_recent_impact.get("status") - if impact_status == ImpactStatus.OK.value: - return BudgetWindowEconomicImpactResult.completed( - json.loads(most_recent_impact["reform_impact_json"]) - ) - - execution_id = most_recent_impact.get("execution_id") - if not execution_id: - return BudgetWindowEconomicImpactResult.failed( - most_recent_impact.get("message") - or "Budget-window batch tracking row is missing execution_id", - queued_years=queued_years_on_submit, - ) - - try: - batch_execution = simulation_api.get_budget_window_batch_by_id(execution_id) - except Exception: - if impact_status == ImpactStatus.ERROR.value: - return BudgetWindowEconomicImpactResult.failed( - most_recent_impact.get("message") or "Budget-window batch failed", - queued_years=queued_years_on_submit, - ) - raise + batch_execution = simulation_api.get_budget_window_batch_by_id(batch_job_id) if batch_execution.status in EXECUTION_STATUSES_SUCCESS: result = batch_execution.result or {} - self._set_reform_impact_complete( - setup_options=setup_options, - reform_impact_json=json.dumps(result), - execution_id=execution_id, - ) + budget_window_cache.set_completed_result(cache_key, result) + budget_window_cache.clear_batch_job_id(cache_key) return BudgetWindowEconomicImpactResult.completed(result) if batch_execution.status in EXECUTION_STATUSES_FAILURE: - error_message = batch_execution.error or ( - most_recent_impact.get("message") or "Budget-window batch failed" - ) - self._set_reform_impact_error( - setup_options=setup_options, - message=error_message, - execution_id=execution_id, - ) + error_message = batch_execution.error or "Budget-window batch failed" return BudgetWindowEconomicImpactResult.failed( error_message, completed_years=batch_execution.completed_years, computing_years=batch_execution.running_years, - queued_years=batch_execution.queued_years, + queued_years=batch_execution.queued_years or queued_years_on_submit, ) if batch_execution.status in EXECUTION_STATUSES_PENDING: @@ -842,68 +830,6 @@ def _get_existing_economic_impact( raise ValueError(f"Unknown impact status: {status}") - def _get_economic_impact_error_message( - self, result: EconomicImpactResult, year: str - ) -> str: - if result.message: - return result.message - - if isinstance(result.data, dict): - data_message = result.data.get("message") - if isinstance(data_message, str) and data_message: - return data_message - - return f"Budget-window calculation failed for {year}" - - def _extract_budget_window_annual_impact( - self, year: str, impact_data: dict - ) -> dict[str, Union[str, int, float]]: - budget = impact_data.get("budget", {}) - state_tax_revenue_impact = budget.get("state_tax_revenue_impact", 0) - tax_revenue_impact = budget.get("tax_revenue_impact", 0) - - return { - "year": year, - "taxRevenueImpact": tax_revenue_impact, - "federalTaxRevenueImpact": tax_revenue_impact - state_tax_revenue_impact, - "stateTaxRevenueImpact": state_tax_revenue_impact, - "benefitSpendingImpact": budget.get("benefit_spending_impact", 0), - "budgetaryImpact": budget.get("budgetary_impact", 0), - } - - def _sum_budget_window_annual_impacts(self, annual_impacts: list[dict]) -> dict: - totals = { - "year": "Total", - "taxRevenueImpact": 0, - "federalTaxRevenueImpact": 0, - "stateTaxRevenueImpact": 0, - "benefitSpendingImpact": 0, - "budgetaryImpact": 0, - } - - for annual_impact in annual_impacts: - totals["taxRevenueImpact"] += annual_impact["taxRevenueImpact"] - totals["federalTaxRevenueImpact"] += annual_impact[ - "federalTaxRevenueImpact" - ] - totals["stateTaxRevenueImpact"] += annual_impact["stateTaxRevenueImpact"] - totals["benefitSpendingImpact"] += annual_impact["benefitSpendingImpact"] - totals["budgetaryImpact"] += annual_impact["budgetaryImpact"] - - return totals - - def _build_budget_window_output( - self, *, start_year: str, window_size: int, annual_impacts: list[dict] - ) -> dict: - return { - "kind": "budgetWindow", - "startYear": start_year, - "endYear": str(int(start_year) + window_size - 1), - "windowSize": window_size, - "annualImpacts": annual_impacts, - "totals": self._sum_budget_window_annual_impacts(annual_impacts), - } - def _build_budget_window_progress_message( self, *, diff --git a/policyengine_api/services/reform_impacts_service.py b/policyengine_api/services/reform_impacts_service.py index 91928495e..d2aad8b3a 100644 --- a/policyengine_api/services/reform_impacts_service.py +++ b/policyengine_api/services/reform_impacts_service.py @@ -6,7 +6,6 @@ LOCAL_REFORM_IMPACT_LOCK = Lock() -REFORM_IMPACT_SCHEMA_LOCK = Lock() REFORM_IMPACT_LOCK_TIMEOUT_SECONDS = 5 @@ -16,46 +15,6 @@ class ReformImpactsService: this is connected to the shared reform_impact table. """ - def __init__(self): - self._schema_checked = False - - def _ensure_remote_schema(self) -> None: - if database.local or self._schema_checked: - return - - with REFORM_IMPACT_SCHEMA_LOCK: - if self._schema_checked: - return - - existing_columns = { - row["Field"] - for row in database.query("SHOW COLUMNS FROM reform_impact").fetchall() - } - required_columns = { - "dataset": ( - "ALTER TABLE reform_impact " - "ADD COLUMN dataset VARCHAR(255) NOT NULL DEFAULT 'default'" - ), - "execution_id": ( - "ALTER TABLE reform_impact " - "ADD COLUMN execution_id VARCHAR(255) NULL" - ), - "end_time": ( - "ALTER TABLE reform_impact ADD COLUMN end_time DATETIME NULL" - ), - } - - for column_name, alter_query in required_columns.items(): - if column_name in existing_columns: - continue - try: - database.query(alter_query) - except Exception as error: - if "Duplicate column name" not in str(error): - raise - - self._schema_checked = True - def _build_lock_name( self, country_id, @@ -137,7 +96,6 @@ def get_all_reform_impacts( api_version, ): try: - self._ensure_remote_schema() query = ( "SELECT reform_impact_json, status, message, start_time, execution_id FROM " "reform_impact WHERE country_id = ? AND reform_policy_id = ? AND " @@ -175,7 +133,6 @@ def get_all_reform_impacts_by_options_hash_prefix( api_version, ): try: - self._ensure_remote_schema() query = ( "SELECT reform_impact_json, status, message, start_time, execution_id, options_hash FROM " "reform_impact WHERE country_id = ? AND reform_policy_id = ? AND " @@ -219,7 +176,6 @@ def set_reform_impact( execution_id: str, ): try: - self._ensure_remote_schema() query = ( "INSERT INTO reform_impact (country_id, reform_policy_id, baseline_policy_id, " "region, dataset, time_period, options_json, options_hash, status, api_version, " @@ -260,7 +216,6 @@ def update_reform_impact_execution_id( new_execution_id, ): try: - self._ensure_remote_schema() query = ( "UPDATE reform_impact SET execution_id = ? WHERE country_id = ? AND " "reform_policy_id = ? AND baseline_policy_id = ? AND region = ? AND " @@ -297,7 +252,6 @@ def delete_reform_impact( options_hash, ): try: - self._ensure_remote_schema() query = ( "DELETE FROM reform_impact WHERE country_id = ? AND " "reform_policy_id = ? AND baseline_policy_id = ? AND " @@ -334,7 +288,6 @@ def set_error_reform_impact( execution_id: str, ): try: - self._ensure_remote_schema() query = ( "UPDATE reform_impact SET status = ?, message = ?, end_time = ? WHERE " "country_id = ? AND reform_policy_id = ? AND baseline_policy_id = ? AND " @@ -379,7 +332,6 @@ def set_complete_reform_impact( execution_id, ): try: - self._ensure_remote_schema() query = ( "UPDATE reform_impact SET status = ?, message = ?, end_time = ?, " "reform_impact_json = ? WHERE country_id = ? AND reform_policy_id = ? AND " diff --git a/tests/fixtures/services/economy_service.py b/tests/fixtures/services/economy_service.py index f29eb88b9..3ad61da1f 100644 --- a/tests/fixtures/services/economy_service.py +++ b/tests/fixtures/services/economy_service.py @@ -162,6 +162,26 @@ def mock_simulation_api(): yield mock +@pytest.fixture +def mock_budget_window_cache(): + """Mock Redis-backed budget-window cache.""" + mock_cache = MagicMock() + mock_cache.build_key.return_value = "budget-window-cache-key" + mock_cache.get_completed_result.return_value = None + mock_cache.get_batch_job_id.return_value = None + mock_cache.claim_batch_start.return_value = True + mock_cache.store_batch_job_id.return_value = None + mock_cache.clear_starting_claim.return_value = None + mock_cache.set_completed_result.return_value = None + mock_cache.clear_batch_job_id.return_value = None + + with patch( + "policyengine_api.services.economy_service.budget_window_cache", + mock_cache, + ) as mock: + yield mock + + @pytest.fixture def mock_logger(): """Mock logger.""" diff --git a/tests/unit/libs/test_simulation_api_modal.py b/tests/unit/libs/test_simulation_api_modal.py index c80ab373e..a2b4cdcdb 100644 --- a/tests/unit/libs/test_simulation_api_modal.py +++ b/tests/unit/libs/test_simulation_api_modal.py @@ -521,6 +521,20 @@ def test__given_failed_batch__then_returns_error( assert execution.failed_years == ["2027"] assert execution.error == "Budget window failed" + def test__given_unexpected_http_error__then_raises_exception( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=404, + json_data={"detail": "Budget-window job not found"}, + ) + api = SimulationAPIModal() + + with pytest.raises(httpx.HTTPStatusError): + api.get_budget_window_batch_by_id(MOCK_BATCH_JOB_ID) + class TestGetExecutionId: def test__given_execution__then_returns_job_id(self, mock_httpx_client): # Given diff --git a/tests/unit/services/test_budget_window_cache.py b/tests/unit/services/test_budget_window_cache.py new file mode 100644 index 000000000..ad23d26aa --- /dev/null +++ b/tests/unit/services/test_budget_window_cache.py @@ -0,0 +1,72 @@ +from policyengine_api.services.budget_window_cache import BudgetWindowCache + + +class FakeRedis: + def __init__(self): + self.values = {} + + def get(self, key): + return self.values.get(key) + + def set(self, key, value, nx=False, ex=None): + if nx and key in self.values: + return False + self.values[key] = value + return True + + def delete(self, key): + self.values.pop(key, None) + + +def test_build_key_is_stable_for_request_identity(): + cache = BudgetWindowCache(client=FakeRedis()) + + first = cache.build_key( + country_id="us", + reform_policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="budget_window:2026:10", + options_hash="[option=value]", + api_version="e1cache01", + ) + second = cache.build_key( + country_id="us", + reform_policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="budget_window:2026:10", + options_hash="[option=value]", + api_version="e1cache01", + ) + + assert first == second + assert first.startswith("budget_window:v1:us:") + + +def test_claim_batch_start_allows_one_starter(): + cache = BudgetWindowCache(client=FakeRedis()) + + assert cache.claim_batch_start("budget_window:v1:us:key", "process-1") is True + assert cache.claim_batch_start("budget_window:v1:us:key", "process-2") is False + assert cache.get_batch_job_id("budget_window:v1:us:key") is None + + +def test_store_batch_job_id_replaces_starting_claim(): + cache = BudgetWindowCache(client=FakeRedis()) + cache.claim_batch_start("budget_window:v1:us:key", "process-1") + + cache.store_batch_job_id("budget_window:v1:us:key", "fc-parent") + + assert cache.get_batch_job_id("budget_window:v1:us:key") == "fc-parent" + + +def test_completed_result_round_trips(): + cache = BudgetWindowCache(client=FakeRedis()) + result = {"kind": "budgetWindow", "totals": {"budgetaryImpact": 10}} + + cache.set_completed_result("budget_window:v1:us:key", result) + + assert cache.get_completed_result("budget_window:v1:us:key") == result diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index fd9c64416..f1f72bf87 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -989,12 +989,13 @@ def base_params(self): "target": "general", } - def test__given_no_tracking_row__submits_parent_batch_and_returns_queued_result( + def test__given_no_cached_batch__submits_parent_batch_and_returns_queued_result( self, economy_service, base_params, mock_reform_impacts_service, mock_simulation_api, + mock_budget_window_cache, ): batch_execution = create_mock_budget_window_batch_execution( batch_job_id="fc-budget-123", @@ -1019,20 +1020,20 @@ def test__given_no_tracking_row__submits_parent_batch_and_returns_queued_result( assert submitted_payload["max_parallel"] == 20 assert submitted_payload["target"] == "general" assert "time_period" not in submitted_payload - mock_reform_impacts_service.set_reform_impact.assert_called_once() - assert ( - mock_reform_impacts_service.set_reform_impact.call_args.kwargs[ - "execution_id" - ] - == "fc-budget-123" + mock_budget_window_cache.claim_batch_start.assert_called_once_with( + "budget-window-cache-key", MOCK_PROCESS_ID ) + mock_budget_window_cache.store_batch_job_id.assert_called_once_with( + "budget-window-cache-key", "fc-budget-123" + ) + mock_reform_impacts_service.set_reform_impact.assert_not_called() - def test__given_completed_tracking_row__returns_completed_batch_result( + def test__given_completed_cached_result__returns_completed_batch_result( self, economy_service, base_params, - mock_reform_impacts_service, mock_simulation_api, + mock_budget_window_cache, ): completed_result = { "kind": "budgetWindow", @@ -1058,14 +1059,9 @@ def test__given_completed_tracking_row__returns_completed_batch_result( "budgetaryImpact": 90, }, } - mock_reform_impacts_service.get_all_reform_impacts.return_value = [ - create_mock_reform_impact( - status="ok", - execution_id="fc-budget-123", - reform_impact_json=json.dumps(completed_result), - time_period="budget_window:2026:3", - ) - ] + mock_budget_window_cache.get_completed_result.return_value = ( + completed_result + ) result = economy_service.get_budget_window_economic_impact(**base_params) @@ -1073,21 +1069,16 @@ def test__given_completed_tracking_row__returns_completed_batch_result( assert result.progress == 100 assert result.data == completed_result mock_simulation_api.get_budget_window_batch_by_id.assert_not_called() + mock_simulation_api.run_budget_window_batch.assert_not_called() - def test__given_running_tracking_row__returns_running_batch_progress( + def test__given_cached_batch_id__returns_running_batch_progress( self, economy_service, base_params, - mock_reform_impacts_service, mock_simulation_api, + mock_budget_window_cache, ): - mock_reform_impacts_service.get_all_reform_impacts.return_value = [ - create_mock_reform_impact( - status="computing", - execution_id="fc-budget-123", - time_period="budget_window:2026:3", - ) - ] + mock_budget_window_cache.get_batch_job_id.return_value = "fc-budget-123" mock_simulation_api.get_budget_window_batch_by_id.return_value = ( create_mock_budget_window_batch_execution( batch_job_id="fc-budget-123", @@ -1107,13 +1098,16 @@ def test__given_running_tracking_row__returns_running_batch_progress( assert result.computing_years == ["2027"] assert result.queued_years == ["2028"] assert "1 of 3 complete" in result.message + mock_simulation_api.get_budget_window_batch_by_id.assert_called_once_with( + "fc-budget-123" + ) - def test__given_completed_batch_poll__persists_result_and_returns_completed( + def test__given_completed_batch_poll__caches_result_and_returns_completed( self, economy_service, base_params, - mock_reform_impacts_service, mock_simulation_api, + mock_budget_window_cache, ): completed_result = { "kind": "budgetWindow", @@ -1123,13 +1117,7 @@ def test__given_completed_batch_poll__persists_result_and_returns_completed( "annualImpacts": [], "totals": {}, } - mock_reform_impacts_service.get_all_reform_impacts.return_value = [ - create_mock_reform_impact( - status="computing", - execution_id="fc-budget-123", - time_period="budget_window:2026:3", - ) - ] + mock_budget_window_cache.get_batch_job_id.return_value = "fc-budget-123" mock_simulation_api.get_budget_window_batch_by_id.return_value = ( create_mock_budget_window_batch_execution( batch_job_id="fc-budget-123", @@ -1144,27 +1132,21 @@ def test__given_completed_batch_poll__persists_result_and_returns_completed( assert result.status == ImpactStatus.OK assert result.data == completed_result - mock_reform_impacts_service.set_complete_reform_impact.assert_called_once() - call_kwargs = ( - mock_reform_impacts_service.set_complete_reform_impact.call_args.kwargs + mock_budget_window_cache.set_completed_result.assert_called_once_with( + "budget-window-cache-key", completed_result + ) + mock_budget_window_cache.clear_batch_job_id.assert_called_once_with( + "budget-window-cache-key" ) - assert call_kwargs["execution_id"] == "fc-budget-123" - assert json.loads(call_kwargs["reform_impact_json"]) == completed_result - def test__given_failed_batch_poll__persists_error_and_returns_failed( + def test__given_failed_batch_poll__returns_failed( self, economy_service, base_params, - mock_reform_impacts_service, mock_simulation_api, + mock_budget_window_cache, ): - mock_reform_impacts_service.get_all_reform_impacts.return_value = [ - create_mock_reform_impact( - status="computing", - execution_id="fc-budget-123", - time_period="budget_window:2026:3", - ) - ] + mock_budget_window_cache.get_batch_job_id.return_value = "fc-budget-123" mock_simulation_api.get_budget_window_batch_by_id.return_value = ( create_mock_budget_window_batch_execution( batch_job_id="fc-budget-123", @@ -1184,12 +1166,40 @@ def test__given_failed_batch_poll__persists_error_and_returns_failed( assert result.completed_years == ["2026"] assert result.computing_years == [] assert result.queued_years == ["2028"] - mock_reform_impacts_service.set_error_reform_impact.assert_called_once() - assert ( - mock_reform_impacts_service.set_error_reform_impact.call_args.kwargs[ - "execution_id" - ] - == "fc-budget-123" + mock_budget_window_cache.set_completed_result.assert_not_called() + + def test__given_existing_start_claim__does_not_submit_duplicate_batch( + self, + economy_service, + base_params, + mock_simulation_api, + mock_budget_window_cache, + ): + mock_budget_window_cache.claim_batch_start.return_value = False + + result = economy_service.get_budget_window_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + assert result.progress == 0 + assert result.queued_years == ["2026", "2027", "2028"] + mock_simulation_api.run_budget_window_batch.assert_not_called() + + def test__given_batch_submission_fails__clears_start_claim( + self, + economy_service, + base_params, + mock_simulation_api, + mock_budget_window_cache, + ): + mock_simulation_api.run_budget_window_batch.side_effect = RuntimeError( + "submit failed" + ) + + with pytest.raises(RuntimeError, match="submit failed"): + economy_service.get_budget_window_economic_impact(**base_params) + + mock_budget_window_cache.clear_starting_claim.assert_called_once_with( + "budget-window-cache-key", MOCK_PROCESS_ID ) def test__given_cliff_target__raises_value_error( @@ -1228,39 +1238,14 @@ def test__given_end_year_after_2099__raises_value_error( ): economy_service.get_budget_window_economic_impact(**base_params) - def test__given_failed_tracking_row_and_unavailable_batch__returns_stored_error( - self, - economy_service, - base_params, - mock_reform_impacts_service, - mock_simulation_api, - ): - mock_reform_impacts_service.get_all_reform_impacts.return_value = [ - create_mock_reform_impact( - status="error", - execution_id="fc-budget-123", - time_period="budget_window:2026:3", - message="Stored batch failure", - ) - ] - mock_simulation_api.get_budget_window_batch_by_id.side_effect = Exception( - "batch lookup failed" - ) - - result = economy_service.get_budget_window_economic_impact(**base_params) - - assert result.status == ImpactStatus.ERROR - assert result.error == "Stored batch failure" - assert result.queued_years == ["2026", "2027", "2028"] - def test__given_runtime_cache_version__uses_versioned_cache_key_for_budget_window( self, economy_service, base_params, mock_country_package_versions, mock_get_dataset_version, - mock_reform_impacts_service, mock_simulation_api, + mock_budget_window_cache, mock_logger, mock_datetime, mock_numpy_random, @@ -1275,21 +1260,9 @@ def test__given_runtime_cache_version__uses_versioned_cache_key_for_budget_windo result = economy_service.get_budget_window_economic_impact(**base_params) assert result.status == ImpactStatus.COMPUTING - mock_reform_impacts_service.get_all_reform_impacts.assert_called_once() - assert ( - mock_reform_impacts_service.get_all_reform_impacts.call_args.args[5] - == "budget_window:2026:3" - ) - assert ( - mock_reform_impacts_service.get_all_reform_impacts.call_args.args[7] - == cache_version - ) - assert ( - mock_reform_impacts_service.set_reform_impact.call_args.kwargs[ - "api_version" - ] - == cache_version - ) + cache_key_kwargs = mock_budget_window_cache.build_key.call_args.kwargs + assert cache_key_kwargs["time_period"] == "budget_window:2026:3" + assert cache_key_kwargs["api_version"] == cache_version class TestGetPreviousImpacts: @pytest.fixture diff --git a/tests/unit/services/test_reform_impacts_service.py b/tests/unit/services/test_reform_impacts_service.py index 106cf8757..ebe671879 100644 --- a/tests/unit/services/test_reform_impacts_service.py +++ b/tests/unit/services/test_reform_impacts_service.py @@ -6,77 +6,36 @@ class TestReformImpactsService: - def test__given_remote_database_missing_columns__ensure_remote_schema_adds_them( - self, monkeypatch - ): + def test__given_reform_impact_lookup__does_not_manage_schema(self, monkeypatch): service = ReformImpactsService() - show_columns_result = MagicMock() - show_columns_result.fetchall.return_value = [ - {"Field": "reform_impact_id"}, - {"Field": "status"}, - {"Field": "start_time"}, - ] - alter_dataset_result = MagicMock() - alter_execution_result = MagicMock() - alter_end_time_result = MagicMock() - + select_result = MagicMock() + select_result.fetchall.return_value = [] mock_database = MagicMock() mock_database.local = False - mock_database.query.side_effect = [ - show_columns_result, - alter_dataset_result, - alter_execution_result, - alter_end_time_result, - ] + mock_database.query.return_value = select_result monkeypatch.setattr( "policyengine_api.services.reform_impacts_service.database", mock_database, ) - service._ensure_remote_schema() - - assert mock_database.query.call_args_list[0].args == ( - "SHOW COLUMNS FROM reform_impact", - ) - assert mock_database.query.call_args_list[1].args == ( - "ALTER TABLE reform_impact ADD COLUMN dataset VARCHAR(255) NOT NULL DEFAULT 'default'", - ) - assert mock_database.query.call_args_list[2].args == ( - "ALTER TABLE reform_impact ADD COLUMN execution_id VARCHAR(255) NULL", - ) - assert mock_database.query.call_args_list[3].args == ( - "ALTER TABLE reform_impact ADD COLUMN end_time DATETIME NULL", - ) - - def test__given_remote_database_existing_columns__ensure_remote_schema_skips_alter( - self, monkeypatch - ): - service = ReformImpactsService() - - show_columns_result = MagicMock() - show_columns_result.fetchall.return_value = [ - {"Field": "reform_impact_id"}, - {"Field": "status"}, - {"Field": "start_time"}, - {"Field": "dataset"}, - {"Field": "execution_id"}, - {"Field": "end_time"}, - ] - - mock_database = MagicMock() - mock_database.local = False - mock_database.query.return_value = show_columns_result - - monkeypatch.setattr( - "policyengine_api.services.reform_impacts_service.database", - mock_database, + service.get_all_reform_impacts( + "us", + 123, + 456, + "us", + "enhanced_cps", + "2026", + "[option=value]", + "e1cache01", ) - service._ensure_remote_schema() - - mock_database.query.assert_called_once_with("SHOW COLUMNS FROM reform_impact") + mock_database.query.assert_called_once() + query = mock_database.query.call_args.args[0] + assert query.startswith("SELECT reform_impact_json") + assert not query.startswith("ALTER") + assert not query.startswith("SHOW") def test__given_remote_database__claim_lock_uses_advisory_lock(self, monkeypatch): service = ReformImpactsService() From c968a2e03e255e0812f7f09ce0a76079d161cd88 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 1 May 2026 20:09:20 +0200 Subject: [PATCH 16/27] Ignore budget-window version override --- policyengine_api/routes/economy_routes.py | 3 ++- .../to_refactor/python/test_economy_budget_window_routes.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/policyengine_api/routes/economy_routes.py b/policyengine_api/routes/economy_routes.py index d772697c2..c6b1a6b2d 100644 --- a/policyengine_api/routes/economy_routes.py +++ b/policyengine_api/routes/economy_routes.py @@ -132,7 +132,8 @@ def get_budget_window_economic_impact( "Budget-window calculations only support target=general" ) - api_version = options.pop("version", COUNTRY_PACKAGE_VERSIONS.get(country_id)) + options.pop("version", None) + api_version = COUNTRY_PACKAGE_VERSIONS.get(country_id) try: economic_impact_result: BudgetWindowEconomicImpactResult = ( diff --git a/tests/to_refactor/python/test_economy_budget_window_routes.py b/tests/to_refactor/python/test_economy_budget_window_routes.py index f185f8c28..06c55ad34 100644 --- a/tests/to_refactor/python/test_economy_budget_window_routes.py +++ b/tests/to_refactor/python/test_economy_budget_window_routes.py @@ -1,6 +1,8 @@ import json from unittest.mock import Mock, patch +from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS + @patch( "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" @@ -87,7 +89,7 @@ def test_budget_window_route_rejects_end_year_after_2099(rest_client): @patch( "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" ) -def test_budget_window_route_passes_version_to_service( +def test_budget_window_route_ignores_version_override( mock_get_budget_window_economic_impact, rest_client ): mock_result = Mock() @@ -128,6 +130,6 @@ def test_budget_window_route_passes_version_to_service( start_year="2026", window_size=2, options={}, - api_version="1.2.3", + api_version=COUNTRY_PACKAGE_VERSIONS.get("us"), target="general", ) From ad461c39518fafe080b52d2e73349847bd6ccc2e Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 1 May 2026 21:21:51 +0200 Subject: [PATCH 17/27] Stop passing budget-window version --- policyengine_api/routes/economy_routes.py | 2 -- policyengine_api/services/economy_service.py | 5 +---- .../to_refactor/python/test_economy_budget_window_routes.py | 3 --- tests/unit/services/test_economy_service.py | 1 - 4 files changed, 1 insertion(+), 10 deletions(-) diff --git a/policyengine_api/routes/economy_routes.py b/policyengine_api/routes/economy_routes.py index c6b1a6b2d..b078f0de7 100644 --- a/policyengine_api/routes/economy_routes.py +++ b/policyengine_api/routes/economy_routes.py @@ -133,7 +133,6 @@ def get_budget_window_economic_impact( ) options.pop("version", None) - api_version = COUNTRY_PACKAGE_VERSIONS.get(country_id) try: economic_impact_result: BudgetWindowEconomicImpactResult = ( @@ -146,7 +145,6 @@ def get_budget_window_economic_impact( start_year=start_year, window_size=window_size, options=options, - api_version=api_version, target=target, ) ) diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 0b16ffe99..b455e8ea7 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -283,7 +283,6 @@ def get_budget_window_economic_impact( start_year: str, window_size: int, options: dict, - api_version: str, target: Literal["general", "cliff"] = "general", max_active_years: int = BUDGET_WINDOW_MAX_ACTIVE_YEARS, ) -> BudgetWindowEconomicImpactResult: @@ -321,7 +320,6 @@ def get_budget_window_economic_impact( start_year=start_year, window_size=window_size, options=options, - api_version=api_version, target=target, ) cache_key = self._build_budget_window_cache_key(setup_options) @@ -394,7 +392,6 @@ def _build_budget_window_setup_options( start_year: str, window_size: int, options: dict, - api_version: str, target: Literal["general", "cliff"], ) -> EconomicImpactSetupOptions: return self._build_economic_impact_setup_options( @@ -408,7 +405,7 @@ def _build_budget_window_setup_options( window_size=window_size, ), options=dict(options), - api_version=api_version, + api_version=COUNTRY_PACKAGE_VERSIONS.get(country_id), target=target, ) diff --git a/tests/to_refactor/python/test_economy_budget_window_routes.py b/tests/to_refactor/python/test_economy_budget_window_routes.py index 06c55ad34..e32e586af 100644 --- a/tests/to_refactor/python/test_economy_budget_window_routes.py +++ b/tests/to_refactor/python/test_economy_budget_window_routes.py @@ -1,8 +1,6 @@ import json from unittest.mock import Mock, patch -from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS - @patch( "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" @@ -130,6 +128,5 @@ def test_budget_window_route_ignores_version_override( start_year="2026", window_size=2, options={}, - api_version=COUNTRY_PACKAGE_VERSIONS.get("us"), target="general", ) diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index f1f72bf87..d3101ac9e 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -985,7 +985,6 @@ def base_params(self): "start_year": "2026", "window_size": 3, "options": MOCK_OPTIONS, - "api_version": MOCK_API_VERSION, "target": "general", } From a64db3f6e18dabe6c91ab29317aa6793eb96275e Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 1 May 2026 21:44:50 +0200 Subject: [PATCH 18/27] Align budget-window version handling --- policyengine_api/routes/economy_routes.py | 3 ++- policyengine_api/services/economy_service.py | 5 ++++- .../to_refactor/python/test_economy_budget_window_routes.py | 3 ++- tests/unit/services/test_economy_service.py | 1 + 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/policyengine_api/routes/economy_routes.py b/policyengine_api/routes/economy_routes.py index b078f0de7..d772697c2 100644 --- a/policyengine_api/routes/economy_routes.py +++ b/policyengine_api/routes/economy_routes.py @@ -132,7 +132,7 @@ def get_budget_window_economic_impact( "Budget-window calculations only support target=general" ) - options.pop("version", None) + api_version = options.pop("version", COUNTRY_PACKAGE_VERSIONS.get(country_id)) try: economic_impact_result: BudgetWindowEconomicImpactResult = ( @@ -145,6 +145,7 @@ def get_budget_window_economic_impact( start_year=start_year, window_size=window_size, options=options, + api_version=api_version, target=target, ) ) diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index b455e8ea7..0b16ffe99 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -283,6 +283,7 @@ def get_budget_window_economic_impact( start_year: str, window_size: int, options: dict, + api_version: str, target: Literal["general", "cliff"] = "general", max_active_years: int = BUDGET_WINDOW_MAX_ACTIVE_YEARS, ) -> BudgetWindowEconomicImpactResult: @@ -320,6 +321,7 @@ def get_budget_window_economic_impact( start_year=start_year, window_size=window_size, options=options, + api_version=api_version, target=target, ) cache_key = self._build_budget_window_cache_key(setup_options) @@ -392,6 +394,7 @@ def _build_budget_window_setup_options( start_year: str, window_size: int, options: dict, + api_version: str, target: Literal["general", "cliff"], ) -> EconomicImpactSetupOptions: return self._build_economic_impact_setup_options( @@ -405,7 +408,7 @@ def _build_budget_window_setup_options( window_size=window_size, ), options=dict(options), - api_version=COUNTRY_PACKAGE_VERSIONS.get(country_id), + api_version=api_version, target=target, ) diff --git a/tests/to_refactor/python/test_economy_budget_window_routes.py b/tests/to_refactor/python/test_economy_budget_window_routes.py index e32e586af..f185f8c28 100644 --- a/tests/to_refactor/python/test_economy_budget_window_routes.py +++ b/tests/to_refactor/python/test_economy_budget_window_routes.py @@ -87,7 +87,7 @@ def test_budget_window_route_rejects_end_year_after_2099(rest_client): @patch( "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" ) -def test_budget_window_route_ignores_version_override( +def test_budget_window_route_passes_version_to_service( mock_get_budget_window_economic_impact, rest_client ): mock_result = Mock() @@ -128,5 +128,6 @@ def test_budget_window_route_ignores_version_override( start_year="2026", window_size=2, options={}, + api_version="1.2.3", target="general", ) diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index d3101ac9e..f1f72bf87 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -985,6 +985,7 @@ def base_params(self): "start_year": "2026", "window_size": 3, "options": MOCK_OPTIONS, + "api_version": MOCK_API_VERSION, "target": "general", } From 09ba29a8bd0d5ffc618ff3a34a81010236809dec Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 4 May 2026 19:10:47 +0200 Subject: [PATCH 19/27] Patch remote database accessor in unit tests --- tests/unit/conftest.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 5e2a983ad..485778a68 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -111,8 +111,17 @@ def override_database(test_db, monkeypatch): # Patch at the root module level where database is defined import policyengine_api.data + import policyengine_api.data.data as data_module + + def get_test_remote_database(): + return test_db monkeypatch.setattr(policyengine_api.data, "database", test_db) + monkeypatch.setattr( + policyengine_api.data, "get_remote_database", get_test_remote_database + ) + monkeypatch.setattr(data_module, "remote_database", test_db) + monkeypatch.setattr(data_module, "get_remote_database", get_test_remote_database) # Also patch the module-level variable for any existing imports import sys @@ -123,5 +132,11 @@ def override_database(test_db, monkeypatch): monkeypatch.setattr(module, "database", test_db) if hasattr(module, "local_database"): monkeypatch.setattr(module, "local_database", test_db) + if hasattr(module, "remote_database"): + monkeypatch.setattr(module, "remote_database", test_db) + if hasattr(module, "get_remote_database"): + monkeypatch.setattr( + module, "get_remote_database", get_test_remote_database + ) yield test_db From c1c5ea38b1add4ac6d976e325c3b8de0b78dab8f Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 4 May 2026 20:47:52 +0200 Subject: [PATCH 20/27] Add budget-window cache test coverage --- .github/workflows/push.yml | 2 +- .../test_live_budget_window_cache.py | 74 +++++ .../test_economy_budget_window_routes.py | 99 ++++-- tests/unit/data/test_sqlalchemy_v2.py | 43 ++- tests/unit/libs/test_simulation_api_modal.py | 93 ++++++ .../unit/services/test_budget_window_cache.py | 171 ++++++++++ tests/unit/services/test_economy_service.py | 243 ++++++++++++++ .../services/test_reform_impacts_service.py | 298 ++++++++++++++++++ 8 files changed, 1002 insertions(+), 21 deletions(-) create mode 100644 tests/integration/test_live_budget_window_cache.py diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 1ff895b02..cb5289d74 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -226,7 +226,7 @@ jobs: - name: Install staging test dependencies run: pip install pytest httpx - name: Run staging smoke test - run: python -m pytest tests/integration/test_live_calculate.py tests/integration/test_live_economy.py -v + run: python -m pytest tests/integration/test_live_calculate.py tests/integration/test_live_economy.py tests/integration/test_live_budget_window_cache.py -v env: API_BASE_URL: ${{ needs.deploy-staging.outputs.url }} STAGING_API_TEST_PROBE_ID: ${{ needs.deploy-staging.outputs.version }} diff --git a/tests/integration/test_live_budget_window_cache.py b/tests/integration/test_live_budget_window_cache.py new file mode 100644 index 000000000..5ea5f4f38 --- /dev/null +++ b/tests/integration/test_live_budget_window_cache.py @@ -0,0 +1,74 @@ +import json +import os +import time +from pathlib import Path + + +INTEGRATION_TIMEOUT_SECONDS = float( + os.environ.get("STAGING_API_TEST_TIMEOUT_SECONDS", "900") +) +INTEGRATION_POLL_INTERVAL_SECONDS = float( + os.environ.get("STAGING_API_TEST_POLL_INTERVAL_SECONDS", "5") +) + + +def _load_reform_payload(filename: str) -> dict: + return json.loads( + (Path(__file__).resolve().parents[1] / "data" / filename).read_text( + encoding="utf-8" + ) + ) + + +def _poll_budget_window(api_client, path: str, params: dict) -> dict: + deadline = time.monotonic() + INTEGRATION_TIMEOUT_SECONDS + + while True: + response = api_client.get(path, params=params) + response.raise_for_status() + payload = response.json() + + if payload["status"] != "computing": + return payload + + assert time.monotonic() < deadline, ( + f"Timed out polling the budget-window route; last response was {payload}" + ) + time.sleep(INTEGRATION_POLL_INTERVAL_SECONDS) + + +def test_live_budget_window_completed_result_cache(api_client, integration_probe_id): + metadata_response = api_client.get("/us/metadata") + metadata_response.raise_for_status() + current_law_id = metadata_response.json()["result"]["current_law_id"] + + policy_response = api_client.post( + "/us/policy", + json=_load_reform_payload("utah_reform.json"), + ) + assert policy_response.status_code in (200, 201) + policy_id = policy_response.json()["result"]["policy_id"] + + path = f"/us/economy/{policy_id}/over/{current_law_id}/budget-window" + params = { + "region": "ut", + "start_year": "2025", + "window_size": 1, + "staging_probe": f"{integration_probe_id}-budget-window-cache", + } + + first_payload = _poll_budget_window(api_client, path, params) + + assert first_payload["status"] == "ok", first_payload + assert first_payload["progress"] == 100, first_payload + assert first_payload["result"] is not None, first_payload + assert first_payload["result"]["kind"] == "budgetWindow", first_payload + assert first_payload["result"]["windowSize"] == 1, first_payload + + second_response = api_client.get(path, params=params) + second_response.raise_for_status() + second_payload = second_response.json() + + assert second_payload["status"] == "ok", second_payload + assert second_payload["progress"] == 100, second_payload + assert second_payload["result"] == first_payload["result"] diff --git a/tests/to_refactor/python/test_economy_budget_window_routes.py b/tests/to_refactor/python/test_economy_budget_window_routes.py index f185f8c28..e641a2ca8 100644 --- a/tests/to_refactor/python/test_economy_budget_window_routes.py +++ b/tests/to_refactor/python/test_economy_budget_window_routes.py @@ -2,6 +2,28 @@ from unittest.mock import Mock, patch +def _mock_budget_window_result(): + mock_result = Mock() + mock_result.to_dict.return_value = { + "status": "ok", + "message": None, + "data": { + "kind": "budgetWindow", + "startYear": "2026", + "endYear": "2027", + "windowSize": 2, + "annualImpacts": [], + "totals": {}, + }, + "progress": 100, + "completed_years": ["2026", "2027"], + "computing_years": [], + "queued_years": [], + "error": None, + } + return mock_result + + @patch( "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" ) @@ -21,6 +43,42 @@ def test_budget_window_route_rejects_cliff_target( mock_get_budget_window_economic_impact.assert_not_called() +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_requires_region( + mock_get_budget_window_economic_impact, rest_client +): + response = rest_client.get( + "/us/economy/123/over/456/budget-window?start_year=2026&window_size=2" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert data["message"] == "Missing required query parameter: region" + mock_get_budget_window_economic_impact.assert_not_called() + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_requires_start_year( + mock_get_budget_window_economic_impact, rest_client +): + response = rest_client.get( + "/us/economy/123/over/456/budget-window?region=us&window_size=2" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert data["message"] == "Missing required query parameter: start_year" + mock_get_budget_window_economic_impact.assert_not_called() + + @patch( "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" ) @@ -90,24 +148,7 @@ def test_budget_window_route_rejects_end_year_after_2099(rest_client): def test_budget_window_route_passes_version_to_service( mock_get_budget_window_economic_impact, rest_client ): - mock_result = Mock() - mock_result.to_dict.return_value = { - "status": "ok", - "message": None, - "data": { - "kind": "budgetWindow", - "startYear": "2026", - "endYear": "2027", - "windowSize": 2, - "annualImpacts": [], - "totals": {}, - }, - "progress": 100, - "completed_years": ["2026", "2027"], - "computing_years": [], - "queued_years": [], - "error": None, - } + mock_result = _mock_budget_window_result() mock_get_budget_window_economic_impact.return_value = mock_result response = rest_client.get( @@ -131,3 +172,25 @@ def test_budget_window_route_passes_version_to_service( api_version="1.2.3", target="general", ) + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_uses_breakdown_dataset_for_us_national_request( + mock_get_budget_window_economic_impact, rest_client +): + mock_get_budget_window_economic_impact.return_value = _mock_budget_window_result() + + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2026&window_size=2" + "&include_district_breakdowns=true" + ) + + assert response.status_code == 200 + mock_get_budget_window_economic_impact.assert_called_once() + assert ( + mock_get_budget_window_economic_impact.call_args.kwargs["dataset"] + == "national-with-breakdowns" + ) diff --git a/tests/unit/data/test_sqlalchemy_v2.py b/tests/unit/data/test_sqlalchemy_v2.py index e873521a2..0a5887ea8 100644 --- a/tests/unit/data/test_sqlalchemy_v2.py +++ b/tests/unit/data/test_sqlalchemy_v2.py @@ -10,10 +10,15 @@ (dict(row) and row["key"]). """ -import sqlalchemy from unittest.mock import MagicMock -from policyengine_api.data.data import _ResultProxy, PolicyEngineDatabase +import sqlalchemy + +from policyengine_api.data.data import ( + _ResultProxy, + PolicyEngineDatabase, + get_remote_database as _real_get_remote_database, +) class TestSQLAlchemyVersion: @@ -98,6 +103,15 @@ def test_result_proxy_for_insert_statement(self): assert proxy.fetchone() is None assert proxy.fetchall() == [] + def test_result_proxy_preserves_rowcount_for_write_statement(self): + engine = sqlalchemy.create_engine("sqlite://") + with engine.connect() as conn: + conn.exec_driver_sql("CREATE TABLE test (id INTEGER PRIMARY KEY)") + result = conn.exec_driver_sql("INSERT INTO test VALUES (1)") + proxy = _ResultProxy(result) + + assert proxy.rowcount == 1 + class TestRemoteQueryPath: """Test the non-local query path that uses SQLAlchemy engine @@ -211,3 +225,28 @@ def fake_create_engine(url, **kwargs): assert creator() is first_connection assert creator() is second_connection assert captured_kwargs["pool_pre_ping"] is True + + def test_get_remote_database_lazily_constructs_and_reuses_remote_database( + self, monkeypatch + ): + created_databases = [] + + class FakeDatabase: + def __init__(self, *, local, initialize): + self.local = local + self.initialize = initialize + created_databases.append(self) + + monkeypatch.setattr("policyengine_api.data.data.remote_database", None) + monkeypatch.setattr( + "policyengine_api.data.data.PolicyEngineDatabase", + FakeDatabase, + ) + + first = _real_get_remote_database() + second = _real_get_remote_database() + + assert first is second + assert len(created_databases) == 1 + assert created_databases[0].local is False + assert created_databases[0].initialize is False diff --git a/tests/unit/libs/test_simulation_api_modal.py b/tests/unit/libs/test_simulation_api_modal.py index a2b4cdcdb..373678764 100644 --- a/tests/unit/libs/test_simulation_api_modal.py +++ b/tests/unit/libs/test_simulation_api_modal.py @@ -290,6 +290,29 @@ def test__given_telemetry_payload__then_preserves_it_in_post_body( call_args = mock_httpx_client.post.call_args assert call_args[1]["json"]["_telemetry"]["run_id"] == MOCK_RUN_ID + def test__given_model_and_data_versions__then_translates_payload_for_modal( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.post.return_value = create_mock_httpx_response( + status_code=202, + json_data=MOCK_SUBMIT_RESPONSE_SUCCESS, + ) + payload = { + **MOCK_SIMULATION_PAYLOAD, + "model_version": "1.459.0", + "data_version": "1.77.0", + } + api = SimulationAPIModal() + + api.run(payload) + + posted_payload = mock_httpx_client.post.call_args.kwargs["json"] + assert posted_payload["version"] == "1.459.0" + assert "model_version" not in posted_payload + assert "data_version" not in posted_payload + def test__given_http_error__then_raises_exception( self, mock_httpx_client, @@ -340,6 +363,25 @@ def test__given_country_and_version__then_returns_registered_app( assert app_name == MOCK_RESOLVED_APP_NAME assert resolved_version == "1.459.0" + def test__given_unknown_version__then_raises_value_error( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.get.return_value = create_mock_httpx_response( + status_code=200, + json_data={ + "latest": "1.459.0", + "1.459.0": MOCK_RESOLVED_APP_NAME, + }, + ) + api = SimulationAPIModal() + + with pytest.raises( + ValueError, match="Unknown version 9.9.9 for country us" + ): + api.resolve_app_name("us", "9.9.9") + class TestRunBudgetWindowBatch: def test__given_valid_payload__then_returns_batch_execution( self, @@ -359,6 +401,29 @@ def test__given_valid_payload__then_returns_batch_execution( call_args = mock_httpx_client.post.call_args assert "/simulate/economy/budget-window" in call_args[0][0] + def test__given_model_and_data_versions__then_translates_payload_for_modal( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.post.return_value = create_mock_httpx_response( + status_code=202, + json_data=MOCK_BATCH_SUBMIT_RESPONSE_SUCCESS, + ) + payload = { + **MOCK_SIMULATION_PAYLOAD, + "model_version": "1.459.0", + "data_version": "1.77.0", + } + api = SimulationAPIModal() + + api.run_budget_window_batch(payload) + + posted_payload = mock_httpx_client.post.call_args.kwargs["json"] + assert posted_payload["version"] == "1.459.0" + assert "model_version" not in posted_payload + assert "data_version" not in posted_payload + def test__given_http_error__then_raises_exception( self, mock_httpx_client, @@ -374,6 +439,20 @@ def test__given_http_error__then_raises_exception( with pytest.raises(httpx.HTTPStatusError): api.run_budget_window_batch(MOCK_SIMULATION_PAYLOAD) + def test__given_network_error__then_raises_exception( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.post.side_effect = httpx.RequestError("Connection failed") + api = SimulationAPIModal() + + with pytest.raises(httpx.RequestError): + api.run_budget_window_batch(MOCK_SIMULATION_PAYLOAD_WITH_TELEMETRY) + + log_payload = mock_modal_logger.log_struct.call_args.args[0] + assert log_payload["run_id"] == MOCK_RUN_ID + class TestGetExecutionById: def test__given_running_job__then_returns_running_status( self, @@ -535,6 +614,20 @@ def test__given_unexpected_http_error__then_raises_exception( with pytest.raises(httpx.HTTPStatusError): api.get_budget_window_batch_by_id(MOCK_BATCH_JOB_ID) + def test__given_network_error__then_raises_exception( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.get.side_effect = httpx.RequestError("Connection failed") + api = SimulationAPIModal() + + with pytest.raises(httpx.RequestError): + api.get_budget_window_batch_by_id(MOCK_BATCH_JOB_ID) + + log_payload = mock_modal_logger.log_struct.call_args.args[0] + assert MOCK_BATCH_JOB_ID in log_payload["message"] + class TestGetExecutionId: def test__given_execution__then_returns_job_id(self, mock_httpx_client): # Given diff --git a/tests/unit/services/test_budget_window_cache.py b/tests/unit/services/test_budget_window_cache.py index ad23d26aa..233ba4fcd 100644 --- a/tests/unit/services/test_budget_window_cache.py +++ b/tests/unit/services/test_budget_window_cache.py @@ -1,3 +1,7 @@ +from unittest.mock import MagicMock + +import pytest + from policyengine_api.services.budget_window_cache import BudgetWindowCache @@ -18,6 +22,25 @@ def delete(self, key): self.values.pop(key, None) +class RaisingRedis: + def __init__(self, *, method): + self.method = method + + def get(self, key): + if self.method == "get": + raise RuntimeError("redis unavailable") + return None + + def set(self, key, value, nx=False, ex=None): + if self.method == "set": + raise RuntimeError("redis unavailable") + return True + + def delete(self, key): + if self.method == "delete": + raise RuntimeError("redis unavailable") + + def test_build_key_is_stable_for_request_identity(): cache = BudgetWindowCache(client=FakeRedis()) @@ -70,3 +93,151 @@ def test_completed_result_round_trips(): cache.set_completed_result("budget_window:v1:us:key", result) assert cache.get_completed_result("budget_window:v1:us:key") == result + + +def test_get_completed_result_returns_none_for_empty_payload(): + redis_client = FakeRedis() + redis_client.values["budget_window:v1:us:key:result"] = "" + cache = BudgetWindowCache(client=redis_client) + + assert cache.get_completed_result("budget_window:v1:us:key") is None + + +def test_get_completed_result_returns_none_for_invalid_json(monkeypatch): + mock_logger = MagicMock() + monkeypatch.setattr( + "policyengine_api.services.budget_window_cache.logger", + mock_logger, + ) + redis_client = FakeRedis() + redis_client.values["budget_window:v1:us:key:result"] = "{not-json" + cache = BudgetWindowCache(client=redis_client) + + assert cache.get_completed_result("budget_window:v1:us:key") is None + assert mock_logger.log_struct.call_args.kwargs["severity"] == "WARNING" + + +def test_get_completed_result_reraises_read_errors(monkeypatch): + mock_logger = MagicMock() + monkeypatch.setattr( + "policyengine_api.services.budget_window_cache.logger", + mock_logger, + ) + cache = BudgetWindowCache(client=RaisingRedis(method="get")) + + with pytest.raises(RuntimeError, match="redis unavailable"): + cache.get_completed_result("budget_window:v1:us:key") + + assert mock_logger.log_struct.call_args.kwargs["severity"] == "WARNING" + + +def test_set_completed_result_logs_write_errors(monkeypatch): + mock_logger = MagicMock() + monkeypatch.setattr( + "policyengine_api.services.budget_window_cache.logger", + mock_logger, + ) + cache = BudgetWindowCache(client=RaisingRedis(method="set")) + + cache.set_completed_result("budget_window:v1:us:key", {"ok": True}) + + assert mock_logger.log_struct.call_args.kwargs["severity"] == "WARNING" + + +def test_get_batch_job_id_ignores_empty_non_string_and_starting_values(): + redis_client = FakeRedis() + cache = BudgetWindowCache(client=redis_client) + + redis_client.values["budget_window:v1:us:key:batch_job_id"] = "" + assert cache.get_batch_job_id("budget_window:v1:us:key") is None + + redis_client.values["budget_window:v1:us:key:batch_job_id"] = 123 + assert cache.get_batch_job_id("budget_window:v1:us:key") is None + + redis_client.values["budget_window:v1:us:key:batch_job_id"] = "starting:process-1" + assert cache.get_batch_job_id("budget_window:v1:us:key") is None + + +def test_get_batch_job_id_reraises_read_errors(monkeypatch): + mock_logger = MagicMock() + monkeypatch.setattr( + "policyengine_api.services.budget_window_cache.logger", + mock_logger, + ) + cache = BudgetWindowCache(client=RaisingRedis(method="get")) + + with pytest.raises(RuntimeError, match="redis unavailable"): + cache.get_batch_job_id("budget_window:v1:us:key") + + assert mock_logger.log_struct.call_args.kwargs["severity"] == "WARNING" + + +def test_claim_batch_start_reraises_claim_errors(monkeypatch): + mock_logger = MagicMock() + monkeypatch.setattr( + "policyengine_api.services.budget_window_cache.logger", + mock_logger, + ) + cache = BudgetWindowCache(client=RaisingRedis(method="set")) + + with pytest.raises(RuntimeError, match="redis unavailable"): + cache.claim_batch_start("budget_window:v1:us:key", "process-1") + + assert mock_logger.log_struct.call_args.kwargs["severity"] == "WARNING" + + +def test_store_batch_job_id_reraises_write_errors(monkeypatch): + mock_logger = MagicMock() + monkeypatch.setattr( + "policyengine_api.services.budget_window_cache.logger", + mock_logger, + ) + cache = BudgetWindowCache(client=RaisingRedis(method="set")) + + with pytest.raises(RuntimeError, match="redis unavailable"): + cache.store_batch_job_id("budget_window:v1:us:key", "fc-parent") + + assert mock_logger.log_struct.call_args.kwargs["severity"] == "WARNING" + + +def test_clear_starting_claim_deletes_only_matching_token(): + redis_client = FakeRedis() + cache = BudgetWindowCache(client=redis_client) + cache.claim_batch_start("budget_window:v1:us:key", "process-1") + + cache.clear_starting_claim("budget_window:v1:us:key", "process-2") + + assert ( + redis_client.values["budget_window:v1:us:key:batch_job_id"] + == "starting:process-1" + ) + + cache.clear_starting_claim("budget_window:v1:us:key", "process-1") + + assert "budget_window:v1:us:key:batch_job_id" not in redis_client.values + + +def test_clear_starting_claim_logs_and_swallows_errors(monkeypatch): + mock_logger = MagicMock() + monkeypatch.setattr( + "policyengine_api.services.budget_window_cache.logger", + mock_logger, + ) + cache = BudgetWindowCache(client=RaisingRedis(method="get")) + + cache.clear_starting_claim("budget_window:v1:us:key", "process-1") + + assert mock_logger.log_struct.call_args.kwargs["severity"] == "WARNING" + + +def test_clear_batch_job_id_logs_and_swallows_errors(monkeypatch): + mock_logger = MagicMock() + monkeypatch.setattr( + "policyengine_api.services.budget_window_cache.logger", + mock_logger, + ) + cache = BudgetWindowCache(client=RaisingRedis(method="delete")) + + cache.clear_batch_job_id("budget_window:v1:us:key") + + assert mock_logger.log_struct.call_args.kwargs["severity"] == "WARNING" diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index f1f72bf87..47f5efeaf 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -541,6 +541,112 @@ def test__given_claim_lock_timeout_and_existing_provisional_claim__returns_compu assert result.status == ImpactStatus.COMPUTING mock_simulation_api.run.assert_not_called() + def test__given_claim_lock_timeout_and_no_existing_claim__returns_computing( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_numpy_random, + ): + mock_reform_impacts_service.get_all_reform_impacts.side_effect = [ + [], + [], + ] + mock_reform_impacts_service.claim_lock.side_effect = TimeoutError( + "lock busy" + ) + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + mock_simulation_api.run.assert_not_called() + + def test__given_completed_impact_appears_after_lock__returns_cached_result( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_get_policyengine_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + completed_impact = create_mock_reform_impact(status="ok") + mock_reform_impacts_service.get_all_reform_impacts.side_effect = [ + [], + [completed_impact], + ] + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.OK + mock_simulation_api.run.assert_not_called() + + def test__given_computing_impact_appears_after_lock__returns_progress( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + computing_impact = create_mock_reform_impact(status="computing") + mock_reform_impacts_service.get_all_reform_impacts.side_effect = [ + [], + [computing_impact], + ] + mock_simulation_api.get_execution_status.return_value = "running" + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + mock_simulation_api.run.assert_not_called() + mock_simulation_api.get_execution_by_id.assert_called_once_with( + MOCK_EXECUTION_ID + ) + + def test__given_error_impact_appears_after_lock__returns_error( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + error_impact = create_mock_reform_impact( + status="error", + message="Failed before lock released", + ) + mock_reform_impacts_service.get_all_reform_impacts.side_effect = [ + [], + [error_impact], + ] + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.ERROR + assert result.message == "Failed before lock released" + mock_simulation_api.run.assert_not_called() + def test__given_stale_provisional_claim__expires_and_recreates_simulation( self, economy_service, @@ -653,6 +759,35 @@ def test__given_provisional_promotion_updates_zero_rows_but_newer_claim_exists__ == f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}" ) + def test__given_provisional_promotion_raises__inserts_replacement_tracking_row( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + mock_reform_impacts_service.get_all_reform_impacts.return_value = [] + mock_reform_impacts_service.update_reform_impact_execution_id.side_effect = RuntimeError( + "update failed" + ) + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + assert mock_reform_impacts_service.set_reform_impact.call_count == 2 + assert ( + mock_reform_impacts_service.set_reform_impact.call_args.kwargs[ + "execution_id" + ] + == MOCK_EXECUTION_ID + ) + def test__given_runtime_cache_version__uses_versioned_economy_cache_key( self, economy_service, @@ -1264,6 +1399,73 @@ def test__given_runtime_cache_version__uses_versioned_cache_key_for_budget_windo assert cache_key_kwargs["time_period"] == "budget_window:2026:3" assert cache_key_kwargs["api_version"] == cache_version + def test__given_legacy_us_region__normalizes_before_building_cache_key( + self, + economy_service, + base_params, + mock_reform_impacts_service, + mock_simulation_api, + mock_budget_window_cache, + ): + base_params["region"] = "ca" + base_params["dataset"] = "default" + + result = economy_service.get_budget_window_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + assert mock_budget_window_cache.build_key.call_args.kwargs["region"] == ( + "state/ca" + ) + + def test__given_unexpected_batch_status__raises_value_error( + self, + economy_service, + base_params, + mock_simulation_api, + mock_budget_window_cache, + ): + mock_budget_window_cache.get_batch_job_id.return_value = "fc-budget-123" + mock_simulation_api.get_budget_window_batch_by_id.return_value = ( + create_mock_budget_window_batch_execution( + batch_job_id="fc-budget-123", + status="paused", + ) + ) + + with pytest.raises( + ValueError, + match="Unexpected budget-window batch execution state: paused", + ): + economy_service.get_budget_window_economic_impact(**base_params) + + def test__given_missing_progress__computes_progress_from_completed_years( + self, economy_service + ): + result = economy_service._build_budget_window_computing_result( + total_years=4, + completed_years=["2026"], + computing_years=["2027", "2028", "2029"], + queued_years=[], + progress=None, + ) + + assert result.progress == 25 + assert result.message == "Scoring 2027, 2028 + 1 more (1 of 4 complete)..." + + def test__given_no_active_or_queued_years__uses_generic_progress_message( + self, economy_service + ): + result = economy_service._build_budget_window_computing_result( + total_years=3, + completed_years=["2026"], + computing_years=[], + queued_years=[], + progress=None, + ) + + assert result.progress == 33 + assert result.message == "Scoring budget window (1 of 3 complete)..." + class TestGetPreviousImpacts: @pytest.fixture def economy_service(self): @@ -1402,6 +1604,20 @@ def test__given_stale_provisional_impact__returns_none( assert result is None + def test__given_unknown_status__raises_value_error( + self, + economy_service, + setup_options, + mock_reform_impacts_service, + ): + unknown_impact = create_mock_reform_impact(status="mystery") + mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + unknown_impact + ] + + with pytest.raises(ValueError, match="Unknown impact status: mystery"): + economy_service._get_existing_economic_impact(setup_options) + class TestDetermineImpactAction: @pytest.fixture def economy_service(self): @@ -1454,6 +1670,33 @@ def test__given_unknown_status__raises_error(self, economy_service): economy_service._determine_impact_action(impact) assert "Unknown impact status: unknown" in str(exc_info.value) + def test__given_stale_provisional_iso_start_time__returns_create( + self, economy_service + ): + impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_stale", + start_time=( + datetime.datetime.now(datetime.timezone.utc) + - datetime.timedelta(seconds=PROVISIONAL_CLAIM_TTL_SECONDS + 1) + ).isoformat(), + ) + + assert ( + economy_service._determine_impact_action(impact) == ImpactAction.CREATE + ) + + def test__given_provisional_without_start_time__is_not_stale( + self, economy_service + ): + impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_pending", + ) + impact["start_time"] = None + + assert economy_service._is_stale_provisional_impact(impact) is False + class TestHandleExecutionState: @pytest.fixture def economy_service(self): diff --git a/tests/unit/services/test_reform_impacts_service.py b/tests/unit/services/test_reform_impacts_service.py index ebe671879..8ec6791c5 100644 --- a/tests/unit/services/test_reform_impacts_service.py +++ b/tests/unit/services/test_reform_impacts_service.py @@ -5,6 +5,32 @@ from policyengine_api.services.reform_impacts_service import ReformImpactsService +LOCK_KWARGS = { + "country_id": "us", + "policy_id": 123, + "baseline_policy_id": 456, + "region": "us", + "dataset": "enhanced_cps", + "time_period": "2026", + "options_hash": "[option=value]", + "api_version": "e1cache01", +} + + +def _mock_database(monkeypatch, *, local=False, query_result=None, side_effect=None): + mock_database = MagicMock() + mock_database.local = local + if side_effect is not None: + mock_database.query.side_effect = side_effect + elif query_result is not None: + mock_database.query.return_value = query_result + monkeypatch.setattr( + "policyengine_api.services.reform_impacts_service.database", + mock_database, + ) + return mock_database + + class TestReformImpactsService: def test__given_reform_impact_lookup__does_not_manage_schema(self, monkeypatch): service = ReformImpactsService() @@ -146,3 +172,275 @@ def test__given_remote_database_lock_timeout__claim_lock_raises(self, monkeypatc api_version="e1cache01", ): pass + + def test__given_local_database__claim_lock_uses_in_process_lock(self, monkeypatch): + service = ReformImpactsService() + mock_database = _mock_database(monkeypatch, local=True) + + with service.claim_lock(**LOCK_KWARGS): + mock_database.query.assert_not_called() + + mock_database.pool.connect.assert_not_called() + + def test__get_all_reform_impacts__queries_dataset_and_stable_order( + self, monkeypatch + ): + service = ReformImpactsService() + query_result = MagicMock() + query_result.fetchall.return_value = [{"status": "ok"}] + mock_database = _mock_database(monkeypatch, query_result=query_result) + + result = service.get_all_reform_impacts(**LOCK_KWARGS) + + assert result == [{"status": "ok"}] + query, params = mock_database.query.call_args.args + assert "AND dataset = ?" in query + assert "ORDER BY start_time DESC, reform_impact_id DESC" in query + assert params == ( + "us", + 123, + 456, + "us", + "2026", + "[option=value]", + "e1cache01", + "enhanced_cps", + ) + + def test__get_all_reform_impacts_by_options_hash_prefix__prefers_exact_hash( + self, monkeypatch + ): + service = ReformImpactsService() + query_result = MagicMock() + query_result.fetchall.return_value = [{"options_hash": "[option=value]"}] + mock_database = _mock_database(monkeypatch, query_result=query_result) + + result = service.get_all_reform_impacts_by_options_hash_prefix( + **LOCK_KWARGS, + options_hash_prefix="[option=%", + ) + + assert result == [{"options_hash": "[option=value]"}] + query, params = mock_database.query.call_args.args + assert "(options_hash = ? OR options_hash LIKE ? ESCAPE '\\')" in query + assert "ORDER BY CASE WHEN options_hash = ? THEN 0 ELSE 1 END" in query + assert params == ( + "us", + 123, + 456, + "us", + "2026", + "[option=value]", + "[option=%", + "e1cache01", + "enhanced_cps", + "[option=value]", + ) + + def test__set_reform_impact__inserts_tracking_row(self, monkeypatch): + service = ReformImpactsService() + mock_database = _mock_database(monkeypatch) + + service.set_reform_impact( + **LOCK_KWARGS, + options='{"option": "value"}', + status="computing", + reform_impact_json="{}", + start_time="2026-01-01 00:00:00", + execution_id="pending:job-1", + ) + + query, params = mock_database.query.call_args.args + assert query.startswith("INSERT INTO reform_impact") + assert params == ( + "us", + 123, + 456, + "us", + "enhanced_cps", + "2026", + '{"option": "value"}', + "[option=value]", + "computing", + "e1cache01", + "{}", + "2026-01-01 00:00:00", + "pending:job-1", + ) + + def test__update_reform_impact_execution_id__returns_rowcount(self, monkeypatch): + service = ReformImpactsService() + query_result = MagicMock() + query_result.rowcount = 1 + mock_database = _mock_database(monkeypatch, query_result=query_result) + + rowcount = service.update_reform_impact_execution_id( + country_id="us", + policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="2026", + options_hash="[option=value]", + current_execution_id="pending:job-1", + new_execution_id="fc-job-1", + ) + + assert rowcount == 1 + query, params = mock_database.query.call_args.args + assert "status = 'computing'" in query + assert params == ( + "fc-job-1", + "us", + 123, + 456, + "us", + "2026", + "[option=value]", + "enhanced_cps", + "pending:job-1", + ) + + def test__delete_reform_impact__only_deletes_computing_rows(self, monkeypatch): + service = ReformImpactsService() + mock_database = _mock_database(monkeypatch) + + service.delete_reform_impact( + country_id="us", + policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="2026", + options_hash="[option=value]", + ) + + query, params = mock_database.query.call_args.args + assert "status = 'computing'" in query + assert params == ( + "us", + 123, + 456, + "us", + "2026", + "[option=value]", + "enhanced_cps", + ) + + def test__set_error_reform_impact__updates_status_message_and_execution( + self, monkeypatch + ): + service = ReformImpactsService() + mock_database = _mock_database(monkeypatch) + + service.set_error_reform_impact( + country_id="us", + policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="2026", + options_hash="[option=value]", + message="failed", + execution_id="fc-job-1", + ) + + query, params = mock_database.query.call_args.args + assert query.startswith("UPDATE reform_impact SET status = ?, message = ?") + assert params[0] == "error" + assert params[1] == "failed" + assert params[-1] == "fc-job-1" + + def test__set_complete_reform_impact__updates_result_and_execution( + self, monkeypatch + ): + service = ReformImpactsService() + mock_database = _mock_database(monkeypatch) + + service.set_complete_reform_impact( + country_id="us", + reform_policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="2026", + options_hash="[option=value]", + reform_impact_json='{"ok": true}', + execution_id="fc-job-1", + ) + + query, params = mock_database.query.call_args.args + assert query.startswith("UPDATE reform_impact SET status = ?, message = ?") + assert params[0] == "ok" + assert params[1] == "Completed" + assert params[3] == '{"ok": true}' + assert params[-1] == "fc-job-1" + + @pytest.mark.parametrize( + "call_service", + [ + lambda service: service.get_all_reform_impacts(**LOCK_KWARGS), + lambda service: service.get_all_reform_impacts_by_options_hash_prefix( + **LOCK_KWARGS, + options_hash_prefix="[option=%", + ), + lambda service: service.set_reform_impact( + **LOCK_KWARGS, + options="{}", + status="computing", + reform_impact_json="{}", + start_time="2026-01-01 00:00:00", + execution_id="pending:job-1", + ), + lambda service: service.update_reform_impact_execution_id( + country_id="us", + policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="2026", + options_hash="[option=value]", + current_execution_id="pending:job-1", + new_execution_id="fc-job-1", + ), + lambda service: service.delete_reform_impact( + country_id="us", + policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="2026", + options_hash="[option=value]", + ), + lambda service: service.set_error_reform_impact( + country_id="us", + policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="2026", + options_hash="[option=value]", + message="failed", + execution_id="fc-job-1", + ), + lambda service: service.set_complete_reform_impact( + country_id="us", + reform_policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="2026", + options_hash="[option=value]", + reform_impact_json="{}", + execution_id="fc-job-1", + ), + ], + ) + def test__given_database_error__service_methods_reraise( + self, monkeypatch, call_service + ): + service = ReformImpactsService() + _mock_database(monkeypatch, side_effect=RuntimeError("db down")) + + with pytest.raises(RuntimeError, match="db down"): + call_service(service) From 35e7d2f9d0126749c244160d985c75708ad57542 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 4 May 2026 21:36:34 +0200 Subject: [PATCH 21/27] Harden budget-window cache responses --- policyengine_api/libs/simulation_api_modal.py | 10 +++++ policyengine_api/routes/economy_routes.py | 7 ++- policyengine_api/services/economy_service.py | 45 ++++++++++++++++--- .../test_live_budget_window_cache.py | 1 + .../test_economy_budget_window_routes.py | 21 ++++++++- tests/unit/libs/test_simulation_api_modal.py | 6 ++- tests/unit/services/test_economy_service.py | 39 ++++++++++++++++ 7 files changed, 121 insertions(+), 8 deletions(-) diff --git a/policyengine_api/libs/simulation_api_modal.py b/policyengine_api/libs/simulation_api_modal.py index 833c2684f..67b60591d 100644 --- a/policyengine_api/libs/simulation_api_modal.py +++ b/policyengine_api/libs/simulation_api_modal.py @@ -166,6 +166,16 @@ def run(self, payload: dict) -> ModalSimulationExecution: ) raise + except httpx.RequestError as e: + logger.log_struct( + { + "message": f"Modal API request error: {str(e)}", + "run_id": (payload.get("_telemetry") or {}).get("run_id"), + }, + severity="ERROR", + ) + raise + def run_budget_window_batch(self, payload: dict) -> ModalBudgetWindowBatchExecution: """ Submit a budget-window batch job to the Modal API. diff --git a/policyengine_api/routes/economy_routes.py b/policyengine_api/routes/economy_routes.py index d772697c2..cbecc16cd 100644 --- a/policyengine_api/routes/economy_routes.py +++ b/policyengine_api/routes/economy_routes.py @@ -12,6 +12,7 @@ economy_bp = Blueprint("economy", __name__) economy_service = EconomyService() +BUDGET_WINDOW_CACHE_HEADER = "X-PolicyEngine-Budget-Window-Cache" def _json_response(payload: dict, status: int = 200) -> Response: @@ -154,7 +155,7 @@ def get_budget_window_economic_impact( result_dict = economic_impact_result.to_dict() - return _json_response( + response = _json_response( { "status": result_dict["status"], "message": result_dict["message"], @@ -166,3 +167,7 @@ def get_budget_window_economic_impact( "error": result_dict["error"], } ) + cache_status = getattr(economic_impact_result, "cache_status", None) + if isinstance(cache_status, str) and cache_status: + response.headers[BUDGET_WINDOW_CACHE_HEADER] = cache_status + return response diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 0b16ffe99..455d4ecda 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -159,6 +159,7 @@ class BudgetWindowEconomicImpactResult(BaseModel): queued_years: list[str] = Field(default_factory=list) message: Optional[str] = None error: Optional[str] = None + cache_status: Optional[str] = None model_config = {"frozen": True} @@ -175,8 +176,15 @@ def to_dict(self) -> dict[str, Any]: } @classmethod - def completed(cls, data: dict) -> "BudgetWindowEconomicImpactResult": - return cls(status=ImpactStatus.OK, data=data, progress=100) + def completed( + cls, data: dict, *, cache_status: Optional[str] = None + ) -> "BudgetWindowEconomicImpactResult": + return cls( + status=ImpactStatus.OK, + data=data, + progress=100, + cache_status=cache_status, + ) @classmethod def computing( @@ -187,6 +195,7 @@ def computing( computing_years: list[str], queued_years: list[str], message: str, + cache_status: Optional[str] = None, ) -> "BudgetWindowEconomicImpactResult": return cls( status=ImpactStatus.COMPUTING, @@ -196,6 +205,7 @@ def computing( computing_years=computing_years, queued_years=queued_years, message=message, + cache_status=cache_status, ) @classmethod @@ -206,6 +216,7 @@ def failed( completed_years: Optional[list[str]] = None, computing_years: Optional[list[str]] = None, queued_years: Optional[list[str]] = None, + cache_status: Optional[str] = None, ) -> "BudgetWindowEconomicImpactResult": logger.log_struct({"message": message}, severity="ERROR") return cls( @@ -216,6 +227,7 @@ def failed( queued_years=queued_years or [], message=message, error=message, + cache_status=cache_status, ) @@ -328,7 +340,10 @@ def get_budget_window_economic_impact( cached_result = budget_window_cache.get_completed_result(cache_key) if cached_result is not None: - return BudgetWindowEconomicImpactResult.completed(cached_result) + return BudgetWindowEconomicImpactResult.completed( + cached_result, + cache_status="result-hit", + ) batch_job_id = budget_window_cache.get_batch_job_id(cache_key) if batch_job_id: @@ -337,10 +352,13 @@ def get_budget_window_economic_impact( cache_key=cache_key, total_years=len(years), queued_years_on_submit=years, + cache_status="batch-id-hit", ) claim_token = setup_options.process_id + cache_status = "starting-claim-hit" if budget_window_cache.claim_batch_start(cache_key, claim_token): + cache_status = "miss" try: batch_execution = self._start_budget_window_batch( setup_options=setup_options, @@ -361,6 +379,7 @@ def get_budget_window_economic_impact( computing_years=[], queued_years=years, progress=0, + cache_status=cache_status, ) except Exception as e: print(f"Error getting budget-window economic impact: {str(e)}") @@ -498,14 +517,26 @@ def _get_budget_window_result_from_batch_job_id( cache_key: str, total_years: int, queued_years_on_submit: list[str], + cache_status: Optional[str] = None, ) -> BudgetWindowEconomicImpactResult: batch_execution = simulation_api.get_budget_window_batch_by_id(batch_job_id) if batch_execution.status in EXECUTION_STATUSES_SUCCESS: - result = batch_execution.result or {} + result = batch_execution.result + if not isinstance(result, dict) or not result: + return BudgetWindowEconomicImpactResult.failed( + "Budget-window batch completed without a result", + completed_years=batch_execution.completed_years, + computing_years=batch_execution.running_years, + queued_years=batch_execution.queued_years or queued_years_on_submit, + cache_status=cache_status, + ) budget_window_cache.set_completed_result(cache_key, result) budget_window_cache.clear_batch_job_id(cache_key) - return BudgetWindowEconomicImpactResult.completed(result) + return BudgetWindowEconomicImpactResult.completed( + result, + cache_status=cache_status, + ) if batch_execution.status in EXECUTION_STATUSES_FAILURE: error_message = batch_execution.error or "Budget-window batch failed" @@ -514,6 +545,7 @@ def _get_budget_window_result_from_batch_job_id( completed_years=batch_execution.completed_years, computing_years=batch_execution.running_years, queued_years=batch_execution.queued_years or queued_years_on_submit, + cache_status=cache_status, ) if batch_execution.status in EXECUTION_STATUSES_PENDING: @@ -523,6 +555,7 @@ def _get_budget_window_result_from_batch_job_id( computing_years=batch_execution.running_years, queued_years=batch_execution.queued_years, progress=batch_execution.progress, + cache_status=cache_status, ) raise ValueError( @@ -537,6 +570,7 @@ def _build_budget_window_computing_result( computing_years: list[str], queued_years: list[str], progress: Optional[int] = None, + cache_status: Optional[str] = None, ) -> BudgetWindowEconomicImpactResult: resolved_progress = progress if resolved_progress is None: @@ -553,6 +587,7 @@ def _build_budget_window_computing_result( computing_years=computing_years, queued_years=queued_years, ), + cache_status=cache_status, ) def _build_economic_impact_setup_options( diff --git a/tests/integration/test_live_budget_window_cache.py b/tests/integration/test_live_budget_window_cache.py index 5ea5f4f38..820415dc9 100644 --- a/tests/integration/test_live_budget_window_cache.py +++ b/tests/integration/test_live_budget_window_cache.py @@ -72,3 +72,4 @@ def test_live_budget_window_completed_result_cache(api_client, integration_probe assert second_payload["status"] == "ok", second_payload assert second_payload["progress"] == 100, second_payload assert second_payload["result"] == first_payload["result"] + assert second_response.headers["X-PolicyEngine-Budget-Window-Cache"] == "result-hit" diff --git a/tests/to_refactor/python/test_economy_budget_window_routes.py b/tests/to_refactor/python/test_economy_budget_window_routes.py index e641a2ca8..9821502d7 100644 --- a/tests/to_refactor/python/test_economy_budget_window_routes.py +++ b/tests/to_refactor/python/test_economy_budget_window_routes.py @@ -2,8 +2,9 @@ from unittest.mock import Mock, patch -def _mock_budget_window_result(): +def _mock_budget_window_result(cache_status=None): mock_result = Mock() + mock_result.cache_status = cache_status mock_result.to_dict.return_value = { "status": "ok", "message": None, @@ -194,3 +195,21 @@ def test_budget_window_route_uses_breakdown_dataset_for_us_national_request( mock_get_budget_window_economic_impact.call_args.kwargs["dataset"] == "national-with-breakdowns" ) + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_sets_cache_status_header( + mock_get_budget_window_economic_impact, rest_client +): + mock_get_budget_window_economic_impact.return_value = _mock_budget_window_result( + cache_status="result-hit" + ) + + response = rest_client.get( + "/us/economy/123/over/456/budget-window?region=us&start_year=2026&window_size=2" + ) + + assert response.status_code == 200 + assert response.headers["X-PolicyEngine-Budget-Window-Cache"] == "result-hit" diff --git a/tests/unit/libs/test_simulation_api_modal.py b/tests/unit/libs/test_simulation_api_modal.py index 373678764..8012dfb9b 100644 --- a/tests/unit/libs/test_simulation_api_modal.py +++ b/tests/unit/libs/test_simulation_api_modal.py @@ -341,7 +341,11 @@ def test__given_network_error__then_raises_exception( # When/Then with pytest.raises(httpx.RequestError): - api.run(MOCK_SIMULATION_PAYLOAD) + api.run(MOCK_SIMULATION_PAYLOAD_WITH_TELEMETRY) + + log_payload = mock_modal_logger.log_struct.call_args.args[0] + assert "Modal API request error" in log_payload["message"] + assert log_payload["run_id"] == MOCK_RUN_ID class TestResolveAppName: def test__given_country_and_version__then_returns_registered_app( diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index 47f5efeaf..5f977e5c1 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -1145,6 +1145,7 @@ def test__given_no_cached_batch__submits_parent_batch_and_returns_queued_result( assert result.completed_years == [] assert result.computing_years == [] assert result.queued_years == ["2026", "2027", "2028"] + assert result.cache_status == "miss" assert "Queued 2026" in result.message mock_simulation_api.run_budget_window_batch.assert_called_once() submitted_payload = ( @@ -1203,6 +1204,7 @@ def test__given_completed_cached_result__returns_completed_batch_result( assert result.status == ImpactStatus.OK assert result.progress == 100 assert result.data == completed_result + assert result.cache_status == "result-hit" mock_simulation_api.get_budget_window_batch_by_id.assert_not_called() mock_simulation_api.run_budget_window_batch.assert_not_called() @@ -1232,6 +1234,7 @@ def test__given_cached_batch_id__returns_running_batch_progress( assert result.completed_years == ["2026"] assert result.computing_years == ["2027"] assert result.queued_years == ["2028"] + assert result.cache_status == "batch-id-hit" assert "1 of 3 complete" in result.message mock_simulation_api.get_budget_window_batch_by_id.assert_called_once_with( "fc-budget-123" @@ -1267,6 +1270,7 @@ def test__given_completed_batch_poll__caches_result_and_returns_completed( assert result.status == ImpactStatus.OK assert result.data == completed_result + assert result.cache_status == "batch-id-hit" mock_budget_window_cache.set_completed_result.assert_called_once_with( "budget-window-cache-key", completed_result ) @@ -1274,6 +1278,39 @@ def test__given_completed_batch_poll__caches_result_and_returns_completed( "budget-window-cache-key" ) + @pytest.mark.parametrize("malformed_result", [None, {}, []]) + def test__given_completed_batch_without_result__returns_error_without_caching( + self, + economy_service, + base_params, + mock_simulation_api, + mock_budget_window_cache, + malformed_result, + ): + mock_budget_window_cache.get_batch_job_id.return_value = "fc-budget-123" + mock_simulation_api.get_budget_window_batch_by_id.return_value = ( + create_mock_budget_window_batch_execution( + batch_job_id="fc-budget-123", + status="complete", + progress=100, + completed_years=["2026", "2027"], + running_years=[], + queued_years=["2028"], + result=malformed_result, + ) + ) + + result = economy_service.get_budget_window_economic_impact(**base_params) + + assert result.status == ImpactStatus.ERROR + assert result.error == "Budget-window batch completed without a result" + assert result.completed_years == ["2026", "2027"] + assert result.computing_years == [] + assert result.queued_years == ["2028"] + assert result.cache_status == "batch-id-hit" + mock_budget_window_cache.set_completed_result.assert_not_called() + mock_budget_window_cache.clear_batch_job_id.assert_not_called() + def test__given_failed_batch_poll__returns_failed( self, economy_service, @@ -1301,6 +1338,7 @@ def test__given_failed_batch_poll__returns_failed( assert result.completed_years == ["2026"] assert result.computing_years == [] assert result.queued_years == ["2028"] + assert result.cache_status == "batch-id-hit" mock_budget_window_cache.set_completed_result.assert_not_called() def test__given_existing_start_claim__does_not_submit_duplicate_batch( @@ -1317,6 +1355,7 @@ def test__given_existing_start_claim__does_not_submit_duplicate_batch( assert result.status == ImpactStatus.COMPUTING assert result.progress == 0 assert result.queued_years == ["2026", "2027", "2028"] + assert result.cache_status == "starting-claim-hit" mock_simulation_api.run_budget_window_batch.assert_not_called() def test__given_batch_submission_fails__clears_start_claim( From ba93db9739e5130894400c7a5b5f981431138403 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 4 May 2026 22:13:55 +0200 Subject: [PATCH 22/27] Cover budget-window dedupe and failure paths --- Makefile | 2 +- policyengine_api/libs/simulation_api_modal.py | 9 ++ .../services/budget_window_cache.py | 1 + .../test_budget_window_in_flight_dedupe.py | 115 ++++++++++++++++++ .../test_live_budget_window_cache.py | 94 +++++++++++++- tests/unit/libs/test_simulation_api_modal.py | 14 +++ .../unit/services/test_budget_window_cache.py | 5 +- tests/unit/services/test_economy_service.py | 34 ++++++ 8 files changed, 268 insertions(+), 6 deletions(-) create mode 100644 tests/integration/test_budget_window_in_flight_dedupe.py diff --git a/Makefile b/Makefile index 7a3588d95..4fe1df11c 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ test-env-vars: pytest tests/env_variables test: - MAX_HOUSEHOLDS=1000 coverage run -a --branch -m pytest tests/to_refactor tests/unit --disable-pytest-warnings + MAX_HOUSEHOLDS=1000 coverage run -a --branch -m pytest tests/to_refactor tests/unit tests/integration/test_budget_window_in_flight_dedupe.py --disable-pytest-warnings coverage xml -i debug-test: diff --git a/policyengine_api/libs/simulation_api_modal.py b/policyengine_api/libs/simulation_api_modal.py index 67b60591d..0ea8900f9 100644 --- a/policyengine_api/libs/simulation_api_modal.py +++ b/policyengine_api/libs/simulation_api_modal.py @@ -299,6 +299,15 @@ def get_execution_by_id(self, job_id: str) -> ModalSimulationExecution: ) raise + except httpx.RequestError as e: + logger.log_struct( + { + "message": f"Modal API request error polling job {job_id}: {str(e)}", + }, + severity="ERROR", + ) + raise + def get_budget_window_batch_by_id( self, batch_job_id: str ) -> ModalBudgetWindowBatchExecution: diff --git a/policyengine_api/services/budget_window_cache.py b/policyengine_api/services/budget_window_cache.py index dbfc74fb8..7c19f5921 100644 --- a/policyengine_api/services/budget_window_cache.py +++ b/policyengine_api/services/budget_window_cache.py @@ -112,6 +112,7 @@ def set_completed_result(self, cache_key: str, result: dict[str, Any]) -> None: ) except Exception as error: self._handle_cache_error("write result", error) + raise def get_batch_job_id(self, cache_key: str) -> str | None: try: diff --git a/tests/integration/test_budget_window_in_flight_dedupe.py b/tests/integration/test_budget_window_in_flight_dedupe.py new file mode 100644 index 000000000..7f608ebf1 --- /dev/null +++ b/tests/integration/test_budget_window_in_flight_dedupe.py @@ -0,0 +1,115 @@ +import os +from unittest.mock import MagicMock + +from flask import Flask + +os.environ.setdefault("POLICYENGINE_DB_PASSWORD", "test") + + +class FakeRedis: + def __init__(self): + self.values = {} + + def get(self, key): + return self.values.get(key) + + def set(self, key, value, nx=False, ex=None): + if nx and key in self.values: + return False + self.values[key] = value + return True + + def delete(self, key): + self.values.pop(key, None) + + +def _create_client(economy_bp): + app = Flask(__name__) + app.register_blueprint(economy_bp) + return app.test_client() + + +def test_budget_window_in_flight_dedupe_uses_existing_batch_without_live_db( + monkeypatch, +): + from policyengine_api.libs.simulation_api_modal import ( + ModalBudgetWindowBatchExecution, + ) + from policyengine_api.routes.economy_routes import economy_bp + from policyengine_api.services import economy_service as economy_service_module + from policyengine_api.services.budget_window_cache import BudgetWindowCache + + fake_cache = BudgetWindowCache(client=FakeRedis()) + simulation_api = MagicMock() + reform_impacts_service = MagicMock() + + simulation_api.run_budget_window_batch.return_value = ( + ModalBudgetWindowBatchExecution( + batch_job_id="fc-budget-window-parent", + status="submitted", + ) + ) + simulation_api.get_budget_window_batch_by_id.return_value = ( + ModalBudgetWindowBatchExecution( + batch_job_id="fc-budget-window-parent", + status="running", + progress=50, + completed_years=["2026"], + running_years=["2027"], + queued_years=["2028"], + ) + ) + + monkeypatch.setattr(economy_service_module, "budget_window_cache", fake_cache) + monkeypatch.setattr(economy_service_module, "simulation_api", simulation_api) + monkeypatch.setattr( + economy_service_module, + "reform_impacts_service", + reform_impacts_service, + ) + monkeypatch.setattr( + economy_service_module.EconomyService, + "_build_budget_window_batch_payload", + lambda self, **kwargs: { + "country_id": "us", + "start_year": kwargs["start_year"], + "window_size": kwargs["window_size"], + "max_parallel": kwargs["max_parallel"], + }, + ) + + client = _create_client(economy_bp) + path = "/us/economy/123/over/456/budget-window" + params = { + "region": "us", + "dataset": "hf://policyengine/test.h5@1.0", + "start_year": "2026", + "window_size": "3", + } + + first_response = client.get(path, query_string=params) + first_payload = first_response.get_json() + + assert first_response.status_code == 200 + assert first_response.headers["X-PolicyEngine-Budget-Window-Cache"] == "miss" + assert first_payload["status"] == "computing" + assert first_payload["queued_years"] == ["2026", "2027", "2028"] + + second_response = client.get(path, query_string=params) + second_payload = second_response.get_json() + + assert second_response.status_code == 200 + assert ( + second_response.headers["X-PolicyEngine-Budget-Window-Cache"] == "batch-id-hit" + ) + assert second_payload["status"] == "computing" + assert second_payload["progress"] == 50 + assert second_payload["completed_years"] == ["2026"] + assert second_payload["computing_years"] == ["2027"] + assert second_payload["queued_years"] == ["2028"] + + simulation_api.run_budget_window_batch.assert_called_once() + simulation_api.get_budget_window_batch_by_id.assert_called_once_with( + "fc-budget-window-parent" + ) + reform_impacts_service.assert_not_called() diff --git a/tests/integration/test_live_budget_window_cache.py b/tests/integration/test_live_budget_window_cache.py index 820415dc9..9886b92b6 100644 --- a/tests/integration/test_live_budget_window_cache.py +++ b/tests/integration/test_live_budget_window_cache.py @@ -37,17 +37,24 @@ def _poll_budget_window(api_client, path: str, params: dict) -> dict: time.sleep(INTEGRATION_POLL_INTERVAL_SECONDS) -def test_live_budget_window_completed_result_cache(api_client, integration_probe_id): +def _get_current_law_id(api_client) -> int: metadata_response = api_client.get("/us/metadata") metadata_response.raise_for_status() - current_law_id = metadata_response.json()["result"]["current_law_id"] + return metadata_response.json()["result"]["current_law_id"] + +def _create_utah_reform_policy(api_client) -> int: policy_response = api_client.post( "/us/policy", json=_load_reform_payload("utah_reform.json"), ) assert policy_response.status_code in (200, 201) - policy_id = policy_response.json()["result"]["policy_id"] + return policy_response.json()["result"]["policy_id"] + + +def test_live_budget_window_completed_result_cache(api_client, integration_probe_id): + current_law_id = _get_current_law_id(api_client) + policy_id = _create_utah_reform_policy(api_client) path = f"/us/economy/{policy_id}/over/{current_law_id}/budget-window" params = { @@ -73,3 +80,84 @@ def test_live_budget_window_completed_result_cache(api_client, integration_probe assert second_payload["progress"] == 100, second_payload assert second_payload["result"] == first_payload["result"] assert second_response.headers["X-PolicyEngine-Budget-Window-Cache"] == "result-hit" + + +def test_live_budget_window_multi_year_run(api_client, integration_probe_id): + current_law_id = _get_current_law_id(api_client) + policy_id = _create_utah_reform_policy(api_client) + + path = f"/us/economy/{policy_id}/over/{current_law_id}/budget-window" + params = { + "region": "ut", + "start_year": "2025", + "window_size": 2, + "staging_probe": f"{integration_probe_id}-budget-window-multi-year", + } + + payload = _poll_budget_window(api_client, path, params) + + assert payload["status"] == "ok", payload + assert payload["progress"] == 100, payload + result = payload["result"] + assert result is not None, payload + assert result["kind"] == "budgetWindow", payload + assert result["windowSize"] == 2, payload + assert result["startYear"] == "2025", payload + assert result["endYear"] == "2026", payload + assert [impact["year"] for impact in result["annualImpacts"]] == [ + "2025", + "2026", + ] + assert result["totals"]["year"] == "Total", payload + + +def test_live_budget_window_failed_batch_mapping(api_client, integration_probe_id): + current_law_id = _get_current_law_id(api_client) + policy_id = _create_utah_reform_policy(api_client) + + path = f"/us/economy/{policy_id}/over/{current_law_id}/budget-window" + params = { + "region": "ut", + "dataset": "hf://policyengine/nonexistent-budget-window-test.h5@0.0.0", + "start_year": "2025", + "window_size": 1, + "staging_probe": f"{integration_probe_id}-budget-window-failure", + } + + payload = _poll_budget_window(api_client, path, params) + + assert payload["status"] == "error", payload + assert payload["result"] is None, payload + assert payload["error"], payload + assert isinstance(payload["completed_years"], list), payload + assert isinstance(payload["computing_years"], list), payload + assert isinstance(payload["queued_years"], list), payload + + +def test_live_budget_window_in_flight_dedupe(api_client, integration_probe_id): + current_law_id = _get_current_law_id(api_client) + policy_id = _create_utah_reform_policy(api_client) + + path = f"/us/economy/{policy_id}/over/{current_law_id}/budget-window" + params = { + "region": "ut", + "start_year": "2025", + "window_size": 2, + "staging_probe": f"{integration_probe_id}-budget-window-in-flight", + } + + first_response = api_client.get(path, params=params) + first_response.raise_for_status() + first_payload = first_response.json() + + assert first_payload["status"] == "computing", first_payload + assert first_response.headers["X-PolicyEngine-Budget-Window-Cache"] == "miss" + + second_response = api_client.get(path, params=params) + second_response.raise_for_status() + second_payload = second_response.json() + + assert second_response.headers["X-PolicyEngine-Budget-Window-Cache"] == ( + "batch-id-hit" + ) + assert second_payload["status"] in ("computing", "ok"), second_payload diff --git a/tests/unit/libs/test_simulation_api_modal.py b/tests/unit/libs/test_simulation_api_modal.py index 8012dfb9b..300badee5 100644 --- a/tests/unit/libs/test_simulation_api_modal.py +++ b/tests/unit/libs/test_simulation_api_modal.py @@ -551,6 +551,20 @@ def test__given_unexpected_http_error__then_raises_exception( with pytest.raises(httpx.HTTPStatusError): api.get_execution_by_id(MOCK_MODAL_JOB_ID) + def test__given_network_error__then_raises_exception( + self, + mock_httpx_client, + mock_modal_logger, + ): + mock_httpx_client.get.side_effect = httpx.RequestError("Connection failed") + api = SimulationAPIModal() + + with pytest.raises(httpx.RequestError): + api.get_execution_by_id(MOCK_MODAL_JOB_ID) + + log_payload = mock_modal_logger.log_struct.call_args.args[0] + assert MOCK_MODAL_JOB_ID in log_payload["message"] + class TestGetBudgetWindowBatchById: def test__given_running_batch__then_returns_running_status( self, diff --git a/tests/unit/services/test_budget_window_cache.py b/tests/unit/services/test_budget_window_cache.py index 233ba4fcd..f4677fae7 100644 --- a/tests/unit/services/test_budget_window_cache.py +++ b/tests/unit/services/test_budget_window_cache.py @@ -131,7 +131,7 @@ def test_get_completed_result_reraises_read_errors(monkeypatch): assert mock_logger.log_struct.call_args.kwargs["severity"] == "WARNING" -def test_set_completed_result_logs_write_errors(monkeypatch): +def test_set_completed_result_reraises_write_errors(monkeypatch): mock_logger = MagicMock() monkeypatch.setattr( "policyengine_api.services.budget_window_cache.logger", @@ -139,7 +139,8 @@ def test_set_completed_result_logs_write_errors(monkeypatch): ) cache = BudgetWindowCache(client=RaisingRedis(method="set")) - cache.set_completed_result("budget_window:v1:us:key", {"ok": True}) + with pytest.raises(RuntimeError, match="redis unavailable"): + cache.set_completed_result("budget_window:v1:us:key", {"ok": True}) assert mock_logger.log_struct.call_args.kwargs["severity"] == "WARNING" diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index 5f977e5c1..ada3f8e67 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -1311,6 +1311,40 @@ def test__given_completed_batch_without_result__returns_error_without_caching( mock_budget_window_cache.set_completed_result.assert_not_called() mock_budget_window_cache.clear_batch_job_id.assert_not_called() + def test__given_completed_batch_cache_write_fails__does_not_clear_batch_id( + self, + economy_service, + base_params, + mock_simulation_api, + mock_budget_window_cache, + ): + completed_result = { + "kind": "budgetWindow", + "startYear": "2026", + "endYear": "2028", + "windowSize": 3, + "annualImpacts": [], + "totals": {}, + } + mock_budget_window_cache.get_batch_job_id.return_value = "fc-budget-123" + mock_budget_window_cache.set_completed_result.side_effect = RuntimeError( + "redis unavailable" + ) + mock_simulation_api.get_budget_window_batch_by_id.return_value = ( + create_mock_budget_window_batch_execution( + batch_job_id="fc-budget-123", + status="complete", + progress=100, + completed_years=["2026", "2027", "2028"], + result=completed_result, + ) + ) + + with pytest.raises(RuntimeError, match="redis unavailable"): + economy_service.get_budget_window_economic_impact(**base_params) + + mock_budget_window_cache.clear_batch_job_id.assert_not_called() + def test__given_failed_batch_poll__returns_failed( self, economy_service, From 42cf9ddcc71c5fc8ddca276a137d4bb965a07bc2 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 4 May 2026 22:52:29 +0200 Subject: [PATCH 23/27] Canonicalize budget-window cache options --- policyengine_api/services/economy_service.py | 2 +- tests/unit/services/test_economy_service.py | 53 ++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 455d4ecda..8c5eb7587 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -1391,7 +1391,7 @@ def _build_options_hash( data_version: str | None = None, policyengine_version: str | None = None, ) -> str: - option_pairs = "&".join([f"{k}={v}" for k, v in options.items()]) + option_pairs = "&".join(f"{key}={options[key]}" for key in sorted(options)) bundle_parts = [ f"dataset={dataset}", f"model_version={model_version}", diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index ada3f8e67..d105c21bb 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -1472,6 +1472,41 @@ def test__given_runtime_cache_version__uses_versioned_cache_key_for_budget_windo assert cache_key_kwargs["time_period"] == "budget_window:2026:3" assert cache_key_kwargs["api_version"] == cache_version + def test__given_reordered_options__uses_same_budget_window_cache_identity( + self, + economy_service, + base_params, + mock_simulation_api, + mock_budget_window_cache, + ): + mock_budget_window_cache.get_completed_result.return_value = { + "kind": "budgetWindow" + } + + economy_service.get_budget_window_economic_impact( + **{ + **base_params, + "options": {"staging_probe": "abc", "analysis_mode": "x"}, + } + ) + first_cache_key_kwargs = dict( + mock_budget_window_cache.build_key.call_args.kwargs + ) + mock_budget_window_cache.build_key.reset_mock() + + economy_service.get_budget_window_economic_impact( + **{ + **base_params, + "options": {"analysis_mode": "x", "staging_probe": "abc"}, + } + ) + second_cache_key_kwargs = dict( + mock_budget_window_cache.build_key.call_args.kwargs + ) + + assert first_cache_key_kwargs == second_cache_key_kwargs + mock_simulation_api.run_budget_window_batch.assert_not_called() + def test__given_legacy_us_region__normalizes_before_building_cache_key( self, economy_service, @@ -2518,3 +2553,21 @@ def test__given_nonexistent_district__raises_value_error(self): with pytest.raises(ValueError) as exc_info: service._validate_us_region("congressional_district/cruft") assert "Invalid congressional district: 'cruft'" in str(exc_info.value) + + +class TestBuildOptionsHash: + def test__given_reordered_options__returns_same_hash(self): + service = EconomyService() + + first = service._build_options_hash( + options={"staging_probe": "abc", "analysis_mode": "x"}, + model_version="1.2.3", + dataset="hf://policyengine/test.h5@1.0", + ) + second = service._build_options_hash( + options={"analysis_mode": "x", "staging_probe": "abc"}, + model_version="1.2.3", + dataset="hf://policyengine/test.h5@1.0", + ) + + assert first == second From 707a8965211313fab5e86df084c6d1a3a85edc08 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 4 May 2026 23:32:25 +0200 Subject: [PATCH 24/27] Narrow budget-window PR scope --- .../fix-silent-exception-swallowing.fixed.md | 1 - policyengine_api/api.py | 8 +- policyengine_api/country.py | 9 +- policyengine_api/data/__init__.py | 7 +- policyengine_api/data/data.py | 29 +- policyengine_api/endpoints/simulation.py | 14 +- policyengine_api/services/economy_service.py | 525 +++-------------- .../services/reform_impacts_service.py | 134 +---- tests/fixtures/services/economy_service.py | 3 - tests/unit/conftest.py | 15 - tests/unit/data/test_sqlalchemy_v2.py | 74 +-- tests/unit/endpoints/test_simulation.py | 20 - tests/unit/services/test_economy_service.py | 534 +----------------- .../services/test_reform_impacts_service.py | 446 --------------- uv.lock | 2 +- 15 files changed, 110 insertions(+), 1711 deletions(-) delete mode 100644 changelog.d/fix-silent-exception-swallowing.fixed.md delete mode 100644 tests/unit/endpoints/test_simulation.py delete mode 100644 tests/unit/services/test_reform_impacts_service.py diff --git a/changelog.d/fix-silent-exception-swallowing.fixed.md b/changelog.d/fix-silent-exception-swallowing.fixed.md deleted file mode 100644 index 4b10062e5..000000000 --- a/changelog.d/fix-silent-exception-swallowing.fixed.md +++ /dev/null @@ -1 +0,0 @@ -Log exceptions instead of silently swallowing them during household calculations. diff --git a/policyengine_api/api.py b/policyengine_api/api.py index eb3eba9ee..112cce9ac 100644 --- a/policyengine_api/api.py +++ b/policyengine_api/api.py @@ -4,7 +4,6 @@ import time import sys -import os start_time = time.time() @@ -158,11 +157,8 @@ def log_timing(message): app.register_blueprint(user_profile_bp) log_timing("User profile routes registered") -if os.environ.get("FLASK_DEBUG") == "1": - app.route("/simulations", methods=["GET"])(get_simulations) - log_timing("Simulations endpoint registered") -else: - log_timing("Simulations endpoint skipped outside debug mode") +app.route("/simulations", methods=["GET"])(get_simulations) +log_timing("Simulations endpoint registered") app.register_blueprint(tracer_analysis_bp) log_timing("Tracer analysis routes registered") diff --git a/policyengine_api/country.py b/policyengine_api/country.py index 81331fb8f..4278637d8 100644 --- a/policyengine_api/country.py +++ b/policyengine_api/country.py @@ -1,6 +1,5 @@ import importlib import inspect -import logging import json from policyengine_core.taxbenefitsystems import TaxBenefitSystem from typing import Union, Optional @@ -429,10 +428,12 @@ def calculate( household[entity_plural][entity_id][variable_name][period] = ( entity_result ) - except Exception: - logging.exception(f"Error computing {variable_name} for {entity_id}") - if "axes" not in household: + except Exception as e: + if "axes" in household: + pass + else: household[entity_plural][entity_id][variable_name][period] = None + print(f"Error computing {variable_name} for {entity_id}: {e}") tracer_output = simulation.tracer.computation_log log_lines = tracer_output.lines(aggregate=False, max_depth=10) diff --git a/policyengine_api/data/__init__.py b/policyengine_api/data/__init__.py index 94703ee36..15673afdb 100644 --- a/policyengine_api/data/__init__.py +++ b/policyengine_api/data/__init__.py @@ -1,6 +1 @@ -from .data import ( - PolicyEngineDatabase, - database, - get_remote_database, - local_database, -) +from .data import PolicyEngineDatabase, database, local_database diff --git a/policyengine_api/data/data.py b/policyengine_api/data/data.py index 78cdb5459..6b16e713e 100644 --- a/policyengine_api/data/data.py +++ b/policyengine_api/data/data.py @@ -19,7 +19,6 @@ class _ResultProxy: Provides fetchone()/fetchall() with dict-like row access.""" def __init__(self, cursor_result): - self.rowcount = getattr(cursor_result, "rowcount", -1) try: # Use .mappings() so rows behave like dicts self._rows = list(cursor_result.mappings()) @@ -106,20 +105,16 @@ def _create_pool(self): with open(".dbpw") as f: db_pass = f.read().strip() db_name = "policyengine" - - def get_connection(): - return self.connector.connect( - instance_connection_string=instance_connection_name, - driver="pymysql", - db=db_name, - user=db_user, - password=db_pass, - ) - + conn = self.connector.connect( + instance_connection_string=instance_connection_name, + driver="pymysql", + db=db_name, + user=db_user, + password=db_pass, + ) self.pool = sqlalchemy.create_engine( "mysql+pymysql://", - creator=get_connection, - pool_pre_ping=True, + creator=lambda: conn, ) def _close_pool(self): @@ -264,11 +259,3 @@ def initialize(self): database = PolicyEngineDatabase(local=False, initialize=False) local_database = PolicyEngineDatabase(local=True, initialize=False) -remote_database = None - - -def get_remote_database() -> PolicyEngineDatabase: - global remote_database - if remote_database is None: - remote_database = PolicyEngineDatabase(local=False, initialize=False) - return remote_database diff --git a/policyengine_api/endpoints/simulation.py b/policyengine_api/endpoints/simulation.py index d300ae80b..a0d9bd70d 100644 --- a/policyengine_api/endpoints/simulation.py +++ b/policyengine_api/endpoints/simulation.py @@ -1,4 +1,4 @@ -from policyengine_api.data import get_remote_database +from policyengine_api.data import local_database """ @@ -42,14 +42,10 @@ def get_simulations( max_results = _DEFAULT_SIMULATION_RESULTS max_results = max(1, min(max_results, _MAX_SIMULATION_RESULTS)) - result = ( - get_remote_database() - .query( - "SELECT * FROM reform_impact ORDER BY start_time DESC LIMIT ?", - (max_results,), - ) - .fetchall() - ) + result = local_database.query( + "SELECT * FROM reform_impact ORDER BY start_time DESC LIMIT ?", + (max_results,), + ).fetchall() # Format into [{}] diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 8c5eb7587..67a38087c 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -59,7 +59,6 @@ class ImpactAction(Enum): COMPLETED = "completed" COMPUTING = "computing" CREATE = "create" - ERROR = "error" class ImpactStatus(Enum): @@ -77,11 +76,6 @@ class ImpactStatus(Enum): BUDGET_WINDOW_MAX_ACTIVE_YEARS = 20 BUDGET_WINDOW_MAX_YEARS = 75 BUDGET_WINDOW_MAX_END_YEAR = 2099 -PENDING_EXECUTION_ID_PREFIX = "pending:" -PROVISIONAL_CLAIM_TTL_SECONDS = 90 -STALE_PROVISIONAL_IMPACT_MESSAGE = ( - "Simulation claim expired before job submission completed" -) class EconomicImpactSetupOptions(BaseModel): @@ -110,7 +104,6 @@ class EconomicImpactResult(BaseModel): status: ImpactStatus data: Optional[dict] = None - message: Optional[str] = None model_config = {"frozen": True} # Make model immutable @@ -143,7 +136,7 @@ def error(cls, message: str) -> "EconomicImpactResult": Create an EconomicImpactResult for an error in the impact calculation. """ logger.log_struct({"message": message}, severity="ERROR") - return cls(status=ImpactStatus.ERROR, data=None, message=message) + return cls(status=ImpactStatus.ERROR, data=None) class BudgetWindowEconomicImpactResult(BaseModel): @@ -524,6 +517,7 @@ def _get_budget_window_result_from_batch_job_id( if batch_execution.status in EXECUTION_STATUSES_SUCCESS: result = batch_execution.result if not isinstance(result, dict) or not result: + budget_window_cache.clear_batch_job_id(cache_key) return BudgetWindowEconomicImpactResult.failed( "Budget-window batch completed without a result", completed_years=batch_execution.completed_years, @@ -540,6 +534,7 @@ def _get_budget_window_result_from_batch_job_id( if batch_execution.status in EXECUTION_STATUSES_FAILURE: error_message = batch_execution.error or "Budget-window batch failed" + budget_window_cache.clear_batch_job_id(cache_key) return BudgetWindowEconomicImpactResult.failed( error_message, completed_years=batch_execution.completed_years, @@ -694,111 +689,8 @@ def _get_or_create_economic_impact( most_recent_impact=most_recent_impact, ) - if impact_action == ImpactAction.ERROR: - logger.log_struct( - { - "message": "Found failed economic impact in db; returning error", - **setup_options.model_dump(), - }, - severity="INFO", - ) - return self._handle_error_impact( - setup_options=setup_options, - most_recent_impact=most_recent_impact, - ) - if impact_action == ImpactAction.CREATE: self._resolve_runtime_bundle_for_setup_options(setup_options) - try: - with reform_impacts_service.claim_lock( - country_id=setup_options.country_id, - policy_id=setup_options.reform_policy_id, - baseline_policy_id=setup_options.baseline_policy_id, - region=setup_options.region, - dataset=setup_options.dataset, - time_period=setup_options.time_period, - options_hash=setup_options.options_hash, - api_version=setup_options.api_version, - ): - most_recent_impact = self._get_exact_reform_impact( - setup_options=setup_options - ) - impact_action = self._determine_impact_action( - most_recent_impact=most_recent_impact - ) - - if impact_action == ImpactAction.COMPLETED: - logger.log_struct( - { - "message": "Found completed economic impact in db after locking; returning result", - **setup_options.model_dump(), - }, - severity="INFO", - ) - return self._handle_completed_impact( - setup_options=setup_options, - most_recent_impact=most_recent_impact, - ) - - if impact_action == ImpactAction.COMPUTING: - logger.log_struct( - { - "message": "Found computing economic impact in db after locking; returning progress", - **setup_options.model_dump(), - }, - severity="INFO", - ) - return self._handle_computing_impact( - setup_options=setup_options, - most_recent_impact=most_recent_impact, - ) - - if impact_action == ImpactAction.ERROR: - logger.log_struct( - { - "message": "Found failed economic impact in db after locking; returning error", - **setup_options.model_dump(), - }, - severity="INFO", - ) - return self._handle_error_impact( - setup_options=setup_options, - most_recent_impact=most_recent_impact, - ) - - stale_provisional_execution_id = None - if self._is_stale_provisional_impact(most_recent_impact): - stale_provisional_execution_id = most_recent_impact.get( - "execution_id" - ) - - provisional_execution_id = self._build_provisional_execution_id( - setup_options.process_id - ) - self._set_reform_impact_computing( - setup_options=setup_options, - execution_id=provisional_execution_id, - ) - if stale_provisional_execution_id: - self._expire_stale_provisional_impact( - setup_options=setup_options, - execution_id=stale_provisional_execution_id, - ) - except TimeoutError: - logger.log_struct( - { - "message": "Timed out waiting for economic impact claim lock; re-checking existing claim", - **setup_options.model_dump(), - }, - severity="WARNING", - ) - existing_impact = self._get_existing_economic_impact( - setup_options=setup_options - ) - if existing_impact is not None: - return existing_impact - return EconomicImpactResult.computing() - logger.log_struct( { "message": "No previous economic impact record found in db; creating new simulation run", @@ -808,7 +700,6 @@ def _get_or_create_economic_impact( ) return self._handle_create_impact( setup_options=setup_options, - provisional_execution_id=provisional_execution_id, ) raise ValueError(f"Unexpected impact action: {impact_action}") @@ -835,36 +726,6 @@ def _resolve_runtime_bundle_for_setup_options( runtime_app_name=setup_options.runtime_app_name, ) - def _get_existing_economic_impact( - self, setup_options: EconomicImpactSetupOptions - ) -> Optional[EconomicImpactResult]: - most_recent_impact = self._get_exact_reform_impact(setup_options=setup_options) - if not most_recent_impact: - return None - - status = most_recent_impact.get("status") - if status == ImpactStatus.ERROR.value: - return self._handle_error_impact( - setup_options=setup_options, - most_recent_impact=most_recent_impact, - ) - - if status == ImpactStatus.OK.value: - return self._handle_completed_impact( - setup_options=setup_options, - most_recent_impact=most_recent_impact, - ) - - if status == ImpactStatus.COMPUTING.value: - if self._is_stale_provisional_impact(most_recent_impact): - return None - return self._handle_computing_impact( - setup_options=setup_options, - most_recent_impact=most_recent_impact, - ) - - raise ValueError(f"Unknown impact status: {status}") - def _build_budget_window_progress_message( self, *, @@ -913,37 +774,7 @@ def _get_previous_impacts( api_version, ) ) - if previous_impacts: - return previous_impacts - - return reform_impacts_service.get_all_reform_impacts( - country_id, - policy_id, - baseline_policy_id, - region, - dataset, - time_period, - options_hash, - api_version, - ) - - def _get_exact_reform_impact( - self, - setup_options: EconomicImpactSetupOptions, - ) -> dict | None: - previous_impacts = reform_impacts_service.get_all_reform_impacts( - setup_options.country_id, - setup_options.reform_policy_id, - setup_options.baseline_policy_id, - setup_options.region, - setup_options.dataset, - setup_options.time_period, - setup_options.options_hash, - setup_options.api_version, - ) - if not previous_impacts: - return None - return previous_impacts[0] + return previous_impacts def _get_most_recent_impact( self, @@ -973,62 +804,6 @@ def _get_most_recent_impact( return previous_impacts[0] - def _build_provisional_execution_id(self, process_id: str) -> str: - return f"{PENDING_EXECUTION_ID_PREFIX}{process_id}" - - def _is_provisional_execution_id(self, execution_id: Any) -> bool: - return isinstance(execution_id, str) and execution_id.startswith( - PENDING_EXECUTION_ID_PREFIX - ) - - def _coerce_impact_start_time(self, start_time: Any) -> Optional[datetime.datetime]: - if start_time is None: - return None - - if isinstance(start_time, str): - parsed_start_time = datetime.datetime.fromisoformat(start_time) - elif hasattr(start_time, "tzinfo") and hasattr(start_time, "isoformat"): - parsed_start_time = start_time - else: - return None - - if parsed_start_time.tzinfo is None: - return parsed_start_time.replace(tzinfo=datetime.timezone.utc) - - return parsed_start_time.astimezone(datetime.timezone.utc) - - def _is_stale_provisional_impact(self, impact: dict | None) -> bool: - if not impact: - return False - - if not self._is_provisional_execution_id(impact.get("execution_id")): - return False - - start_time = self._coerce_impact_start_time(impact.get("start_time")) - if start_time is None: - return False - - current_time = datetime.datetime.now(datetime.timezone.utc) - if current_time.tzinfo is None: - current_time = current_time.replace(tzinfo=datetime.timezone.utc) - - claim_age = current_time - start_time - return claim_age.total_seconds() > PROVISIONAL_CLAIM_TTL_SECONDS - - def _expire_stale_provisional_impact( - self, - setup_options: EconomicImpactSetupOptions, - execution_id: str, - ) -> None: - if not self._is_provisional_execution_id(execution_id): - return - - self._set_reform_impact_error( - setup_options=setup_options, - message=STALE_PROVISIONAL_IMPACT_MESSAGE, - execution_id=execution_id, - ) - def _determine_impact_action( self, most_recent_impact: dict | None, @@ -1037,13 +812,9 @@ def _determine_impact_action( return ImpactAction.CREATE status = most_recent_impact.get("status") - if status == ImpactStatus.OK.value: + if status in [ImpactStatus.OK.value, ImpactStatus.ERROR.value]: return ImpactAction.COMPLETED - elif status == ImpactStatus.ERROR.value: - return ImpactAction.ERROR elif status == ImpactStatus.COMPUTING.value: - if self._is_stale_provisional_impact(most_recent_impact): - return ImpactAction.CREATE return ImpactAction.COMPUTING else: raise ValueError(f"Unknown impact status: {status}") @@ -1122,30 +893,14 @@ def _handle_completed_impact( ) ) - def _handle_error_impact( - self, - setup_options: EconomicImpactSetupOptions, - most_recent_impact: dict, - ) -> EconomicImpactResult: - error_message = most_recent_impact.get("message") or ( - f"Economic impact failed for {setup_options.time_period}" - ) - return EconomicImpactResult( - status=ImpactStatus.ERROR, - data=None, - message=error_message, - ) - def _handle_computing_impact( self, setup_options: EconomicImpactSetupOptions, most_recent_impact: dict, ) -> EconomicImpactResult: - execution_id = most_recent_impact["execution_id"] - if self._is_provisional_execution_id(execution_id): - return EconomicImpactResult.computing() - - execution = simulation_api.get_execution_by_id(execution_id) + execution = simulation_api.get_execution_by_id( + most_recent_impact["execution_id"] + ) execution_state = simulation_api.get_execution_status(execution) return self._handle_execution_state( execution_state=execution_state, @@ -1157,75 +912,65 @@ def _handle_computing_impact( def _handle_create_impact( self, setup_options: EconomicImpactSetupOptions, - provisional_execution_id: str, ) -> EconomicImpactResult: - try: - baseline_policy = policy_service.get_policy_json( - setup_options.country_id, setup_options.baseline_policy_id - ) - reform_policy = policy_service.get_policy_json( - setup_options.country_id, setup_options.reform_policy_id - ) - - sim_config: SimulationOptions = self._setup_sim_options( - country_id=setup_options.country_id, - reform_policy=reform_policy, - baseline_policy=baseline_policy, - region=setup_options.region, - time_period=setup_options.time_period, - dataset=setup_options.dataset, - scope="macro", - include_cliffs=setup_options.target == "cliff", - model_version=setup_options.model_version, - data_version=setup_options.data_version, - ) + baseline_policy = policy_service.get_policy_json( + setup_options.country_id, setup_options.baseline_policy_id + ) + reform_policy = policy_service.get_policy_json( + setup_options.country_id, setup_options.reform_policy_id + ) - sim_params = sim_config.model_dump(mode="json") - telemetry = self._build_simulation_telemetry( - setup_options=setup_options, - sim_config=sim_params, - ) + sim_config: SimulationOptions = self._setup_sim_options( + country_id=setup_options.country_id, + reform_policy=reform_policy, + baseline_policy=baseline_policy, + region=setup_options.region, + time_period=setup_options.time_period, + dataset=setup_options.dataset, + scope="macro", + include_cliffs=setup_options.target == "cliff", + model_version=setup_options.model_version, + data_version=setup_options.data_version, + ) - logger.log_struct( - { - "message": "Setting up sim API job", - "run_id": telemetry["run_id"], - **setup_options.model_dump(), - } - ) + sim_params = sim_config.model_dump(mode="json") + telemetry = self._build_simulation_telemetry( + setup_options=setup_options, + sim_config=sim_params, + ) - # Preserve both legacy metadata and the new telemetry envelope. - sim_params["_metadata"] = { - "reform_policy_id": setup_options.reform_policy_id, - "baseline_policy_id": setup_options.baseline_policy_id, - "process_id": setup_options.process_id, - "model_version": setup_options.model_version, - "policyengine_version": setup_options.policyengine_version, - "data_version": setup_options.data_version, - "dataset": setup_options.dataset, - "resolved_app_name": setup_options.runtime_app_name, + logger.log_struct( + { + "message": "Setting up sim API job", + "run_id": telemetry["run_id"], + **setup_options.model_dump(), } - sim_params["_telemetry"] = telemetry - - # The simulation gateway (policyengine-api-v2 PR #458) requires - # ``time_period`` as a string, but the upstream ``policyengine`` - # package (``TimePeriodType = int``) coerces the value to int during - # ``model_validate`` and ``model_dump`` re-emits it as int. Cast back - # to str at the request boundary so the gateway's strict schema - # validates instead of returning 422. - if sim_params.get("time_period") is not None: - sim_params["time_period"] = str(sim_params["time_period"]) - - sim_api_execution = simulation_api.run(sim_params) - execution_id = simulation_api.get_execution_id(sim_api_execution) - except Exception as error: - error_message = f"Failed to start simulation API job: {str(error)}" - self._set_reform_impact_error( - setup_options=setup_options, - message=error_message, - execution_id=provisional_execution_id, - ) - return EconomicImpactResult.error(message=error_message) + ) + + # Preserve both legacy metadata and the new telemetry envelope. + sim_params["_metadata"] = { + "reform_policy_id": setup_options.reform_policy_id, + "baseline_policy_id": setup_options.baseline_policy_id, + "process_id": setup_options.process_id, + "model_version": setup_options.model_version, + "policyengine_version": setup_options.policyengine_version, + "data_version": setup_options.data_version, + "dataset": setup_options.dataset, + "resolved_app_name": setup_options.runtime_app_name, + } + sim_params["_telemetry"] = telemetry + + # The simulation gateway (policyengine-api-v2 PR #458) requires + # ``time_period`` as a string, but the upstream ``policyengine`` + # package (``TimePeriodType = int``) coerces the value to int during + # ``model_validate`` and ``model_dump`` re-emits it as int. Cast back + # to str at the request boundary so the gateway's strict schema + # validates instead of returning 422. + if sim_params.get("time_period") is not None: + sim_params["time_period"] = str(sim_params["time_period"]) + + sim_api_execution = simulation_api.run(sim_params) + execution_id = simulation_api.get_execution_id(sim_api_execution) run_id = getattr(sim_api_execution, "run_id", None) or telemetry["run_id"] @@ -1237,116 +982,12 @@ def _handle_create_impact( } logger.log_struct(progress_log, severity="INFO") - try: - updated_rows = self._update_reform_impact_execution_id( - setup_options=setup_options, - current_execution_id=provisional_execution_id, - new_execution_id=execution_id, - ) - except Exception as error: - logger.log_struct( - { - "message": "Failed to promote provisional reform impact row; inserting replacement tracking row", - **setup_options.model_dump(), - "execution_id": execution_id, - "provisional_execution_id": provisional_execution_id, - "error": str(error), - }, - severity="WARNING", - ) - updated_rows = 0 - - if updated_rows != 1: - self._recover_failed_execution_id_promotion( - setup_options=setup_options, - provisional_execution_id=provisional_execution_id, - execution_id=execution_id, - updated_rows=updated_rows, - ) - - return EconomicImpactResult.computing() - - def _recover_failed_execution_id_promotion( - self, - *, - setup_options: EconomicImpactSetupOptions, - provisional_execution_id: str, - execution_id: str, - updated_rows: int | None, - ) -> None: - logger.log_struct( - { - "message": "Provisional reform impact row was not updated; checking whether tracking has already been superseded", - **setup_options.model_dump(), - "execution_id": execution_id, - "provisional_execution_id": provisional_execution_id, - "updated_rows": updated_rows, - }, - severity="WARNING", + self._set_reform_impact_computing( + setup_options=setup_options, + execution_id=execution_id, ) - try: - with reform_impacts_service.claim_lock( - country_id=setup_options.country_id, - policy_id=setup_options.reform_policy_id, - baseline_policy_id=setup_options.baseline_policy_id, - region=setup_options.region, - dataset=setup_options.dataset, - time_period=setup_options.time_period, - options_hash=setup_options.options_hash, - api_version=setup_options.api_version, - ): - most_recent_impact = self._get_exact_reform_impact( - setup_options=setup_options - ) - if most_recent_impact is not None: - impact_status = most_recent_impact.get("status") - tracked_execution_id = most_recent_impact.get("execution_id") - if tracked_execution_id == execution_id: - return - - if ( - impact_status == ImpactStatus.COMPUTING.value - and tracked_execution_id == provisional_execution_id - ): - retry_updated_rows = self._update_reform_impact_execution_id( - setup_options=setup_options, - current_execution_id=provisional_execution_id, - new_execution_id=execution_id, - ) - if retry_updated_rows == 1: - return - elif impact_status in ( - ImpactStatus.OK.value, - ImpactStatus.COMPUTING.value, - ): - logger.log_struct( - { - "message": "Skipping replacement tracking row because another claim is already authoritative", - **setup_options.model_dump(), - "execution_id": execution_id, - "provisional_execution_id": provisional_execution_id, - "tracked_execution_id": tracked_execution_id, - "tracked_status": impact_status, - }, - severity="WARNING", - ) - return - - self._set_reform_impact_computing( - setup_options=setup_options, - execution_id=execution_id, - ) - except TimeoutError: - logger.log_struct( - { - "message": "Timed out while recovering failed provisional promotion; leaving the newer claim authoritative", - **setup_options.model_dump(), - "execution_id": execution_id, - "provisional_execution_id": provisional_execution_id, - }, - severity="WARNING", - ) + return EconomicImpactResult.computing() def _setup_sim_options( self, @@ -1428,7 +1069,7 @@ def _should_refresh_cached_impact( setup_options: EconomicImpactSetupOptions, most_recent_impact: dict, ) -> bool: - if most_recent_impact.get("status") != ImpactStatus.OK.value: + if most_recent_impact.get("status") == ImpactStatus.COMPUTING.value: return False cached_result = self._extract_cached_result(most_recent_impact) @@ -1664,9 +1305,6 @@ def _set_reform_impact_computing( In the reform_impact table, set the status of the impact to "computing". """ try: - start_time = datetime.datetime.now(datetime.timezone.utc).replace( - tzinfo=None - ) reform_impacts_service.set_reform_impact( country_id=setup_options.country_id, policy_id=setup_options.reform_policy_id, @@ -1679,7 +1317,7 @@ def _set_reform_impact_computing( status=ImpactStatus.COMPUTING.value, api_version=setup_options.api_version, reform_impact_json=json.dumps({}), - start_time=start_time, + start_time=datetime.datetime.now(), execution_id=execution_id, ) except Exception as e: @@ -1691,33 +1329,6 @@ def _set_reform_impact_computing( ) raise e - def _update_reform_impact_execution_id( - self, - setup_options: EconomicImpactSetupOptions, - current_execution_id: str, - new_execution_id: str, - ) -> int | None: - try: - return reform_impacts_service.update_reform_impact_execution_id( - country_id=setup_options.country_id, - policy_id=setup_options.reform_policy_id, - baseline_policy_id=setup_options.baseline_policy_id, - region=setup_options.region, - dataset=setup_options.dataset, - time_period=setup_options.time_period, - options_hash=setup_options.options_hash, - current_execution_id=current_execution_id, - new_execution_id=new_execution_id, - ) - except Exception as e: - logger.log_struct( - { - "message": f"Error updating reform impact execution id: {str(e)}", - **setup_options.model_dump(), - } - ) - raise e - def _set_reform_impact_complete( self, setup_options: EconomicImpactSetupOptions, diff --git a/policyengine_api/services/reform_impacts_service.py b/policyengine_api/services/reform_impacts_service.py index d2aad8b3a..0f41352f3 100644 --- a/policyengine_api/services/reform_impacts_service.py +++ b/policyengine_api/services/reform_impacts_service.py @@ -1,89 +1,14 @@ -from contextlib import contextmanager -import hashlib -from threading import Lock -from policyengine_api.data import database +from policyengine_api.data import local_database import datetime -LOCAL_REFORM_IMPACT_LOCK = Lock() -REFORM_IMPACT_LOCK_TIMEOUT_SECONDS = 5 - - class ReformImpactsService: """ Service for storing and retrieving economy-wide reform impacts; - this is connected to the shared reform_impact table. + this is connected to the locally-stored reform_impact table + and no existing route """ - def _build_lock_name( - self, - country_id, - policy_id, - baseline_policy_id, - region, - dataset, - time_period, - options_hash, - api_version, - ) -> str: - raw_key = ( - f"{country_id}:{policy_id}:{baseline_policy_id}:{region}:{dataset}:" - f"{time_period}:{options_hash}:{api_version}" - ) - digest = hashlib.sha256(raw_key.encode("utf-8")).hexdigest() - return f"ri:{digest[:61]}" - - @contextmanager - def claim_lock( - self, - *, - country_id, - policy_id, - baseline_policy_id, - region, - dataset, - time_period, - options_hash, - api_version, - timeout_seconds: int = REFORM_IMPACT_LOCK_TIMEOUT_SECONDS, - ): - if database.local: - with LOCAL_REFORM_IMPACT_LOCK: - yield - return - - lock_name = self._build_lock_name( - country_id=country_id, - policy_id=policy_id, - baseline_policy_id=baseline_policy_id, - region=region, - dataset=dataset, - time_period=time_period, - options_hash=options_hash, - api_version=api_version, - ) - with database.pool.connect() as conn: - acquired = ( - conn.exec_driver_sql( - "SELECT GET_LOCK(%s, %s) AS acquired", - (lock_name, timeout_seconds), - ) - .mappings() - .first() - ) - if acquired is None or acquired["acquired"] != 1: - raise TimeoutError( - f"Could not acquire reform impact lock for {country_id}/{policy_id}/{time_period}" - ) - - try: - yield - finally: - conn.exec_driver_sql( - "SELECT RELEASE_LOCK(%s) AS released", (lock_name,) - ) - conn.commit() - def get_all_reform_impacts( self, country_id, @@ -100,10 +25,9 @@ def get_all_reform_impacts( "SELECT reform_impact_json, status, message, start_time, execution_id FROM " "reform_impact WHERE country_id = ? AND reform_policy_id = ? AND " "baseline_policy_id = ? AND region = ? AND time_period = ? AND " - "options_hash = ? AND api_version = ? AND dataset = ? " - "ORDER BY start_time DESC, reform_impact_id DESC" + "options_hash = ? AND api_version = ? AND dataset = ?" ) - return database.query( + return local_database.query( query, ( country_id, @@ -140,7 +64,7 @@ def get_all_reform_impacts_by_options_hash_prefix( "(options_hash = ? OR options_hash LIKE ? ESCAPE '\\') AND api_version = ? AND dataset = ? " "ORDER BY CASE WHEN options_hash = ? THEN 0 ELSE 1 END, start_time DESC" ) - return database.query( + return local_database.query( query, ( country_id, @@ -181,7 +105,7 @@ def set_reform_impact( "region, dataset, time_period, options_json, options_hash, status, api_version, " "reform_impact_json, start_time, execution_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" ) - database.query( + local_database.query( query, ( country_id, @@ -203,44 +127,6 @@ def set_reform_impact( print(f"Error setting reform impact: {str(e)}") raise e - def update_reform_impact_execution_id( - self, - country_id, - policy_id, - baseline_policy_id, - region, - dataset, - time_period, - options_hash, - current_execution_id, - new_execution_id, - ): - try: - query = ( - "UPDATE reform_impact SET execution_id = ? WHERE country_id = ? AND " - "reform_policy_id = ? AND baseline_policy_id = ? AND region = ? AND " - "time_period = ? AND options_hash = ? AND dataset = ? AND " - "execution_id = ? AND status = 'computing'" - ) - result = database.query( - query, - ( - new_execution_id, - country_id, - policy_id, - baseline_policy_id, - region, - time_period, - options_hash, - dataset, - current_execution_id, - ), - ) - return getattr(result, "rowcount", None) - except Exception as e: - print(f"Error updating reform impact execution id: {str(e)}") - raise e - def delete_reform_impact( self, country_id, @@ -259,7 +145,7 @@ def delete_reform_impact( "dataset = ? AND status = 'computing'" ) - database.query( + local_database.query( query, ( country_id, @@ -294,7 +180,7 @@ def set_error_reform_impact( "region = ? AND time_period = ? AND options_hash = ? AND dataset = ? AND " "execution_id = ?" ) - database.query( + local_database.query( query, ( "error", @@ -338,7 +224,7 @@ def set_complete_reform_impact( "baseline_policy_id = ? AND region = ? AND time_period = ? AND " "options_hash = ? AND dataset = ? AND execution_id = ?" ) - database.query( + local_database.query( query, ( "ok", diff --git a/tests/fixtures/services/economy_service.py b/tests/fixtures/services/economy_service.py index 3ad61da1f..49202132d 100644 --- a/tests/fixtures/services/economy_service.py +++ b/tests/fixtures/services/economy_service.py @@ -2,7 +2,6 @@ from unittest.mock import patch, MagicMock import json import datetime -from contextlib import nullcontext from policyengine_api.constants import ( MODAL_EXECUTION_STATUS_SUBMITTED, @@ -124,10 +123,8 @@ def mock_reform_impacts_service(): mock_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [] mock_service.get_all_reform_impacts.return_value = [] mock_service.set_reform_impact.return_value = None - mock_service.update_reform_impact_execution_id.return_value = 1 mock_service.set_complete_reform_impact.return_value = None mock_service.set_error_reform_impact.return_value = None - mock_service.claim_lock.side_effect = lambda **kwargs: nullcontext() with patch( "policyengine_api.services.economy_service.reform_impacts_service", diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 485778a68..5e2a983ad 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -111,17 +111,8 @@ def override_database(test_db, monkeypatch): # Patch at the root module level where database is defined import policyengine_api.data - import policyengine_api.data.data as data_module - - def get_test_remote_database(): - return test_db monkeypatch.setattr(policyengine_api.data, "database", test_db) - monkeypatch.setattr( - policyengine_api.data, "get_remote_database", get_test_remote_database - ) - monkeypatch.setattr(data_module, "remote_database", test_db) - monkeypatch.setattr(data_module, "get_remote_database", get_test_remote_database) # Also patch the module-level variable for any existing imports import sys @@ -132,11 +123,5 @@ def get_test_remote_database(): monkeypatch.setattr(module, "database", test_db) if hasattr(module, "local_database"): monkeypatch.setattr(module, "local_database", test_db) - if hasattr(module, "remote_database"): - monkeypatch.setattr(module, "remote_database", test_db) - if hasattr(module, "get_remote_database"): - monkeypatch.setattr( - module, "get_remote_database", get_test_remote_database - ) yield test_db diff --git a/tests/unit/data/test_sqlalchemy_v2.py b/tests/unit/data/test_sqlalchemy_v2.py index 0a5887ea8..3882bb0f7 100644 --- a/tests/unit/data/test_sqlalchemy_v2.py +++ b/tests/unit/data/test_sqlalchemy_v2.py @@ -10,15 +10,10 @@ (dict(row) and row["key"]). """ -from unittest.mock import MagicMock - +import pytest import sqlalchemy -from policyengine_api.data.data import ( - _ResultProxy, - PolicyEngineDatabase, - get_remote_database as _real_get_remote_database, -) +from policyengine_api.data.data import _ResultProxy, PolicyEngineDatabase class TestSQLAlchemyVersion: @@ -103,15 +98,6 @@ def test_result_proxy_for_insert_statement(self): assert proxy.fetchone() is None assert proxy.fetchall() == [] - def test_result_proxy_preserves_rowcount_for_write_statement(self): - engine = sqlalchemy.create_engine("sqlite://") - with engine.connect() as conn: - conn.exec_driver_sql("CREATE TABLE test (id INTEGER PRIMARY KEY)") - result = conn.exec_driver_sql("INSERT INTO test VALUES (1)") - proxy = _ResultProxy(result) - - assert proxy.rowcount == 1 - class TestRemoteQueryPath: """Test the non-local query path that uses SQLAlchemy engine @@ -194,59 +180,3 @@ def test_remote_delete(self): db._execute_remote(["DELETE FROM test_table WHERE id = ?", (1,)]) result = db._execute_remote(["SELECT * FROM test_table WHERE id = ?", (1,)]) assert result.fetchone() is None - - -class TestRemotePoolCreation: - def test_create_pool_uses_fresh_connection_creator(self, monkeypatch): - first_connection = MagicMock(name="first_connection") - second_connection = MagicMock(name="second_connection") - mock_connector = MagicMock() - mock_connector.connect.side_effect = [first_connection, second_connection] - - captured_kwargs = {} - - def fake_create_engine(url, **kwargs): - captured_kwargs.update(kwargs) - return MagicMock() - - monkeypatch.setenv("POLICYENGINE_DB_PASSWORD", "test-password") - monkeypatch.setattr( - "policyengine_api.data.data.Connector", lambda: mock_connector - ) - monkeypatch.setattr( - "policyengine_api.data.data.sqlalchemy.create_engine", - fake_create_engine, - ) - - db = PolicyEngineDatabase.__new__(PolicyEngineDatabase) - db._create_pool() - - creator = captured_kwargs["creator"] - assert creator() is first_connection - assert creator() is second_connection - assert captured_kwargs["pool_pre_ping"] is True - - def test_get_remote_database_lazily_constructs_and_reuses_remote_database( - self, monkeypatch - ): - created_databases = [] - - class FakeDatabase: - def __init__(self, *, local, initialize): - self.local = local - self.initialize = initialize - created_databases.append(self) - - monkeypatch.setattr("policyengine_api.data.data.remote_database", None) - monkeypatch.setattr( - "policyengine_api.data.data.PolicyEngineDatabase", - FakeDatabase, - ) - - first = _real_get_remote_database() - second = _real_get_remote_database() - - assert first is second - assert len(created_databases) == 1 - assert created_databases[0].local is False - assert created_databases[0].initialize is False diff --git a/tests/unit/endpoints/test_simulation.py b/tests/unit/endpoints/test_simulation.py deleted file mode 100644 index e2936de11..000000000 --- a/tests/unit/endpoints/test_simulation.py +++ /dev/null @@ -1,20 +0,0 @@ -from unittest.mock import MagicMock, patch - -from policyengine_api.endpoints.simulation import get_simulations - - -def test_get_simulations_reads_from_remote_database(): - mock_database = MagicMock() - mock_database.query.return_value.fetchall.return_value = [{"id": 1}] - - with patch( - "policyengine_api.endpoints.simulation.get_remote_database", - return_value=mock_database, - ): - result = get_simulations() - - mock_database.query.assert_called_once_with( - "SELECT * FROM reform_impact ORDER BY start_time DESC LIMIT ?", - (100,), - ) - assert result == {"result": [{"id": 1}]} diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index d105c21bb..5b1da4405 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -1,4 +1,3 @@ -import datetime import json import sys import pytest @@ -66,9 +65,6 @@ def _stub_get_default_dataset(country, region): EconomicImpactSetupOptions, ImpactAction, ImpactStatus, - PENDING_EXECUTION_ID_PREFIX, - PROVISIONAL_CLAIM_TTL_SECONDS, - STALE_PROVISIONAL_IMPACT_MESSAGE, ) from tests.fixtures.services.economy_service import ( MOCK_COUNTRY_ID, @@ -203,37 +199,6 @@ def test__given_legacy_completed_impact__refreshes_cache( ) mock_simulation_api.run.assert_called_once() - def test__given_error_impact__returns_error_result( - self, - economy_service, - base_params, - mock_country_package_versions, - mock_get_dataset_version, - mock_get_policyengine_version, - mock_policy_service, - mock_reform_impacts_service, - mock_simulation_api, - mock_logger, - mock_datetime, - mock_numpy_random, - ): - error_impact = create_mock_reform_impact( - status="error", - reform_impact_json=json.dumps({}), - ) - error_impact["message"] = "Failed to start simulation API job" - mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.return_value = [ - error_impact - ] - - result = economy_service.get_economic_impact(**base_params) - - assert result.status == ImpactStatus.ERROR - assert result.data is None - assert result.message == "Failed to start simulation API job" - mock_reform_impacts_service.get_all_reform_impacts_by_options_hash_prefix.assert_called_once() - mock_simulation_api.run.assert_not_called() - def test__given_computing_impact_with_succeeded_execution__returns_completed_result( self, economy_service, @@ -347,21 +312,6 @@ def test__given_no_previous_impact__creates_new_simulation( assert result.data is None mock_simulation_api.run.assert_called_once() mock_reform_impacts_service.set_reform_impact.assert_called_once() - assert any( - call.args == (datetime.timezone.utc,) - for call in mock_datetime.now.call_args_list - ) - mock_reform_impacts_service.update_reform_impact_execution_id.assert_called_once_with( - country_id=MOCK_COUNTRY_ID, - policy_id=MOCK_POLICY_ID, - baseline_policy_id=MOCK_BASELINE_POLICY_ID, - region=MOCK_REGION, - dataset=MOCK_RESOLVED_DATASET, - time_period=MOCK_TIME_PERIOD, - options_hash=MOCK_OPTIONS_HASH, - current_execution_id=f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}", - new_execution_id=MOCK_EXECUTION_ID, - ) def test__given_no_previous_impact__includes_metadata_in_simulation_params( self, @@ -433,361 +383,6 @@ def test__given_no_previous_impact__includes_telemetry_in_simulation_params( mock_logger.log_struct.call_args_list[-1].kwargs["severity"] == "INFO" ) - def test__given_simulation_api_submission_failure__marks_provisional_claim_error( - self, - economy_service, - base_params, - mock_country_package_versions, - mock_get_dataset_version, - mock_policy_service, - mock_reform_impacts_service, - mock_simulation_api, - mock_logger, - mock_datetime, - mock_numpy_random, - ): - mock_reform_impacts_service.get_all_reform_impacts.return_value = [] - mock_simulation_api.run.side_effect = RuntimeError("gateway unavailable") - - result = economy_service.get_economic_impact(**base_params) - - assert result.status == ImpactStatus.ERROR - assert ( - result.message - == "Failed to start simulation API job: gateway unavailable" - ) - mock_reform_impacts_service.set_reform_impact.assert_called_once() - mock_reform_impacts_service.set_error_reform_impact.assert_called_once_with( - country_id=MOCK_COUNTRY_ID, - policy_id=MOCK_POLICY_ID, - baseline_policy_id=MOCK_BASELINE_POLICY_ID, - region=MOCK_REGION, - dataset=MOCK_RESOLVED_DATASET, - time_period=MOCK_TIME_PERIOD, - options_hash=MOCK_OPTIONS_HASH, - message="Failed to start simulation API job: gateway unavailable", - execution_id=f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}", - ) - mock_reform_impacts_service.update_reform_impact_execution_id.assert_not_called() - - def test__given_simulation_setup_failure__marks_provisional_claim_error( - self, - economy_service, - base_params, - mock_country_package_versions, - mock_get_dataset_version, - mock_policy_service, - mock_reform_impacts_service, - mock_simulation_api, - mock_logger, - mock_datetime, - mock_numpy_random, - ): - mock_reform_impacts_service.get_all_reform_impacts.return_value = [] - with patch.object( - economy_service, - "_setup_sim_options", - side_effect=ValueError("Invalid US state: 'zz'"), - ): - result = economy_service.get_economic_impact(**base_params) - - assert result.status == ImpactStatus.ERROR - assert ( - result.message - == "Failed to start simulation API job: Invalid US state: 'zz'" - ) - mock_reform_impacts_service.set_reform_impact.assert_called_once() - mock_reform_impacts_service.set_error_reform_impact.assert_called_once_with( - country_id=MOCK_COUNTRY_ID, - policy_id=MOCK_POLICY_ID, - baseline_policy_id=MOCK_BASELINE_POLICY_ID, - region=MOCK_REGION, - dataset=MOCK_RESOLVED_DATASET, - time_period=MOCK_TIME_PERIOD, - options_hash=MOCK_OPTIONS_HASH, - message="Failed to start simulation API job: Invalid US state: 'zz'", - execution_id=f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}", - ) - mock_simulation_api.run.assert_not_called() - mock_reform_impacts_service.update_reform_impact_execution_id.assert_not_called() - - def test__given_claim_lock_timeout_and_existing_provisional_claim__returns_computing( - self, - economy_service, - base_params, - mock_country_package_versions, - mock_get_dataset_version, - mock_policy_service, - mock_reform_impacts_service, - mock_simulation_api, - mock_logger, - mock_numpy_random, - ): - provisional_impact = create_mock_reform_impact( - status="computing", - execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_other", - start_time=datetime.datetime.now(datetime.timezone.utc), - ) - mock_reform_impacts_service.get_all_reform_impacts.side_effect = [ - [], - [provisional_impact], - ] - mock_reform_impacts_service.claim_lock.side_effect = TimeoutError( - "lock busy" - ) - - result = economy_service.get_economic_impact(**base_params) - - assert result.status == ImpactStatus.COMPUTING - mock_simulation_api.run.assert_not_called() - - def test__given_claim_lock_timeout_and_no_existing_claim__returns_computing( - self, - economy_service, - base_params, - mock_country_package_versions, - mock_get_dataset_version, - mock_policy_service, - mock_reform_impacts_service, - mock_simulation_api, - mock_logger, - mock_numpy_random, - ): - mock_reform_impacts_service.get_all_reform_impacts.side_effect = [ - [], - [], - ] - mock_reform_impacts_service.claim_lock.side_effect = TimeoutError( - "lock busy" - ) - - result = economy_service.get_economic_impact(**base_params) - - assert result.status == ImpactStatus.COMPUTING - mock_simulation_api.run.assert_not_called() - - def test__given_completed_impact_appears_after_lock__returns_cached_result( - self, - economy_service, - base_params, - mock_country_package_versions, - mock_get_dataset_version, - mock_get_policyengine_version, - mock_policy_service, - mock_reform_impacts_service, - mock_simulation_api, - mock_logger, - mock_datetime, - mock_numpy_random, - ): - completed_impact = create_mock_reform_impact(status="ok") - mock_reform_impacts_service.get_all_reform_impacts.side_effect = [ - [], - [completed_impact], - ] - - result = economy_service.get_economic_impact(**base_params) - - assert result.status == ImpactStatus.OK - mock_simulation_api.run.assert_not_called() - - def test__given_computing_impact_appears_after_lock__returns_progress( - self, - economy_service, - base_params, - mock_country_package_versions, - mock_get_dataset_version, - mock_policy_service, - mock_reform_impacts_service, - mock_simulation_api, - mock_logger, - mock_datetime, - mock_numpy_random, - ): - computing_impact = create_mock_reform_impact(status="computing") - mock_reform_impacts_service.get_all_reform_impacts.side_effect = [ - [], - [computing_impact], - ] - mock_simulation_api.get_execution_status.return_value = "running" - - result = economy_service.get_economic_impact(**base_params) - - assert result.status == ImpactStatus.COMPUTING - mock_simulation_api.run.assert_not_called() - mock_simulation_api.get_execution_by_id.assert_called_once_with( - MOCK_EXECUTION_ID - ) - - def test__given_error_impact_appears_after_lock__returns_error( - self, - economy_service, - base_params, - mock_country_package_versions, - mock_get_dataset_version, - mock_policy_service, - mock_reform_impacts_service, - mock_simulation_api, - mock_logger, - mock_datetime, - mock_numpy_random, - ): - error_impact = create_mock_reform_impact( - status="error", - message="Failed before lock released", - ) - mock_reform_impacts_service.get_all_reform_impacts.side_effect = [ - [], - [error_impact], - ] - - result = economy_service.get_economic_impact(**base_params) - - assert result.status == ImpactStatus.ERROR - assert result.message == "Failed before lock released" - mock_simulation_api.run.assert_not_called() - - def test__given_stale_provisional_claim__expires_and_recreates_simulation( - self, - economy_service, - base_params, - mock_country_package_versions, - mock_get_dataset_version, - mock_policy_service, - mock_reform_impacts_service, - mock_simulation_api, - mock_logger, - ): - stale_start_time = datetime.datetime.now( - datetime.timezone.utc - ) - datetime.timedelta(seconds=PROVISIONAL_CLAIM_TTL_SECONDS + 1) - stale_provisional_impact = create_mock_reform_impact( - status="computing", - execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_stale", - start_time=stale_start_time, - ) - mock_reform_impacts_service.get_all_reform_impacts.side_effect = [ - [stale_provisional_impact], - [stale_provisional_impact], - ] - - result = economy_service.get_economic_impact(**base_params) - - assert result.status == ImpactStatus.COMPUTING - mock_reform_impacts_service.set_error_reform_impact.assert_called_once_with( - country_id=MOCK_COUNTRY_ID, - policy_id=MOCK_POLICY_ID, - baseline_policy_id=MOCK_BASELINE_POLICY_ID, - region=MOCK_REGION, - dataset=MOCK_RESOLVED_DATASET, - time_period=MOCK_TIME_PERIOD, - options_hash=MOCK_OPTIONS_HASH, - message=STALE_PROVISIONAL_IMPACT_MESSAGE, - execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_stale", - ) - mock_reform_impacts_service.set_reform_impact.assert_called_once() - mock_simulation_api.run.assert_called_once() - - def test__given_provisional_promotion_updates_zero_rows__inserts_replacement_tracking_row( - self, - economy_service, - base_params, - mock_country_package_versions, - mock_get_dataset_version, - mock_policy_service, - mock_reform_impacts_service, - mock_simulation_api, - mock_logger, - mock_datetime, - mock_numpy_random, - ): - mock_reform_impacts_service.get_all_reform_impacts.return_value = [] - mock_reform_impacts_service.update_reform_impact_execution_id.return_value = 0 - - result = economy_service.get_economic_impact(**base_params) - - assert result.status == ImpactStatus.COMPUTING - assert mock_reform_impacts_service.set_reform_impact.call_count == 2 - first_insert = mock_reform_impacts_service.set_reform_impact.call_args_list[ - 0 - ] - second_insert = ( - mock_reform_impacts_service.set_reform_impact.call_args_list[1] - ) - assert ( - first_insert.kwargs["execution_id"] - == f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}" - ) - assert second_insert.kwargs["execution_id"] == MOCK_EXECUTION_ID - - def test__given_provisional_promotion_updates_zero_rows_but_newer_claim_exists__does_not_insert_fallback( - self, - economy_service, - base_params, - mock_country_package_versions, - mock_get_dataset_version, - mock_policy_service, - mock_reform_impacts_service, - mock_simulation_api, - mock_logger, - mock_datetime, - mock_numpy_random, - ): - replacement_impact = create_mock_reform_impact( - status="computing", - execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_replacement", - start_time=datetime.datetime.now(datetime.timezone.utc), - ) - mock_reform_impacts_service.get_all_reform_impacts.side_effect = [ - [], - [], - [replacement_impact], - ] - mock_reform_impacts_service.update_reform_impact_execution_id.return_value = 0 - - result = economy_service.get_economic_impact(**base_params) - - assert result.status == ImpactStatus.COMPUTING - assert mock_reform_impacts_service.set_reform_impact.call_count == 1 - inserted_execution_id = ( - mock_reform_impacts_service.set_reform_impact.call_args.kwargs[ - "execution_id" - ] - ) - assert ( - inserted_execution_id - == f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}" - ) - - def test__given_provisional_promotion_raises__inserts_replacement_tracking_row( - self, - economy_service, - base_params, - mock_country_package_versions, - mock_get_dataset_version, - mock_policy_service, - mock_reform_impacts_service, - mock_simulation_api, - mock_logger, - mock_datetime, - mock_numpy_random, - ): - mock_reform_impacts_service.get_all_reform_impacts.return_value = [] - mock_reform_impacts_service.update_reform_impact_execution_id.side_effect = RuntimeError( - "update failed" - ) - - result = economy_service.get_economic_impact(**base_params) - - assert result.status == ImpactStatus.COMPUTING - assert mock_reform_impacts_service.set_reform_impact.call_count == 2 - assert ( - mock_reform_impacts_service.set_reform_impact.call_args.kwargs[ - "execution_id" - ] - == MOCK_EXECUTION_ID - ) - def test__given_runtime_cache_version__uses_versioned_economy_cache_key( self, economy_service, @@ -1309,7 +904,9 @@ def test__given_completed_batch_without_result__returns_error_without_caching( assert result.queued_years == ["2028"] assert result.cache_status == "batch-id-hit" mock_budget_window_cache.set_completed_result.assert_not_called() - mock_budget_window_cache.clear_batch_job_id.assert_not_called() + mock_budget_window_cache.clear_batch_job_id.assert_called_once_with( + "budget-window-cache-key" + ) def test__given_completed_batch_cache_write_fails__does_not_clear_batch_id( self, @@ -1374,6 +971,9 @@ def test__given_failed_batch_poll__returns_failed( assert result.queued_years == ["2028"] assert result.cache_status == "batch-id-hit" mock_budget_window_cache.set_completed_result.assert_not_called() + mock_budget_window_cache.clear_batch_job_id.assert_called_once_with( + "budget-window-cache-key" + ) def test__given_existing_start_claim__does_not_submit_duplicate_batch( self, @@ -1671,61 +1271,6 @@ def test__given_no_impacts__returns_none( # Assert assert result is None - class TestGetExistingEconomicImpact: - @pytest.fixture - def economy_service(self): - return EconomyService() - - @pytest.fixture - def setup_options(self): - return EconomicImpactSetupOptions( - process_id=MOCK_PROCESS_ID, - country_id=MOCK_COUNTRY_ID, - reform_policy_id=MOCK_POLICY_ID, - baseline_policy_id=MOCK_BASELINE_POLICY_ID, - region=MOCK_REGION, - dataset=MOCK_DATASET, - time_period=MOCK_TIME_PERIOD, - options=MOCK_OPTIONS, - api_version=MOCK_API_VERSION, - target="general", - options_hash=MOCK_OPTIONS_HASH, - ) - - def test__given_stale_provisional_impact__returns_none( - self, - economy_service, - setup_options, - mock_reform_impacts_service, - ): - stale_impact = create_mock_reform_impact( - status="computing", - execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_stale", - start_time=datetime.datetime.now(datetime.timezone.utc) - - datetime.timedelta(seconds=PROVISIONAL_CLAIM_TTL_SECONDS + 1), - ) - mock_reform_impacts_service.get_all_reform_impacts.return_value = [ - stale_impact - ] - - result = economy_service._get_existing_economic_impact(setup_options) - - assert result is None - - def test__given_unknown_status__raises_value_error( - self, - economy_service, - setup_options, - mock_reform_impacts_service, - ): - unknown_impact = create_mock_reform_impact(status="mystery") - mock_reform_impacts_service.get_all_reform_impacts.return_value = [ - unknown_impact - ] - - with pytest.raises(ValueError, match="Unknown impact status: mystery"): - economy_service._get_existing_economic_impact(setup_options) - class TestDetermineImpactAction: @pytest.fixture def economy_service(self): @@ -1743,12 +1288,12 @@ def test__given_ok_status__returns_completed(self, economy_service): assert result == ImpactAction.COMPLETED - def test__given_error_status__returns_error(self, economy_service): + def test__given_error_status__returns_completed(self, economy_service): impact = create_mock_reform_impact(status="error") result = economy_service._determine_impact_action(impact) - assert result == ImpactAction.ERROR + assert result == ImpactAction.COMPLETED def test__given_computing_status__returns_computing(self, economy_service): impact = create_mock_reform_impact(status="computing") @@ -1757,20 +1302,6 @@ def test__given_computing_status__returns_computing(self, economy_service): assert result == ImpactAction.COMPUTING - def test__given_stale_provisional_computing_status__returns_create( - self, economy_service - ): - impact = create_mock_reform_impact( - status="computing", - execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_stale", - start_time=datetime.datetime.now(datetime.timezone.utc) - - datetime.timedelta(seconds=PROVISIONAL_CLAIM_TTL_SECONDS + 1), - ) - - result = economy_service._determine_impact_action(impact) - - assert result == ImpactAction.CREATE - def test__given_unknown_status__raises_error(self, economy_service): impact = create_mock_reform_impact(status="unknown") @@ -1778,33 +1309,6 @@ def test__given_unknown_status__raises_error(self, economy_service): economy_service._determine_impact_action(impact) assert "Unknown impact status: unknown" in str(exc_info.value) - def test__given_stale_provisional_iso_start_time__returns_create( - self, economy_service - ): - impact = create_mock_reform_impact( - status="computing", - execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_stale", - start_time=( - datetime.datetime.now(datetime.timezone.utc) - - datetime.timedelta(seconds=PROVISIONAL_CLAIM_TTL_SECONDS + 1) - ).isoformat(), - ) - - assert ( - economy_service._determine_impact_action(impact) == ImpactAction.CREATE - ) - - def test__given_provisional_without_start_time__is_not_stale( - self, economy_service - ): - impact = create_mock_reform_impact( - status="computing", - execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_pending", - ) - impact["start_time"] = None - - assert economy_service._is_stale_provisional_impact(impact) is False - class TestHandleExecutionState: @pytest.fixture def economy_service(self): @@ -1871,7 +1375,6 @@ def test__given_failed_state__returns_error_result( assert result.status == ImpactStatus.ERROR assert result.data is None - assert result.message == "Simulation API execution failed" mock_reform_impacts_service.set_error_reform_impact.assert_called_once() def test__given_active_state__returns_computing_result( @@ -1886,21 +1389,6 @@ def test__given_active_state__returns_computing_result( assert result.status == ImpactStatus.COMPUTING assert result.data is None - def test__given_provisional_claim__returns_computing_without_polling( - self, economy_service, setup_options, mock_simulation_api, mock_logger - ): - reform_impact = create_mock_reform_impact( - status="computing", - execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_pending", - ) - - result = economy_service._handle_computing_impact( - setup_options, reform_impact - ) - - assert result.status == ImpactStatus.COMPUTING - mock_simulation_api.get_execution_by_id.assert_not_called() - def test__given_unknown_state__raises_error( self, economy_service, setup_options ): @@ -1963,7 +1451,6 @@ def test__given_modal_failed_state__then_returns_error_result( # Then assert result.status == ImpactStatus.ERROR assert result.data is None - assert result.message == "Simulation API execution failed" mock_reform_impacts_service.set_error_reform_impact.assert_called_once() def test__given_modal_failed_state_with_error_message__then_includes_error_in_message( @@ -1985,10 +1472,6 @@ def test__given_modal_failed_state_with_error_message__then_includes_error_in_me # Then assert result.status == ImpactStatus.ERROR - assert ( - result.message - == "Simulation API execution failed: Simulation timed out" - ) # Verify the error message was passed to the service call_args = mock_reform_impacts_service.set_error_reform_impact.call_args assert "Simulation timed out" in call_args[1]["message"] @@ -2086,7 +1569,6 @@ def test__given_error__creates_correct_instance_and_logs(self): assert result.status == ImpactStatus.ERROR assert result.data is None - assert result.message == "Test error message" mock_logger.log_struct.assert_called_once() diff --git a/tests/unit/services/test_reform_impacts_service.py b/tests/unit/services/test_reform_impacts_service.py deleted file mode 100644 index 8ec6791c5..000000000 --- a/tests/unit/services/test_reform_impacts_service.py +++ /dev/null @@ -1,446 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from policyengine_api.services.reform_impacts_service import ReformImpactsService - - -LOCK_KWARGS = { - "country_id": "us", - "policy_id": 123, - "baseline_policy_id": 456, - "region": "us", - "dataset": "enhanced_cps", - "time_period": "2026", - "options_hash": "[option=value]", - "api_version": "e1cache01", -} - - -def _mock_database(monkeypatch, *, local=False, query_result=None, side_effect=None): - mock_database = MagicMock() - mock_database.local = local - if side_effect is not None: - mock_database.query.side_effect = side_effect - elif query_result is not None: - mock_database.query.return_value = query_result - monkeypatch.setattr( - "policyengine_api.services.reform_impacts_service.database", - mock_database, - ) - return mock_database - - -class TestReformImpactsService: - def test__given_reform_impact_lookup__does_not_manage_schema(self, monkeypatch): - service = ReformImpactsService() - - select_result = MagicMock() - select_result.fetchall.return_value = [] - mock_database = MagicMock() - mock_database.local = False - mock_database.query.return_value = select_result - - monkeypatch.setattr( - "policyengine_api.services.reform_impacts_service.database", - mock_database, - ) - - service.get_all_reform_impacts( - "us", - 123, - 456, - "us", - "enhanced_cps", - "2026", - "[option=value]", - "e1cache01", - ) - - mock_database.query.assert_called_once() - query = mock_database.query.call_args.args[0] - assert query.startswith("SELECT reform_impact_json") - assert not query.startswith("ALTER") - assert not query.startswith("SHOW") - - def test__given_remote_database__claim_lock_uses_advisory_lock(self, monkeypatch): - service = ReformImpactsService() - - acquired_result = MagicMock() - acquired_result.mappings.return_value.first.return_value = {"acquired": 1} - release_result = MagicMock() - - mock_connection = MagicMock() - mock_connection.exec_driver_sql.side_effect = [ - acquired_result, - release_result, - ] - - mock_connection_context = MagicMock() - mock_connection_context.__enter__.return_value = mock_connection - mock_connection_context.__exit__.return_value = False - - mock_pool = MagicMock() - mock_pool.connect.return_value = mock_connection_context - - mock_database = MagicMock() - mock_database.local = False - mock_database.pool = mock_pool - - monkeypatch.setattr( - "policyengine_api.services.reform_impacts_service.database", - mock_database, - ) - - with service.claim_lock( - country_id="us", - policy_id=123, - baseline_policy_id=456, - region="us", - dataset="enhanced_cps", - time_period="2026", - options_hash="[option=value]", - api_version="e1cache01", - ): - pass - - assert mock_connection.exec_driver_sql.call_count == 2 - - acquire_call = mock_connection.exec_driver_sql.call_args_list[0] - assert acquire_call.args == ( - "SELECT GET_LOCK(%s, %s) AS acquired", - ( - service._build_lock_name( - country_id="us", - policy_id=123, - baseline_policy_id=456, - region="us", - dataset="enhanced_cps", - time_period="2026", - options_hash="[option=value]", - api_version="e1cache01", - ), - 5, - ), - ) - assert len(acquire_call.args[1][0]) <= 64 - - release_call = mock_connection.exec_driver_sql.call_args_list[1] - assert release_call.args == ( - "SELECT RELEASE_LOCK(%s) AS released", - (acquire_call.args[1][0],), - ) - mock_connection.commit.assert_called_once() - - def test__given_remote_database_lock_timeout__claim_lock_raises(self, monkeypatch): - service = ReformImpactsService() - - acquired_result = MagicMock() - acquired_result.mappings.return_value.first.return_value = {"acquired": 0} - - mock_connection = MagicMock() - mock_connection.exec_driver_sql.return_value = acquired_result - - mock_connection_context = MagicMock() - mock_connection_context.__enter__.return_value = mock_connection - mock_connection_context.__exit__.return_value = False - - mock_pool = MagicMock() - mock_pool.connect.return_value = mock_connection_context - - mock_database = MagicMock() - mock_database.local = False - mock_database.pool = mock_pool - - monkeypatch.setattr( - "policyengine_api.services.reform_impacts_service.database", - mock_database, - ) - - with pytest.raises( - TimeoutError, - match="Could not acquire reform impact lock", - ): - with service.claim_lock( - country_id="us", - policy_id=123, - baseline_policy_id=456, - region="us", - dataset="enhanced_cps", - time_period="2026", - options_hash="[option=value]", - api_version="e1cache01", - ): - pass - - def test__given_local_database__claim_lock_uses_in_process_lock(self, monkeypatch): - service = ReformImpactsService() - mock_database = _mock_database(monkeypatch, local=True) - - with service.claim_lock(**LOCK_KWARGS): - mock_database.query.assert_not_called() - - mock_database.pool.connect.assert_not_called() - - def test__get_all_reform_impacts__queries_dataset_and_stable_order( - self, monkeypatch - ): - service = ReformImpactsService() - query_result = MagicMock() - query_result.fetchall.return_value = [{"status": "ok"}] - mock_database = _mock_database(monkeypatch, query_result=query_result) - - result = service.get_all_reform_impacts(**LOCK_KWARGS) - - assert result == [{"status": "ok"}] - query, params = mock_database.query.call_args.args - assert "AND dataset = ?" in query - assert "ORDER BY start_time DESC, reform_impact_id DESC" in query - assert params == ( - "us", - 123, - 456, - "us", - "2026", - "[option=value]", - "e1cache01", - "enhanced_cps", - ) - - def test__get_all_reform_impacts_by_options_hash_prefix__prefers_exact_hash( - self, monkeypatch - ): - service = ReformImpactsService() - query_result = MagicMock() - query_result.fetchall.return_value = [{"options_hash": "[option=value]"}] - mock_database = _mock_database(monkeypatch, query_result=query_result) - - result = service.get_all_reform_impacts_by_options_hash_prefix( - **LOCK_KWARGS, - options_hash_prefix="[option=%", - ) - - assert result == [{"options_hash": "[option=value]"}] - query, params = mock_database.query.call_args.args - assert "(options_hash = ? OR options_hash LIKE ? ESCAPE '\\')" in query - assert "ORDER BY CASE WHEN options_hash = ? THEN 0 ELSE 1 END" in query - assert params == ( - "us", - 123, - 456, - "us", - "2026", - "[option=value]", - "[option=%", - "e1cache01", - "enhanced_cps", - "[option=value]", - ) - - def test__set_reform_impact__inserts_tracking_row(self, monkeypatch): - service = ReformImpactsService() - mock_database = _mock_database(monkeypatch) - - service.set_reform_impact( - **LOCK_KWARGS, - options='{"option": "value"}', - status="computing", - reform_impact_json="{}", - start_time="2026-01-01 00:00:00", - execution_id="pending:job-1", - ) - - query, params = mock_database.query.call_args.args - assert query.startswith("INSERT INTO reform_impact") - assert params == ( - "us", - 123, - 456, - "us", - "enhanced_cps", - "2026", - '{"option": "value"}', - "[option=value]", - "computing", - "e1cache01", - "{}", - "2026-01-01 00:00:00", - "pending:job-1", - ) - - def test__update_reform_impact_execution_id__returns_rowcount(self, monkeypatch): - service = ReformImpactsService() - query_result = MagicMock() - query_result.rowcount = 1 - mock_database = _mock_database(monkeypatch, query_result=query_result) - - rowcount = service.update_reform_impact_execution_id( - country_id="us", - policy_id=123, - baseline_policy_id=456, - region="us", - dataset="enhanced_cps", - time_period="2026", - options_hash="[option=value]", - current_execution_id="pending:job-1", - new_execution_id="fc-job-1", - ) - - assert rowcount == 1 - query, params = mock_database.query.call_args.args - assert "status = 'computing'" in query - assert params == ( - "fc-job-1", - "us", - 123, - 456, - "us", - "2026", - "[option=value]", - "enhanced_cps", - "pending:job-1", - ) - - def test__delete_reform_impact__only_deletes_computing_rows(self, monkeypatch): - service = ReformImpactsService() - mock_database = _mock_database(monkeypatch) - - service.delete_reform_impact( - country_id="us", - policy_id=123, - baseline_policy_id=456, - region="us", - dataset="enhanced_cps", - time_period="2026", - options_hash="[option=value]", - ) - - query, params = mock_database.query.call_args.args - assert "status = 'computing'" in query - assert params == ( - "us", - 123, - 456, - "us", - "2026", - "[option=value]", - "enhanced_cps", - ) - - def test__set_error_reform_impact__updates_status_message_and_execution( - self, monkeypatch - ): - service = ReformImpactsService() - mock_database = _mock_database(monkeypatch) - - service.set_error_reform_impact( - country_id="us", - policy_id=123, - baseline_policy_id=456, - region="us", - dataset="enhanced_cps", - time_period="2026", - options_hash="[option=value]", - message="failed", - execution_id="fc-job-1", - ) - - query, params = mock_database.query.call_args.args - assert query.startswith("UPDATE reform_impact SET status = ?, message = ?") - assert params[0] == "error" - assert params[1] == "failed" - assert params[-1] == "fc-job-1" - - def test__set_complete_reform_impact__updates_result_and_execution( - self, monkeypatch - ): - service = ReformImpactsService() - mock_database = _mock_database(monkeypatch) - - service.set_complete_reform_impact( - country_id="us", - reform_policy_id=123, - baseline_policy_id=456, - region="us", - dataset="enhanced_cps", - time_period="2026", - options_hash="[option=value]", - reform_impact_json='{"ok": true}', - execution_id="fc-job-1", - ) - - query, params = mock_database.query.call_args.args - assert query.startswith("UPDATE reform_impact SET status = ?, message = ?") - assert params[0] == "ok" - assert params[1] == "Completed" - assert params[3] == '{"ok": true}' - assert params[-1] == "fc-job-1" - - @pytest.mark.parametrize( - "call_service", - [ - lambda service: service.get_all_reform_impacts(**LOCK_KWARGS), - lambda service: service.get_all_reform_impacts_by_options_hash_prefix( - **LOCK_KWARGS, - options_hash_prefix="[option=%", - ), - lambda service: service.set_reform_impact( - **LOCK_KWARGS, - options="{}", - status="computing", - reform_impact_json="{}", - start_time="2026-01-01 00:00:00", - execution_id="pending:job-1", - ), - lambda service: service.update_reform_impact_execution_id( - country_id="us", - policy_id=123, - baseline_policy_id=456, - region="us", - dataset="enhanced_cps", - time_period="2026", - options_hash="[option=value]", - current_execution_id="pending:job-1", - new_execution_id="fc-job-1", - ), - lambda service: service.delete_reform_impact( - country_id="us", - policy_id=123, - baseline_policy_id=456, - region="us", - dataset="enhanced_cps", - time_period="2026", - options_hash="[option=value]", - ), - lambda service: service.set_error_reform_impact( - country_id="us", - policy_id=123, - baseline_policy_id=456, - region="us", - dataset="enhanced_cps", - time_period="2026", - options_hash="[option=value]", - message="failed", - execution_id="fc-job-1", - ), - lambda service: service.set_complete_reform_impact( - country_id="us", - reform_policy_id=123, - baseline_policy_id=456, - region="us", - dataset="enhanced_cps", - time_period="2026", - options_hash="[option=value]", - reform_impact_json="{}", - execution_id="fc-job-1", - ), - ], - ) - def test__given_database_error__service_methods_reraise( - self, monkeypatch, call_service - ): - service = ReformImpactsService() - _mock_database(monkeypatch, side_effect=RuntimeError("db down")) - - with pytest.raises(RuntimeError, match="db down"): - call_service(service) diff --git a/uv.lock b/uv.lock index 778b61515..8bb11c5e4 100644 --- a/uv.lock +++ b/uv.lock @@ -2622,7 +2622,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/a0/f3/eeea7dab690e46cd9 [[package]] name = "policyengine-api" -version = "3.40.8" +version = "3.40.7" source = { editable = "." } dependencies = [ { name = "anthropic" }, From fdda3954675c35834362fe5ef2b45b5e7cf8f900 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 5 May 2026 16:19:02 +0200 Subject: [PATCH 25/27] Update live budget-window test years --- tests/integration/test_live_budget_window_cache.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/integration/test_live_budget_window_cache.py b/tests/integration/test_live_budget_window_cache.py index 9886b92b6..db57d3c39 100644 --- a/tests/integration/test_live_budget_window_cache.py +++ b/tests/integration/test_live_budget_window_cache.py @@ -59,7 +59,7 @@ def test_live_budget_window_completed_result_cache(api_client, integration_probe path = f"/us/economy/{policy_id}/over/{current_law_id}/budget-window" params = { "region": "ut", - "start_year": "2025", + "start_year": "2026", "window_size": 1, "staging_probe": f"{integration_probe_id}-budget-window-cache", } @@ -89,7 +89,7 @@ def test_live_budget_window_multi_year_run(api_client, integration_probe_id): path = f"/us/economy/{policy_id}/over/{current_law_id}/budget-window" params = { "region": "ut", - "start_year": "2025", + "start_year": "2026", "window_size": 2, "staging_probe": f"{integration_probe_id}-budget-window-multi-year", } @@ -102,11 +102,11 @@ def test_live_budget_window_multi_year_run(api_client, integration_probe_id): assert result is not None, payload assert result["kind"] == "budgetWindow", payload assert result["windowSize"] == 2, payload - assert result["startYear"] == "2025", payload - assert result["endYear"] == "2026", payload + assert result["startYear"] == "2026", payload + assert result["endYear"] == "2027", payload assert [impact["year"] for impact in result["annualImpacts"]] == [ - "2025", "2026", + "2027", ] assert result["totals"]["year"] == "Total", payload @@ -119,7 +119,7 @@ def test_live_budget_window_failed_batch_mapping(api_client, integration_probe_i params = { "region": "ut", "dataset": "hf://policyengine/nonexistent-budget-window-test.h5@0.0.0", - "start_year": "2025", + "start_year": "2026", "window_size": 1, "staging_probe": f"{integration_probe_id}-budget-window-failure", } @@ -141,7 +141,7 @@ def test_live_budget_window_in_flight_dedupe(api_client, integration_probe_id): path = f"/us/economy/{policy_id}/over/{current_law_id}/budget-window" params = { "region": "ut", - "start_year": "2025", + "start_year": "2026", "window_size": 2, "staging_probe": f"{integration_probe_id}-budget-window-in-flight", } From 0d0252a8ee9c1deac472106fd1c533739def82cf Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 5 May 2026 16:29:43 +0200 Subject: [PATCH 26/27] Document required Redis runtime --- .env.example | 6 ++++++ README.md | 6 ++++++ gcp/README.md | 2 ++ gcp/policyengine_api/start.sh | 21 ++++++++++++++------- policyengine_api/api.py | 6 ++++-- 5 files changed, 32 insertions(+), 9 deletions(-) diff --git a/.env.example b/.env.example index 9e0195f99..562132a89 100644 --- a/.env.example +++ b/.env.example @@ -18,3 +18,9 @@ OPENAI_API_KEY=policyengine_openai_api_key # Token for Hugging Face models HUGGING_FACE_TOKEN=policyengine_huggingface_token + +# Redis is required for budget-window economy requests and other API cache paths. +# Local development and App Engine use an in-container/local Redis by default. +CACHE_REDIS_HOST=127.0.0.1 +CACHE_REDIS_PORT=6379 +CACHE_REDIS_DB=0 diff --git a/README.md b/README.md index bb584acee..5aa1d2477 100644 --- a/README.md +++ b/README.md @@ -120,6 +120,8 @@ NOTE: Any output that needs to be calculated will not work. Therefore, only hous ### 6. Testing calculations +Redis is required for API cache paths, including budget-window economy requests. The budget-window endpoint uses Redis for completed-result caching and in-flight batch deduplication; if Redis is unavailable, those requests fail instead of falling back to the database or an in-process cache. + To test anything that utilizes Redis or the API's service workers (e.g. anything that requires society-wide calculations with the policy calculator), you'll also need to complete the following steps: 1. Start Redis @@ -136,6 +138,8 @@ brew install redis redis-server ``` +By default the API connects to Redis at `127.0.0.1:6379`, database `0`. Override this with `CACHE_REDIS_HOST`, `CACHE_REDIS_PORT`, and `CACHE_REDIS_DB` if your local Redis uses different connection settings. + 2. Start the API Run the below @@ -144,6 +148,8 @@ Run the below FLASK_DEBUG=1 python -m flask --app policyengine_api.api run ``` +App Engine staging and production deployments install and start Redis in the API container before Gunicorn starts. + NOTE: Calculations are not possible in the uk app without access to a specific dataset. Expect an error: "ValueError: Invalid response code 404 for url https://api.github.com/repos/policyengine/non-public-microdata/releases/tags/uk-2024-march-efo." ## Testing, Formatting, Changelogging diff --git a/gcp/README.md b/gcp/README.md index 477ab5fc7..a63f65ebd 100644 --- a/gcp/README.md +++ b/gcp/README.md @@ -2,6 +2,8 @@ The deployment actions build Docker images and deploy them to Google App Engine. The docker images themselves are based off a starter image (to save each API docker image having to spend 5 minutes installing the same dependencies). The starter image is the `Dockerfile` in this directory. +The App Engine API image installs `redis-server` and starts it through `gcp/policyengine_api/start.sh`. Redis is required at runtime for budget-window economy request caching and in-flight batch deduplication. The API reads `CACHE_REDIS_HOST`, `CACHE_REDIS_PORT`, and `CACHE_REDIS_DB`, defaulting to `127.0.0.1`, `6379`, and `0`. + To update the starter image: * `python setup.py sdist` to build the python package * `twine upload dist/*` to upload the package to pypi as `policyengine-api` diff --git a/gcp/policyengine_api/start.sh b/gcp/policyengine_api/start.sh index 37abbad46..92818ba81 100644 --- a/gcp/policyengine_api/start.sh +++ b/gcp/policyengine_api/start.sh @@ -1,18 +1,25 @@ #!/bin/sh # Environment variables PORT="${PORT:-8080}" -REDIS_PORT="${REDIS_PORT:-6379}" +CACHE_REDIS_HOST="${CACHE_REDIS_HOST:-127.0.0.1}" +CACHE_REDIS_PORT="${CACHE_REDIS_PORT:-6379}" +CACHE_REDIS_DB="${CACHE_REDIS_DB:-0}" +export CACHE_REDIS_HOST CACHE_REDIS_PORT CACHE_REDIS_DB -# Start the API -gunicorn -b :"$PORT" policyengine_api.api --timeout 300 --workers 5 --preload & - -# Start Redis with configuration for multiple clients -redis-server --protected-mode no \ +# Start Redis with configuration for multiple clients. +redis-server --bind "$CACHE_REDIS_HOST" \ + --port "$CACHE_REDIS_PORT" \ + --protected-mode yes \ --maxclients 10000 \ --timeout 0 & # Wait for Redis to be ready -sleep 2 +until redis-cli -h "$CACHE_REDIS_HOST" -p "$CACHE_REDIS_PORT" ping >/dev/null 2>&1; do + sleep 1 +done + +# Start the API +gunicorn -b :"$PORT" policyengine_api.api --timeout 300 --workers 5 --preload & # Keep the script running and handle shutdown gracefully trap "pkill -P $$; exit 1" INT TERM diff --git a/policyengine_api/api.py b/policyengine_api/api.py index 112cce9ac..7d6a67b01 100644 --- a/policyengine_api/api.py +++ b/policyengine_api/api.py @@ -4,6 +4,7 @@ import time import sys +import os start_time = time.time() @@ -89,8 +90,9 @@ def log_timing(message): { "CACHE_TYPE": "RedisCache", "CACHE_KEY_PREFIX": "policyengine", - "CACHE_REDIS_HOST": "127.0.0.1", - "CACHE_REDIS_PORT": 6379, + "CACHE_REDIS_HOST": os.environ.get("CACHE_REDIS_HOST", "127.0.0.1"), + "CACHE_REDIS_PORT": int(os.environ.get("CACHE_REDIS_PORT", "6379")), + "CACHE_REDIS_DB": int(os.environ.get("CACHE_REDIS_DB", "0")), "CACHE_DEFAULT_TIMEOUT": 300, } ) From 55e4f2f72460c79a234340d44c461f2fb9c80f2a Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 5 May 2026 17:57:15 +0200 Subject: [PATCH 27/27] Extract budget-window setup utilities --- policyengine_api/services/economy_service.py | 82 +++----------------- policyengine_api/utils/budget_window.py | 57 ++++++++++++++ tests/unit/utils/test_budget_window.py | 73 +++++++++++++++++ 3 files changed, 141 insertions(+), 71 deletions(-) create mode 100644 policyengine_api/utils/budget_window.py create mode 100644 tests/unit/utils/test_budget_window.py diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 67a38087c..3532244d9 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -21,6 +21,7 @@ normalize_us_region, ) from policyengine_api.data.places import validate_place_code +from policyengine_api.utils import budget_window as budget_window_utils from policyengine.simulation import SimulationOptions from policyengine.utils.data.datasets import get_default_dataset import json @@ -73,9 +74,9 @@ class ImpactStatus(Enum): COMPLETE_STATUSES = [ImpactStatus.OK.value, ImpactStatus.ERROR.value] COMPUTING_STATUS = ImpactStatus.COMPUTING.value -BUDGET_WINDOW_MAX_ACTIVE_YEARS = 20 -BUDGET_WINDOW_MAX_YEARS = 75 -BUDGET_WINDOW_MAX_END_YEAR = 2099 +BUDGET_WINDOW_MAX_ACTIVE_YEARS = budget_window_utils.BUDGET_WINDOW_MAX_ACTIVE_YEARS +BUDGET_WINDOW_MAX_YEARS = budget_window_utils.BUDGET_WINDOW_MAX_YEARS +BUDGET_WINDOW_MAX_END_YEAR = budget_window_utils.BUDGET_WINDOW_MAX_END_YEAR class EconomicImpactSetupOptions(BaseModel): @@ -296,36 +297,21 @@ def get_budget_window_economic_impact( if country_id == "us": region = normalize_us_region(region) - if target != "general": - raise ValueError( - "Budget-window calculations only support target='general'" - ) - - start_year_int = int(start_year) - if not 1 <= window_size <= BUDGET_WINDOW_MAX_YEARS: - raise ValueError( - f"window_size must be between 1 and {BUDGET_WINDOW_MAX_YEARS}" - ) - end_year = start_year_int + window_size - 1 - if end_year > BUDGET_WINDOW_MAX_END_YEAR: - raise ValueError( - f"budget-window end_year must be {BUDGET_WINDOW_MAX_END_YEAR} or earlier" - ) - - start_year = str(start_year_int) - years = self._build_budget_window_years( + budget_window_setup = budget_window_utils.build_budget_window_request_setup( start_year=start_year, window_size=window_size, + target=target, ) - setup_options = self._build_budget_window_setup_options( + start_year = budget_window_setup.start_year + years = budget_window_setup.years + setup_options = self._build_economic_impact_setup_options( country_id=country_id, policy_id=policy_id, baseline_policy_id=baseline_policy_id, region=region, dataset=dataset, - start_year=start_year, - window_size=window_size, - options=options, + time_period=budget_window_setup.time_period, + options=dict(options), api_version=api_version, target=target, ) @@ -378,52 +364,6 @@ def get_budget_window_economic_impact( print(f"Error getting budget-window economic impact: {str(e)}") raise e - def _build_budget_window_years( - self, - *, - start_year: str, - window_size: int, - ) -> list[str]: - start_year_int = int(start_year) - return [str(start_year_int + index) for index in range(window_size)] - - def _build_budget_window_time_period( - self, - *, - start_year: str, - window_size: int, - ) -> str: - return f"budget_window:{start_year}:{window_size}" - - def _build_budget_window_setup_options( - self, - *, - country_id: str, - policy_id: int, - baseline_policy_id: int, - region: str, - dataset: str, - start_year: str, - window_size: int, - options: dict, - api_version: str, - target: Literal["general", "cliff"], - ) -> EconomicImpactSetupOptions: - return self._build_economic_impact_setup_options( - country_id=country_id, - policy_id=policy_id, - baseline_policy_id=baseline_policy_id, - region=region, - dataset=dataset, - time_period=self._build_budget_window_time_period( - start_year=start_year, - window_size=window_size, - ), - options=dict(options), - api_version=api_version, - target=target, - ) - def _build_budget_window_cache_key( self, setup_options: EconomicImpactSetupOptions, diff --git a/policyengine_api/utils/budget_window.py b/policyengine_api/utils/budget_window.py new file mode 100644 index 000000000..6acbd5b3a --- /dev/null +++ b/policyengine_api/utils/budget_window.py @@ -0,0 +1,57 @@ +from dataclasses import dataclass +from typing import Literal + +BUDGET_WINDOW_MAX_ACTIVE_YEARS = 20 +BUDGET_WINDOW_MAX_YEARS = 75 +BUDGET_WINDOW_MAX_END_YEAR = 2099 + + +@dataclass(frozen=True) +class BudgetWindowRequestSetup: + start_year: str + window_size: int + years: list[str] + time_period: str + + +def build_budget_window_years(*, start_year: str, window_size: int) -> list[str]: + start_year_int = int(start_year) + return [str(start_year_int + index) for index in range(window_size)] + + +def build_budget_window_time_period(*, start_year: str, window_size: int) -> str: + return f"budget_window:{start_year}:{window_size}" + + +def build_budget_window_request_setup( + *, + start_year: str, + window_size: int, + target: Literal["general", "cliff"], +) -> BudgetWindowRequestSetup: + if target != "general": + raise ValueError("Budget-window calculations only support target='general'") + + start_year_int = int(start_year) + if not 1 <= window_size <= BUDGET_WINDOW_MAX_YEARS: + raise ValueError(f"window_size must be between 1 and {BUDGET_WINDOW_MAX_YEARS}") + + end_year = start_year_int + window_size - 1 + if end_year > BUDGET_WINDOW_MAX_END_YEAR: + raise ValueError( + f"budget-window end_year must be {BUDGET_WINDOW_MAX_END_YEAR} or earlier" + ) + + normalized_start_year = str(start_year_int) + return BudgetWindowRequestSetup( + start_year=normalized_start_year, + window_size=window_size, + years=build_budget_window_years( + start_year=normalized_start_year, + window_size=window_size, + ), + time_period=build_budget_window_time_period( + start_year=normalized_start_year, + window_size=window_size, + ), + ) diff --git a/tests/unit/utils/test_budget_window.py b/tests/unit/utils/test_budget_window.py new file mode 100644 index 000000000..b74d44092 --- /dev/null +++ b/tests/unit/utils/test_budget_window.py @@ -0,0 +1,73 @@ +import pytest + +from policyengine_api.utils.budget_window import ( + BUDGET_WINDOW_MAX_END_YEAR, + BUDGET_WINDOW_MAX_YEARS, + build_budget_window_request_setup, + build_budget_window_time_period, + build_budget_window_years, +) + + +def test_build_budget_window_years(): + assert build_budget_window_years(start_year="2026", window_size=3) == [ + "2026", + "2027", + "2028", + ] + + +def test_build_budget_window_time_period(): + assert ( + build_budget_window_time_period(start_year="2026", window_size=3) + == "budget_window:2026:3" + ) + + +def test_build_budget_window_request_setup_normalizes_start_year(): + setup = build_budget_window_request_setup( + start_year="02026", + window_size=2, + target="general", + ) + + assert setup.start_year == "2026" + assert setup.window_size == 2 + assert setup.years == ["2026", "2027"] + assert setup.time_period == "budget_window:2026:2" + + +def test_build_budget_window_request_setup_rejects_cliff_target(): + with pytest.raises( + ValueError, + match="Budget-window calculations only support target='general'", + ): + build_budget_window_request_setup( + start_year="2026", + window_size=2, + target="cliff", + ) + + +def test_build_budget_window_request_setup_rejects_oversized_window(): + with pytest.raises( + ValueError, + match=f"window_size must be between 1 and {BUDGET_WINDOW_MAX_YEARS}", + ): + build_budget_window_request_setup( + start_year="2026", + window_size=BUDGET_WINDOW_MAX_YEARS + 1, + target="general", + ) + + +def test_build_budget_window_request_setup_rejects_end_year_after_max(): + with pytest.raises( + ValueError, + match=f"budget-window end_year must be {BUDGET_WINDOW_MAX_END_YEAR} or earlier", + ): + build_budget_window_request_setup( + start_year="2090", + window_size=20, + target="general", + )