diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index ca6995a727..353a8df779 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -185,6 +185,21 @@ with catalog.create_table_transaction(identifier="docs_example.bids", schema=sch txn.set_properties(test_a="test_aa", test_b="test_b", test_c="test_c") ``` +## Replace a table + +Atomically replace an existing table's schema, partition spec, sort order, location, and properties via `replace_table_transaction`. The table UUID and history (snapshots, schemas, specs, sort orders, metadata log) are preserved; the current snapshot is cleared (the `main` branch ref is removed). Open the transaction with the new definition, stage any additional changes (writes, property updates, schema evolution), and commit — for example, an RTAS (replace-table-as-select) that swaps the schema and writes the new data atomically: + +```python +with catalog.replace_table_transaction(identifier="docs_example.bids", schema=df.schema) as txn: + txn.append(df) +``` + +Field IDs are reused by name from the previous schema; new columns get fresh IDs above `last-column-id`. + +Table properties are *merged* on replace: properties you don't pass are preserved on the table. To remove a property, drop it explicitly within the transaction. + +Pass `format-version` in `properties` to upgrade the table's format version as part of the replace. + ## Register a table To register a table using existing metadata: diff --git a/pyiceberg/catalog/__init__.py b/pyiceberg/catalog/__init__.py index 95ceaa539f..d8cabc8dcd 100644 --- a/pyiceberg/catalog/__init__.py +++ b/pyiceberg/catalog/__init__.py @@ -42,20 +42,25 @@ ) from pyiceberg.io import FileIO, load_file_io from pyiceberg.manifest import ManifestFile -from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec -from pyiceberg.schema import Schema +from pyiceberg.partitioning import ( + UNPARTITIONED_PARTITION_SPEC, + PartitionSpec, + assign_fresh_partition_spec_ids_for_replace, +) +from pyiceberg.schema import Schema, assign_fresh_schema_ids_for_replace from pyiceberg.serializers import ToOutputFile from pyiceberg.table import ( DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, CommitTableResponse, CreateTableTransaction, + ReplaceTableTransaction, StagedTable, Table, TableProperties, ) from pyiceberg.table.locations import load_location_provider from pyiceberg.table.metadata import TableMetadata, TableMetadataV1, new_table_metadata -from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder +from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder, assign_fresh_sort_order_ids from pyiceberg.table.update import ( TableRequirement, TableUpdate, @@ -444,6 +449,90 @@ def create_table_if_not_exists( except TableAlreadyExistsError: return self.load_table(identifier) + def replace_table_transaction( + self, + identifier: str | Identifier, + schema: Schema | pa.Schema, + location: str | None = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> ReplaceTableTransaction: + """Create a ReplaceTableTransaction. + + The transaction can be used to stage additional changes (schema evolution, + partition evolution, etc.) before committing. + + Args: + identifier (str | Identifier): Table identifier. + schema (Schema): New table schema. + location (str | None): New table location. Defaults to the existing location. + partition_spec (PartitionSpec): New partition spec. + sort_order (SortOrder): New sort order. + properties (Properties): Properties to apply. Merged on top of the existing + table properties: keys present here override existing values; existing keys + not present here are preserved. To remove a property, follow up with a + transaction that removes it explicitly. + + Returns: + ReplaceTableTransaction: A transaction for the replace operation. + + Raises: + NoSuchTableError: If the table does not exist. + """ + existing_table = self.load_table(identifier) + existing_metadata = existing_table.metadata + + raw_format_version = properties.get(TableProperties.FORMAT_VERSION) + if raw_format_version is not None: + try: + requested_format_version = int(raw_format_version) + except (TypeError, ValueError) as exc: + raise ValueError(f"Invalid format-version property: {raw_format_version!r}") from exc + if requested_format_version < existing_metadata.format_version: + raise ValueError( + f"Cannot downgrade format-version from {existing_metadata.format_version} to {requested_format_version}" + ) + resolved_format_version = requested_format_version + else: + resolved_format_version = existing_metadata.format_version + iceberg_schema = self._convert_schema_if_needed(schema, cast(TableVersion, resolved_format_version)) + iceberg_schema.check_format_version_compatibility(cast(TableVersion, resolved_format_version)) + + fresh_schema, _ = assign_fresh_schema_ids_for_replace( + iceberg_schema, existing_metadata.schema(), existing_metadata.last_column_id + ) + fresh_partition_spec, _ = assign_fresh_partition_spec_ids_for_replace( + partition_spec, + iceberg_schema, + fresh_schema, + existing_metadata.partition_specs, + existing_metadata.last_partition_id, + format_version=existing_metadata.format_version, + current_spec=existing_metadata.spec(), + ) + fresh_sort_order = assign_fresh_sort_order_ids(sort_order, iceberg_schema, fresh_schema) + + resolved_location = location.rstrip("/") if location else existing_metadata.location + if not resolved_location: + raise ValueError("Resolved table location must not be empty") + + staged_table = StagedTable( + identifier=existing_table.name(), + metadata=existing_metadata, + metadata_location=existing_table.metadata_location, + io=existing_table.io, + catalog=self, + ) + return ReplaceTableTransaction( + table=staged_table, + new_schema=fresh_schema, + new_spec=fresh_partition_spec, + new_sort_order=fresh_sort_order, + new_location=resolved_location, + new_properties=properties, + ) + @abstractmethod def load_table(self, identifier: str | Identifier) -> Table: """Load the table's metadata and returns the table instance. diff --git a/pyiceberg/catalog/noop.py b/pyiceberg/catalog/noop.py index aeb3c72843..06348903af 100644 --- a/pyiceberg/catalog/noop.py +++ b/pyiceberg/catalog/noop.py @@ -28,6 +28,7 @@ from pyiceberg.table import ( CommitTableResponse, CreateTableTransaction, + ReplaceTableTransaction, Table, ) from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder @@ -68,6 +69,18 @@ def create_table_transaction( ) -> CreateTableTransaction: raise NotImplementedError + @override + def replace_table_transaction( + self, + identifier: str | Identifier, + schema: Schema | pa.Schema, + location: str | None = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> ReplaceTableTransaction: + raise NotImplementedError + @override def load_table(self, identifier: str | Identifier) -> Table: raise NotImplementedError diff --git a/pyiceberg/catalog/rest/__init__.py b/pyiceberg/catalog/rest/__init__.py index 39954ef561..74510fbdb1 100644 --- a/pyiceberg/catalog/rest/__init__.py +++ b/pyiceberg/catalog/rest/__init__.py @@ -68,13 +68,18 @@ FileIO, load_file_io, ) -from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec, assign_fresh_partition_spec_ids +from pyiceberg.partitioning import ( + UNPARTITIONED_PARTITION_SPEC, + PartitionSpec, + assign_fresh_partition_spec_ids, +) from pyiceberg.schema import Schema, assign_fresh_schema_ids from pyiceberg.table import ( CommitTableRequest, CommitTableResponse, CreateTableTransaction, FileScanTask, + ReplaceTableTransaction, StagedTable, Table, TableIdentifier, @@ -957,6 +962,19 @@ def create_table_transaction( staged_table = self._response_to_staged_table(self.identifier_to_tuple(identifier), table_response) return CreateTableTransaction(staged_table) + @override + @retry(**_RETRY_ARGS) + def replace_table_transaction( + self, + identifier: str | Identifier, + schema: Schema | pa.Schema, + location: str | None = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> ReplaceTableTransaction: + return super().replace_table_transaction(identifier, schema, location, partition_spec, sort_order, properties) + @override @retry(**_RETRY_ARGS) def create_view( diff --git a/pyiceberg/partitioning.py b/pyiceberg/partitioning.py index 3de185d886..b51f37443c 100644 --- a/pyiceberg/partitioning.py +++ b/pyiceberg/partitioning.py @@ -335,6 +335,175 @@ def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fre return PartitionSpec(*partition_fields, spec_id=INITIAL_PARTITION_SPEC_ID) +def assign_fresh_partition_spec_ids_for_replace( + spec: PartitionSpec, + old_schema: Schema, + fresh_schema: Schema, + existing_specs: list[PartitionSpec], + last_partition_id: int | None, + format_version: int = 2, + current_spec: PartitionSpec | None = None, +) -> tuple[PartitionSpec, int]: + """Assign partition field IDs for a replace operation, reusing IDs from existing specs. + + - For v2+, reuse partition field IDs by `(source_id, transform)` across all existing specs. + New fields get IDs starting from `last_partition_id + 1`. + - For v1, the current spec's fields must be preserved (v1 specs are append-only). Fields + absent from the new spec are carried forward with a `VoidTransform`. Matching new fields + reuse the existing partition field ID; remaining new fields are appended with fresh IDs. + + Args: + spec: The new partition spec to assign IDs to. Its `source_id`s reference `old_schema`. + old_schema: The schema that the new spec's `source_id`s reference. + fresh_schema: The schema with freshly assigned field IDs. + existing_specs: All partition specs from the existing table metadata. + last_partition_id: The current table's `last_partition_id`. + format_version: Table format version. Required to be set to 1 for v1 carry-forward. + current_spec: The current default partition spec. Required when `format_version <= 1`. + + Returns: + A tuple of `(fresh_spec, new_last_partition_id)`. + """ + effective_last_partition_id = last_partition_id if last_partition_id is not None else PARTITION_FIELD_ID_START - 1 + + if format_version <= 1: + if current_spec is None: + raise ValueError("current_spec is required for v1 replace_table") + return _assign_fresh_partition_spec_ids_for_replace_v1( + spec, old_schema, fresh_schema, current_spec, effective_last_partition_id + ) + + # v2+: reuse field IDs by (source_id, transform) across all specs. When the same + # (source_id, transform) appears in multiple specs, prefer the highest field_id. + transform_to_field_id: dict[tuple[int, str], int] = {} + for existing_spec in existing_specs: + for field in existing_spec.fields: + key = (field.source_id, str(field.transform)) + if key not in transform_to_field_id or field.field_id > transform_to_field_id[key]: + transform_to_field_id[key] = field.field_id + + next_id = effective_last_partition_id + partition_fields = [] + for field in spec.fields: + original_column_name = old_schema.find_column_name(field.source_id) + if original_column_name is None: + raise ValueError(f"Could not find in old schema: {field}") + fresh_field = fresh_schema.find_field(original_column_name) + if fresh_field is None: + raise ValueError(f"Could not find field in fresh schema: {original_column_name}") + + validate_partition_name(field.name, field.transform, fresh_field.field_id, fresh_schema, set()) + + key = (fresh_field.field_id, str(field.transform)) + if key in transform_to_field_id: + partition_field_id = transform_to_field_id[key] + else: + next_id += 1 + partition_field_id = next_id + transform_to_field_id[key] = partition_field_id + + partition_fields.append( + PartitionField( + name=field.name, + source_id=fresh_field.field_id, + field_id=partition_field_id, + transform=field.transform, + ) + ) + + # `next_id` starts at `effective_last_partition_id` and only increments, so it is the + # new last partition id. + return PartitionSpec(*partition_fields, spec_id=INITIAL_PARTITION_SPEC_ID), next_id + + +def _assign_fresh_partition_spec_ids_for_replace_v1( + spec: PartitionSpec, + old_schema: Schema, + fresh_schema: Schema, + current_spec: PartitionSpec, + effective_last_partition_id: int, +) -> tuple[PartitionSpec, int]: + """v1 branch of `assign_fresh_partition_spec_ids_for_replace`. See parent docstring.""" + # Build (fresh_source_id, transform) → (new_field, fresh_source_id) for the new spec, + # in insertion order so leftover fields keep their declared order on append. + new_field_by_key: dict[tuple[int, str], tuple[PartitionField, int]] = {} + new_field_names: list[str] = [] + for new_field in spec.fields: + col_name = old_schema.find_column_name(new_field.source_id) + if col_name is None: + raise ValueError(f"Could not find in old schema: {new_field}") + fresh_field = fresh_schema.find_field(col_name) + if fresh_field is None: + raise ValueError(f"Could not find field in fresh schema: {col_name}") + validate_partition_name(new_field.name, new_field.transform, fresh_field.field_id, fresh_schema, set()) + key = (fresh_field.field_id, str(new_field.transform)) + new_field_by_key[key] = (new_field, fresh_field.field_id) + new_field_names.append(new_field.name) + + # Walk current spec, carrying forward each field. Matching new fields consume their key; + # missing fields become void transforms. + used_names: set[str] = set(new_field_names) + partition_fields = [] + for cur_field in current_spec.fields: + key = (cur_field.source_id, str(cur_field.transform)) + match = new_field_by_key.pop(key, None) + if match is not None: + new_field, fresh_source_id = match + partition_fields.append( + PartitionField( + name=new_field.name, + source_id=fresh_source_id, + field_id=cur_field.field_id, + transform=new_field.transform, + ) + ) + used_names.add(new_field.name) + else: + void_name = _unique_void_name(cur_field.name, cur_field.field_id, used_names) + used_names.add(void_name) + partition_fields.append( + PartitionField( + name=void_name, + source_id=cur_field.source_id, + field_id=cur_field.field_id, + transform=VoidTransform(), + ) + ) + + # Append remaining new fields at the end with fresh partition IDs. + next_id = effective_last_partition_id + for new_field, fresh_source_id in new_field_by_key.values(): + next_id += 1 + partition_fields.append( + PartitionField( + name=new_field.name, + source_id=fresh_source_id, + field_id=next_id, + transform=new_field.transform, + ) + ) + + # `next_id` starts at `effective_last_partition_id` and only increments, so it is the + # new last partition id. + return PartitionSpec(*partition_fields, spec_id=INITIAL_PARTITION_SPEC_ID), next_id + + +def _unique_void_name(base_name: str, field_id: int, used_names: set[str]) -> str: + """Pick a void-transform name that does not collide with already-used names. + + First tries `base_name`; if taken, tries `base_name_{field_id}`; if still taken, + appends `_2`, `_3`, ... until unique. + """ + if base_name not in used_names: + return base_name + candidate = f"{base_name}_{field_id}" + suffix = 2 + while candidate in used_names: + candidate = f"{base_name}_{field_id}_{suffix}" + suffix += 1 + return candidate + + T = TypeVar("T") diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py index fd60eb8f94..7ae198c74d 100644 --- a/pyiceberg/schema.py +++ b/pyiceberg/schema.py @@ -1380,6 +1380,62 @@ def primitive(self, primitive: PrimitiveType) -> PrimitiveType: return primitive +class _SetFreshIDsForReplace(_SetFreshIDs): + """Assign fresh IDs for a replace operation, reusing IDs from the base schema by field name. + + For each field in the new schema, if a field with the same full name exists in the + base schema, its ID is reused; otherwise a fresh ID is allocated starting from + last_column_id + 1. + + Note: ID reuse is purely name-based — a field whose name matches but whose type differs + (e.g. `int` → `string`) will reuse the base ID. This is intentional: replace allows + arbitrary schema changes; type compatibility is the caller's responsibility. + """ + + def __init__(self, old_id_to_base_id: dict[int, int], starting_id: int) -> None: + self.old_id_to_new_id: dict[int, int] = {} + self._old_id_to_base_id = old_id_to_base_id + counter = itertools.count(starting_id + 1) + self.next_id_func = lambda: next(counter) + + def _get_and_increment(self, current_id: int) -> int: + if current_id in self._old_id_to_base_id: + new_id = self._old_id_to_base_id[current_id] + else: + new_id = self.next_id_func() + self.old_id_to_new_id[current_id] = new_id + return new_id + + +def assign_fresh_schema_ids_for_replace(schema: Schema, base_schema: Schema, last_column_id: int) -> tuple[Schema, int]: + """Assign fresh IDs to a schema for a replace operation, reusing IDs from the base schema. + + For each field in the new schema, if a field with the same full path name exists + in the base schema, its ID is reused. New fields get IDs starting from + last_column_id + 1. + + Args: + schema: The new schema to assign IDs to. + base_schema: The existing table's current schema (IDs are reused from here by name). + last_column_id: The current table's last_column_id (new IDs start above this). + + Returns: + A tuple of (fresh_schema, new_last_column_id). + """ + base_name_to_id = index_by_name(base_schema) + new_id_to_name = index_name_by_id(schema) + + old_id_to_base_id: dict[int, int] = {} + for old_id, name in new_id_to_name.items(): + if name in base_name_to_id: + old_id_to_base_id[old_id] = base_name_to_id[name] + + visitor = _SetFreshIDsForReplace(old_id_to_base_id, last_column_id) + fresh_schema = pre_order_visit(schema, visitor) + new_last_column_id = max(fresh_schema.highest_field_id, last_column_id) + return fresh_schema, new_last_column_id + + # Implementation copied from Apache Iceberg repo. def make_compatible_name(name: str) -> str: """Make a field name compatible with Avro specification. diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 64ad10050d..8b077403ee 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -58,10 +58,13 @@ AddSchemaUpdate, AddSortOrderUpdate, AssertCreate, + AssertLastAssignedFieldId, + AssertLastAssignedPartitionId, AssertRefSnapshotId, AssertTableUUID, AssignUUIDUpdate, RemovePropertiesUpdate, + RemoveSnapshotRefUpdate, SetCurrentSchemaUpdate, SetDefaultSortOrderUpdate, SetDefaultSpecUpdate, @@ -1106,6 +1109,150 @@ def commit_transaction(self) -> Table: return self._table +class ReplaceTableTransaction(Transaction): + """A transaction that replaces an existing table's schema, spec, sort order, location, and properties. + + The existing table UUID, snapshots, snapshot log, metadata log, and history are preserved. + The "main" branch ref is removed (current-snapshot-id set to -1), and new + schema/spec/sort-order/location/properties are applied. + """ + + def __init__( + self, + table: StagedTable, + new_schema: Schema, + new_spec: PartitionSpec, + new_sort_order: SortOrder, + new_location: str, + new_properties: Properties, + ) -> None: + super().__init__(table, autocommit=False) + self._initial_changes(table.metadata, new_schema, new_spec, new_sort_order, new_location, new_properties) + + def _initial_changes( + self, + table_metadata: TableMetadata, + new_schema: Schema, + new_spec: PartitionSpec, + new_sort_order: SortOrder, + new_location: str, + new_properties: Properties, + ) -> None: + """Set the initial changes that transform the existing table into the replacement. + + Always emits `SetCurrentSchema` / `SetDefaultPartitionSpec` / `SetDefaultSortOrder` + (even when the resulting id is reused) so the request body unambiguously signals a + replace. Bumps `format-version` when the new properties request it. + """ + # Upgrade format-version if requested via properties. + requested_format_version_str = new_properties.get(TableProperties.FORMAT_VERSION) + if requested_format_version_str is not None: + requested_format_version = int(requested_format_version_str) + if requested_format_version > table_metadata.format_version: + self._updates += (UpgradeFormatVersionUpdate(format_version=requested_format_version),) + + # Remove the main branch ref to clear the current snapshot. + self._updates += (RemoveSnapshotRefUpdate(ref_name=MAIN_BRANCH),) + + # Schema: reuse an existing schema_id if structurally identical, else add a new one + # with a fresh schema_id (max + 1, matching UpdateSchema's convention). + existing_schema_id = self._find_matching_schema_id(table_metadata, new_schema) + if existing_schema_id is not None: + self._updates += (SetCurrentSchemaUpdate(schema_id=existing_schema_id),) + else: + next_schema_id = max((s.schema_id for s in table_metadata.schemas), default=-1) + 1 + schema_with_fresh_id = new_schema.model_copy(update={"schema_id": next_schema_id}) + self._updates += ( + AddSchemaUpdate(schema_=schema_with_fresh_id), + SetCurrentSchemaUpdate(schema_id=-1), + ) + + # Partition spec: same reuse-or-add pattern. Assign a fresh spec_id on add to avoid + # collisions with existing specs (AddPartitionSpecUpdate refuses duplicate IDs). + effective_spec = UNPARTITIONED_PARTITION_SPEC if new_spec.is_unpartitioned() else new_spec + existing_spec_id = self._find_matching_spec_id(table_metadata, effective_spec) + if existing_spec_id is not None: + self._updates += (SetDefaultSpecUpdate(spec_id=existing_spec_id),) + else: + next_spec_id = max((s.spec_id for s in table_metadata.partition_specs), default=-1) + 1 + spec_with_fresh_id = PartitionSpec(*effective_spec.fields, spec_id=next_spec_id) + self._updates += ( + AddPartitionSpecUpdate(spec=spec_with_fresh_id), + SetDefaultSpecUpdate(spec_id=-1), + ) + + # Sort order: same reuse-or-add pattern with fresh order_id on add. + effective_sort_order = UNSORTED_SORT_ORDER if new_sort_order.is_unsorted else new_sort_order + existing_order_id = self._find_matching_sort_order_id(table_metadata, effective_sort_order) + if existing_order_id is not None: + self._updates += (SetDefaultSortOrderUpdate(sort_order_id=existing_order_id),) + else: + next_order_id = max((o.order_id for o in table_metadata.sort_orders), default=-1) + 1 + sort_order_with_fresh_id = SortOrder(*effective_sort_order.fields, order_id=next_order_id) + self._updates += ( + AddSortOrderUpdate(sort_order=sort_order_with_fresh_id), + SetDefaultSortOrderUpdate(sort_order_id=-1), + ) + + # Set location if changed. + if new_location != table_metadata.location: + self._updates += (SetLocationUpdate(location=new_location),) + + # Merge properties (SetPropertiesUpdate merges onto existing properties). + # Strip `format-version` so it does not get persisted as a regular property. + persisted_properties = {k: v for k, v in new_properties.items() if k != TableProperties.FORMAT_VERSION} + if persisted_properties: + self._updates += (SetPropertiesUpdate(updates=persisted_properties),) + + @staticmethod + def _find_matching_schema_id(table_metadata: TableMetadata, schema: Schema) -> int | None: + """Find an existing schema structurally equal to the given one, returning its schema_id or None.""" + for existing in table_metadata.schemas: + if existing == schema: + return existing.schema_id + return None + + @staticmethod + def _find_matching_spec_id(table_metadata: TableMetadata, spec: PartitionSpec) -> int | None: + """Find an existing partition spec with the same fields, returning its spec_id or None.""" + for existing in table_metadata.partition_specs: + if existing.fields == spec.fields: + return existing.spec_id + return None + + @staticmethod + def _find_matching_sort_order_id(table_metadata: TableMetadata, sort_order: SortOrder) -> int | None: + """Find an existing sort order with the same fields, returning its order_id or None.""" + for existing in table_metadata.sort_orders: + if existing.fields == sort_order.fields: + return existing.order_id + return None + + def commit_transaction(self) -> Table: + """Commit the changes to the catalog. + + Returns: + The table with the updates applied. + """ + if len(self._updates) > 0: + base = self._table.metadata + requirements: tuple[TableRequirement, ...] = ( + AssertTableUUID(uuid=base.table_uuid), + AssertLastAssignedFieldId(last_assigned_field_id=base.last_column_id), + ) + if base.last_partition_id is not None: + requirements += (AssertLastAssignedPartitionId(last_assigned_partition_id=base.last_partition_id),) + self._table._do_commit( # pylint: disable=W0212 + updates=self._updates, + requirements=requirements, + ) + + self._updates = () + self._requirements = () + + return self._table + + class Namespace(IcebergRootModel[list[str]]): """Reference to one or more levels of a namespace.""" diff --git a/tests/catalog/test_catalog_behaviors.py b/tests/catalog/test_catalog_behaviors.py index b859e2d541..bcc5d22e80 100644 --- a/tests/catalog/test_catalog_behaviors.py +++ b/tests/catalog/test_catalog_behaviors.py @@ -21,6 +21,7 @@ import os from collections.abc import Generator +from dataclasses import dataclass from pathlib import Path from typing import Any @@ -42,11 +43,11 @@ from pyiceberg.io.pyarrow import _dataframe_to_data_files, schema_to_pyarrow from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec from pyiceberg.schema import Schema -from pyiceberg.table import TableProperties +from pyiceberg.table import Table, TableProperties from pyiceberg.table.snapshots import Operation from pyiceberg.table.sorting import NullOrder, SortDirection, SortField, SortOrder from pyiceberg.table.update import AddSchemaUpdate, SetCurrentSchemaUpdate -from pyiceberg.transforms import IdentityTransform +from pyiceberg.transforms import IdentityTransform, VoidTransform from pyiceberg.typedef import Identifier from pyiceberg.types import BooleanType, IntegerType, LongType, NestedField, StringType @@ -388,6 +389,380 @@ def test_load_table_from_self_identifier( assert table.metadata == loaded_table.metadata +_SIMPLE_SCHEMA = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), +) + + +def _create_simple_table( + catalog: Catalog, + identifier: Identifier, + *, + schema: Schema = _SIMPLE_SCHEMA, + format_version: int = 2, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + properties: dict[str, str] | None = None, +) -> tuple[Identifier, Schema]: + namespace = Catalog.namespace_from(identifier) + catalog.create_namespace_if_not_exists(namespace) + merged_properties = {"format-version": str(format_version), **(properties or {})} + catalog.create_table(identifier, schema=schema, partition_spec=partition_spec, properties=merged_properties) + return identifier, schema + + +def _simple_data(num_rows: int = 2) -> pa.Table: + return pa.Table.from_pydict( + {"id": list(range(num_rows)), "data": [chr(ord("a") + i) for i in range(num_rows)]}, + schema=pa.schema([pa.field("id", pa.int64()), pa.field("data", pa.large_string())]), + ) + + +_REPLACE_SCHEMA = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + NestedField(field_id=3, name="extra", field_type=BooleanType(), required=False), +) + + +def test_replace_transaction(catalog: Catalog, test_table_identifier: Identifier) -> None: + _, original_schema = _create_simple_table(catalog, test_table_identifier) + original = catalog.load_table(test_table_identifier) + original.append(_simple_data()) + original = catalog.load_table(test_table_identifier) + old_snapshot_id = original.current_snapshot().snapshot_id # type: ignore[union-attr] + snapshot_log_before = list(original.metadata.snapshot_log) + assert len(snapshot_log_before) == 1 + + catalog.replace_table_transaction(test_table_identifier, schema=_REPLACE_SCHEMA).commit_transaction() + replaced = catalog.load_table(test_table_identifier) + + # UUID + history preserved, current snapshot cleared, current schema swapped. + assert replaced.metadata.table_uuid == original.metadata.table_uuid + assert replaced.metadata.current_snapshot_id is None + assert {f.name for f in replaced.schema().fields} == {"id", "data", "extra"} + # Old snapshot kept by identity (not just count), and snapshot_log entries from before survive + # in order at the front of the log. + assert any(s.snapshot_id == old_snapshot_id for s in replaced.metadata.snapshots) + assert replaced.metadata.snapshot_log[: len(snapshot_log_before)] == snapshot_log_before + # Old schema is still in the schemas list alongside the new one. + schema_ids = sorted(s.schema_id for s in replaced.metadata.schemas) + assert schema_ids == [0, 1] + assert replaced.metadata.current_schema_id == 1 + # Time-travel back to the pre-replace snapshot returns the rows that were there before. + assert replaced.scan(snapshot_id=old_snapshot_id).to_arrow().equals(_simple_data()) + + +@dataclass +class _ReplaceFixture: + """State produced by `_run_complete_replace`: the table before/after the replace plus + the inputs needed to assert on the result.""" + + original: Table + replaced: Table + new_sort: SortOrder + original_data: pa.Table + old_snapshot_id: int + + +def _run_complete_replace(catalog: Catalog, identifier: Identifier, tmp_path: Path) -> _ReplaceFixture: + """Set up a table, run a full-six-args RTAS replace, and return the handles needed for assertions.""" + _create_simple_table(catalog, identifier, properties={"keep": "yes", "override": "old"}) + catalog.load_table(identifier).append(_simple_data()) + original = catalog.load_table(identifier) + old_snapshot_id = original.current_snapshot().snapshot_id # type: ignore[union-attr] + original_data = original.scan().to_arrow() + + new_location = f"file://{tmp_path}/replaced" + new_schema = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + NestedField(field_id=3, name="extra", field_type=BooleanType(), required=False), + ) + new_spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, name="id_part", transform=IdentityTransform())) + new_sort = SortOrder(SortField(source_id=1, transform=IdentityTransform(), direction=SortDirection.ASC)) + new_data = pa.Table.from_pydict( + {"id": [10, 20], "data": ["alice", "bob"], "extra": [True, False]}, + schema=pa.schema([pa.field("id", pa.int64()), pa.field("data", pa.large_string()), pa.field("extra", pa.bool_())]), + ) + + with catalog.replace_table_transaction( + identifier, + schema=new_schema, + partition_spec=new_spec, + sort_order=new_sort, + location=new_location, + properties={"override": "new", "added": "v"}, + ) as txn: + txn.append(new_data) + + return _ReplaceFixture( + original=original, + replaced=catalog.load_table(identifier), + new_sort=new_sort, + original_data=original_data, + old_snapshot_id=old_snapshot_id, + ) + + +def test_complete_replace_transaction_applies_new_schema_spec_and_sort( + catalog: Catalog, test_table_identifier: Identifier, tmp_path: Path +) -> None: + fx = _run_complete_replace(catalog, test_table_identifier, tmp_path) + # Identity invariants. + assert fx.replaced.metadata.table_uuid == fx.original.metadata.table_uuid + assert fx.replaced.metadata.location == f"file://{tmp_path}/replaced" + # New schema / spec / sort applied; old entries retained in history. + assert {f.name for f in fx.replaced.schema().fields} == {"id", "data", "extra"} + assert sorted(s.schema_id for s in fx.replaced.metadata.schemas) == [0, 1] + assert fx.replaced.spec().fields[0].source_id == 1 + assert isinstance(fx.replaced.spec().fields[0].transform, IdentityTransform) + assert {s.spec_id for s in fx.replaced.metadata.partition_specs} == {0, 1} + assert fx.replaced.sort_order().fields == fx.new_sort.fields + assert {s.order_id for s in fx.replaced.metadata.sort_orders} == {0, fx.replaced.metadata.default_sort_order_id} + + +def test_complete_replace_transaction_merges_properties( + catalog: Catalog, test_table_identifier: Identifier, tmp_path: Path +) -> None: + fx = _run_complete_replace(catalog, test_table_identifier, tmp_path) + # `keep` is preserved, `override` is updated, `added` is new, and `format-version` does not leak. + assert fx.replaced.properties["keep"] == "yes" + assert fx.replaced.properties["override"] == "new" + assert fx.replaced.properties["added"] == "v" + assert "format-version" not in fx.replaced.properties + + +def test_complete_replace_transaction_rtas_preserves_old_snapshot( + catalog: Catalog, test_table_identifier: Identifier, tmp_path: Path +) -> None: + fx = _run_complete_replace(catalog, test_table_identifier, tmp_path) + # New snapshot exists, has no parent (fresh start), old snapshot is still in the snapshot list. + new_snapshot = fx.replaced.current_snapshot() + assert new_snapshot is not None + assert new_snapshot.snapshot_id != fx.old_snapshot_id + assert new_snapshot.parent_snapshot_id is None + assert any(s.snapshot_id == fx.old_snapshot_id for s in fx.replaced.metadata.snapshots) + assert fx.replaced.scan().to_arrow().num_rows == 2 + # Time-travel back to before the replace returns the original rows from the old schema. + time_travel = fx.replaced.scan(snapshot_id=fx.old_snapshot_id).to_arrow() + assert time_travel.num_rows == fx.original_data.num_rows + assert time_travel.column("id").to_pylist() == fx.original_data.column("id").to_pylist() + + +def test_replace_transaction_requires_table_exists(catalog: Catalog, test_table_identifier: Identifier) -> None: + schema = Schema(NestedField(field_id=1, name="id", field_type=LongType(), required=False)) + with pytest.raises(NoSuchTableError): + catalog.replace_table_transaction(test_table_identifier, schema=schema) + + +def test_replace_table_reuses_schema_id_when_identical(catalog: Catalog, test_table_identifier: Identifier) -> None: + _, base_schema = _create_simple_table(catalog, test_table_identifier) + catalog.replace_table_transaction(test_table_identifier, schema=base_schema).commit_transaction() + replaced = catalog.load_table(test_table_identifier) + # Identical shape -> no new schema appended, current points back at id 0. + assert [s.schema_id for s in replaced.metadata.schemas] == [0] + assert replaced.metadata.current_schema_id == 0 + assert replaced.metadata.last_column_id == 2 + + +def test_replace_table_reuses_partition_spec_and_sort_order_when_identical( + catalog: Catalog, test_table_identifier: Identifier +) -> None: + spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, name="id_part", transform=IdentityTransform())) + sort = SortOrder(SortField(source_id=1, transform=IdentityTransform(), direction=SortDirection.ASC)) + _, schema = _create_simple_table(catalog, test_table_identifier, partition_spec=spec) + # Introduce a sort order then replay both spec and sort — neither should append a new entry. + catalog.replace_table_transaction( + test_table_identifier, schema=schema, partition_spec=spec, sort_order=sort + ).commit_transaction() + sorted_first = catalog.load_table(test_table_identifier) + sorted_order_id = sorted_first.metadata.default_sort_order_id + assert sorted_order_id != 0 + + catalog.replace_table_transaction( + test_table_identifier, schema=schema, partition_spec=spec, sort_order=sort + ).commit_transaction() + replayed = catalog.load_table(test_table_identifier) + assert [s.spec_id for s in replayed.metadata.partition_specs] == [0] + assert replayed.metadata.default_spec_id == 0 + assert replayed.metadata.default_sort_order_id == sorted_order_id + + # Dropping the sort order falls back to the unsorted order_id 0 (also reused, not appended). + catalog.replace_table_transaction(test_table_identifier, schema=schema, partition_spec=spec).commit_transaction() + unsorted = catalog.load_table(test_table_identifier) + assert unsorted.sort_order().is_unsorted + assert unsorted.metadata.default_sort_order_id == 0 + + +@pytest.mark.parametrize("keep_identifier", [True, False], ids=["preserves", "drops"]) +def test_replace_table_identifier_field_ids(catalog: Catalog, test_table_identifier: Identifier, keep_identifier: bool) -> None: + schema = Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + identifier_field_ids=[1], + ) + _create_simple_table(catalog, test_table_identifier, schema=schema) + new_schema = ( + Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + NestedField(field_id=3, name="extra", field_type=BooleanType(), required=False), + identifier_field_ids=[1], + ) + if keep_identifier + else Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ) + ) + catalog.replace_table_transaction(test_table_identifier, schema=new_schema).commit_transaction() + replaced = catalog.load_table(test_table_identifier) + expected = [1] if keep_identifier else [] + assert list(replaced.schema().identifier_field_ids) == expected + + +@pytest.mark.parametrize( + "format_version, expect_void_carry_forward", + [(1, True), (2, False)], + ids=["v1-carries-forward", "v2-drops"], +) +def test_replace_table_partition_field_carry_forward( + catalog: Catalog, + test_table_identifier: Identifier, + format_version: int, + expect_void_carry_forward: bool, +) -> None: + spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, name="id_part", transform=IdentityTransform())) + _, schema = _create_simple_table(catalog, test_table_identifier, partition_spec=spec, format_version=format_version) + catalog.replace_table_transaction(test_table_identifier, schema=schema).commit_transaction() + replaced = catalog.load_table(test_table_identifier) + new_spec = replaced.spec() + if expect_void_carry_forward: + void_field = next(f for f in new_spec.fields if f.field_id == 1000) + assert isinstance(void_field.transform, VoidTransform) + assert void_field.source_id == 1 + assert void_field.name == "id_part" + else: + assert new_spec.is_unpartitioned() + + +def test_replace_table_upgrades_format_version(catalog: Catalog, test_table_identifier: Identifier) -> None: + _, schema = _create_simple_table(catalog, test_table_identifier, format_version=1) + assert catalog.load_table(test_table_identifier).format_version == 1 + + catalog.replace_table_transaction( + test_table_identifier, schema=schema, properties={"format-version": "2"} + ).commit_transaction() + upgraded = catalog.load_table(test_table_identifier) + assert upgraded.format_version == 2 + # `format-version` is a control input, not a persisted property. + assert "format-version" not in upgraded.properties + + +def test_replace_table_keeps_upgraded_format_version_on_subsequent_replace( + catalog: Catalog, test_table_identifier: Identifier +) -> None: + _, schema = _create_simple_table(catalog, test_table_identifier, format_version=1) + catalog.replace_table_transaction( + test_table_identifier, schema=schema, properties={"format-version": "2"} + ).commit_transaction() + new_schema = Schema(*schema.fields, NestedField(field_id=3, name="extra", field_type=BooleanType(), required=False)) + catalog.replace_table_transaction(test_table_identifier, schema=new_schema).commit_transaction() + replayed = catalog.load_table(test_table_identifier) + assert replayed.format_version == 2 + assert {f.name for f in replayed.schema().fields} == {"id", "data", "extra"} + + +@pytest.mark.parametrize( + "properties, location, expected_match", + [ + pytest.param({"format-version": "1"}, None, "Cannot downgrade format-version", id="format-version-downgrade"), + pytest.param({"format-version": "two"}, None, "Invalid format-version property", id="non-numeric-format-version"), + pytest.param({}, "/", "location must not be empty", id="empty-location-after-rstrip"), + ], +) +def test_replace_table_rejects_invalid_inputs( + catalog: Catalog, + test_table_identifier: Identifier, + properties: dict[str, str], + location: str | None, + expected_match: str, +) -> None: + _, schema = _create_simple_table(catalog, test_table_identifier, format_version=2) + with pytest.raises(ValueError, match=expected_match): + catalog.replace_table_transaction(test_table_identifier, schema=schema, properties=properties, location=location) + + +def test_replace_table_inherits_existing_location(catalog: Catalog, test_table_identifier: Identifier) -> None: + _, schema = _create_simple_table(catalog, test_table_identifier) + existing = catalog.load_table(test_table_identifier).metadata.location + catalog.replace_table_transaction(test_table_identifier, schema=schema).commit_transaction() + assert catalog.load_table(test_table_identifier).metadata.location == existing + + +def test_replace_table_uses_explicit_location(catalog: Catalog, test_table_identifier: Identifier, tmp_path: Path) -> None: + _, schema = _create_simple_table(catalog, test_table_identifier) + new_location = f"file://{tmp_path}/relocated" + catalog.replace_table_transaction(test_table_identifier, schema=schema, location=new_location).commit_transaction() + assert catalog.load_table(test_table_identifier).metadata.location == new_location + + +def test_replace_table_strips_trailing_slash_from_location( + catalog: Catalog, test_table_identifier: Identifier, tmp_path: Path +) -> None: + _, schema = _create_simple_table(catalog, test_table_identifier) + bare = f"file://{tmp_path}/relocated" + catalog.replace_table_transaction(test_table_identifier, schema=schema, location=bare + "/").commit_transaction() + assert catalog.load_table(test_table_identifier).metadata.location == bare + + +def test_replace_table_transaction_rolls_back_on_failure(catalog: Catalog, test_table_identifier: Identifier) -> None: + _create_simple_table(catalog, test_table_identifier) + catalog.load_table(test_table_identifier).append(_simple_data()) + before = catalog.load_table(test_table_identifier).metadata + + def run_failing_replace() -> None: + with catalog.replace_table_transaction(test_table_identifier, schema=_REPLACE_SCHEMA): + raise RuntimeError("simulated failure inside replace transaction") + + with pytest.raises(RuntimeError, match="simulated failure inside replace transaction"): + run_failing_replace() + + after = catalog.load_table(test_table_identifier).metadata + assert after.table_uuid == before.table_uuid + assert after.current_snapshot_id == before.current_snapshot_id + assert after.current_schema_id == before.current_schema_id + assert len(after.schemas) == len(before.schemas) + + +def test_concurrent_replace_transaction_schema_conflict(catalog: Catalog, test_table_identifier: Identifier) -> None: + _create_simple_table(catalog, test_table_identifier) + txn_a = catalog.replace_table_transaction(test_table_identifier, schema=_REPLACE_SCHEMA) + txn_b = catalog.replace_table_transaction(test_table_identifier, schema=_REPLACE_SCHEMA) + + txn_a.commit_transaction() + after_a = catalog.load_table(test_table_identifier).metadata + with pytest.raises(CommitFailedException, match="last assigned field id"): + txn_b.commit_transaction() + # The failed commit must be a true no-op: no metadata advanced past where `txn_a` left things. + assert catalog.load_table(test_table_identifier).metadata.last_column_id == after_a.last_column_id + + +def test_concurrent_replace_transaction_partition_spec_conflict(catalog: Catalog, test_table_identifier: Identifier) -> None: + _, schema = _create_simple_table(catalog, test_table_identifier) + new_spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, name="id_part", transform=IdentityTransform())) + txn_a = catalog.replace_table_transaction(test_table_identifier, schema=schema, partition_spec=new_spec) + txn_b = catalog.replace_table_transaction(test_table_identifier, schema=schema, partition_spec=new_spec) + + txn_a.commit_transaction() + after_a = catalog.load_table(test_table_identifier).metadata + with pytest.raises(CommitFailedException, match="last assigned partition id"): + txn_b.commit_transaction() + # The failed commit must be a true no-op: no metadata advanced past where `txn_a` left things. + assert catalog.load_table(test_table_identifier).metadata.last_partition_id == after_a.last_partition_id + + # Rename table tests diff --git a/tests/catalog/test_rest.py b/tests/catalog/test_rest.py index df2f96a392..bc50556f1d 100644 --- a/tests/catalog/test_rest.py +++ b/tests/catalog/test_rest.py @@ -65,7 +65,7 @@ from pyiceberg.table.sorting import SortField, SortOrder from pyiceberg.transforms import IdentityTransform, TruncateTransform from pyiceberg.typedef import RecursiveDict -from pyiceberg.types import StringType +from pyiceberg.types import BooleanType, IntegerType, NestedField, StringType from pyiceberg.utils.config import Config from pyiceberg.view import View from pyiceberg.view.metadata import ViewMetadata, ViewVersion @@ -2899,3 +2899,91 @@ def test_load_table_without_storage_credentials( ) assert actual.metadata.model_dump() == expected.metadata.model_dump() assert actual == expected + + +def _mock_replace_endpoints( + rest_mock: Mocker, + namespace: str, + table: str, + load_response: dict[str, Any], + commit_response: dict[str, Any], +) -> None: + rest_mock.get( + f"{TEST_URI}v1/namespaces/{namespace}/tables/{table}", + json=load_response, + status_code=200, + request_headers=TEST_HEADERS, + ) + rest_mock.post( + f"{TEST_URI}v1/namespaces/{namespace}/tables/{table}", + json=commit_response, + status_code=200, + request_headers=TEST_HEADERS, + ) + + +def test_replace_table_transaction_wire_payload( + rest_mock: Mocker, + example_table_metadata_with_snapshot_v1_rest_json: dict[str, Any], + example_table_metadata_no_snapshot_v1_rest_json: dict[str, Any], +) -> None: + table_uuid = example_table_metadata_with_snapshot_v1_rest_json["metadata"]["table-uuid"] + example_table_metadata_no_snapshot_v1_rest_json["metadata"]["table-uuid"] = table_uuid + _mock_replace_endpoints( + rest_mock, + "fokko", + "fokko2", + example_table_metadata_with_snapshot_v1_rest_json, + example_table_metadata_no_snapshot_v1_rest_json, + ) + catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) + + new_schema = Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + NestedField(field_id=3, name="new_col", field_type=BooleanType(), required=False), + ) + catalog.replace_table_transaction(identifier=("fokko", "fokko2"), schema=new_schema).commit_transaction() + request = rest_mock.last_request.json() + + fixture_metadata = example_table_metadata_with_snapshot_v1_rest_json["metadata"] + assert request["requirements"] == [ + {"type": "assert-table-uuid", "uuid": table_uuid}, + {"type": "assert-last-assigned-field-id", "last-assigned-field-id": fixture_metadata["last-column-id"]}, + {"type": "assert-last-assigned-partition-id", "last-assigned-partition-id": fixture_metadata["last-partition-id"]}, + ] + + actions = [u["action"] for u in request["updates"]] + assert sorted(actions) == [ + "add-schema", + "remove-snapshot-ref", + "set-current-schema", + "set-default-sort-order", + "set-default-spec", + ] + updates_by_action = {u["action"]: u for u in request["updates"]} + + assert updates_by_action["remove-snapshot-ref"] == {"action": "remove-snapshot-ref", "ref-name": "main"} + added_schema = updates_by_action["add-schema"]["schema"] + assert {f["name"]: f["id"] for f in added_schema["fields"]} == {"id": 1, "data": 2, "new_col": 3} + # schema-id=-1 is the wire sentinel meaning "the schema we just added in this commit". + assert updates_by_action["set-current-schema"]["schema-id"] == -1 + assert updates_by_action["set-default-spec"]["spec-id"] == fixture_metadata["default-spec-id"] + assert updates_by_action["set-default-sort-order"]["sort-order-id"] == fixture_metadata["default-sort-order-id"] + + +def test_replace_table_transaction_404_raises( + rest_mock: Mocker, +) -> None: + rest_mock.get( + f"{TEST_URI}v1/namespaces/fokko/tables/missing", + json={"error": {"message": "Table not found", "type": "NoSuchTableException", "code": 404}}, + status_code=404, + request_headers=TEST_HEADERS, + ) + catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) + with pytest.raises(NoSuchTableError): + catalog.replace_table_transaction( + identifier=("fokko", "missing"), + schema=Schema(NestedField(field_id=1, name="id", field_type=IntegerType(), required=False)), + ) diff --git a/tests/integration/test_catalog.py b/tests/integration/test_catalog.py index 751dbe0479..668ffbb42c 100644 --- a/tests/integration/test_catalog.py +++ b/tests/integration/test_catalog.py @@ -20,6 +20,7 @@ from collections.abc import Generator from pathlib import Path, PosixPath +import pyarrow as pa import pytest from pytest_lazy_fixtures import lf @@ -85,7 +86,15 @@ def sqlite_catalog_file(warehouse: Path) -> Generator[Catalog, None, None]: @pytest.fixture(scope="function") def rest_catalog() -> Generator[Catalog, None, None]: - test_catalog = RestCatalog("rest", uri="http://localhost:8181") + test_catalog = RestCatalog( + "rest", + **{ + "uri": "http://localhost:8181", + "s3.endpoint": "http://localhost:9000", + "s3.access-key-id": "admin", + "s3.secret-access-key": "password", + }, + ) yield test_catalog @@ -866,3 +875,36 @@ def test_load_missing_table(test_catalog: Catalog, database_name: str, table_nam with pytest.raises(NoSuchTableError): test_catalog.load_table(identifier) + + +@pytest.mark.integration +@pytest.mark.parametrize("test_catalog", CATALOGS) +def test_replace_table_transaction(test_catalog: Catalog, database_name: str, table_name: str) -> None: + test_catalog.create_namespace(database_name) + identifier = (database_name, table_name) + + old_data = pa.Table.from_pydict( + {"id": [1], "data": ["old"]}, + schema=pa.schema([pa.field("id", pa.int64()), pa.field("data", pa.large_string())]), + ) + original = test_catalog.create_table(identifier, schema=old_data.schema) + original.append(old_data) + old_snapshot_id = test_catalog.load_table(identifier).current_snapshot().snapshot_id # type: ignore[union-attr] + + new_data = pa.Table.from_pydict( + {"id": [10, 20], "name": ["alice", "bob"]}, + schema=pa.schema([pa.field("id", pa.int64()), pa.field("name", pa.large_string())]), + ) + with test_catalog.replace_table_transaction(identifier, schema=new_data.schema) as txn: + txn.append(new_data) + + replaced = test_catalog.load_table(identifier) + assert replaced.metadata.table_uuid == original.metadata.table_uuid + assert replaced.current_snapshot() is not None + assert replaced.current_snapshot().snapshot_id != old_snapshot_id # type: ignore[union-attr] + assert any(s.snapshot_id == old_snapshot_id for s in replaced.metadata.snapshots) + assert replaced.scan().to_arrow().num_rows == 2 + # Time-travel back to the pre-replace snapshot returns the original row. + old_via_time_travel = replaced.scan(snapshot_id=old_snapshot_id).to_arrow() + assert old_via_time_travel.num_rows == 1 + assert old_via_time_travel.column("id").to_pylist() == [1] diff --git a/tests/table/test_partitioning.py b/tests/table/test_partitioning.py index a27046ef30..aa39153200 100644 --- a/tests/table/test_partitioning.py +++ b/tests/table/test_partitioning.py @@ -22,7 +22,12 @@ import pytest from pyiceberg.exceptions import ValidationError -from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec +from pyiceberg.partitioning import ( + UNPARTITIONED_PARTITION_SPEC, + PartitionField, + PartitionSpec, + assign_fresh_partition_spec_ids_for_replace, +) from pyiceberg.schema import Schema from pyiceberg.transforms import ( BucketTransform, @@ -31,6 +36,7 @@ IdentityTransform, MonthTransform, TruncateTransform, + VoidTransform, YearTransform, ) from pyiceberg.typedef import Record @@ -298,3 +304,178 @@ def test_incompatible_transform_source_type() -> None: spec.check_compatible(schema) assert "Invalid source field foo with type int for transform: year" in str(exc.value) + + +_REPLACE_SCHEMA_FOR_PARTITION = Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), +) + + +@pytest.mark.parametrize( + "new_spec, existing_specs, last_partition_id, expected_field_id, expected_last_partition_id", + [ + # Reuse-by-identity: same (source_id, IdentityTransform) already in an existing spec. + pytest.param( + PartitionSpec(PartitionField(source_id=1, field_id=999, transform=IdentityTransform(), name="id")), + [PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="id"), spec_id=0)], + 1000, + 1000, + 1000, + id="reuse-identity", + ), + # Reuse-by-(source,bucket): same source_id + same BucketTransform, even under a renamed field. + pytest.param( + PartitionSpec(PartitionField(source_id=1, field_id=999, transform=BucketTransform(8), name="id_bucket_renamed")), + [ + PartitionSpec( + PartitionField(source_id=1, field_id=1042, transform=BucketTransform(8), name="id_bucket"), spec_id=0 + ) + ], + 1042, + 1042, + 1042, + id="reuse-bucket-under-rename", + ), + # No match: fresh id above last_partition_id. + pytest.param( + PartitionSpec(PartitionField(source_id=1, field_id=999, transform=IdentityTransform(), name="id")), + [PartitionSpec(spec_id=0)], + 999, + 1000, + 1000, + id="new-field-above-last-partition-id", + ), + ], +) +def test_assign_fresh_partition_spec_ids_for_replace_v2( + new_spec: PartitionSpec, + existing_specs: list[PartitionSpec], + last_partition_id: int, + expected_field_id: int, + expected_last_partition_id: int, +) -> None: + fresh_spec, new_last_pid = assign_fresh_partition_spec_ids_for_replace( + new_spec, _REPLACE_SCHEMA_FOR_PARTITION, _REPLACE_SCHEMA_FOR_PARTITION, existing_specs, last_partition_id + ) + assert fresh_spec.fields[0].field_id == expected_field_id + assert new_last_pid == expected_last_partition_id + + +def test_assign_fresh_partition_spec_ids_for_replace_v1_carries_forward_as_void() -> None: + """v1 specs are append-only: a field absent from the new spec is carried forward as void.""" + current_spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="id"), spec_id=0) + # New spec drops "id" entirely, partitioned by "data" instead. + new_spec = PartitionSpec(PartitionField(source_id=2, field_id=999, transform=IdentityTransform(), name="data")) + fresh_spec, new_last_pid = assign_fresh_partition_spec_ids_for_replace( + new_spec, + _REPLACE_SCHEMA_FOR_PARTITION, + _REPLACE_SCHEMA_FOR_PARTITION, + existing_specs=[current_spec], + last_partition_id=1000, + format_version=1, + current_spec=current_spec, + ) + # Two fields: the carried-forward void at field_id=1000, and the new "data" field above it. + fields_by_id = {f.field_id: f for f in fresh_spec.fields} + assert isinstance(fields_by_id[1000].transform, VoidTransform) + assert fields_by_id[1000].name == "id" + assert fields_by_id[1001].name == "data" + assert isinstance(fields_by_id[1001].transform, IdentityTransform) + assert new_last_pid == 1001 + + +def test_assign_fresh_partition_spec_ids_for_replace_v1_renames_void_on_name_collision() -> None: + """When a void field's name collides with a new field's name, a unique suffix is added.""" + current_spec = PartitionSpec( + PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="data"), spec_id=0 + ) + # New spec partitions "data" by a different transform — the OLD "data" must be voided + # under a different name to avoid collision with the NEW "data" partition. + new_spec = PartitionSpec(PartitionField(source_id=2, field_id=999, transform=IdentityTransform(), name="data")) + fresh_spec, _ = assign_fresh_partition_spec_ids_for_replace( + new_spec, + _REPLACE_SCHEMA_FOR_PARTITION, + _REPLACE_SCHEMA_FOR_PARTITION, + existing_specs=[current_spec], + last_partition_id=1000, + format_version=1, + current_spec=current_spec, + ) + void_field = next(f for f in fresh_spec.fields if isinstance(f.transform, VoidTransform)) + assert void_field.name != "data", "void name must not collide with active partition name" + assert void_field.name == "data_1000" + + +def test_assign_fresh_partition_spec_ids_for_replace_v1_keeps_field_preserves_id() -> None: + """v1 carry-forward: when a current-spec field is also in the new spec, its field_id is preserved.""" + schema = Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ) + current_spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="id"), spec_id=0) + # New spec keeps the same (source, transform) on "id" — should reuse field_id=1000, no void emitted. + new_spec = PartitionSpec(PartitionField(source_id=1, field_id=999, transform=IdentityTransform(), name="id")) + fresh_spec, new_last_pid = assign_fresh_partition_spec_ids_for_replace( + new_spec, + schema, + schema, + existing_specs=[current_spec], + last_partition_id=1000, + format_version=1, + current_spec=current_spec, + ) + assert [f.field_id for f in fresh_spec.fields] == [1000] + assert fresh_spec.fields[0].name == "id" + assert isinstance(fresh_spec.fields[0].transform, IdentityTransform) + assert not any(isinstance(f.transform, VoidTransform) for f in fresh_spec.fields) + assert new_last_pid == 1000 + + +def test_assign_fresh_partition_spec_ids_for_replace_v1_void_name_uses_multi_suffix_loop() -> None: + """When `name` and `name_` are both already used, append `_2`, `_3`, ... until unique.""" + # Three columns, one role each: source for the current (about-to-be-voided) partition, + # source for the new partition that collides on the void's preferred name, and source for + # the new partition that collides on the void's first fallback name. + schema = Schema( + NestedField(field_id=1, name="current_source", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="collide_on_name", field_type=IntegerType(), required=False), + NestedField(field_id=3, name="collide_on_fallback", field_type=IntegerType(), required=False), + ) + # Current v1 spec partitions source=1 by bucket(4) at field_id=1000, named "p" — for + # non-identity transforms the partition NAME doesn't have to match the source column name. + current_spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=BucketTransform(4), name="p"), spec_id=0) + # New spec has two partition fields named "p" and "p_1000" — colliding with both the + # void's preferred name and its first fallback. Both are on different sources, so they + # do not match the current (source=1, bucket[4]) key and the current field becomes void. + new_spec = PartitionSpec( + PartitionField(source_id=2, field_id=997, transform=BucketTransform(4), name="p"), + PartitionField(source_id=3, field_id=998, transform=BucketTransform(4), name="p_1000"), + ) + fresh_spec, _ = assign_fresh_partition_spec_ids_for_replace( + new_spec, + schema, + schema, + existing_specs=[current_spec], + last_partition_id=1000, + format_version=1, + current_spec=current_spec, + ) + void_field = next(f for f in fresh_spec.fields if isinstance(f.transform, VoidTransform)) + assert void_field.name == "p_1000_2" + + +def test_assign_fresh_partition_spec_ids_for_replace_v2_prefers_highest_field_id_for_repeated_key() -> None: + """v2: when the same (source_id, transform) appears across multiple specs, the highest field_id wins.""" + # Two historical specs both partition by (source_id=1, IdentityTransform), with different field_ids. + existing_specs = [ + PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="id"), spec_id=0), + PartitionSpec(PartitionField(source_id=1, field_id=1002, transform=IdentityTransform(), name="id_v2"), spec_id=1), + ] + # New spec uses the same (source, transform) — should reuse the highest historical field_id (1002). + new_spec = PartitionSpec(PartitionField(source_id=1, field_id=999, transform=IdentityTransform(), name="id")) + fresh_spec, new_last_pid = assign_fresh_partition_spec_ids_for_replace( + new_spec, _REPLACE_SCHEMA_FOR_PARTITION, _REPLACE_SCHEMA_FOR_PARTITION, existing_specs, last_partition_id=1002 + ) + assert fresh_spec.fields[0].field_id == 1002 + assert new_last_pid == 1002 diff --git a/tests/test_schema.py b/tests/test_schema.py index 93ddc16202..5f5368fda5 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -26,6 +26,7 @@ Accessor, Schema, _check_schema_compatible, + assign_fresh_schema_ids_for_replace, build_position_accessors, index_by_id, index_by_name, @@ -1815,3 +1816,114 @@ def test_check_schema_compatible_optional_map_field_present() -> None: ) # Should not raise - schemas match _check_schema_compatible(requested_schema, provided_schema) + + +@pytest.mark.parametrize( + "new_fields, expected_ids, expected_last_col_id", + [ + # All columns reused by name: IDs come from base, last_column_id unchanged. + ([("id", IntegerType()), ("data", StringType())], [1, 2], 2), + # Mix of reused and new: new column gets ID above last_column_id. + ([("id", IntegerType()), ("data", StringType()), ("new_col", BooleanType())], [1, 2, 3], 3), + # No column names match: all fresh IDs starting from last_column_id + 1. + ([("x", IntegerType()), ("y", IntegerType())], [3, 4], 4), + ], + ids=[ + "all-reused-keeps-last-col-id", + "new-field-bumps-last-col-id", + "no-name-overlap-bumps-from-base", + ], +) +def test_assign_fresh_schema_ids_for_replace_primitive_fields( + new_fields: list[tuple[str, IcebergType]], expected_ids: list[int], expected_last_col_id: int +) -> None: + """Replace schema field IDs are reused from the base schema by name; new fields get IDs above last_column_id.""" + base_schema = Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + ) + new_schema = Schema( + *( + NestedField(field_id=10 * (i + 1), name=name, field_type=field_type, required=False) + for i, (name, field_type) in enumerate(new_fields) + ) + ) + fresh, last_col_id = assign_fresh_schema_ids_for_replace(new_schema, base_schema, 2) + assert [f.field_id for f in fresh.fields] == expected_ids + assert last_col_id == expected_last_col_id + + +def test_assign_fresh_schema_ids_for_replace_with_nested_struct() -> None: + """Test that nested struct field IDs are reused by full path name.""" + base_schema = Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField( + field_id=2, + name="location", + field_type=StructType( + NestedField(field_id=3, name="lat", field_type=FloatType(), required=False), + NestedField(field_id=4, name="lon", field_type=FloatType(), required=False), + ), + required=False, + ), + ) + new_schema = Schema( + NestedField(field_id=10, name="id", field_type=IntegerType(), required=False), + NestedField( + field_id=20, + name="location", + field_type=StructType( + NestedField(field_id=30, name="lat", field_type=FloatType(), required=False), + NestedField(field_id=40, name="lon", field_type=FloatType(), required=False), + NestedField(field_id=50, name="alt", field_type=FloatType(), required=False), + ), + required=False, + ), + ) + fresh, last_col_id = assign_fresh_schema_ids_for_replace(new_schema, base_schema, 4) + assert fresh.fields[0].field_id == 1 # id reused + assert fresh.fields[1].field_id == 2 # location reused + loc_fields = fresh.fields[1].field_type.fields + assert loc_fields[0].field_id == 3 # location.lat reused + assert loc_fields[1].field_id == 4 # location.lon reused + assert loc_fields[2].field_id == 5 # location.alt is new + assert last_col_id == 5 + + +def test_assign_fresh_schema_ids_for_replace_with_list_and_map() -> None: + """`element_id`, `key_id`, and `value_id` are reused by name path (e.g. `tags.element`, `m.key`, `m.value`).""" + base_schema = Schema( + NestedField( + field_id=1, + name="tags", + field_type=ListType(element_id=2, element_type=StringType(), element_required=False), + required=False, + ), + NestedField( + field_id=3, + name="m", + field_type=MapType(key_id=4, key_type=StringType(), value_id=5, value_type=IntegerType(), value_required=False), + required=False, + ), + ) + new_schema = Schema( + NestedField( + field_id=10, + name="tags", + field_type=ListType(element_id=20, element_type=StringType(), element_required=False), + required=False, + ), + NestedField( + field_id=30, + name="m", + field_type=MapType(key_id=40, key_type=StringType(), value_id=50, value_type=IntegerType(), value_required=False), + required=False, + ), + ) + fresh, last_col_id = assign_fresh_schema_ids_for_replace(new_schema, base_schema, 5) + assert fresh.fields[0].field_id == 1 + assert fresh.fields[0].field_type.element_id == 2 + assert fresh.fields[1].field_id == 3 + assert fresh.fields[1].field_type.key_id == 4 + assert fresh.fields[1].field_type.value_id == 5 + assert last_col_id == 5