diff --git a/core/src/main/java/org/apache/calcite/rel/core/Intersect.java b/core/src/main/java/org/apache/calcite/rel/core/Intersect.java index e6ea2f6c242c..2de728a74279 100644 --- a/core/src/main/java/org/apache/calcite/rel/core/Intersect.java +++ b/core/src/main/java/org/apache/calcite/rel/core/Intersect.java @@ -22,7 +22,10 @@ import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.hint.RelHint; import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.util.Util; import java.util.Collections; import java.util.List; @@ -79,4 +82,12 @@ protected Intersect(RelInput input) { dRows *= 0.25; return dRows; } + + @Override protected RelDataType deriveRowType() { + // An output column is only nullable if it is nullable in ALL the inputs. + return ReturnTypes.refineNullabilityForIntersect( + getCluster().getTypeFactory(), + deriveLeastRestrictiveRowType(), + Util.transform(getInputs(), RelNode::getRowType)); + } } diff --git a/core/src/main/java/org/apache/calcite/rel/core/Minus.java b/core/src/main/java/org/apache/calcite/rel/core/Minus.java index 3a68945acc64..c27af940b78f 100644 --- a/core/src/main/java/org/apache/calcite/rel/core/Minus.java +++ b/core/src/main/java/org/apache/calcite/rel/core/Minus.java @@ -23,7 +23,9 @@ import org.apache.calcite.rel.hint.RelHint; import org.apache.calcite.rel.metadata.RelMdUtil; import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.type.ReturnTypes; import java.util.Collections; import java.util.List; @@ -60,4 +62,12 @@ protected Minus(RelInput input) { @Override public double estimateRowCount(RelMetadataQuery mq) { return RelMdUtil.getMinusRowCount(mq, this); } + + @Override protected RelDataType deriveRowType() { + // The nullability of the output columns is the same as that of the primary input. + return ReturnTypes.refineNullabilityForExcept( + getCluster().getTypeFactory(), + deriveLeastRestrictiveRowType(), + getInput(0).getRowType()); + } } diff --git a/core/src/main/java/org/apache/calcite/rel/core/SetOp.java b/core/src/main/java/org/apache/calcite/rel/core/SetOp.java index 04e1409e78aa..9af6dd2ca91f 100644 --- a/core/src/main/java/org/apache/calcite/rel/core/SetOp.java +++ b/core/src/main/java/org/apache/calcite/rel/core/SetOp.java @@ -114,6 +114,10 @@ public abstract SetOp copy( } @Override protected RelDataType deriveRowType() { + return deriveLeastRestrictiveRowType(); + } + + protected RelDataType deriveLeastRestrictiveRowType() { final List inputRowTypes = Util.transform(inputs, RelNode::getRowType); final RelDataType rowType = diff --git a/core/src/main/java/org/apache/calcite/rel/rules/IntersectToDistinctRule.java b/core/src/main/java/org/apache/calcite/rel/rules/IntersectToDistinctRule.java index a8d2713e0980..394f8e088551 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/IntersectToDistinctRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/IntersectToDistinctRule.java @@ -140,6 +140,10 @@ public void onMatchAggregateOnUnion(RelOptRuleCall call) { // Project all but the last added field (e.g. count_i{n}) relBuilder.project(skipLast(relBuilder.fields(), branchCount)); + + // ensure the nullabilities of columns in the new relation match those of the input relation + relBuilder.convert(intersect.getRowType(), false); + call.transformTo(relBuilder.build()); } @@ -206,6 +210,9 @@ public void onMatchAggregatePushdown(RelOptRuleCall call) { // Project all but the last field relBuilder.project(Util.skipLast(relBuilder.fields())); + // ensure the nullabilities of columns in the new relation match those of the input relation + relBuilder.convert(intersect.getRowType(), false); + // the schema for intersect distinct matches that of the relation, // built here with an extra last column for the count, // which is projected out by the final project we added diff --git a/core/src/main/java/org/apache/calcite/rel/rules/MinusToDistinctRule.java b/core/src/main/java/org/apache/calcite/rel/rules/MinusToDistinctRule.java index 7b14f2e94613..59eaf0b60c25 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/MinusToDistinctRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/MinusToDistinctRule.java @@ -159,6 +159,10 @@ public MinusToDistinctRule(Class minusClass, relBuilder.filter(filters.build()); relBuilder.project(Util.first(relBuilder.fields(), originalFieldCnt)); + + // ensure the nullabilities of columns in the new relation match those of the minus output + relBuilder.convert(minus.getRowType(), false); + call.transformTo(relBuilder.build()); } diff --git a/core/src/main/java/org/apache/calcite/sql/SqlSetOperator.java b/core/src/main/java/org/apache/calcite/sql/SqlSetOperator.java index 6d0f48c4f417..0ce69276f71d 100644 --- a/core/src/main/java/org/apache/calcite/sql/SqlSetOperator.java +++ b/core/src/main/java/org/apache/calcite/sql/SqlSetOperator.java @@ -24,6 +24,8 @@ import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.sql.validate.SqlValidatorScope; +import org.checkerframework.checker.nullness.qual.Nullable; + /** * SqlSetOperator represents a relational set theory operator (UNION, INTERSECT, * MINUS). These are binary operators, but with an extra boolean attribute @@ -59,7 +61,7 @@ public SqlSetOperator( int prec, boolean all, SqlReturnTypeInference returnTypeInference, - SqlOperandTypeInference operandTypeInference, + @Nullable SqlOperandTypeInference operandTypeInference, SqlOperandTypeChecker operandTypeChecker) { super( name, diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java index 2a26a78929cf..156f73132a32 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java @@ -121,16 +121,20 @@ public class SqlStdOperatorTable extends ReflectiveSqlOperatorTable { new SqlSetOperator("UNION ALL", SqlKind.UNION, 12, true); public static final SqlSetOperator EXCEPT = - new SqlSetOperator("EXCEPT", SqlKind.EXCEPT, 12, false); + new SqlSetOperator("EXCEPT", SqlKind.EXCEPT, 12, false, + ReturnTypes.LEAST_RESTRICTIVE_EXCEPT, null, OperandTypes.SET_OP); public static final SqlSetOperator EXCEPT_ALL = - new SqlSetOperator("EXCEPT ALL", SqlKind.EXCEPT, 12, true); + new SqlSetOperator("EXCEPT ALL", SqlKind.EXCEPT, 12, true, + ReturnTypes.LEAST_RESTRICTIVE_EXCEPT, null, OperandTypes.SET_OP); public static final SqlSetOperator INTERSECT = - new SqlSetOperator("INTERSECT", SqlKind.INTERSECT, 14, false); + new SqlSetOperator("INTERSECT", SqlKind.INTERSECT, 14, false, + ReturnTypes.LEAST_RESTRICTIVE_INTERSECT, null, OperandTypes.SET_OP); public static final SqlSetOperator INTERSECT_ALL = - new SqlSetOperator("INTERSECT ALL", SqlKind.INTERSECT, 14, true); + new SqlSetOperator("INTERSECT ALL", SqlKind.INTERSECT, 14, true, + ReturnTypes.LEAST_RESTRICTIVE_INTERSECT, null, OperandTypes.SET_OP); /** * The {@code MULTISET UNION DISTINCT} operator. diff --git a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java index 27f1a6fc72a7..7c8680e64336 100644 --- a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java +++ b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java @@ -608,6 +608,81 @@ public static SqlCall stripSeparator(SqlCall call) { .leastRestrictive(opBinding.collectOperandTypes()); } + /** + * Refines the nullability of {@code base} using INTERSECT semantics: a + * column is NOT NULL if it is NOT NULL in at least one of the + * {@code inputTypes} (AND-semantics across inputs). + */ + public static RelDataType refineNullabilityForIntersect( + RelDataTypeFactory typeFactory, + RelDataType base, + List inputTypes) { + final RelDataTypeFactory.Builder builder = + new RelDataTypeFactory.Builder(typeFactory); + final List outputFields = base.getFieldList(); + for (int i = 0; i < outputFields.size(); i++) { + boolean nullable = true; + for (RelDataType inputType : inputTypes) { + nullable &= inputType.getFieldList().get(i).getType().isNullable(); + } + builder.add(outputFields.get(i)).nullable(nullable); + } + return builder.build(); + } + + /** + * Type-inference strategy for INTERSECT. Computes the least restrictive row + * type across all inputs, then refines nullability: a column is NOT NULL if + * it is NOT NULL in at least one input (AND semantics across inputs). + */ + public static final SqlReturnTypeInference LEAST_RESTRICTIVE_INTERSECT = + andThen(SqlTypeTransforms.FROM_MEASURE_IF::apply, opBinding -> { + final List inputTypes = opBinding.collectOperandTypes(); + final RelDataType base = + opBinding.getTypeFactory().leastRestrictive(inputTypes); + if (base == null) { + return null; + } + return refineNullabilityForIntersect(opBinding.getTypeFactory(), base, inputTypes); + }); + + /** + * Refines the nullability of {@code base} using EXCEPT/MINUS semantics: a + * column's nullability matches that of the first (primary) input. + */ + public static RelDataType refineNullabilityForExcept( + RelDataTypeFactory typeFactory, + RelDataType base, + RelDataType primaryInputType) { + final RelDataTypeFactory.Builder builder = + new RelDataTypeFactory.Builder(typeFactory); + final List outputFields = base.getFieldList(); + final List primaryInputFields = + primaryInputType.getFieldList(); + for (int i = 0; i < outputFields.size(); i++) { + builder.add(outputFields.get(i)) + .nullable(primaryInputFields.get(i).getType().isNullable()); + } + return builder.build(); + } + + /** + * Type-inference strategy for EXCEPT/MINUS. Computes the least restrictive + * row type across all inputs, then refines nullability: a column's + * nullability matches that of the first (primary) input. + */ + public static final SqlReturnTypeInference LEAST_RESTRICTIVE_EXCEPT = + andThen(SqlTypeTransforms.FROM_MEASURE_IF::apply, opBinding -> { + final List inputTypes = opBinding.collectOperandTypes(); + final RelDataType base = + opBinding.getTypeFactory().leastRestrictive(inputTypes); + if (base == null) { + return null; + } + return refineNullabilityForExcept( + opBinding.getTypeFactory(), base, inputTypes.get(0)); + }); + /** * Type-inference strategy for NVL2 function. It returns the least restrictive type * between the second and third operands. diff --git a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java index b54e08e26d4f..a09d8f2c3d84 100644 --- a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java @@ -131,6 +131,7 @@ import static org.apache.calcite.test.Matchers.hasExpandedTree; import static org.apache.calcite.test.Matchers.hasFieldNames; import static org.apache.calcite.test.Matchers.hasHints; +import static org.apache.calcite.test.Matchers.hasRelDataType; import static org.apache.calcite.test.Matchers.hasTree; import static org.hamcrest.CoreMatchers.allOf; @@ -2346,6 +2347,72 @@ private static RelNode groupIdRel(RelBuilder builder, boolean extra) { assertThat(root, hasTree(expected)); } + /** Test case for + * [CALCITE-6451] + * Improve Nullability Derivation for Intersect and Minus. */ + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void testUnionTypeDerivation(boolean all) { + final RelBuilder builder = RelBuilder.create(config().build()); + + RelDataType input1RowType = + new RelDataTypeFactory.Builder(builder.getTypeFactory()) + .add("a", SqlTypeName.BIGINT) + .nullable(false) + .add("b", SqlTypeName.BIGINT) + .nullable(false) + .add("c", SqlTypeName.BIGINT) + .nullable(true) + .add("d", SqlTypeName.BIGINT) + .nullable(true) + .build(); + + RelDataType input2RowType = + new RelDataTypeFactory.Builder(builder.getTypeFactory()) + .add("a", SqlTypeName.BIGINT) + .nullable(false) + .add("b", SqlTypeName.BIGINT) + .nullable(false) + .add("c", SqlTypeName.BIGINT) + .nullable(false) + .add("d", SqlTypeName.BIGINT) + .nullable(false) + .build(); + + RelDataType input3RowType = + new RelDataTypeFactory.Builder(builder.getTypeFactory()) + .add("a", SqlTypeName.BIGINT) + .nullable(false) + .add("b", SqlTypeName.BIGINT) + .nullable(true) + .add("c", SqlTypeName.BIGINT) + .nullable(false) + .add("d", SqlTypeName.BIGINT) + .nullable(true) + .build(); + + RelNode root = + builder + .values(input1RowType) + .values(input2RowType) + .values(input3RowType) + .union(all, 3) + .build(); + + RelDataType expectedRowType = + new RelDataTypeFactory.Builder(builder.getTypeFactory()) + .add("a", SqlTypeName.BIGINT) + .nullable(false) + .add("b", SqlTypeName.BIGINT) + .nullable(true) + .add("c", SqlTypeName.BIGINT) + .nullable(true) + .add("d", SqlTypeName.BIGINT) + .nullable(true) + .build(); + assertThat(root.getRowType(), hasRelDataType(expectedRowType)); + } + /** Test case for * [CALCITE-1522] * Fix error message for SetOp with incompatible args. */ @@ -2550,6 +2617,69 @@ private static RelNode groupIdRel(RelBuilder builder, boolean extra) { assertThat(root, hasTree(expected)); } + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void testIntersectTypeDerivation(boolean all) { + final RelBuilder builder = RelBuilder.create(config().build()); + + RelDataType input1RowType = + new RelDataTypeFactory.Builder(builder.getTypeFactory()) + .add("a", SqlTypeName.BIGINT) + .nullable(false) + .add("b", SqlTypeName.BIGINT) + .nullable(false) + .add("c", SqlTypeName.BIGINT) + .nullable(true) + .add("d", SqlTypeName.BIGINT) + .nullable(true) + .build(); + + RelDataType input2RowType = + new RelDataTypeFactory.Builder(builder.getTypeFactory()) + .add("a", SqlTypeName.BIGINT) + .nullable(true) + .add("b", SqlTypeName.BIGINT) + .nullable(true) + .add("c", SqlTypeName.BIGINT) + .nullable(true) + .add("d", SqlTypeName.BIGINT) + .nullable(true) + .build(); + + RelDataType input3RowType = + new RelDataTypeFactory.Builder(builder.getTypeFactory()) + .add("a", SqlTypeName.BIGINT) + .nullable(false) + .add("b", SqlTypeName.BIGINT) + .nullable(true) + .add("c", SqlTypeName.BIGINT) + .nullable(false) + .add("d", SqlTypeName.BIGINT) + .nullable(true) + .build(); + + RelNode root = + builder + .values(input1RowType) + .values(input2RowType) + .values(input3RowType) + .intersect(all, 3) + .build(); + + RelDataType expectedRowType = + new RelDataTypeFactory.Builder(builder.getTypeFactory()) + .add("a", SqlTypeName.BIGINT) + .nullable(false) + .add("b", SqlTypeName.BIGINT) + .nullable(false) + .add("c", SqlTypeName.BIGINT) + .nullable(false) + .add("d", SqlTypeName.BIGINT) + .nullable(true) + .build(); + assertThat(root.getRowType(), hasRelDataType(expectedRowType)); + } + @Test void testExcept() { // Equivalent SQL: // SELECT empno FROM emp @@ -2577,6 +2707,55 @@ private static RelNode groupIdRel(RelBuilder builder, boolean extra) { assertThat(root, hasTree(expected)); } + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void testExceptTypeDerivation(boolean all) { + final RelBuilder builder = RelBuilder.create(config().build()); + + RelDataType primaryRowType = + new RelDataTypeFactory.Builder(builder.getTypeFactory()) + .add("a", SqlTypeName.BIGINT) + .nullable(false) + .add("b", SqlTypeName.BIGINT) + .nullable(false) + .add("c", SqlTypeName.BIGINT) + .nullable(true) + .add("d", SqlTypeName.BIGINT) + .nullable(true) + .build(); + + RelDataType secondaryRowType = + new RelDataTypeFactory.Builder(builder.getTypeFactory()) + .add("a", SqlTypeName.BIGINT) + .nullable(false) + .add("b", SqlTypeName.BIGINT) + .nullable(true) + .add("c", SqlTypeName.BIGINT) + .nullable(false) + .add("d", SqlTypeName.BIGINT) + .nullable(true) + .build(); + + RelNode root = + builder.values(primaryRowType) + .values(secondaryRowType) + .minus(all) + .build(); + + RelDataType expectedRowType = + new RelDataTypeFactory.Builder(builder.getTypeFactory()) + .add("a", SqlTypeName.BIGINT) + .nullable(false) + .add("b", SqlTypeName.BIGINT) + .nullable(false) + .add("c", SqlTypeName.BIGINT) + .nullable(true) + .add("d", SqlTypeName.BIGINT) + .nullable(true) + .build(); + assertThat(root.getRowType(), hasRelDataType(expectedRowType)); + } + /** Tests building a simple join. Also checks {@link RelBuilder#size()} * at every step. */ @Test void testJoin() { diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml index ece2ad5258aa..0043fedfd081 100644 --- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml @@ -7302,15 +7302,14 @@ LogicalIntersect(all=[false]) LogicalAggregate(group=[{0}]) LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], joinType=[semi]) LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], joinType=[semi]) - LogicalProject(ENAME=[CAST($0):VARCHAR]) + LogicalProject(ENAME=[CAST($0):VARCHAR NOT NULL]) LogicalProject(ENAME=[$1]) LogicalFilter(condition=[=($7, 10)]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) - LogicalProject(ENAME=[CAST($0):VARCHAR]) - LogicalProject(DEPTNO=[CAST($7):VARCHAR NOT NULL]) - LogicalFilter(condition=[OR(=($1, 'a'), =($1, 'b'))]) - LogicalTableScan(table=[[CATALOG, SALES, EMP]]) - LogicalProject(ENAME=[CAST($0):VARCHAR]) + LogicalProject(DEPTNO=[CAST($7):VARCHAR NOT NULL]) + LogicalFilter(condition=[OR(=($1, 'a'), =($1, 'b'))]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalProject(ENAME=[CAST($0):VARCHAR NOT NULL]) LogicalProject(ENAME=[$1]) LogicalTableScan(table=[[CATALOG, SALES, EMPNULLABLES]]) ]]> @@ -10436,10 +10435,10 @@ LogicalMinus(all=[false]) hasFieldNames(String fieldNames) { } }; } + + /** + * Creates a Matcher that matches a {@link RelDataType} if its + * {@link RelDataType#getFullTypeString()} is equal to that of the given {@code relDataType}. + */ + public static Matcher hasRelDataType(RelDataType relDataType) { + return compose( + IsEqual.equalTo(relDataType.getFullTypeString()), + RelDataType::getFullTypeString); + } + /** * Creates a Matcher that matches a {@link RelNode} if its string * representation, after converting Windows-style line endings ("\r\n")