Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,24 @@ public Predicate convert(Filter filter, boolean ignoreFailure) {
}
}

private boolean isNaN(Object value) {
if (value instanceof Float) {
return Float.isNaN((Float) value);
} else if (value instanceof Double) {
return Double.isNaN((Double) value);
}
return false;
}

public Predicate convert(Filter filter) {
if (filter instanceof EqualTo) {
EqualTo eq = (EqualTo) filter;
// TODO deal with isNaN
int index = fieldIndex(eq.attribute());
Object literal = convertLiteral(index, eq.value());
if (isNaN(literal)) {
// NaN != NaN, so equality with NaN should never match
return PredicateBuilder.alwaysFalse();
}
return builder.equal(index, literal);
} else if (filter instanceof EqualNullSafe) {
EqualNullSafe eq = (EqualNullSafe) filter;
Expand All @@ -116,21 +128,37 @@ public Predicate convert(Filter filter) {
GreaterThan gt = (GreaterThan) filter;
int index = fieldIndex(gt.attribute());
Object literal = convertLiteral(index, gt.value());
if (isNaN(literal)) {
// Any comparison with NaN is false
return PredicateBuilder.alwaysFalse();
}
return builder.greaterThan(index, literal);
} else if (filter instanceof GreaterThanOrEqual) {
GreaterThanOrEqual gt = (GreaterThanOrEqual) filter;
int index = fieldIndex(gt.attribute());
Object literal = convertLiteral(index, gt.value());
if (isNaN(literal)) {
// Any comparison with NaN is false
return PredicateBuilder.alwaysFalse();
}
return builder.greaterOrEqual(index, literal);
} else if (filter instanceof LessThan) {
LessThan lt = (LessThan) filter;
int index = fieldIndex(lt.attribute());
Object literal = convertLiteral(index, lt.value());
if (isNaN(literal)) {
// Any comparison with NaN is false
return PredicateBuilder.alwaysFalse();
}
return builder.lessThan(index, literal);
} else if (filter instanceof LessThanOrEqual) {
LessThanOrEqual lt = (LessThanOrEqual) filter;
int index = fieldIndex(lt.attribute());
Object literal = convertLiteral(index, lt.value());
if (isNaN(literal)) {
// Any comparison with NaN is false
return PredicateBuilder.alwaysFalse();
}
return builder.lessOrEqual(index, literal);
} else if (filter instanceof In) {
In in = (In) filter;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,25 @@ case class SparkV2FilterConverter(rowType: RowType) extends Logging {
}
}

private def isNaN(value: Any): Boolean = {
value match {
case f: Float => f.isNaN
case d: Double => d.isNaN
case _ => false
}
}

private def convert(sparkPredicate: SparkPredicate): Predicate = {
sparkPredicate.name() match {
case EQUAL_TO =>
sparkPredicate match {
case BinaryPredicate(transform, literal) =>
// TODO deal with isNaN
builder.equal(transform, literal)
if (isNaN(literal)) {
// NaN != NaN, so equality with NaN should never match
PredicateBuilder.alwaysFalse()
} else {
builder.equal(transform, literal)
}
case _ =>
throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.")
}
Expand All @@ -70,31 +82,51 @@ case class SparkV2FilterConverter(rowType: RowType) extends Logging {
case GREATER_THAN =>
sparkPredicate match {
case BinaryPredicate(transform, literal) =>
builder.greaterThan(transform, literal)
if (isNaN(literal)) {
// Any comparison with NaN is false
PredicateBuilder.alwaysFalse()
} else {
builder.greaterThan(transform, literal)
}
case _ =>
throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.")
}

case GREATER_THAN_OR_EQUAL =>
sparkPredicate match {
case BinaryPredicate((transform, literal)) =>
builder.greaterOrEqual(transform, literal)
if (isNaN(literal)) {
// Any comparison with NaN is false
PredicateBuilder.alwaysFalse()
} else {
builder.greaterOrEqual(transform, literal)
}
case _ =>
throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.")
}

case LESS_THAN =>
sparkPredicate match {
case BinaryPredicate(transform, literal) =>
builder.lessThan(transform, literal)
if (isNaN(literal)) {
// Any comparison with NaN is false
PredicateBuilder.alwaysFalse()
} else {
builder.lessThan(transform, literal)
}
case _ =>
throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.")
}

case LESS_THAN_OR_EQUAL =>
sparkPredicate match {
case BinaryPredicate(transform, literal) =>
builder.lessOrEqual(transform, literal)
if (isNaN(literal)) {
// Any comparison with NaN is false
PredicateBuilder.alwaysFalse()
} else {
builder.lessOrEqual(transform, literal)
}
case _ =>
throw new UnsupportedOperationException(s"Convert $sparkPredicate is unsupported.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ abstract class SparkV2FilterConverterTestBase extends PaimonSparkTestBase {
checkAnswer(sql(s"SELECT float_col from test_tbl WHERE $filter"), Seq(Row(1.0f)))
assert(scanFilesCount(filter) == 1)

// Test NaN handling - equality with NaN should return AlwaysFalse
val nanFilter = "float_col = CAST('NaN' AS FLOAT)"
val nanPredicate = converter.convert(v2Filter(nanFilter)).get
assert(
nanPredicate.equals(PredicateBuilder.alwaysFalse()),
"NaN equality should return AlwaysFalse")

filter = "double_col = 1.0"
actual = converter.convert(v2Filter(filter)).get
assert(actual.equals(builder.equal(6, 1.0d)))
Expand Down Expand Up @@ -507,6 +514,85 @@ abstract class SparkV2FilterConverterTestBase extends PaimonSparkTestBase {
assert(filesScanned == 4, s"Expected 4 files but scanned $filesScanned files")
}

test("V2Filter: EqualTo with NaN should return AlwaysFalse") {
// Test float_col = NaN should always return false (no matching rows)
val filter1 = "float_col = CAST('NaN' AS FLOAT)"
val predicate1 = converter.convert(v2Filter(filter1)).get
assert(predicate1.equals(PredicateBuilder.alwaysFalse()))

// Verify no files are scanned (AlwaysFalse should skip all files)
val filesScanned1 = scanFilesWithPredicate(predicate1)
assert(
filesScanned1 == 0,
s"Expected 0 files for NaN equality but scanned $filesScanned1 files")

// Test double_col = NaN should always return false
val filter2 = "double_col = CAST('NaN' AS DOUBLE)"
val predicate2 = converter.convert(v2Filter(filter2)).get
assert(predicate2.equals(PredicateBuilder.alwaysFalse()))

val filesScanned2 = scanFilesWithPredicate(predicate2)
assert(
filesScanned2 == 0,
s"Expected 0 files for NaN equality but scanned $filesScanned2 files")
}

test("V2Filter: GreaterThan with NaN should return AlwaysFalse") {
// Test float_col > NaN should always return false
val filter1 = "float_col > CAST('NaN' AS FLOAT)"
val predicate1 = converter.convert(v2Filter(filter1)).get
assert(predicate1.equals(PredicateBuilder.alwaysFalse()))

val filesScanned1 = scanFilesWithPredicate(predicate1)
assert(
filesScanned1 == 0,
s"Expected 0 files for NaN comparison but scanned $filesScanned1 files")

// Test double_col > NaN should always return false
val filter2 = "double_col > CAST('NaN' AS DOUBLE)"
val predicate2 = converter.convert(v2Filter(filter2)).get
assert(predicate2.equals(PredicateBuilder.alwaysFalse()))

val filesScanned2 = scanFilesWithPredicate(predicate2)
assert(
filesScanned2 == 0,
s"Expected 0 files for NaN comparison but scanned $filesScanned2 files")
}

test("V2Filter: LessThanOrEqual with NaN should return AlwaysFalse") {
// Test float_col <= NaN should always return false
val filter1 = "float_col <= CAST('NaN' AS FLOAT)"
val predicate1 = converter.convert(v2Filter(filter1)).get
assert(predicate1.equals(PredicateBuilder.alwaysFalse()))

val filesScanned1 = scanFilesWithPredicate(predicate1)
assert(
filesScanned1 == 0,
s"Expected 0 files for NaN comparison but scanned $filesScanned1 files")
}

test("V2Filter: float and double normal operations not affected by NaN handling") {
// Verify that normal float/double queries still work correctly
val filter1 = "float_col = 1.0"
val predicate1 = converter.convert(v2Filter(filter1)).get
assert(predicate1.equals(builder.equal(5, 1.0f)))
checkAnswer(sql(s"SELECT float_col FROM test_tbl WHERE $filter1"), Seq(Row(1.0f)))

val filter2 = "double_col > 2.0"
val predicate2 = converter.convert(v2Filter(filter2)).get
assert(predicate2.equals(builder.greaterThan(6, 2.0d)))
checkAnswer(
sql(s"SELECT double_col FROM test_tbl WHERE $filter2 ORDER BY double_col"),
Seq(Row(3.0d), Row(4.0d)))

val filter3 = "float_col <= 3.0"
val predicate3 = converter.convert(v2Filter(filter3)).get
assert(predicate3.equals(builder.lessOrEqual(5, 3.0f)))
checkAnswer(
sql(s"SELECT float_col FROM test_tbl WHERE $filter3 ORDER BY float_col"),
Seq(Row(1.0f), Row(2.0f), Row(3.0f)))
}

private def v2Filter(str: String, tableName: String = "test_tbl"): SparkPredicate = {
val condition = sql(s"SELECT * FROM $tableName WHERE $str").queryExecution.optimizedPlan
.collectFirst { case f: Filter => f }
Expand Down