diff --git a/docs/source/contributor-guide/adding_a_new_operator.md b/docs/source/contributor-guide/adding_a_new_operator.md index 4317943aa8..82c237f830 100644 --- a/docs/source/contributor-guide/adding_a_new_operator.md +++ b/docs/source/contributor-guide/adding_a_new_operator.md @@ -553,8 +553,14 @@ For operators that run in the JVM: Example pattern from `CometExecRule.scala`: ```scala -case s: ShuffleExchangeExec if nativeShuffleSupported(s) => - CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle) +case s: ShuffleExchangeExec => + CometShuffleExchangeExec.shuffleSupported(s) match { + case Some(CometNativeShuffle) => + CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle) + case Some(CometColumnarShuffle) => + CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle) + case None => s + } ``` ## Common Patterns and Helpers diff --git a/spark/src/main/scala/org/apache/comet/CometFallback.scala b/spark/src/main/scala/org/apache/comet/CometFallback.scala deleted file mode 100644 index 28a4816b66..0000000000 --- a/spark/src/main/scala/org/apache/comet/CometFallback.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet - -import org.apache.spark.sql.catalyst.trees.{TreeNode, TreeNodeTag} - -import org.apache.comet.CometSparkSessionExtensions.withInfo - -/** - * Sticky fallback marker for shuffle / stage nodes. - * - * Comet's shuffle-support predicates (e.g. `CometShuffleExchangeExec.columnarShuffleSupported`) - * run at both initial planning and AQE stage-prep. Some fallback decisions depend on the - * surrounding plan shape - for example, the presence of a DPP scan below a shuffle. Between the - * two passes AQE can reshape that subtree (a completed child stage becomes a - * `ShuffleQueryStageExec`, a `LeafExecNode` whose `children` is empty), so a naive re-evaluation - * can flip the decision. - * - * When a decision is made on the initial-plan pass, the deciding rule records a sticky tag via - * [[markForFallback]]. On subsequent passes, callers short-circuit via [[isMarkedForFallback]] - * and preserve the earlier decision instead of re-deriving it from the current plan shape. - * - * This tag is kept separate from `CometExplainInfo.EXTENSION_INFO` on purpose: the explain tag - * accumulates informational reasons (including rolled-up child reasons), many of which are not a - * full-fallback signal. Treating any presence of explain info as fallback is too coarse and - * breaks legitimate conversions (e.g. a shuffle tagged "Comet native shuffle not enabled" should - * still be eligible for columnar shuffle). The fallback tag exists only for decisions that should - * remain sticky. - */ -object CometFallback { - - val STAGE_FALLBACK_TAG: TreeNodeTag[Set[String]] = - new TreeNodeTag[Set[String]]("CometStageFallback") - - /** - * Mark a node so that subsequent shuffle-support re-evaluations fall back to Spark without - * re-deriving the decision from the (possibly reshaped) subtree. Also records the reason in the - * usual explain channel so it surfaces in extended explain output. - */ - def markForFallback[T <: TreeNode[_]](node: T, reason: String): T = { - val existing = node.getTagValue(STAGE_FALLBACK_TAG).getOrElse(Set.empty[String]) - node.setTagValue(STAGE_FALLBACK_TAG, existing + reason) - withInfo(node, reason) - node - } - - /** True if a prior rule pass marked this node for Spark fallback via [[markForFallback]]. */ - def isMarkedForFallback(node: TreeNode[_]): Boolean = - node.getTagValue(STAGE_FALLBACK_TAG).exists(_.nonEmpty) -} diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 5839570684..b89e57422b 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -200,22 +200,29 @@ object CometSparkSessionExtensions extends Logging { } /** - * Attaches explain information to a TreeNode, rolling up the corresponding information tags - * from any child nodes. For now, we are using this to attach the reasons why certain Spark - * operators or expressions are disabled. + * Record a fallback reason on a `TreeNode` (a Spark operator or expression) explaining why + * Comet cannot accelerate it. Reasons recorded here are surfaced in extended explain output + * (see `ExtendedExplainInfo`) and, when `COMET_LOG_FALLBACK_REASONS` is enabled, logged as + * warnings. The reasons are also rolled up from child nodes so that the operator that remains + * in the Spark plan carries the reasons from its converted-away subtree. + * + * Call this in any code path where Comet decides not to convert a given node - serde `convert` + * methods returning `None`, unsupported data types, disabled configs, etc. Do not use this for + * informational messages that are not fallback reasons: anything tagged here is treated by the + * rules as a signal that the node falls back to Spark. * * @param node - * The node to attach the explain information to. Typically a SparkPlan + * The Spark operator or expression that is falling back to Spark. * @param info - * Information text. Optional, may be null or empty. If not provided, then only information - * from child nodes will be included. + * The fallback reason. Optional, may be null or empty - pass empty only when the call is used + * purely to roll up reasons from `exprs`. * @param exprs - * Child nodes. Information attached in these nodes will be be included in the information - * attached to @node + * Child nodes whose own fallback reasons should be rolled up into `node`. Pass the + * sub-expressions or child operators whose failure caused `node` to fall back. * @tparam T - * The type of the TreeNode. Typically SparkPlan, AggregateExpression, or Expression + * The type of the TreeNode. Typically `SparkPlan`, `AggregateExpression`, or `Expression`. * @return - * The node with information (if any) attached + * `node` with fallback reasons attached (as a side effect on its tag map). */ def withInfo[T <: TreeNode[_]](node: T, info: String, exprs: T*): T = { // support existing approach of passing in multiple infos in a newline-delimited string @@ -228,22 +235,24 @@ object CometSparkSessionExtensions extends Logging { } /** - * Attaches explain information to a TreeNode, rolling up the corresponding information tags - * from any child nodes. For now, we are using this to attach the reasons why certain Spark - * operators or expressions are disabled. + * Record one or more fallback reasons on a `TreeNode` and roll up reasons from any child nodes. + * This is the set-valued form of [[withInfo]]; see that overload for the full contract. + * + * Reasons are accumulated (never overwritten) on the node's `EXTENSION_INFO` tag and are + * surfaced in extended explain output. When `COMET_LOG_FALLBACK_REASONS` is enabled, each new + * reason is also emitted as a warning. * * @param node - * The node to attach the explain information to. Typically a SparkPlan + * The Spark operator or expression that is falling back to Spark. * @param info - * Information text. May contain zero or more strings. If not provided, then only information - * from child nodes will be included. + * The fallback reasons for this node. May be empty when the call is used purely to roll up + * child reasons. * @param exprs - * Child nodes. Information attached in these nodes will be be included in the information - * attached to @node + * Child nodes whose own fallback reasons should be rolled up into `node`. * @tparam T - * The type of the TreeNode. Typically SparkPlan, AggregateExpression, or Expression + * The type of the TreeNode. Typically `SparkPlan`, `AggregateExpression`, or `Expression`. * @return - * The node with information (if any) attached + * `node` with fallback reasons attached (as a side effect on its tag map). */ def withInfos[T <: TreeNode[_]](node: T, info: Set[String], exprs: T*): T = { if (CometConf.COMET_LOG_FALLBACK_REASONS.get()) { @@ -259,25 +268,27 @@ object CometSparkSessionExtensions extends Logging { } /** - * Attaches explain information to a TreeNode, rolling up the corresponding information tags - * from any child nodes + * Roll up fallback reasons from `exprs` onto `node` without adding a new reason of its own. Use + * this when a parent operator is itself falling back and wants to preserve the reasons recorded + * on its child expressions/operators so they appear together in explain output. * * @param node - * The node to attach the explain information to. Typically a SparkPlan + * The parent operator or expression falling back to Spark. * @param exprs - * Child nodes. Information attached in these nodes will be be included in the information - * attached to @node + * Child nodes whose fallback reasons should be aggregated onto `node`. * @tparam T - * The type of the TreeNode. Typically SparkPlan, AggregateExpression, or Expression + * The type of the TreeNode. Typically `SparkPlan`, `AggregateExpression`, or `Expression`. * @return - * The node with information (if any) attached + * `node` with the rolled-up reasons attached (as a side effect on its tag map). */ def withInfo[T <: TreeNode[_]](node: T, exprs: T*): T = { withInfos(node, Set.empty, exprs: _*) } /** - * Checks whether a TreeNode has any explain information attached + * True if any fallback reason has been recorded on `node` (via [[withInfo]] / [[withInfos]]). + * Callers that need to short-circuit when a prior rule pass has already decided a node falls + * back can use this as the sticky signal. */ def hasExplainInfo(node: TreeNode[_]): Boolean = { node.getTagValue(CometExplainInfo.EXTENSION_INFO).exists(_.nonEmpty) diff --git a/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala b/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala index d30a1fe788..dbd22cec84 100644 --- a/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala +++ b/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala @@ -214,6 +214,15 @@ object CometCoverageStats { object CometExplainInfo { val EXTENSION_INFO = new TreeNodeTag[Set[String]]("CometExtensionInfo") + /** + * Records handler class names whose `getSupportLevel` has already returned `Unsupported` or + * `Incompatible` (without `allowIncompat`) on a given operator, so that repeat invocations of + * the same handler on the same node during later rule passes can short-circuit without + * re-running the check. Orthogonal to [[EXTENSION_INFO]]; keyed per handler so other handlers + * on the same node are unaffected. + */ + val FAILED_HANDLERS = new TreeNodeTag[Set[String]]("CometFailedHandlers") + def getActualPlan(node: TreeNode[_]): TreeNode[_] = { node match { case p: AdaptiveSparkPlanExec => getActualPlan(p.executedPlan) diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index 70983b0599..fac7d0fe92 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -98,17 +98,18 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { private lazy val showTransformations = CometConf.COMET_EXPLAIN_TRANSFORMATIONS.get() private def applyCometShuffle(plan: SparkPlan): SparkPlan = { - plan.transformUp { - case s: ShuffleExchangeExec if CometShuffleExchangeExec.nativeShuffleSupported(s) => - // Switch to use Decimal128 regardless of precision, since Arrow native execution - // doesn't support Decimal32 and Decimal64 yet. - conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true") - CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle) - - case s: ShuffleExchangeExec if CometShuffleExchangeExec.columnarShuffleSupported(s) => - // Columnar shuffle for regular Spark operators (not Comet) and Comet operators - // (if configured) - CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle) + plan.transformUp { case s: ShuffleExchangeExec => + CometShuffleExchangeExec.shuffleSupported(s) match { + case Some(CometNativeShuffle) => + // Switch to use Decimal128 regardless of precision, since Arrow native execution + // doesn't support Decimal32 and Decimal64 yet. + conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true") + CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle) + case Some(CometColumnarShuffle) => + CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle) + case None => + s + } } } @@ -479,8 +480,13 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { } /** Convert a Spark plan to a Comet plan using the specified serde handler */ - private def convertToComet(op: SparkPlan, handler: CometOperatorSerde[_]): Option[SparkPlan] = { + private[rules] def convertToComet( + op: SparkPlan, + handler: CometOperatorSerde[_]): Option[SparkPlan] = { val serde = handler.asInstanceOf[CometOperatorSerde[SparkPlan]] + if (hasFailedHandler(op, handler)) { + return None + } if (isOperatorEnabled(serde, op)) { // For operators that require native children (like writes), check if all data-producing // children are CometNativeExec. This prevents runtime failures when the native operator @@ -521,6 +527,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { handler.getSupportLevel(op) match { case Unsupported(notes) => withInfo(op, notes.getOrElse("")) + recordFailedHandler(op, handler) false case Incompatible(notes) => val allowIncompat = CometConf.isOperatorAllowIncompat(opName) @@ -539,6 +546,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { s"$opName is not fully compatible with Spark$optionalNotes. " + s"To enable it anyway, set $incompatConf=true. " + s"${CometConf.COMPAT_GUIDE}.") + recordFailedHandler(op, handler) false } case Compatible(notes) => @@ -556,6 +564,16 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { } } + private def hasFailedHandler(op: SparkPlan, handler: CometOperatorSerde[_]): Boolean = { + op.getTagValue(CometExplainInfo.FAILED_HANDLERS) + .exists(_.contains(handler.getClass.getName)) + } + + private def recordFailedHandler(op: SparkPlan, handler: CometOperatorSerde[_]): Unit = { + val existing = op.getTagValue(CometExplainInfo.FAILED_HANDLERS).getOrElse(Set.empty[String]) + op.setTagValue(CometExplainInfo.FAILED_HANDLERS, existing + handler.getClass.getName) + } + private def shouldApplySparkToColumnar(conf: SQLConf, op: SparkPlan): Boolean = { // Only consider converting leaf nodes to columnar currently, so that all the following // operators can have a chance to be converted to columnar. Leaf operators that output diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 2500a52658..bd6922b9e8 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -48,10 +48,9 @@ import org.apache.spark.util.random.XORShiftRandom import com.google.common.base.Objects -import org.apache.comet.CometConf +import org.apache.comet.{CometConf, CometExplainInfo} import org.apache.comet.CometConf.{COMET_EXEC_SHUFFLE_ENABLED, COMET_SHUFFLE_MODE} -import org.apache.comet.CometFallback.{isMarkedForFallback, markForFallback} -import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleManagerEnabled, withInfo} +import org.apache.comet.CometSparkSessionExtensions.{hasExplainInfo, isCometShuffleManagerEnabled, withInfos} import org.apache.comet.serde.{Compatible, OperatorOuterClass, QueryPlanSerde, SupportLevel, Unsupported} import org.apache.comet.serde.operator.CometSink import org.apache.comet.shims.ShimCometShuffleExchangeExec @@ -223,39 +222,90 @@ object CometShuffleExchangeExec with SQLConfHelper { override def getSupportLevel(op: ShuffleExchangeExec): SupportLevel = { - if (nativeShuffleSupported(op) || columnarShuffleSupported(op)) { - Compatible() - } else { - Unsupported() - } + if (shuffleSupported(op).isDefined) Compatible() else Unsupported() } override def createExec( nativeOp: OperatorOuterClass.Operator, op: ShuffleExchangeExec): CometNativeExec = { - if (nativeShuffleSupported(op) && op.children.forall(_.isInstanceOf[CometNativeExec])) { - // Switch to use Decimal128 regardless of precision, since Arrow native execution - // doesn't support Decimal32 and Decimal64 yet. - conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true") - CometSinkPlaceHolder( - nativeOp, - op, - CometShuffleExchangeExec(op, shuffleType = CometNativeShuffle)) - - } else if (columnarShuffleSupported(op)) { - CometSinkPlaceHolder( - nativeOp, - op, - CometShuffleExchangeExec(op, shuffleType = CometColumnarShuffle)) - } else { - throw new IllegalStateException() + shuffleSupported(op) match { + case Some(CometNativeShuffle) if op.children.forall(_.isInstanceOf[CometNativeExec]) => + // Switch to use Decimal128 regardless of precision, since Arrow native execution + // doesn't support Decimal32 and Decimal64 yet. + conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true") + CometSinkPlaceHolder( + nativeOp, + op, + CometShuffleExchangeExec(op, shuffleType = CometNativeShuffle)) + case Some(CometColumnarShuffle) => + CometSinkPlaceHolder( + nativeOp, + op, + CometShuffleExchangeExec(op, shuffleType = CometColumnarShuffle)) + case Some(CometNativeShuffle) => + // Native was chosen but children are not native - fall through to columnar if possible. + // This can happen when getSupportLevel selected native but a later pass changed the plan. + throw new IllegalStateException( + "shuffleSupported chose native shuffle but children are not all CometNativeExec") + case None => + throw new IllegalStateException() + } + } + + /** + * Decide which Comet shuffle path (if any) can handle this shuffle. Returns `None` if neither + * native nor columnar shuffle can be used; in that case the node is tagged with the combined + * fallback reasons via `withInfos` so subsequent passes short-circuit via `hasExplainInfo`. + * + * This is the single coordination point: the two path-specific predicates + * (`nativeShuffleFailureReasons` / `columnarShuffleFailureReasons`) are pure - they return + * collected reasons but do not tag. Tagging only happens here, and only on total failure. + */ + def shuffleSupported(s: ShuffleExchangeExec): Option[ShuffleType] = { + // Sticky: a prior rule pass (initial planning or an earlier AQE pass) already decided this + // shuffle falls back to Spark and tagged it. Preserve that decision - re-deriving it against + // a possibly-reshaped subtree (e.g. AQE stage-wrapping) can flip the answer and produce + // inconsistent plans across passes (see #3949). + if (hasExplainInfo(s)) return None + + isCometShuffleEnabledReason(s) match { + case Some(reason) => + withInfos(s, Set(reason)) + return None + case None => + } + + // DPP fallback is a combined-path decision: a Comet shuffle wrapped around a stage that + // still contains a DPP scan produces inefficient row<->columnar transitions. Disqualifies + // both paths. + if (CometConf.COMET_DPP_FALLBACK_ENABLED.get() && stageContainsDPPScan(s)) { + withInfos(s, Set("Stage contains a scan with Dynamic Partition Pruning")) + return None + } + + // Native path is only eligible when the child is already a Comet plan; otherwise skip it + // silently (no reason to surface) and let columnar take over. + val nativeReasons: Seq[String] = + if (isCometPlan(s.child)) nativeShuffleFailureReasons(s) else Seq.empty + if (isCometPlan(s.child) && nativeReasons.isEmpty) { + return Some(CometNativeShuffle) + } + + val columnarReasons = columnarShuffleFailureReasons(s) + if (columnarReasons.isEmpty) { + return Some(CometColumnarShuffle) } + + val combined = (nativeReasons ++ columnarReasons).toSet + if (combined.nonEmpty) withInfos(s, combined) + None } /** - * Whether the given Spark partitioning is supported by Comet native shuffle. + * Reasons the native shuffle path cannot handle this shuffle. Empty means native is supported. + * Pure: does not tag the node. */ - def nativeShuffleSupported(s: ShuffleExchangeExec): Boolean = { + private def nativeShuffleFailureReasons(s: ShuffleExchangeExec): Seq[String] = { /** * Determine which data types are supported as partition columns in native shuffle. @@ -279,19 +329,6 @@ object CometShuffleExchangeExec false } - /** - * Check if a data type contains a decimal with precision > 18. Such decimals require - * conversion to Java BigDecimal before hashing, which is not supported in native shuffle. - */ - def containsHighPrecisionDecimal(dt: DataType): Boolean = dt match { - case d: DecimalType => d.precision > 18 - case StructType(fields) => fields.exists(f => containsHighPrecisionDecimal(f.dataType)) - case ArrayType(elementType, _) => containsHighPrecisionDecimal(elementType) - case MapType(keyType, valueType, _) => - containsHighPrecisionDecimal(keyType) || containsHighPrecisionDecimal(valueType) - case _ => false - } - /** * Determine which data types are supported as partition columns in native shuffle. * @@ -329,31 +366,19 @@ object CometShuffleExchangeExec false } - // Preserve any prior-pass fallback decision (see `CometFallback`). - if (isMarkedForFallback(s)) { - return false - } - - if (!isCometShuffleEnabledWithInfo(s)) { - return false - } + val reasons = scala.collection.mutable.ListBuffer.empty[String] if (!isCometNativeShuffleMode(s.conf)) { - withInfo(s, "Comet native shuffle not enabled") - return false - } - - if (!isCometPlan(s.child)) { - // we do not need to report a fallback reason if the child plan is not a Comet plan - return false + reasons += "Comet native shuffle not enabled" + return reasons.toSeq } val inputs = s.child.output for (input <- inputs) { if (!supportedSerializableDataType(input.dataType)) { - withInfo(s, s"unsupported shuffle data type ${input.dataType} for input $input") - return false + reasons += s"unsupported shuffle data type ${input.dataType} for input $input" + return reasons.toSeq } } @@ -361,76 +386,58 @@ object CometShuffleExchangeExec val conf = SQLConf.get partitioning match { case HashPartitioning(expressions, _) => - var supported = true if (!CometConf.COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED.get(conf)) { - withInfo( - s, - s"${CometConf.COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED.key} is disabled") - supported = false + reasons += + s"${CometConf.COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED.key} is disabled" } for (expr <- expressions) { if (QueryPlanSerde.exprToProto(expr, inputs).isEmpty) { - withInfo(s, s"unsupported hash partitioning expression: $expr") - supported = false - // We don't short-circuit in case there is more than one unsupported expression - // to provide info for. + reasons += s"unsupported hash partitioning expression: $expr" } } for (dt <- expressions.map(_.dataType).distinct) { if (!supportedHashPartitioningDataType(dt)) { - withInfo(s, s"unsupported hash partitioning data type for native shuffle: $dt") - supported = false + reasons += s"unsupported hash partitioning data type for native shuffle: $dt" } } - supported case SinglePartition => - // we already checked that the input types are supported - true + // we already checked that the input types are supported case RangePartitioning(orderings, _) => if (!CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.get(conf)) { - withInfo( - s, - s"${CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.key} is disabled") - return false + reasons += + s"${CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.key} is disabled" + return reasons.toSeq } - var supported = true for (o <- orderings) { if (QueryPlanSerde.exprToProto(o, inputs).isEmpty) { - withInfo(s, s"unsupported range partitioning sort order: $o", o) - supported = false - // We don't short-circuit in case there is more than one unsupported expression - // to provide info for. + reasons += s"unsupported range partitioning sort order: $o" + // Roll up fallback reasons recorded on the sort-order expression (e.g. strict + // floating-point sort) so they surface in the shuffle's explain output. + o.getTagValue(CometExplainInfo.EXTENSION_INFO).foreach(reasons ++= _) } } for (dt <- orderings.map(_.dataType).distinct) { if (!supportedRangePartitioningDataType(dt)) { - withInfo(s, s"unsupported range partitioning data type for native shuffle: $dt") - supported = false + reasons += s"unsupported range partitioning data type for native shuffle: $dt" } } - supported case RoundRobinPartitioning(_) => val config = CometConf.COMET_EXEC_SHUFFLE_WITH_ROUND_ROBIN_PARTITIONING_ENABLED if (!config.get(conf)) { - withInfo(s, s"${config.key} is disabled") - return false + reasons += s"${config.key} is disabled" } - // RoundRobin partitioning uses position-based distribution matching Spark's behavior - true case _ => - withInfo( - s, - s"unsupported Spark partitioning for native shuffle: ${partitioning.getClass.getName}") - false + reasons += + s"unsupported Spark partitioning for native shuffle: ${partitioning.getClass.getName}" } + reasons.toSeq } /** - * Check if JVM-based columnar shuffle (CometColumnarExchange) can be used for this shuffle. JVM - * shuffle is used when the child plan is not a Comet native operator, or when native shuffle - * doesn't support the required partitioning type. + * Reasons the columnar shuffle path cannot handle this shuffle. Empty means columnar is + * supported. Pure: does not tag the node. */ - def columnarShuffleSupported(s: ShuffleExchangeExec): Boolean = { + private def columnarShuffleFailureReasons(s: ShuffleExchangeExec): Seq[String] = { /** * Determine which data types are supported as data columns in columnar shuffle. @@ -456,80 +463,55 @@ object CometShuffleExchangeExec false } - // Preserve any prior-pass fallback decision (see `CometFallback`). - if (isMarkedForFallback(s)) { - return false - } - - if (!isCometShuffleEnabledWithInfo(s)) { - return false - } - - if (CometConf.COMET_DPP_FALLBACK_ENABLED.get() && stageContainsDPPScan(s)) { - markForFallback(s, "Stage contains a scan with Dynamic Partition Pruning") - return false - } + val reasons = scala.collection.mutable.ListBuffer.empty[String] if (!isCometJVMShuffleMode(s.conf)) { - withInfo(s, "Comet columnar shuffle not enabled") - return false + reasons += "Comet columnar shuffle not enabled" + return reasons.toSeq } if (isShuffleOperator(s.child)) { - withInfo(s, s"Child ${s.child.getClass.getName} is a shuffle operator") - return false + reasons += s"Child ${s.child.getClass.getName} is a shuffle operator" + return reasons.toSeq } if (!(!s.child.supportsColumnar || isCometPlan(s.child))) { - withInfo(s, s"Child ${s.child.getClass.getName} is a neither row-based or a Comet operator") - return false + reasons += s"Child ${s.child.getClass.getName} is a neither row-based or a Comet operator" + return reasons.toSeq } val inputs = s.child.output for (input <- inputs) { if (!supportedSerializableDataType(input.dataType)) { - withInfo(s, s"unsupported shuffle data type ${input.dataType} for input $input") - return false + reasons += s"unsupported shuffle data type ${input.dataType} for input $input" + return reasons.toSeq } } val partitioning = s.outputPartitioning partitioning match { case HashPartitioning(expressions, _) => - var supported = true for (expr <- expressions) { if (QueryPlanSerde.exprToProto(expr, inputs).isEmpty) { - withInfo(s, s"unsupported hash partitioning expression: $expr") - supported = false - // We don't short-circuit in case there is more than one unsupported expression - // to provide info for. + reasons += s"unsupported hash partitioning expression: $expr" } } - supported case SinglePartition => - // we already checked that the input types are supported - true + // we already checked that the input types are supported case RoundRobinPartitioning(_) => - // we already checked that the input types are supported - true + // we already checked that the input types are supported case RangePartitioning(orderings, _) => - var supported = true for (o <- orderings) { if (QueryPlanSerde.exprToProto(o, inputs).isEmpty) { - withInfo(s, s"unsupported range partitioning sort order: $o") - supported = false - // We don't short-circuit in case there is more than one unsupported expression - // to provide info for. + reasons += s"unsupported range partitioning sort order: $o" } } - supported case _ => - withInfo( - s, - s"unsupported Spark partitioning for columnar shuffle: ${partitioning.getClass.getName}") - false + reasons += + s"unsupported Spark partitioning for columnar shuffle: ${partitioning.getClass.getName}" } + reasons.toSeq } private def isCometNativeShuffleMode(conf: SQLConf): Boolean = { @@ -578,17 +560,17 @@ object CometShuffleExchangeExec } } - def isCometShuffleEnabledWithInfo(op: SparkPlan): Boolean = { + /** + * Reason Comet shuffle is not enabled for this node, or `None` if it is enabled. Pure: does not + * tag the node. + */ + private def isCometShuffleEnabledReason(op: SparkPlan): Option[String] = { if (!COMET_EXEC_SHUFFLE_ENABLED.get(op.conf)) { - withInfo( - op, - s"Comet shuffle is not enabled: ${COMET_EXEC_SHUFFLE_ENABLED.key} is not enabled") - false + Some(s"Comet shuffle is not enabled: ${COMET_EXEC_SHUFFLE_ENABLED.key} is not enabled") } else if (!isCometShuffleManagerEnabled(op.conf)) { - withInfo(op, s"spark.shuffle.manager is not set to ${classOf[CometShuffleManager].getName}") - false + Some(s"spark.shuffle.manager is not set to ${classOf[CometShuffleManager].getName}") } else { - true + None } } diff --git a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala index cf6f8918f4..216b750d9f 100644 --- a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala @@ -30,7 +30,9 @@ import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.types.{DataTypes, StructField, StructType} -import org.apache.comet.CometConf +import org.apache.comet.{CometConf, ConfigEntry} +import org.apache.comet.serde.{CometOperatorSerde, OperatorOuterClass, SupportLevel, Unsupported} +import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator} /** @@ -206,6 +208,67 @@ class CometExecRuleSuite extends CometTestBase { } } + test("convertToComet short-circuits repeat call for a handler that already failed") { + withTempView("test_data") { + createTestDataFrame.createOrReplaceTempView("test_data") + + val sparkPlan = createSparkPlan(spark, "SELECT id FROM test_data WHERE id > 0") + val op = sparkPlan.collectFirst { case f: FilterExec => f }.get + + val stub = new StubUnsupportedHandlerA + val rule = CometExecRule(spark) + + assert(rule.convertToComet(op, stub).isEmpty) + assert(stub.calls == 1, s"first call should invoke getSupportLevel once, got ${stub.calls}") + + assert(rule.convertToComet(op, stub).isEmpty) + assert( + stub.calls == 1, + s"second call should short-circuit; getSupportLevel calls = ${stub.calls}") + } + } + + test("convertToComet short-circuit is per-handler (other handlers still run)") { + withTempView("test_data") { + createTestDataFrame.createOrReplaceTempView("test_data") + + val sparkPlan = createSparkPlan(spark, "SELECT id FROM test_data WHERE id > 0") + val op = sparkPlan.collectFirst { case f: FilterExec => f }.get + + val stubA = new StubUnsupportedHandlerA + val stubB = new StubUnsupportedHandlerB + val rule = CometExecRule(spark) + + assert(rule.convertToComet(op, stubA).isEmpty) + assert(rule.convertToComet(op, stubA).isEmpty) + assert( + stubA.calls == 1, + s"handler A should be short-circuited on repeat, got ${stubA.calls}") + + // Different handler class on the same node must still run even though A already failed. + assert(rule.convertToComet(op, stubB).isEmpty) + assert(stubB.calls == 1, s"handler B (different class) must still run, got ${stubB.calls}") + } + } + + private abstract class CountingStubHandler extends CometOperatorSerde[SparkPlan] { + var calls: Int = 0 + override def enabledConfig: Option[ConfigEntry[Boolean]] = None + override def getSupportLevel(operator: SparkPlan): SupportLevel = { + calls += 1 + Unsupported(Some("stub fallback")) + } + override def convert( + op: SparkPlan, + builder: Operator.Builder, + childOp: Operator*): Option[OperatorOuterClass.Operator] = None + override def createExec(nativeOp: Operator, op: SparkPlan): CometNativeExec = + throw new AssertionError("createExec should not be invoked on Unsupported") + } + + private class StubUnsupportedHandlerA extends CountingStubHandler + private class StubUnsupportedHandlerB extends CountingStubHandler + test("CometExecRule should apply shuffle exchange transformations") { withTempView("test_data") { createTestDataFrame.createOrReplaceTempView("test_data") diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometDppFallbackRepro3949Suite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometDppFallbackRepro3949Suite.scala index 5b74b590d2..955d900888 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometDppFallbackRepro3949Suite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometDppFallbackRepro3949Suite.scala @@ -124,7 +124,7 @@ class CometDppFallbackRepro3949Suite extends CometTestBase { fail(s"No ShuffleExchangeExec found in initial plan:\n${initialPlan.treeString}") } - val initialDecision = CometShuffleExchangeExec.columnarShuffleSupported(shuffle) + val initialDecision = CometShuffleExchangeExec.shuffleSupported(shuffle) val initialDppVisible = shuffle.child.exists { case scan: FileSourceScanExec => @@ -135,13 +135,13 @@ class CometDppFallbackRepro3949Suite extends CometTestBase { // Simulate AQE stage prep: wrap the shuffle's child in an opaque LeafExecNode, // matching how `ShuffleQueryStageExec` presents to `.exists` walks (its `children` // is `Seq.empty`). `withNewChildren` preserves tree-node tags, so if the fix is in - // place the sticky CometFallback marker on `shuffle` carries over to - // `postAqeShuffle`, and the decision short-circuits to false. Without the fix, - // the DPP walk re-runs, fails to see the scan, and flips to true. + // place the explain-info tag on `shuffle` carries over to `postAqeShuffle`, and the + // decision short-circuits to None. Without the fix, the DPP walk re-runs, fails to + // see the scan, and flips to Some(...). val hiddenChild = OpaqueStageStub(shuffle.child.output) val postAqeShuffle = shuffle.withNewChildren(Seq(hiddenChild)).asInstanceOf[ShuffleExchangeExec] - val postAqeDecision = CometShuffleExchangeExec.columnarShuffleSupported(postAqeShuffle) + val postAqeDecision = CometShuffleExchangeExec.shuffleSupported(postAqeShuffle) val postAqeDppVisible = postAqeShuffle.child.exists { case scan: FileSourceScanExec => @@ -151,9 +151,9 @@ class CometDppFallbackRepro3949Suite extends CometTestBase { assert(initialDppVisible, "initial child tree should expose DPP scan") assert(!postAqeDppVisible, "stage-wrapped child should hide DPP scan") - assert(!initialDecision, s"expected fall back initially, got $initialDecision") + assert(initialDecision.isEmpty, s"expected fall back initially, got $initialDecision") assert( - !postAqeDecision, + postAqeDecision.isEmpty, s"decision must stay 'fall back' across the AQE-style wrap, got $postAqeDecision") } } diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometShuffleFallbackStickinessSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometShuffleFallbackStickinessSuite.scala index 7e3ee63502..0374c57068 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometShuffleFallbackStickinessSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometShuffleFallbackStickinessSuite.scala @@ -29,52 +29,36 @@ import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.internal.SQLConf -import org.apache.comet.{CometConf, CometFallback} +import org.apache.comet.CometConf +import org.apache.comet.CometSparkSessionExtensions.{hasExplainInfo, withInfo} /** - * Pins the sticky-fallback invariant for Comet shuffle decisions: `nativeShuffleSupported` / - * `columnarShuffleSupported` must return `false` whenever the shuffle already carries a - * `CometFallback` marker from a prior rule pass. + * Pins the sticky-fallback invariant for Comet shuffle decisions: `shuffleSupported` must return + * `None` whenever the shuffle already carries explain info from a prior rule pass. * - * Without this behavior, AQE's stage-prep rule re-evaluation can flip the decision — e.g., + * Without this behavior, AQE's stage-prep rule re-evaluation can flip the decision - e.g., * `stageContainsDPPScan` walks the shuffle's child tree with `.exists`, but a materialized child * stage is wrapped in `ShuffleQueryStageExec` (a `LeafExecNode`) so `.exists` stops at the * wrapper and the DPP scan becomes invisible. That causes the same shuffle to fall back to Spark * at initial planning and then convert to Comet at stage prep, producing plan-shape * inconsistencies across the two passes (suspected mechanism behind #3949). * - * Fallback decisions that must survive AQE replanning use `CometFallback.markForFallback`. The - * shuffle-support predicates check `isMarkedForFallback` at the top and short-circuit. + * The coordinator tags the node with `withInfos` only on total fallback and short-circuits via + * `hasExplainInfo` on subsequent passes. */ class CometShuffleFallbackStickinessSuite extends CometTestBase { - test("both support predicates fall back when the shuffle carries a CometFallback marker") { + test("shuffleSupported returns None when the shuffle already carries explain info") { val shuffle = ShuffleExchangeExec(SinglePartition, SyntheticLeaf(Nil)) - CometFallback.markForFallback(shuffle, "pretend prior pass decided Spark fallback") + withInfo(shuffle, "pretend prior pass decided Spark fallback") withSQLConf(CometConf.COMET_DPP_FALLBACK_ENABLED.key -> "true") { assert( - !CometShuffleExchangeExec.columnarShuffleSupported(shuffle), - "marked shuffle must preserve its prior-pass fallback decision (columnar path)") - assert( - !CometShuffleExchangeExec.nativeShuffleSupported(shuffle), - "marked shuffle must preserve its prior-pass fallback decision (native path)") + CometShuffleExchangeExec.shuffleSupported(shuffle).isEmpty, + "marked shuffle must preserve its prior-pass fallback decision") } } - test("informational explain-info alone does NOT force fallback") { - // A shuffle can accumulate explain info (e.g. 'Comet native shuffle not enabled') as - // informational output from earlier checks without being a full-fallback signal. That - // info must not cause the columnar path to decline. - val shuffle = ShuffleExchangeExec(SinglePartition, SyntheticLeaf(Nil)) - // Note: withInfo, not markForFallback. - org.apache.comet.CometSparkSessionExtensions - .withInfo(shuffle, "Comet native shuffle not enabled") - assert( - !CometFallback.isMarkedForFallback(shuffle), - "explain info alone must not imply a sticky fallback marker") - } - test( "DPP fallback decision is sticky across two invocations even when the child tree changes") { withTempDir { dir => @@ -120,23 +104,21 @@ class CometShuffleFallbackStickinessSuite extends CometTestBase { .collectFirst { case s: ShuffleExchangeExec => s } .getOrElse(fail(s"no shuffle found:\n${initial.treeString}")) - // Pass 1: real DPP subtree visible. Returns false AND marks the shuffle. - val first = CometShuffleExchangeExec.columnarShuffleSupported(shuffle) - assert(!first, "initial pass must fall back (DPP visible)") - assert( - CometFallback.isMarkedForFallback(shuffle), - "fallback marker must be placed on the shuffle") + // Pass 1: real DPP subtree visible. Returns None AND tags the shuffle. + val first = CometShuffleExchangeExec.shuffleSupported(shuffle) + assert(first.isEmpty, "initial pass must fall back (DPP visible)") + assert(hasExplainInfo(shuffle), "fallback reason must be tagged on the shuffle") // Pass 2 simulates AQE stage-prep: replace the child with an opaque leaf that hides - // the DPP subtree from tree walks. A naive `.exists`-based check would flip to true - // here; the sticky marker must keep the decision stable. + // the DPP subtree from tree walks. A naive `.exists`-based check would flip to "convert" + // here; the sticky tag must keep the decision stable. val reshapedShuffle = shuffle .withNewChildren(Seq(SyntheticLeaf(shuffle.child.output))) .asInstanceOf[ShuffleExchangeExec] - val second = CometShuffleExchangeExec.columnarShuffleSupported(reshapedShuffle) + val second = CometShuffleExchangeExec.shuffleSupported(reshapedShuffle) assert( - !second, + second.isEmpty, "second pass must still fall back even though the DPP subtree is now hidden") } }