Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion datafusion/physical-plan/src/spill/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ use arrow::record_batch::RecordBatch;
use arrow_data::ArrayDataBuilder;

use datafusion_common::config::SpillCompression;
use datafusion_common::{DataFusionError, Result, exec_datafusion_err};
use datafusion_common::{DataFusionError, Result, exec_datafusion_err, exec_err};
use datafusion_common_runtime::SpawnedTask;
use datafusion_execution::RecordBatchStream;
use datafusion_execution::disk_manager::RefCountedTempFile;
Expand Down Expand Up @@ -121,6 +121,7 @@ impl SpillReaderStream {
unreachable!()
};

let expected_schema = Arc::clone(&self.schema);
let task = SpawnedTask::spawn_blocking(move || {
let file = BufReader::new(File::open(spill_file.path())?);
// SAFETY: DataFusion's spill writer strictly follows Arrow IPC specifications
Expand All @@ -130,6 +131,21 @@ impl SpillReaderStream {
StreamReader::try_new(file, None)?.with_skip_validation(true)
};

// Validate the schema read from Arrow IPC file is the same as the
// schema of the current `SpillManager`
let actual_schema = reader.schema();

if actual_schema != expected_schema {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The equality includes metadata. So we also need to strict metadta?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I’m not sure—some metadata may need to be preserved through the round trip.

What do you think about keeping the metadata comparison for now to stay conservative? If it proves unnecessarily strict later, we can relax it.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense

return exec_err!(
"Spill file schema mismatch: expected {}, got {}. \
The caller must use the same SpillManager that created the spill file to read it.",
expected_schema,
actual_schema
);
}

// TODO: Same-schema reads from a different SpillManager still pass today.
// Add a SpillManager UID to IPC metadata and validate it here as well.
let next_batch = reader.next().transpose()?;

Ok((reader, next_batch))
Expand Down
101 changes: 99 additions & 2 deletions datafusion/physical-plan/src/spill/spill_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ impl SpillManager {
}

/// Reads a spill file as a stream. The file must be created by the current
/// `SpillManager`; otherwise behavior is undefined.
/// `SpillManager`; otherwise an error will be returned.
///
/// Output is produced in FIFO order: the batch appended first is read first.
///
Expand Down Expand Up @@ -247,15 +247,112 @@ fn byte_view_data_buffer_size<T: ByteViewType>(array: &GenericByteViewArray<T>)

#[cfg(test)]
mod tests {
use super::SpillManager;
use crate::common::collect;
use crate::metrics::{ExecutionPlanMetricsSet, SpillMetrics};
use crate::spill::{get_record_batch_memory_size, spill_manager::GetSlicedSize};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::{
array::{ArrayRef, StringViewArray},
array::{ArrayRef, Int32Array, StringArray, StringViewArray},
record_batch::RecordBatch,
};
use datafusion_common::Result;
use datafusion_execution::runtime_env::RuntimeEnv;
use std::sync::Arc;

fn build_test_spill_manager(
env: Arc<RuntimeEnv>,
schema: Arc<Schema>,
) -> SpillManager {
let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
SpillManager::new(env, metrics, schema)
}

fn build_writer_batch(schema: Arc<Schema>) -> Result<RecordBatch> {
RecordBatch::try_new(
schema,
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["a", "b", "c"])),
],
)
.map_err(Into::into)
}

#[tokio::test]
async fn test_read_spill_as_stream_from_another_spill_manager_same_schema()
-> Result<()> {
let env = Arc::new(RuntimeEnv::default());
let writer_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("value", DataType::Utf8, false),
]));
let reader_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("value", DataType::Utf8, false),
]));

let writer =
build_test_spill_manager(Arc::clone(&env), Arc::clone(&writer_schema));
let reader = build_test_spill_manager(env, Arc::clone(&reader_schema));
let written_batch = build_writer_batch(Arc::clone(&writer_schema))?;

let spill_file = writer
.spill_record_batch_and_finish(
std::slice::from_ref(&written_batch),
"writer",
)?
.unwrap();

// Same-schema reads through a different SpillManager currently pass
// because only schema compatibility is validated. This is not a
// supported usage pattern.
let stream = reader.read_spill_as_stream(spill_file, None)?;
assert_eq!(stream.schema(), reader_schema);

let batches = collect(stream).await?;
assert_eq!(batches, vec![written_batch]);

Ok(())
}

#[tokio::test]
async fn test_read_spill_as_stream_from_another_spill_manager_different_schema()
-> Result<()> {
let env = Arc::new(RuntimeEnv::default());
let writer_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("value", DataType::Utf8, false),
]));
let reader_schema = Arc::new(Schema::new(vec![
Field::new("other_id", DataType::Int32, true),
Field::new("other_value", DataType::Utf8, true),
]));

let writer =
build_test_spill_manager(Arc::clone(&env), Arc::clone(&writer_schema));
let reader = build_test_spill_manager(env, Arc::clone(&reader_schema));
let written_batch = build_writer_batch(Arc::clone(&writer_schema))?;

let spill_file = writer
.spill_record_batch_and_finish(
std::slice::from_ref(&written_batch),
"writer",
)?
.unwrap();

let stream = reader.read_spill_as_stream(spill_file, None)?;
let err = collect(stream)
.await
.expect_err("schema mismatch should fail fast");
let err = err.to_string();
assert!(err.contains("Spill file schema mismatch"));
assert!(err.contains("expected"));
assert!(err.contains("got"));

Ok(())
}

#[test]
fn check_sliced_size_for_string_view_array() -> Result<()> {
let array_length = 50;
Expand Down
Loading