Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 111 additions & 11 deletions src/policyengine/outputs/aggregate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from difflib import get_close_matches
from enum import Enum
from typing import Any, Optional

from policyengine.core import Output, Simulation
from policyengine.core import Output, Simulation, Variable


class AggregateType(str, Enum):
Expand All @@ -10,6 +11,72 @@ class AggregateType(str, Enum):
COUNT = "count"


def get_aggregate_variable(
simulation: Simulation,
variable: str,
context: str,
) -> Variable:
"""Return a model variable with an aggregation-specific error message."""
model_version = simulation.tax_benefit_model_version
try:
return model_version.get_variable(variable)
except ValueError as exc:
candidates = sorted(model_version.variables_by_name)
suggestions = get_close_matches(variable, candidates, n=3, cutoff=0.65)
suggestion_text = (
f" Did you mean: {', '.join(repr(name) for name in suggestions)}?"
if suggestions
else ""
)
raise ValueError(
f"{context} references missing variable '{variable}' in "
f"{model_version.model.id} version {model_version.version}."
f"{suggestion_text}"
) from exc


def get_output_entity_data(
simulation: Simulation,
entity: str,
context: str,
) -> Any:
"""Return output data for an entity with a clear error if it is unavailable."""
if simulation.output_dataset is None or simulation.output_dataset.data is None:
raise ValueError(
f"{context} requires simulation '{simulation.id}' to have an "
"output dataset before aggregation."
)

try:
return getattr(simulation.output_dataset.data, entity)
except AttributeError as exc:
raise ValueError(
f"{context} references entity '{entity}', but simulation "
f"'{simulation.id}' has no output data for that entity."
) from exc


def require_output_column(
data: Any,
variable: str,
entity: str,
simulation: Simulation,
context: str,
) -> None:
"""Raise a descriptive error when a known variable was not materialized."""
if variable in data.columns:
return

model_version = simulation.tax_benefit_model_version
raise ValueError(
f"{context} variable '{variable}' exists in {model_version.model.id} "
f"version {model_version.version}, but is not present in simulation "
f"'{simulation.id}' output data for entity '{entity}'. Add '{variable}' "
f"to {model_version.__class__.__name__}.entity_variables or pass it via "
"Simulation.extra_variables before running the simulation."
)


class Aggregate(Output):
simulation: Simulation
variable: str
Expand Down Expand Up @@ -47,42 +114,75 @@ def run(self):
elif self.quantile_geq is not None:
self.filter_variable_geq = (self.quantile_geq - 1) / self.quantile

# Get variable object
var_obj = next(
v
for v in self.simulation.tax_benefit_model_version.variables
if v.name == self.variable
var_obj = get_aggregate_variable(
self.simulation, self.variable, "Aggregate.variable"
)

# Get the target entity data
target_entity = self.entity or var_obj.entity
data = getattr(self.simulation.output_dataset.data, target_entity)
data = get_output_entity_data(
self.simulation, target_entity, "Aggregate.entity"
)

# Map variable to target entity if needed
if var_obj.entity != target_entity:
source_data = get_output_entity_data(
self.simulation, var_obj.entity, "Aggregate.variable"
)
require_output_column(
source_data,
self.variable,
var_obj.entity,
self.simulation,
"Aggregate.variable",
)
mapped = self.simulation.output_dataset.data.map_to_entity(
var_obj.entity, target_entity, columns=[self.variable]
)
series = mapped[self.variable]
else:
require_output_column(
data,
self.variable,
target_entity,
self.simulation,
"Aggregate.variable",
)
series = data[self.variable]

# Apply filters
if self.filter_variable is not None:
filter_var_obj = next(
v
for v in self.simulation.tax_benefit_model_version.variables
if v.name == self.filter_variable
filter_var_obj = get_aggregate_variable(
self.simulation, self.filter_variable, "Aggregate.filter_variable"
)

if filter_var_obj.entity != target_entity:
filter_source_data = get_output_entity_data(
self.simulation,
filter_var_obj.entity,
"Aggregate.filter_variable",
)
require_output_column(
filter_source_data,
self.filter_variable,
filter_var_obj.entity,
self.simulation,
"Aggregate.filter_variable",
)
filter_mapped = self.simulation.output_dataset.data.map_to_entity(
filter_var_obj.entity,
target_entity,
columns=[self.filter_variable],
)
filter_series = filter_mapped[self.filter_variable]
else:
require_output_column(
data,
self.filter_variable,
target_entity,
self.simulation,
"Aggregate.filter_variable",
)
filter_series = data[self.filter_variable]

if self.filter_variable_describes_quantiles:
Expand Down
97 changes: 82 additions & 15 deletions src/policyengine/outputs/change_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from typing import Any, Optional

from policyengine.core import Output, Simulation
from policyengine.outputs.aggregate import (
get_aggregate_variable,
get_output_entity_data,
require_output_column,
)


class ChangeAggregateType(str, Enum):
Expand Down Expand Up @@ -59,34 +64,75 @@ def run(self):
elif self.quantile_geq is not None:
self.filter_variable_geq = (self.quantile_geq - 1) / self.quantile

# Get variable object
var_obj = next(
v
for v in self.baseline_simulation.tax_benefit_model_version.variables
if v.name == self.variable
var_obj = get_aggregate_variable(
self.baseline_simulation, self.variable, "ChangeAggregate.variable"
)

# Get the target entity data
target_entity = self.entity or var_obj.entity
baseline_data = getattr(
self.baseline_simulation.output_dataset.data, target_entity
baseline_data = get_output_entity_data(
self.baseline_simulation,
target_entity,
"ChangeAggregate.baseline_entity",
)
reform_data = get_output_entity_data(
self.reform_simulation,
target_entity,
"ChangeAggregate.reform_entity",
)
reform_data = getattr(self.reform_simulation.output_dataset.data, target_entity)

# Map variable to target entity if needed
if var_obj.entity != target_entity:
baseline_source_data = get_output_entity_data(
self.baseline_simulation,
var_obj.entity,
"ChangeAggregate.variable",
)
reform_source_data = get_output_entity_data(
self.reform_simulation,
var_obj.entity,
"ChangeAggregate.variable",
)
require_output_column(
baseline_source_data,
self.variable,
var_obj.entity,
self.baseline_simulation,
"ChangeAggregate.variable",
)
require_output_column(
reform_source_data,
self.variable,
var_obj.entity,
self.reform_simulation,
"ChangeAggregate.variable",
)
baseline_mapped = (
self.baseline_simulation.output_dataset.data.map_to_entity(
var_obj.entity, target_entity
var_obj.entity, target_entity, columns=[self.variable]
)
)
baseline_series = baseline_mapped[self.variable]

reform_mapped = self.reform_simulation.output_dataset.data.map_to_entity(
var_obj.entity, target_entity
var_obj.entity, target_entity, columns=[self.variable]
)
reform_series = reform_mapped[self.variable]
else:
require_output_column(
baseline_data,
self.variable,
target_entity,
self.baseline_simulation,
"ChangeAggregate.variable",
)
require_output_column(
reform_data,
self.variable,
target_entity,
self.reform_simulation,
"ChangeAggregate.variable",
)
baseline_series = baseline_data[self.variable]
reform_series = reform_data[self.variable]

Expand Down Expand Up @@ -124,20 +170,41 @@ def run(self):

# Apply filter_variable filters
if self.filter_variable is not None:
filter_var_obj = next(
v
for v in self.baseline_simulation.tax_benefit_model_version.variables
if v.name == self.filter_variable
filter_var_obj = get_aggregate_variable(
self.baseline_simulation,
self.filter_variable,
"ChangeAggregate.filter_variable",
)

if filter_var_obj.entity != target_entity:
filter_source_data = get_output_entity_data(
self.baseline_simulation,
filter_var_obj.entity,
"ChangeAggregate.filter_variable",
)
require_output_column(
filter_source_data,
self.filter_variable,
filter_var_obj.entity,
self.baseline_simulation,
"ChangeAggregate.filter_variable",
)
filter_mapped = (
self.baseline_simulation.output_dataset.data.map_to_entity(
filter_var_obj.entity, target_entity
filter_var_obj.entity,
target_entity,
columns=[self.filter_variable],
)
)
filter_series = filter_mapped[self.filter_variable]
else:
require_output_column(
baseline_data,
self.filter_variable,
target_entity,
self.baseline_simulation,
"ChangeAggregate.filter_variable",
)
filter_series = baseline_data[self.filter_variable]

if self.filter_variable_describes_quantiles:
Expand Down
7 changes: 6 additions & 1 deletion src/policyengine/outputs/program_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@


class ProgramStatistics(Output):
"""Single program's statistics from a policy reform - represents one database row."""
"""Single program's statistics from a policy reform.

Count fields are reported in the configured entity's units. For example,
a tax-unit variable reports tax-unit recipient/winner/loser counts, while
a person variable reports person counts.
"""

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand Down
Loading