Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,25 @@ struct JoinParameters {
pub join_type: DFJoinType,
}

/// If `expr` evaluates to `Timestamp(_, Some(_))` against `schema`, wrap it in a
/// metadata-only cast to `Timestamp(_, None)`. This is required because
/// DataFusion's `SortMergeJoinExec` comparator only supports timezone-less
/// timestamp types, while Spark's `TimestampType` serializes as
/// `Timestamp(µs, "UTC")`. The cast preserves ordering on the same time unit.
fn strip_timestamp_tz(
expr: Arc<dyn PhysicalExpr>,
schema: &Schema,
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
match expr.data_type(schema)? {
DataType::Timestamp(unit, Some(_)) => Ok(Arc::new(CastExpr::new(
expr,
DataType::Timestamp(unit, None),
None,
))),
_ => Ok(expr),
}
}

#[derive(Default)]
pub struct BinaryExprOptions {
pub is_integral_div: bool,
Expand Down Expand Up @@ -1630,10 +1649,23 @@ impl PhysicalPlanner {
let left = Arc::clone(&join_params.left.native_plan);
let right = Arc::clone(&join_params.right.native_plan);

let left_schema = left.schema();
let right_schema = right.schema();
let join_on = join_params
.join_on
.into_iter()
.map(|(l, r)| {
Ok((
strip_timestamp_tz(l, left_schema.as_ref())?,
strip_timestamp_tz(r, right_schema.as_ref())?,
))
})
.collect::<Result<Vec<_>, ExecutionError>>()?;

let join = Arc::new(SortMergeJoinExec::try_new(
Arc::clone(&left),
Arc::clone(&right),
join_params.join_on,
join_on,
join_params.join_filter,
join_params.join_type,
sort_options,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, TimestampNTZType}
import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, TimestampNTZType, TimestampType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.SerializableConfiguration
import org.apache.spark.util.io.ChunkedByteBuffer
Expand Down Expand Up @@ -2094,7 +2094,7 @@ object CometSortMergeJoinExec extends CometOperatorSerde[SortMergeJoinExec] {
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
_: DoubleType | _: StringType | _: DateType | _: DecimalType | _: BooleanType =>
true
case TimestampNTZType => true
case TimestampNTZType | _: TimestampType => true
case _ => false
}

Expand Down
114 changes: 106 additions & 8 deletions spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.scalatest.Tag
import org.apache.spark.sql.CometTestBase
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec}
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometSortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf

import org.apache.comet.CometConf
Expand Down Expand Up @@ -54,21 +54,119 @@ class CometJoinSuite extends CometTestBase {
.toSeq)
}

test("SortMergeJoin with unsupported key type should fall back to Spark") {
test("SortMergeJoin with TimestampType key runs natively") {
withSQLConf(
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu",
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.PREFER_SORTMERGEJOIN.key -> "true") {
withTable("t1", "t2") {
sql("CREATE TABLE t1(name STRING, time TIMESTAMP) USING PARQUET")
sql("INSERT OVERWRITE t1 VALUES('a', timestamp'2019-01-01 11:11:11')")
sql(
"INSERT OVERWRITE t1 VALUES " +
"('a', timestamp'2019-01-01 11:11:11'), " +
"('b', timestamp'2020-05-05 05:05:05')")

sql("CREATE TABLE t2(name STRING, time TIMESTAMP) USING PARQUET")
sql("INSERT OVERWRITE t2 VALUES('a', timestamp'2019-01-01 11:11:11')")
sql(
"INSERT OVERWRITE t2 VALUES " +
"('a', timestamp'2019-01-01 11:11:11'), " +
"('c', timestamp'2021-07-07 07:07:07')")

checkSparkAnswerAndOperator(
sql("SELECT * FROM t1 JOIN t2 ON t1.time = t2.time"),
Seq(classOf[CometSortMergeJoinExec]))
}
}
}

test("SortMergeJoin with TimestampType key supports outer joins") {
withSQLConf(
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu",
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.PREFER_SORTMERGEJOIN.key -> "true") {
withTable("t1", "t2") {
sql("CREATE TABLE t1(id INT, time TIMESTAMP) USING PARQUET")
sql(
"INSERT OVERWRITE t1 VALUES " +
"(1, timestamp'2019-01-01 11:11:11'), " +
"(2, timestamp'2020-05-05 05:05:05'), " +
"(3, timestamp'2021-07-07 07:07:07')")

sql("CREATE TABLE t2(id INT, time TIMESTAMP) USING PARQUET")
sql(
"INSERT OVERWRITE t2 VALUES " +
"(10, timestamp'2019-01-01 11:11:11'), " +
"(20, timestamp'2022-02-02 02:02:02')")

for (joinType <- Seq("LEFT OUTER", "RIGHT OUTER", "FULL OUTER")) {
checkSparkAnswerAndOperator(
sql(s"SELECT * FROM t1 $joinType JOIN t2 ON t1.time = t2.time"),
Seq(classOf[CometSortMergeJoinExec]))
}
}
}
}

val df = sql("SELECT * FROM t1 JOIN t2 ON t1.time = t2.time")
val (sparkPlan, cometPlan) = checkSparkAnswer(df)
assert(sparkPlan.canonicalized === cometPlan.canonicalized)
test("SortMergeJoin with composite (string, timestamp) key runs natively") {
withSQLConf(
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.PREFER_SORTMERGEJOIN.key -> "true") {
withTable("t1", "t2") {
sql("CREATE TABLE t1(name STRING, time TIMESTAMP) USING PARQUET")
sql(
"INSERT OVERWRITE t1 VALUES " +
"('a', timestamp'2019-01-01 11:11:11'), " +
"('b', timestamp'2019-01-01 11:11:11'), " +
"('a', timestamp'2020-05-05 05:05:05')")

sql("CREATE TABLE t2(name STRING, time TIMESTAMP) USING PARQUET")
sql(
"INSERT OVERWRITE t2 VALUES " +
"('a', timestamp'2019-01-01 11:11:11'), " +
"('b', timestamp'2020-05-05 05:05:05'), " +
"('a', timestamp'2020-05-05 05:05:05')")

checkSparkAnswerAndOperator(
sql(
"SELECT * FROM t1 JOIN t2 " +
"ON t1.name = t2.name AND t1.time = t2.time"),
Seq(classOf[CometSortMergeJoinExec]))
}
}
}

test("SortMergeJoin with nullable TimestampType key runs natively") {
withSQLConf(
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.PREFER_SORTMERGEJOIN.key -> "true") {
withTable("t1", "t2") {
sql("CREATE TABLE t1(id INT, time TIMESTAMP) USING PARQUET")
sql(
"INSERT OVERWRITE t1 VALUES " +
"(1, timestamp'2019-01-01 11:11:11'), " +
"(2, CAST(NULL AS TIMESTAMP)), " +
"(3, timestamp'2020-05-05 05:05:05')")

sql("CREATE TABLE t2(id INT, time TIMESTAMP) USING PARQUET")
sql(
"INSERT OVERWRITE t2 VALUES " +
"(10, timestamp'2019-01-01 11:11:11'), " +
"(20, CAST(NULL AS TIMESTAMP)), " +
"(30, timestamp'2022-02-02 02:02:02')")

// Inner join: NULL = NULL must not match in Spark semantics.
checkSparkAnswerAndOperator(
sql("SELECT * FROM t1 JOIN t2 ON t1.time = t2.time"),
Seq(classOf[CometSortMergeJoinExec]))

// Full outer join: NULL-keyed rows from both sides surface as unmatched.
checkSparkAnswerAndOperator(
sql("SELECT * FROM t1 FULL OUTER JOIN t2 ON t1.time = t2.time"),
Seq(classOf[CometSortMergeJoinExec]))
}
}
}
Expand Down
Loading