diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 93d7e7258772f..030fa5bbf94ff 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -207,6 +207,13 @@ pub struct NestedLoopJoinExec { /// Each output stream waits on the `OnceAsync` to signal the completion of /// the build(left) side data, and buffer them all for later joining. build_side_data: OnceAsync, + /// Shared left-side spill data for OOM fallback. + /// + /// When `build_side_data` fails with OOM, the first partition to + /// initiate fallback spills the entire left side to disk. Other + /// partitions share the same spill file via this `OnceAsync`, + /// avoiding redundant re-execution of the left child. + left_spill_data: Arc>, /// Information of index and left / right placement of columns column_indices: Vec, /// Projection to apply to the output of the join @@ -290,6 +297,7 @@ impl NestedLoopJoinExecBuilder { join_type, join_schema, build_side_data: Default::default(), + left_spill_data: Arc::new(OnceAsync::default()), column_indices, projection, metrics: Default::default(), @@ -492,6 +500,7 @@ impl NestedLoopJoinExec { right, metrics: ExecutionPlanMetricsSet::new(), build_side_data: Default::default(), + left_spill_data: Arc::new(OnceAsync::default()), cache: Arc::clone(&self.cache), filter: self.filter.clone(), join_type: self.join_type, @@ -655,6 +664,7 @@ impl ExecutionPlan for NestedLoopJoinExec { SpillState::Pending { left_plan: Arc::clone(&self.left), task_context: Arc::clone(&context), + left_spill_data: Arc::clone(&self.left_spill_data), } } else { SpillState::Disabled @@ -863,6 +873,20 @@ enum NLJState { EmitLeftUnmatched, Done, } +/// Shared data for the left-side spill fallback. +/// +/// When the in-memory `OnceFut` path fails with OOM, the first partition +/// spills the entire left side to disk. This struct holds the spill file +/// reference so other partitions can read from the same file. +pub(crate) struct LeftSpillData { + /// SpillManager used to read the spill file (has the left schema) + spill_manager: SpillManager, + /// The spill file containing all left-side batches + spill_file: RefCountedTempFile, + /// Left-side schema + schema: SchemaRef, +} + /// Tracks the state of the memory-limited spill fallback for NLJ. /// /// The NLJ always starts with the standard OnceFut path. If the in-memory @@ -882,6 +906,9 @@ pub(crate) enum SpillState { left_plan: Arc, /// TaskContext for re-execution and SpillManager creation task_context: Arc, + /// Shared OnceAsync for left-side spill data. The first partition + /// to initiate fallback spills the left side; others share the file. + left_spill_data: Arc>, }, /// Fallback has been triggered. Left data is being loaded in chunks @@ -892,16 +919,20 @@ pub(crate) enum SpillState { /// State for active memory-limited spill execution. /// Boxed inside [`SpillState::Active`] to reduce enum size. pub(crate) struct SpillStateActive { - /// Left input stream for incremental buffering - left_stream: SendableRecordBatchStream, + /// Shared future for left-side spill data. All partitions wait on + /// the same future — the first to poll triggers the actual spill. + left_spill_fut: OnceFut, + /// Left input stream for incremental chunk reading (from spill file). + /// None until `left_spill_fut` resolves. + left_stream: Option, + /// Left-side schema (set once `left_spill_fut` resolves) + left_schema: Option, /// Memory reservation for left-side buffering reservation: MemoryReservation, /// Accumulated left batches for the current chunk pending_batches: Vec, - /// Left-side schema (for concat_batches) - left_schema: SchemaRef, /// SpillManager for right-side spilling - spill_manager: SpillManager, + right_spill_manager: SpillManager, /// In-progress spill file for writing right batches during first pass right_spill_in_progress: Option, /// Completed right-side spill file (available after first pass) @@ -1263,19 +1294,18 @@ impl NestedLoopJoinStream { /// Switch from the standard OnceFut path to memory-limited mode. /// - /// Re-executes the left child to get a fresh stream, creates a - /// SpillManager for right-side spilling, and transitions the spill - /// state from `Pending` to `Active`. The next call to - /// `handle_buffering_left` will dispatch to - /// `handle_buffering_left_memory_limited`. + /// Uses the shared `left_spill_data` OnceAsync so that only the first + /// partition to reach this point re-executes the left child and spills + /// it to disk. Other partitions share the same spill file. fn initiate_fallback(&mut self) -> Result<()> { // Take ownership of Pending state - let (left_plan, context) = + let (left_plan, context, left_spill_data) = match std::mem::replace(&mut self.spill_state, SpillState::Disabled) { SpillState::Pending { left_plan, task_context, - } => (left_plan, task_context), + left_spill_data, + } => (left_plan, task_context, left_spill_data), _ => { return internal_err!( "initiate_fallback called in non-Pending spill state" @@ -1283,9 +1313,42 @@ impl NestedLoopJoinStream { } }; - // Re-execute left child to get a fresh stream - let left_stream = left_plan.execute(0, Arc::clone(&context))?; - let left_schema = left_stream.schema(); + // Use OnceAsync to ensure only the first partition spills the left + // side. Other partitions will get the same OnceFut that resolves + // to the shared spill file. + let left_spill_fut = left_spill_data.try_once(|| { + let plan = Arc::clone(&left_plan); + let ctx = Arc::clone(&context); + let spill_metrics = self.metrics.spill_metrics.clone(); + Ok(async move { + let mut stream = plan.execute(0, Arc::clone(&ctx))?; + let schema = stream.schema(); + let left_spill_manager = SpillManager::new( + ctx.runtime_env(), + spill_metrics, + Arc::clone(&schema), + ) + .with_compression_type(ctx.session_config().spill_compression()); + + let result = left_spill_manager + .spill_record_batch_stream_and_return_max_batch_memory( + &mut stream, + "NestedLoopJoin left spill", + ) + .await?; + + match result { + Some((file, _max_batch_memory)) => Ok(LeftSpillData { + spill_manager: left_spill_manager, + spill_file: file, + schema, + }), + None => { + internal_err!("Left side produced no data to spill") + } + } + }) + })?; // Create reservation with can_spill for fair memory allocation let reservation = MemoryConsumer::new("NestedLoopJoinLoad[fallback]".to_string()) @@ -1294,7 +1357,7 @@ impl NestedLoopJoinStream { // Create SpillManager for right-side spilling let right_schema = self.right_data.schema(); - let spill_manager = SpillManager::new( + let right_spill_manager = SpillManager::new( context.runtime_env(), self.metrics.spill_metrics.clone(), right_schema, @@ -1302,11 +1365,12 @@ impl NestedLoopJoinStream { .with_compression_type(context.session_config().spill_compression()); self.spill_state = SpillState::Active(Box::new(SpillStateActive { - left_stream, + left_spill_fut, + left_stream: None, + left_schema: None, reservation, pending_batches: Vec::new(), - left_schema, - spill_manager, + right_spill_manager, right_spill_in_progress: None, right_spill_file: None, right_max_batch_memory: 0, @@ -1378,11 +1442,44 @@ impl NestedLoopJoinStream { ); }; + // On first entry (or after re-entry for a new chunk pass when + // left_stream was consumed), wait for the shared left spill + // future to resolve and then open a stream from the spill file. + if active.left_stream.is_none() { + match active.left_spill_fut.get_shared(cx) { + Poll::Ready(Ok(spill_data)) => { + match spill_data + .spill_manager + .read_spill_as_stream(spill_data.spill_file.clone(), None) + { + Ok(stream) => { + active.left_schema = Some(Arc::clone(&spill_data.schema)); + active.left_stream = Some(stream); + } + Err(e) => { + return ControlFlow::Break(Poll::Ready(Some(Err(e)))); + } + } + } + Poll::Ready(Err(e)) => { + return ControlFlow::Break(Poll::Ready(Some(Err(e)))); + } + Poll::Pending => { + return ControlFlow::Break(Poll::Pending); + } + } + } + + let left_stream = active + .left_stream + .as_mut() + .expect("left_stream must be set after spill future resolves"); + // Poll left stream for more batches. // Note: pending_batches may already contain a batch from the // previous chunk iteration (the batch that triggered the memory limit). loop { - match active.left_stream.poll_next_unpin(cx) { + match left_stream.poll_next_unpin(cx) { Poll::Ready(Some(Ok(batch))) => { if batch.num_rows() == 0 { continue; @@ -1431,13 +1528,18 @@ impl NestedLoopJoinStream { return ControlFlow::Continue(()); } - let merged_batch = - match concat_batches(&active.left_schema, &active.pending_batches) { - Ok(batch) => batch, - Err(e) => { - return ControlFlow::Break(Poll::Ready(Some(Err(e.into())))); - } - }; + let merged_batch = match concat_batches( + active + .left_schema + .as_ref() + .expect("left_schema must be set"), + &active.pending_batches, + ) { + Ok(batch) => batch, + Err(e) => { + return ControlFlow::Break(Poll::Ready(Some(Err(e.into())))); + } + }; active.pending_batches.clear(); // Build visited bitmap if needed for this join type @@ -1472,7 +1574,7 @@ impl NestedLoopJoinStream { // Set up right-side stream for this pass if !active.is_first_right_pass { if let Some(file) = active.right_spill_file.as_ref() { - match active.spill_manager.read_spill_as_stream( + match active.right_spill_manager.read_spill_as_stream( file.clone(), Some(active.right_max_batch_memory), ) { @@ -1487,7 +1589,7 @@ impl NestedLoopJoinStream { } else { // First pass: create InProgressSpillFile for right side match active - .spill_manager + .right_spill_manager .create_in_progress_file("NestedLoopJoin right spill") { Ok(file) => { diff --git a/datafusion/sqllogictest/test_files/nested_loop_join_spill.slt b/datafusion/sqllogictest/test_files/nested_loop_join_spill.slt index 5b383f3edf6cc..7b5da1d4b8e03 100644 --- a/datafusion/sqllogictest/test_files/nested_loop_join_spill.slt +++ b/datafusion/sqllogictest/test_files/nested_loop_join_spill.slt @@ -39,8 +39,9 @@ INNER JOIN generate_series(1, 1) AS t2(v2) 100000 1 100000 # --- Verify spill metrics via EXPLAIN ANALYZE --- -# The NestedLoopJoinExec line should show spill_count=1, confirming -# the memory-limited fallback path was taken and right side was spilled. +# The NestedLoopJoinExec line should show spill_count=2, confirming +# the memory-limited fallback path was taken (left side spilled once, +# right side spilled once). query TT EXPLAIN ANALYZE SELECT count(*) FROM generate_series(1, 100000) AS t1(v1) @@ -50,7 +51,7 @@ INNER JOIN generate_series(1, 1) AS t2(v2) Plan with Metrics 01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)], metrics=[] 02)--AggregateExec: mode=Single, gby=[], aggr=[count(Int64(1))], metrics=[] -03)----NestedLoopJoinExec: join_type=Inner, filter=v1@0 + v2@1 > 0, projection=[], metrics=[output_rows=100.0 K, spill_count=1, ] +03)----NestedLoopJoinExec: join_type=Inner, filter=v1@0 + v2@1 > 0, projection=[], metrics=[output_rows=100.0 K, spill_count=2, ] 04)------ProjectionExec: expr=[value@0 as v1], metrics=[] 05)--------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192], metrics=[] 06)------ProjectionExec: expr=[value@0 as v2], metrics=[]