diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index ac35925ace..d907611322 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -149,6 +149,25 @@ struct JoinParameters { pub join_type: DFJoinType, } +/// If `expr` evaluates to `Timestamp(_, Some(_))` against `schema`, wrap it in a +/// metadata-only cast to `Timestamp(_, None)`. This is required because +/// DataFusion's `SortMergeJoinExec` comparator only supports timezone-less +/// timestamp types, while Spark's `TimestampType` serializes as +/// `Timestamp(µs, "UTC")`. The cast preserves ordering on the same time unit. +fn strip_timestamp_tz( + expr: Arc, + schema: &Schema, +) -> Result, ExecutionError> { + match expr.data_type(schema)? { + DataType::Timestamp(unit, Some(_)) => Ok(Arc::new(CastExpr::new( + expr, + DataType::Timestamp(unit, None), + None, + ))), + _ => Ok(expr), + } +} + #[derive(Default)] pub struct BinaryExprOptions { pub is_integral_div: bool, @@ -1630,10 +1649,23 @@ impl PhysicalPlanner { let left = Arc::clone(&join_params.left.native_plan); let right = Arc::clone(&join_params.right.native_plan); + let left_schema = left.schema(); + let right_schema = right.schema(); + let join_on = join_params + .join_on + .into_iter() + .map(|(l, r)| { + Ok(( + strip_timestamp_tz(l, left_schema.as_ref())?, + strip_timestamp_tz(r, right_schema.as_ref())?, + )) + }) + .collect::, ExecutionError>>()?; + let join = Arc::new(SortMergeJoinExec::try_new( Arc::clone(&left), Arc::clone(&right), - join_params.join_on, + join_on, join_params.join_filter, join_params.join_type, sort_options, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index a9c065d726..109aa3f44f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, TimestampNTZType} +import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, TimestampNTZType, TimestampType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration import org.apache.spark.util.io.ChunkedByteBuffer @@ -2094,7 +2094,7 @@ object CometSortMergeJoinExec extends CometOperatorSerde[SortMergeJoinExec] { case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType | _: StringType | _: DateType | _: DecimalType | _: BooleanType => true - case TimestampNTZType => true + case TimestampNTZType | _: TimestampType => true case _ => false } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index 49fbe10c30..b30733181e 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.Tag import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometSortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf import org.apache.comet.CometConf @@ -54,21 +54,119 @@ class CometJoinSuite extends CometTestBase { .toSeq) } - test("SortMergeJoin with unsupported key type should fall back to Spark") { + test("SortMergeJoin with TimestampType key runs natively") { withSQLConf( SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu", SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.PREFER_SORTMERGEJOIN.key -> "true") { withTable("t1", "t2") { sql("CREATE TABLE t1(name STRING, time TIMESTAMP) USING PARQUET") - sql("INSERT OVERWRITE t1 VALUES('a', timestamp'2019-01-01 11:11:11')") + sql( + "INSERT OVERWRITE t1 VALUES " + + "('a', timestamp'2019-01-01 11:11:11'), " + + "('b', timestamp'2020-05-05 05:05:05')") sql("CREATE TABLE t2(name STRING, time TIMESTAMP) USING PARQUET") - sql("INSERT OVERWRITE t2 VALUES('a', timestamp'2019-01-01 11:11:11')") + sql( + "INSERT OVERWRITE t2 VALUES " + + "('a', timestamp'2019-01-01 11:11:11'), " + + "('c', timestamp'2021-07-07 07:07:07')") + + checkSparkAnswerAndOperator( + sql("SELECT * FROM t1 JOIN t2 ON t1.time = t2.time"), + Seq(classOf[CometSortMergeJoinExec])) + } + } + } + + test("SortMergeJoin with TimestampType key supports outer joins") { + withSQLConf( + SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.PREFER_SORTMERGEJOIN.key -> "true") { + withTable("t1", "t2") { + sql("CREATE TABLE t1(id INT, time TIMESTAMP) USING PARQUET") + sql( + "INSERT OVERWRITE t1 VALUES " + + "(1, timestamp'2019-01-01 11:11:11'), " + + "(2, timestamp'2020-05-05 05:05:05'), " + + "(3, timestamp'2021-07-07 07:07:07')") + + sql("CREATE TABLE t2(id INT, time TIMESTAMP) USING PARQUET") + sql( + "INSERT OVERWRITE t2 VALUES " + + "(10, timestamp'2019-01-01 11:11:11'), " + + "(20, timestamp'2022-02-02 02:02:02')") + + for (joinType <- Seq("LEFT OUTER", "RIGHT OUTER", "FULL OUTER")) { + checkSparkAnswerAndOperator( + sql(s"SELECT * FROM t1 $joinType JOIN t2 ON t1.time = t2.time"), + Seq(classOf[CometSortMergeJoinExec])) + } + } + } + } - val df = sql("SELECT * FROM t1 JOIN t2 ON t1.time = t2.time") - val (sparkPlan, cometPlan) = checkSparkAnswer(df) - assert(sparkPlan.canonicalized === cometPlan.canonicalized) + test("SortMergeJoin with composite (string, timestamp) key runs natively") { + withSQLConf( + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.PREFER_SORTMERGEJOIN.key -> "true") { + withTable("t1", "t2") { + sql("CREATE TABLE t1(name STRING, time TIMESTAMP) USING PARQUET") + sql( + "INSERT OVERWRITE t1 VALUES " + + "('a', timestamp'2019-01-01 11:11:11'), " + + "('b', timestamp'2019-01-01 11:11:11'), " + + "('a', timestamp'2020-05-05 05:05:05')") + + sql("CREATE TABLE t2(name STRING, time TIMESTAMP) USING PARQUET") + sql( + "INSERT OVERWRITE t2 VALUES " + + "('a', timestamp'2019-01-01 11:11:11'), " + + "('b', timestamp'2020-05-05 05:05:05'), " + + "('a', timestamp'2020-05-05 05:05:05')") + + checkSparkAnswerAndOperator( + sql( + "SELECT * FROM t1 JOIN t2 " + + "ON t1.name = t2.name AND t1.time = t2.time"), + Seq(classOf[CometSortMergeJoinExec])) + } + } + } + + test("SortMergeJoin with nullable TimestampType key runs natively") { + withSQLConf( + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.PREFER_SORTMERGEJOIN.key -> "true") { + withTable("t1", "t2") { + sql("CREATE TABLE t1(id INT, time TIMESTAMP) USING PARQUET") + sql( + "INSERT OVERWRITE t1 VALUES " + + "(1, timestamp'2019-01-01 11:11:11'), " + + "(2, CAST(NULL AS TIMESTAMP)), " + + "(3, timestamp'2020-05-05 05:05:05')") + + sql("CREATE TABLE t2(id INT, time TIMESTAMP) USING PARQUET") + sql( + "INSERT OVERWRITE t2 VALUES " + + "(10, timestamp'2019-01-01 11:11:11'), " + + "(20, CAST(NULL AS TIMESTAMP)), " + + "(30, timestamp'2022-02-02 02:02:02')") + + // Inner join: NULL = NULL must not match in Spark semantics. + checkSparkAnswerAndOperator( + sql("SELECT * FROM t1 JOIN t2 ON t1.time = t2.time"), + Seq(classOf[CometSortMergeJoinExec])) + + // Full outer join: NULL-keyed rows from both sides surface as unmatched. + checkSparkAnswerAndOperator( + sql("SELECT * FROM t1 FULL OUTER JOIN t2 ON t1.time = t2.time"), + Seq(classOf[CometSortMergeJoinExec])) } } }