From d7382900a2017cf5ac5b5329499559a55ab53738 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 30 Apr 2026 23:55:46 +0200 Subject: [PATCH 1/8] Fix US program statistics variable mappings --- src/policyengine/outputs/aggregate.py | 112 +++++++++++-- src/policyengine/outputs/change_aggregate.py | 97 ++++++++++-- .../outputs/program_statistics.py | 19 ++- .../tax_benefit_models/us/analysis.py | 87 ++++++++-- .../tax_benefit_models/us/model.py | 1 + tests/test_aggregate.py | 4 +- tests/test_us_program_statistics.py | 148 ++++++++++++++++++ 7 files changed, 418 insertions(+), 50 deletions(-) create mode 100644 tests/test_us_program_statistics.py diff --git a/src/policyengine/outputs/aggregate.py b/src/policyengine/outputs/aggregate.py index d014b06c..9e4d3b86 100644 --- a/src/policyengine/outputs/aggregate.py +++ b/src/policyengine/outputs/aggregate.py @@ -1,3 +1,4 @@ +from difflib import get_close_matches from enum import Enum from typing import Any, Optional @@ -10,6 +11,64 @@ class AggregateType(str, Enum): COUNT = "count" +def get_aggregate_variable(simulation: Simulation, variable: str, context: str): + """Return a model variable with an aggregation-specific error message.""" + model_version = simulation.tax_benefit_model_version + try: + return model_version.get_variable(variable) + except ValueError as exc: + candidates = sorted(model_version.variables_by_name) + suggestions = get_close_matches(variable, candidates, n=3, cutoff=0.65) + suggestion_text = ( + f" Did you mean: {', '.join(repr(name) for name in suggestions)}?" + if suggestions + else "" + ) + raise ValueError( + f"{context} references missing variable '{variable}' in " + f"{model_version.model.id} version {model_version.version}." + f"{suggestion_text}" + ) from exc + + +def get_output_entity_data(simulation: Simulation, entity: str, context: str): + """Return output data for an entity with a clear error if it is unavailable.""" + if simulation.output_dataset is None or simulation.output_dataset.data is None: + raise ValueError( + f"{context} requires simulation '{simulation.id}' to have an " + "output dataset before aggregation." + ) + + try: + return getattr(simulation.output_dataset.data, entity) + except AttributeError as exc: + raise ValueError( + f"{context} references entity '{entity}', but simulation " + f"'{simulation.id}' has no output data for that entity." + ) from exc + + +def require_output_column( + data, + variable: str, + entity: str, + simulation: Simulation, + context: str, +) -> None: + """Raise a descriptive error when a known variable was not materialized.""" + if variable in data.columns: + return + + model_version = simulation.tax_benefit_model_version + raise ValueError( + f"{context} variable '{variable}' exists in {model_version.model.id} " + f"version {model_version.version}, but is not present in simulation " + f"'{simulation.id}' output data for entity '{entity}'. Add '{variable}' " + f"to {model_version.__class__.__name__}.entity_variables or pass it via " + "Simulation.extra_variables before running the simulation." + ) + + class Aggregate(Output): simulation: Simulation variable: str @@ -47,35 +106,61 @@ def run(self): elif self.quantile_geq is not None: self.filter_variable_geq = (self.quantile_geq - 1) / self.quantile - # Get variable object - var_obj = next( - v - for v in self.simulation.tax_benefit_model_version.variables - if v.name == self.variable + var_obj = get_aggregate_variable( + self.simulation, self.variable, "Aggregate.variable" ) # Get the target entity data target_entity = self.entity or var_obj.entity - data = getattr(self.simulation.output_dataset.data, target_entity) + data = get_output_entity_data( + self.simulation, target_entity, "Aggregate.entity" + ) # Map variable to target entity if needed if var_obj.entity != target_entity: + source_data = get_output_entity_data( + self.simulation, var_obj.entity, "Aggregate.variable" + ) + require_output_column( + source_data, + self.variable, + var_obj.entity, + self.simulation, + "Aggregate.variable", + ) mapped = self.simulation.output_dataset.data.map_to_entity( var_obj.entity, target_entity, columns=[self.variable] ) series = mapped[self.variable] else: + require_output_column( + data, + self.variable, + target_entity, + self.simulation, + "Aggregate.variable", + ) series = data[self.variable] # Apply filters if self.filter_variable is not None: - filter_var_obj = next( - v - for v in self.simulation.tax_benefit_model_version.variables - if v.name == self.filter_variable + filter_var_obj = get_aggregate_variable( + self.simulation, self.filter_variable, "Aggregate.filter_variable" ) if filter_var_obj.entity != target_entity: + filter_source_data = get_output_entity_data( + self.simulation, + filter_var_obj.entity, + "Aggregate.filter_variable", + ) + require_output_column( + filter_source_data, + self.filter_variable, + filter_var_obj.entity, + self.simulation, + "Aggregate.filter_variable", + ) filter_mapped = self.simulation.output_dataset.data.map_to_entity( filter_var_obj.entity, target_entity, @@ -83,6 +168,13 @@ def run(self): ) filter_series = filter_mapped[self.filter_variable] else: + require_output_column( + data, + self.filter_variable, + target_entity, + self.simulation, + "Aggregate.filter_variable", + ) filter_series = data[self.filter_variable] if self.filter_variable_describes_quantiles: diff --git a/src/policyengine/outputs/change_aggregate.py b/src/policyengine/outputs/change_aggregate.py index 87d2e0d9..f9ea6502 100644 --- a/src/policyengine/outputs/change_aggregate.py +++ b/src/policyengine/outputs/change_aggregate.py @@ -2,6 +2,11 @@ from typing import Any, Optional from policyengine.core import Output, Simulation +from policyengine.outputs.aggregate import ( + get_aggregate_variable, + get_output_entity_data, + require_output_column, +) class ChangeAggregateType(str, Enum): @@ -59,34 +64,75 @@ def run(self): elif self.quantile_geq is not None: self.filter_variable_geq = (self.quantile_geq - 1) / self.quantile - # Get variable object - var_obj = next( - v - for v in self.baseline_simulation.tax_benefit_model_version.variables - if v.name == self.variable + var_obj = get_aggregate_variable( + self.baseline_simulation, self.variable, "ChangeAggregate.variable" ) # Get the target entity data target_entity = self.entity or var_obj.entity - baseline_data = getattr( - self.baseline_simulation.output_dataset.data, target_entity + baseline_data = get_output_entity_data( + self.baseline_simulation, + target_entity, + "ChangeAggregate.baseline_entity", + ) + reform_data = get_output_entity_data( + self.reform_simulation, + target_entity, + "ChangeAggregate.reform_entity", ) - reform_data = getattr(self.reform_simulation.output_dataset.data, target_entity) # Map variable to target entity if needed if var_obj.entity != target_entity: + baseline_source_data = get_output_entity_data( + self.baseline_simulation, + var_obj.entity, + "ChangeAggregate.variable", + ) + reform_source_data = get_output_entity_data( + self.reform_simulation, + var_obj.entity, + "ChangeAggregate.variable", + ) + require_output_column( + baseline_source_data, + self.variable, + var_obj.entity, + self.baseline_simulation, + "ChangeAggregate.variable", + ) + require_output_column( + reform_source_data, + self.variable, + var_obj.entity, + self.reform_simulation, + "ChangeAggregate.variable", + ) baseline_mapped = ( self.baseline_simulation.output_dataset.data.map_to_entity( - var_obj.entity, target_entity + var_obj.entity, target_entity, columns=[self.variable] ) ) baseline_series = baseline_mapped[self.variable] reform_mapped = self.reform_simulation.output_dataset.data.map_to_entity( - var_obj.entity, target_entity + var_obj.entity, target_entity, columns=[self.variable] ) reform_series = reform_mapped[self.variable] else: + require_output_column( + baseline_data, + self.variable, + target_entity, + self.baseline_simulation, + "ChangeAggregate.variable", + ) + require_output_column( + reform_data, + self.variable, + target_entity, + self.reform_simulation, + "ChangeAggregate.variable", + ) baseline_series = baseline_data[self.variable] reform_series = reform_data[self.variable] @@ -124,20 +170,41 @@ def run(self): # Apply filter_variable filters if self.filter_variable is not None: - filter_var_obj = next( - v - for v in self.baseline_simulation.tax_benefit_model_version.variables - if v.name == self.filter_variable + filter_var_obj = get_aggregate_variable( + self.baseline_simulation, + self.filter_variable, + "ChangeAggregate.filter_variable", ) if filter_var_obj.entity != target_entity: + filter_source_data = get_output_entity_data( + self.baseline_simulation, + filter_var_obj.entity, + "ChangeAggregate.filter_variable", + ) + require_output_column( + filter_source_data, + self.filter_variable, + filter_var_obj.entity, + self.baseline_simulation, + "ChangeAggregate.filter_variable", + ) filter_mapped = ( self.baseline_simulation.output_dataset.data.map_to_entity( - filter_var_obj.entity, target_entity + filter_var_obj.entity, + target_entity, + columns=[self.filter_variable], ) ) filter_series = filter_mapped[self.filter_variable] else: + require_output_column( + baseline_data, + self.filter_variable, + target_entity, + self.baseline_simulation, + "ChangeAggregate.filter_variable", + ) filter_series = baseline_data[self.filter_variable] if self.filter_variable_describes_quantiles: diff --git a/src/policyengine/outputs/program_statistics.py b/src/policyengine/outputs/program_statistics.py index a48ff8a8..ac029734 100644 --- a/src/policyengine/outputs/program_statistics.py +++ b/src/policyengine/outputs/program_statistics.py @@ -21,6 +21,7 @@ class ProgramStatistics(Output): reform_simulation: Simulation program_name: str entity: str + variable_name: Optional[str] = None is_tax: bool = False # Results populated by run() @@ -34,10 +35,12 @@ class ProgramStatistics(Output): def run(self): """Calculate statistics for this program.""" + variable_name = self.variable_name or self.program_name + # Baseline totals baseline_total = Aggregate( simulation=self.baseline_simulation, - variable=self.program_name, + variable=variable_name, aggregate_type=AggregateType.SUM, entity=self.entity, ) @@ -46,7 +49,7 @@ def run(self): # Reform totals reform_total = Aggregate( simulation=self.reform_simulation, - variable=self.program_name, + variable=variable_name, aggregate_type=AggregateType.SUM, entity=self.entity, ) @@ -55,10 +58,10 @@ def run(self): # Count of recipients/payers (baseline) baseline_count = Aggregate( simulation=self.baseline_simulation, - variable=self.program_name, + variable=variable_name, aggregate_type=AggregateType.COUNT, entity=self.entity, - filter_variable=self.program_name, + filter_variable=variable_name, filter_variable_geq=0.01, ) baseline_count.run() @@ -66,10 +69,10 @@ def run(self): # Count of recipients/payers (reform) reform_count = Aggregate( simulation=self.reform_simulation, - variable=self.program_name, + variable=variable_name, aggregate_type=AggregateType.COUNT, entity=self.entity, - filter_variable=self.program_name, + filter_variable=variable_name, filter_variable_geq=0.01, ) reform_count.run() @@ -78,7 +81,7 @@ def run(self): winners = ChangeAggregate( baseline_simulation=self.baseline_simulation, reform_simulation=self.reform_simulation, - variable=self.program_name, + variable=variable_name, aggregate_type=ChangeAggregateType.COUNT, entity=self.entity, change_geq=0.01 if not self.is_tax else -0.01, @@ -88,7 +91,7 @@ def run(self): losers = ChangeAggregate( baseline_simulation=self.baseline_simulation, reform_simulation=self.reform_simulation, - variable=self.program_name, + variable=variable_name, aggregate_type=ChangeAggregateType.COUNT, entity=self.entity, change_leq=-0.01 if not self.is_tax else 0.01, diff --git a/src/policyengine/tax_benefit_models/us/analysis.py b/src/policyengine/tax_benefit_models/us/analysis.py index 8b3eefc8..96ce956b 100644 --- a/src/policyengine/tax_benefit_models/us/analysis.py +++ b/src/policyengine/tax_benefit_models/us/analysis.py @@ -27,6 +27,24 @@ calculate_us_poverty_rates, ) +US_PROGRAMS = { + "income_tax": {"entity": "tax_unit", "is_tax": True}, + "employee_payroll_tax": {"entity": "tax_unit", "is_tax": True}, + "state_income_tax": { + "entity": "tax_unit", + "variable_name": "household_state_income_tax", + "is_tax": True, + }, + "snap": {"entity": "spm_unit", "is_tax": False}, + "tanf": {"entity": "spm_unit", "is_tax": False}, + "ssi": {"entity": "person", "is_tax": False}, + "social_security": {"entity": "person", "is_tax": False}, + "medicare_cost": {"entity": "person", "is_tax": False}, + "medicaid": {"entity": "person", "is_tax": False}, + "eitc": {"entity": "tax_unit", "is_tax": False}, + "ctc": {"entity": "tax_unit", "is_tax": False}, +} + class PolicyReformAnalysis(BaseModel): """Complete policy reform analysis result.""" @@ -39,6 +57,56 @@ class PolicyReformAnalysis(BaseModel): reform_inequality: Inequality +def _validate_program_statistics_config( + baseline_simulation: Simulation, + reform_simulation: Simulation, +) -> None: + """Validate US program-stat variables before running simulations.""" + missing_variables: set[str] = set() + missing_outputs: set[tuple[str, str, str]] = set() + + simulations = (baseline_simulation, reform_simulation) + for program_name, program_info in US_PROGRAMS.items(): + variable_name = program_info.get("variable_name", program_name) + + for simulation in simulations: + model_version = simulation.tax_benefit_model_version + try: + variable = model_version.get_variable(variable_name) + except ValueError: + missing_variables.add(variable_name) + continue + + resolved_variables = model_version.resolve_entity_variables(simulation) + if variable_name not in resolved_variables.get(variable.entity, []): + missing_outputs.add( + (program_name, variable_name, variable.entity) + ) + + if not missing_variables and not missing_outputs: + return + + lines = ["US program statistics config is invalid:"] + if missing_variables: + lines.append( + "Missing model variables: " + ", ".join(sorted(missing_variables)) + ) + if missing_outputs: + formatted = ", ".join( + f"{program_name} -> {variable_name} on {entity}" + for program_name, variable_name, entity in sorted(missing_outputs) + ) + lines.append( + "Variables not materialized in simulation outputs: " + formatted + ) + lines.append( + "Add them to the model version's entity_variables or pass them " + "via Simulation.extra_variables before running the simulation." + ) + + raise ValueError("\n".join(lines)) + + def economic_impact_analysis( baseline_simulation: Simulation, reform_simulation: Simulation, @@ -55,6 +123,8 @@ def economic_impact_analysis( ``PolicyReformAnalysis`` with decile impacts, program statistics, baseline and reform poverty, and inequality. """ + _validate_program_statistics_config(baseline_simulation, reform_simulation) + baseline_simulation.ensure() reform_simulation.ensure() @@ -71,26 +141,13 @@ def economic_impact_analysis( income_variable="household_net_income", ) - programs = { - "income_tax": {"entity": "tax_unit", "is_tax": True}, - "payroll_tax": {"entity": "person", "is_tax": True}, - "state_income_tax": {"entity": "tax_unit", "is_tax": True}, - "snap": {"entity": "spm_unit", "is_tax": False}, - "tanf": {"entity": "spm_unit", "is_tax": False}, - "ssi": {"entity": "person", "is_tax": False}, - "social_security": {"entity": "person", "is_tax": False}, - "medicare": {"entity": "person", "is_tax": False}, - "medicaid": {"entity": "person", "is_tax": False}, - "eitc": {"entity": "tax_unit", "is_tax": False}, - "ctc": {"entity": "tax_unit", "is_tax": False}, - } - program_statistics = [] - for program_name, program_info in programs.items(): + for program_name, program_info in US_PROGRAMS.items(): stats = ProgramStatistics( baseline_simulation=baseline_simulation, reform_simulation=reform_simulation, program_name=program_name, + variable_name=program_info.get("variable_name", program_name), entity=program_info["entity"], is_tax=program_info["is_tax"], ) diff --git a/src/policyengine/tax_benefit_models/us/model.py b/src/policyengine/tax_benefit_models/us/model.py index 184dd110..fd04e4fa 100644 --- a/src/policyengine/tax_benefit_models/us/model.py +++ b/src/policyengine/tax_benefit_models/us/model.py @@ -65,6 +65,7 @@ class PolicyEngineUSLatest(MicrosimulationModelVersion): # Benefits "ssi", "social_security", + "medicare_cost", "medicaid", "unemployment_compensation", ], diff --git a/tests/test_aggregate.py b/tests/test_aggregate.py index 5b4e8b27..56abb196 100644 --- a/tests/test_aggregate.py +++ b/tests/test_aggregate.py @@ -478,7 +478,7 @@ def test_aggregate_invalid_variable(): variable="nonexistent_variable", aggregate_type=AggregateType.SUM, ) - with pytest.raises(StopIteration): + with pytest.raises(ValueError, match="Aggregate.variable"): agg.run() # Invalid filter variable name should raise error on run() @@ -488,5 +488,5 @@ def test_aggregate_invalid_variable(): aggregate_type=AggregateType.SUM, filter_variable="nonexistent_filter", ) - with pytest.raises(StopIteration): + with pytest.raises(ValueError, match="Aggregate.filter_variable"): agg.run() diff --git a/tests/test_us_program_statistics.py b/tests/test_us_program_statistics.py new file mode 100644 index 00000000..ffc39536 --- /dev/null +++ b/tests/test_us_program_statistics.py @@ -0,0 +1,148 @@ +import pandas as pd +from microdf import MicroDataFrame + +from policyengine.core import Simulation +from policyengine.outputs import ProgramStatistics +from policyengine.tax_benefit_models.us.analysis import ( + US_PROGRAMS, + _validate_program_statistics_config, +) +from policyengine.tax_benefit_models.us.datasets import ( + PolicyEngineUSDataset, + USYearData, +) +from policyengine.tax_benefit_models.us.model import us_latest + + +def _microdf(data: dict, weights: str) -> MicroDataFrame: + return MicroDataFrame(pd.DataFrame(data), weights=weights) + + +def _make_us_output_simulation(tmp_path, simulation_id: str, multiplier: float): + data = USYearData( + person=_microdf( + { + "person_id": [1, 2], + "household_id": [1, 2], + "marital_unit_id": [1, 2], + "family_id": [1, 2], + "spm_unit_id": [1, 2], + "tax_unit_id": [1, 2], + "person_weight": [1.0, 2.0], + "ssi": [100.0 * multiplier, 0.0], + "social_security": [0.0, 200.0 * multiplier], + "medicare_cost": [300.0 * multiplier, 0.0], + "medicaid": [0.0, 400.0 * multiplier], + }, + "person_weight", + ), + marital_unit=_microdf( + { + "marital_unit_id": [1, 2], + "marital_unit_weight": [1.0, 2.0], + }, + "marital_unit_weight", + ), + family=_microdf( + { + "family_id": [1, 2], + "family_weight": [1.0, 2.0], + }, + "family_weight", + ), + spm_unit=_microdf( + { + "spm_unit_id": [1, 2], + "spm_unit_weight": [1.0, 2.0], + "snap": [500.0 * multiplier, 0.0], + "tanf": [0.0, 600.0 * multiplier], + }, + "spm_unit_weight", + ), + tax_unit=_microdf( + { + "tax_unit_id": [1, 2], + "tax_unit_weight": [1.0, 2.0], + "income_tax": [700.0 * multiplier, 0.0], + "employee_payroll_tax": [0.0, 800.0 * multiplier], + "household_state_income_tax": [900.0 * multiplier, 0.0], + "eitc": [0.0, 1_000.0 * multiplier], + "ctc": [1_100.0 * multiplier, 0.0], + }, + "tax_unit_weight", + ), + household=_microdf( + { + "household_id": [1, 2], + "household_weight": [1.0, 2.0], + }, + "household_weight", + ), + ) + dataset = PolicyEngineUSDataset( + id=simulation_id, + name=f"{simulation_id} output", + description="Mocked US output dataset for program statistics", + filepath=str(tmp_path / f"{simulation_id}.h5"), + year=2026, + is_output_dataset=True, + data=data, + ) + return Simulation( + id=simulation_id, + dataset=dataset, + tax_benefit_model_version=us_latest, + output_dataset=dataset, + ) + + +def test_us_program_statistics_config_runs_against_mocked_outputs(tmp_path): + baseline = _make_us_output_simulation(tmp_path, "baseline", 1.0) + reform = _make_us_output_simulation(tmp_path, "reform", 2.0) + + _validate_program_statistics_config(baseline, reform) + + results = {} + for program_name, program_info in US_PROGRAMS.items(): + stats = ProgramStatistics( + baseline_simulation=baseline, + reform_simulation=reform, + program_name=program_name, + variable_name=program_info.get("variable_name", program_name), + entity=program_info["entity"], + is_tax=program_info["is_tax"], + ) + stats.run() + results[program_name] = stats + + assert set(results) == set(US_PROGRAMS) + assert results["employee_payroll_tax"].baseline_total == 1_600.0 + assert results["medicare_cost"].baseline_total == 300.0 + assert results["state_income_tax"].variable_name == "household_state_income_tax" + assert results["state_income_tax"].baseline_total == 900.0 + + +def test_us_program_statistics_config_fails_before_simulation_run( + tmp_path, monkeypatch +): + baseline = _make_us_output_simulation(tmp_path, "baseline", 1.0) + reform = _make_us_output_simulation(tmp_path, "reform", 2.0) + + entity_variables = { + entity: list(variables) + for entity, variables in us_latest.entity_variables.items() + } + entity_variables["person"].remove("medicare_cost") + monkeypatch.setattr( + baseline.tax_benefit_model_version, + "entity_variables", + entity_variables, + ) + + try: + _validate_program_statistics_config(baseline, reform) + except ValueError as exc: + assert "US program statistics config is invalid" in str(exc) + assert "medicare_cost" in str(exc) + else: + raise AssertionError("Expected ValueError") From e9a73dc4fbb5d03976a495cdb15a008ae87385aa Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 30 Apr 2026 23:59:23 +0200 Subject: [PATCH 2/8] Format US program statistics validation --- src/policyengine/tax_benefit_models/us/analysis.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/policyengine/tax_benefit_models/us/analysis.py b/src/policyengine/tax_benefit_models/us/analysis.py index 96ce956b..625b13ab 100644 --- a/src/policyengine/tax_benefit_models/us/analysis.py +++ b/src/policyengine/tax_benefit_models/us/analysis.py @@ -79,26 +79,20 @@ def _validate_program_statistics_config( resolved_variables = model_version.resolve_entity_variables(simulation) if variable_name not in resolved_variables.get(variable.entity, []): - missing_outputs.add( - (program_name, variable_name, variable.entity) - ) + missing_outputs.add((program_name, variable_name, variable.entity)) if not missing_variables and not missing_outputs: return lines = ["US program statistics config is invalid:"] if missing_variables: - lines.append( - "Missing model variables: " + ", ".join(sorted(missing_variables)) - ) + lines.append("Missing model variables: " + ", ".join(sorted(missing_variables))) if missing_outputs: formatted = ", ".join( f"{program_name} -> {variable_name} on {entity}" for program_name, variable_name, entity in sorted(missing_outputs) ) - lines.append( - "Variables not materialized in simulation outputs: " + formatted - ) + lines.append("Variables not materialized in simulation outputs: " + formatted) lines.append( "Add them to the model version's entity_variables or pass them " "via Simulation.extra_variables before running the simulation." From fe7863677e405a4cab12e58c3b7810c06c173c64 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 1 May 2026 15:06:14 +0200 Subject: [PATCH 3/8] Restore direct US program statistic mappings --- .../outputs/program_statistics.py | 19 +++++++-------- .../tax_benefit_models/us/analysis.py | 23 +++++++------------ .../tax_benefit_models/us/model.py | 1 + tests/test_us_program_statistics.py | 4 +--- 4 files changed, 18 insertions(+), 29 deletions(-) diff --git a/src/policyengine/outputs/program_statistics.py b/src/policyengine/outputs/program_statistics.py index ac029734..a48ff8a8 100644 --- a/src/policyengine/outputs/program_statistics.py +++ b/src/policyengine/outputs/program_statistics.py @@ -21,7 +21,6 @@ class ProgramStatistics(Output): reform_simulation: Simulation program_name: str entity: str - variable_name: Optional[str] = None is_tax: bool = False # Results populated by run() @@ -35,12 +34,10 @@ class ProgramStatistics(Output): def run(self): """Calculate statistics for this program.""" - variable_name = self.variable_name or self.program_name - # Baseline totals baseline_total = Aggregate( simulation=self.baseline_simulation, - variable=variable_name, + variable=self.program_name, aggregate_type=AggregateType.SUM, entity=self.entity, ) @@ -49,7 +46,7 @@ def run(self): # Reform totals reform_total = Aggregate( simulation=self.reform_simulation, - variable=variable_name, + variable=self.program_name, aggregate_type=AggregateType.SUM, entity=self.entity, ) @@ -58,10 +55,10 @@ def run(self): # Count of recipients/payers (baseline) baseline_count = Aggregate( simulation=self.baseline_simulation, - variable=variable_name, + variable=self.program_name, aggregate_type=AggregateType.COUNT, entity=self.entity, - filter_variable=variable_name, + filter_variable=self.program_name, filter_variable_geq=0.01, ) baseline_count.run() @@ -69,10 +66,10 @@ def run(self): # Count of recipients/payers (reform) reform_count = Aggregate( simulation=self.reform_simulation, - variable=variable_name, + variable=self.program_name, aggregate_type=AggregateType.COUNT, entity=self.entity, - filter_variable=variable_name, + filter_variable=self.program_name, filter_variable_geq=0.01, ) reform_count.run() @@ -81,7 +78,7 @@ def run(self): winners = ChangeAggregate( baseline_simulation=self.baseline_simulation, reform_simulation=self.reform_simulation, - variable=variable_name, + variable=self.program_name, aggregate_type=ChangeAggregateType.COUNT, entity=self.entity, change_geq=0.01 if not self.is_tax else -0.01, @@ -91,7 +88,7 @@ def run(self): losers = ChangeAggregate( baseline_simulation=self.baseline_simulation, reform_simulation=self.reform_simulation, - variable=variable_name, + variable=self.program_name, aggregate_type=ChangeAggregateType.COUNT, entity=self.entity, change_leq=-0.01 if not self.is_tax else 0.01, diff --git a/src/policyengine/tax_benefit_models/us/analysis.py b/src/policyengine/tax_benefit_models/us/analysis.py index 625b13ab..e3142586 100644 --- a/src/policyengine/tax_benefit_models/us/analysis.py +++ b/src/policyengine/tax_benefit_models/us/analysis.py @@ -30,11 +30,7 @@ US_PROGRAMS = { "income_tax": {"entity": "tax_unit", "is_tax": True}, "employee_payroll_tax": {"entity": "tax_unit", "is_tax": True}, - "state_income_tax": { - "entity": "tax_unit", - "variable_name": "household_state_income_tax", - "is_tax": True, - }, + "state_income_tax": {"entity": "tax_unit", "is_tax": True}, "snap": {"entity": "spm_unit", "is_tax": False}, "tanf": {"entity": "spm_unit", "is_tax": False}, "ssi": {"entity": "person", "is_tax": False}, @@ -63,23 +59,21 @@ def _validate_program_statistics_config( ) -> None: """Validate US program-stat variables before running simulations.""" missing_variables: set[str] = set() - missing_outputs: set[tuple[str, str, str]] = set() + missing_outputs: set[tuple[str, str]] = set() simulations = (baseline_simulation, reform_simulation) for program_name, program_info in US_PROGRAMS.items(): - variable_name = program_info.get("variable_name", program_name) - for simulation in simulations: model_version = simulation.tax_benefit_model_version try: - variable = model_version.get_variable(variable_name) + variable = model_version.get_variable(program_name) except ValueError: - missing_variables.add(variable_name) + missing_variables.add(program_name) continue resolved_variables = model_version.resolve_entity_variables(simulation) - if variable_name not in resolved_variables.get(variable.entity, []): - missing_outputs.add((program_name, variable_name, variable.entity)) + if program_name not in resolved_variables.get(variable.entity, []): + missing_outputs.add((program_name, variable.entity)) if not missing_variables and not missing_outputs: return @@ -89,8 +83,8 @@ def _validate_program_statistics_config( lines.append("Missing model variables: " + ", ".join(sorted(missing_variables))) if missing_outputs: formatted = ", ".join( - f"{program_name} -> {variable_name} on {entity}" - for program_name, variable_name, entity in sorted(missing_outputs) + f"{program_name} on {entity}" + for program_name, entity in sorted(missing_outputs) ) lines.append("Variables not materialized in simulation outputs: " + formatted) lines.append( @@ -141,7 +135,6 @@ def economic_impact_analysis( baseline_simulation=baseline_simulation, reform_simulation=reform_simulation, program_name=program_name, - variable_name=program_info.get("variable_name", program_name), entity=program_info["entity"], is_tax=program_info["is_tax"], ) diff --git a/src/policyengine/tax_benefit_models/us/model.py b/src/policyengine/tax_benefit_models/us/model.py index fd04e4fa..655e05d6 100644 --- a/src/policyengine/tax_benefit_models/us/model.py +++ b/src/policyengine/tax_benefit_models/us/model.py @@ -92,6 +92,7 @@ class PolicyEngineUSLatest(MicrosimulationModelVersion): "tax_unit_weight", "income_tax", "employee_payroll_tax", + "state_income_tax", "household_state_income_tax", "eitc", "ctc", diff --git a/tests/test_us_program_statistics.py b/tests/test_us_program_statistics.py index ffc39536..42b7662d 100644 --- a/tests/test_us_program_statistics.py +++ b/tests/test_us_program_statistics.py @@ -65,7 +65,7 @@ def _make_us_output_simulation(tmp_path, simulation_id: str, multiplier: float): "tax_unit_weight": [1.0, 2.0], "income_tax": [700.0 * multiplier, 0.0], "employee_payroll_tax": [0.0, 800.0 * multiplier], - "household_state_income_tax": [900.0 * multiplier, 0.0], + "state_income_tax": [900.0 * multiplier, 0.0], "eitc": [0.0, 1_000.0 * multiplier], "ctc": [1_100.0 * multiplier, 0.0], }, @@ -108,7 +108,6 @@ def test_us_program_statistics_config_runs_against_mocked_outputs(tmp_path): baseline_simulation=baseline, reform_simulation=reform, program_name=program_name, - variable_name=program_info.get("variable_name", program_name), entity=program_info["entity"], is_tax=program_info["is_tax"], ) @@ -118,7 +117,6 @@ def test_us_program_statistics_config_runs_against_mocked_outputs(tmp_path): assert set(results) == set(US_PROGRAMS) assert results["employee_payroll_tax"].baseline_total == 1_600.0 assert results["medicare_cost"].baseline_total == 300.0 - assert results["state_income_tax"].variable_name == "household_state_income_tax" assert results["state_income_tax"].baseline_total == 900.0 From 2726503929bcad65f7222136e33967bd2fe59dc0 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 1 May 2026 15:18:19 +0200 Subject: [PATCH 4/8] Factor program statistics validation errors --- .../tax_benefit_models/us/analysis.py | 58 ++++++++++++++----- src/policyengine/utils/__init__.py | 2 + src/policyengine/utils/errors.py | 24 ++++++++ tests/test_errors.py | 19 ++++++ 4 files changed, 88 insertions(+), 15 deletions(-) create mode 100644 src/policyengine/utils/errors.py create mode 100644 tests/test_errors.py diff --git a/src/policyengine/tax_benefit_models/us/analysis.py b/src/policyengine/tax_benefit_models/us/analysis.py index e3142586..0e880399 100644 --- a/src/policyengine/tax_benefit_models/us/analysis.py +++ b/src/policyengine/tax_benefit_models/us/analysis.py @@ -26,6 +26,10 @@ Poverty, calculate_us_poverty_rates, ) +from policyengine.utils.errors import ( + create_error, + format_conditional_error_detail, +) US_PROGRAMS = { "income_tax": {"entity": "tax_unit", "is_tax": True}, @@ -53,6 +57,38 @@ class PolicyReformAnalysis(BaseModel): reform_inequality: Inequality +def _format_missing_program_variables(missing_variables: set[str]) -> str | None: + """Format the optional missing-variable detail for program statistics.""" + return format_conditional_error_detail( + "Missing model variables", + missing_variables, + ) + + +def _program_statistics_config_error_message( + missing_variables: set[str], + missing_outputs: set[tuple[str, str]], +) -> str: + lines = ["US program statistics config is invalid:"] + + missing_variables_message = _format_missing_program_variables(missing_variables) + if missing_variables_message is not None: + lines.append(missing_variables_message) + + if missing_outputs: + formatted = ", ".join( + f"{program_name} on {entity}" + for program_name, entity in sorted(missing_outputs) + ) + lines.append("Variables not materialized in simulation outputs: " + formatted) + lines.append( + "Add them to the model version's entity_variables or pass them " + "via Simulation.extra_variables before running the simulation." + ) + + return "\n".join(lines) + + def _validate_program_statistics_config( baseline_simulation: Simulation, reform_simulation: Simulation, @@ -78,21 +114,13 @@ def _validate_program_statistics_config( if not missing_variables and not missing_outputs: return - lines = ["US program statistics config is invalid:"] - if missing_variables: - lines.append("Missing model variables: " + ", ".join(sorted(missing_variables))) - if missing_outputs: - formatted = ", ".join( - f"{program_name} on {entity}" - for program_name, entity in sorted(missing_outputs) - ) - lines.append("Variables not materialized in simulation outputs: " + formatted) - lines.append( - "Add them to the model version's entity_variables or pass them " - "via Simulation.extra_variables before running the simulation." - ) - - raise ValueError("\n".join(lines)) + raise create_error( + ValueError, + _program_statistics_config_error_message( + missing_variables, + missing_outputs, + ), + ) def economic_impact_analysis( diff --git a/src/policyengine/utils/__init__.py b/src/policyengine/utils/__init__.py index bfbfe10b..d4a9e072 100644 --- a/src/policyengine/utils/__init__.py +++ b/src/policyengine/utils/__init__.py @@ -1,5 +1,7 @@ from .dates import parse_safe_date as parse_safe_date from .design import COLORS as COLORS +from .errors import create_error as create_error +from .errors import format_conditional_error_detail as format_conditional_error_detail from .parameter_labels import build_scale_lookup as build_scale_lookup from .parameter_labels import ( generate_label_for_parameter as generate_label_for_parameter, diff --git a/src/policyengine/utils/errors.py b/src/policyengine/utils/errors.py new file mode 100644 index 00000000..22e92877 --- /dev/null +++ b/src/policyengine/utils/errors.py @@ -0,0 +1,24 @@ +"""Shared helpers for constructing consistent errors.""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Optional, TypeVar + +ErrorT = TypeVar("ErrorT", bound=Exception) + + +def create_error(error_type: type[ErrorT], message: str) -> ErrorT: + """Create an exception instance from an error type and message.""" + return error_type(message) + + +def format_conditional_error_detail( + label: str, + values: Iterable[str], +) -> Optional[str]: + """Build a labelled error detail line when ``values`` is non-empty.""" + sorted_values = sorted(set(values)) + if not sorted_values: + return None + return f"{label}: {', '.join(sorted_values)}" diff --git a/tests/test_errors.py b/tests/test_errors.py new file mode 100644 index 00000000..bfefb0d8 --- /dev/null +++ b/tests/test_errors.py @@ -0,0 +1,19 @@ +from policyengine.utils.errors import ( + create_error, + format_conditional_error_detail, +) + + +def test_create_error_returns_requested_error_type(): + error = create_error(ValueError, "Example failure") + + assert isinstance(error, ValueError) + assert str(error) == "Example failure" + + +def test_format_conditional_error_detail(): + assert ( + format_conditional_error_detail("Missing model variables", {"beta", "alpha"}) + == "Missing model variables: alpha, beta" + ) + assert format_conditional_error_detail("Missing model variables", set()) is None From 0f665f208ef23e4e4baa52627b3ce301aa0362cb Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 1 May 2026 15:38:39 +0200 Subject: [PATCH 5/8] Update US household snapshots for program outputs --- .../us_married_two_kids_high_income.json | 5 +++++ .../us_single_adult_employment_income.json | 2 ++ .../us_single_adult_no_income.json | 2 ++ .../us_single_parent_one_child.json | 3 +++ 4 files changed, 12 insertions(+) diff --git a/tests/fixtures/household_calculator_snapshots/us_married_two_kids_high_income.json b/tests/fixtures/household_calculator_snapshots/us_married_two_kids_high_income.json index 1d5e98ca..0a8c2662 100644 --- a/tests/fixtures/household_calculator_snapshots/us_married_two_kids_high_income.json +++ b/tests/fixtures/household_calculator_snapshots/us_married_two_kids_high_income.json @@ -21,6 +21,7 @@ "person[0].is_male": 1.0, "person[0].marital_unit_id": 0.0, "person[0].medicaid": 0.0, + "person[0].medicare_cost": 14500.0, "person[0].person_id": 0.0, "person[0].person_weight": 1.0, "person[0].race": 3.0, @@ -38,6 +39,7 @@ "person[1].is_male": 1.0, "person[1].marital_unit_id": 0.0, "person[1].medicaid": 0.0, + "person[1].medicare_cost": 14500.0, "person[1].person_id": 1.0, "person[1].person_weight": 1.0, "person[1].race": 3.0, @@ -55,6 +57,7 @@ "person[2].is_male": 1.0, "person[2].marital_unit_id": 0.0, "person[2].medicaid": 0.0, + "person[2].medicare_cost": 14500.0, "person[2].person_id": 2.0, "person[2].person_weight": 1.0, "person[2].race": 3.0, @@ -72,6 +75,7 @@ "person[3].is_male": 1.0, "person[3].marital_unit_id": 0.0, "person[3].medicaid": 0.0, + "person[3].medicare_cost": 14500.0, "person[3].person_id": 3.0, "person[3].person_weight": 1.0, "person[3].race": 3.0, @@ -92,6 +96,7 @@ "tax_unit.employee_payroll_tax": 21480.0, "tax_unit.household_state_income_tax": 12690.07, "tax_unit.income_tax": 30740.0, + "tax_unit.state_income_tax": 12690.07, "tax_unit.tax_unit_id": 0.0, "tax_unit.tax_unit_weight": 1.0 } diff --git a/tests/fixtures/household_calculator_snapshots/us_single_adult_employment_income.json b/tests/fixtures/household_calculator_snapshots/us_single_adult_employment_income.json index d94660a9..8284c6fc 100644 --- a/tests/fixtures/household_calculator_snapshots/us_single_adult_employment_income.json +++ b/tests/fixtures/household_calculator_snapshots/us_single_adult_employment_income.json @@ -21,6 +21,7 @@ "person[0].is_male": 1.0, "person[0].marital_unit_id": 0.0, "person[0].medicaid": 0.0, + "person[0].medicare_cost": 14500.0, "person[0].person_id": 0.0, "person[0].person_weight": 1.0, "person[0].race": 3.0, @@ -41,6 +42,7 @@ "tax_unit.employee_payroll_tax": 5370.0, "tax_unit.household_state_income_tax": 1602.86, "tax_unit.income_tax": 5020.0, + "tax_unit.state_income_tax": 1602.86, "tax_unit.tax_unit_id": 0.0, "tax_unit.tax_unit_weight": 1.0 } diff --git a/tests/fixtures/household_calculator_snapshots/us_single_adult_no_income.json b/tests/fixtures/household_calculator_snapshots/us_single_adult_no_income.json index 258db6f1..b77b54f4 100644 --- a/tests/fixtures/household_calculator_snapshots/us_single_adult_no_income.json +++ b/tests/fixtures/household_calculator_snapshots/us_single_adult_no_income.json @@ -21,6 +21,7 @@ "person[0].is_male": 1.0, "person[0].marital_unit_id": 0.0, "person[0].medicaid": 6439.11, + "person[0].medicare_cost": 14500.0, "person[0].person_id": 0.0, "person[0].person_weight": 1.0, "person[0].race": 3.0, @@ -41,6 +42,7 @@ "tax_unit.employee_payroll_tax": 0.0, "tax_unit.household_state_income_tax": 0.0, "tax_unit.income_tax": 0.0, + "tax_unit.state_income_tax": 0.0, "tax_unit.tax_unit_id": 0.0, "tax_unit.tax_unit_weight": 1.0 } diff --git a/tests/fixtures/household_calculator_snapshots/us_single_parent_one_child.json b/tests/fixtures/household_calculator_snapshots/us_single_parent_one_child.json index 78ba7237..46504931 100644 --- a/tests/fixtures/household_calculator_snapshots/us_single_parent_one_child.json +++ b/tests/fixtures/household_calculator_snapshots/us_single_parent_one_child.json @@ -21,6 +21,7 @@ "person[0].is_male": 1.0, "person[0].marital_unit_id": 0.0, "person[0].medicaid": 0.0, + "person[0].medicare_cost": 14500.0, "person[0].person_id": 0.0, "person[0].person_weight": 1.0, "person[0].race": 3.0, @@ -38,6 +39,7 @@ "person[1].is_male": 1.0, "person[1].marital_unit_id": 0.0, "person[1].medicaid": 3258.31, + "person[1].medicare_cost": 14500.0, "person[1].person_id": 1.0, "person[1].person_weight": 1.0, "person[1].race": 3.0, @@ -58,6 +60,7 @@ "tax_unit.employee_payroll_tax": 3580.0, "tax_unit.household_state_income_tax": 0.0, "tax_unit.income_tax": -2467.62, + "tax_unit.state_income_tax": 0.0, "tax_unit.tax_unit_id": 0.0, "tax_unit.tax_unit_weight": 1.0 } From 21c7197ae1fa905504c533cabb2b1050590bbbd8 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 1 May 2026 15:51:42 +0200 Subject: [PATCH 6/8] Tighten aggregate error regression tests --- tests/test_change_aggregate.py | 113 ++++++++++++++++++++++++++++ tests/test_us_program_statistics.py | 12 +-- 2 files changed, 119 insertions(+), 6 deletions(-) diff --git a/tests/test_change_aggregate.py b/tests/test_change_aggregate.py index ea900db6..ec849154 100644 --- a/tests/test_change_aggregate.py +++ b/tests/test_change_aggregate.py @@ -2,6 +2,7 @@ import tempfile import pandas as pd +import pytest from microdf import MicroDataFrame from policyengine.core import ( @@ -18,6 +19,118 @@ ) +def _make_change_aggregate_simulations(tmp_path): + person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2], + "benunit_id": [1, 2], + "household_id": [1, 2], + "age": [30, 40], + "employment_income": [50000, 60000], + "person_weight": [1.0, 1.0], + } + ), + weights="person_weight", + ) + reform_person_df = MicroDataFrame( + pd.DataFrame( + { + "person_id": [1, 2], + "benunit_id": [1, 2], + "household_id": [1, 2], + "age": [30, 40], + "employment_income": [51000, 61000], + "person_weight": [1.0, 1.0], + } + ), + weights="person_weight", + ) + benunit_df = MicroDataFrame( + pd.DataFrame( + { + "benunit_id": [1, 2], + "benunit_weight": [1.0, 1.0], + } + ), + weights="benunit_weight", + ) + household_df = MicroDataFrame( + pd.DataFrame( + { + "household_id": [1, 2], + "household_weight": [1.0, 1.0], + } + ), + weights="household_weight", + ) + + baseline_dataset = PolicyEngineUKDataset( + name="Baseline", + description="Baseline dataset", + filepath=str(tmp_path / "baseline.h5"), + year=2024, + data=UKYearData(person=person_df, benunit=benunit_df, household=household_df), + ) + reform_dataset = PolicyEngineUKDataset( + name="Reform", + description="Reform dataset", + filepath=str(tmp_path / "reform.h5"), + year=2024, + data=UKYearData( + person=reform_person_df, + benunit=benunit_df, + household=household_df, + ), + ) + + baseline_sim = Simulation( + dataset=baseline_dataset, + tax_benefit_model_version=uk_latest, + output_dataset=baseline_dataset, + ) + reform_sim = Simulation( + dataset=reform_dataset, + tax_benefit_model_version=uk_latest, + output_dataset=reform_dataset, + ) + return baseline_sim, reform_sim + + +def test_change_aggregate_invalid_variable(tmp_path): + baseline_sim, reform_sim = _make_change_aggregate_simulations(tmp_path) + + agg = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="not_a_variable", + aggregate_type=ChangeAggregateType.COUNT, + ) + + with pytest.raises(ValueError, match="ChangeAggregate.variable") as exc_info: + agg.run() + + assert "not_a_variable" in str(exc_info.value) + + +def test_change_aggregate_invalid_filter_variable(tmp_path): + baseline_sim, reform_sim = _make_change_aggregate_simulations(tmp_path) + + agg = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="employment_income", + aggregate_type=ChangeAggregateType.COUNT, + filter_variable="not_a_filter_variable", + filter_variable_geq=0, + ) + + with pytest.raises(ValueError, match="ChangeAggregate.filter_variable") as exc_info: + agg.run() + + assert "not_a_filter_variable" in str(exc_info.value) + + def test_change_aggregate_count(): """Test counting people with any change.""" person_df = MicroDataFrame( diff --git a/tests/test_us_program_statistics.py b/tests/test_us_program_statistics.py index 42b7662d..2c5044f8 100644 --- a/tests/test_us_program_statistics.py +++ b/tests/test_us_program_statistics.py @@ -1,4 +1,5 @@ import pandas as pd +import pytest from microdf import MicroDataFrame from policyengine.core import Simulation @@ -137,10 +138,9 @@ def test_us_program_statistics_config_fails_before_simulation_run( entity_variables, ) - try: + with pytest.raises( + ValueError, match="US program statistics config is invalid" + ) as exc_info: _validate_program_statistics_config(baseline, reform) - except ValueError as exc: - assert "US program statistics config is invalid" in str(exc) - assert "medicare_cost" in str(exc) - else: - raise AssertionError("Expected ValueError") + + assert "medicare_cost" in str(exc_info.value) From 0c39e5a0c67e895445cad42583a220299fd34108 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 1 May 2026 15:59:19 +0200 Subject: [PATCH 7/8] Document program statistics count units --- src/policyengine/outputs/program_statistics.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/policyengine/outputs/program_statistics.py b/src/policyengine/outputs/program_statistics.py index a48ff8a8..ccb4f1e1 100644 --- a/src/policyengine/outputs/program_statistics.py +++ b/src/policyengine/outputs/program_statistics.py @@ -13,7 +13,12 @@ class ProgramStatistics(Output): - """Single program's statistics from a policy reform - represents one database row.""" + """Single program's statistics from a policy reform. + + Count fields are reported in the configured entity's units. For example, + a tax-unit variable reports tax-unit recipient/winner/loser counts, while + a person variable reports person counts. + """ model_config = ConfigDict(arbitrary_types_allowed=True) From ed1ab89e911dafb9714215b259b3a26010694d7f Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 5 May 2026 20:55:01 +0200 Subject: [PATCH 8/8] Address program statistics review cleanup --- src/policyengine/outputs/aggregate.py | 16 ++++++++++++---- .../tax_benefit_models/us/analysis.py | 8 ++------ src/policyengine/utils/__init__.py | 1 - src/policyengine/utils/errors.py | 9 +-------- tests/test_aggregate.py | 8 ++++++-- tests/test_change_aggregate.py | 6 ++++-- tests/test_errors.py | 12 +----------- 7 files changed, 26 insertions(+), 34 deletions(-) diff --git a/src/policyengine/outputs/aggregate.py b/src/policyengine/outputs/aggregate.py index 9e4d3b86..3f314d5f 100644 --- a/src/policyengine/outputs/aggregate.py +++ b/src/policyengine/outputs/aggregate.py @@ -2,7 +2,7 @@ from enum import Enum from typing import Any, Optional -from policyengine.core import Output, Simulation +from policyengine.core import Output, Simulation, Variable class AggregateType(str, Enum): @@ -11,7 +11,11 @@ class AggregateType(str, Enum): COUNT = "count" -def get_aggregate_variable(simulation: Simulation, variable: str, context: str): +def get_aggregate_variable( + simulation: Simulation, + variable: str, + context: str, +) -> Variable: """Return a model variable with an aggregation-specific error message.""" model_version = simulation.tax_benefit_model_version try: @@ -31,7 +35,11 @@ def get_aggregate_variable(simulation: Simulation, variable: str, context: str): ) from exc -def get_output_entity_data(simulation: Simulation, entity: str, context: str): +def get_output_entity_data( + simulation: Simulation, + entity: str, + context: str, +) -> Any: """Return output data for an entity with a clear error if it is unavailable.""" if simulation.output_dataset is None or simulation.output_dataset.data is None: raise ValueError( @@ -49,7 +57,7 @@ def get_output_entity_data(simulation: Simulation, entity: str, context: str): def require_output_column( - data, + data: Any, variable: str, entity: str, simulation: Simulation, diff --git a/src/policyengine/tax_benefit_models/us/analysis.py b/src/policyengine/tax_benefit_models/us/analysis.py index 0e880399..7bc1cd52 100644 --- a/src/policyengine/tax_benefit_models/us/analysis.py +++ b/src/policyengine/tax_benefit_models/us/analysis.py @@ -26,10 +26,7 @@ Poverty, calculate_us_poverty_rates, ) -from policyengine.utils.errors import ( - create_error, - format_conditional_error_detail, -) +from policyengine.utils.errors import format_conditional_error_detail US_PROGRAMS = { "income_tax": {"entity": "tax_unit", "is_tax": True}, @@ -114,8 +111,7 @@ def _validate_program_statistics_config( if not missing_variables and not missing_outputs: return - raise create_error( - ValueError, + raise ValueError( _program_statistics_config_error_message( missing_variables, missing_outputs, diff --git a/src/policyengine/utils/__init__.py b/src/policyengine/utils/__init__.py index d4a9e072..8cee3ff2 100644 --- a/src/policyengine/utils/__init__.py +++ b/src/policyengine/utils/__init__.py @@ -1,6 +1,5 @@ from .dates import parse_safe_date as parse_safe_date from .design import COLORS as COLORS -from .errors import create_error as create_error from .errors import format_conditional_error_detail as format_conditional_error_detail from .parameter_labels import build_scale_lookup as build_scale_lookup from .parameter_labels import ( diff --git a/src/policyengine/utils/errors.py b/src/policyengine/utils/errors.py index 22e92877..34213b59 100644 --- a/src/policyengine/utils/errors.py +++ b/src/policyengine/utils/errors.py @@ -3,14 +3,7 @@ from __future__ import annotations from collections.abc import Iterable -from typing import Optional, TypeVar - -ErrorT = TypeVar("ErrorT", bound=Exception) - - -def create_error(error_type: type[ErrorT], message: str) -> ErrorT: - """Create an exception instance from an error type and message.""" - return error_type(message) +from typing import Optional def format_conditional_error_detail( diff --git a/tests/test_aggregate.py b/tests/test_aggregate.py index 56abb196..28c29928 100644 --- a/tests/test_aggregate.py +++ b/tests/test_aggregate.py @@ -478,8 +478,10 @@ def test_aggregate_invalid_variable(): variable="nonexistent_variable", aggregate_type=AggregateType.SUM, ) - with pytest.raises(ValueError, match="Aggregate.variable"): + with pytest.raises(ValueError) as exc_info: agg.run() + assert "nonexistent_variable" in str(exc_info.value) + assert "references missing variable" in str(exc_info.value) # Invalid filter variable name should raise error on run() agg = Aggregate( @@ -488,5 +490,7 @@ def test_aggregate_invalid_variable(): aggregate_type=AggregateType.SUM, filter_variable="nonexistent_filter", ) - with pytest.raises(ValueError, match="Aggregate.filter_variable"): + with pytest.raises(ValueError) as exc_info: agg.run() + assert "nonexistent_filter" in str(exc_info.value) + assert "references missing variable" in str(exc_info.value) diff --git a/tests/test_change_aggregate.py b/tests/test_change_aggregate.py index ec849154..0728b880 100644 --- a/tests/test_change_aggregate.py +++ b/tests/test_change_aggregate.py @@ -107,10 +107,11 @@ def test_change_aggregate_invalid_variable(tmp_path): aggregate_type=ChangeAggregateType.COUNT, ) - with pytest.raises(ValueError, match="ChangeAggregate.variable") as exc_info: + with pytest.raises(ValueError) as exc_info: agg.run() assert "not_a_variable" in str(exc_info.value) + assert "references missing variable" in str(exc_info.value) def test_change_aggregate_invalid_filter_variable(tmp_path): @@ -125,10 +126,11 @@ def test_change_aggregate_invalid_filter_variable(tmp_path): filter_variable_geq=0, ) - with pytest.raises(ValueError, match="ChangeAggregate.filter_variable") as exc_info: + with pytest.raises(ValueError) as exc_info: agg.run() assert "not_a_filter_variable" in str(exc_info.value) + assert "references missing variable" in str(exc_info.value) def test_change_aggregate_count(): diff --git a/tests/test_errors.py b/tests/test_errors.py index bfefb0d8..81803b40 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1,14 +1,4 @@ -from policyengine.utils.errors import ( - create_error, - format_conditional_error_detail, -) - - -def test_create_error_returns_requested_error_type(): - error = create_error(ValueError, "Example failure") - - assert isinstance(error, ValueError) - assert str(error) == "Example failure" +from policyengine.utils.errors import format_conditional_error_detail def test_format_conditional_error_detail():