diff --git a/datafusion/physical-plan/src/spill/mod.rs b/datafusion/physical-plan/src/spill/mod.rs index 1c900b7579f7..0560cf74e15c 100644 --- a/datafusion/physical-plan/src/spill/mod.rs +++ b/datafusion/physical-plan/src/spill/mod.rs @@ -49,7 +49,7 @@ use arrow::record_batch::RecordBatch; use arrow_data::ArrayDataBuilder; use datafusion_common::config::SpillCompression; -use datafusion_common::{DataFusionError, Result, exec_datafusion_err}; +use datafusion_common::{DataFusionError, Result, exec_datafusion_err, exec_err}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::RecordBatchStream; use datafusion_execution::disk_manager::RefCountedTempFile; @@ -121,6 +121,7 @@ impl SpillReaderStream { unreachable!() }; + let expected_schema = Arc::clone(&self.schema); let task = SpawnedTask::spawn_blocking(move || { let file = BufReader::new(File::open(spill_file.path())?); // SAFETY: DataFusion's spill writer strictly follows Arrow IPC specifications @@ -130,6 +131,21 @@ impl SpillReaderStream { StreamReader::try_new(file, None)?.with_skip_validation(true) }; + // Validate the schema read from Arrow IPC file is the same as the + // schema of the current `SpillManager` + let actual_schema = reader.schema(); + + if actual_schema != expected_schema { + return exec_err!( + "Spill file schema mismatch: expected {}, got {}. \ + The caller must use the same SpillManager that created the spill file to read it.", + expected_schema, + actual_schema + ); + } + + // TODO: Same-schema reads from a different SpillManager still pass today. + // Add a SpillManager UID to IPC metadata and validate it here as well. let next_batch = reader.next().transpose()?; Ok((reader, next_batch)) diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index 1664256e6588..365a9f977eac 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -161,7 +161,7 @@ impl SpillManager { } /// Reads a spill file as a stream. The file must be created by the current - /// `SpillManager`; otherwise behavior is undefined. + /// `SpillManager`; otherwise an error will be returned. /// /// Output is produced in FIFO order: the batch appended first is read first. /// @@ -247,15 +247,112 @@ fn byte_view_data_buffer_size(array: &GenericByteViewArray) #[cfg(test)] mod tests { + use super::SpillManager; + use crate::common::collect; + use crate::metrics::{ExecutionPlanMetricsSet, SpillMetrics}; use crate::spill::{get_record_batch_memory_size, spill_manager::GetSlicedSize}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::{ - array::{ArrayRef, StringViewArray}, + array::{ArrayRef, Int32Array, StringArray, StringViewArray}, record_batch::RecordBatch, }; use datafusion_common::Result; + use datafusion_execution::runtime_env::RuntimeEnv; use std::sync::Arc; + fn build_test_spill_manager( + env: Arc, + schema: Arc, + ) -> SpillManager { + let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + SpillManager::new(env, metrics, schema) + } + + fn build_writer_batch(schema: Arc) -> Result { + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + ) + .map_err(Into::into) + } + + #[tokio::test] + async fn test_read_spill_as_stream_from_another_spill_manager_same_schema() + -> Result<()> { + let env = Arc::new(RuntimeEnv::default()); + let writer_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ])); + let reader_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ])); + + let writer = + build_test_spill_manager(Arc::clone(&env), Arc::clone(&writer_schema)); + let reader = build_test_spill_manager(env, Arc::clone(&reader_schema)); + let written_batch = build_writer_batch(Arc::clone(&writer_schema))?; + + let spill_file = writer + .spill_record_batch_and_finish( + std::slice::from_ref(&written_batch), + "writer", + )? + .unwrap(); + + // Same-schema reads through a different SpillManager currently pass + // because only schema compatibility is validated. This is not a + // supported usage pattern. + let stream = reader.read_spill_as_stream(spill_file, None)?; + assert_eq!(stream.schema(), reader_schema); + + let batches = collect(stream).await?; + assert_eq!(batches, vec![written_batch]); + + Ok(()) + } + + #[tokio::test] + async fn test_read_spill_as_stream_from_another_spill_manager_different_schema() + -> Result<()> { + let env = Arc::new(RuntimeEnv::default()); + let writer_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ])); + let reader_schema = Arc::new(Schema::new(vec![ + Field::new("other_id", DataType::Int32, true), + Field::new("other_value", DataType::Utf8, true), + ])); + + let writer = + build_test_spill_manager(Arc::clone(&env), Arc::clone(&writer_schema)); + let reader = build_test_spill_manager(env, Arc::clone(&reader_schema)); + let written_batch = build_writer_batch(Arc::clone(&writer_schema))?; + + let spill_file = writer + .spill_record_batch_and_finish( + std::slice::from_ref(&written_batch), + "writer", + )? + .unwrap(); + + let stream = reader.read_spill_as_stream(spill_file, None)?; + let err = collect(stream) + .await + .expect_err("schema mismatch should fail fast"); + let err = err.to_string(); + assert!(err.contains("Spill file schema mismatch")); + assert!(err.contains("expected")); + assert!(err.contains("got")); + + Ok(()) + } + #[test] fn check_sliced_size_for_string_view_array() -> Result<()> { let array_length = 50;