diff --git a/docs/source/user-guide/latest/compatibility.md b/docs/source/user-guide/latest/compatibility.md index b12ebfed45..1b66345210 100644 --- a/docs/source/user-guide/latest/compatibility.md +++ b/docs/source/user-guide/latest/compatibility.md @@ -58,6 +58,12 @@ Expressions that are not 100% Spark-compatible will fall back to Spark by defaul `spark.comet.expression.EXPRNAME.allowIncompatible=true`, where `EXPRNAME` is the Spark expression class name. See the [Comet Supported Expressions Guide](expressions.md) for more information on this configuration setting. +### Aggregate Expressions + +- **CollectSet**: Comet deduplicates NaN values (treats `NaN == NaN`) while Spark treats each NaN as a distinct value. + When `spark.comet.exec.strictFloatingPoint=true`, `collect_set` on floating-point types falls back to Spark unless + `spark.comet.expression.CollectSet.allowIncompatible=true` is set. + ### Array Expressions - **ArraysOverlap**: Inconsistent behavior when arrays contain NULL values. diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index c4ab531814..c0f48065bf 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -203,6 +203,7 @@ Expressions that are not Spark-compatible will fall back to Spark by default and | BitXorAgg | | Yes | | | BoolAnd | `bool_and` | Yes | | | BoolOr | `bool_or` | Yes | | +| CollectSet | | No | NaN dedup differs from Spark. See compatibility guide. | | Corr | | Yes | | | Count | | Yes | | | CovPopulation | | Yes | | diff --git a/docs/spark_expressions_support.md b/docs/spark_expressions_support.md index 9d9e8f7017..e6b3ca69ee 100644 --- a/docs/spark_expressions_support.md +++ b/docs/spark_expressions_support.md @@ -33,7 +33,7 @@ - [x] bool_and - [x] bool_or - [ ] collect_list -- [ ] collect_set +- [x] collect_set - [ ] corr - [x] count - [x] count_if diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index ac35925ace..176104a3a5 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -70,6 +70,7 @@ use datafusion_comet_spark_expr::{ create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, BinaryOutputStyle, BloomFilterAgg, BloomFilterMightContain, CsvWriteOptions, EvalMode, SumInteger, ToCsv, }; +use datafusion_spark::function::aggregate::collect::SparkCollectSet; use iceberg::expr::Bind; use crate::execution::operators::ExecutionError::GeneralError; @@ -2266,6 +2267,11 @@ impl PhysicalPlanner { )); Self::create_aggr_func_expr("bloom_filter_agg", schema, vec![child], func) } + AggExprStruct::CollectSet(expr) => { + let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; + let func = AggregateUDF::new_from_impl(SparkCollectSet::new()); + Self::create_aggr_func_expr("collect_set", schema, vec![child], func) + } } } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 369c64a4c7..52b9849e10 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -140,6 +140,7 @@ message AggExpr { Stddev stddev = 14; Correlation correlation = 15; BloomFilterAgg bloomFilterAgg = 16; + CollectSet collectSet = 17; } // Optional filter expression for SQL FILTER (WHERE ...) clause. @@ -248,6 +249,11 @@ message BloomFilterAgg { DataType datatype = 4; } +message CollectSet { + Expr child = 1; + DataType datatype = 2; +} + enum EvalMode { LEGACY = 0; TRY = 1; diff --git a/spark/src/main/scala/org/apache/comet/serde/CometSortOrder.scala b/spark/src/main/scala/org/apache/comet/serde/CometSortOrder.scala index aabe34f13a..3390d86cee 100644 --- a/spark/src/main/scala/org/apache/comet/serde/CometSortOrder.scala +++ b/spark/src/main/scala/org/apache/comet/serde/CometSortOrder.scala @@ -20,7 +20,6 @@ package org.apache.comet.serde import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Descending, NullsFirst, NullsLast, SortOrder} -import org.apache.spark.sql.types._ import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo @@ -30,19 +29,8 @@ object CometSortOrder extends CometExpressionSerde[SortOrder] { override def getSupportLevel(expr: SortOrder): SupportLevel = { - def containsFloatingPoint(dt: DataType): Boolean = { - dt match { - case DataTypes.FloatType | DataTypes.DoubleType => true - case ArrayType(elementType, _) => containsFloatingPoint(elementType) - case StructType(fields) => fields.exists(f => containsFloatingPoint(f.dataType)) - case MapType(keyType, valueType, _) => - containsFloatingPoint(keyType) || containsFloatingPoint(valueType) - case _ => false - } - } - if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() && - containsFloatingPoint(expr.child.dataType)) { + SupportLevel.containsFloatingPoint(expr.child.dataType)) { // https://github.com/apache/datafusion-comet/issues/2626 Incompatible( Some( diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 810d9bd7da..41a64a105a 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -262,6 +262,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[BitOrAgg] -> CometBitOrAgg, classOf[BitXorAgg] -> CometBitXOrAgg, classOf[BloomFilterAggregate] -> CometBloomFilterAggregate, + classOf[CollectSet] -> CometCollectSet, classOf[Corr] -> CometCorr, classOf[Count] -> CometCount, classOf[CovPopulation] -> CometCovPopulation, diff --git a/spark/src/main/scala/org/apache/comet/serde/SupportLevel.scala b/spark/src/main/scala/org/apache/comet/serde/SupportLevel.scala index d5a524077d..cb78c7d2d4 100644 --- a/spark/src/main/scala/org/apache/comet/serde/SupportLevel.scala +++ b/spark/src/main/scala/org/apache/comet/serde/SupportLevel.scala @@ -19,6 +19,8 @@ package org.apache.comet.serde +import org.apache.spark.sql.types._ + sealed trait SupportLevel /** @@ -40,3 +42,18 @@ case class Incompatible(notes: Option[String] = None) extends SupportLevel /** Comet does not support this feature */ case class Unsupported(notes: Option[String] = None) extends SupportLevel + +object SupportLevel { + + /** + * Returns true if the given data type contains FloatType or DoubleType at any nesting level. + */ + def containsFloatingPoint(dt: DataType): Boolean = dt match { + case FloatType | DoubleType => true + case ArrayType(elementType, _) => containsFloatingPoint(elementType) + case StructType(fields) => fields.exists(f => containsFloatingPoint(f.dataType)) + case MapType(keyType, valueType, _) => + containsFloatingPoint(keyType) || containsFloatingPoint(valueType) + case _ => false + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 8e58c08740..7d78bbe3e5 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -22,7 +22,7 @@ package org.apache.comet.serde import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType} @@ -664,6 +664,52 @@ object CometBloomFilterAggregate extends CometAggregateExpressionSerde[BloomFilt } } +object CometCollectSet extends CometAggregateExpressionSerde[CollectSet] { + + override def getSupportLevel(expr: CollectSet): SupportLevel = { + if (COMET_EXEC_STRICT_FLOATING_POINT.get() && + SupportLevel.containsFloatingPoint(expr.children.head.dataType)) { + Incompatible( + Some( + "collect_set on floating-point types is not 100% compatible with Spark " + + "(Comet deduplicates NaN values while Spark treats each NaN as distinct), " + + s"and Comet is running with ${COMET_EXEC_STRICT_FLOATING_POINT.key}=true. " + + s"${CometConf.COMPAT_GUIDE}")) + } else { + Compatible() + } + } + + override def convert( + aggExpr: AggregateExpression, + expr: CollectSet, + inputs: Seq[Attribute], + binding: Boolean, + conf: SQLConf): Option[ExprOuterClass.AggExpr] = { + val child = expr.children.head + val childExpr = exprToProto(child, inputs, binding) + val dataType = serializeDataType(expr.dataType) + + if (childExpr.isDefined && dataType.isDefined) { + val builder = ExprOuterClass.CollectSet.newBuilder() + builder.setChild(childExpr.get) + builder.setDatatype(dataType.get) + + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setCollectSet(builder) + .build()) + } else if (dataType.isEmpty) { + withInfo(aggExpr, s"datatype ${expr.dataType} is not supported", child) + None + } else { + withInfo(aggExpr, child) + None + } + } +} + object AggSerde { import org.apache.spark.sql.types._ diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index e05967958a..2a0a1c59b4 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -138,16 +138,6 @@ object CometArrayDistinct extends CometExpressionSerde[ArrayDistinct] { } object CometSortArray extends CometExpressionSerde[SortArray] { - private def containsFloatingPoint(dt: DataType): Boolean = { - dt match { - case FloatType | DoubleType => true - case ArrayType(elementType, _) => containsFloatingPoint(elementType) - case StructType(fields) => fields.exists(f => containsFloatingPoint(f.dataType)) - case MapType(keyType, valueType, _) => - containsFloatingPoint(keyType) || containsFloatingPoint(valueType) - case _ => false - } - } private def supportedSortArrayElementType( dt: DataType, @@ -173,7 +163,7 @@ object CometSortArray extends CometExpressionSerde[SortArray] { if (!supportedSortArrayElementType(elementType)) { Unsupported(Some(s"Sort on array element type $elementType is not supported")) } else if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() && - containsFloatingPoint(elementType)) { + SupportLevel.containsFloatingPoint(elementType)) { Incompatible( Some( "Sorting on floating-point is not 100% compatible with Spark, and Comet is running " + diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index a9c065d726..a6f6b03330 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -30,7 +30,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, ExpressionSet, Generator, NamedExpression, SortOrder} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, Final, Partial, PartialMerge} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, CollectSet, Final, Partial, PartialMerge} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -1580,7 +1580,7 @@ object CometObjectHashAggregateExec CometHashAggregateExec( nativeOp, op, - op.output, + adjustOutputForNativeState(op), op.groupingExpressions, op.aggregateExpressions, op.resultExpressions, @@ -1588,6 +1588,43 @@ object CometObjectHashAggregateExec op.child, SerializedPlan(None)) } + + /** + * For Partial mode aggregates containing TypedImperativeAggregate functions (like CollectSet), + * the Spark-side output declares buffer columns as BinaryType (since Spark serializes state to + * binary). However, the native Comet aggregate produces the actual state type (e.g., + * ArrayType(elementType) for CollectSet). This method corrects the output schema to match the + * native state types so the shuffle exchange schema is consistent with the actual data. + * + * NOTE: If a new TypedImperativeAggregate function (e.g., CollectList) is added natively, add a + * case branch here mapping it to the native state type. + */ + private def adjustOutputForNativeState(op: ObjectHashAggregateExec): Seq[Attribute] = { + // CometBaseAggregate.doConvert guarantees all expressions share the same mode. + val modes = op.aggregateExpressions.map(_.mode).distinct + if (modes != Seq(Partial)) { + return op.output + } + + val numGrouping = op.groupingExpressions.length + val output = op.output.toArray + + var bufferIdx = numGrouping + for (aggExpr <- op.aggregateExpressions) { + val aggFunc = aggExpr.aggregateFunction + val bufferAttrs = aggFunc.aggBufferAttributes + aggFunc match { + case cs: CollectSet => + val elementType = cs.children.head.dataType + val nativeStateType = ArrayType(elementType, containsNull = true) + output(bufferIdx) = output(bufferIdx).withDataType(nativeStateType) + case _ => + } + bufferIdx += bufferAttrs.length + } + + output.toSeq + } } case class CometHashAggregateExec( diff --git a/spark/src/test/resources/sql-tests/expressions/aggregate/collect_set.sql b/spark/src/test/resources/sql-tests/expressions/aggregate/collect_set.sql new file mode 100644 index 0000000000..cd528272af --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/aggregate/collect_set.sql @@ -0,0 +1,300 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Config: spark.comet.exec.strictFloatingPoint=true +-- ConfigMatrix: parquet.enable.dictionary=false,true + +-- ============================================================ +-- Setup: tables +-- ============================================================ + +statement +CREATE TABLE cs_src_int(i int, grp string) USING parquet + +statement +INSERT INTO cs_src_int VALUES + (1, 'a'), (2, 'a'), (1, 'a'), (3, 'a'), + (4, 'b'), (4, 'b'), (NULL, 'b'), (5, 'b'), + (NULL, 'c'), (NULL, 'c') + +statement +CREATE TABLE cs_src_nulls(val int, grp string) USING parquet + +statement +INSERT INTO cs_src_nulls VALUES + (NULL, 'a'), (NULL, 'a'), (NULL, 'b'), (1, 'b') + +statement +CREATE TABLE cs_src_empty(val int) USING parquet + +statement +CREATE TABLE cs_src_single(val int) USING parquet + +statement +INSERT INTO cs_src_single VALUES (42) + +statement +CREATE TABLE cs_src_dupes(val int, grp string) USING parquet + +statement +INSERT INTO cs_src_dupes VALUES (7, 'a'), (7, 'a'), (7, 'a'), (8, 'b'), (9, 'b') + +-- ============================================================ +-- Basic: integer dedup (global aggregate, no GROUP BY) +-- ============================================================ + +query +SELECT sort_array(collect_set(i)) FROM cs_src_int + +-- ============================================================ +-- GROUP BY: integer dedup per group +-- ============================================================ + +query +SELECT grp, sort_array(collect_set(i)) FROM cs_src_int GROUP BY grp ORDER BY grp + +-- ============================================================ +-- NULLs: all NULLs in a group returns empty array +-- ============================================================ + +query +SELECT grp, sort_array(collect_set(val)) FROM cs_src_nulls GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Empty table: returns empty array +-- ============================================================ + +query +SELECT sort_array(collect_set(val)) FROM cs_src_empty + +-- ============================================================ +-- Single value +-- ============================================================ + +query +SELECT sort_array(collect_set(val)) FROM cs_src_single + +-- ============================================================ +-- All duplicates in a group +-- ============================================================ + +query +SELECT grp, sort_array(collect_set(val)) FROM cs_src_dupes GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Boolean (with NULLs) +-- ============================================================ + +statement +CREATE TABLE cs_src_bool(v boolean, grp string) USING parquet + +statement +INSERT INTO cs_src_bool VALUES + (true, 'a'), (false, 'a'), (true, 'a'), (NULL, 'a'), + (NULL, 'b'), (true, 'b') + +query +SELECT grp, sort_array(collect_set(v)) FROM cs_src_bool GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Byte / Short (with NULLs) +-- ============================================================ + +statement +CREATE TABLE cs_src_small(b tinyint, s smallint, grp string) USING parquet + +statement +INSERT INTO cs_src_small VALUES + (1, 100, 'a'), (2, 200, 'a'), (1, 100, 'a'), (NULL, NULL, 'a'), + (3, 300, 'b'), (NULL, 300, 'b') + +query +SELECT grp, sort_array(collect_set(b)) FROM cs_src_small GROUP BY grp ORDER BY grp + +query +SELECT grp, sort_array(collect_set(s)) FROM cs_src_small GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Int / BigInt (with NULLs) +-- ============================================================ + +statement +CREATE TABLE cs_src_intbig(i int, bi bigint, grp string) USING parquet + +statement +INSERT INTO cs_src_intbig VALUES + (10, 1000000000000, 'a'), (20, 2000000000000, 'a'), + (10, 1000000000000, 'a'), (NULL, NULL, 'a'), + (30, 3000000000000, 'b'), (30, NULL, 'b') + +query +SELECT grp, sort_array(collect_set(i)) FROM cs_src_intbig GROUP BY grp ORDER BY grp + +query +SELECT grp, sort_array(collect_set(bi)) FROM cs_src_intbig GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Float (with NULLs, NaN, Inf, -Inf, +0, -0) +-- Comet deduplicates NaN while Spark does not; with +-- strictFloatingPoint=true collect_set falls back to Spark. +-- ============================================================ + +statement +CREATE TABLE cs_src_float(v float, grp string) USING parquet + +statement +INSERT INTO cs_src_float VALUES + (1.5, 'a'), (2.5, 'a'), (1.5, 'a'), (NULL, 'a'), + (CAST('NaN' AS FLOAT), 'b'), (CAST('NaN' AS FLOAT), 'b'), (1.0, 'b'), + (CAST('Infinity' AS FLOAT), 'c'), (CAST('-Infinity' AS FLOAT), 'c'), (CAST('Infinity' AS FLOAT), 'c'), + (CAST(0.0 AS FLOAT), 'd'), (CAST(-0.0 AS FLOAT), 'd'), (1.0, 'd'), (NULL, 'd') + +query expect_fallback(not fully compatible with Spark) +SELECT grp, sort_array(collect_set(v)) FROM cs_src_float GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Double (with NULLs, NaN, Inf, -Inf, +0, -0) +-- ============================================================ + +statement +CREATE TABLE cs_src_double(v double, grp string) USING parquet + +statement +INSERT INTO cs_src_double VALUES + (1.1, 'a'), (2.2, 'a'), (1.1, 'a'), (NULL, 'a'), + (CAST('NaN' AS DOUBLE), 'b'), (CAST('NaN' AS DOUBLE), 'b'), (1.0, 'b'), + (CAST('Infinity' AS DOUBLE), 'c'), (CAST('-Infinity' AS DOUBLE), 'c'), (CAST('Infinity' AS DOUBLE), 'c'), + (0.0, 'd'), (-0.0, 'd'), (1.0, 'd'), (NULL, 'd') + +query expect_fallback(not fully compatible with Spark) +SELECT grp, sort_array(collect_set(v)) FROM cs_src_double GROUP BY grp ORDER BY grp + +-- ============================================================ +-- String (with NULLs and empty string) +-- ============================================================ + +statement +CREATE TABLE cs_src_string(v string, grp string) USING parquet + +statement +INSERT INTO cs_src_string VALUES + ('hello', 'a'), ('world', 'a'), ('hello', 'a'), (NULL, 'a'), + ('', 'b'), ('x', 'b'), ('', 'b'), (NULL, 'b') + +query +SELECT grp, sort_array(collect_set(v)) FROM cs_src_string GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Binary (with NULLs) +-- ============================================================ + +statement +CREATE TABLE cs_src_binary(v binary, grp string) USING parquet + +statement +INSERT INTO cs_src_binary VALUES + (X'CAFE', 'a'), (X'BABE', 'a'), (X'CAFE', 'a'), (NULL, 'a'), + (X'', 'b'), (X'FF', 'b'), (NULL, 'b') + +query +SELECT grp, sort_array(collect_set(v)) FROM cs_src_binary GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Decimal (with NULLs) +-- ============================================================ + +statement +CREATE TABLE cs_src_decimal(v decimal(10,2), grp string) USING parquet + +statement +INSERT INTO cs_src_decimal VALUES + (1.50, 'a'), (2.50, 'a'), (1.50, 'a'), (NULL, 'a'), + (0.00, 'b'), (99999999.99, 'b'), (NULL, 'b') + +query +SELECT grp, sort_array(collect_set(v)) FROM cs_src_decimal GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Date (with NULLs) +-- ============================================================ + +statement +CREATE TABLE cs_src_date(v date, grp string) USING parquet + +statement +INSERT INTO cs_src_date VALUES + (DATE '2024-01-01', 'a'), (DATE '2024-06-15', 'a'), (DATE '2024-01-01', 'a'), (NULL, 'a'), + (DATE '1970-01-01', 'b'), (NULL, 'b') + +query +SELECT grp, sort_array(collect_set(v)) FROM cs_src_date GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Timestamp (with NULLs) +-- ============================================================ + +statement +CREATE TABLE cs_src_ts(v timestamp, grp string) USING parquet + +statement +INSERT INTO cs_src_ts VALUES + (TIMESTAMP '2024-01-01 00:00:00', 'a'), (TIMESTAMP '2024-06-15 12:30:00', 'a'), + (TIMESTAMP '2024-01-01 00:00:00', 'a'), (NULL, 'a'), + (TIMESTAMP '1970-01-01 00:00:00', 'b'), (NULL, 'b') + +query +SELECT grp, sort_array(collect_set(v)) FROM cs_src_ts GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Mixed with other aggregates +-- ============================================================ + +query +SELECT grp, sort_array(collect_set(i)), count(*), sum(i) +FROM cs_src_int GROUP BY grp ORDER BY grp + +-- ============================================================ +-- Multiple collect_set in the same query +-- ============================================================ + +statement +CREATE TABLE cs_src_multi(a int, b string, grp string) USING parquet + +statement +INSERT INTO cs_src_multi VALUES + (1, 'x', 'g1'), (2, 'y', 'g1'), (1, 'x', 'g1'), + (3, 'z', 'g2'), (NULL, NULL, 'g2') + +query +SELECT grp, sort_array(collect_set(a)), sort_array(collect_set(b)) +FROM cs_src_multi GROUP BY grp ORDER BY grp + +-- ============================================================ +-- DISTINCT: semantically redundant but exercises a different +-- planner path (distinct aggregate handling) +-- ============================================================ + +query +SELECT grp, sort_array(collect_set(DISTINCT i)) FROM cs_src_int GROUP BY grp ORDER BY grp + +-- ============================================================ +-- HAVING clause with collect_set +-- ============================================================ + +query +SELECT grp, sort_array(collect_set(i)) +FROM cs_src_int GROUP BY grp HAVING size(collect_set(i)) > 1 ORDER BY grp