diff --git a/.github/scripts/dispatch_publication_pipeline.sh b/.github/scripts/dispatch_publication_pipeline.sh new file mode 100644 index 000000000..cb9b326ce --- /dev/null +++ b/.github/scripts/dispatch_publication_pipeline.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +set -euo pipefail + +workflow_file="${PIPELINE_WORKFLOW_FILE:-pipeline.yaml}" +workflow_ref="${PIPELINE_WORKFLOW_REF:-main}" + +if [[ -z "${RUN_ID:-}" ]]; then + echo "RUN_ID is required" >&2 + exit 1 +fi + +if [[ -z "${SOURCE_SHA:-}" ]]; then + echo "SOURCE_SHA is required" >&2 + exit 1 +fi + +gh workflow run "${workflow_file}" \ + --ref "${workflow_ref}" \ + -f run_id="${RUN_ID}" \ + -f source_sha="${SOURCE_SHA}" + +if [[ -n "${GITHUB_STEP_SUMMARY:-}" ]]; then + { + echo "## Pipeline Dispatched" + echo + echo "| Field | Value |" + echo "|-------|-------|" + echo "| Run ID | \`${RUN_ID}\` |" + echo "| Source SHA | \`${SOURCE_SHA}\` |" + echo "| Workflow | \`${workflow_file}\` |" + echo "| Workflow ref | \`${workflow_ref}\` |" + } >> "${GITHUB_STEP_SUMMARY}" +fi diff --git a/.github/scripts/resolve_run_context.py b/.github/scripts/resolve_run_context.py new file mode 100644 index 000000000..0fec76a65 --- /dev/null +++ b/.github/scripts/resolve_run_context.py @@ -0,0 +1,86 @@ +"""Resolve run context for GitHub Actions workflows.""" + +from __future__ import annotations + +import os +import sys +from pathlib import Path + +_REPO_ROOT = Path(__file__).resolve().parents[2] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from policyengine_us_data.utils.run_context import ( # noqa: E402 + DEFAULT_MODAL_APP_PREFIX, + RunContext, + build_modal_resource_name, +) + + +def _append_key_values(path_env: str, values: dict[str, str]) -> None: + output_path = os.environ.get(path_env) + if not output_path: + return + with Path(output_path).open("a") as handle: + for key, value in values.items(): + handle.write(f"{key}={value}\n") + + +def main() -> None: + app_prefix = os.environ.get("US_DATA_MODAL_APP_PREFIX", DEFAULT_MODAL_APP_PREFIX) + context = RunContext.from_env(modal_app_prefix=app_prefix) + if not context.run_id: + raise RuntimeError( + "Could not resolve run ID. Set US_DATA_RUN_ID or run " + "inside GitHub Actions with GITHUB_RUN_ID." + ) + + pipeline_volume_name = os.environ.get( + "US_DATA_PIPELINE_VOLUME_NAME", + build_modal_resource_name( + context.run_id, + prefix="pipeline-artifacts", + ), + ) + staging_volume_name = os.environ.get( + "US_DATA_STAGING_VOLUME_NAME", + build_modal_resource_name( + context.run_id, + prefix="local-area-staging", + ), + ) + checkpoint_volume_name = os.environ.get( + "US_DATA_CHECKPOINT_VOLUME_NAME", + build_modal_resource_name( + context.run_id, + prefix="data-build-checkpoints", + ), + ) + context = RunContext.from_mapping( + { + **context.to_dict(), + "pipeline_volume_name": pipeline_volume_name, + "staging_volume_name": staging_volume_name, + "checkpoint_volume_name": checkpoint_volume_name, + }, + modal_app_name=context.modal_app_name, + modal_environment=context.modal_environment, + ) + + outputs = { + "run_id": context.run_id, + "modal_app_name": context.modal_app_name, + "modal_environment": context.modal_environment, + "hf_staging_prefix": context.hf_staging_prefix, + "github_run_url": context.github_run_url, + "pipeline_volume_name": context.pipeline_volume_name, + "staging_volume_name": context.staging_volume_name, + "checkpoint_volume_name": context.checkpoint_volume_name, + } + _append_key_values("GITHUB_OUTPUT", outputs) + _append_key_values("GITHUB_ENV", context.export_env()) + print(context.to_json()) + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/spawn_modal_pipeline.py b/.github/scripts/spawn_modal_pipeline.py index c29c6b4cf..4cbb0033d 100644 --- a/.github/scripts/spawn_modal_pipeline.py +++ b/.github/scripts/spawn_modal_pipeline.py @@ -1,14 +1,21 @@ import os +import sys from pathlib import Path import modal +_REPO_ROOT = Path(__file__).resolve().parents[2] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from policyengine_us_data.utils.run_context import RunContext # noqa: E402 + def _as_bool(value: str) -> bool: return value.lower() == "true" -def _append_summary(function_call_id: str) -> None: +def _append_summary(function_call_id: str, context: RunContext) -> None: summary_path = os.environ.get("GITHUB_STEP_SUMMARY") if not summary_path: return @@ -23,13 +30,20 @@ def _append_summary(function_call_id: str) -> None: f"`{os.environ['EPOCHS']}` / " f"`{os.environ['NATIONAL_EPOCHS']}` |\n" ) + handle.write(f"| Run ID | `{context.run_id}` |\n") + handle.write(f"| Modal app | `{context.modal_app_name}` |\n") + handle.write(f"| Modal environment | `{context.modal_environment}` |\n") + handle.write(f"| HF staging | `{context.hf_staging_prefix}` |\n") + if os.environ.get("SOURCE_SHA"): + handle.write(f"| Source SHA | `{os.environ['SOURCE_SHA']}` |\n") handle.write(f"| Function call ID | `{function_call_id}` |\n\n") handle.write("**[Monitor on Modal Dashboard](https://modal.com/apps)**\n") def main() -> None: - app_name = os.environ.get("MODAL_APP_NAME", "policyengine-us-data-pipeline") - environment_name = os.environ.get("MODAL_ENVIRONMENT") + context = RunContext.from_env() + app_name = context.modal_app_name or "policyengine-us-data-pipeline" + environment_name = context.modal_environment or os.environ.get("MODAL_ENVIRONMENT") kwargs = { "branch": os.environ.get("PIPELINE_BRANCH", "main"), "gpu": os.environ["GPU"], @@ -39,6 +53,11 @@ def main() -> None: "skip_national": _as_bool(os.environ["SKIP_NATIONAL"]), "resume_run_id": os.environ.get("RESUME_RUN_ID") or None, "version_override": os.environ.get("VERSION_OVERRIDE", ""), + "sha_override": os.environ.get("SOURCE_SHA", ""), + "run_id": context.run_id, + "run_context": context.to_dict(), + "modal_app_name": context.modal_app_name, + "modal_environment": context.modal_environment, } if environment_name: run_pipeline = modal.Function.from_name( @@ -50,8 +69,14 @@ def main() -> None: run_pipeline = modal.Function.from_name(app_name, "run_pipeline") function_call = run_pipeline.spawn(**kwargs) print("Pipeline spawned.") + print(f"Run ID: {context.run_id}") + print(f"Modal app: {app_name}") + print(f"Modal environment: {environment_name}") + print(f"HF staging prefix: {context.hf_staging_prefix}") + if os.environ.get("SOURCE_SHA"): + print(f"Source SHA: {os.environ['SOURCE_SHA']}") print(f"Function call ID: {function_call.object_id}") - _append_summary(function_call.object_id) + _append_summary(function_call.object_id, context) if __name__ == "__main__": diff --git a/.github/workflows/pipeline.yaml b/.github/workflows/pipeline.yaml index 3ef2f57c8..3c8efef75 100644 --- a/.github/workflows/pipeline.yaml +++ b/.github/workflows/pipeline.yaml @@ -1,8 +1,6 @@ name: Run Pipeline on: - push: - branches: [main] workflow_dispatch: inputs: gpu: @@ -33,19 +31,30 @@ on: description: "Override version (default: read from pyproject.toml)" default: "" type: string + run_id: + description: "Run ID to use across GitHub, Modal, and HF staging" + default: "" + type: string + source_sha: + description: "Exact policyengine-us-data commit SHA to deploy" + default: "" + type: string concurrency: - group: pipeline-main + group: pipeline-${{ github.run_id }}-${{ github.run_attempt }} cancel-in-progress: false jobs: pipeline: runs-on: ubuntu-latest - if: >- - github.event_name == 'workflow_dispatch' || - github.event.head_commit.message == 'Update package version' + env: + MODAL_ENVIRONMENT: main + US_DATA_MODAL_APP_PREFIX: policyengine-us-data-pub + US_DATA_RUN_ID: ${{ inputs.run_id || '' }} steps: - uses: actions/checkout@v4 + with: + ref: ${{ inputs.source_sha || github.sha }} - uses: actions/setup-python@v5 with: @@ -54,6 +63,10 @@ jobs: - name: Install Modal Runner Deps run: pip install modal pandas + - name: Resolve run context + id: run-context + run: python .github/scripts/resolve_run_context.py + - name: Deploy and launch pipeline on Modal env: MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} @@ -66,6 +79,7 @@ jobs: SKIP_NATIONAL: ${{ inputs.skip_national || 'false' }} RESUME_RUN_ID: ${{ inputs.resume_run_id || '' }} VERSION_OVERRIDE: ${{ inputs.version_override || '' }} + SOURCE_SHA: ${{ inputs.source_sha || github.sha }} run: | - modal deploy modal_app/pipeline.py + modal deploy --env="${MODAL_ENVIRONMENT}" --name="${MODAL_APP_NAME}" --tag="${RUN_ID}" modal_app/pipeline.py python .github/scripts/spawn_modal_pipeline.py diff --git a/.github/workflows/push.yaml b/.github/workflows/push.yaml index 005a6eaad..663405e21 100644 --- a/.github/workflows/push.yaml +++ b/.github/workflows/push.yaml @@ -12,28 +12,56 @@ jobs: - run: pip install ruff>=0.9.0 - run: ruff format --check . + run-context: + name: Run context + runs-on: ubuntu-latest + if: github.event.head_commit.message != 'Update package version' + outputs: + run_id: ${{ steps.run-context.outputs.run_id }} + github_run_url: ${{ steps.run-context.outputs.github_run_url }} + env: + MODAL_ENVIRONMENT: main + US_DATA_MODAL_APP_PREFIX: policyengine-us-data-pub + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.14" + - name: Resolve run context + id: run-context + run: python .github/scripts/resolve_run_context.py + # ── Dataset build ─────────────────────────────────────────── build-datasets: name: Build datasets runs-on: ubuntu-latest - needs: lint + needs: + - lint + - run-context if: github.event.head_commit.message != 'Update package version' env: MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} MODAL_ENVIRONMENT: main HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} + US_DATA_RUN_ID: ${{ needs.run-context.outputs.run_id }} steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: python-version: "3.14" - run: pip install modal + - name: Resolve run context + id: run-context + env: + US_DATA_MODAL_APP_PREFIX: policyengine-us-data-build + run: python .github/scripts/resolve_run_context.py - name: Build datasets on Modal run: | modal run --env="${MODAL_ENVIRONMENT}" modal_app/data_build.py \ --upload \ - --branch=${{ github.ref_name }} + --branch=${{ github.ref_name }} \ + --run-id="${RUN_ID}" # ── Documentation ────────────────────────────────────────── docs: @@ -67,7 +95,10 @@ jobs: versioning: name: Versioning runs-on: ubuntu-latest + needs: run-context if: github.event.head_commit.message != 'Update package version' + outputs: + version_sha: ${{ steps.version-commit.outputs.sha }} steps: - name: Generate GitHub App token id: app-token @@ -95,6 +126,29 @@ jobs: with: add: "." message: Update package version + - name: Capture version commit + id: version-commit + run: echo "sha=$(git rev-parse HEAD)" >> "$GITHUB_OUTPUT" + + # ── Full publication pipeline ─────────────────────────────── + launch-pipeline: + name: Launch publication pipeline + runs-on: ubuntu-latest + needs: + - run-context + - build-datasets + - versioning + if: github.event.head_commit.message != 'Update package version' + permissions: + actions: write + contents: read + steps: + - name: Dispatch pipeline workflow + env: + GH_TOKEN: ${{ github.token }} + RUN_ID: ${{ needs.run-context.outputs.run_id }} + SOURCE_SHA: ${{ needs.versioning.outputs.version_sha }} + run: bash .github/scripts/dispatch_publication_pipeline.sh # ── PyPI publish (version bump commits only) ──────────────── publish: diff --git a/changelog.d/phase-3c-publication-context.changed.md b/changelog.d/phase-3c-publication-context.changed.md new file mode 100644 index 000000000..debd410e9 --- /dev/null +++ b/changelog.d/phase-3c-publication-context.changed.md @@ -0,0 +1 @@ +Add run-scoped publication identity for GitHub, Modal, and Hugging Face staging. diff --git a/modal_app/data_build.py b/modal_app/data_build.py index 8b125dbc9..cd5ae9e52 100644 --- a/modal_app/data_build.py +++ b/modal_app/data_build.py @@ -1,13 +1,15 @@ import functools +import json import os import shutil import subprocess import sys import threading from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path -from typing import IO, Optional +from typing import IO, Any, Optional import modal @@ -18,21 +20,28 @@ sys.path.insert(0, _p) from modal_app.images import cpu_image as image # noqa: E402 +from policyengine_us_data.utils.run_context import ( # noqa: E402 + resolve_run_id, +) -app = modal.App("policyengine-us-data") +app = modal.App( + os.environ.get("US_DATA_DATA_BUILD_APP_NAME") + or os.environ.get("US_DATA_MODAL_APP_NAME") + or "policyengine-us-data" +) hf_secret = modal.Secret.from_name("huggingface-token") gcp_secret = modal.Secret.from_name("gcp-credentials") # Create persistent volume for checkpoints checkpoint_volume = modal.Volume.from_name( - "data-build-checkpoints", + os.environ.get("US_DATA_CHECKPOINT_VOLUME_NAME", "data-build-checkpoints"), create_if_missing=True, ) # Shared pipeline volume for inter-step artifact transport pipeline_volume = modal.Volume.from_name( - "pipeline-artifacts", + os.environ.get("US_DATA_PIPELINE_VOLUME_NAME", "pipeline-artifacts"), create_if_missing=True, ) PIPELINE_MOUNT = "/pipeline" @@ -40,6 +49,39 @@ VOLUME_MOUNT = "/checkpoints" _volume_lock = threading.Lock() + +@dataclass +class CheckpointStats: + expected_outputs: int = 0 + valid_reused_outputs: int = 0 + recomputed_outputs: int = 0 + invalid_outputs: int = 0 + _lock: Any = field(default_factory=threading.Lock, init=False, repr=False) + + def record( + self, + *, + expected_outputs: int, + valid_reused_outputs: int = 0, + recomputed_outputs: int = 0, + invalid_outputs: int = 0, + ) -> None: + with self._lock: + self.expected_outputs += expected_outputs + self.valid_reused_outputs += valid_reused_outputs + self.recomputed_outputs += recomputed_outputs + self.invalid_outputs += invalid_outputs + + def snapshot(self) -> dict[str, int]: + with self._lock: + return { + "expected_outputs": self.expected_outputs, + "valid_reused_outputs": self.valid_reused_outputs, + "recomputed_outputs": self.recomputed_outputs, + "invalid_outputs": self.invalid_outputs, + } + + # Script to output file mapping for checkpointing # Values can be a single file path (str) or a list of file paths SCRIPT_OUTPUTS = { @@ -291,6 +333,7 @@ def run_script_with_checkpoint( args: Optional[list] = None, env: Optional[dict] = None, log_file: IO = None, + checkpoint_stats: CheckpointStats | None = None, ) -> str: """Run script if output not checkpointed, then checkpoint result. @@ -309,6 +352,7 @@ def run_script_with_checkpoint( # Normalize to list if isinstance(output_files, str): output_files = [output_files] + expected_count = len(output_files) # Check if ALL outputs are checkpointed all_checkpointed = all(is_checkpointed(branch, f) for f in output_files) @@ -318,14 +362,29 @@ def run_script_with_checkpoint( for output_file in output_files: restore_from_checkpoint(branch, output_file) print(f"Skipping {script_path} (restored from checkpoint)") + if checkpoint_stats is not None: + checkpoint_stats.record( + expected_outputs=expected_count, + valid_reused_outputs=expected_count, + ) return script_path + missing_or_invalid = sum( + 1 for output_file in output_files if not is_checkpointed(branch, output_file) + ) + # Run the script run_script(script_path, args=args, env=env, log_file=log_file) # Checkpoint all outputs for output_file in output_files: save_checkpoint(branch, output_file, volume) + if checkpoint_stats is not None: + checkpoint_stats.record( + expected_outputs=expected_count, + recomputed_outputs=expected_count, + invalid_outputs=missing_or_invalid, + ) return script_path @@ -336,6 +395,7 @@ def run_cps_then_puf_phase( *, env: dict, log_file: IO = None, + checkpoint_stats: CheckpointStats | None = None, ) -> None: """Build CPS before PUF because PUF pension imputation loads CPS_2024.""" for script in (CPS_BUILD_SCRIPT, PUF_BUILD_SCRIPT): @@ -346,6 +406,7 @@ def run_cps_then_puf_phase( volume, env=env, log_file=log_file, + checkpoint_stats=checkpoint_stats, ) @@ -432,6 +493,14 @@ def build_datasets( stage_only: Upload to HF staging only, without promoting a release. """ setup_gcp_credentials() + checkpoint_stats = CheckpointStats() + run_id = run_id or resolve_run_id() + if not run_id: + raise RuntimeError( + "run_id is required. Production data builds must receive the " + "GitHub-created run ID via --run-id or US_DATA_RUN_ID." + ) + os.environ["US_DATA_RUN_ID"] = run_id # Reload volume to see latest checkpoints checkpoint_volume.reload() @@ -506,6 +575,7 @@ def build_datasets( checkpoint_volume, env=env, log_file=log_file, + checkpoint_stats=checkpoint_stats, ) else: # Parallel execution based on dependency groups with checkpointing @@ -535,6 +605,7 @@ def build_datasets( checkpoint_volume, env=env, log_file=log_file, + checkpoint_stats=checkpoint_stats, ): script for script, output in group1 } @@ -550,6 +621,7 @@ def build_datasets( checkpoint_volume, env=env, log_file=log_file, + checkpoint_stats=checkpoint_stats, ) # SEQUENTIAL: Extended CPS (needs both cps and puf) @@ -561,6 +633,7 @@ def build_datasets( checkpoint_volume, env=env, log_file=log_file, + checkpoint_stats=checkpoint_stats, ) # GROUP 3: After extended_cps - run in parallel @@ -580,6 +653,7 @@ def build_datasets( checkpoint_volume, env=env, log_file=log_file, + checkpoint_stats=checkpoint_stats, ) ) else: @@ -595,6 +669,7 @@ def build_datasets( checkpoint_volume, env=env, log_file=log_file, + checkpoint_stats=checkpoint_stats, ) ) for future in as_completed(phase4_futures): @@ -620,6 +695,7 @@ def build_datasets( checkpoint_volume, env=env, log_file=log_file, + checkpoint_stats=checkpoint_stats, ) ) if not skip_enhanced_cps: @@ -634,6 +710,7 @@ def build_datasets( checkpoint_volume, env=env, log_file=log_file, + checkpoint_stats=checkpoint_stats, ) ) else: @@ -681,6 +758,8 @@ def build_datasets( ) print(" Copied calibration_weights.npy") shutil.copy2(log_path, artifacts_dir / "build_log.txt") + with open(artifacts_dir / "data_build_checkpoint_stats.json", "w") as f: + json.dump(checkpoint_stats.snapshot(), f, indent=2, sort_keys=True) log_file.close() pipeline_volume.commit() print("Pipeline artifacts committed to shared volume") @@ -717,6 +796,11 @@ def main( stage_only: bool = False, run_id: str = "", ): + run_id = run_id or resolve_run_id() + if not run_id: + raise RuntimeError( + "run_id is required. Pass --run-id or run inside GitHub Actions." + ) result = build_datasets.remote( upload=upload, branch=branch, diff --git a/modal_app/h5_test_harness.py b/modal_app/h5_test_harness.py index 2c2cb9ed5..258f28cfc 100644 --- a/modal_app/h5_test_harness.py +++ b/modal_app/h5_test_harness.py @@ -4,6 +4,7 @@ import hashlib import json +import os import shutil import sys from pathlib import Path @@ -20,7 +21,10 @@ from modal_app.local_area import VOLUME_MOUNT, pipeline_volume, staging_volume # noqa: E402 -app = modal.App("policyengine-us-data-h5-test-harness") +app = modal.App( + os.environ.get("US_DATA_H5_HARNESS_APP_NAME") + or "policyengine-us-data-h5-test-harness" +) def _sha256(path: Path) -> str: diff --git a/modal_app/local_area.py b/modal_app/local_area.py index 831231307..88855ac17 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -36,19 +36,22 @@ from policyengine_us_data.calibration.local_h5.partitioning import ( # noqa: E402 partition_weighted_work_items, ) +from policyengine_us_data.utils.run_context import resolve_run_id # noqa: E402 -app = modal.App("policyengine-us-data-local-area") +app = modal.App( + os.environ.get("US_DATA_LOCAL_AREA_APP_NAME") or "policyengine-us-data-local-area" +) hf_secret = modal.Secret.from_name("huggingface-token") gcp_secret = modal.Secret.from_name("gcp-credentials") staging_volume = modal.Volume.from_name( - "local-area-staging", + os.environ.get("US_DATA_STAGING_VOLUME_NAME", "local-area-staging"), create_if_missing=True, ) pipeline_volume = modal.Volume.from_name( - "pipeline-artifacts", + os.environ.get("US_DATA_PIPELINE_VOLUME_NAME", "pipeline-artifacts"), create_if_missing=True, ) @@ -830,11 +833,12 @@ def coordinate_publish( version = get_version() + run_id = run_id or resolve_run_id() if not run_id: - from policyengine_us_data.utils.run_id import generate_run_id - - sha = os.environ.get("GIT_COMMIT", "unknown") - run_id = generate_run_id(version, sha) + raise RuntimeError( + "run_id is required. Local-area publishing must receive the " + "GitHub-created run ID from the pipeline." + ) print("=" * 60) print(f"Run ID: {run_id}") @@ -982,6 +986,7 @@ def coordinate_publish( staging_volume.reload() completed = get_completed_from_volume(run_dir) + initially_completed = set(completed) print(f"Found {len(completed)} already-completed items on volume") phase_args = dict( @@ -1033,12 +1038,22 @@ def coordinate_publish( f"Volume preserved for retry." ) + reused_outputs = initially_completed & completed + recomputed_outputs = completed - initially_completed + reuse_measurement = { + "expected_outputs": expected_total, + "valid_reused_outputs": len(reused_outputs), + "recomputed_outputs": len(recomputed_outputs), + "invalid_outputs": max(expected_total - len(completed), 0), + } + if skip_upload: print("\nSkipping upload (--skip-upload flag set)") return { "message": (f"Build complete for version {version}. Upload skipped."), "validation_rows": accumulated_validation_rows, "fingerprint": fingerprint, + "reuse_measurement": reuse_measurement, } print("\nValidating staging...") @@ -1075,6 +1090,7 @@ def coordinate_publish( "run_id": run_id, "validation_rows": accumulated_validation_rows, "fingerprint": fingerprint, + "reuse_measurement": reuse_measurement, } @@ -1124,11 +1140,12 @@ def coordinate_national_publish( version = get_version() + run_id = run_id or resolve_run_id() if not run_id: - from policyengine_us_data.utils.run_id import generate_run_id - - sha = os.environ.get("GIT_COMMIT", "unknown") - run_id = generate_run_id(version, sha) + raise RuntimeError( + "run_id is required. National publishing must receive the " + "GitHub-created run ID from the pipeline." + ) print("=" * 60) print(f"Run ID: {run_id}") @@ -1194,6 +1211,7 @@ def coordinate_national_publish( ) run_dir = staging_dir / run_id run_dir.mkdir(parents=True, exist_ok=True) + national_h5 = run_dir / "national" / "US.h5" work_items = [{"type": "national", "id": "US"}] print("Spawning worker for national H5 build...") @@ -1262,6 +1280,12 @@ def coordinate_national_publish( "run_id": run_id, "fingerprint": fingerprint, "national_validation": national_validation_output, + "reuse_measurement": { + "expected_outputs": 1, + "valid_reused_outputs": 0, + "recomputed_outputs": 1, + "invalid_outputs": 0, + }, } print(f"Uploading {national_h5} to HF staging...") @@ -1304,6 +1328,12 @@ def coordinate_national_publish( "run_id": run_id, "fingerprint": fingerprint, "national_validation": national_validation_output, + "reuse_measurement": { + "expected_outputs": 1, + "valid_reused_outputs": 0, + "recomputed_outputs": 1, + "invalid_outputs": 0, + }, } diff --git a/modal_app/pipeline.py b/modal_app/pipeline.py index 2f1e5399e..74c5f764e 100644 --- a/modal_app/pipeline.py +++ b/modal_app/pipeline.py @@ -38,11 +38,9 @@ import sys import time import traceback -from dataclasses import asdict, dataclass, field from datetime import datetime, timezone from io import BytesIO from pathlib import Path -from typing import Optional import modal @@ -52,24 +50,63 @@ if _p not in sys.path: sys.path.insert(0, _p) -from modal_app.images import cpu_image as image -from modal_app.resilience import ensure_resume_sha_compatible +from modal_app.images import cpu_image as image # noqa: E402 +from modal_app.resilience import ensure_resume_sha_compatible # noqa: E402 +from modal_app.step_manifests.runtime import ( # noqa: E402 + ArtifactReference, + PIPELINE_MOUNT, + ReuseMeasurement, + RUNS_DIR, + RunMetadata, + STAGING_MOUNT, + StepManifest, + apply_run_context_env as _apply_run_context_env, + artifact_identities as _artifact_identities, + artifacts_dir as _artifacts_dir, + artifacts_dir_for_run, + collect_artifacts, + collect_diagnostics as _collect_diagnostics, + collect_directory_artifacts, + collect_staging_outputs as _collect_staging_outputs, + complete_step_manifest as _complete_step_manifest, + fail_step_manifest as _fail_step_manifest, + mark_step_reused as _mark_step_reused, + metadata_run_fields as _metadata_run_fields, + read_run_meta, + record_step as _record_step, + run_manifest_path, + run_dir as _run_dir, + start_step_manifest as _start_step_manifest, + step_reusable as _step_reusable, + write_run_meta, +) +from policyengine_us_data.utils.run_context import RunContext, resolve_run_id # noqa: E402 +from policyengine_us_data.utils.step_manifest import ( # noqa: E402 + completed_validated_outputs, + read_step_manifest, +) # ── Modal resources ────────────────────────────────────────────── -app = modal.App("policyengine-us-data-pipeline") +app = modal.App( + os.environ.get("US_DATA_PIPELINE_APP_NAME") + or os.environ.get("US_DATA_MODAL_APP_NAME") + or "policyengine-us-data-pipeline" +) hf_secret = modal.Secret.from_name("huggingface-token") gcp_secret = modal.Secret.from_name("gcp-credentials") -pipeline_volume = modal.Volume.from_name("pipeline-artifacts", create_if_missing=True) -staging_volume = modal.Volume.from_name("local-area-staging", create_if_missing=True) +pipeline_volume = modal.Volume.from_name( + os.environ.get("US_DATA_PIPELINE_VOLUME_NAME", "pipeline-artifacts"), + create_if_missing=True, +) +staging_volume = modal.Volume.from_name( + os.environ.get("US_DATA_STAGING_VOLUME_NAME", "local-area-staging"), + create_if_missing=True, +) REPO_URL = "https://github.com/PolicyEngine/policyengine-us-data.git" -PIPELINE_MOUNT = "/pipeline" -STAGING_MOUNT = "/staging" -ARTIFACTS_BASE = f"{PIPELINE_MOUNT}/artifacts" -RUNS_DIR = f"{PIPELINE_MOUNT}/runs" def _python_cmd(*args: str) -> list[str]: @@ -77,97 +114,6 @@ def _python_cmd(*args: str) -> list[str]: return [sys.executable, *args] -def artifacts_dir_for_run(run_id: str) -> str: - """Return the run-scoped artifacts directory. - - When run_id is empty, falls back to the flat base path - for backward compatibility with standalone invocations. - """ - if run_id: - return f"{ARTIFACTS_BASE}/{run_id}" - return ARTIFACTS_BASE - - -# ── Run metadata ───────────────────────────────────────────────── - - -@dataclass -class RunMetadata: - """Metadata for a pipeline run. - - Tracks run identity, progress, and diagnostics for - auditability and resume support. - """ - - run_id: str - branch: str - sha: str - version: str - start_time: str - status: str # running | completed | failed | promoted - step_timings: dict = field(default_factory=dict) - error: Optional[str] = None - resume_history: list = field(default_factory=list) - fingerprint: Optional[str] = None - regional_fingerprint: Optional[str] = None - - def __post_init__(self) -> None: - if self.regional_fingerprint is None and self.fingerprint is not None: - self.regional_fingerprint = self.fingerprint - if self.fingerprint is None and self.regional_fingerprint is not None: - self.fingerprint = self.regional_fingerprint - - def to_dict(self) -> dict: - data = asdict(self) - if ( - data.get("fingerprint") is None - and data.get("regional_fingerprint") is not None - ): - data["fingerprint"] = data["regional_fingerprint"] - return data - - @classmethod - def from_dict(cls, data: dict) -> "RunMetadata": - data = dict(data) - if ( - data.get("regional_fingerprint") is None - and data.get("fingerprint") is not None - ): - data["regional_fingerprint"] = data["fingerprint"] - return cls(**data) - - -def generate_run_id(version: str, sha: str) -> str: - ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") - return f"{version}_{sha[:8]}_{ts}" - - -def write_run_meta( - meta: RunMetadata, - vol: modal.Volume, -) -> None: - """Write run metadata to the pipeline volume.""" - run_dir = Path(RUNS_DIR) / meta.run_id - run_dir.mkdir(parents=True, exist_ok=True) - meta_path = run_dir / "meta.json" - with open(meta_path, "w") as f: - json.dump(meta.to_dict(), f, indent=2) - vol.commit() - - -def read_run_meta( - run_id: str, - vol: modal.Volume, -) -> RunMetadata: - """Read run metadata from the pipeline volume.""" - vol.reload() - meta_path = Path(RUNS_DIR) / run_id / "meta.json" - if not meta_path.exists(): - raise FileNotFoundError(f"No metadata found for run {run_id} at {meta_path}") - with open(meta_path) as f: - return RunMetadata.from_dict(json.load(f)) - - def get_pinned_sha(branch: str) -> str: """Get the current tip SHA for a branch from GitHub.""" result = subprocess.run( @@ -229,84 +175,27 @@ def archive_diagnostics( vol.commit() -def _step_completed(meta: RunMetadata, step: str) -> bool: - """Check if a step is marked completed in metadata.""" - timing = meta.step_timings.get(step, {}) - return timing.get("status") == "completed" - - -def find_resumable_run(branch: str, sha: str, vol: modal.Volume) -> Optional[str]: - """Find an existing running run for the same branch+sha.""" - vol.reload() - runs_dir = Path(RUNS_DIR) - if not runs_dir.exists(): - return None - - best_run_id = None - best_start = "" - - for entry in runs_dir.iterdir(): - if not entry.is_dir(): - continue - meta_path = entry / "meta.json" - if not meta_path.exists(): - continue - try: - with open(meta_path) as f: - data = json.load(f) - if ( - data.get("branch") == branch - and data.get("sha") == sha - and data.get("status") == "running" - ): - start = data.get("start_time", "") - if start > best_start: - best_start = start - best_run_id = data.get("run_id") - except (json.JSONDecodeError, KeyError): - continue - - return best_run_id - - -def _record_step( - meta: RunMetadata, - step: str, - start: float, - vol: modal.Volume, - status: str = "completed", -) -> None: - """Record step timing and status in metadata.""" - meta.step_timings[step] = { - "start": datetime.fromtimestamp(start, tz=timezone.utc).isoformat(), - "end": datetime.now(timezone.utc).isoformat(), - "duration_s": round(time.time() - start, 1), - "status": status, - } - write_run_meta(meta, vol) - - # ── Include other Modal apps ───────────────────────────────────── # app.include() merges functions from other apps into this one, # ensuring Modal mounts their files and registers their functions # (with their GPU/memory/volume configs) in the ephemeral run. # sys.path setup is handled at the top of this file. -from modal_app.data_build import app as _data_build_app -from modal_app.data_build import build_datasets +from modal_app.data_build import app as _data_build_app # noqa: E402 +from modal_app.data_build import build_datasets # noqa: E402 app.include(_data_build_app) -from modal_app.remote_calibration_runner import app as _calibration_app -from modal_app.remote_calibration_runner import ( +from modal_app.remote_calibration_runner import app as _calibration_app # noqa: E402 +from modal_app.remote_calibration_runner import ( # noqa: E402 build_package_remote, PACKAGE_GPU_FUNCTIONS, ) app.include(_calibration_app) -from modal_app.local_area import app as _local_area_app -from modal_app.local_area import ( +from modal_app.local_area import app as _local_area_app # noqa: E402 +from modal_app.local_area import ( # noqa: E402 coordinate_publish, coordinate_national_publish, promote_publish, @@ -801,6 +690,11 @@ def run_pipeline( resume_run_id: str = None, clear_checkpoints: bool = False, version_override: str = "", + sha_override: str = "", + run_id: str = "", + run_context: dict | None = None, + modal_app_name: str = "", + modal_environment: str = "", ) -> str: """Run the full pipeline end-to-end. @@ -819,6 +713,13 @@ def run_pipeline( scoped by commit SHA, so stale ones from other commits are cleaned automatically. Use True only to force a full rebuild of the current commit. + sha_override: Exact source SHA deployed by GitHub Actions. When + provided, this is recorded instead of reading the current + branch tip. + run_id: Cross-system run ID created by GitHub. + run_context: Serialized run context from the launcher workflow. + modal_app_name: Deployed Modal app name for this run. + modal_environment: Modal environment used for this run. Returns: The run ID for use with promote. @@ -832,20 +733,29 @@ def run_pipeline( os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = creds_path # ── Initialize or resume run ── - sha = get_pinned_sha(branch) + sha = sha_override or get_pinned_sha(branch) version = version_override or get_version_from_branch(branch) + resolved_run_id = resolve_run_id(run_id) + current_run_context = RunContext.from_mapping( + run_context, + run_id=resolved_run_id, + modal_app_name=modal_app_name, + modal_environment=modal_environment, + ) explicit_resume = bool(resume_run_id) - if not resume_run_id: - existing = find_resumable_run(branch, sha, pipeline_volume) - if existing: - print(f"Auto-resuming existing run {existing}") - resume_run_id = existing - if resume_run_id: print(f"Resuming run {resume_run_id}...") meta = read_run_meta(resume_run_id, pipeline_volume) + current_run_context = RunContext.from_mapping( + meta.run_context, + run_id=meta.run_id, + modal_app_name=meta.modal_app_name or current_run_context.modal_app_name, + modal_environment=meta.modal_environment + or current_run_context.modal_environment, + ) + _apply_run_context_env(current_run_context) current_sha = sha sha_match = ensure_resume_sha_compatible( branch=branch, @@ -867,9 +777,24 @@ def run_pipeline( } ) meta.status = "running" + if not meta.run_context: + meta.run_context = current_run_context.to_dict() + meta.modal_app_name = meta.modal_app_name or current_run_context.modal_app_name + meta.modal_environment = ( + meta.modal_environment or current_run_context.modal_environment + ) + meta.hf_staging_prefix = ( + meta.hf_staging_prefix or current_run_context.hf_staging_prefix + ) run_id = resume_run_id else: - run_id = generate_run_id(version, sha) + if not current_run_context.run_id: + raise RuntimeError( + "run_id is required. Production pipeline runs must receive the " + "GitHub-created run ID through workflow_dispatch." + ) + _apply_run_context_env(current_run_context) + run_id = current_run_context.run_id meta = RunMetadata( run_id=run_id, branch=branch, @@ -877,6 +802,7 @@ def run_pipeline( version=version, start_time=datetime.now(timezone.utc).isoformat(), status="running", + **_metadata_run_fields(current_run_context), ) # Create run directory @@ -893,6 +819,10 @@ def run_pipeline( print("PIPELINE RUN") print("=" * 60) print(f" Run ID: {run_id}") + if meta.modal_app_name: + print(f" Modal app: {meta.modal_app_name}") + if meta.hf_staging_prefix: + print(f" HF staging: {meta.hf_staging_prefix}") print(f" Branch: {branch}") print(f" SHA: {sha[:12]}") print(f" Version: {version}") @@ -909,11 +839,44 @@ def run_pipeline( print(f" Resume: skipping {completed}") print("=" * 60) + active_step_manifest: StepManifest | None = None + try: # ── Step 1: Build datasets ── - if not _step_completed(meta, "build_datasets"): + build_dataset_inputs = {"source": {"branch": branch, "sha": sha}} + build_dataset_parameters = { + "upload": True, + "sequential": False, + "clear_checkpoints": clear_checkpoints, + "skip_tests": True, + "skip_enhanced_cps": False, + "run_id": run_id, + } + build_dataset_reuse = _step_reusable( + meta, + "01_build_datasets", + expected_input_identities=build_dataset_inputs, + expected_parameters=build_dataset_parameters, + ) + if build_dataset_reuse.reusable: + _mark_step_reused( + meta, + "01_build_datasets", + build_dataset_reuse, + vol=pipeline_volume, + legacy_step="build_datasets", + ) + print("\n[Step 1/5] Build datasets (skipped - manifest valid)") + else: print("\n[Step 1/5] Building datasets...") step_start = time.time() + active_step_manifest = _start_step_manifest( + meta, + "01_build_datasets", + parameters=build_dataset_parameters, + input_identities=build_dataset_inputs, + vol=pipeline_volume, + ) build_datasets.remote( upload=True, @@ -930,22 +893,83 @@ def run_pipeline( # policy_data.db) are staged to HF in step 4. # TODO(#617): When pipeline_artifacts.py lands, # call mirror_to_pipeline() here for audit trail. + dataset_outputs = collect_directory_artifacts( + _artifacts_dir(run_id), + role="output", + ) + checkpoint_stats_path = ( + _artifacts_dir(run_id) / "data_build_checkpoint_stats.json" + ) + checkpoint_stats = ( + json.loads(checkpoint_stats_path.read_text()) + if checkpoint_stats_path.exists() + else {} + ) _record_step( meta, "build_datasets", step_start, pipeline_volume, + step_id="01_build_datasets", + step_manifest=active_step_manifest, + outputs=dataset_outputs, + reuse_measurement=ReuseMeasurement( + expected_outputs=checkpoint_stats.get( + "expected_outputs", len(dataset_outputs) + ), + valid_reused_outputs=checkpoint_stats.get( + "valid_reused_outputs", 0 + ), + recomputed_outputs=checkpoint_stats.get( + "recomputed_outputs", len(dataset_outputs) + ), + invalid_outputs=checkpoint_stats.get("invalid_outputs", 0), + ), ) + active_step_manifest = None print( f" Completed in {meta.step_timings['build_datasets']['duration_s']}s" ) - else: - print("\n[Step 1/5] Build datasets (skipped - completed)") # ── Step 2: Build calibration package ── - if not _step_completed(meta, "build_package"): + package_inputs = _artifact_identities( + { + "dataset": _artifacts_dir(run_id) + / "source_imputed_stratified_extended_cps.h5", + "database": _artifacts_dir(run_id) / "policy_data.db", + } + ) + package_parameters = { + "workers": num_workers, + "n_clones": n_clones, + "target_config": None, + "skip_county": True, + } + package_reuse = _step_reusable( + meta, + "02_build_package", + expected_input_identities=package_inputs, + expected_parameters=package_parameters, + ) + if package_reuse.reusable: + _mark_step_reused( + meta, + "02_build_package", + package_reuse, + vol=pipeline_volume, + legacy_step="build_package", + ) + print("\n[Step 2/5] Build package (skipped - manifest valid)") + else: print("\n[Step 2/5] Building calibration package...") step_start = time.time() + active_step_manifest = _start_step_manifest( + meta, + "02_build_package", + parameters=package_parameters, + input_identities=package_inputs, + vol=pipeline_volume, + ) pkg_path = build_package_remote.remote( branch=branch, @@ -960,13 +984,78 @@ def run_pipeline( "build_package", step_start, pipeline_volume, + step_id="02_build_package", + step_manifest=active_step_manifest, + outputs=collect_artifacts( + [_artifacts_dir(run_id) / "calibration_package.pkl"], + missing_ok=True, + ), ) + active_step_manifest = None print(f" Completed in {meta.step_timings['build_package']['duration_s']}s") - else: - print("\n[Step 2/5] Build package (skipped - completed)") # ── Step 3: Fit weights (parallel) ── - if not _step_completed(meta, "fit_weights"): + fit_inputs = _artifact_identities( + { + "calibration_package": _artifacts_dir(run_id) + / "calibration_package.pkl", + } + ) + regional_fit_parameters = { + "gpu": gpu, + "epochs": epochs, + "target_config": "policyengine_us_data/calibration/target_config.yaml", + "beta": 0.65, + "lambda_l0": 1e-7, + "lambda_l2": 1e-8, + "log_freq": 100, + } + national_fit_parameters = { + "gpu": national_gpu, + "epochs": national_epochs, + "target_config": "policyengine_us_data/calibration/target_config.yaml", + "beta": 0.65, + "lambda_l0": 2e-2, + "lambda_l2": 1e-12, + "log_freq": 100, + "skip_national": skip_national, + } + regional_fit_reuse = _step_reusable( + meta, + "03_fit_weights_regional", + expected_input_identities=fit_inputs, + expected_parameters=regional_fit_parameters, + ) + national_fit_reuse = ( + _step_reusable( + meta, + "03_fit_weights_national", + expected_input_identities=fit_inputs, + expected_parameters=national_fit_parameters, + ) + if not skip_national + else None + ) + fit_reusable = regional_fit_reuse.reusable and ( + skip_national or national_fit_reuse.reusable + ) + if fit_reusable: + _mark_step_reused( + meta, + "03_fit_weights_regional", + regional_fit_reuse, + vol=pipeline_volume, + legacy_step="fit_weights", + ) + if national_fit_reuse is not None: + _mark_step_reused( + meta, + "03_fit_weights_national", + national_fit_reuse, + vol=pipeline_volume, + ) + print("\n[Step 3/5] Fit weights (skipped - manifests valid)") + else: print("\n[Step 3/5] Fitting calibration weights...") step_start = time.time() @@ -987,9 +1076,20 @@ def run_pipeline( log_freq=100, ) print(f" → regional fit fc: {regional_handle.object_id}") + regional_fit_manifest = _start_step_manifest( + meta, + "03_fit_weights_regional", + scope="regional", + parameters=regional_fit_parameters, + input_identities=fit_inputs, + modal_call_id=regional_handle.object_id, + vol=pipeline_volume, + ) + active_step_manifest = regional_fit_manifest # Spawn national fit (if enabled) national_handle = None + national_fit_manifest = None if not skip_national: national_func = PACKAGE_GPU_FUNCTIONS[national_gpu] print( @@ -1008,6 +1108,15 @@ def run_pipeline( log_freq=100, ) print(f" → national fit fc: {national_handle.object_id}") + national_fit_manifest = _start_step_manifest( + meta, + "03_fit_weights_national", + scope="national", + parameters=national_fit_parameters, + input_identities=fit_inputs, + modal_call_id=national_handle.object_id, + vol=pipeline_volume, + ) # Collect regional results print(" Waiting for regional fit...") @@ -1038,6 +1147,27 @@ def run_pipeline( pipeline_volume, prefix="", ) + regional_outputs = collect_artifacts( + [ + _artifacts_dir(run_id) / "calibration_weights.npy", + _artifacts_dir(run_id) / "geography_assignment.npz", + _artifacts_dir(run_id) / "unified_run_config.json", + ], + missing_ok=True, + ) + regional_fit_reuse_measurement = ReuseMeasurement( + expected_outputs=len(regional_outputs), + recomputed_outputs=len(regional_outputs), + ) + _complete_step_manifest( + regional_fit_manifest, + outputs=regional_outputs, + diagnostics=_collect_diagnostics(run_id), + reuse_decision="computed", + reuse_measurement=regional_fit_reuse_measurement, + vol=pipeline_volume, + ) + active_step_manifest = national_fit_manifest # Collect national results if national_handle is not None: @@ -1067,16 +1197,40 @@ def run_pipeline( pipeline_volume, prefix="national_", ) + national_outputs = collect_artifacts( + [ + _artifacts_dir(run_id) / "national_calibration_weights.npy", + _artifacts_dir(run_id) / "national_geography_assignment.npz", + _artifacts_dir(run_id) / "national_unified_run_config.json", + ], + missing_ok=True, + ) + _complete_step_manifest( + national_fit_manifest, + outputs=national_outputs, + diagnostics=_collect_diagnostics(run_id), + reuse_decision="computed", + reuse_measurement=ReuseMeasurement( + expected_outputs=len(national_outputs), + recomputed_outputs=len(national_outputs), + ), + vol=pipeline_volume, + ) + active_step_manifest = None _record_step( meta, "fit_weights", step_start, pipeline_volume, + step_id="03_fit_weights_regional", + step_manifest=regional_fit_manifest, + outputs=regional_outputs, + diagnostics=_collect_diagnostics(run_id), + reuse_measurement=regional_fit_reuse_measurement, ) + active_step_manifest = None print(f" Completed in {meta.step_timings['fit_weights']['duration_s']}s") - else: - print("\n[Step 3/5] Fit weights (skipped - completed)") # ── Step 4: Build H5s + stage + diagnostics (parallel) ── # 4a. coordinate_publish (regional H5s) @@ -1085,7 +1239,104 @@ def run_pipeline( # 4d. upload_run_diagnostics (calibration diagnostics → HF) # 4e. _write_validation_diagnostics (after H5 builds) # 4f. upload_run_diagnostics (validation diagnostics → HF) - if not _step_completed(meta, "publish_and_stage"): + regional_h5_inputs = _artifact_identities( + { + "weights": _artifacts_dir(run_id) / "calibration_weights.npy", + "geography": _artifacts_dir(run_id) / "geography_assignment.npz", + "dataset": _artifacts_dir(run_id) + / "source_imputed_stratified_extended_cps.h5", + "database": _artifacts_dir(run_id) / "policy_data.db", + "run_config": _artifacts_dir(run_id) / "unified_run_config.json", + "calibration_package": _artifacts_dir(run_id) + / "calibration_package.pkl", + } + ) + regional_h5_parameters = { + "num_workers": num_workers, + "n_clones": n_clones, + "validate": True, + "skip_upload": False, + } + national_h5_inputs = _artifact_identities( + { + "weights": _artifacts_dir(run_id) / "national_calibration_weights.npy", + "geography": _artifacts_dir(run_id) + / "national_geography_assignment.npz", + "dataset": _artifacts_dir(run_id) + / "source_imputed_stratified_extended_cps.h5", + "database": _artifacts_dir(run_id) / "policy_data.db", + "run_config": _artifacts_dir(run_id) + / "national_unified_run_config.json", + } + ) + national_h5_parameters = { + "n_clones": n_clones, + "validate": True, + "skip_upload": False, + "skip_national": skip_national, + } + stage_base_inputs = _artifact_identities( + { + "policy_data_db": _artifacts_dir(run_id) / "policy_data.db", + "source_imputed": _artifacts_dir(run_id) + / "source_imputed_stratified_extended_cps.h5", + } + ) + stage_base_parameters = { + "version": version, + "branch": branch, + "run_id": run_id, + } + regional_h5_reuse = _step_reusable( + meta, + "04_build_h5_regional", + expected_input_identities=regional_h5_inputs, + expected_parameters=regional_h5_parameters, + ) + national_h5_reuse = ( + _step_reusable( + meta, + "04_build_h5_national", + expected_input_identities=national_h5_inputs, + expected_parameters=national_h5_parameters, + ) + if not skip_national + else None + ) + stage_base_reuse = _step_reusable( + meta, + "04_stage_base_datasets", + expected_input_identities=stage_base_inputs, + expected_parameters=stage_base_parameters, + ) + publish_reusable = ( + regional_h5_reuse.reusable + and (skip_national or national_h5_reuse.reusable) + and stage_base_reuse.reusable + ) + if publish_reusable: + _mark_step_reused( + meta, + "04_build_h5_regional", + regional_h5_reuse, + vol=pipeline_volume, + legacy_step="publish_and_stage", + ) + if national_h5_reuse is not None: + _mark_step_reused( + meta, + "04_build_h5_national", + national_h5_reuse, + vol=pipeline_volume, + ) + _mark_step_reused( + meta, + "04_stage_base_datasets", + stage_base_reuse, + vol=pipeline_volume, + ) + print("\n[Step 4/5] Publish + stage (skipped - manifests valid)") + else: print( "\n[Step 4/5] Building H5s, staging datasets, " "uploading diagnostics (parallel)..." @@ -1106,8 +1357,19 @@ def run_pipeline( ), ) print(f" → coordinate_publish fc: {regional_h5_handle.object_id}") + regional_h5_manifest = _start_step_manifest( + meta, + "04_build_h5_regional", + scope="regional", + parameters=regional_h5_parameters, + input_identities=regional_h5_inputs, + modal_call_id=regional_h5_handle.object_id, + vol=pipeline_volume, + ) + active_step_manifest = regional_h5_manifest national_h5_handle = None + national_h5_manifest = None if not skip_national: print(" Spawning national H5 build...") national_h5_handle = coordinate_national_publish.spawn( @@ -1119,12 +1381,45 @@ def run_pipeline( print( f" → coordinate_national_publish fc: {national_h5_handle.object_id}" ) + national_h5_manifest = _start_step_manifest( + meta, + "04_build_h5_national", + scope="national", + parameters=national_h5_parameters, + input_identities=national_h5_inputs, + modal_call_id=national_h5_handle.object_id, + vol=pipeline_volume, + ) # While H5 builds run, stage base datasets in this container pipeline_volume.reload() print(" Staging base datasets to HF...") + stage_base_manifest = _start_step_manifest( + meta, + "04_stage_base_datasets", + parameters=stage_base_parameters, + input_identities=stage_base_inputs, + vol=pipeline_volume, + ) + active_step_manifest = stage_base_manifest stage_base_datasets(run_id, version, branch) + base_stage_outputs = collect_directory_artifacts( + _artifacts_dir(run_id), + patterns=("*.h5", "*.db"), + role="output", + ) + _complete_step_manifest( + stage_base_manifest, + outputs=base_stage_outputs, + reuse_decision="computed", + reuse_measurement=ReuseMeasurement( + expected_outputs=len(base_stage_outputs), + recomputed_outputs=len(base_stage_outputs), + ), + vol=pipeline_volume, + ) + active_step_manifest = regional_h5_manifest # Now wait for H5 builds to finish print(" Waiting for regional H5 build...") @@ -1141,7 +1436,28 @@ def run_pipeline( ): meta.regional_fingerprint = regional_h5_result["fingerprint"] meta.fingerprint = regional_h5_result["fingerprint"] + regional_h5_manifest.input_identities["h5_scope_fingerprint"] = ( + regional_h5_result["fingerprint"] + ) write_run_meta(meta, pipeline_volume) + regional_reuse_measurement = ReuseMeasurement.from_dict( + regional_h5_result.get("reuse_measurement", {}) + if isinstance(regional_h5_result, dict) + else {} + ) + _complete_step_manifest( + regional_h5_manifest, + outputs=_collect_staging_outputs(run_id, scope="regional"), + diagnostics=_collect_diagnostics(run_id), + reuse_decision=( + "partially_reused" + if regional_reuse_measurement.valid_reused_outputs + else "computed" + ), + reuse_measurement=regional_reuse_measurement, + vol=pipeline_volume, + ) + active_step_manifest = national_h5_manifest national_h5_result = None if national_h5_handle is not None: @@ -1153,6 +1469,30 @@ def run_pipeline( else national_h5_result ) print(f" National H5: {national_msg}") + if isinstance(national_h5_result, dict) and national_h5_result.get( + "fingerprint" + ): + national_h5_manifest.input_identities["h5_scope_fingerprint"] = ( + national_h5_result["fingerprint"] + ) + national_reuse_measurement = ReuseMeasurement.from_dict( + national_h5_result.get("reuse_measurement", {}) + if isinstance(national_h5_result, dict) + else {} + ) + _complete_step_manifest( + national_h5_manifest, + outputs=_collect_staging_outputs(run_id, scope="national"), + diagnostics=_collect_diagnostics(run_id), + reuse_decision=( + "partially_reused" + if national_reuse_measurement.valid_reused_outputs + else "computed" + ), + reuse_measurement=national_reuse_measurement, + vol=pipeline_volume, + ) + active_step_manifest = None # ── Aggregate validation results ── _write_validation_diagnostics( @@ -1165,20 +1505,54 @@ def run_pipeline( # Upload validation diagnostics (written after H5 builds) print(" Uploading validation diagnostics...") + diagnostics_manifest = _start_step_manifest( + meta, + "04_upload_diagnostics", + parameters={"branch": branch, "run_id": run_id}, + input_identities={ + "diagnostics": [ + artifact.to_dict() for artifact in _collect_diagnostics(run_id) + ] + }, + vol=pipeline_volume, + ) + active_step_manifest = diagnostics_manifest upload_run_diagnostics(run_id, branch) + diagnostic_outputs = _collect_diagnostics(run_id) + _complete_step_manifest( + diagnostics_manifest, + outputs=diagnostic_outputs, + diagnostics=diagnostic_outputs, + reuse_decision="computed", + reuse_measurement=ReuseMeasurement( + expected_outputs=len(diagnostic_outputs), + recomputed_outputs=len(diagnostic_outputs), + ), + vol=pipeline_volume, + ) + active_step_manifest = regional_h5_manifest _record_step( meta, "publish_and_stage", step_start, pipeline_volume, + step_id="04_build_h5_regional", + step_manifest=regional_h5_manifest, + outputs=_collect_staging_outputs(run_id, scope="regional"), + diagnostics=_collect_diagnostics(run_id), + reuse_decision=( + "partially_reused" + if regional_reuse_measurement.valid_reused_outputs + else "computed" + ), + reuse_measurement=regional_reuse_measurement, ) + active_step_manifest = None print( f" Completed in " f"{meta.step_timings['publish_and_stage']['duration_s']}s" ) - else: - print("\n[Step 4/5] Publish + stage (skipped - completed)") # ── Step 5: Finalize ── print("\n[Step 5/5] Finalizing run...") @@ -1202,6 +1576,7 @@ def run_pipeline( return run_id except Exception as e: + _fail_step_manifest(active_step_manifest, e, pipeline_volume) meta.status = "failed" meta.error = f"{type(e).__name__}: {e}\n{traceback.format_exc()}" write_run_meta(meta, pipeline_volume) @@ -1278,6 +1653,34 @@ def promote_run( print(f"WARNING: Run {run_id} was already promoted. Re-promoting...") version = version or meta.version + promote_inputs = { + "validated_step_outputs": [ + artifact.to_dict() + for artifact in completed_validated_outputs( + _run_dir(run_id), + step_ids=[ + "04_build_h5_regional", + "04_build_h5_national", + "04_stage_base_datasets", + ], + ) + ] + } + if ( + run_manifest_path(_run_dir(run_id)).exists() + and not promote_inputs["validated_step_outputs"] + ): + raise RuntimeError( + "No validated completed step outputs found for release promotion. " + "Run Phase 3c pipeline steps before promoting this run." + ) + promote_manifest = _start_step_manifest( + meta, + "05_promote_release", + parameters={"version": version, "run_id": run_id}, + input_identities=promote_inputs, + vol=pipeline_volume, + ) print("=" * 60) print("PROMOTING PIPELINE RUN") @@ -1364,7 +1767,7 @@ def promote_run( version="{version}", blob_names=blob_names, ) -manifest.pipeline_run_id = "{run_id}" +manifest.run_id = "{run_id}" manifest.diagnostics_path = "calibration/runs/{run_id}/diagnostics/" upload_manifest(manifest) print("Registered version {version} in version_manifest.json") @@ -1384,6 +1787,15 @@ def promote_run( # Update run status meta.status = "promoted" + _complete_step_manifest( + promote_manifest, + outputs=[ + ArtifactReference.from_dict(artifact) + for artifact in promote_inputs["validated_step_outputs"] + ], + reuse_decision="computed", + vol=pipeline_volume, + ) write_run_meta(meta, pipeline_volume) print("\n" + "=" * 60) @@ -1419,6 +1831,7 @@ def pipeline_status( if run_id: meta = read_run_meta(run_id, pipeline_volume) + steps_dir = _run_dir(run_id) / "steps" lines = [ f"Run: {meta.run_id}", f" Branch: {meta.branch}", @@ -1429,8 +1842,19 @@ def pipeline_status( ] if meta.error: lines.append(f" Error: {meta.error[:200]}") + if steps_dir.exists(): + lines.append(" Step manifests:") + for manifest_path in sorted(steps_dir.glob("*.json")): + manifest = read_step_manifest(manifest_path) + duration = ( + manifest.duration_s if manifest.duration_s is not None else "?" + ) + reuse = manifest.reuse_decision + lines.append( + f" {manifest.step_id}: {duration}s ({manifest.status}, {reuse})" + ) if meta.step_timings: - lines.append(" Steps:") + lines.append(" Legacy step timings:") for step, timing in meta.step_timings.items(): dur = timing.get("duration_s", "?") status = timing.get("status", "unknown") @@ -1440,8 +1864,18 @@ def pipeline_status( # List all runs runs = [] for entry in sorted(runs_dir.iterdir()): + manifest_path = entry / "run_manifest.json" meta_path = entry / "meta.json" - if meta_path.exists(): + if manifest_path.exists(): + with open(manifest_path) as f: + data = json.load(f) + runs.append( + f" {data['run_id']}: " + f"{data['status']} " + f"(branch={data['branch']}, " + f"v={data['version']})" + ) + elif meta_path.exists(): with open(meta_path) as f: data = json.load(f) runs.append( @@ -1475,6 +1909,7 @@ def main( skip_national: bool = False, clear_checkpoints: bool = False, version: str = None, + sha_override: str = "", ): """Pipeline entrypoint. @@ -1496,6 +1931,8 @@ def main( resume_run_id=resume_run_id, clear_checkpoints=clear_checkpoints, version_override=version or "", + sha_override=sha_override, + run_id=run_id or "", ) print(f"\nPipeline run complete: {result}") diff --git a/modal_app/remote_calibration_runner.py b/modal_app/remote_calibration_runner.py index f339b68c9..534a30fe4 100644 --- a/modal_app/remote_calibration_runner.py +++ b/modal_app/remote_calibration_runner.py @@ -13,10 +13,15 @@ from modal_app.images import gpu_image as image # noqa: E402 -app = modal.App("policyengine-us-data-fit-weights") +app = modal.App( + os.environ.get("US_DATA_FIT_WEIGHTS_APP_NAME") or "policyengine-us-data-fit-weights" +) hf_secret = modal.Secret.from_name("huggingface-token") -pipeline_vol = modal.Volume.from_name("pipeline-artifacts", create_if_missing=True) +pipeline_vol = modal.Volume.from_name( + os.environ.get("US_DATA_PIPELINE_VOLUME_NAME", "pipeline-artifacts"), + create_if_missing=True, +) PIPELINE_MOUNT = "/pipeline" diff --git a/modal_app/step_manifests/__init__.py b/modal_app/step_manifests/__init__.py new file mode 100644 index 000000000..a97087b19 --- /dev/null +++ b/modal_app/step_manifests/__init__.py @@ -0,0 +1 @@ +"""Step-manifest runtime helpers for Modal pipeline orchestration.""" diff --git a/modal_app/step_manifests/runtime.py b/modal_app/step_manifests/runtime.py new file mode 100644 index 000000000..cc40ff00f --- /dev/null +++ b/modal_app/step_manifests/runtime.py @@ -0,0 +1,427 @@ +"""Runtime helpers for Modal pipeline step manifests.""" + +from __future__ import annotations + +import json +import os +import time +from dataclasses import asdict, dataclass, field, fields +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Optional + +from policyengine_us_data.utils.run_context import RunContext +from policyengine_us_data.utils.step_manifest import ( + ArtifactReference, + ReuseMeasurement, + RunManifest, + StepManifest, + collect_artifacts, + collect_directory_artifacts, + evaluate_step_reuse, + read_step_manifest, + run_manifest_path, + step_manifest_path, + utc_now, + write_run_manifest, + write_step_manifest, +) + +PIPELINE_MOUNT = "/pipeline" +STAGING_MOUNT = "/staging" +ARTIFACTS_BASE = f"{PIPELINE_MOUNT}/artifacts" +RUNS_DIR = f"{PIPELINE_MOUNT}/runs" + +RUN_STEP_IDS = [ + "01_build_datasets", + "02_build_package", + "03_fit_weights_regional", + "03_fit_weights_national", + "04_build_h5_regional", + "04_build_h5_national", + "04_stage_base_datasets", + "04_upload_diagnostics", + "05_promote_release", +] + + +def artifacts_dir_for_run(run_id: str) -> str: + """Return the run-scoped artifacts directory.""" + if run_id: + return f"{ARTIFACTS_BASE}/{run_id}" + return ARTIFACTS_BASE + + +@dataclass +class RunMetadata: + """Metadata for a pipeline run.""" + + run_id: str + branch: str + sha: str + version: str + start_time: str + status: str + step_timings: dict = field(default_factory=dict) + error: Optional[str] = None + resume_history: list = field(default_factory=list) + fingerprint: Optional[str] = None + regional_fingerprint: Optional[str] = None + run_context: dict = field(default_factory=dict) + modal_app_name: Optional[str] = None + modal_environment: Optional[str] = None + hf_staging_prefix: Optional[str] = None + + def __post_init__(self) -> None: + if self.regional_fingerprint is None and self.fingerprint is not None: + self.regional_fingerprint = self.fingerprint + if self.fingerprint is None and self.regional_fingerprint is not None: + self.fingerprint = self.regional_fingerprint + + def to_dict(self) -> dict: + data = asdict(self) + if ( + data.get("fingerprint") is None + and data.get("regional_fingerprint") is not None + ): + data["fingerprint"] = data["regional_fingerprint"] + return data + + @classmethod + def from_dict(cls, data: dict) -> "RunMetadata": + data = dict(data) + if "run_context" not in data and "publication_context" in data: + data["run_context"] = data["publication_context"] + if ( + data.get("regional_fingerprint") is None + and data.get("fingerprint") is not None + ): + data["regional_fingerprint"] = data["fingerprint"] + allowed_fields = {field.name for field in fields(cls)} + return cls( + **{key: value for key, value in data.items() if key in allowed_fields} + ) + + +def apply_run_context_env(context: RunContext) -> None: + """Expose run context to subprocess upload helpers.""" + for key, value in context.export_env().items(): + os.environ[key] = value + + +def metadata_run_fields(context: RunContext) -> dict: + return { + "run_context": context.to_dict(), + "modal_app_name": context.modal_app_name, + "modal_environment": context.modal_environment, + "hf_staging_prefix": context.hf_staging_prefix, + } + + +def run_dir(run_id: str) -> Path: + return Path(RUNS_DIR) / run_id + + +def artifacts_dir(run_id: str) -> Path: + return Path(artifacts_dir_for_run(run_id)) + + +def _write_run_manifest(meta: RunMetadata) -> None: + """Write the run-scoped execution ledger.""" + manifest = RunManifest( + run_id=meta.run_id, + branch=meta.branch, + sha=meta.sha, + version=meta.version, + status=meta.status, + started_at=meta.start_time, + run_context=meta.run_context, + modal_app_name=meta.modal_app_name, + modal_environment=meta.modal_environment, + hf_staging_prefix=meta.hf_staging_prefix, + updated_at=utc_now(), + completed_at=utc_now() + if meta.status in {"completed", "failed", "promoted"} + else None, + known_step_ids=RUN_STEP_IDS, + resume_history=meta.resume_history, + error=meta.error, + ) + write_run_manifest(run_manifest_path(run_dir(meta.run_id)), manifest) + + +def write_run_meta(meta: RunMetadata, vol: Any) -> None: + """Write compatibility metadata and the canonical run manifest.""" + destination = run_dir(meta.run_id) + destination.mkdir(parents=True, exist_ok=True) + meta_path = destination / "meta.json" + with open(meta_path, "w") as f: + json.dump(meta.to_dict(), f, indent=2) + _write_run_manifest(meta) + vol.commit() + + +def read_run_meta(run_id: str, vol: Any) -> RunMetadata: + """Read run metadata from the pipeline volume.""" + vol.reload() + meta_path = run_dir(run_id) / "meta.json" + if not meta_path.exists(): + raise FileNotFoundError(f"No metadata found for run {run_id} at {meta_path}") + with open(meta_path) as f: + return RunMetadata.from_dict(json.load(f)) + + +def step_completed(meta: RunMetadata, step: str) -> bool: + """Check if a legacy step is marked completed in compatibility metadata.""" + timing = meta.step_timings.get(step, {}) + return timing.get("status") == "completed" + + +def _next_step_attempt(run_id: str, step_id: str) -> int: + path = step_manifest_path(run_dir(run_id), step_id) + if not path.exists(): + return 1 + try: + return read_step_manifest(path).attempt + 1 + except Exception: + return 1 + + +def start_step_manifest( + meta: RunMetadata, + step_id: str, + *, + parameters: dict | None = None, + input_identities: dict | None = None, + scope: str | None = None, + modal_call_id: str | None = None, + vol: Any | None = None, +) -> StepManifest: + manifest = StepManifest( + run_id=meta.run_id, + step_id=step_id, + scope=scope, + status="running", + attempt=_next_step_attempt(meta.run_id, step_id), + started_at=utc_now(), + branch=meta.branch, + sha=meta.sha, + version=meta.version, + modal_app_name=meta.modal_app_name, + modal_environment=meta.modal_environment, + hf_staging_prefix=meta.hf_staging_prefix, + modal_call_id=modal_call_id, + parameters=parameters or {}, + input_identities=input_identities or {}, + ) + write_step_manifest(step_manifest_path(run_dir(meta.run_id), step_id), manifest) + if vol is not None: + vol.commit() + return manifest + + +def complete_step_manifest( + manifest: StepManifest, + *, + outputs: list[ArtifactReference] | None = None, + diagnostics: list[ArtifactReference] | None = None, + reuse_decision: str = "computed", + reuse_reason: str | None = None, + reuse_measurement: ReuseMeasurement | None = None, + status: str = "completed", + vol: Any | None = None, +) -> StepManifest: + completed = manifest.complete( + status=status, + outputs=outputs, + diagnostics=diagnostics, + reuse_decision=reuse_decision, + reuse_reason=reuse_reason, + reuse_measurement=reuse_measurement, + ) + write_step_manifest( + step_manifest_path(run_dir(completed.run_id), completed.step_id), + completed, + ) + if vol is not None: + vol.commit() + return completed + + +def fail_step_manifest( + manifest: StepManifest | None, + exc: BaseException, + vol: Any, +) -> None: + if manifest is None: + return + failed = manifest.fail(exc) + write_step_manifest( + step_manifest_path(run_dir(failed.run_id), failed.step_id), failed + ) + vol.commit() + + +def mark_step_reused( + meta: RunMetadata, + step_id: str, + decision, + *, + vol: Any, + legacy_step: str | None = None, +) -> StepManifest: + previous = decision.manifest + if previous is None: + raise RuntimeError(f"Cannot reuse {step_id}: missing prior manifest") + reused = StepManifest( + run_id=previous.run_id, + step_id=previous.step_id, + scope=previous.scope, + status="reused", + attempt=previous.attempt + 1, + started_at=utc_now(), + completed_at=utc_now(), + duration_s=0.0, + branch=meta.branch, + sha=meta.sha, + version=meta.version, + modal_app_name=meta.modal_app_name or previous.modal_app_name, + modal_environment=meta.modal_environment or previous.modal_environment, + hf_staging_prefix=meta.hf_staging_prefix or previous.hf_staging_prefix, + modal_app_id=previous.modal_app_id, + modal_call_id=previous.modal_call_id, + parameters=previous.parameters, + input_identities=previous.input_identities, + outputs=previous.outputs, + diagnostics=previous.diagnostics, + reuse_decision="reused", + reuse_reason=decision.reason, + reuse_measurement=ReuseMeasurement( + expected_outputs=len(previous.outputs), + valid_reused_outputs=len(previous.outputs), + recomputed_outputs=0, + invalid_outputs=0, + ), + ) + write_step_manifest(step_manifest_path(run_dir(meta.run_id), step_id), reused) + meta.step_timings[legacy_step or step_id] = { + "start": reused.started_at, + "end": reused.completed_at, + "duration_s": 0.0, + "status": "completed", + "reuse_decision": "reused", + "reuse_reason": decision.reason, + } + write_run_meta(meta, vol) + return reused + + +def step_reusable( + meta: RunMetadata, + step_id: str, + *, + expected_input_identities: dict | None = None, + expected_parameters: dict | None = None, +) -> object: + return evaluate_step_reuse( + step_manifest_path(run_dir(meta.run_id), step_id), + expected_input_identities=expected_input_identities, + expected_parameters=expected_parameters, + ) + + +def artifact_identity(path: str | Path) -> dict: + artifact = ArtifactReference.from_path(path) + return { + "path": artifact.path, + "size_bytes": artifact.size_bytes, + "sha256": artifact.sha256, + } + + +def artifact_identities(paths: dict[str, str | Path]) -> dict: + identities = {} + for label, path in paths.items(): + artifact_path = Path(path) + identities[label] = ( + artifact_identity(artifact_path) + if artifact_path.exists() + else {"path": str(artifact_path), "missing": True} + ) + return identities + + +def collect_diagnostics(run_id: str) -> list[ArtifactReference]: + return collect_directory_artifacts( + run_dir(run_id) / "diagnostics", + patterns=("*.csv", "*.json", "*.txt"), + role="diagnostic", + ) + + +def collect_staging_outputs(run_id: str, *, scope: str) -> list[ArtifactReference]: + scoped_run_dir = Path(STAGING_MOUNT) / run_id + paths: list[Path] = [] + if scope == "regional": + for subdir in ("states", "districts", "cities"): + paths.extend(sorted((scoped_run_dir / subdir).glob("*.h5"))) + manifest_path = scoped_run_dir / "manifest.json" + if manifest_path.exists(): + paths.append(manifest_path) + elif scope == "national": + paths.extend(sorted((scoped_run_dir / "national").glob("*.h5"))) + else: + raise ValueError(f"Unknown H5 output scope: {scope}") + return collect_artifacts(paths, missing_ok=True) + + +def record_step( + meta: RunMetadata, + step: str, + start: float, + vol: Any, + status: str = "completed", + *, + step_id: str | None = None, + step_manifest: StepManifest | None = None, + parameters: dict | None = None, + input_identities: dict | None = None, + outputs: list[ArtifactReference] | None = None, + diagnostics: list[ArtifactReference] | None = None, + reuse_decision: str = "computed", + reuse_reason: str | None = None, + reuse_measurement: ReuseMeasurement | None = None, +) -> None: + """Record step timing/status and complete the step manifest.""" + meta.step_timings[step] = { + "start": datetime.fromtimestamp(start, tz=timezone.utc).isoformat(), + "end": datetime.now(timezone.utc).isoformat(), + "duration_s": round(time.time() - start, 1), + "status": status, + } + canonical_step_id = step_id or step + manifest = step_manifest or StepManifest( + run_id=meta.run_id, + step_id=canonical_step_id, + status="running", + attempt=_next_step_attempt(meta.run_id, canonical_step_id), + started_at=datetime.fromtimestamp(start, tz=timezone.utc).isoformat(), + branch=meta.branch, + sha=meta.sha, + version=meta.version, + modal_app_name=meta.modal_app_name, + modal_environment=meta.modal_environment, + hf_staging_prefix=meta.hf_staging_prefix, + parameters=parameters or {}, + input_identities=input_identities or {}, + ) + complete_step_manifest( + manifest, + outputs=outputs or [], + diagnostics=diagnostics or [], + reuse_decision=reuse_decision, + reuse_reason=reuse_reason, + reuse_measurement=reuse_measurement, + status=status, + ) + write_run_meta(meta, vol) diff --git a/policyengine_us_data/storage/upload_completed_datasets.py b/policyengine_us_data/storage/upload_completed_datasets.py index 3d9000952..14bb758a3 100644 --- a/policyengine_us_data/storage/upload_completed_datasets.py +++ b/policyengine_us_data/storage/upload_completed_datasets.py @@ -16,6 +16,7 @@ upload_from_hf_staging_to_gcs, upload_to_staging_hf, ) +from policyengine_us_data.utils.run_context import resolve_run_id from policyengine_us_data.utils.dataset_validation import ( DatasetContractError, load_dataset_for_validation, @@ -61,6 +62,10 @@ MAX_HOUSEHOLD_WEIGHT_SUM = 200e6 # 200 million +def _resolve_run_id(run_id: str = "") -> str: + return run_id or resolve_run_id() + + class DatasetValidationError(Exception): """Raised when a dataset fails pre-upload validation.""" @@ -114,6 +119,7 @@ def _collect_staged_dataset_repo_paths( run_id: str = "", ) -> list[str]: api = HfApi() + run_id = _resolve_run_id(run_id) prefix = f"staging/{run_id}" if run_id else "staging" repo_files = set( api.list_repo_files( @@ -145,6 +151,7 @@ def _download_staged_dataset_artifacts( rel_paths: list[str], run_id: str = "", ) -> list[tuple[Path, str]]: + run_id = _resolve_run_id(run_id) staging_prefix = f"staging/{run_id}" if run_id else "staging" downloaded_files = [] for rel_path in rel_paths: @@ -301,6 +308,7 @@ def stage_datasets( version: str | None = None, run_id: str = "", ) -> list[tuple[Path, str]]: + run_id = _resolve_run_id(run_id) version = version or DATA_PACKAGE_VERSION files_with_repo_paths = _collect_existing_dataset_artifacts( require_enhanced_cps=require_enhanced_cps @@ -324,6 +332,7 @@ def promote_datasets( run_id: str = "", files_with_repo_paths: list[tuple[Path, str]] | None = None, ) -> list[str]: + run_id = _resolve_run_id(run_id) version = version or DATA_PACKAGE_VERSION rel_paths = ( [repo_path for _, repo_path in files_with_repo_paths] @@ -407,6 +416,7 @@ def upload_datasets( run_id: str = "", version: str | None = None, ): + run_id = _resolve_run_id(run_id) if stage_only and promote_only: raise ValueError("Choose either stage_only or promote_only, not both.") @@ -480,8 +490,8 @@ def validate_built_datasets(require_enhanced_cps: bool = True): ) parser.add_argument( "--run-id", - default="", - help="Optional staging run ID, for example a CI commit SHA.", + default=resolve_run_id(), + help="GitHub-created staging run ID.", ) parser.add_argument( "--version", diff --git a/policyengine_us_data/utils/data_upload.py b/policyengine_us_data/utils/data_upload.py index 0d0eed3fd..25ab7174e 100644 --- a/policyengine_us_data/utils/data_upload.py +++ b/policyengine_us_data/utils/data_upload.py @@ -30,6 +30,10 @@ build_release_manifest, serialize_release_manifest, ) +from policyengine_us_data.utils.run_context import ( + RunContext, + resolve_run_id, +) from policyengine_us_data.utils.trace_tro import ( TRACE_TRO_FILENAME, build_trace_tro_from_release_manifest, @@ -54,6 +58,17 @@ } +def _resolve_staging_run_id(run_id: str = "") -> str: + return run_id or resolve_run_id() + + +def _run_context_for_release() -> dict | None: + run_id = resolve_run_id() + if not run_id: + return None + return RunContext.from_env(run_id=run_id).to_dict() + + def _get_model_package_version( package_name: str = "policyengine-us", ) -> Optional[str]: @@ -275,6 +290,7 @@ def create_release_manifest_commit_operations( model_package_version: Optional[str] = None, model_package_git_sha: Optional[str] = None, model_package_data_build_fingerprint: Optional[str] = None, + run_context: Optional[Dict] = None, existing_manifest: Optional[Dict] = None, ) -> Tuple[Dict, List[CommitOperationAdd]]: manifest = build_release_manifest( @@ -285,6 +301,7 @@ def create_release_manifest_commit_operations( model_package_version=model_package_version, model_package_git_sha=model_package_git_sha, model_package_data_build_fingerprint=model_package_data_build_fingerprint, + run_context=run_context, existing_manifest=existing_manifest, ) manifest_payload = serialize_release_manifest(manifest) @@ -489,6 +506,7 @@ def upload_files_to_hf( model_package_data_build_fingerprint=model_build_metadata[ "data_build_fingerprint" ], + run_context=_run_context_for_release(), existing_manifest=existing_manifest, ) hf_operations.extend(manifest_operations) @@ -691,6 +709,7 @@ def publish_release_manifest_to_hf( model_package_data_build_fingerprint=model_build_metadata[ "data_build_fingerprint" ], + run_context=_run_context_for_release(), existing_manifest=existing_manifest, ) parent_commit = get_repo_head_revision( @@ -789,12 +808,28 @@ def upload_to_staging_hf( """ token = os.environ.get("HUGGING_FACE_TOKEN") api = HfApi() - staging_prefix = f"staging/{run_id}" if run_id else "staging" + run_id = _resolve_staging_run_id(run_id) + staging_prefix = _staging_prefix(run_id) + context_payload = None + if run_id: + context_payload = RunContext.from_env(run_id=run_id).to_dict() + context_payload["hf_staging_prefix"] = staging_prefix total_uploaded = 0 for i in range(0, len(files_with_paths), batch_size): batch = files_with_paths[i : i + batch_size] operations = [] + if i == 0 and context_payload is not None: + operations.append( + CommitOperationAdd( + path_in_repo=f"{staging_prefix}/_run_context.json", + path_or_fileobj=BytesIO( + ( + json.dumps(context_payload, indent=2, sort_keys=True) + "\n" + ).encode("utf-8") + ), + ) + ) for local_path, rel_path in batch: local_path = Path(local_path) if not local_path.exists(): @@ -816,11 +851,18 @@ def upload_to_staging_hf( repo_id=hf_repo_name, repo_type=hf_repo_type, token=token, - commit_message=f"Upload batch {i // batch_size + 1} to staging for version {version}", + commit_message=( + f"Upload batch {i // batch_size + 1} to staging " + f"for version {version}" + (f" ({run_id})" if run_id else "") + ), ) - total_uploaded += len(operations) + uploaded_files = len(operations) - ( + 1 if i == 0 and context_payload is not None else 0 + ) + total_uploaded += uploaded_files logging.info( - f"Uploaded batch {i // batch_size + 1}: {len(operations)} files to staging/" + f"Uploaded batch {i // batch_size + 1}: " + f"{uploaded_files} files to {staging_prefix}/" ) logging.info(f"Total: uploaded {total_uploaded} files to staging/ in HuggingFace") @@ -828,6 +870,7 @@ def upload_to_staging_hf( def _staging_prefix(run_id: str = "") -> str: + run_id = _resolve_staging_run_id(run_id) return f"staging/{run_id}" if run_id else "staging" @@ -859,6 +902,7 @@ def promote_staging_to_production_hf( """ token = os.environ.get("HUGGING_FACE_TOKEN") api = HfApi() + run_id = _resolve_staging_run_id(run_id) staging_prefix = _staging_prefix(run_id) operations = [] @@ -887,7 +931,10 @@ def promote_staging_to_production_hf( repo_id=hf_repo_name, repo_type=hf_repo_type, token=token, - commit_message=f"Promote {len(files)} files from staging to production for version {version}", + commit_message=( + f"Promote {len(files)} files from staging to production " + f"for version {version}" + (f" ({run_id})" if run_id else "") + ), ) if result.oid == head_before: @@ -927,6 +974,7 @@ def cleanup_staging_hf( """ token = os.environ.get("HUGGING_FACE_TOKEN") api = HfApi() + run_id = _resolve_staging_run_id(run_id) staging_prefix = _staging_prefix(run_id) operations = [] @@ -949,7 +997,10 @@ def cleanup_staging_hf( repo_id=hf_repo_name, repo_type=hf_repo_type, token=token, - commit_message=f"Clean up staging after version {version} promotion", + commit_message=( + f"Clean up staging after version {version} promotion" + + (f" ({run_id})" if run_id else "") + ), ) if result.oid == head_before: @@ -984,6 +1035,7 @@ def upload_from_hf_staging_to_gcs( Number of files uploaded """ token = os.environ.get("HUGGING_FACE_TOKEN") + run_id = _resolve_staging_run_id(run_id) staging_prefix = _staging_prefix(run_id) credentials, project_id = google.auth.default() diff --git a/policyengine_us_data/utils/release_manifest.py b/policyengine_us_data/utils/release_manifest.py index d85f8e8fb..d4d46b9fd 100644 --- a/policyengine_us_data/utils/release_manifest.py +++ b/policyengine_us_data/utils/release_manifest.py @@ -40,6 +40,7 @@ def _base_manifest( model_package_version: str | None, model_package_git_sha: str | None, model_package_data_build_fingerprint: str | None, + run_context: Mapping[str, str] | None, build_id: str, created_at: str, ) -> Dict: @@ -69,6 +70,8 @@ def _base_manifest( "git_sha": model_package_git_sha, "data_build_fingerprint": model_package_data_build_fingerprint, } + if run_context: + manifest["build"]["run"] = dict(run_context) if model_package_version: manifest["compatible_model_packages"].append( { @@ -103,6 +106,7 @@ def build_release_manifest( model_package_version: str | None = None, model_package_git_sha: str | None = None, model_package_data_build_fingerprint: str | None = None, + run_context: Mapping[str, str] | None = None, build_id: str | None = None, existing_manifest: Mapping | None = None, default_datasets: Optional[Mapping[str, str]] = None, @@ -124,6 +128,7 @@ def build_release_manifest( model_package_version=model_package_version, model_package_git_sha=model_package_git_sha, model_package_data_build_fingerprint=model_package_data_build_fingerprint, + run_context=run_context, build_id=resolved_build_id, created_at=manifest_timestamp, ) @@ -144,6 +149,8 @@ def build_release_manifest( "git_sha": model_package_git_sha, "data_build_fingerprint": model_package_data_build_fingerprint, } + if run_context: + manifest["build"]["run"] = dict(run_context) if model_package_version: manifest["compatible_model_packages"] = [ { diff --git a/policyengine_us_data/utils/run_context.py b/policyengine_us_data/utils/run_context.py new file mode 100644 index 000000000..a1573cc2f --- /dev/null +++ b/policyengine_us_data/utils/run_context.py @@ -0,0 +1,231 @@ +"""Run identity helpers for US data publication runs. + +The run ID is the cross-system correlation key for one candidate publication +attempt. GitHub creates it first, Modal records it while running, and Hugging +Face staging uses it as the staging namespace. +""" + +from __future__ import annotations + +import hashlib +import json +import os +import re +from dataclasses import asdict, dataclass +from typing import Mapping + + +RUN_ID_ENV = "US_DATA_RUN_ID" +MODAL_APP_NAME_ENV = "US_DATA_MODAL_APP_NAME" +MODAL_ENVIRONMENT_ENV = "US_DATA_MODAL_ENVIRONMENT" +DEFAULT_MODAL_APP_PREFIX = "policyengine-us-data-pub" +DEFAULT_MODAL_ENVIRONMENT = "main" +DEFAULT_MAX_RESOURCE_NAME_LENGTH = 64 + + +def _slugify(value: str) -> str: + slug = re.sub(r"[^a-z0-9-]+", "-", value.lower()) + slug = re.sub(r"-+", "-", slug).strip("-") + return slug + + +def _truncate_with_digest(value: str, max_length: int) -> str: + if len(value) <= max_length: + return value + digest = hashlib.sha1(value.encode("utf-8")).hexdigest()[:8] + head_length = max_length - len(digest) - 1 + return f"{value[:head_length].rstrip('-')}-{digest}" + + +def sanitize_run_id(value: str) -> str: + """Return a Modal/HF-path-safe run ID.""" + slug = _slugify(value) + if not slug: + raise ValueError("Run ID cannot be empty") + return _truncate_with_digest(slug, DEFAULT_MAX_RESOURCE_NAME_LENGTH) + + +def build_run_id( + *, + github_run_id: str, + github_run_attempt: str, + github_sha: str, +) -> str: + """Build a deterministic run ID from GitHub Actions identity.""" + if not github_run_id: + raise ValueError("github_run_id is required") + attempt = github_run_attempt or "1" + sha = (github_sha or "unknown")[:8] + return sanitize_run_id(f"usdata-gha{github_run_id}-a{attempt}-{sha}") + + +def build_modal_resource_name( + run_id: str, + *, + prefix: str = DEFAULT_MODAL_APP_PREFIX, + max_length: int = DEFAULT_MAX_RESOURCE_NAME_LENGTH, +) -> str: + """Build a safe Modal app or volume name from a run ID.""" + return _truncate_with_digest( + _slugify(f"{prefix}-{sanitize_run_id(run_id)}"), + max_length, + ) + + +def staging_prefix(run_id: str = "") -> str: + return f"staging/{run_id}" if run_id else "staging" + + +def github_run_url(env: Mapping[str, str]) -> str: + repository = env.get("GITHUB_REPOSITORY", "") + run_id = env.get("GITHUB_RUN_ID", "") + if not repository or not run_id: + return "" + server_url = env.get("GITHUB_SERVER_URL", "https://github.com") + return f"{server_url}/{repository}/actions/runs/{run_id}" + + +def resolve_run_id( + explicit: str = "", + *, + env: Mapping[str, str] | None = None, +) -> str: + """Resolve the canonical run ID from an explicit value or GitHub context.""" + env = env or os.environ + candidate = explicit or env.get(RUN_ID_ENV, "") or env.get("RUN_ID", "") + if candidate: + return sanitize_run_id(candidate) + if env.get("GITHUB_RUN_ID"): + return build_run_id( + github_run_id=env.get("GITHUB_RUN_ID", ""), + github_run_attempt=env.get("GITHUB_RUN_ATTEMPT", "1"), + github_sha=env.get("GITHUB_SHA", ""), + ) + return "" + + +@dataclass(frozen=True) +class RunContext: + """Cross-system context for one publication run.""" + + run_id: str + modal_app_name: str + modal_environment: str + hf_staging_prefix: str + github_run_url: str = "" + github_repository: str = "" + github_workflow: str = "" + github_ref: str = "" + github_ref_name: str = "" + github_sha: str = "" + github_run_id: str = "" + github_run_attempt: str = "" + pipeline_volume_name: str = "" + staging_volume_name: str = "" + checkpoint_volume_name: str = "" + + @classmethod + def from_env( + cls, + *, + run_id: str = "", + modal_app_name: str = "", + modal_environment: str = "", + env: Mapping[str, str] | None = None, + modal_app_prefix: str = DEFAULT_MODAL_APP_PREFIX, + ) -> "RunContext": + env = env or os.environ + resolved_run_id = resolve_run_id(run_id, env=env) + resolved_modal_environment = ( + modal_environment + or env.get(MODAL_ENVIRONMENT_ENV, "") + or env.get("MODAL_ENVIRONMENT", "") + or DEFAULT_MODAL_ENVIRONMENT + ) + resolved_modal_app_name = ( + modal_app_name + or env.get(MODAL_APP_NAME_ENV, "") + or env.get("MODAL_APP_NAME", "") + or ( + build_modal_resource_name( + resolved_run_id, + prefix=modal_app_prefix, + ) + if resolved_run_id + else "" + ) + ) + return cls( + run_id=resolved_run_id, + modal_app_name=resolved_modal_app_name, + modal_environment=resolved_modal_environment, + hf_staging_prefix=staging_prefix(resolved_run_id), + github_run_url=env.get("US_DATA_GITHUB_RUN_URL", "") or github_run_url(env), + github_repository=env.get("GITHUB_REPOSITORY", ""), + github_workflow=env.get("GITHUB_WORKFLOW", ""), + github_ref=env.get("GITHUB_REF", ""), + github_ref_name=env.get("GITHUB_REF_NAME", ""), + github_sha=env.get("GITHUB_SHA", ""), + github_run_id=env.get("GITHUB_RUN_ID", ""), + github_run_attempt=env.get("GITHUB_RUN_ATTEMPT", ""), + pipeline_volume_name=env.get("US_DATA_PIPELINE_VOLUME_NAME", ""), + staging_volume_name=env.get("US_DATA_STAGING_VOLUME_NAME", ""), + checkpoint_volume_name=env.get("US_DATA_CHECKPOINT_VOLUME_NAME", ""), + ) + + @classmethod + def from_mapping( + cls, + data: Mapping[str, object] | None, + *, + env: Mapping[str, str] | None = None, + run_id: str = "", + modal_app_name: str = "", + modal_environment: str = "", + ) -> "RunContext": + base = cls.from_env( + run_id=run_id, + modal_app_name=modal_app_name, + modal_environment=modal_environment, + env=env, + ) + if not data: + return base + merged = asdict(base) + for key, value in data.items(): + if key == "publication_id": + key = "run_id" + if key in merged and value: + merged[key] = str(value) + if merged.get("run_id"): + merged["run_id"] = sanitize_run_id(str(merged["run_id"])) + merged["hf_staging_prefix"] = staging_prefix(merged["run_id"]) + return cls(**merged) + + def to_dict(self) -> dict[str, str]: + return { + key: value for key, value in asdict(self).items() if value not in ("", None) + } + + def to_json(self) -> str: + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + + def export_env(self) -> dict[str, str]: + """Return environment variables representing this context.""" + values = { + RUN_ID_ENV: self.run_id, + "RUN_ID": self.run_id, + MODAL_APP_NAME_ENV: self.modal_app_name, + "MODAL_APP_NAME": self.modal_app_name, + MODAL_ENVIRONMENT_ENV: self.modal_environment, + "MODAL_ENVIRONMENT": self.modal_environment, + "US_DATA_HF_STAGING_PREFIX": self.hf_staging_prefix, + "US_DATA_GITHUB_RUN_URL": self.github_run_url, + } + if self.pipeline_volume_name: + values["US_DATA_PIPELINE_VOLUME_NAME"] = self.pipeline_volume_name + if self.staging_volume_name: + values["US_DATA_STAGING_VOLUME_NAME"] = self.staging_volume_name + if self.checkpoint_volume_name: + values["US_DATA_CHECKPOINT_VOLUME_NAME"] = self.checkpoint_volume_name + return {key: value for key, value in values.items() if value} diff --git a/policyengine_us_data/utils/run_id.py b/policyengine_us_data/utils/run_id.py deleted file mode 100644 index 3a9d95b82..000000000 --- a/policyengine_us_data/utils/run_id.py +++ /dev/null @@ -1,6 +0,0 @@ -from datetime import datetime, timezone - - -def generate_run_id(version: str, sha: str) -> str: - ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") - return f"{version}_{sha[:8]}_{ts}" diff --git a/policyengine_us_data/utils/step_manifest.py b/policyengine_us_data/utils/step_manifest.py new file mode 100644 index 000000000..816a797bc --- /dev/null +++ b/policyengine_us_data/utils/step_manifest.py @@ -0,0 +1,590 @@ +"""Run-scoped execution manifests for pipeline steps. + +Step manifests are execution records: they describe what a pipeline step +read, wrote, reused, invalidated, and failed for one run ID. They are kept +separate from release manifests, which remain the publication contract. +""" + +from __future__ import annotations + +import hashlib +import json +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Iterable, Mapping, Sequence + + +STEP_MANIFEST_SCHEMA_VERSION = "1" + +COMPLETED_STATUSES = frozenset({"completed", "reused", "partially_reused"}) +VALID_STATUSES = frozenset( + { + "pending", + "running", + "completed", + "failed", + "reused", + "partially_reused", + } +) +VALID_REUSE_DECISIONS = frozenset( + { + "computed", + "reused", + "partially_reused", + "invalidated", + "failed", + "not_applicable", + } +) + + +def utc_now() -> str: + """Return an ISO-8601 UTC timestamp.""" + return datetime.now(timezone.utc).isoformat() + + +def canonical_json_dumps(payload: Mapping[str, Any]) -> str: + """Serialize manifest JSON deterministically.""" + return json.dumps(payload, indent=2, sort_keys=True) + "\n" + + +def _drop_none(value: Any) -> Any: + if isinstance(value, dict): + return {k: _drop_none(v) for k, v in value.items() if v is not None} + if isinstance(value, list): + return [_drop_none(v) for v in value] + return value + + +def sha256_file(path: Path) -> str: + """Compute a file SHA-256 digest as lowercase hex.""" + digest = hashlib.sha256() + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(1024 * 1024), b""): + digest.update(chunk) + return digest.hexdigest() + + +def _manifest_path(path: Path, *, base_dir: Path | None = None) -> str: + if base_dir is None: + return str(path) + try: + return str(path.relative_to(base_dir)) + except ValueError: + return str(path) + + +@dataclass(frozen=True) +class ArtifactReference: + """Durable artifact reference recorded in a step manifest.""" + + path: str + size_bytes: int + sha256: str + role: str = "output" + media_type: str | None = None + + @classmethod + def from_path( + cls, + path: str | Path, + *, + role: str = "output", + base_dir: str | Path | None = None, + manifest_path: str | None = None, + media_type: str | None = None, + ) -> "ArtifactReference": + artifact_path = Path(path) + if not artifact_path.exists(): + raise FileNotFoundError(f"Cannot record missing artifact: {artifact_path}") + if not artifact_path.is_file(): + raise ValueError(f"Step manifest artifacts must be files: {artifact_path}") + base = Path(base_dir) if base_dir is not None else None + return cls( + path=manifest_path or _manifest_path(artifact_path, base_dir=base), + size_bytes=artifact_path.stat().st_size, + sha256=sha256_file(artifact_path), + role=role, + media_type=media_type, + ) + + def to_dict(self) -> dict[str, Any]: + return _drop_none(asdict(self)) + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "ArtifactReference": + return cls( + path=str(data["path"]), + size_bytes=int(data["size_bytes"]), + sha256=str(data["sha256"]), + role=str(data.get("role", "output")), + media_type=data.get("media_type"), + ) + + +@dataclass(frozen=True) +class ReuseMeasurement: + """Measured reuse/recompute counts for one step.""" + + expected_outputs: int = 0 + valid_reused_outputs: int = 0 + recomputed_outputs: int = 0 + invalid_outputs: int = 0 + saved_duration_s: float | None = None + + def to_dict(self) -> dict[str, Any]: + return _drop_none(asdict(self)) + + @classmethod + def from_dict(cls, data: Mapping[str, Any] | None) -> "ReuseMeasurement": + data = data or {} + return cls( + expected_outputs=int(data.get("expected_outputs", 0)), + valid_reused_outputs=int(data.get("valid_reused_outputs", 0)), + recomputed_outputs=int(data.get("recomputed_outputs", 0)), + invalid_outputs=int(data.get("invalid_outputs", 0)), + saved_duration_s=data.get("saved_duration_s"), + ) + + +@dataclass(frozen=True) +class OutputValidation: + """Result of validating manifest-declared outputs.""" + + valid: bool + reason: str + missing_outputs: tuple[str, ...] = () + checksum_mismatches: tuple[str, ...] = () + + +@dataclass(frozen=True) +class StepReuseDecision: + """Manifest-backed decision about whether a step can be reused.""" + + reusable: bool + reason: str + manifest: "StepManifest | None" = None + validation: OutputValidation | None = None + + +@dataclass +class StepManifest: + """Execution manifest for one meaningful pipeline step.""" + + run_id: str + step_id: str + status: str + attempt: int + started_at: str + completed_at: str | None = None + duration_s: float | None = None + branch: str | None = None + sha: str | None = None + version: str | None = None + modal_app_name: str | None = None + modal_environment: str | None = None + hf_staging_prefix: str | None = None + scope: str | None = None + modal_app_id: str | None = None + modal_call_id: str | None = None + parameters: dict[str, Any] = field(default_factory=dict) + input_identities: dict[str, Any] = field(default_factory=dict) + outputs: list[ArtifactReference] = field(default_factory=list) + diagnostics: list[ArtifactReference] = field(default_factory=list) + reuse_decision: str = "not_applicable" + reuse_reason: str | None = None + reuse_measurement: ReuseMeasurement = field(default_factory=ReuseMeasurement) + error: dict[str, Any] | None = None + schema_version: str = STEP_MANIFEST_SCHEMA_VERSION + + def __post_init__(self) -> None: + if self.status not in VALID_STATUSES: + raise ValueError(f"Invalid step manifest status: {self.status}") + if self.reuse_decision not in VALID_REUSE_DECISIONS: + raise ValueError( + f"Invalid step manifest reuse decision: {self.reuse_decision}" + ) + + def complete( + self, + *, + completed_at: str | None = None, + status: str = "completed", + outputs: Sequence[ArtifactReference] | None = None, + diagnostics: Sequence[ArtifactReference] | None = None, + reuse_decision: str = "computed", + reuse_reason: str | None = None, + reuse_measurement: ReuseMeasurement | None = None, + ) -> "StepManifest": + completed = completed_at or utc_now() + started = datetime.fromisoformat(self.started_at) + ended = datetime.fromisoformat(completed) + return StepManifest( + run_id=self.run_id, + step_id=self.step_id, + status=status, + attempt=self.attempt, + started_at=self.started_at, + completed_at=completed, + duration_s=round((ended - started).total_seconds(), 1), + branch=self.branch, + sha=self.sha, + version=self.version, + modal_app_name=self.modal_app_name, + modal_environment=self.modal_environment, + hf_staging_prefix=self.hf_staging_prefix, + scope=self.scope, + modal_app_id=self.modal_app_id, + modal_call_id=self.modal_call_id, + parameters=self.parameters, + input_identities=self.input_identities, + outputs=list(outputs if outputs is not None else self.outputs), + diagnostics=list( + diagnostics if diagnostics is not None else self.diagnostics + ), + reuse_decision=reuse_decision, + reuse_reason=reuse_reason, + reuse_measurement=reuse_measurement or self.reuse_measurement, + schema_version=self.schema_version, + ) + + def fail( + self, + exc: BaseException, + *, + completed_at: str | None = None, + ) -> "StepManifest": + completed = completed_at or utc_now() + started = datetime.fromisoformat(self.started_at) + ended = datetime.fromisoformat(completed) + return StepManifest( + run_id=self.run_id, + step_id=self.step_id, + status="failed", + attempt=self.attempt, + started_at=self.started_at, + completed_at=completed, + duration_s=round((ended - started).total_seconds(), 1), + branch=self.branch, + sha=self.sha, + version=self.version, + modal_app_name=self.modal_app_name, + modal_environment=self.modal_environment, + hf_staging_prefix=self.hf_staging_prefix, + scope=self.scope, + modal_app_id=self.modal_app_id, + modal_call_id=self.modal_call_id, + parameters=self.parameters, + input_identities=self.input_identities, + outputs=self.outputs, + diagnostics=self.diagnostics, + reuse_decision="failed", + reuse_reason="step_failed", + reuse_measurement=self.reuse_measurement, + error={ + "type": type(exc).__name__, + "message": str(exc), + }, + schema_version=self.schema_version, + ) + + def to_dict(self) -> dict[str, Any]: + payload = { + "schema_version": self.schema_version, + "run_id": self.run_id, + "step_id": self.step_id, + "scope": self.scope, + "status": self.status, + "attempt": self.attempt, + "started_at": self.started_at, + "completed_at": self.completed_at, + "duration_s": self.duration_s, + "branch": self.branch, + "sha": self.sha, + "version": self.version, + "modal_app_name": self.modal_app_name, + "modal_environment": self.modal_environment, + "hf_staging_prefix": self.hf_staging_prefix, + "modal_app_id": self.modal_app_id, + "modal_call_id": self.modal_call_id, + "parameters": self.parameters, + "input_identities": self.input_identities, + "outputs": [artifact.to_dict() for artifact in self.outputs], + "diagnostics": [artifact.to_dict() for artifact in self.diagnostics], + "reuse_decision": self.reuse_decision, + "reuse_reason": self.reuse_reason, + "reuse_measurement": self.reuse_measurement.to_dict(), + "error": self.error, + } + return _drop_none(payload) + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "StepManifest": + return cls( + schema_version=str( + data.get("schema_version", STEP_MANIFEST_SCHEMA_VERSION) + ), + run_id=str(data["run_id"]), + step_id=str(data["step_id"]), + scope=data.get("scope"), + status=str(data["status"]), + attempt=int(data["attempt"]), + started_at=str(data["started_at"]), + completed_at=data.get("completed_at"), + duration_s=data.get("duration_s"), + branch=data.get("branch"), + sha=data.get("sha"), + version=data.get("version"), + modal_app_name=data.get("modal_app_name"), + modal_environment=data.get("modal_environment"), + hf_staging_prefix=data.get("hf_staging_prefix"), + modal_app_id=data.get("modal_app_id"), + modal_call_id=data.get("modal_call_id"), + parameters=dict(data.get("parameters", {})), + input_identities=dict(data.get("input_identities", {})), + outputs=[ + ArtifactReference.from_dict(item) for item in data.get("outputs", []) + ], + diagnostics=[ + ArtifactReference.from_dict(item) + for item in data.get("diagnostics", []) + ], + reuse_decision=str(data.get("reuse_decision", "not_applicable")), + reuse_reason=data.get("reuse_reason"), + reuse_measurement=ReuseMeasurement.from_dict(data.get("reuse_measurement")), + error=data.get("error"), + ) + + def to_json(self) -> str: + return canonical_json_dumps(self.to_dict()) + + +@dataclass +class RunManifest: + """Run-level execution ledger for step manifests.""" + + run_id: str + branch: str + sha: str + version: str + status: str + started_at: str + known_step_ids: list[str] + run_context: dict[str, Any] = field(default_factory=dict) + modal_app_name: str | None = None + modal_environment: str | None = None + hf_staging_prefix: str | None = None + updated_at: str | None = None + completed_at: str | None = None + resume_history: list[dict[str, Any]] = field(default_factory=list) + error: str | None = None + schema_version: str = STEP_MANIFEST_SCHEMA_VERSION + + def to_dict(self) -> dict[str, Any]: + return _drop_none(asdict(self)) + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "RunManifest": + return cls( + schema_version=str( + data.get("schema_version", STEP_MANIFEST_SCHEMA_VERSION) + ), + run_id=str(data["run_id"]), + branch=str(data["branch"]), + sha=str(data["sha"]), + version=str(data["version"]), + status=str(data["status"]), + started_at=str(data["started_at"]), + run_context=dict( + data.get("run_context") or data.get("publication_context", {}) + ), + modal_app_name=data.get("modal_app_name"), + modal_environment=data.get("modal_environment"), + hf_staging_prefix=data.get("hf_staging_prefix"), + updated_at=data.get("updated_at"), + completed_at=data.get("completed_at"), + known_step_ids=list(data.get("known_step_ids", [])), + resume_history=list(data.get("resume_history", [])), + error=data.get("error"), + ) + + def to_json(self) -> str: + return canonical_json_dumps(self.to_dict()) + + +def run_manifest_path(run_dir: str | Path) -> Path: + return Path(run_dir) / "run_manifest.json" + + +def step_manifest_dir(run_dir: str | Path) -> Path: + return Path(run_dir) / "steps" + + +def step_manifest_path(run_dir: str | Path, step_id: str) -> Path: + return step_manifest_dir(run_dir) / f"{step_id}.json" + + +def write_run_manifest(path: str | Path, manifest: RunManifest) -> None: + output_path = Path(path) + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(manifest.to_json()) + + +def read_run_manifest(path: str | Path) -> RunManifest: + return RunManifest.from_dict(json.loads(Path(path).read_text())) + + +def write_step_manifest(path: str | Path, manifest: StepManifest) -> None: + output_path = Path(path) + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(manifest.to_json()) + + +def read_step_manifest(path: str | Path) -> StepManifest: + return StepManifest.from_dict(json.loads(Path(path).read_text())) + + +def collect_artifacts( + paths: Iterable[str | Path], + *, + role: str = "output", + base_dir: str | Path | None = None, + missing_ok: bool = False, +) -> list[ArtifactReference]: + artifacts: list[ArtifactReference] = [] + for path in paths: + artifact_path = Path(path) + if not artifact_path.exists(): + if missing_ok: + continue + raise FileNotFoundError(f"Expected artifact does not exist: {path}") + artifacts.append( + ArtifactReference.from_path( + artifact_path, + role=role, + base_dir=base_dir, + ) + ) + return artifacts + + +def collect_directory_artifacts( + root: str | Path, + *, + patterns: Sequence[str] = ("*",), + role: str = "output", + base_dir: str | Path | None = None, +) -> list[ArtifactReference]: + root_path = Path(root) + if not root_path.exists(): + return [] + paths: list[Path] = [] + for pattern in patterns: + paths.extend(path for path in root_path.glob(pattern) if path.is_file()) + return [ + ArtifactReference.from_path(path, role=role, base_dir=base_dir) + for path in sorted(set(paths)) + ] + + +def _resolve_artifact_path(path: str, *, root: str | Path | None = None) -> Path: + artifact_path = Path(path) + if artifact_path.is_absolute() or root is None: + return artifact_path + return Path(root) / artifact_path + + +def _contains_expected_values( + actual: Mapping[str, Any], + expected: Mapping[str, Any], +) -> bool: + """Return True when every expected key/value is present in actual.""" + return all(actual.get(key) == value for key, value in expected.items()) + + +def validate_step_outputs( + manifest: StepManifest, + *, + root: str | Path | None = None, +) -> OutputValidation: + missing: list[str] = [] + mismatches: list[str] = [] + + for artifact in manifest.outputs: + path = _resolve_artifact_path(artifact.path, root=root) + if not path.exists(): + missing.append(artifact.path) + continue + actual_sha = sha256_file(path) + if actual_sha != artifact.sha256: + mismatches.append(artifact.path) + + if missing: + return OutputValidation(False, "missing_output", tuple(missing), ()) + if mismatches: + return OutputValidation(False, "checksum_mismatch", (), tuple(mismatches)) + return OutputValidation(True, "valid") + + +def evaluate_step_reuse( + manifest_path_value: str | Path, + *, + expected_input_identities: Mapping[str, Any] | None = None, + expected_parameters: Mapping[str, Any] | None = None, + output_root: str | Path | None = None, +) -> StepReuseDecision: + path = Path(manifest_path_value) + if not path.exists(): + return StepReuseDecision(False, "missing_manifest") + + manifest = read_step_manifest(path) + if manifest.status not in COMPLETED_STATUSES: + return StepReuseDecision(False, "incomplete_status", manifest=manifest) + + if expected_input_identities is not None and not _contains_expected_values( + manifest.input_identities, dict(expected_input_identities) + ): + return StepReuseDecision(False, "input_changed", manifest=manifest) + + if expected_parameters is not None and manifest.parameters != dict( + expected_parameters + ): + return StepReuseDecision(False, "parameter_changed", manifest=manifest) + + validation = validate_step_outputs(manifest, root=output_root) + if not validation.valid: + return StepReuseDecision( + False, + validation.reason, + manifest=manifest, + validation=validation, + ) + + return StepReuseDecision( + True, "prior_success", manifest=manifest, validation=validation + ) + + +def completed_validated_outputs( + run_dir: str | Path, + *, + output_root: str | Path | None = None, + step_ids: Iterable[str] | None = None, +) -> list[ArtifactReference]: + """Return validated outputs from completed step manifests.""" + root = Path(run_dir) + wanted = set(step_ids) if step_ids is not None else None + outputs: list[ArtifactReference] = [] + for manifest_file in sorted(step_manifest_dir(root).glob("*.json")): + manifest = read_step_manifest(manifest_file) + if wanted is not None and manifest.step_id not in wanted: + continue + if manifest.status not in COMPLETED_STATUSES: + continue + validation = validate_step_outputs(manifest, root=output_root) + if not validation.valid: + continue + outputs.extend(manifest.outputs) + return outputs diff --git a/policyengine_us_data/utils/version_manifest.py b/policyengine_us_data/utils/version_manifest.py index c5479307a..992bff54a 100644 --- a/policyengine_us_data/utils/version_manifest.py +++ b/policyengine_us_data/utils/version_manifest.py @@ -93,7 +93,7 @@ class VersionManifest: gcs: GCSVersionInfo special_operation: Optional[str] = None roll_back_version: Optional[str] = None - pipeline_run_id: Optional[str] = None + run_id: Optional[str] = None diagnostics_path: Optional[str] = None policyengine_us: Optional[PolicyEngineUSBuildInfo] = None @@ -108,8 +108,8 @@ def to_dict(self) -> dict[str, Any]: result["special_operation"] = self.special_operation if self.roll_back_version is not None: result["roll_back_version"] = self.roll_back_version - if self.pipeline_run_id is not None: - result["pipeline_run_id"] = self.pipeline_run_id + if self.run_id is not None: + result["run_id"] = self.run_id if self.diagnostics_path is not None: result["diagnostics_path"] = self.diagnostics_path if self.policyengine_us is not None: @@ -126,7 +126,7 @@ def from_dict(cls, data: dict[str, Any]) -> "VersionManifest": gcs=GCSVersionInfo.from_dict(data["gcs"]), special_operation=data.get("special_operation"), roll_back_version=data.get("roll_back_version"), - pipeline_run_id=data.get("pipeline_run_id"), + run_id=data.get("run_id") or data.get("pipeline_run_id"), diagnostics_path=data.get("diagnostics_path"), policyengine_us=( PolicyEngineUSBuildInfo.from_dict(data["policyengine_us"]) diff --git a/tests/unit/test_modal_data_build.py b/tests/unit/test_modal_data_build.py index 161fb9998..ded10d69e 100644 --- a/tests/unit/test_modal_data_build.py +++ b/tests/unit/test_modal_data_build.py @@ -35,6 +35,32 @@ def decorator(func): return importlib.import_module("modal_app.data_build") +def test_checkpoint_stats_are_per_instance(): + data_build = _load_data_build_module() + first = data_build.CheckpointStats() + second = data_build.CheckpointStats() + + first.record( + expected_outputs=3, + valid_reused_outputs=1, + recomputed_outputs=2, + invalid_outputs=2, + ) + + assert first.snapshot() == { + "expected_outputs": 3, + "valid_reused_outputs": 1, + "recomputed_outputs": 2, + "invalid_outputs": 2, + } + assert second.snapshot() == { + "expected_outputs": 0, + "valid_reused_outputs": 0, + "recomputed_outputs": 0, + "invalid_outputs": 0, + } + + def test_validate_and_maybe_upload_datasets_validates_before_upload(monkeypatch): data_build = _load_data_build_module() calls = [] @@ -142,6 +168,7 @@ def fake_run_script_with_checkpoint( args=None, env=None, log_file=None, + checkpoint_stats=None, ): calls.append( ( @@ -152,6 +179,7 @@ def fake_run_script_with_checkpoint( args, env, log_file, + checkpoint_stats, ) ) return script_path @@ -179,6 +207,7 @@ def fake_run_script_with_checkpoint( None, env, log_file, + None, ), ( data_build.PUF_BUILD_SCRIPT, @@ -188,5 +217,6 @@ def fake_run_script_with_checkpoint( None, env, log_file, + None, ), ] diff --git a/tests/unit/test_pipeline.py b/tests/unit/test_pipeline.py index 2d126e71f..8c634e043 100644 --- a/tests/unit/test_pipeline.py +++ b/tests/unit/test_pipeline.py @@ -11,13 +11,14 @@ modal = pytest.importorskip("modal") from modal_app.pipeline import ( # noqa: E402 - RunMetadata, _build_diagnostics_upload_script, - _step_completed, - _record_step, - generate_run_id, - write_run_meta, +) +from modal_app.step_manifests.runtime import ( # noqa: E402 + RunMetadata, read_run_meta, + record_step, + step_completed, + write_run_meta, ) @@ -160,34 +161,6 @@ def test_step_timings_default_empty(self): assert meta.step_timings == {} -# -- generate_run_id tests ------------------------------------- - - -class TestGenerateRunId: - def test_format(self): - run_id = generate_run_id("1.72.3", "abc12345deadbeef") - - parts = run_id.split("_") - assert parts[0] == "1.72.3" - assert parts[1] == "abc12345" - assert len(parts) == 4 # version_sha_date_time - - def test_sha_truncated_to_8(self): - run_id = generate_run_id("1.0.0", "abcdef1234567890") - sha_part = run_id.split("_")[1] - assert sha_part == "abcdef12" - assert len(sha_part) == 8 - - def test_unique_ids(self): - id1 = generate_run_id("1.0.0", "abc123") - time.sleep(0.01) - id2 = generate_run_id("1.0.0", "abc123") - # Timestamps should differ (or at least - # the function doesn't reuse) - assert isinstance(id1, str) - assert isinstance(id2, str) - - # -- _step_completed tests ------------------------------------ @@ -207,7 +180,7 @@ def test_completed_step(self): } }, ) - assert _step_completed(meta, "build_datasets") + assert step_completed(meta, "build_datasets") def test_incomplete_step(self): meta = RunMetadata( @@ -224,7 +197,7 @@ def test_incomplete_step(self): } }, ) - assert not _step_completed(meta, "build_datasets") + assert not step_completed(meta, "build_datasets") def test_missing_step(self): meta = RunMetadata( @@ -235,14 +208,14 @@ def test_missing_step(self): start_time="now", status="running", ) - assert not _step_completed(meta, "build_datasets") + assert not step_completed(meta, "build_datasets") # -- _record_step tests ---------------------------------------- class TestRecordStep: - def test_records_timing(self): + def test_records_timing(self, tmp_path): meta = RunMetadata( run_id="test", branch="main", @@ -254,16 +227,20 @@ def test_records_timing(self): mock_vol = MagicMock() start = time.time() - 5.0 - with patch("modal_app.pipeline.write_run_meta"): - _record_step(meta, "build_datasets", start, mock_vol) + with ( + patch("modal_app.step_manifests.runtime.RUNS_DIR", str(tmp_path / "runs")), + patch("modal_app.step_manifests.runtime.write_run_meta"), + ): + record_step(meta, "build_datasets", start, mock_vol) timing = meta.step_timings["build_datasets"] assert timing["status"] == "completed" assert timing["duration_s"] >= 5.0 assert "start" in timing assert "end" in timing + assert (tmp_path / "runs" / "test" / "steps" / "build_datasets.json").exists() - def test_records_custom_status(self): + def test_records_custom_status(self, tmp_path): meta = RunMetadata( run_id="test", branch="main", @@ -274,8 +251,11 @@ def test_records_custom_status(self): ) mock_vol = MagicMock() - with patch("modal_app.pipeline.write_run_meta"): - _record_step( + with ( + patch("modal_app.step_manifests.runtime.RUNS_DIR", str(tmp_path / "runs")), + patch("modal_app.step_manifests.runtime.write_run_meta"), + ): + record_step( meta, "build_datasets", time.time(), @@ -304,7 +284,7 @@ def test_write_and_read(self, tmp_path): runs_dir = tmp_path / "runs" with patch( - "modal_app.pipeline.RUNS_DIR", + "modal_app.step_manifests.runtime.RUNS_DIR", str(runs_dir), ): write_run_meta(meta, mock_vol) @@ -323,7 +303,7 @@ def test_read_nonexistent_raises(self): mock_vol = MagicMock() with patch( - "modal_app.pipeline.RUNS_DIR", + "modal_app.step_manifests.runtime.RUNS_DIR", "/nonexistent", ): with pytest.raises(FileNotFoundError): diff --git a/tests/unit/test_release_manifest.py b/tests/unit/test_release_manifest.py index 1938a0cda..3a4a3bd0b 100644 --- a/tests/unit/test_release_manifest.py +++ b/tests/unit/test_release_manifest.py @@ -158,6 +158,31 @@ def test_build_release_manifest_merges_existing_release_same_version(tmp_path): assert manifest["artifacts"]["districts/NC-01"]["sha256"] == _sha256(district_bytes) +def test_build_release_manifest_records_run_context(tmp_path): + dataset_path = _write_file( + tmp_path / "enhanced_cps_2024.h5", + b"national-dataset", + ) + + manifest = build_release_manifest( + files_with_repo_paths=[(dataset_path, "enhanced_cps_2024.h5")], + version="1.73.0", + repo_id="policyengine/policyengine-us-data", + run_context={ + "run_id": "usdata-gha123-a1-abcdef12", + "modal_app_name": "policyengine-us-data-pub-usdata-gha123-a1-abcdef12", + "hf_staging_prefix": "staging/usdata-gha123-a1-abcdef12", + }, + created_at="2026-04-10T12:00:00Z", + ) + + assert manifest["build"]["run"] == { + "run_id": "usdata-gha123-a1-abcdef12", + "modal_app_name": "policyengine-us-data-pub-usdata-gha123-a1-abcdef12", + "hf_staging_prefix": "staging/usdata-gha123-a1-abcdef12", + } + + def test_load_release_manifest_from_hf_uses_explicit_revision_when_requested(tmp_path): manifest_path = _write_file( tmp_path / "release_manifest.json", diff --git a/tests/unit/test_run_context.py b/tests/unit/test_run_context.py new file mode 100644 index 000000000..1f86c8b5f --- /dev/null +++ b/tests/unit/test_run_context.py @@ -0,0 +1,90 @@ +from policyengine_us_data.utils.run_context import ( + RunContext, + build_modal_resource_name, + build_run_id, + resolve_run_id, + sanitize_run_id, + staging_prefix, +) + + +def test_run_id_from_github_identity() -> None: + assert ( + build_run_id( + github_run_id="123456789", + github_run_attempt="2", + github_sha="abcdef123456", + ) + == "usdata-gha123456789-a2-abcdef12" + ) + + +def test_run_id_sanitizes_for_modal_and_hf_paths() -> None: + assert sanitize_run_id("Feature/Some PR #12!") == "feature-some-pr-12" + + +def test_modal_resource_name_uses_safe_prefix_and_truncates() -> None: + run_id = "usdata-gha123456789-a1-" + ("a" * 80) + + name = build_modal_resource_name(run_id, prefix="policyengine-us-data-pub") + + assert name.startswith("policyengine-us-data-pub-usdata-gha123456789-a1") + assert len(name) <= 64 + + +def test_resolve_run_id_prefers_explicit_value() -> None: + env = { + "US_DATA_RUN_ID": "from-env", + "GITHUB_RUN_ID": "123", + "GITHUB_RUN_ATTEMPT": "1", + "GITHUB_SHA": "abcdef12", + } + + assert resolve_run_id("Explicit Value", env=env) == "explicit-value" + + +def test_run_context_from_env_records_cross_system_identity() -> None: + env = { + "GITHUB_SERVER_URL": "https://github.com", + "GITHUB_REPOSITORY": "PolicyEngine/policyengine-us-data", + "GITHUB_WORKFLOW": "Run Pipeline", + "GITHUB_REF": "refs/heads/main", + "GITHUB_REF_NAME": "main", + "GITHUB_SHA": "abcdef123456", + "GITHUB_RUN_ID": "123456789", + "GITHUB_RUN_ATTEMPT": "1", + "US_DATA_PIPELINE_VOLUME_NAME": "pipeline-artifacts-test", + "US_DATA_STAGING_VOLUME_NAME": "local-area-staging-test", + "US_DATA_CHECKPOINT_VOLUME_NAME": "data-build-checkpoints-test", + } + + context = RunContext.from_env(env=env) + + assert context.run_id == "usdata-gha123456789-a1-abcdef12" + assert context.modal_app_name == ( + "policyengine-us-data-pub-usdata-gha123456789-a1-abcdef12" + ) + assert context.modal_environment == "main" + assert context.hf_staging_prefix == staging_prefix(context.run_id) + assert context.github_run_url == ( + "https://github.com/PolicyEngine/policyengine-us-data/actions/runs/123456789" + ) + assert context.pipeline_volume_name == "pipeline-artifacts-test" + assert context.staging_volume_name == "local-area-staging-test" + assert context.checkpoint_volume_name == "data-build-checkpoints-test" + + +def test_run_context_export_env_includes_modal_and_hf_values() -> None: + context = RunContext.from_env( + env={"US_DATA_RUN_ID": "run-123"}, + modal_app_name="policyengine-us-data-pub-run-123", + modal_environment="main", + ) + + exported = context.export_env() + + assert exported["US_DATA_RUN_ID"] == "run-123" + assert exported["RUN_ID"] == "run-123" + assert exported["MODAL_APP_NAME"] == "policyengine-us-data-pub-run-123" + assert exported["MODAL_ENVIRONMENT"] == "main" + assert exported["US_DATA_HF_STAGING_PREFIX"] == "staging/run-123" diff --git a/tests/unit/test_step_manifest.py b/tests/unit/test_step_manifest.py new file mode 100644 index 000000000..f3671d5d9 --- /dev/null +++ b/tests/unit/test_step_manifest.py @@ -0,0 +1,232 @@ +import json + +from policyengine_us_data.utils.step_manifest import ( + ArtifactReference, + ReuseMeasurement, + RunManifest, + StepManifest, + completed_validated_outputs, + evaluate_step_reuse, + read_step_manifest, + run_manifest_path, + step_manifest_path, + validate_step_outputs, + write_run_manifest, + write_step_manifest, +) + + +def _write(path, content: bytes): + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(content) + return path + + +def test_step_manifest_serialization_is_deterministic(tmp_path): + output = _write(tmp_path / "artifacts" / "out.h5", b"dataset") + manifest = StepManifest( + run_id="run-1", + step_id="04_build_h5_regional", + scope="regional", + status="completed", + attempt=1, + started_at="2026-04-30T12:00:00+00:00", + completed_at="2026-04-30T12:00:02+00:00", + duration_s=2.0, + branch="main", + sha="abc123", + version="1.0.0", + parameters={"n_clones": 1, "validate": True}, + input_identities={"h5_scope_fingerprint": "abc123fingerprint"}, + outputs=[ArtifactReference.from_path(output, base_dir=tmp_path)], + reuse_decision="computed", + ) + + first = manifest.to_json() + second = manifest.to_json() + + assert first == second + assert json.loads(first)["input_identities"] == { + "h5_scope_fingerprint": "abc123fingerprint" + } + + +def test_evaluate_step_reuse_requires_matching_inputs_parameters_and_outputs(tmp_path): + output = _write(tmp_path / "out.h5", b"dataset") + manifest = StepManifest( + run_id="run-1", + step_id="01_build_datasets", + status="completed", + attempt=1, + started_at="2026-04-30T12:00:00+00:00", + completed_at="2026-04-30T12:00:01+00:00", + parameters={"branch": "main"}, + input_identities={"sha": "abc"}, + outputs=[ArtifactReference.from_path(output, base_dir=tmp_path)], + reuse_decision="computed", + ) + manifest_path = tmp_path / "step.json" + write_step_manifest(manifest_path, manifest) + + decision = evaluate_step_reuse( + manifest_path, + expected_input_identities={"sha": "abc"}, + expected_parameters={"branch": "main"}, + output_root=tmp_path, + ) + + assert decision.reusable is True + assert decision.reason == "prior_success" + + +def test_evaluate_step_reuse_allows_derived_input_identity_fields(tmp_path): + output = _write(tmp_path / "out.h5", b"dataset") + manifest = StepManifest( + run_id="run-1", + step_id="04_build_h5_regional", + status="completed", + attempt=1, + started_at="2026-04-30T12:00:00+00:00", + parameters={"n_clones": 1}, + input_identities={ + "weights": {"sha256": "abc"}, + "h5_scope_fingerprint": "derived-after-run", + }, + outputs=[ArtifactReference.from_path(output, base_dir=tmp_path)], + reuse_decision="computed", + ) + manifest_path = tmp_path / "step.json" + write_step_manifest(manifest_path, manifest) + + decision = evaluate_step_reuse( + manifest_path, + expected_input_identities={"weights": {"sha256": "abc"}}, + expected_parameters={"n_clones": 1}, + output_root=tmp_path, + ) + + assert decision.reusable is True + + +def test_evaluate_step_reuse_recomputes_when_output_checksum_changes(tmp_path): + output = _write(tmp_path / "out.h5", b"dataset") + manifest = StepManifest( + run_id="run-1", + step_id="01_build_datasets", + status="completed", + attempt=1, + started_at="2026-04-30T12:00:00+00:00", + outputs=[ArtifactReference.from_path(output, base_dir=tmp_path)], + reuse_decision="computed", + ) + manifest_path = tmp_path / "step.json" + write_step_manifest(manifest_path, manifest) + output.write_bytes(b"changed") + + decision = evaluate_step_reuse(manifest_path, output_root=tmp_path) + + assert decision.reusable is False + assert decision.reason == "checksum_mismatch" + assert decision.validation.checksum_mismatches == ("out.h5",) + + +def test_validate_step_outputs_reports_missing_files(tmp_path): + output = _write(tmp_path / "out.h5", b"dataset") + artifact = ArtifactReference.from_path(output, base_dir=tmp_path) + output.unlink() + manifest = StepManifest( + run_id="run-1", + step_id="01_build_datasets", + status="completed", + attempt=1, + started_at="2026-04-30T12:00:00+00:00", + outputs=[artifact], + reuse_decision="computed", + ) + + validation = validate_step_outputs(manifest, root=tmp_path) + + assert validation.valid is False + assert validation.reason == "missing_output" + assert validation.missing_outputs == ("out.h5",) + + +def test_partial_h5_reuse_counts_are_manifest_fields(tmp_path): + output = _write(tmp_path / "staging" / "run-1" / "districts" / "NC-01.h5", b"h5") + manifest = StepManifest( + run_id="run-1", + step_id="04_build_h5_regional", + scope="regional", + status="partially_reused", + attempt=2, + started_at="2026-04-30T12:00:00+00:00", + outputs=[ArtifactReference.from_path(output)], + reuse_decision="partially_reused", + reuse_reason="prior_success", + reuse_measurement=ReuseMeasurement( + expected_outputs=3, + valid_reused_outputs=1, + recomputed_outputs=2, + invalid_outputs=0, + ), + ) + + data = manifest.to_dict() + + assert data["reuse_measurement"] == { + "expected_outputs": 3, + "valid_reused_outputs": 1, + "recomputed_outputs": 2, + "invalid_outputs": 0, + } + + +def test_completed_validated_outputs_reads_release_candidates_from_steps(tmp_path): + run_dir = tmp_path / "runs" / "run-1" + output = _write(tmp_path / "staging" / "run-1" / "states" / "NC.h5", b"h5") + stale = _write(tmp_path / "staging" / "run-1" / "states" / "SC.h5", b"stale") + write_run_manifest( + run_manifest_path(run_dir), + RunManifest( + run_id="run-1", + branch="main", + sha="abc", + version="1.0.0", + status="completed", + started_at="2026-04-30T12:00:00+00:00", + known_step_ids=["04_build_h5_regional"], + ), + ) + write_step_manifest( + step_manifest_path(run_dir, "04_build_h5_regional"), + StepManifest( + run_id="run-1", + step_id="04_build_h5_regional", + status="completed", + attempt=1, + started_at="2026-04-30T12:00:00+00:00", + outputs=[ArtifactReference.from_path(output)], + reuse_decision="computed", + ), + ) + write_step_manifest( + step_manifest_path(run_dir, "04_build_h5_stale"), + StepManifest( + run_id="run-1", + step_id="04_build_h5_stale", + status="completed", + attempt=1, + started_at="2026-04-30T12:00:00+00:00", + outputs=[ArtifactReference.from_path(stale)], + reuse_decision="computed", + ), + ) + stale.write_bytes(b"changed") + + outputs = completed_validated_outputs(run_dir) + + assert [artifact.path for artifact in outputs] == [str(output)] + assert ( + read_step_manifest(step_manifest_path(run_dir, "04_build_h5_regional")).step_id + == "04_build_h5_regional" + ) diff --git a/tests/unit/utils/test_data_upload.py b/tests/unit/utils/test_data_upload.py index 97b5a9dc4..8d39b8155 100644 --- a/tests/unit/utils/test_data_upload.py +++ b/tests/unit/utils/test_data_upload.py @@ -123,7 +123,8 @@ def test_upload_to_staging_hf_accepts_run_id_kwarg(monkeypatch, tmp_path): ) assert n == 1 - assert len(captured_ops) == 1 + assert len(captured_ops) == 2 + assert captured_ops[0].path_in_repo == ("staging/abc123/_run_context.json") def test_upload_to_staging_hf_run_id_scopes_staging_prefix(monkeypatch, tmp_path): @@ -133,6 +134,7 @@ def test_upload_to_staging_hf_run_id_scopes_staging_prefix(monkeypatch, tmp_path data_upload.upload_to_staging_hf(files, version="1.73.0", run_id="abc123") assert [op.path_in_repo for op in captured_ops] == [ + "staging/abc123/_run_context.json", "staging/abc123/states/AL.h5", "staging/abc123/states/CA.h5", ] @@ -149,6 +151,19 @@ def test_upload_to_staging_hf_without_run_id_uses_bare_staging_prefix( assert [op.path_in_repo for op in captured_ops] == ["staging/states/AL.h5"] +def test_upload_to_staging_hf_uses_run_id_env(monkeypatch, tmp_path): + monkeypatch.setenv("US_DATA_RUN_ID", "run-123") + data_upload, captured_ops = _install_fake_hf(monkeypatch, tmp_path) + files = _make_files(tmp_path, ["states/AL.h5"]) + + data_upload.upload_to_staging_hf(files, version="1.73.0") + + assert [op.path_in_repo for op in captured_ops] == [ + "staging/run-123/_run_context.json", + "staging/run-123/states/AL.h5", + ] + + def test_promote_staging_to_production_hf_uses_run_scoped_source_only(monkeypatch): data_upload = _load_data_upload_module() commit_operations = [] diff --git a/tests/unit/version_manifest/test_version_manifest.py b/tests/unit/version_manifest/test_version_manifest.py index 1c2eede0f..447048538 100644 --- a/tests/unit/version_manifest/test_version_manifest.py +++ b/tests/unit/version_manifest/test_version_manifest.py @@ -153,14 +153,12 @@ def test_regular_manifest_has_no_special_operation( assert result.special_operation is None assert result.roll_back_version is None - def test_pipeline_run_id_omitted_by_default(self, sample_manifest): + def test_run_id_omitted_by_default(self, sample_manifest): data = sample_manifest.to_dict() - assert "pipeline_run_id" not in data + assert "run_id" not in data assert "diagnostics_path" not in data - def test_pipeline_run_id_included_when_set( - self, sample_generations, sample_hf_info - ): + def test_run_id_included_when_set(self, sample_generations, sample_hf_info): manifest = VersionManifest( version="1.73.0", created_at="2026-03-10T15:00:00Z", @@ -169,14 +167,14 @@ def test_pipeline_run_id_included_when_set( bucket="policyengine-us-data", generations=sample_generations, ), - pipeline_run_id="1.73.0_abc12345_20260310", + run_id="usdata-gha123-a1-abc12345", diagnostics_path=("calibration/runs/1.73.0_abc12345_20260310/diagnostics/"), ) data = manifest.to_dict() - assert data["pipeline_run_id"] == ("1.73.0_abc12345_20260310") + assert data["run_id"] == "usdata-gha123-a1-abc12345" assert "diagnostics/" in data["diagnostics_path"] - def test_pipeline_run_id_roundtrip(self, sample_generations, sample_hf_info): + def test_run_id_roundtrip(self, sample_generations, sample_hf_info): manifest = VersionManifest( version="1.73.0", created_at="2026-03-10T15:00:00Z", @@ -185,11 +183,11 @@ def test_pipeline_run_id_roundtrip(self, sample_generations, sample_hf_info): bucket="policyengine-us-data", generations=sample_generations, ), - pipeline_run_id="1.73.0_abc12345_20260310", + run_id="usdata-gha123-a1-abc12345", diagnostics_path="calibration/runs/x/diag/", ) roundtripped = VersionManifest.from_dict(manifest.to_dict()) - assert roundtripped.pipeline_run_id == ("1.73.0_abc12345_20260310") + assert roundtripped.run_id == "usdata-gha123-a1-abc12345" assert roundtripped.diagnostics_path == ("calibration/runs/x/diag/")