From 7b9898fff447cafc034bb2930f79e204cae9166f Mon Sep 17 00:00:00 2001 From: xuzifu666 <1206332514@qq.com> Date: Thu, 2 Apr 2026 17:12:30 +0800 Subject: [PATCH 1/3] [spark] Support Nan check in SparkFilterConverter --- .../paimon/spark/SparkFilterConverter.java | 30 ++++- .../paimon/spark/SparkV2FilterConverter.scala | 44 +++++++- .../sql/SparkV2FilterConverterTestBase.scala | 104 ++++++++++++++++++ 3 files changed, 171 insertions(+), 7 deletions(-) diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java index c0b8cfd66be1..accf6488045d 100644 --- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java @@ -96,12 +96,24 @@ public Predicate convert(Filter filter, boolean ignoreFailure) { } } + private boolean isNaN(Object value) { + if (value instanceof Float) { + return Float.isNaN((Float) value); + } else if (value instanceof Double) { + return Double.isNaN((Double) value); + } + return false; + } + public Predicate convert(Filter filter) { if (filter instanceof EqualTo) { EqualTo eq = (EqualTo) filter; - // TODO deal with isNaN int index = fieldIndex(eq.attribute()); Object literal = convertLiteral(index, eq.value()); + if (isNaN(literal)) { + // NaN != NaN, so equality with NaN should never match + return PredicateBuilder.alwaysFalse(); + } return builder.equal(index, literal); } else if (filter instanceof EqualNullSafe) { EqualNullSafe eq = (EqualNullSafe) filter; @@ -116,21 +128,37 @@ public Predicate convert(Filter filter) { GreaterThan gt = (GreaterThan) filter; int index = fieldIndex(gt.attribute()); Object literal = convertLiteral(index, gt.value()); + if (isNaN(literal)) { + // Any comparison with NaN is false + return PredicateBuilder.alwaysFalse(); + } return builder.greaterThan(index, literal); } else if (filter instanceof GreaterThanOrEqual) { GreaterThanOrEqual gt = (GreaterThanOrEqual) filter; int index = fieldIndex(gt.attribute()); Object literal = convertLiteral(index, gt.value()); + if (isNaN(literal)) { + // Any comparison with NaN is false + return PredicateBuilder.alwaysFalse(); + } return builder.greaterOrEqual(index, literal); } else if (filter instanceof LessThan) { LessThan lt = (LessThan) filter; int index = fieldIndex(lt.attribute()); Object literal = convertLiteral(index, lt.value()); + if (isNaN(literal)) { + // Any comparison with NaN is false + return PredicateBuilder.alwaysFalse(); + } return builder.lessThan(index, literal); } else if (filter instanceof LessThanOrEqual) { LessThanOrEqual lt = (LessThanOrEqual) filter; int index = fieldIndex(lt.attribute()); Object literal = convertLiteral(index, lt.value()); + if (isNaN(literal)) { + // Any comparison with NaN is false + return PredicateBuilder.alwaysFalse(); + } return builder.lessOrEqual(index, literal); } else if (filter instanceof In) { In in = (In) filter; diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala index 1493dfc49c76..47353b1aa426 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala @@ -44,13 +44,25 @@ case class SparkV2FilterConverter(rowType: RowType) extends Logging { } } + private def isNaN(value: Any): Boolean = { + value match { + case f: Float => f.isNaN + case d: Double => d.isNaN + case _ => false + } + } + private def convert(sparkPredicate: SparkPredicate): Predicate = { sparkPredicate.name() match { case EQUAL_TO => sparkPredicate match { case BinaryPredicate(transform, literal) => - // TODO deal with isNaN - builder.equal(transform, literal) + if (isNaN(literal)) { + // NaN != NaN, so equality with NaN should never match + PredicateBuilder.alwaysFalse() + } else { + builder.equal(transform, literal) + } case _ => throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.") } @@ -70,7 +82,12 @@ case class SparkV2FilterConverter(rowType: RowType) extends Logging { case GREATER_THAN => sparkPredicate match { case BinaryPredicate(transform, literal) => - builder.greaterThan(transform, literal) + if (isNaN(literal)) { + // Any comparison with NaN is false + PredicateBuilder.alwaysFalse() + } else { + builder.greaterThan(transform, literal) + } case _ => throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.") } @@ -78,7 +95,12 @@ case class SparkV2FilterConverter(rowType: RowType) extends Logging { case GREATER_THAN_OR_EQUAL => sparkPredicate match { case BinaryPredicate((transform, literal)) => - builder.greaterOrEqual(transform, literal) + if (isNaN(literal)) { + // Any comparison with NaN is false + PredicateBuilder.alwaysFalse() + } else { + builder.greaterOrEqual(transform, literal) + } case _ => throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.") } @@ -86,7 +108,12 @@ case class SparkV2FilterConverter(rowType: RowType) extends Logging { case LESS_THAN => sparkPredicate match { case BinaryPredicate(transform, literal) => - builder.lessThan(transform, literal) + if (isNaN(literal)) { + // Any comparison with NaN is false + PredicateBuilder.alwaysFalse() + } else { + builder.lessThan(transform, literal) + } case _ => throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.") } @@ -94,7 +121,12 @@ case class SparkV2FilterConverter(rowType: RowType) extends Logging { case LESS_THAN_OR_EQUAL => sparkPredicate match { case BinaryPredicate(transform, literal) => - builder.lessOrEqual(transform, literal) + if (isNaN(literal)) { + // Any comparison with NaN is false + PredicateBuilder.alwaysFalse() + } else { + builder.lessOrEqual(transform, literal) + } case _ => throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.") } diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala index f7e0bba63f14..bc199d71a15c 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala @@ -77,6 +77,11 @@ abstract class SparkV2FilterConverterTestBase extends PaimonSparkTestBase { |INSERT INTO test_tbl VALUES |('paimon', 4, 4, null, 4, 4.0, 4.0, 42.12345, true, date('2025-01-18'), binary('b4')) |""".stripMargin) + sql( + """ + |INSERT INTO test_tbl VALUES + |('nan_test', 5, 5, 5, 5, CAST('NaN' AS FLOAT), CAST('NaN' AS DOUBLE), 52.12345, false, date('2025-01-19'), binary('b5')) + |""".stripMargin) } override protected def afterAll(): Unit = { @@ -127,6 +132,13 @@ abstract class SparkV2FilterConverterTestBase extends PaimonSparkTestBase { checkAnswer(sql(s"SELECT float_col from test_tbl WHERE $filter"), Seq(Row(1.0f))) assert(scanFilesCount(filter) == 1) + // Test NaN handling - equality with NaN should return AlwaysFalse + val nanFilter = "float_col = CAST('NaN' AS FLOAT)" + val nanPredicate = converter.convert(v2Filter(nanFilter)).get + assert( + nanPredicate.equals(PredicateBuilder.alwaysFalse()), + "NaN equality should return AlwaysFalse") + filter = "double_col = 1.0" actual = converter.convert(v2Filter(filter)).get assert(actual.equals(builder.equal(6, 1.0d))) @@ -507,6 +519,98 @@ abstract class SparkV2FilterConverterTestBase extends PaimonSparkTestBase { assert(filesScanned == 4, s"Expected 4 files but scanned $filesScanned files") } + test("V2Filter: EqualTo with NaN should return AlwaysFalse") { + // Test float_col = NaN should always return false (no matching rows) + val filter1 = "float_col = CAST('NaN' AS FLOAT)" + val predicate1 = converter.convert(v2Filter(filter1)).get + assert(predicate1.equals(PredicateBuilder.alwaysFalse())) + + // Verify no files are scanned (AlwaysFalse should skip all files) + val filesScanned1 = scanFilesWithPredicate(predicate1) + assert( + filesScanned1 == 0, + s"Expected 0 files for NaN equality but scanned $filesScanned1 files") + + // Test double_col = NaN should always return false + val filter2 = "double_col = CAST('NaN' AS DOUBLE)" + val predicate2 = converter.convert(v2Filter(filter2)).get + assert(predicate2.equals(PredicateBuilder.alwaysFalse())) + + val filesScanned2 = scanFilesWithPredicate(predicate2) + assert( + filesScanned2 == 0, + s"Expected 0 files for NaN equality but scanned $filesScanned2 files") + } + + test("V2Filter: GreaterThan with NaN should return AlwaysFalse") { + // Test float_col > NaN should always return false + val filter1 = "float_col > CAST('NaN' AS FLOAT)" + val predicate1 = converter.convert(v2Filter(filter1)).get + assert(predicate1.equals(PredicateBuilder.alwaysFalse())) + + val filesScanned1 = scanFilesWithPredicate(predicate1) + assert( + filesScanned1 == 0, + s"Expected 0 files for NaN comparison but scanned $filesScanned1 files") + + // Test double_col > NaN should always return false + val filter2 = "double_col > CAST('NaN' AS DOUBLE)" + val predicate2 = converter.convert(v2Filter(filter2)).get + assert(predicate2.equals(PredicateBuilder.alwaysFalse())) + + val filesScanned2 = scanFilesWithPredicate(predicate2) + assert( + filesScanned2 == 0, + s"Expected 0 files for NaN comparison but scanned $filesScanned2 files") + } + + test("V2Filter: LessThanOrEqual with NaN should return AlwaysFalse") { + // Test float_col <= NaN should always return false + val filter1 = "float_col <= CAST('NaN' AS FLOAT)" + val predicate1 = converter.convert(v2Filter(filter1)).get + assert(predicate1.equals(PredicateBuilder.alwaysFalse())) + + val filesScanned1 = scanFilesWithPredicate(predicate1) + assert( + filesScanned1 == 0, + s"Expected 0 files for NaN comparison but scanned $filesScanned1 files") + } + + test("V2Filter: float and double normal operations not affected by NaN handling") { + // Verify that normal float/double queries still work correctly + val filter1 = "float_col = 1.0" + val predicate1 = converter.convert(v2Filter(filter1)).get + assert(predicate1.equals(builder.equal(5, 1.0f))) + checkAnswer(sql(s"SELECT float_col FROM test_tbl WHERE $filter1"), Seq(Row(1.0f))) + + val filter2 = "double_col > 2.0" + val predicate2 = converter.convert(v2Filter(filter2)).get + assert(predicate2.equals(builder.greaterThan(6, 2.0d))) + checkAnswer( + sql(s"SELECT double_col FROM test_tbl WHERE $filter2 ORDER BY double_col"), + Seq(Row(3.0d), Row(4.0d))) + + val filter3 = "float_col <= 3.0" + val predicate3 = converter.convert(v2Filter(filter3)).get + assert(predicate3.equals(builder.lessOrEqual(5, 3.0f))) + checkAnswer( + sql(s"SELECT float_col FROM test_tbl WHERE $filter3 ORDER BY float_col"), + Seq(Row(1.0f), Row(2.0f), Row(3.0f))) + } + + test("V2Filter: NaN row exists but is not matched by NaN equality") { + // The table has a row with NaN values, but NaN = NaN should not match + val countQuery = "SELECT COUNT(*) FROM test_tbl WHERE float_col = CAST('NaN' AS FLOAT)" + checkAnswer(sql(countQuery), Seq(Row(0))) + + val countQuery2 = "SELECT COUNT(*) FROM test_tbl WHERE double_col = CAST('NaN' AS DOUBLE)" + checkAnswer(sql(countQuery2), Seq(Row(0))) + + // But we can verify the NaN row exists by checking the row count + val totalCount = sql("SELECT COUNT(*) FROM test_tbl").collect().head.getLong(0) + assert(totalCount == 5, s"Expected 5 rows total including NaN row") + } + private def v2Filter(str: String, tableName: String = "test_tbl"): SparkPredicate = { val condition = sql(s"SELECT * FROM $tableName WHERE $str").queryExecution.optimizedPlan .collectFirst { case f: Filter => f } From f381a3e7d7ed57fc20a30b8d12e34e423ddcd2ce Mon Sep 17 00:00:00 2001 From: xuzifu666 <1206332514@qq.com> Date: Thu, 2 Apr 2026 17:44:31 +0800 Subject: [PATCH 2/3] fix --- .../paimon/spark/sql/SparkV2FilterConverterTestBase.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala index bc199d71a15c..d6a9a581ebe5 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala @@ -77,11 +77,6 @@ abstract class SparkV2FilterConverterTestBase extends PaimonSparkTestBase { |INSERT INTO test_tbl VALUES |('paimon', 4, 4, null, 4, 4.0, 4.0, 42.12345, true, date('2025-01-18'), binary('b4')) |""".stripMargin) - sql( - """ - |INSERT INTO test_tbl VALUES - |('nan_test', 5, 5, 5, 5, CAST('NaN' AS FLOAT), CAST('NaN' AS DOUBLE), 52.12345, false, date('2025-01-19'), binary('b5')) - |""".stripMargin) } override protected def afterAll(): Unit = { From 2c2af9a75f3b6d630b8490d971ed308e610079ee Mon Sep 17 00:00:00 2001 From: xuzifu666 <1206332514@qq.com> Date: Thu, 2 Apr 2026 18:15:59 +0800 Subject: [PATCH 3/3] fix --- .../spark/sql/SparkV2FilterConverterTestBase.scala | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala index d6a9a581ebe5..c5e714afaab0 100644 --- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala @@ -593,19 +593,6 @@ abstract class SparkV2FilterConverterTestBase extends PaimonSparkTestBase { Seq(Row(1.0f), Row(2.0f), Row(3.0f))) } - test("V2Filter: NaN row exists but is not matched by NaN equality") { - // The table has a row with NaN values, but NaN = NaN should not match - val countQuery = "SELECT COUNT(*) FROM test_tbl WHERE float_col = CAST('NaN' AS FLOAT)" - checkAnswer(sql(countQuery), Seq(Row(0))) - - val countQuery2 = "SELECT COUNT(*) FROM test_tbl WHERE double_col = CAST('NaN' AS DOUBLE)" - checkAnswer(sql(countQuery2), Seq(Row(0))) - - // But we can verify the NaN row exists by checking the row count - val totalCount = sql("SELECT COUNT(*) FROM test_tbl").collect().head.getLong(0) - assert(totalCount == 5, s"Expected 5 rows total including NaN row") - } - private def v2Filter(str: String, tableName: String = "test_tbl"): SparkPredicate = { val condition = sql(s"SELECT * FROM $tableName WHERE $str").queryExecution.optimizedPlan .collectFirst { case f: Filter => f }