Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/user-guide/latest/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ Expressions that are not 100% Spark-compatible will fall back to Spark by defaul
`spark.comet.expression.EXPRNAME.allowIncompatible=true`, where `EXPRNAME` is the Spark expression class name. See
the [Comet Supported Expressions Guide](expressions.md) for more information on this configuration setting.

### Aggregate Expressions

- **CollectSet**: Comet deduplicates NaN values (treats `NaN == NaN`) while Spark treats each NaN as a distinct value.
When `spark.comet.exec.strictFloatingPoint=true`, `collect_set` on floating-point types falls back to Spark unless
`spark.comet.expression.CollectSet.allowIncompatible=true` is set.

### Array Expressions

- **ArraysOverlap**: Inconsistent behavior when arrays contain NULL values.
Expand Down
1 change: 1 addition & 0 deletions docs/source/user-guide/latest/expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ Expressions that are not Spark-compatible will fall back to Spark by default and
| BitXorAgg | | Yes | |
| BoolAnd | `bool_and` | Yes | |
| BoolOr | `bool_or` | Yes | |
| CollectSet | | No | NaN dedup differs from Spark. See compatibility guide. |
| Corr | | Yes | |
| Count | | Yes | |
| CovPopulation | | Yes | |
Expand Down
2 changes: 1 addition & 1 deletion docs/spark_expressions_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
- [x] bool_and
- [x] bool_or
- [ ] collect_list
- [ ] collect_set
- [x] collect_set
- [ ] corr
- [x] count
- [x] count_if
Expand Down
6 changes: 6 additions & 0 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ use datafusion_comet_spark_expr::{
create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, BinaryOutputStyle,
BloomFilterAgg, BloomFilterMightContain, CsvWriteOptions, EvalMode, SumInteger, ToCsv,
};
use datafusion_spark::function::aggregate::collect::SparkCollectSet;
use iceberg::expr::Bind;

use crate::execution::operators::ExecutionError::GeneralError;
Expand Down Expand Up @@ -2266,6 +2267,11 @@ impl PhysicalPlanner {
));
Self::create_aggr_func_expr("bloom_filter_agg", schema, vec![child], func)
}
AggExprStruct::CollectSet(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let func = AggregateUDF::new_from_impl(SparkCollectSet::new());
Self::create_aggr_func_expr("collect_set", schema, vec![child], func)
}
}
}

Expand Down
6 changes: 6 additions & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ message AggExpr {
Stddev stddev = 14;
Correlation correlation = 15;
BloomFilterAgg bloomFilterAgg = 16;
CollectSet collectSet = 17;
}

// Optional filter expression for SQL FILTER (WHERE ...) clause.
Expand Down Expand Up @@ -248,6 +249,11 @@ message BloomFilterAgg {
DataType datatype = 4;
}

message CollectSet {
Expr child = 1;
DataType datatype = 2;
}

enum EvalMode {
LEGACY = 0;
TRY = 1;
Expand Down
14 changes: 1 addition & 13 deletions spark/src/main/scala/org/apache/comet/serde/CometSortOrder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
package org.apache.comet.serde

import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Descending, NullsFirst, NullsLast, SortOrder}
import org.apache.spark.sql.types._

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.withInfo
Expand All @@ -30,19 +29,8 @@ object CometSortOrder extends CometExpressionSerde[SortOrder] {

override def getSupportLevel(expr: SortOrder): SupportLevel = {

def containsFloatingPoint(dt: DataType): Boolean = {
dt match {
case DataTypes.FloatType | DataTypes.DoubleType => true
case ArrayType(elementType, _) => containsFloatingPoint(elementType)
case StructType(fields) => fields.exists(f => containsFloatingPoint(f.dataType))
case MapType(keyType, valueType, _) =>
containsFloatingPoint(keyType) || containsFloatingPoint(valueType)
case _ => false
}
}

if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() &&
containsFloatingPoint(expr.child.dataType)) {
SupportLevel.containsFloatingPoint(expr.child.dataType)) {
// https://github.com/apache/datafusion-comet/issues/2626
Incompatible(
Some(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
classOf[BitOrAgg] -> CometBitOrAgg,
classOf[BitXorAgg] -> CometBitXOrAgg,
classOf[BloomFilterAggregate] -> CometBloomFilterAggregate,
classOf[CollectSet] -> CometCollectSet,
classOf[Corr] -> CometCorr,
classOf[Count] -> CometCount,
classOf[CovPopulation] -> CometCovPopulation,
Expand Down
17 changes: 17 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/SupportLevel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

package org.apache.comet.serde

import org.apache.spark.sql.types._

sealed trait SupportLevel

/**
Expand All @@ -40,3 +42,18 @@ case class Incompatible(notes: Option[String] = None) extends SupportLevel

/** Comet does not support this feature */
case class Unsupported(notes: Option[String] = None) extends SupportLevel

object SupportLevel {

/**
* Returns true if the given data type contains FloatType or DoubleType at any nesting level.
*/
def containsFloatingPoint(dt: DataType): Boolean = dt match {
case FloatType | DoubleType => true
case ArrayType(elementType, _) => containsFloatingPoint(elementType)
case StructType(fields) => fields.exists(f => containsFloatingPoint(f.dataType))
case MapType(keyType, valueType, _) =>
containsFloatingPoint(keyType) || containsFloatingPoint(valueType)
case _ => false
}
}
48 changes: 47 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/aggregates.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ package org.apache.comet.serde
import scala.jdk.CollectionConverters._

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType}

Expand Down Expand Up @@ -664,6 +664,52 @@ object CometBloomFilterAggregate extends CometAggregateExpressionSerde[BloomFilt
}
}

object CometCollectSet extends CometAggregateExpressionSerde[CollectSet] {

override def getSupportLevel(expr: CollectSet): SupportLevel = {
if (COMET_EXEC_STRICT_FLOATING_POINT.get() &&
SupportLevel.containsFloatingPoint(expr.children.head.dataType)) {
Incompatible(
Some(
"collect_set on floating-point types is not 100% compatible with Spark " +
"(Comet deduplicates NaN values while Spark treats each NaN as distinct), " +
s"and Comet is running with ${COMET_EXEC_STRICT_FLOATING_POINT.key}=true. " +
s"${CometConf.COMPAT_GUIDE}"))
} else {
Compatible()
}
}

override def convert(
aggExpr: AggregateExpression,
expr: CollectSet,
inputs: Seq[Attribute],
binding: Boolean,
conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
val child = expr.children.head
val childExpr = exprToProto(child, inputs, binding)
val dataType = serializeDataType(expr.dataType)

if (childExpr.isDefined && dataType.isDefined) {
val builder = ExprOuterClass.CollectSet.newBuilder()
builder.setChild(childExpr.get)
builder.setDatatype(dataType.get)

Some(
ExprOuterClass.AggExpr
.newBuilder()
.setCollectSet(builder)
.build())
} else if (dataType.isEmpty) {
withInfo(aggExpr, s"datatype ${expr.dataType} is not supported", child)
None
} else {
withInfo(aggExpr, child)
None
}
}
}

object AggSerde {
import org.apache.spark.sql.types._

Expand Down
12 changes: 1 addition & 11 deletions spark/src/main/scala/org/apache/comet/serde/arrays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,6 @@ object CometArrayDistinct extends CometExpressionSerde[ArrayDistinct] {
}

object CometSortArray extends CometExpressionSerde[SortArray] {
private def containsFloatingPoint(dt: DataType): Boolean = {
dt match {
case FloatType | DoubleType => true
case ArrayType(elementType, _) => containsFloatingPoint(elementType)
case StructType(fields) => fields.exists(f => containsFloatingPoint(f.dataType))
case MapType(keyType, valueType, _) =>
containsFloatingPoint(keyType) || containsFloatingPoint(valueType)
case _ => false
}
}

private def supportedSortArrayElementType(
dt: DataType,
Expand All @@ -173,7 +163,7 @@ object CometSortArray extends CometExpressionSerde[SortArray] {
if (!supportedSortArrayElementType(elementType)) {
Unsupported(Some(s"Sort on array element type $elementType is not supported"))
} else if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() &&
containsFloatingPoint(elementType)) {
SupportLevel.containsFloatingPoint(elementType)) {
Incompatible(
Some(
"Sorting on floating-point is not 100% compatible with Spark, and Comet is running " +
Expand Down
41 changes: 39 additions & 2 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, ExpressionSet, Generator, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, Final, Partial, PartialMerge}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, CollectSet, Final, Partial, PartialMerge}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
Expand Down Expand Up @@ -1580,14 +1580,51 @@ object CometObjectHashAggregateExec
CometHashAggregateExec(
nativeOp,
op,
op.output,
adjustOutputForNativeState(op),
op.groupingExpressions,
op.aggregateExpressions,
op.resultExpressions,
op.child.output,
op.child,
SerializedPlan(None))
}

/**
* For Partial mode aggregates containing TypedImperativeAggregate functions (like CollectSet),
* the Spark-side output declares buffer columns as BinaryType (since Spark serializes state to
* binary). However, the native Comet aggregate produces the actual state type (e.g.,
* ArrayType(elementType) for CollectSet). This method corrects the output schema to match the
* native state types so the shuffle exchange schema is consistent with the actual data.
*
* NOTE: If a new TypedImperativeAggregate function (e.g., CollectList) is added natively, add a
* case branch here mapping it to the native state type.
*/
private def adjustOutputForNativeState(op: ObjectHashAggregateExec): Seq[Attribute] = {
// CometBaseAggregate.doConvert guarantees all expressions share the same mode.
val modes = op.aggregateExpressions.map(_.mode).distinct
if (modes != Seq(Partial)) {
return op.output
}

val numGrouping = op.groupingExpressions.length
val output = op.output.toArray

var bufferIdx = numGrouping
for (aggExpr <- op.aggregateExpressions) {
val aggFunc = aggExpr.aggregateFunction
val bufferAttrs = aggFunc.aggBufferAttributes
aggFunc match {
case cs: CollectSet =>
val elementType = cs.children.head.dataType
val nativeStateType = ArrayType(elementType, containsNull = true)
output(bufferIdx) = output(bufferIdx).withDataType(nativeStateType)
case _ =>
}
bufferIdx += bufferAttrs.length
}

output.toSeq
}
}

case class CometHashAggregateExec(
Expand Down
Loading
Loading