diff --git a/pyproject.toml b/pyproject.toml index 1d918c6..8cca206 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [tool.poetry] -name = "nhs_dve" +name = "data-validation-engine" version = "0.6.2" description = "`nhs data validation engine` is a framework used to validate data" authors = ["NHS England "] diff --git a/src/dve/core_engine/backends/base/backend.py b/src/dve/core_engine/backends/base/backend.py index 9d6abaa..29e8644 100644 --- a/src/dve/core_engine/backends/base/backend.py +++ b/src/dve/core_engine/backends/base/backend.py @@ -163,7 +163,7 @@ def apply( return entities, get_parent(processing_errors_uri), successful for entity_name, entity in entities.items(): - entities[entity_name] = self.step_implementations.add_row_id(entity) + entities[entity_name] = self.step_implementations.add_record_index(entity) # TODO: Handle entity manager creation errors. entity_manager = EntityManager(entities, reference_data) @@ -172,9 +172,6 @@ def apply( # TODO: and return uri to errors _ = self.step_implementations.apply_rules(working_dir, entity_manager, rule_metadata) - for entity_name, entity in entity_manager.entities.items(): - entity_manager.entities[entity_name] = self.step_implementations.drop_row_id(entity) - return entity_manager.entities, get_parent(dc_feedback_errors_uri), True def process( diff --git a/src/dve/core_engine/backends/base/contract.py b/src/dve/core_engine/backends/base/contract.py index fc7da4d..948ff77 100644 --- a/src/dve/core_engine/backends/base/contract.py +++ b/src/dve/core_engine/backends/base/contract.py @@ -337,9 +337,9 @@ def read_raw_entities( successful = True for entity_name, resource in entity_locations.items(): reader_metadata = contract_metadata.reader_metadata[entity_name] - extension = "." + ( - get_file_suffix(resource) or "" - ).lower() # Already checked that extension supported. + extension = ( + "." + (get_file_suffix(resource) or "").lower() + ) # Already checked that extension supported. reader_config = reader_metadata[extension] reader_type = get_reader(reader_config.reader) @@ -369,6 +369,14 @@ def read_raw_entities( return entities, dedup_messages(messages), successful + def add_record_index(self, entity: EntityType, **kwargs) -> EntityType: + """Add a record index to the entity""" + raise NotImplementedError(f"add_record_index not implemented in {self.__class__}") + + def drop_record_index(self, entity: EntityType, **kwargs) -> EntityType: + """Drop a record index from the entity""" + raise NotImplementedError(f"drop_record_index not implemented in {self.__class__}") + @abstractmethod def apply_data_contract( self, diff --git a/src/dve/core_engine/backends/base/reader.py b/src/dve/core_engine/backends/base/reader.py index 54abaa9..ac30111 100644 --- a/src/dve/core_engine/backends/base/reader.py +++ b/src/dve/core_engine/backends/base/reader.py @@ -127,6 +127,14 @@ def read_to_entity_type( return reader_func(self, resource, entity_name, schema) + def add_record_index(self, entity: EntityType, **kwargs) -> EntityType: + """Add a record index to the entity""" + raise NotImplementedError(f"add_record_index not implemented in {self.__class__}") + + def drop_record_index(self, entity: EntityType, **kwargs) -> EntityType: + """Drop a record index to the entity""" + raise NotImplementedError(f"drop_record_index not implemented in {self.__class__}") + def write_parquet( self, entity: EntityType, diff --git a/src/dve/core_engine/backends/base/rules.py b/src/dve/core_engine/backends/base/rules.py index 97a6b4d..b66b3ae 100644 --- a/src/dve/core_engine/backends/base/rules.py +++ b/src/dve/core_engine/backends/base/rules.py @@ -135,15 +135,13 @@ def register_udfs(cls, **kwargs): """Method to register all custom dve functions for use during business rules application""" raise NotImplementedError() - @staticmethod - def add_row_id(entity: EntityType) -> EntityType: - """Add a unique row id field to an entity""" - raise NotImplementedError() + def add_record_index(self, entity: EntityType, **kwargs) -> EntityType: + """Add a record index to the entity""" + raise NotImplementedError(f"add_record_index not implemented in {self.__class__}") - @staticmethod - def drop_row_id(entity: EntityType) -> EntityType: - """Add a unique row id field to an entity""" - raise NotImplementedError() + def drop_record_index(self, entity: EntityType) -> EntityType: + """Drop a unique row id field to an entity""" + raise NotImplementedError(f"drop_record_index not implemented in {self.__class__}") @classmethod def _raise_notimplemented_error( diff --git a/src/dve/core_engine/backends/implementations/duckdb/contract.py b/src/dve/core_engine/backends/implementations/duckdb/contract.py index 860f06b..075573d 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/contract.py +++ b/src/dve/core_engine/backends/implementations/duckdb/contract.py @@ -29,6 +29,7 @@ ) from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import ( duckdb_read_parquet, + duckdb_record_index, duckdb_write_parquet, get_duckdb_type_from_annotation, relation_is_empty, @@ -37,6 +38,7 @@ from dve.core_engine.backends.metadata.contract import DataContractMetadata from dve.core_engine.backends.types import StageSuccessful from dve.core_engine.backends.utilities import get_polars_type_from_annotation, stringify_model +from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME from dve.core_engine.message import FeedbackMessage from dve.core_engine.type_hints import URI, EntityLocations from dve.core_engine.validation import RowValidator, apply_row_validator_helper @@ -54,6 +56,7 @@ def __call__(self, row: pd.Series): return row # no op +@duckdb_record_index @duckdb_write_parquet @duckdb_read_parquet class DuckDBDataContract(BaseDataContract[DuckDBPyRelation]): @@ -144,10 +147,12 @@ def apply_data_contract( fld.name: get_duckdb_type_from_annotation(fld.annotation) for fld in entity_fields.values() } + ddb_schema[RECORD_INDEX_COLUMN_NAME] = get_duckdb_type_from_annotation(int) polars_schema: dict[str, PolarsType] = { fld.name: get_polars_type_from_annotation(fld.annotation) for fld in entity_fields.values() } + polars_schema[RECORD_INDEX_COLUMN_NAME] = get_polars_type_from_annotation(int) if relation_is_empty(relation): self.logger.warning(f"+ Empty relation for {entity_name}") empty_df = pl.DataFrame([], schema=polars_schema) # type: ignore # pylint: disable=W0612 @@ -170,6 +175,9 @@ def apply_data_contract( self.logger.info(f"Data contract found {msg_count} issues in {entity_name}") + if not RECORD_INDEX_COLUMN_NAME in relation.columns: + relation = self.add_record_index(relation) + casting_statements = [ ( self.generate_ddb_cast_statement(column, dtype) diff --git a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py index 843ee40..f5b0fe9 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py +++ b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py @@ -12,13 +12,14 @@ import duckdb.typing as ddbtyp import numpy as np -from duckdb import DuckDBPyConnection, DuckDBPyRelation +from duckdb import DuckDBPyConnection, DuckDBPyRelation, StarExpression from duckdb.typing import DuckDBPyType from pandas import DataFrame from pydantic import BaseModel from typing_extensions import Annotated, get_args, get_origin, get_type_hints from dve.core_engine.backends.base.utilities import _get_non_heterogenous_type +from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME from dve.core_engine.type_hints import URI from dve.parser.file_handling.service import LocalFilesystemImplementation, _get_implementation @@ -286,3 +287,29 @@ def duckdb_rel_to_dictionaries( cols: tuple[str] = tuple(entity.columns) # type: ignore while rows := entity.fetchmany(batch_size): yield from (dict(zip(cols, rw)) for rw in rows) + + +def _add_duckdb_record_index( + self, entity: DuckDBPyRelation # pylint: disable=W0613 +) -> DuckDBPyRelation: + """Add record index to duckdb relation""" + if RECORD_INDEX_COLUMN_NAME in entity.columns: + return entity + + return entity.select(f"*, row_number() OVER () as {RECORD_INDEX_COLUMN_NAME}") + + +def _drop_duckdb_record_index( + self, entity: DuckDBPyRelation # pylint: disable=W0613 +) -> DuckDBPyRelation: + """Drop record index from duckdb relation""" + if RECORD_INDEX_COLUMN_NAME not in entity.columns: + return entity + return entity.select(StarExpression(exclude=[RECORD_INDEX_COLUMN_NAME])) + + +def duckdb_record_index(cls): + """Class decorator to add record index methods for duckdb implementations""" + setattr(cls, "add_record_index", _add_duckdb_record_index) + setattr(cls, "drop_record_index", _drop_duckdb_record_index) + return cls diff --git a/src/dve/core_engine/backends/implementations/duckdb/readers/csv.py b/src/dve/core_engine/backends/implementations/duckdb/readers/csv.py index 43edb6a..f4c092b 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/readers/csv.py +++ b/src/dve/core_engine/backends/implementations/duckdb/readers/csv.py @@ -6,23 +6,32 @@ import duckdb as ddb import polars as pl -from duckdb import DuckDBPyConnection, DuckDBPyRelation, default_connection, read_csv +from duckdb import ( + DuckDBPyConnection, + DuckDBPyRelation, + StarExpression, + default_connection, + read_csv, +) from pydantic import BaseModel from dve.core_engine.backends.base.reader import BaseFileReader, read_function from dve.core_engine.backends.exceptions import EmptyFileError, MessageBearingError from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import ( + duckdb_record_index, duckdb_write_parquet, get_duckdb_type_from_annotation, ) from dve.core_engine.backends.implementations.duckdb.types import SQLType from dve.core_engine.backends.readers.utilities import check_csv_header_expected -from dve.core_engine.backends.utilities import get_polars_type_from_annotation +from dve.core_engine.backends.utilities import get_polars_type_from_annotation, polars_record_index +from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME from dve.core_engine.message import FeedbackMessage from dve.core_engine.type_hints import URI, EntityName from dve.parser.file_handling import get_content_length +@duckdb_record_index @duckdb_write_parquet class DuckDBCSVReader(BaseFileReader): """A reader for CSV files including the ability to compare the passed model @@ -111,18 +120,19 @@ def read_to_relation( # pylint: disable=unused-argument } reader_options["columns"] = ddb_schema - rel = read_csv(resource, **reader_options) + + rel = self.add_record_index(read_csv(resource, **reader_options, parallel=False)) if self.null_empty_strings: - cleaned_cols = ",".join([ - f"NULLIF(TRIM({c}), '') as {c}" - for c in reader_options["columns"].keys() - ]) + cleaned_cols = ",".join( + [f"NULLIF(TRIM({c}), '') as {c}" for c in reader_options["columns"].keys()] + ) rel = rel.select(cleaned_cols) return rel +@polars_record_index class PolarsToDuckDBCSVReader(DuckDBCSVReader): """ Utilises the polars lazy csv reader which is then converted into a DuckDBPyRelation object. @@ -156,10 +166,19 @@ def read_to_relation( # pylint: disable=unused-argument # there is a raise_if_empty arg for 0.18+. Future reference when upgrading. Makes L85 # redundant - df = pl.scan_csv(resource, **reader_options).select(list(polars_types.keys())) # type: ignore # pylint: disable=W0612 + df = self.add_record_index( # pylint: disable=W0612 + pl.scan_csv(resource, **reader_options).select( # type: ignore + list(polars_types.keys()) + ) + ) if self.null_empty_strings: - df = df.select([pl.col(c).str.strip_chars().replace("", None) for c in df.columns]) + pl_exprs = [ + pl.col(c).str.strip_chars().replace("", None) + for c in df.columns + if not c == RECORD_INDEX_COLUMN_NAME + ] + [pl.col(RECORD_INDEX_COLUMN_NAME)] + df = df.select(pl_exprs) return ddb.sql("SELECT * FROM df") @@ -203,8 +222,10 @@ def __init__( def read_to_relation( # pylint: disable=unused-argument self, resource: URI, entity_name: EntityName, schema: type[BaseModel] ) -> DuckDBPyRelation: - entity = super().read_to_relation(resource=resource, entity_name=entity_name, schema=schema) - entity = entity.distinct() + entity: DuckDBPyRelation = super().read_to_relation( + resource=resource, entity_name=entity_name, schema=schema + ) + entity = entity.select(StarExpression(exclude=[RECORD_INDEX_COLUMN_NAME])).distinct() no_records = entity.shape[0] if no_records != 1: @@ -233,4 +254,4 @@ def read_to_relation( # pylint: disable=unused-argument ], ) - return entity + return entity.select(f"*, 1 as {RECORD_INDEX_COLUMN_NAME}") diff --git a/src/dve/core_engine/backends/implementations/duckdb/readers/json.py b/src/dve/core_engine/backends/implementations/duckdb/readers/json.py index b1a3ad4..8afb5a4 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/readers/json.py +++ b/src/dve/core_engine/backends/implementations/duckdb/readers/json.py @@ -9,6 +9,7 @@ from dve.core_engine.backends.base.reader import BaseFileReader, read_function from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import ( + duckdb_record_index, duckdb_write_parquet, get_duckdb_type_from_annotation, ) @@ -16,6 +17,7 @@ from dve.core_engine.type_hints import URI, EntityName +@duckdb_record_index @duckdb_write_parquet class DuckDBJSONReader(BaseFileReader): """A reader for JSON files""" @@ -47,4 +49,6 @@ def read_to_relation( # pylint: disable=unused-argument for fld in schema.__fields__.values() } - return read_json(resource, columns=ddb_schema, format=self._json_format) # type: ignore + return self.add_record_index( + read_json(resource, columns=ddb_schema, format=self._json_format) # type: ignore + ) diff --git a/src/dve/core_engine/backends/implementations/duckdb/readers/xml.py b/src/dve/core_engine/backends/implementations/duckdb/readers/xml.py index a955946..a10998c 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/readers/xml.py +++ b/src/dve/core_engine/backends/implementations/duckdb/readers/xml.py @@ -11,10 +11,15 @@ from dve.core_engine.backends.exceptions import MessageBearingError from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import duckdb_write_parquet from dve.core_engine.backends.readers.xml import XMLStreamReader -from dve.core_engine.backends.utilities import get_polars_type_from_annotation, stringify_model +from dve.core_engine.backends.utilities import ( + get_polars_type_from_annotation, + polars_record_index, + stringify_model, +) from dve.core_engine.type_hints import URI +@polars_record_index @duckdb_write_parquet class DuckDBXMLStreamReader(XMLStreamReader): """A reader for XML files""" @@ -39,7 +44,9 @@ def read_to_relation(self, resource: URI, entity_name: str, schema: type[BaseMod for fld in stringify_model(schema).__fields__.values() } - _lazy_frame = pl.LazyFrame( - data=self.read_to_py_iterator(resource, entity_name, schema), schema=polars_schema + _lazy_frame = self.add_record_index( + pl.LazyFrame( + data=self.read_to_py_iterator(resource, entity_name, schema), schema=polars_schema + ) ) return self.ddb_connection.sql("select * from _lazy_frame") diff --git a/src/dve/core_engine/backends/implementations/duckdb/rules.py b/src/dve/core_engine/backends/implementations/duckdb/rules.py index e556c6b..7ed775c 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/rules.py +++ b/src/dve/core_engine/backends/implementations/duckdb/rules.py @@ -23,6 +23,7 @@ from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import ( DDBStruct, duckdb_read_parquet, + duckdb_record_index, duckdb_rel_to_dictionaries, duckdb_write_parquet, get_all_registered_udfs, @@ -51,13 +52,13 @@ SemiJoin, TableUnion, ) -from dve.core_engine.constants import ROWID_COLUMN_NAME from dve.core_engine.functions import implementations as functions from dve.core_engine.message import FeedbackMessage from dve.core_engine.templating import template_object from dve.core_engine.type_hints import Messages +@duckdb_record_index @duckdb_write_parquet @duckdb_read_parquet class DuckDBStepImplementations(BaseStepImplementations[DuckDBPyRelation]): @@ -106,20 +107,6 @@ def register_udfs( # type: ignore connection.sql(_sql) return cls(connection=connection, **kwargs) - @staticmethod - def add_row_id(entity: DuckDBPyRelation) -> DuckDBPyRelation: - """Adds a row identifier to the Relation""" - if ROWID_COLUMN_NAME not in entity.columns: - entity = entity.project(f"*, ROW_NUMBER() OVER () as {ROWID_COLUMN_NAME}") - return entity - - @staticmethod - def drop_row_id(entity: DuckDBPyRelation) -> DuckDBPyRelation: - """Drops the row identiifer from a Relation""" - if ROWID_COLUMN_NAME in entity.columns: - entity = entity.select(StarExpression(exclude=[ROWID_COLUMN_NAME])) - return entity - def add(self, entities: DuckDBEntities, *, config: ColumnAddition) -> Messages: """A transformation step which adds a column to an entity.""" entity: DuckDBPyRelation = entities[config.entity_name] diff --git a/src/dve/core_engine/backends/implementations/spark/backend.py b/src/dve/core_engine/backends/implementations/spark/backend.py index 742e9e3..3999b62 100644 --- a/src/dve/core_engine/backends/implementations/spark/backend.py +++ b/src/dve/core_engine/backends/implementations/spark/backend.py @@ -11,7 +11,7 @@ from dve.core_engine.backends.implementations.spark.rules import SparkStepImplementations from dve.core_engine.backends.implementations.spark.spark_helpers import get_type_from_annotation from dve.core_engine.backends.implementations.spark.types import SparkEntities -from dve.core_engine.constants import ROWID_COLUMN_NAME +from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME from dve.core_engine.loggers import get_child_logger, get_logger from dve.core_engine.models import SubmissionInfo from dve.core_engine.type_hints import URI, EntityParquetLocations @@ -58,7 +58,7 @@ def write_entities_to_parquet( locations = {} self.logger.info(f"Writing entities to the output location: {cache_prefix}") for entity_name, entity in entities.items(): - entity = entity.drop(ROWID_COLUMN_NAME) + entity = entity.drop(RECORD_INDEX_COLUMN_NAME) self.logger.info(f"Entity: {entity_name}") diff --git a/src/dve/core_engine/backends/implementations/spark/contract.py b/src/dve/core_engine/backends/implementations/spark/contract.py index d8078bd..6152ad7 100644 --- a/src/dve/core_engine/backends/implementations/spark/contract.py +++ b/src/dve/core_engine/backends/implementations/spark/contract.py @@ -10,7 +10,7 @@ from pyspark.sql import DataFrame, SparkSession from pyspark.sql import functions as sf from pyspark.sql.functions import col, lit -from pyspark.sql.types import ArrayType, DataType, MapType, StringType, StructType +from pyspark.sql.types import ArrayType, DataType, LongType, MapType, StructField, StructType from dve.common.error_utils import ( BackgroundMessageWriter, @@ -28,19 +28,21 @@ df_is_empty, get_type_from_annotation, spark_read_parquet, + spark_record_index, spark_write_parquet, ) from dve.core_engine.backends.implementations.spark.types import SparkEntities from dve.core_engine.backends.metadata.contract import DataContractMetadata from dve.core_engine.backends.readers import CSVFileReader from dve.core_engine.backends.types import StageSuccessful -from dve.core_engine.constants import ROWID_COLUMN_NAME +from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME from dve.core_engine.type_hints import URI, EntityLocations, EntityName COMPLEX_TYPES: set[type[DataType]] = {StructType, ArrayType, MapType} """Spark types indicating complex types.""" +@spark_record_index @spark_write_parquet @spark_read_parquet class SparkDataContract(BaseDataContract[DataFrame]): @@ -84,6 +86,7 @@ def create_entity_from_py_iterator( schema=get_type_from_annotation(schema), ) + # pylint: disable=R0915 def apply_data_contract( self, working_dir: URI, @@ -100,14 +103,16 @@ def apply_data_contract( successful = True for entity_name, record_df in entities.items(): spark_schema = get_type_from_annotation(contract_metadata.schemas[entity_name]) - + spark_schema.add(StructField(RECORD_INDEX_COLUMN_NAME, LongType())) if df_is_empty(record_df): self.logger.warning(f"+ Empty dataframe for {entity_name}") entities[entity_name] = self.spark_session.createDataFrame( # type: ignore [], schema=spark_schema - ).withColumn(ROWID_COLUMN_NAME, lit(None).cast(StringType())) + ) continue + if not RECORD_INDEX_COLUMN_NAME in record_df.columns: + record_df = self.add_record_index(record_df) if self.debug: # Note, the count will realise the dataframe, so only do this diff --git a/src/dve/core_engine/backends/implementations/spark/readers/csv.py b/src/dve/core_engine/backends/implementations/spark/readers/csv.py index e629517..8c2b137 100644 --- a/src/dve/core_engine/backends/implementations/spark/readers/csv.py +++ b/src/dve/core_engine/backends/implementations/spark/readers/csv.py @@ -12,12 +12,14 @@ from dve.core_engine.backends.exceptions import EmptyFileError from dve.core_engine.backends.implementations.spark.spark_helpers import ( get_type_from_annotation, + spark_record_index, spark_write_parquet, ) from dve.core_engine.type_hints import URI, EntityName from dve.parser.file_handling import get_content_length +@spark_record_index @spark_write_parquet class SparkCSVReader(BaseFileReader): """A Spark reader for CSV files.""" @@ -73,16 +75,15 @@ def read_to_dataframe( "multiLine": self.multi_line, } - df = ( + df = self.add_record_index( self.spark_session.read.format("csv") .options(**kwargs) # type: ignore .load(resource, schema=spark_schema) ) if self.null_empty_strings: - df = df.select(*[ - psf.trim(psf.col(c.name)).alias(c.name) - for c in spark_schema.fields - ]).replace("", None) + df = df.select( + *[psf.trim(psf.col(c.name)).alias(c.name) for c in spark_schema.fields] + ).replace("", None) return df diff --git a/src/dve/core_engine/backends/implementations/spark/readers/json.py b/src/dve/core_engine/backends/implementations/spark/readers/json.py index c336ee0..0b4a09f 100644 --- a/src/dve/core_engine/backends/implementations/spark/readers/json.py +++ b/src/dve/core_engine/backends/implementations/spark/readers/json.py @@ -11,12 +11,14 @@ from dve.core_engine.backends.exceptions import EmptyFileError from dve.core_engine.backends.implementations.spark.spark_helpers import ( get_type_from_annotation, + spark_record_index, spark_write_parquet, ) from dve.core_engine.type_hints import URI, EntityName from dve.parser.file_handling import get_content_length +@spark_record_index @spark_write_parquet class SparkJSONReader(BaseFileReader): """A Spark reader for JSON files.""" @@ -59,7 +61,7 @@ def read_to_dataframe( "multiline": self.multi_line, } - return ( + return self.add_record_index( self.spark_session.read.format("json") .options(**kwargs) # type: ignore .load(resource, schema=spark_schema) diff --git a/src/dve/core_engine/backends/implementations/spark/readers/xml.py b/src/dve/core_engine/backends/implementations/spark/readers/xml.py index 30d6756..39433b3 100644 --- a/src/dve/core_engine/backends/implementations/spark/readers/xml.py +++ b/src/dve/core_engine/backends/implementations/spark/readers/xml.py @@ -17,6 +17,7 @@ from dve.core_engine.backends.implementations.spark.spark_helpers import ( df_is_empty, get_type_from_annotation, + spark_record_index, spark_write_parquet, ) from dve.core_engine.backends.readers.xml import BasicXMLFileReader, XMLStreamReader @@ -28,6 +29,7 @@ """The mode to use when parsing XML files with Spark.""" +@spark_record_index @spark_write_parquet class SparkXMLStreamReader(XMLStreamReader): """An XML stream reader that adds a method to read to a dataframe""" @@ -45,12 +47,15 @@ def read_to_dataframe( if not self.spark: self.spark = SparkSession.builder.getOrCreate() # type: ignore spark_schema = get_type_from_annotation(schema) - return self.spark.createDataFrame( # type: ignore - list(self.read_to_py_iterator(resource, entity_name, schema)), - schema=spark_schema, + return self.add_record_index( + self.spark.createDataFrame( # type: ignore + list(self.read_to_py_iterator(resource, entity_name, schema)), + schema=spark_schema, + ) ) +@spark_record_index @spark_write_parquet class SparkXMLReader(BasicXMLFileReader): # pylint: disable=too-many-instance-attributes """A reader for XML files built atop Spark-XML.""" @@ -177,7 +182,7 @@ def read_to_dataframe( df = self._add_missing_columns(df, spark_schema) df = self._sanitise_columns(df) - return df + return self.add_record_index(df) def _add_missing_columns(self, df: DataFrame, fields: Iterable[StructField]) -> DataFrame: for field in fields: diff --git a/src/dve/core_engine/backends/implementations/spark/rules.py b/src/dve/core_engine/backends/implementations/spark/rules.py index 15afa09..5d1cfe0 100644 --- a/src/dve/core_engine/backends/implementations/spark/rules.py +++ b/src/dve/core_engine/backends/implementations/spark/rules.py @@ -15,6 +15,7 @@ get_all_registered_udfs, object_to_spark_literal, spark_read_parquet, + spark_record_index, spark_write_parquet, ) from dve.core_engine.backends.implementations.spark.types import ( @@ -43,13 +44,13 @@ SemiJoin, TableUnion, ) -from dve.core_engine.constants import ROWID_COLUMN_NAME from dve.core_engine.functions import implementations as functions from dve.core_engine.message import FeedbackMessage from dve.core_engine.templating import template_object from dve.core_engine.type_hints import Messages +@spark_record_index @spark_write_parquet @spark_read_parquet class SparkStepImplementations(BaseStepImplementations[DataFrame]): @@ -100,18 +101,6 @@ def register_udfs( return cls(spark_session=spark_session, **kwargs) - @staticmethod - def add_row_id(entity: DataFrame) -> DataFrame: - if ROWID_COLUMN_NAME not in entity.columns: - entity = entity.withColumn(ROWID_COLUMN_NAME, sf.expr("uuid()")) - return entity - - @staticmethod - def drop_row_id(entity: DataFrame) -> DataFrame: - if ROWID_COLUMN_NAME in entity.columns: - entity = entity.drop(ROWID_COLUMN_NAME) - return entity - def add(self, entities: SparkEntities, *, config: ColumnAddition) -> Messages: entity: DataFrame = entities[config.entity_name] entity = entity.withColumn(config.column_name, sf.expr(config.expression)) diff --git a/src/dve/core_engine/backends/implementations/spark/spark_helpers.py b/src/dve/core_engine/backends/implementations/spark/spark_helpers.py index 7cb7b17..4381fdd 100644 --- a/src/dve/core_engine/backends/implementations/spark/spark_helpers.py +++ b/src/dve/core_engine/backends/implementations/spark/spark_helpers.py @@ -17,14 +17,16 @@ from delta.exceptions import ConcurrentAppendException, DeltaConcurrentModificationException from pydantic import BaseModel from pydantic.types import ConstrainedDecimal -from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import DataFrame, Row, SparkSession from pyspark.sql import functions as sf from pyspark.sql import types as st from pyspark.sql.column import Column from pyspark.sql.functions import lit, udf +from pyspark.sql.types import LongType, StructField, StructType from typing_extensions import Annotated, Protocol, TypedDict, get_args, get_origin, get_type_hints from dve.core_engine.backends.base.utilities import _get_non_heterogenous_type +from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME from dve.core_engine.type_hints import URI # It would be really nice if there was a more parameterisable @@ -410,3 +412,30 @@ def _inner(*args, **kwargs): return _inner return _wrapper + + +def _add_spark_record_index(self, entity: DataFrame) -> DataFrame: # pylint: disable=W0613 + """Add a record index to spark dataframe""" + if RECORD_INDEX_COLUMN_NAME in entity.columns: + return entity + schema: StructType = entity.schema + schema.add(StructField(RECORD_INDEX_COLUMN_NAME, LongType())) + return ( + entity.rdd.zipWithIndex() + .map(lambda x: Row(**x[0].asDict(True), RECORD_INDEX_COLUMN_NAME=x[1] + 1)) + .toDF(schema=schema) + ) + + +def _drop_spark_record_index(self, entity: DataFrame) -> DataFrame: # pylint: disable=W0613 + """Drop record index from spark dataframe""" + if not RECORD_INDEX_COLUMN_NAME in entity.columns: + return entity + return entity.drop(RECORD_INDEX_COLUMN_NAME) + + +def spark_record_index(cls): + """Class decorator to add record index methods for spark implementations""" + setattr(cls, "add_record_index", _add_spark_record_index) + setattr(cls, "drop_record_index", _drop_spark_record_index) + return cls diff --git a/src/dve/core_engine/backends/readers/csv.py b/src/dve/core_engine/backends/readers/csv.py index bc05b58..edd6bf0 100644 --- a/src/dve/core_engine/backends/readers/csv.py +++ b/src/dve/core_engine/backends/readers/csv.py @@ -16,6 +16,7 @@ MissingHeaderError, ) from dve.core_engine.backends.utilities import get_polars_type_from_annotation, stringify_model +from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME from dve.core_engine.type_hints import EntityName from dve.parser.file_handling import get_content_length, open_stream from dve.parser.file_handling.implementations.file import file_uri_to_local_path @@ -204,7 +205,9 @@ def read_to_py_iterator( ) coerce_func = partial(self._coerce, field_names=field_names) - yield from map(coerce_func, reader) + for idx, record in enumerate(map(coerce_func, reader), start=1): + record[RECORD_INDEX_COLUMN_NAME] = idx # type: ignore + yield record def write_parquet( # type: ignore self, @@ -223,6 +226,7 @@ def write_parquet( # type: ignore fld.name: get_polars_type_from_annotation(fld.annotation) for fld in stringify_model(schema).__fields__.values() } + polars_schema[RECORD_INDEX_COLUMN_NAME] = get_polars_type_from_annotation(int) pl.LazyFrame(data=entity, schema=polars_schema).sink_parquet( path=target_location, compression="snappy" diff --git a/src/dve/core_engine/backends/readers/xml.py b/src/dve/core_engine/backends/readers/xml.py index e7480f1..4620402 100644 --- a/src/dve/core_engine/backends/readers/xml.py +++ b/src/dve/core_engine/backends/readers/xml.py @@ -14,6 +14,7 @@ from dve.core_engine.backends.exceptions import EmptyFileError from dve.core_engine.backends.readers.xml_linting import run_xmllint from dve.core_engine.backends.utilities import get_polars_type_from_annotation, stringify_model +from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME from dve.core_engine.loggers import get_logger from dve.core_engine.message import FeedbackMessage from dve.core_engine.type_hints import URI, EntityName @@ -310,7 +311,9 @@ def read_to_py_iterator( raise EmptyFileError(f"File at {resource!r} is empty") with open_stream(resource, "rb") as stream: - yield from self._parse_xml(stream, schema) + for idx, record in enumerate(self._parse_xml(stream, schema), start=1): + record[RECORD_INDEX_COLUMN_NAME] = idx # type: ignore + yield record def write_parquet( # type: ignore self, @@ -329,6 +332,7 @@ def write_parquet( # type: ignore fld.name: get_polars_type_from_annotation(fld.type_) for fld in stringify_model(schema).__fields__.values() } + polars_schema[RECORD_INDEX_COLUMN_NAME] = get_polars_type_from_annotation(int) pl.LazyFrame(data=entity, schema=polars_schema).sink_parquet( path=target_location, compression="snappy", **kwargs ) diff --git a/src/dve/core_engine/backends/utilities.py b/src/dve/core_engine/backends/utilities.py index 9261806..62eb9e2 100644 --- a/src/dve/core_engine/backends/utilities.py +++ b/src/dve/core_engine/backends/utilities.py @@ -11,6 +11,7 @@ from pydantic import BaseModel, create_model from dve.core_engine.backends.base.utilities import _get_non_heterogenous_type +from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME from dve.core_engine.type_hints import Messages # We need to rely on a Python typing implementation detail in Python <= 3.7. @@ -175,3 +176,24 @@ def get_polars_type_from_annotation(type_annotation: Any) -> PolarsType: if polars_type: return polars_type raise ValueError(f"No equivalent DuckDB type for {type_annotation!r}") + + +def _add_polars_record_index(self, entity: pl.LazyFrame) -> pl.LazyFrame: # pylint: disable=W0613 + """Add a record index to polars dataframe""" + if RECORD_INDEX_COLUMN_NAME in entity.columns: + return entity + return entity.with_row_index(name=RECORD_INDEX_COLUMN_NAME, offset=1) + + +def _drop_polars_record_index(self, entity: pl.LazyFrame) -> pl.LazyFrame: # pylint: disable=W0613 + """Drop record index from polars dataframe""" + if not RECORD_INDEX_COLUMN_NAME in entity.columns: + return entity + return entity.drop(RECORD_INDEX_COLUMN_NAME) + + +def polars_record_index(cls): + """Class decorator to add record index methods for polars implementations""" + setattr(cls, "add_record_index", _add_polars_record_index) + setattr(cls, "drop_record_index", _drop_polars_record_index) + return cls diff --git a/src/dve/core_engine/constants.py b/src/dve/core_engine/constants.py index d452c9b..a2a4a65 100644 --- a/src/dve/core_engine/constants.py +++ b/src/dve/core_engine/constants.py @@ -1,7 +1,7 @@ """Constant values used in mutiple places.""" -ROWID_COLUMN_NAME: str = "__rowid__" -"""The name of the column containing the row ID for each entity.""" +RECORD_INDEX_COLUMN_NAME: str = "__record_index__" +"""The name of the column containing the record index for each entity.""" CONTRACT_ERROR_VALUE_FIELD_NAME: str = "__error_value" """The name of the field that can be used to extract the field value that caused diff --git a/src/dve/core_engine/engine.py b/src/dve/core_engine/engine.py index 28a2ac5..c5d1ba9 100644 --- a/src/dve/core_engine/engine.py +++ b/src/dve/core_engine/engine.py @@ -15,7 +15,7 @@ from dve.core_engine.backends.implementations.spark.types import SparkEntities from dve.core_engine.configuration.base import BaseEngineConfig from dve.core_engine.configuration.v1 import V1EngineConfig -from dve.core_engine.constants import ROWID_COLUMN_NAME +from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME from dve.core_engine.loggers import get_child_logger, get_logger from dve.core_engine.models import EngineRunValidation, SubmissionInfo from dve.core_engine.type_hints import EntityName, JSONstring @@ -200,7 +200,7 @@ def _write_entity_outputs(self, entities: SparkEntities) -> SparkEntities: self.main_log.info(f"Writing entities to the output location: {self.output_prefix_uri}") for entity_name, entity in entities.items(): - entity = entity.drop(ROWID_COLUMN_NAME) + entity = entity.drop(RECORD_INDEX_COLUMN_NAME) self.main_log.info(f"Entity: {entity_name} {type(entity)}") diff --git a/src/dve/core_engine/message.py b/src/dve/core_engine/message.py index f2a4e52..627ae3a 100644 --- a/src/dve/core_engine/message.py +++ b/src/dve/core_engine/message.py @@ -13,7 +13,7 @@ from pydantic import BaseModel, ValidationError, validator from pydantic.dataclasses import dataclass -from dve.core_engine.constants import CONTRACT_ERROR_VALUE_FIELD_NAME, ROWID_COLUMN_NAME +from dve.core_engine.constants import CONTRACT_ERROR_VALUE_FIELD_NAME, RECORD_INDEX_COLUMN_NAME from dve.core_engine.templating import template_object from dve.core_engine.type_hints import ( EntityName, @@ -116,6 +116,8 @@ class UserMessage: "The offending values" Category: ErrorCategory "The category of error" + RecordIndex: Optional[int] = None + "The record index where the error occurred (if applicable)" @property def is_informational(self) -> bool: @@ -187,6 +189,7 @@ class FeedbackMessage: # pylint: disable=too-many-instance-attributes "ErrorMessage", "ErrorCode", "ReportingField", + "RecordIndex", "Value", "Category", ] @@ -224,15 +227,6 @@ def _validate_error_location(cls, value: Any) -> Optional[str]: return str(value) - @validator("record") - def _strip_rowid( # pylint: disable=no-self-argument - cls, value: Optional[dict[str, Any]] - ) -> Optional[dict[str, Any]]: - """Strip the row ID column from the record, if present.""" - if isinstance(value, dict): - value.pop(ROWID_COLUMN_NAME, None) - return value - @property def is_critical(self) -> bool: """Whether the error is unrecoverable.""" @@ -333,6 +327,7 @@ def to_row( error_message, self.error_code, self.reporting_field_name or reporting_field, + (self.record.get(RECORD_INDEX_COLUMN_NAME) if self.record else None), value, self.category, ) diff --git a/src/dve/core_engine/type_hints.py b/src/dve/core_engine/type_hints.py index afb6d9d..3112e28 100644 --- a/src/dve/core_engine/type_hints.py +++ b/src/dve/core_engine/type_hints.py @@ -135,6 +135,8 @@ """The value that caused the error.""" ErrorCategory = Literal["Blank", "Wrong format", "Bad value", "Bad file"] """A string indicating the category of the error.""" +RecordIndex = Optional[int] +"""The record index that the error relates to (if applicable)""" MessageTuple = tuple[ Optional[EntityName], @@ -146,6 +148,7 @@ ErrorMessage, ErrorCode, ReportingField, + RecordIndex, Optional[FieldValue], Optional[ErrorCategory], ] diff --git a/src/dve/metadata_parser/models.py b/src/dve/metadata_parser/models.py index 18cdc68..73e6b5c 100644 --- a/src/dve/metadata_parser/models.py +++ b/src/dve/metadata_parser/models.py @@ -371,6 +371,7 @@ class Config(pyd.BaseConfig): fields = self.aliases # type: ignore anystr_strip_whitespace = True allow_population_by_field_name = True + extra = pyd.Extra.ignore return pyd.create_model( # type: ignore model_name, diff --git a/src/dve/pipeline/pipeline.py b/src/dve/pipeline/pipeline.py index 46a89c2..b14ada1 100644 --- a/src/dve/pipeline/pipeline.py +++ b/src/dve/pipeline/pipeline.py @@ -432,7 +432,9 @@ def apply_data_contract( for path, _ in fh.iter_prefix(read_from): entity_locations[fh.get_file_name(path)] = path - entities[fh.get_file_name(path)] = self.data_contract.read_parquet(path) + entities[fh.get_file_name(path)] = self.data_contract.add_record_index( + self.data_contract.read_parquet(path) + ) key_fields = {model: conf.reporting_fields for model, conf in model_config.items()} @@ -563,8 +565,9 @@ def apply_business_rules( for parquet_uri, _ in fh.iter_prefix(contract): file_name = fh.get_file_name(parquet_uri) - entities[file_name] = self.step_implementations.read_parquet(parquet_uri) # type: ignore - entities[file_name] = self.step_implementations.add_row_id(entities[file_name]) # type: ignore + entities[file_name] = self.step_implementations.add_record_index( # type: ignore + self.step_implementations.read_parquet(parquet_uri) # type: ignore + ) entities[f"Original{file_name}"] = self.step_implementations.read_parquet(parquet_uri) # type: ignore sub_info_entity = ( @@ -742,6 +745,7 @@ def _get_error_dataframes(self, submission_id: str): pl.col("ErrorCode").alias("Error_Code"), # type: ignore pl.col("ReportingField").alias("Data_Item"), # type: ignore pl.col("ErrorMessage").alias("Error"), # type: ignore + pl.col("RecordIndex").alias("Record_Index"), pl.col("Value"), # type: ignore pl.col("Key").alias("ID"), # type: ignore pl.col("Category"), # type: ignore diff --git a/src/dve/reporting/error_report.py b/src/dve/reporting/error_report.py index 8852fcb..9e947bf 100644 --- a/src/dve/reporting/error_report.py +++ b/src/dve/reporting/error_report.py @@ -18,6 +18,7 @@ "Error_Code": Utf8(), "Data_Item": Utf8(), "Error": Utf8(), + "Record_Index": pl.UInt32(), "Value": Utf8(), "ID": Utf8(), "Category": Utf8(), diff --git a/tests/features/books.feature b/tests/features/books.feature index f13658a..60cc5db 100644 --- a/tests/features/books.feature +++ b/tests/features/books.feature @@ -4,33 +4,6 @@ Feature: Pipeline tests using the books dataset This tests submissions using nested, complex JSON datasets with arrays, and introduces more complex transformations that require aggregation. - Scenario: Validate complex nested XML data (spark) - Given I submit the books file nested_books.XML for processing - And A spark pipeline is configured with schema file 'nested_books.dischema.json' - And I add initial audit entries for the submission - Then the latest audit record for the submission is marked with processing status file_transformation - When I run the file transformation phase - Then the header entity is stored as a parquet after the file_transformation phase - And the nested_books entity is stored as a parquet after the file_transformation phase - And the latest audit record for the submission is marked with processing status data_contract - When I run the data contract phase - Then there is 1 record rejection from the data_contract phase - And the header entity is stored as a parquet after the data_contract phase - And the nested_books entity is stored as a parquet after the data_contract phase - And the latest audit record for the submission is marked with processing status business_rules - When I run the business rules phase - Then The rules restrict "nested_books" to 3 qualifying records - And The entity "nested_books" contains an entry for "17.85" in column "total_value_of_books" - And the nested_books entity is stored as a parquet after the business_rules phase - And the latest audit record for the submission is marked with processing status error_report - When I run the error report phase - Then An error report is produced - And The statistics entry for the submission shows the following information - | parameter | value | - | record_count | 4 | - | number_record_rejections | 2 | - | number_warnings | 0 | - Scenario: Validate complex nested XML data (duckdb) Given I submit the books file nested_books.XML for processing And A duckdb pipeline is configured with schema file 'nested_books_ddb.dischema.json' diff --git a/tests/features/movies.feature b/tests/features/movies.feature index d737574..fa041ea 100644 --- a/tests/features/movies.feature +++ b/tests/features/movies.feature @@ -21,18 +21,18 @@ Feature: Pipeline tests using the movies dataset When I run the data contract phase Then there are 3 record rejections from the data_contract phase And there are errors with the following details and associated error_count from the data_contract phase - | Entity | ErrorCode | ErrorMessage | error_count | - | movies | BLANKYEAR | year not provided | 1 | - | movies_rename_test | DODGYYEAR | year value (NOT_A_NUMBER) is invalid | 1 | - | movies | DODGYDATE | date_joined value is not valid: daft_date | 1 | + | Entity | ErrorCode | ErrorMessage | RecordIndex | error_count | + | movies | BLANKYEAR | year not provided | 2 | 1 | + | movies_rename_test | DODGYYEAR | year value (NOT_A_NUMBER) is invalid | 1 | 1 | + | movies | DODGYDATE | date_joined value is not valid: daft_date | 1 | 1 | And the movies entity is stored as a parquet after the data_contract phase And the latest audit record for the submission is marked with processing status business_rules When I run the business rules phase Then The rules restrict "movies" to 4 qualifying records And there are errors with the following details and associated error_count from the business_rules phase - | ErrorCode | ErrorMessage | error_count | - | LIMITED_RATINGS | Movie has too few ratings ([6.5]) | 1 | - | RUBBISH_SEQUEL | The movie The Greatest Movie Ever has a rubbish sequel | 1 | + | ErrorCode | ErrorMessage | RecordIndex | error_count | + | LIMITED_RATINGS | Movie has too few ratings ([6.5]) | 4 | 1 | + | RUBBISH_SEQUEL | The movie The Greatest Movie Ever has a rubbish sequel | 1 | 1 | And the latest audit record for the submission is marked with processing status error_report When I run the error report phase Then An error report is produced @@ -57,18 +57,18 @@ Feature: Pipeline tests using the movies dataset When I run the data contract phase Then there are 3 record rejections from the data_contract phase And there are errors with the following details and associated error_count from the data_contract phase - | Entity | ErrorCode | ErrorMessage | error_count | - | movies | BLANKYEAR | year not provided | 1 | - | movies_rename_test | DODGYYEAR | year value (NOT_A_NUMBER) is invalid | 1 | - | movies | DODGYDATE | date_joined value is not valid: daft_date | 1 | + | Entity | ErrorCode | ErrorMessage | RecordIndex | error_count | + | movies | BLANKYEAR | year not provided | 2 | 1 | + | movies_rename_test | DODGYYEAR | year value (NOT_A_NUMBER) is invalid | 1 | 1 | + | movies | DODGYDATE | date_joined value is not valid: daft_date | 1 | 1 | And the movies entity is stored as a parquet after the data_contract phase And the latest audit record for the submission is marked with processing status business_rules When I run the business rules phase Then The rules restrict "movies" to 4 qualifying records And there are errors with the following details and associated error_count from the business_rules phase - | ErrorCode | ErrorMessage | error_count | - | LIMITED_RATINGS | Movie has too few ratings ([6.5]) | 1 | - | RUBBISH_SEQUEL | The movie The Greatest Movie Ever has a rubbish sequel | 1 | + | ErrorCode | ErrorMessage | RecordIndex | error_count | + | LIMITED_RATINGS | Movie has too few ratings ([6.5]) | 4 | 1 | + | RUBBISH_SEQUEL | The movie The Greatest Movie Ever has a rubbish sequel | 1 | 1 | And the latest audit record for the submission is marked with processing status error_report When I run the error report phase Then An error report is produced diff --git a/tests/features/steps/utilities.py b/tests/features/steps/utilities.py index aa9adc1..58edc67 100644 --- a/tests/features/steps/utilities.py +++ b/tests/features/steps/utilities.py @@ -23,6 +23,7 @@ "ErrorType", "ErrorLocation", "ErrorMessage", + "RecordIndex", "ReportingField", "Category", ] diff --git a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_data_contract.py b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_data_contract.py index 0300808..d382ecb 100644 --- a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_data_contract.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_data_contract.py @@ -14,6 +14,7 @@ from dve.core_engine.backends.implementations.duckdb.readers.xml import DuckDBXMLStreamReader from dve.core_engine.backends.metadata.contract import DataContractMetadata, ReaderConfig from dve.core_engine.backends.utilities import stringify_model +from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME from dve.core_engine.message import UserMessage from dve.core_engine.type_hints import URI from dve.core_engine.validation import RowValidator @@ -93,10 +94,12 @@ def test_duckdb_data_contract_csv(temp_csv_file): data_contract: DuckDBDataContract = DuckDBDataContract(connection) entities, feedback_errors_uri, stage_successful = data_contract.apply_data_contract(get_parent(uri.as_posix()), entities, entity_locations, dc_meta) rel: DuckDBPyRelation = entities.get("test_ds") - assert dict(zip(rel.columns, rel.dtypes)) == { + expected_schema = { fld.name: str(get_duckdb_type_from_annotation(fld.annotation)) for fld in mdl.__fields__.values() } + expected_schema[RECORD_INDEX_COLUMN_NAME] = get_duckdb_type_from_annotation(int) + assert dict(zip(rel.columns, rel.dtypes)) == expected_schema assert not get_resource_exists(feedback_errors_uri) assert stage_successful @@ -195,10 +198,12 @@ def test_duckdb_data_contract_xml(temp_xml_file): fld.name: get_duckdb_type_from_annotation(fld.type_) for fld in header_model.__fields__.values() } + header_expected_schema[RECORD_INDEX_COLUMN_NAME] = get_duckdb_type_from_annotation(int) class_data_expected_schema: Dict[str, DuckDBPyType] = { fld.name: get_duckdb_type_from_annotation(fld.type_) for fld in class_model.__fields__.values() } + class_data_expected_schema[RECORD_INDEX_COLUMN_NAME] = get_duckdb_type_from_annotation(int) class_data_rel: DuckDBPyRelation = entities.get("test_class_info") assert not get_resource_exists(feedback_errors_uri) assert header_rel.count("*").fetchone()[0] == 1 @@ -223,7 +228,7 @@ def test_ddb_data_contract_read_and_write_basic_parquet( "id": "VARCHAR", "datefield": "VARCHAR", "strfield": "VARCHAR", - "datetimefield": "VARCHAR", + "datetimefield": "VARCHAR" } # check processes entity contract_dict = json.loads(contract_meta).get("contract") @@ -266,6 +271,7 @@ def test_ddb_data_contract_read_and_write_basic_parquet( "datefield": "DATE", "strfield": "VARCHAR", "datetimefield": "TIMESTAMP", + RECORD_INDEX_COLUMN_NAME: get_duckdb_type_from_annotation(int) } @@ -282,7 +288,7 @@ def test_ddb_data_contract_read_nested_parquet(nested_all_string_parquet): "id": "VARCHAR", "strfield": "VARCHAR", "datetimefield": "VARCHAR", - "subfield": "STRUCT(id VARCHAR, substrfield VARCHAR, subarrayfield VARCHAR[])[]", + "subfield": "STRUCT(id VARCHAR, substrfield VARCHAR, subarrayfield VARCHAR[])[]" } # check processes entity contract_dict = json.loads(contract_meta).get("contract") @@ -325,6 +331,7 @@ def test_ddb_data_contract_read_nested_parquet(nested_all_string_parquet): "strfield": "VARCHAR", "datetimefield": "TIMESTAMP", "subfield": "STRUCT(id BIGINT, substrfield VARCHAR, subarrayfield DATE[])[]", + RECORD_INDEX_COLUMN_NAME: get_duckdb_type_from_annotation(int) } def test_duckdb_data_contract_custom_error_details(nested_all_string_parquet_w_errors, diff --git a/tests/test_core_engine/test_backends/test_implementations/test_spark/test_data_contract.py b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_data_contract.py index 921c9be..70c6b9c 100644 --- a/tests/test_core_engine/test_backends/test_implementations/test_spark/test_data_contract.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_data_contract.py @@ -16,6 +16,7 @@ from dve.core_engine.backends.implementations.spark.contract import SparkDataContract from dve.core_engine.backends.metadata.contract import DataContractMetadata, ReaderConfig +from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME from dve.core_engine.message import UserMessage from dve.core_engine.type_hints import URI from dve.core_engine.validation import RowValidator @@ -89,6 +90,7 @@ def test_spark_data_contract_read_and_write_basic_parquet( StructField("datefield", DateType()), StructField("strfield", StringType()), StructField("datetimefield", TimestampType()), + StructField(RECORD_INDEX_COLUMN_NAME, LongType()) ] ) @@ -173,6 +175,7 @@ def test_spark_data_contract_read_nested_parquet(nested_all_string_parquet): ) ), ), + StructField(RECORD_INDEX_COLUMN_NAME, LongType()) ] ) diff --git a/tests/test_core_engine/test_backends/test_readers/test_csv.py b/tests/test_core_engine/test_backends/test_readers/test_csv.py index 4cd7e07..0737ad2 100644 --- a/tests/test_core_engine/test_backends/test_readers/test_csv.py +++ b/tests/test_core_engine/test_backends/test_readers/test_csv.py @@ -12,6 +12,7 @@ from dve.core_engine.backends.exceptions import EmptyFileError, FieldCountMismatch from dve.core_engine.backends.readers import CSVFileReader +from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME from ....conftest import get_test_file_path from ....fixtures import temp_dir @@ -25,10 +26,13 @@ def planet_location() -> Iterator[str]: @pytest.fixture(scope="function") def planet_data() -> Iterator[Dict[str, Dict[str, str]]]: - """The planet data, as loaded by Python's default parser.""" + """The expected planet data after reading, as loaded by Python's default parser.""" with get_test_file_path("planets/planets.csv").open("r", encoding="utf-8") as file: reader = csv.DictReader(file) - yield {row["planet"]: row for row in reader} + data = {} + for idx, row in enumerate(reader, start=1): + data[row["planet"]] = {RECORD_INDEX_COLUMN_NAME: idx, **row} + yield data @pytest.fixture(scope="function") @@ -138,7 +142,7 @@ def test_csv_file_get_subset( # Keep only keys in the subset from the source subset_keys = set(PlanetsSubset.__fields__.keys()) for data in planet_data.values(): - to_pop = set(data.keys()) - subset_keys + to_pop = set(data.keys()) - subset_keys - {RECORD_INDEX_COLUMN_NAME} for key in to_pop: del data[key] @@ -160,7 +164,7 @@ def test_csv_file_get_subset_add_missing( # Keep only keys in the subset from the source subset_keys = set(PlanetsSubset.__fields__.keys()) for data in planet_data.values(): - to_pop = set(data.keys()) - subset_keys + to_pop = set(data.keys()) - subset_keys - {RECORD_INDEX_COLUMN_NAME} for key in to_pop: del data[key] data["random_null"] = None # type: ignore @@ -182,7 +186,10 @@ def test_csv_file_filled_from_provided( results = list(reader.read_to_py_iterator(planet_location, "", Planets)) parsed = {row["planet"]: row for row in results} del parsed["planet"] + for rec in parsed.values(): + rec[RECORD_INDEX_COLUMN_NAME] -= 1 assert parsed == planet_data + def test_csv_file_raises_missing_cols(self, planet_location: str): """ @@ -235,7 +242,7 @@ def test_csv_file_can_be_pipe_delimited( """Test that a pipe-delimited CSV file can be parsed.""" reader = CSVFileReader(delimiter="|") results = list(reader.read_to_py_iterator(pipe_delimited_location, "", BasicModel)) - assert results == [{"ColumnA": "1", "ColumnB": "2", "ColumnC": "3"}] + assert results == [{"ColumnA": "1", "ColumnB": "2", "ColumnC": "3", RECORD_INDEX_COLUMN_NAME: 1}] @pytest.mark.parametrize(["schema"], [(None,), (Planets,)]) def test_base_csv_reader_parquet_write( @@ -252,5 +259,5 @@ def test_base_csv_reader_parquet_write( reader.write_parquet(entity=entity, target_location=target_location, schema=schema) assert sorted( pd.read_parquet(target_location).to_dict(orient="records"), - key=lambda x: x.get("planet"), - ) == sorted([dict(val) for val in planet_data.values()], key=lambda x: x.get("planet")) + key=lambda x: x.get(RECORD_INDEX_COLUMN_NAME), + ) == sorted([dict(val) for val in planet_data.values()], key=lambda x: x.get(RECORD_INDEX_COLUMN_NAME)) diff --git a/tests/test_core_engine/test_backends/test_readers/test_ddb_csv.py b/tests/test_core_engine/test_backends/test_readers/test_ddb_csv.py index f195a0d..f364045 100644 --- a/tests/test_core_engine/test_backends/test_readers/test_ddb_csv.py +++ b/tests/test_core_engine/test_backends/test_readers/test_ddb_csv.py @@ -17,6 +17,7 @@ PolarsToDuckDBCSVReader, ) from dve.core_engine.backends.utilities import stringify_model +from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME from tests.test_core_engine.test_backends.fixtures import duckdb_connection # pylint: disable=C0116 @@ -74,21 +75,25 @@ def test_ddb_csv_reader_all_str(temp_csv_file): rel: DuckDBPyRelation = reader.read_to_entity_type( DuckDBPyRelation, str(uri), "test", stringify_model(mdl) ) - assert rel.columns == header.split(",") - assert dict(zip(rel.columns, rel.dtypes)) == {fld: "VARCHAR" for fld in header.split(",")} - assert rel.fetchall() == [tuple(str(val) for val in rw) for rw in data] + expected_dtypes = {**{fld: "VARCHAR" for fld in header.split(",")}, RECORD_INDEX_COLUMN_NAME: "BIGINT"} + expected_data = [(*[str(val) for val in rw], idx) for idx, rw in enumerate(data, start=1)] + assert rel.columns == header.split(",") + [RECORD_INDEX_COLUMN_NAME] + assert dict(zip(rel.columns, rel.dtypes)) == expected_dtypes + assert rel.fetchall() == expected_data def test_ddb_csv_reader_cast(temp_csv_file): uri, header, data, mdl = temp_csv_file reader = DuckDBCSVReader(header=True, delim=",", connection=default_connection) rel: DuckDBPyRelation = reader.read_to_entity_type(DuckDBPyRelation, str(uri), "test", mdl) - assert rel.columns == header.split(",") - assert dict(zip(rel.columns, rel.dtypes)) == { + expected_dtypes = {**{ fld.name: str(get_duckdb_type_from_annotation(fld.annotation)) for fld in mdl.__fields__.values() - } - assert rel.fetchall() == [tuple(rw) for rw in data] + }, RECORD_INDEX_COLUMN_NAME: get_duckdb_type_from_annotation(int)} + expected_data = [(*rw, idx) for idx, rw in enumerate(data, start=1)] + assert rel.columns == header.split(",") + [RECORD_INDEX_COLUMN_NAME] + assert dict(zip(rel.columns, rel.dtypes)) == expected_dtypes + assert rel.fetchall() == expected_data def test_ddb_csv_write_parquet(temp_csv_file): @@ -100,7 +105,7 @@ def test_ddb_csv_write_parquet(temp_csv_file): target_loc: Path = uri.parent.joinpath("test_parquet.parquet").as_posix() reader.write_parquet(rel, target_loc) parquet_rel = reader._connection.read_parquet(target_loc) - assert parquet_rel.df().to_dict(orient="records") == rel.df().to_dict(orient="records") + assert sorted(parquet_rel.df().to_dict(orient="records"), key=lambda x: x.get(RECORD_INDEX_COLUMN_NAME)) == sorted([{**rec, RECORD_INDEX_COLUMN_NAME: idx} for idx, rec in enumerate(rel.df().to_dict(orient="records"), start=1)], key=lambda x: x.get(RECORD_INDEX_COLUMN_NAME)) def test_ddb_csv_read_empty_file(temp_empty_csv_file): diff --git a/tests/test_core_engine/test_backends/test_readers/test_ddb_json.py b/tests/test_core_engine/test_backends/test_readers/test_ddb_json.py index c326fef..6942c6a 100644 --- a/tests/test_core_engine/test_backends/test_readers/test_ddb_json.py +++ b/tests/test_core_engine/test_backends/test_readers/test_ddb_json.py @@ -13,6 +13,7 @@ ) from dve.core_engine.backends.implementations.duckdb.readers.json import DuckDBJSONReader from dve.core_engine.backends.utilities import stringify_model +from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME from tests.test_core_engine.test_backends.fixtures import duckdb_connection @@ -59,9 +60,9 @@ def test_ddb_json_reader_all_str(temp_json_file): rel: DuckDBPyRelation = reader.read_to_entity_type( DuckDBPyRelation, uri.as_posix(), "test", stringify_model(mdl) ) - assert rel.columns == expected_fields - assert dict(zip(rel.columns, rel.dtypes)) == {fld: "VARCHAR" for fld in expected_fields} - assert rel.fetchall() == [tuple(str(val) for val in rw.values()) for rw in data] + assert rel.columns == expected_fields + [RECORD_INDEX_COLUMN_NAME] + assert dict(zip(rel.columns, rel.dtypes)) == {**{fld: "VARCHAR" for fld in expected_fields}, RECORD_INDEX_COLUMN_NAME: "BIGINT"} + assert rel.fetchall() == [(*[str(val) for val in rw.values()], idx) for idx, rw in enumerate(data, start=1)] def test_ddb_json_reader_cast(temp_json_file): @@ -70,15 +71,15 @@ def test_ddb_json_reader_cast(temp_json_file): reader = DuckDBJSONReader() rel: DuckDBPyRelation = reader.read_to_entity_type(DuckDBPyRelation, uri.as_posix(), "test", mdl) - assert rel.columns == expected_fields - assert dict(zip(rel.columns, rel.dtypes)) == { + assert rel.columns == expected_fields + [RECORD_INDEX_COLUMN_NAME] + assert dict(zip(rel.columns, rel.dtypes)) == {**{ fld.name: str(get_duckdb_type_from_annotation(fld.annotation)) for fld in mdl.__fields__.values() - } - assert rel.fetchall() == [tuple(rw.values()) for rw in data] + }, RECORD_INDEX_COLUMN_NAME: "BIGINT"} + assert rel.fetchall() == [(*rw.values(), idx) for idx, rw in enumerate(data, start = 1)] -def test_ddb_csv_write_parquet(temp_json_file): +def test_ddb_json_write_parquet(temp_json_file): uri, _, mdl = temp_json_file reader = DuckDBJSONReader() rel: DuckDBPyRelation = reader.read_to_entity_type( diff --git a/tests/test_core_engine/test_backends/test_readers/test_ddb_xml.py b/tests/test_core_engine/test_backends/test_readers/test_ddb_xml.py index dad5b06..585f7b7 100644 --- a/tests/test_core_engine/test_backends/test_readers/test_ddb_xml.py +++ b/tests/test_core_engine/test_backends/test_readers/test_ddb_xml.py @@ -9,6 +9,7 @@ from pydantic import BaseModel from dve.core_engine.backends.implementations.duckdb.readers.xml import DuckDBXMLStreamReader +from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME @pytest.fixture @@ -19,15 +20,15 @@ def temp_dir(): @pytest.fixture def temp_xml_file(temp_dir: Path): - header_data: Dict[str, str] = { + header_data: list[dict[str, str]] = [{ "school_name": "Meadow Fields", "category": "Primary", "headteacher": "Mrs Smith", - } - class_data: Dict[str, Dict[str, str]] = { + }] + class_data: list[dict[str, Dict[str, str]]] = [{ "year_1": {"class_size": "10", "teacher": "Mrs Armitage"}, "year_2": {"class_size": "12", "teacher": "Mr Barney"}, - } + }] class HeaderModel(BaseModel): school_name: str @@ -44,16 +45,17 @@ class ClassDataModel(BaseModel): root = ET.Element("root") header = ET.SubElement(root, "Header") - for nm, val in header_data.items(): + for nm, val in header_data[0].items(): _tag = ET.SubElement(header, nm) _tag.text = val - data = ET.SubElement(root, "ClassData") - for nm, val in class_data.items(): - _parent_tag = ET.SubElement(data, nm) - for sub_nm, sub_val in val.items(): - _child_tag = ET.SubElement(_parent_tag, sub_nm) - _child_tag.text = sub_val + for dta in class_data: + data = ET.SubElement(root, "ClassData") + for nm, val in dta.items(): + _parent_tag = ET.SubElement(data, nm) + for sub_nm, sub_val in val.items(): + _child_tag = ET.SubElement(_parent_tag, sub_nm) + _child_tag.text = sub_val with open(temp_dir.joinpath("test.xml"), mode="wb") as xml_fle: xml_fle.write(ET.tostring(root)) @@ -76,10 +78,12 @@ def test_ddb_xml_reader_all_str(temp_xml_file): class_rel: DuckDBPyRelation = class_reader.read_to_relation( uri.as_uri(), "class_data", class_data_model ) + expected_header = [{**recs, RECORD_INDEX_COLUMN_NAME: idx} for idx, recs in enumerate(header_data, start=1)] + expected_class = [{**recs, RECORD_INDEX_COLUMN_NAME: idx} for idx, recs in enumerate(class_data, start=1)] assert header_rel.count("*").fetchone()[0] == 1 - assert header_rel.df().to_dict("records")[0] == header_data + assert header_rel.df().to_dict("records") == expected_header assert class_rel.count("*").fetchone()[0] == 1 - assert class_rel.df().to_dict("records")[0] == class_data + assert class_rel.df().to_dict("records") == expected_class def test_ddb_xml_reader_write_parquet(temp_xml_file): diff --git a/tests/test_core_engine/test_backends/test_readers/test_spark_json.py b/tests/test_core_engine/test_backends/test_readers/test_spark_json.py index 3cbecb8..24674ca 100644 --- a/tests/test_core_engine/test_backends/test_readers/test_spark_json.py +++ b/tests/test_core_engine/test_backends/test_readers/test_spark_json.py @@ -7,13 +7,14 @@ import pytest from pydantic import BaseModel from pyspark.sql import DataFrame -from pyspark.sql.types import StructType, StructField, StringType +from pyspark.sql.types import LongType, StructType, StructField, StringType from dve.core_engine.backends.implementations.spark.spark_helpers import ( get_type_from_annotation, ) from dve.core_engine.backends.implementations.spark.readers.json import SparkJSONReader from dve.core_engine.backends.utilities import stringify_model +from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME class SimpleModel(BaseModel): @@ -54,25 +55,25 @@ class SimpleModel(BaseModel): def test_spark_json_reader_all_str(temp_json_file): uri, data, mdl = temp_json_file - expected_fields = [fld for fld in mdl.__fields__] + expected_fields = [fld for fld in mdl.__fields__] + [RECORD_INDEX_COLUMN_NAME] reader = SparkJSONReader() df: DataFrame = reader.read_to_entity_type( DataFrame, uri.as_posix(), "test", stringify_model(mdl) ) assert df.columns == expected_fields - assert df.schema == StructType([StructField(nme, StringType()) for nme in expected_fields]) - assert [rw.asDict() for rw in df.collect()] == [{k: str(v) for k, v in rw.items()} for rw in data] + assert df.schema == StructType([StructField(nme, StringType() if not nme == RECORD_INDEX_COLUMN_NAME else LongType()) for nme in expected_fields]) + assert [rw.asDict() for rw in df.collect()] == [{**{k: str(v) for k, v in rw.items()}, RECORD_INDEX_COLUMN_NAME: idx} for idx, rw in enumerate(data, start=1)] def test_spark_json_reader_cast(temp_json_file): uri, data, mdl = temp_json_file - expected_fields = [fld for fld in mdl.__fields__] + expected_fields = [fld for fld in mdl.__fields__] + [RECORD_INDEX_COLUMN_NAME] reader = SparkJSONReader() df: DataFrame = reader.read_to_entity_type(DataFrame, uri.as_posix(), "test", mdl) assert df.columns == expected_fields assert df.schema == StructType([StructField(fld.name, get_type_from_annotation(fld.annotation)) - for fld in mdl.__fields__.values()]) - assert [rw.asDict() for rw in df.collect()] == data + for fld in mdl.__fields__.values()] + [StructField(RECORD_INDEX_COLUMN_NAME, get_type_from_annotation(int))]) + assert [rw.asDict() for rw in df.collect()] == [{**rw, RECORD_INDEX_COLUMN_NAME: idx} for idx, rw in enumerate(data, start=1)] def test_spark_json_write_parquet(spark, temp_json_file): diff --git a/tests/test_core_engine/test_message.py b/tests/test_core_engine/test_message.py index edf89fc..ccb6736 100644 --- a/tests/test_core_engine/test_message.py +++ b/tests/test_core_engine/test_message.py @@ -8,20 +8,8 @@ from pydantic import BaseModel, ValidationError import pytest -from dve.core_engine.constants import ROWID_COLUMN_NAME from dve.core_engine.message import DEFAULT_ERROR_DETAIL, DataContractErrorDetail, FeedbackMessage - -def test_rowid_column_stripped(): - """Ensure that the rowID column is stripped from FeedbackMessages.""" - - message = FeedbackMessage( - entity="entity", record={"key": "value", ROWID_COLUMN_NAME: "some identifier"} - ) - - assert message.record.get(ROWID_COLUMN_NAME) is None - - @pytest.mark.parametrize( ("derived_column", "expected"), [ diff --git a/tests/test_pipeline/pipeline_helpers.py b/tests/test_pipeline/pipeline_helpers.py index ddd4ef8..b13bef3 100644 --- a/tests/test_pipeline/pipeline_helpers.py +++ b/tests/test_pipeline/pipeline_helpers.py @@ -152,6 +152,7 @@ def dodgy_planet_data_after_file_transformation() -> Iterator[Tuple[SubmissionIn "numberOfMoons": "-1", "hasRingSystem": "false", "hasGlobalMagneticField": "sometimes", + "__record_index__": "1" } planet_contract_df = pl.DataFrame( planet_contract_data, {k: pl.Utf8() for k in planet_contract_data} @@ -381,7 +382,8 @@ def error_data_after_business_rules() -> Iterator[Tuple[SubmissionInfo, str]]: "ErrorCode": "LONG_ORBIT", "ReportingField": "orbitalPeriod", "Value": "365.20001220703125", - "Category": "Bad value" + "Category": "Bad value", + "RecordIndex": "1" }, { "Entity": "planets", @@ -394,7 +396,8 @@ def error_data_after_business_rules() -> Iterator[Tuple[SubmissionInfo, str]]: "ErrorCode": "STRONG_GRAVITY", "ReportingField": "gravity", "Value": "9.800000190734863", - "Category": "Bad value" + "Category": "Bad value", + "RecordIndex": "1" } ]""" ) diff --git a/tests/test_pipeline/test_spark_pipeline.py b/tests/test_pipeline/test_spark_pipeline.py index 910626a..262d84f 100644 --- a/tests/test_pipeline/test_spark_pipeline.py +++ b/tests/test_pipeline/test_spark_pipeline.py @@ -175,6 +175,7 @@ def test_apply_data_contract_failed( # pylint: disable=redefined-outer-name "ErrorMessage": "is invalid", "ErrorCode": "BadValue", "ReportingField": "planet", + "RecordIndex": "1", "Value": "EarthEarthEarthEarthEarthEarthEarthEarthEarth", "Category": "Bad value", }, @@ -188,6 +189,7 @@ def test_apply_data_contract_failed( # pylint: disable=redefined-outer-name "ErrorMessage": "is invalid", "ErrorCode": "BadValue", "ReportingField": "numberOfMoons", + "RecordIndex": "1", "Value": "-1", "Category": "Bad value", }, @@ -201,6 +203,7 @@ def test_apply_data_contract_failed( # pylint: disable=redefined-outer-name "ErrorMessage": "is invalid", "ErrorCode": "BadValue", "ReportingField": "hasGlobalMagneticField", + "RecordIndex": "1", "Value": "sometimes", "Category": "Bad value", }, @@ -347,6 +350,7 @@ def test_apply_business_rules_with_data_errors( # pylint: disable=redefined-out "ReportingField": "orbitalPeriod", "Value": "365.20001220703125", "Category": "Bad value", + "RecordIndex": "1" }, { "Entity": "planets", @@ -360,6 +364,7 @@ def test_apply_business_rules_with_data_errors( # pylint: disable=redefined-out "ReportingField": "gravity", "Value": "9.800000190734863", "Category": "Bad value", + "RecordIndex": "1" }, ] @@ -504,6 +509,7 @@ def test_error_report_where_report_is_expected( # pylint: disable=redefined-out "Error Code": "LONG_ORBIT", "Data Item Submission Name": "orbitalPeriod", "Errors and Warnings": "Planet has long orbital period", + "Record Index": 1, "Value": 365.20001220703125, "ID": None, "Category": "Bad value", @@ -516,6 +522,7 @@ def test_error_report_where_report_is_expected( # pylint: disable=redefined-out "Error Code": "STRONG_GRAVITY", "Data Item Submission Name": "gravity", "Errors and Warnings": "Planet has too strong gravity", + "Record Index": 1, "Value": 9.800000190734863, "ID": None, "Category": "Bad value", diff --git a/tests/testdata/movies/movies_ddb_rule_store.json b/tests/testdata/movies/movies_ddb_rule_store.json index 843d4fa..6a51fd6 100644 --- a/tests/testdata/movies/movies_ddb_rule_store.json +++ b/tests/testdata/movies/movies_ddb_rule_store.json @@ -61,7 +61,7 @@ "name": "Get median sequel rating", "operation": "group_by", "entity": "with_sequels", - "group_by": "title", + "group_by": ["__record_index__", "title"], "agg_columns": { "list_aggregate(sequel_rating, 'median')": "median_sequel_rating" } diff --git a/tests/testdata/movies/movies_spark_rule_store.json b/tests/testdata/movies/movies_spark_rule_store.json index 08ad641..e8204c5 100644 --- a/tests/testdata/movies/movies_spark_rule_store.json +++ b/tests/testdata/movies/movies_spark_rule_store.json @@ -63,6 +63,7 @@ "entity": "with_sequels", "columns": { "title": "title", + "__record_index__": "__record_index__", "explode(sequel_rating)": "sequel_rating" } }, @@ -70,7 +71,7 @@ "name": "Get median sequel rating", "operation": "group_by", "entity": "with_sequels", - "group_by": "title", + "group_by": ["__record_index__","title"], "agg_columns": { "percentile_approx(sequel_rating, 0.5)": "median_sequel_rating" } diff --git a/tests/testdata/planets/planets.dischema.json b/tests/testdata/planets/planets.dischema.json index 7a0387c..b44bb2e 100644 --- a/tests/testdata/planets/planets.dischema.json +++ b/tests/testdata/planets/planets.dischema.json @@ -114,8 +114,8 @@ }, { "entity": "planets", - "name": "has_row_id", - "expression": "__rowid__ IS NOT NULL" + "name": "has_record_index", + "expression": "__record_index__ IS NOT NULL" }, { "entity": "planets", diff --git a/tests/testdata/planets/planets_ddb.dischema.json b/tests/testdata/planets/planets_ddb.dischema.json index 51e6650..0869aad 100644 --- a/tests/testdata/planets/planets_ddb.dischema.json +++ b/tests/testdata/planets/planets_ddb.dischema.json @@ -115,7 +115,7 @@ { "entity": "planets", "name": "has_row_id", - "expression": "__rowid__ IS NOT NULL" + "expression": "__record_index__ IS NOT NULL" }, { "entity": "planets",