diff --git a/Cargo.lock b/Cargo.lock index f500265108ff5..8377a263cd0cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2386,6 +2386,7 @@ dependencies = [ "datafusion-functions-aggregate", "datafusion-functions-aggregate-common", "datafusion-macros", + "datafusion-physical-expr", "datafusion-physical-expr-common", "itertools 0.14.0", "log", diff --git a/DOC.md b/DOC.md new file mode 100644 index 0000000000000..10a0ab4e19407 --- /dev/null +++ b/DOC.md @@ -0,0 +1,1166 @@ +This PR adds support for lambdas with column capture and the `array_transform` function used to test the lambda implementation. Example usage: + +```sql +CREATE TABLE t as SELECT 2 as n; + +SELECT array_transform([2, 3], v -> v != t.n) from t; + +[false, true] + +-- arbitrally nested lambdas are also supported +SELECT array_transform([[[2, 3]]], m -> array_transform(m, l -> array_transform(l, v -> v*2))); + +[[[4, 6]]] +``` + +Some comments on code snippets of this doc show what value each struct, variant or field would hold after planning the first example above. Some literals are simplified pseudo code + +3 new `Expr` variants are added, `LambdaFunction`, owing a new trait `LambdaUDF`, which is like a `ScalarFunction`/`ScalarUDFImpl` with support for lambdas, `Lambda`, for the lambda body and it's parameters names, and `LambdaVariable`, which is like `Column` but for lambdas parameters. The reasoning why not using `Column` instead is later on this doc. + +Their logical representations: + +```rust +enum Expr { + LambdaFunction(LambdaFunction), // array_transform([2, 3], v -> v != t.n) + Lambda(Lambda), // v -> v != t.n + LambdaVariable(LambdaVariable), // v, of the lambda body: v != t.n + ... +} + +// array_transform([2, 3], v -> v != t.n) +struct LambdaFunction { + pub func: Arc, // global instance of array_transform + pub args: Vec, // [Expr::ScalarValue([2, 3]), Expr::Lambda(v -> v != n)] +} + +// v -> v != t.n +struct Lambda { + pub params: Vec, // ["v"] + pub body: Box, // v != n +} + +// v, of the lambda body: v != t.n +struct LambdaVariable { + pub name: String, // "v" + pub field: Option, // Some(Field::new("", DataType::Int32, false)) + pub spans: Spans, +} + +``` + +The example would be planned into a tree like this: + +``` +LambdaFunctionExpression + name: array_transform + children: + 1. ListExpression [2,3] + 2. LambdaExpression + parameters: ["v"] + body: + ComparisonExpression (!=) + left: + LambdaVariableExpression("v", Some(Field::new("", Int32, false))) + right: + ColumnExpression("t.n") +``` + +The physical counterparts definition: + +```rust + +struct LambdaFunctionExpr { + fun: Arc, // global instance of array_transform + name: String, // "array_transform" + args: Vec>, // [LiteralExpr([2, 3], LambdaExpr("v -> v != t.n"))] + return_field: FieldRef, // Field::new("", DataType::new_list(DataType::Boolean, false), false) + config_options: Arc, +} + + +struct LambdaExpr { + params: Vec, // ["v"] + body: Arc, // v -> v != t.n +} + +struct LambdaVariable { + name: String, // "v", of the lambda body: v != t.n + field: FieldRef, // Field::new("", DataType::Int32, false) + value: Option, // reasoning later on +} +``` + +Note: For those who primarly wants to check if this lambda implementation supports their usecase and don't want to spend much time here, it's okay to skip most collapsed blocks, as those serve mostly to help code reviewers, with the exception of `LambdaUDF` and the `array_transform` implementation of `LambdaUDF` relevant methods, collapsed due to their size + +
Physical planning implementation is trivial: + +```rust +fn create_physical_expr( + e: &Expr, + input_dfschema: &DFSchema, + execution_props: &ExecutionProps, +) -> Result> { + let input_schema = input_dfschema.as_arrow(); + + match e { + ... + Expr::LambdaFunction(LambdaFunction { func, args}) => { + let physical_args = + create_physical_exprs(args, input_dfschema, execution_props)?; + + Ok(Arc::new(LambdaFunctionExpr::try_new( + Arc::clone(func), + physical_args, + input_schema, + config_options: ... // irrelevant + )?)) + } + Expr::Lambda(Lambda { params, body }) => Ok(Arc::new(LambdaExpr::new( + params.clone(), + create_physical_expr(body, input_dfschema, execution_props)?, + ))), + Expr::LambdaVariable(LambdaVariable { + name, + field, + spans: _, + }) => lambda_variable( + name, + Arc::clone(field), + ), + } +} +``` + +
+
+ +The added `LambdaUDF` trait is almost a clone of `ScalarUDFImpl`, with the exception of: +1. `return_field_from_args` and `invoke_with_args`, where now `args.args` is a list of enums with two variants: `Value` or `Lambda` instead of a list of values +2. the addition of `lambdas_parameters`, which return a `Field` for each parameter supported for every lambda argument based on the `Field` of the non lambda arguments +3. the removal of `return_field` and the deprecated ones `is_nullable` and `display_name`. + +
LambdaUDF + +```rust + +trait LambdaUDF { + /// Returns a list of the same size as args where each value is the logic below applied to value at the correspondent position in args: + /// + /// If it's a value, return None + /// If it's a lambda, return the list of all parameters that that lambda supports + /// based on the Field of the non-lambda arguments + /// + /// Example for array_transform: + /// + /// `array_transform([2, 8], v -> v > 4)` + /// + /// let lambdas_parameters = array_transform.lambdas_parameters(&[ + /// ValueOrLambdaParameter::Value(Field::new("", DataType::new_list(DataType::Int32, false)))]), // the Field associated with the literal `[2, 8]` + /// ValueOrLambdaParameter::Lambda, // A lambda + /// ]?; + /// + /// assert_eq!( + /// lambdas_parameters, + /// vec![ + /// None, // it's a value, return None + /// // it's a lambda, return it's supported parameters, regardless of how many are actually used + /// Some(vec![ + /// Field::new("", DataType::Int32, false), // the value being transformed, + /// Field::new("", DataType::Int32, false), // the 1-based index being transformed, not used on the example above, but implementations doesn't need to care about it + /// ]) + /// ] + /// ) + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>>; + fn return_field_from_args(&self, args: LambdaReturnFieldArgs) -> Result; + fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result; + // ... omitted methods that are similar in ScalarUDFImpl +} + +pub enum ValueOrLambdaParameter { + /// A columnar value with the given field + Value(FieldRef), + /// A lambda + Lambda, +} + +/// Information about arguments passed to the function +/// +/// This structure contains metadata about how the function was called +/// such as the type of the arguments, any scalar arguments and if the +/// arguments can (ever) be null +/// +/// See [`LambdaUDF::return_field_from_args`] for more information +pub struct LambdaReturnFieldArgs<'a> { + /// The data types of the arguments to the function + /// + /// If argument `i` to the function is a lambda, it will be the field returned by the + /// lambda when executed with the arguments returned from `LambdaUDF::lambdas_parameters` + /// + /// For example, with `array_transform([1], v -> v == 5)` + /// this field will be `[ + // ValueOrLambdaField::Value(Field::new("", DataType::new_list(DataType::Int32, false), false)), + // ValueOrLambdaField::Lambda(Field::new("", DataType::Boolean, false)) + // ]` + pub arg_fields: &'a [ValueOrLambdaField], + /// Is argument `i` to the function a scalar (constant)? + /// + /// If the argument `i` is not a scalar, it will be None + /// + /// For example, if a function is called like `my_function(column_a, 5)` + /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]` + pub scalar_arguments: &'a [Option<&'a ScalarValue>], +} + +/// A tagged FieldRef indicating whether it correspond the field of a value or the field of the output of a lambda argument +pub enum ValueOrLambdaField { + /// The FieldRef of a ColumnarValue argument + Value(FieldRef), + /// The return FieldRef of the lambda body when evaluated with the parameters from LambdaUDF::lambda_parameters + Lambda(FieldRef), +} + +/// Arguments passed to [`LambdaUDF::invoke_with_args`] when invoking a +/// lambda function. +pub struct LambdaFunctionArgs { + /// The evaluated arguments to the function + pub args: Vec, + /// Field associated with each arg, if it exists + pub arg_fields: Vec, + /// The number of rows in record batch being evaluated + pub number_rows: usize, + /// The return field of the lambda function returned (from `return_type` + /// or `return_field_from_args`) when creating the physical expression + /// from the logical expression + pub return_field: FieldRef, + /// The config options at execution time + pub config_options: Arc, +} + +/// A lambda argument to a LambdaFunction +pub struct LambdaFunctionLambdaArg { + /// The parameters defined in this lambda + /// + /// For example, for `array_transform([2], v -> -v)`, + /// this will be `vec![Field::new("v", DataType::Int32, true)]` + pub params: Vec, + /// The body of the lambda + /// + /// For example, for `array_transform([2], v -> -v)`, + /// this will be the physical expression of `-v` + pub body: Arc, + /// A RecordBatch containing at least the captured columns inside this lambda body, if any + /// Note that it may contain additional, non-specified columns, + /// but that's implementation detail and should not be relied upon + /// + /// For example, for `array_transform([2], v -> v + t.a + t.b)`, + /// this will be a `RecordBatch` with at least two columns, `t.a` and `t.b` + pub captures: Option, +} + +// An argument to a LambdaUDF +pub enum ValueOrLambda { + Value(ColumnarValue), + Lambda(LambdaFunctionLambdaArg), +} +``` + + +
+ +
array_transform lambdas_parameters implementation + +```rust +impl LambdaUDF for ArrayTransform { + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>> { + // list is the field of [2, 3]: Field::new("", DataType::new_list(DataType::Int32, false), false) + let [ValueOrLambdaParameter::Value(list), ValueOrLambdaParameter::Lambda] = args + else { + return exec_err!( + "{} expects a value follewed by a lambda, got {:?}", + self.name(), + args + ); + }; + + // the field of [2, 3] inner values: Field::new("", DataType::Int32, false) + let (field, index_type) = match list.data_type() { + DataType::List(field) => (field, DataType::Int32), + DataType::LargeList(field) => (field, DataType::Int64), + DataType::FixedSizeList(field, _) => (field, DataType::Int32), + _ => return exec_err!("expected list, got {list}"), + }; + + // we don't need to omit the index in the case the lambda don't specify, e.g. array_transform([], v -> v*2), + // nor check whether the lambda contains more than two parameters, e.g. array_transform([], (v, i, j) -> v+i+j), + // as datafusion will do that for us + let value = Field::new("", field.data_type().clone(), field.is_nullable()) + .with_metadata(field.metadata().clone()); + let index = Field::new("", index_type, false); + + Ok(vec![None, Some(vec![value, index])]) + } +} +``` + +
+ +
array_transform return_field_from_args implementation + +```rust +impl LambdaUDF for ArrayTransform { + fn return_field_from_args( + &self, + args: datafusion_expr::LambdaReturnFieldArgs, + ) -> Result> { + // [ + // Field::new("", DataType::new_list(DataType::Int32, false), false), + // Field::new("", DataType::Boolean, false), + // ] + let [ValueOrLambdaField::Value(list), ValueOrLambdaField::Lambda(lambda)] = + take_function_args(self.name(), args.arg_fields)? + else { + return exec_err!( + "{} expects a value follewed by a lambda, got {:?}", + self.name(), + args + ); + }; + + // lambda is the return_field of the lambda body + // when evaluated with the parameters from lambdas_parameters + let field = Arc::new(Field::new( + Field::LIST_FIELD_DEFAULT_NAME, + lambda.data_type().clone(), + lambda.is_nullable(), + )); + + let return_type = match list.data_type() { + DataType::List(_) => DataType::List(field), + DataType::LargeList(_) => DataType::LargeList(field), + DataType::FixedSizeList(_, size) => DataType::FixedSizeList(field, *size), + other => plan_err!("expected list, got {other}"), + }; + + Ok(Arc::new(Field::new("", return_type, list.is_nullable()))) + } +} +``` + +
+ +
array_transform invoke_with_args implementation + + +```rust +impl LambdaUDF for ArrayTransform { + fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result { + let [list_value, lambda] = take_function_args(self.name(), &args.args)?; + + // list = [2, 3] + // lambda = LambdaFunctionLambdaArg { + // params: vec![Field::new("v", DataType::Int32, false)], + // body: PhysicalExpr("v != t.n"),// the physical expression of the lambda *body*, and not the lambda itself: this is not a LambdaExpr. + // captures: Some(record_batch!("t.n", Int32, [2])) + // } + let (ValueOrLambda::Value(list_value), ValueOrLambda::Lambda(lambda)) = + (list_value, lambda) + else { + return exec_err!( + "{} expects a value followed by a lambda, got {} and {}", + self.name(), + list_value, + lambda, + ); + }; + + let list_array = list_value.to_array(args.number_rows)?; + let list_values = match list_array.data_type() { + DataType::List(_) => list_array.as_list::().values(), + DataType::LargeList(_) => list_array.as_list::().values(), + DataType::FixedSizeList(_, _) => list_array.as_fixed_size_list().values(), + other => exec_err!("expected list, got {other}") + } + + // if any column got captured, we need to adjust it to the values arrays, + // duplicating values of list with mulitple values and removing values of empty lists + // list_indices is not cheap so is important to avoid it when no column is captured + let adjusted_captures = lambda + .captures + .as_ref() + //list_indices return the row_number for each sublist element: [[1, 2], [3], [4]] => [0,0,1,2], not included here + .map(|captures| take_record_batch(captures, &list_indices(&list_array)?)) + .transpose()? + .unwrap_or_else(|| { + RecordBatch::try_new_with_options( + Arc::new(Schema::empty()), + vec![], + &RecordBatchOptions::new().with_row_count(Some(list_values.len())), + ) + .unwrap() + }); + + // by using closures, bind_lambda_variables can evaluate only the needed ones avoiding unnecessary computations + let values_param = || Ok(Arc::clone(list_values)); + //elements_indices return the index of each element within its sublist: [[5, 3], [7, 1, 1]] => [1, 2, 1, 2, 3], not included here + let indices_param = || elements_indices(&list_array); + + let binded_body = bind_lambda_variables( + Arc::clone(&lambda.body), + &lambda.params, + &[&values_param, &indices_param], + )?; + + // call the transforming expression with the record batch + let transformed_values = binded_body + .evaluate(&adjusted_captures)? + .into_array(list_values.len())?; + + let field = match args.return_field.data_type() { + DataType::List(field) + | DataType::LargeList(field) + | DataType::FixedSizeList(field, _) => Arc::clone(field), + _ => { + return exec_err!( + "{} expected ScalarFunctionArgs.return_field to be a list, got {}", + self.name(), + args.return_field + ) + } + }; + + let transformed_list = match list_array.data_type() { + DataType::List(_) => { + let list = list_array.as_list(); + + Arc::new(ListArray::new( + field, + list.offsets().clone(), + transformed_values, + list.nulls().cloned(), + )) as ArrayRef + } + DataType::LargeList(_) => { + let large_list = list_array.as_list(); + + Arc::new(LargeListArray::new( + field, + large_list.offsets().clone(), + transformed_values, + large_list.nulls().cloned(), + )) + } + DataType::FixedSizeList(_, value_length) => { + Arc::new(FixedSizeListArray::new( + field, + *value_length, + transformed_values, + list_array.as_fixed_size_list().nulls().cloned(), + )) + } + other => exec_err!("expected list, got {other}")?, + }; + + Ok(ColumnarValue::Array(transformed_list)) + } +} +``` + +
+ +
How relevant LambdaUDF methods would be called and what they would return during planning and evaluation of the example + + +```rust +// this is called at sql planning +let lambdas_parameters = lambda_udf.lambdas_parameters(&[ + ValueOrLambdaParameter::Value(Field::new("", DataType::new_list(DataType::Int32, false), false)), // the Field of the [2, 3] literal + ValueOrLambdaParameter::Lambda, // A unspecified lambda. On the example, v -> v != t.n +])?; + +assert_eq!( + lambdas_parameters, + vec![ + // the [2, 3] argument, not a lambda so no parameters + None, + // the parameters that *can* be declared on the lambda, and not only + // those actually declared: the implementation doesn't need to care + // about it + Some(vec![ + Field::new("", DataType::Int32, false), // the list inner value + Field::new("", DataType::Int32, false), // the 1-based index of the element being transformed + ])] +); + + + +// this is called every time ExprSchemable is called on a LambdaFunction +let return_field = array_transform.return_field_from_args(&LambdaReturnFieldArgs { + arg_fields: &[ + ValueOrLambdaField::Value(Field::new("", DataType::new_list(DataType::Int32, false), false)), + ValueOrLambdaField::Lambda(Field::new("", DataType::Boolean, false)), // the return_field of the expression "v != t.n" when "v" is of the type returned in lambdas_parameters + ], + scalar_arguments // irrelevant +})?; + +assert_eq!(return_field, Field::new("", DataType::new_list(DataType::Boolean, false), false)); + + + +let value = array_transform.evaluate(&LambdaFunctionArgs { + args: vec![ + ValueOrLambda::Value(List([2, 3])), + ValueOrLambda::Lambda(LambdaFunctionLambdaArg { + params: vec![Field::new("v", DataType::Int32, false)], + body: PhysicalExpr("v != t.n"),// the physical expression of the lambda *body*, and not the lambda itself: this is not a LambdaExpr. + captures: Some(record_batch!("t.n", Int32, [2])) + }), + ], + arg_fields, // same as above + number_rows: 1, + return_field, // same as above + config_options, // irrelevant +})?; + +assert_eq!(value, BooleanArray::from([false, true])) +``` + +
+
+
+ +A pair LambdaUDF/LambdaUDFImpl like ScalarFunction was not used because those exist only [to maintain backwards compatibility with the older API](https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html#api-note) #8045 + +LambdaFunction invocation: + +Instead of evaluating all it's arguments as ScalarFunction, LambdaFunction does the following: + +1. If it's a non lambda argument, evaluate as usual, and provide the resulting `ColumnarValue` to `LambdaUDF::evaluate` as a `ValueOrLambda::Value` +2. If it's a lambda, construct a `LambdaFunctionLambdaArg` containing the lambda body physical expression and a record batch containing any captured columns as a `ValueOrLambda::Lambda` and provide it to `LambdaUDF::evaluate`. To avoid costly copies of uncaptured columns, we swap them with a `NullArray` while keeping the number of columns on the batch the same so captured columns indices are kept stable across the whole tree. The recent #18329 instead projects-out uncaptured columns and rewrites the expr adjusting columns indexes. If that is preferrable we can generalize that implementation and use it here too. + +
LambdaFunction evalution + +```rust + +impl PhysicalExpr for LambdaFunctionExpr { + fn evaluate(&self, batch: &RecordBatch) -> Result { + let args = self.args + .map(|arg| { + match arg.as_any().downcast_ref::() { + Some(lambda) => { + // helper method that returns the indices of the captured columns. In the example, the only column available (index 0) is captured, so this would be HashSet(0) + let captures = lambda.captures(); + + let captures = if !captures.is_empty() { + let (fields, columns): (Vec<_>, _) = std::iter::zip( + batch.schema_ref().fields(), + batch.columns(), + ) + .enumerate() + .map(|(column_index, (field, column))| { + if captures.contains(&column_index) { + (Arc::clone(field), Arc::clone(column)) + } else { + ( + Arc::new(Field::new( + field.name(), + DataType::Null, + false, + )), + Arc::new(NullArray::new(column.len())) as _, + ) + } + }) + .unzip(); + + let schema = Arc::new(Schema::new(fields)); + + Some(RecordBatch::try_new(schema, columns)?) + } else { + None + }; + + Ok(ValueOrLambda::Lambda(LambdaFunctionLambdaArg { + params, // irrelevant, + body: Arc::clone(lambda.body()), // use the lambda body and not the lambda itself + captures, + })) + } + None => Ok(ValueOrLambda::Value(arg.evaluate(batch)?)), + } + }) + .collect::>>()?; + + // evaluate the function + let output = self.fun.invoke_with_args(LambdaFunctionArgs { + args, + arg_fields, // irrelevant + number_rows: batch.num_rows(), + return_field: Arc::clone(&self.return_field), + config_options: Arc::clone(&self.config_options), + })?; + + Ok(output) + } +} + +``` + +
+
+ +Why `LambdaVariable` and not `Column`: + +Existing tree traversals that operate on columns would break if some column nodes referenced to a lambda parameter and not a real column. In the example query, projection pushdown would try to push the lambda parameter "v", which won't exist in table "t". + +Example of code of another traversal that would break: + +```rust +fn minimize_join_filter(expr: Arc, ...) -> JoinFilter { + let mut used_columns = HashSet::new(); + expr.apply(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { + // if this is a lambda column, this function will break + used_columns.insert(col.index()); + } + Ok(TreeNodeRecursion::Continue) + }); + ... +} +``` + +Furthermore, the implemention of `ExprSchemable` and `PhysicalExpr::return_field` for `Column` expects that the schema it receives as a argument contains an entry for its name, which is not the case for lambda parameters. + +By including a `FieldRef` on `LambdaVariable` that should be resolved either during construction time, as in the sql planner, or later by the an `AnalyzerRule`, `ExprSchemable` and `PhysicalExpr::return_field` simply return it's own Field: + +
LambdaVariable ExprSchemable and PhysicalExpr::return_field implementation + +```rust +impl ExprSchemable for Expr { + fn to_field( + &self, + schema: &dyn ExprSchema, + ) -> Result<(Option, Arc)> { + let (relation, schema_name) = self.qualified_name(); + let field = match self { + Expr::LambdaVariable(l) => Ok(Arc::clone(&l.field.ok_or_else(|| plan_err!("Unresolved LambdaVariable {}", l.name)))), + ... + }?; + + Ok(( + relation, + Arc::new(field.as_ref().clone().with_name(schema_name)), + )) + } + ... +} + +impl PhysicalExpr for LambdaVariable { + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::clone(&self.field)) + } + ... +} +``` + +
+
+ +For reference, [Spark](https://github.com/apache/spark/blob/8b68a172d34d2ed9bd0a2deefcae1840a78143b6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala#L77) and [Substrait](https://substrait.io/expressions/lambda_expressions/#parameter-references) also use a specialized node instead of a regular column + +There's also discussions on making every expr own it's type: #18845, #12604 + +
Possible fixes discarded due to complexity, requiring downstream changes and implementation size: + +1. Add a new set of TreeNode methods that provides the set of lambdas parameters names seen during the traversal, so column nodes can be tested if they refer to a regular column or to a lambda parameter. Any downstream user that wants to support lambdas would need use those methods instead of the existing ones. This also would add 1k+ lines to the PR. + +```rust +impl Expr { + pub fn transform_with_lambdas_params< + F: FnMut(Self, &HashSet) -> Result>, + >( + self, + mut f: F, + ) -> Result> {} +} +``` + +How minimize_join_filter would looks like: + + +```rust +fn minimize_join_filter(expr: Arc, ...) -> JoinFilter { + let mut used_columns = HashSet::new(); + expr.apply_with_lambdas_params(|expr, lambdas_params| { + if let Some(col) = expr.as_any().downcast_ref::() { + // dont include lambdas parameters + if !lambdas_params.contains(col.name()) { + used_columns.insert(col.index()); + } + } + Ok(TreeNodeRecursion::Continue) + }) + ... +} +``` + +2. Add a flag to the Column node indicating if it refers to a lambda parameter. Still requires checking for it on existing tree traversals that works on Columns (30+) and also downstream. + +```rust +//logical +struct Column { + pub relation: Option, + pub name: String, + pub spans: Spans, + pub is_lambda_parameter: bool, +} + +//physical +struct Column { + name: String, + index: usize, + is_lambda_parameter: bool, +} +``` + + +How minimize_join_filter would look like: + +```rust +fn minimize_join_filter(expr: Arc, ...) -> JoinFilter { + let mut used_columns = HashSet::new(); + expr.apply(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { + // dont include lambdas parameters + if !col.is_lambda_parameter { + used_columns.insert(col.index()); + } + } + Ok(TreeNodeRecursion::Continue) + }) + ... +} +``` + + +1. Add a new set of TreeNode methods that provides a schema that includes the lambdas parameters for the scope of the node being visited/transformed: + +```rust +impl Expr { + pub fn transform_with_schema< + F: FnMut(Self, &DFSchema) -> Result>, + >( + self, + schema: &DFSchema, + f: F, + ) -> Result> { ... } + ... other methods +} +``` + +For any given LambdaFunction found during the traversal, a new schema is created for each lambda argument that contains it's parameter, returned from LambdaUDF::lambdas_parameters +How it would look like: + +```rust + +pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result<(Expr, bool)> { + let mut has_placeholder = false; + // Provide the schema as the first argument. + // Transforming closure receive an adjusted_schema as argument + self.transform_with_schema(schema, |mut expr, adjusted_schema| { + match &mut expr { + // Default to assuming the arguments are the same type + Expr::BinaryExpr(BinaryExpr { left, op: _, right }) => { + // use adjusted_schema and not schema. Those expressions may contain + // columns referring to a lambda parameter, which Field would only be + // available in adjusted_schema and not in schema + rewrite_placeholder(left.as_mut(), right.as_ref(), adjusted_schema)?; + rewrite_placeholder(right.as_mut(), left.as_ref(), adjusted_schema)?; + } + .... + +``` + +2. Make available trought LogicalPlan and ExecutionPlan nodes a schema that includes all lambdas parameters from all expressions owned by the node, and use this schema for tree traversals. For nodes which won't own any expression, the regular schema can be returned + + +```rust +impl LogicalPlan { + fn lambda_extended_schema(&self) -> &DFSchema; +} + +trait ExecutionPlan { + fn lambda_extended_schema(&self) -> &DFSchema; +} + +//usage +impl LogicalPlan { + pub fn replace_params_with_values( + self, + param_values: &ParamValues, + ) -> Result { + self.transform_up_with_subqueries(|plan| { + // use plan.lambda_extended_schema() containing lambdas parameters + // instead of plan.schema() which wont + let lambda_extended_schema = Arc::clone(plan.lambda_extended_schema()); + let name_preserver = NamePreserver::new(&plan); + plan.map_expressions(|e| { + // if this expression is child of lambda and contain columns referring it's parameters + // the lambda_extended_schema already contain them + let (e, has_placeholder) = e.infer_placeholder_types(&lambda_extended_schema)?; + .... + +``` +
+
+ +`LambdaVariable` evaluation, current implementation: + +The physical `LambdaVariable` contains an optional `ColumnarValue` that must be binded for each batch before evaluation with the helper function `bind_lambda_variables`, which rewrites the whole lambda body, binding any variable of the tree. + +
LambdaVariable::evaluate + +```rust +impl PhysicalExpr for LambdaVariable { + fn evaluate(&self, _batch: &RecordBatch) -> Result { + self.value.clone().ok_or_else(|| exec_datafusion_err!("Physical LambdaVariable {} unbinded value", self.name)) + } +} +``` + +
+
+ +Unbinded: +``` +LambdaExpression + parameters: ["v"] + body: + ComparisonExpression(!=) + left: + LambdaVariableExpression("v", Field::new("", Int32, false), None) + right: + ColumnExpression("n") +``` + +After binding: + +``` +LambdaExpression + parameters: ["v"] + body: + ComparisonExpression(!=) + left: + LambdaVariableExpression("v", Field::new("", Int32, false), Some([2, 3])) + right: + ColumnExpression("n") +``` + +Alternative: + +Make the `LambdaVariable` evaluate it's value from the batch passed to `PhysicalExpr::evaluate` as a regular column. For that, instead of binding the body, the `LambdaUDF` implementation would merge the captured batch of a lambda with the values of it's parameters. So that it happen via an index as a regular column, the schema used plan to physical `LambdaVariable` must contain the lambda parameters. This would be the only place during planning that a schema would contain those parameters. Otherwise it only can get the value from the batch via name instead of index + +1. Add a index to LambdaVariable, similar to Column, and remove the optional value. + +```rust +struct LambdaVariable { + name: String, // "v", of the lambda body: v != t.n + field: FieldRef, // Field::new("", DataType::Int32, false) + index: usize, // 1 +} +``` + +2. Insert the lambda parameters only at the Schema used to do the physical planning, to compute the index of a LambdaVariable + +
how physical planning would look like + +```rust +fn create_physical_expr( + e: &Expr, + input_dfschema: &DFSchema, + execution_props: &ExecutionProps, +) -> Result> { + let input_schema = input_dfschema.as_arrow(); + + match e { + ... + Expr::LambdaFunction(LambdaFunction { func, args}) => { + let args_metadata = args.iter() + .map(|arg| if arg.is::() { + Ok(ValueOrLambdaParameter::Lambda) + } else { + Ok(ValueOrLambdaParameter::Value(arg.to_field(input_dfschema)?)) + }) + .collect()?; + + let lambdas_parameters = func.lambdas_parameters(&args_metadata)?; + + let physical_args = std::iter::zip(args, lambdas_parameters) + .map(|(arg, lambda_parameters)| { + match (arg.downcast_ref::(), lambda_parameters) { + (Some(lambda), Some(lambda_parameters)) => { + let extended_dfschema = merge_schema_and_parameters(input_dfschame, lambda_parameters)?; + + create_physical_expr(body, extended_dfschema, execution_props) + } + (None, None) => create_physical_expr(arg, input_dfschema, execution_props), + (Some(_), None) => plan_err!("lambdas_parameters returned None for a lambda") + (None, Some(_)) => plan_err!("lambdas_parameters returned Some for a non lambda") + } + }) + .collect()?; + + Ok(Arc::new(LambdaFunctionExpr::try_new( + Arc::clone(func), + physical_args, + input_schema, + config_options: ... // irrelevant + )?)) + } + } +} +``` + +
+
+ +3. Insert the lambda parameters values into the RecordBatch during the evaluation phase: the LambdaUDF, instead of binding the lambda body variables, inserts it's parameters on the captured RecordBatch it receives on LambdaFunctionLambdaArg. + +How ArrayTransform::invoke_with_args would look like: + +```rust + ... + let values_param = || Ok(Arc::clone(list_values)); + let indices_param = || elements_indices(&list_array); + + let merged_batch = merge_captures_with_params( + adjusted_captures, + &lambda.params, + &[&values_param, &indices_param], + )?; + + // call the transforming expression with the record batch + let transformed_values = lambda.body + .evaluate(&merged_batch)? + .into_array(list_values.len())?; + + ... +``` + +
+ +Why is `LambdaVariable` `Field` is an `Option`? + +So expr_api users can construct a LambdaVariable just by using it's name, without having to set it's field. An `AnalyzerRule` will then set the `LambdaVariable` field based on the returned values from `LambdaUDF::lambdas_parameters` of any `LambdaFunction` it finds while traversing down a expr tree. We may include that rule on the default rules list for when the plan/expression tree is transformed by another rule in a way that changes the types of non lambda arguments of a lambda function, as it may change the types of it's lambda parameters, which would render `LambdaVariable` field's out of sync, as the rule would fix it. Or to not increase planning time we don't include it by default and instruct `expr_api` users to add it manually if needed + + + +```rust +array_transform( + col("my_array"), + lambda( + vec!["current_value"], + 2 * lambda_variable("current_value") + ) +) + +//instead of + +array_transform( + col("my_array"), + lambda( + vec!["current_value"], + 2 * lambda_variable("current_value", Field::new("", DataType::Int32, false)) + ) +) +``` + + +Why set `LambdaVariable` field during sql planning if it's optional and can be set later via an `AnalyzerRule`? + +Some parts of sql planning checks the type/nullability of the already planned children expression of the expr it's planning, and would error if doing so on a unresolved `LambdaVariable` +Take as example this expression: `array_transform([[0, 1]], v -> v[1])`. `FieldAccess` `v[1]` planning is handled by the `ExprPlanner` `FieldAccessPlanner`, which checks the datatype of `v`, a lambda variable, which `ExprSchemable` implementation depends on it's field being resolved, and not on the `PlannerContext` schema, requiring sql planner to plan `LambdaVariables` with a resolved field + + +
FieldAccessPlanner + +```rust +pub struct FieldAccessPlanner; + +impl ExprPlanner for FieldAccessPlanner { + fn plan_field_access( + &self, + expr: RawFieldAccessExpr, // "v[1]" + schema: &DFSchema, + ) -> Result> { + // { "v", "[1]" } + let RawFieldAccessExpr { expr, field_access } = expr; + + match field_access { + ... + // expr[idx] ==> array_element(expr, idx) + GetFieldAccess::ListIndex { key: index } => { + match expr { + ... + // ExprSchemable::get_type called + _ if matches!(expr.get_type(schema)?, DataType::Map(_, _)) => { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf( + get_field_inner(), + vec![expr, *index], + ), + ))) + } + } + } + } + } +} +``` + +
+
+ + Therefore we can't plan all arguments on a single pass, and must first plan the non-lambda arguments, collect their types and nullability, pass them to `LambdaUDF::lambdas_parameters`, which will derive the type of it's lambda parameters based on the type of it's non-lambda argument, and return it to the planner, which, for each unplanned lambda argument, will create a new `PlannerContext` via `with_lambda_parameters`, which contains a mapping of lambdas parameters names to it's type. Then, when planning a `ast::Identifier`, it first check whether a lambda parameter with the given name exists, and if so, plans it into a `Expr::LambdaVariable` with a resolved field, otherwise plan it into a regular `Expr::Column`. + + + +
sql planning + + +```rust +struct PlannerContext { + /// The parameters of all lambdas seen so far + lambdas_parameters: HashMap, + // ... omitted fields +} + +impl PlannerContext { + pub fn with_lambda_parameters( + mut self, + arguments: impl IntoIterator, + ) -> Self { + self.lambdas_parameters + .extend(arguments.into_iter().map(|f| (f.name().clone(), f))); + + self + } +} + +// copied from sqlparser +struct LambdaFunction { + pub params: OneOrManyWithParens, // One("v") + pub body: Box, // v != t.n +} + +// copied from sqlparser +enum OneOrManyWithParens { + One(T), // "v" + Many(Vec), +} + +/// the planning would happens as the following: + +enum ExprOrLambda { + Expr(Expr), // planned [2, 3] + Lambda(ast::LambdaFunction), // unplanned v -> v != t.n +} + +impl SqlToRel { + // example function, won't exist + fn plan_array_transform(&self, array_transform: Arc, args: Vec, schema: &DFSchema, planner_context: &mut PlannerContext) -> Result { + let args = args.into_iter() + .map(|arg| match arg { + ast::Expr::LambdaFunction(l) => Ok(ExprOrLambda::Lambda(l)),//skip planning until we plan non lambda args + arg => Ok(ExprOrLambda::Expr( + self.sql_fn_arg_to_logical_expr_with_name( + arg, + schema, + planner_context, + )?, + )) + }) + .collect::>>()?; + + let args_metadata = args.iter() + .map(|arg| match arg { + Expr(expr) => Ok(ValueOrLambda::Value(expr.to_field(schema)?)), + Lambda(_) => Ok(ValueOrLambda::Lambda), + }) + .collect::>>()?; + + let lambdas_parameters = array_transform.lambdas_parameters(&args_metadata)?; + + let args = std::iter::zip(args, lambdas_parameters) + .map(|(arg, lambdas_parameters)| match (arg, lambdas_parameters) { + (ExprOrLambda::Expr(planned_expr), None) => Ok(planned_expr), + (ExprOrLambda::Lambda(unplanned_lambda), Some(lambda_parameters)) => { + let params = + unplanned_lambda.params + .iter() + .map(|p| p.value.clone()) + .collect(); + + let lambda_parameters = lambda_params + .into_iter() + .zip(¶ms) + .map(|(field, name)| Arc::new(field.with_name(name))); + + let mut planner_context = planner_context + .clone() + .with_lambda_parameters(lambda_parameters); + + Ok(( + Expr::Lambda(Lambda { + params, + body: Box::new(self.sql_expr_to_logical_expr( + *lambda.body, + schema, + &mut planner_context, + )?), + }), + None, + )) + } + (ExprOrLambda::Expr(planned_expr), Some(lambda_parameters)) => plan_err!("lambdas_parameters returned Some for a value"), + (ExprOrLambda::Lambda(unplanned_lambda), None) => plan_err!("lambdas_parameters returned None for a lambda"), + }) + .collect::>>()?; + + Ok(Expr::LambdaFunction(LambdaFunction { + func: array_transform, + args, + })) + } + + fn sql_identifier_to_expr( + &self, + id: ast::Ident, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + // simplified implementation + if let Some(field) = planner_context.lambdas_parameters.get(id) { + Ok(Expr::LambdaVariable(LambdaVariable { + name: id, // "v" + field, // Field::new("", DataType::Int32, false) + })) + } else { + Ok(Expr::Column(Column::new(id))) + } + } +} + +``` + +
+
+ +`LambdaFunction` `Signature` is non functional + +Currenty, `LambdaUDF::signature` returns the same `Signature` as `ScalarUDF`, but it's `type_signature` field is never used, as most variants of the `TypeSignature` enum aren't applicable to a lambda, and no type coercion is applied on it's arguments, being currently a implementation responsability. We should either add lambda compatible variants to the `TypeSignature` enum, create a new `LambdaTypeSignature` and `LambdaSignature`, or support no automatic type coercion at all on lambda functions. diff --git a/datafusion-examples/examples/sql_frontend.rs b/datafusion-examples/examples/sql_frontend.rs index 1fc9ce24ecbb5..e6341080a9e11 100644 --- a/datafusion-examples/examples/sql_frontend.rs +++ b/datafusion-examples/examples/sql_frontend.rs @@ -20,8 +20,8 @@ use datafusion::common::{plan_err, TableReference}; use datafusion::config::ConfigOptions; use datafusion::error::Result; use datafusion::logical_expr::{ - AggregateUDF, Expr, LogicalPlan, ScalarUDF, TableProviderFilterPushDown, TableSource, - WindowUDF, + AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, TableProviderFilterPushDown, + TableSource, WindowUDF, }; use datafusion::optimizer::{ Analyzer, AnalyzerRule, Optimizer, OptimizerConfig, OptimizerContext, OptimizerRule, @@ -153,6 +153,10 @@ impl ContextProvider for MyContextProvider { None } + fn get_lambda_meta(&self, _name: &str) -> Option> { + None + } + fn get_aggregate_meta(&self, _name: &str) -> Option> { None } diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index 82cc36867939e..34fee4eb6bd41 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -86,7 +86,9 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { | Expr::InSubquery(_) | Expr::ScalarSubquery(_) | Expr::GroupingSet(_) - | Expr::Case(_) => Ok(TreeNodeRecursion::Continue), + | Expr::Case(_) + | Expr::Lambda(_) + | Expr::LambdaVariable(_) => Ok(TreeNodeRecursion::Continue), Expr::ScalarFunction(scalar_function) => { match scalar_function.func.signature().volatility { @@ -98,6 +100,16 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { } } } + Expr::LambdaFunction(lambda_function) => { + match lambda_function.func.signature().volatility { + Volatility::Immutable => Ok(TreeNodeRecursion::Continue), + // TODO: Stable functions could be `applicable`, but that would require access to the context + Volatility::Stable | Volatility::Volatile => { + is_applicable = false; + Ok(TreeNodeRecursion::Stop) + } + } + } // TODO other expressions are not handled yet: // - AGGREGATE and WINDOW should not end up in filter conditions, except maybe in some edge cases @@ -553,7 +565,7 @@ mod tests { use super::*; use datafusion_expr::{ - case, col, lit, AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF, + case, col, lit, AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF, }; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::ExecutionPlan; @@ -1064,6 +1076,12 @@ mod tests { unimplemented!() } + fn lambda_functions( + &self, + ) -> &std::collections::HashMap> { + unimplemented!() + } + fn aggregate_functions( &self, ) -> &std::collections::HashMap> { diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 7b145ac3ae21d..3fd0683659caf 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -22,20 +22,26 @@ pub mod memory; pub mod proxy; pub mod string_utils; -use crate::error::{_exec_datafusion_err, _internal_datafusion_err, _internal_err}; +use crate::error::{ + _exec_datafusion_err, _exec_err, _internal_datafusion_err, _internal_err, +}; use crate::{Result, ScalarValue}; use arrow::array::{ cast::AsArray, Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, }; +use arrow::array::{ArrowPrimitiveType, PrimitiveArray}; use arrow::buffer::OffsetBuffer; use arrow::compute::{partition, SortColumn, SortOptions}; -use arrow::datatypes::{DataType, Field, SchemaRef}; +use arrow::datatypes::{ + ArrowNativeType, DataType, Field, Int32Type, Int64Type, SchemaRef, +}; #[cfg(feature = "sql")] use sqlparser::{ast::Ident, dialect::GenericDialect, parser::Parser}; use std::borrow::{Borrow, Cow}; use std::cmp::{min, Ordering}; use std::collections::HashSet; +use std::iter::repeat_n; use std::num::NonZero; use std::ops::Range; use std::sync::Arc; @@ -939,6 +945,121 @@ pub fn take_function_args( }) } +/// [0, 2, 2, 5, 6] -> [0, 0, 2, 2, 2, 3] +pub fn make_list_array_indices( + offsets: &OffsetBuffer, +) -> PrimitiveArray { + let mut indices = Vec::with_capacity( + offsets.last().unwrap().as_usize() - offsets.first().unwrap().as_usize(), + ); + + for (i, (&start, &end)) in std::iter::zip(&offsets[..], &offsets[1..]).enumerate() { + indices.extend(repeat_n( + T::Native::usize_as(i), + end.as_usize() - start.as_usize(), + )); + } + + PrimitiveArray::new(indices.into(), None) +} + +/// [0, 2, 2, 5, 6] -> [0, 1, 0, 1, 2, 0] +pub fn make_list_element_indices( + offsets: &OffsetBuffer, +) -> PrimitiveArray { + let mut indices = + Vec::with_capacity(offsets.last().unwrap().as_usize() - offsets[0].as_usize()); + + for (&start, &end) in std::iter::zip(&offsets[..], &offsets[1..]) { + indices.extend( + (0..end.as_usize() - start.as_usize()).map(|i| T::Native::usize_as(i)), + ); + } + + PrimitiveArray::new(indices.into(), None) +} + +/// (3, 2) -> [0, 0, 1, 1, 2, 2] +pub fn make_fsl_array_indices( + list_size: i32, + array_len: usize, +) -> PrimitiveArray { + let mut indices = Vec::with_capacity(list_size as usize * array_len); + + for i in 0..array_len { + indices.extend(repeat_n(i as i32, list_size as usize)); + } + + PrimitiveArray::new(indices.into(), None) +} + +/// (3, 2) -> [0, 1, 0, 1, 0, 1] +pub fn make_fsl_element_indices( + list_size: i32, + array_len: usize, +) -> PrimitiveArray { + let mut indices = Vec::with_capacity(list_size as usize * array_len); + + if array_len > 0 { + indices.extend((0..list_size as usize).map(|j| j as i32)); + + for _ in 1..array_len { + indices.extend_from_within(0..list_size as usize); + } + } + + PrimitiveArray::new(indices.into(), None) +} + +pub fn list_values(array: &dyn Array) -> Result<&ArrayRef> { + match array.data_type() { + DataType::List(_) => Ok(array.as_list::().values()), + DataType::LargeList(_) => Ok(array.as_list::().values()), + DataType::FixedSizeList(_, _) => Ok(array.as_fixed_size_list().values()), + other => _exec_err!("expected list, got {other}"), + } +} + +pub fn list_indices(array: &dyn Array) -> Result { + match array.data_type() { + DataType::List(_) => Ok(Arc::new(make_list_array_indices::( + array.as_list().offsets(), + ))), + DataType::LargeList(_) => Ok(Arc::new(make_list_array_indices::( + array.as_list().offsets(), + ))), + DataType::FixedSizeList(_, _) => { + let fixed_size_list = array.as_fixed_size_list(); + + Ok(Arc::new(make_fsl_array_indices( + fixed_size_list.value_length(), + fixed_size_list.len(), + ))) + } + other => _exec_err!("expected list, got {other}"), + } +} + +pub fn elements_indices(array: &dyn Array) -> Result { + match array.data_type() { + DataType::List(_) => Ok(Arc::new(make_list_element_indices::( + array.as_list::().offsets(), + ))), + DataType::LargeList(_) => Ok(Arc::new(make_list_element_indices::( + array.as_list::().offsets(), + ))), + DataType::FixedSizeList(_, _) => { + let fixed_size_list = array.as_fixed_size_list(); + + Ok(Arc::new(make_fsl_element_indices( + fixed_size_list.value_length(), + fixed_size_list.len(), + ))) + } + other => _exec_err!("expected list, got {other}"), + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index 63387c023b11a..97282edf49381 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -18,8 +18,7 @@ use datafusion::execution::SessionStateDefaults; use datafusion_common::{not_impl_err, HashSet, Result}; use datafusion_expr::{ - aggregate_doc_sections, scalar_doc_sections, window_doc_sections, AggregateUDF, - DocSection, Documentation, ScalarUDF, WindowUDF, + AggregateUDF, DocSection, Documentation, LambdaUDF, ScalarUDF, WindowUDF, aggregate_doc_sections, scalar_doc_sections, window_doc_sections }; use itertools::Itertools; use std::env::args; @@ -303,6 +302,18 @@ impl DocProvider for WindowUDF { } } +impl DocProvider for dyn LambdaUDF { + fn get_name(&self) -> String { + self.name().to_string() + } + fn get_aliases(&self) -> Vec { + self.aliases().iter().map(|a| a.to_string()).collect() + } + fn get_documentation(&self) -> Option<&Documentation> { + self.documentation() + } +} + #[allow(clippy::borrowed_box)] #[allow(clippy::ptr_arg)] fn get_names_and_aliases(functions: &Vec<&Box>) -> Vec { diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 687779787ab50..083ecdaf575af 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -75,6 +75,7 @@ use datafusion_common::{ pub use datafusion_execution::config::SessionConfig; use datafusion_execution::registry::SerializerRegistry; pub use datafusion_execution::TaskContext; +use datafusion_expr::LambdaUDF; pub use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::{ expr_rewriter::FunctionRewrite, @@ -1786,6 +1787,21 @@ impl FunctionRegistry for SessionContext { fn udwfs(&self) -> HashSet { self.state.read().udwfs() } + + fn udlfs(&self) -> HashSet { + self.state.read().udlfs() + } + + fn udlf(&self, name: &str) -> Result> { + self.state.read().udlf(name) + } + + fn register_udlf( + &mut self, + udlf: Arc, + ) -> Result>> { + self.state.write().register_udlf(udlf) + } } /// Create a new task context instance from SessionContext diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index c15b7eae08432..f33f3d3412f4d 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -59,7 +59,7 @@ use datafusion_expr::simplify::SimplifyInfo; #[cfg(feature = "sql")] use datafusion_expr::TableSource; use datafusion_expr::{ - AggregateUDF, Explain, Expr, ExprSchemable, LogicalPlan, ScalarUDF, WindowUDF, + AggregateUDF, Explain, Expr, ExprSchemable, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF }; use datafusion_optimizer::simplify_expressions::ExprSimplifier; use datafusion_optimizer::{ @@ -154,6 +154,8 @@ pub struct SessionState { table_functions: HashMap>, /// Scalar functions that are registered with the context scalar_functions: HashMap>, + /// Lambda functions that are registered with the context + lambda_functions: HashMap>, /// Aggregate functions registered in the context aggregate_functions: HashMap>, /// Window functions registered in the context @@ -252,6 +254,10 @@ impl Session for SessionState { fn scalar_functions(&self) -> &HashMap> { &self.scalar_functions } + + fn lambda_functions(&self) -> &HashMap> { + &self.lambda_functions + } fn aggregate_functions(&self) -> &HashMap> { &self.aggregate_functions @@ -921,6 +927,7 @@ pub struct SessionStateBuilder { catalog_list: Option>, table_functions: Option>>, scalar_functions: Option>>, + lambda_functions: Option>>, aggregate_functions: Option>>, window_functions: Option>>, serializer_registry: Option>, @@ -958,6 +965,7 @@ impl SessionStateBuilder { catalog_list: None, table_functions: None, scalar_functions: None, + lambda_functions: None, aggregate_functions: None, window_functions: None, serializer_registry: None, @@ -1008,6 +1016,7 @@ impl SessionStateBuilder { catalog_list: Some(existing.catalog_list), table_functions: Some(existing.table_functions), scalar_functions: Some(existing.scalar_functions.into_values().collect_vec()), + lambda_functions: Some(existing.lambda_functions.into_values().collect_vec()), aggregate_functions: Some( existing.aggregate_functions.into_values().collect_vec(), ), @@ -1048,6 +1057,10 @@ impl SessionStateBuilder { self.scalar_functions .get_or_insert_with(Vec::new) .extend(SessionStateDefaults::default_scalar_functions()); + + self.lambda_functions + .get_or_insert_with(Vec::new) + .extend(SessionStateDefaults::default_lambda_functions()); self.aggregate_functions .get_or_insert_with(Vec::new) @@ -1362,6 +1375,7 @@ impl SessionStateBuilder { catalog_list, table_functions, scalar_functions, + lambda_functions, aggregate_functions, window_functions, serializer_registry, @@ -1395,6 +1409,7 @@ impl SessionStateBuilder { }), table_functions: table_functions.unwrap_or_default(), scalar_functions: HashMap::new(), + lambda_functions: HashMap::new(), aggregate_functions: HashMap::new(), window_functions: HashMap::new(), serializer_registry: serializer_registry @@ -1446,6 +1461,34 @@ impl SessionStateBuilder { } } } + + if let Some(lambda_functions) = lambda_functions { + for udlf in lambda_functions { + let config_options = state.config().options(); + match udlf.with_updated_config(config_options) { + Some(new_udf) => { + if let Err(err) = state.register_udlf(new_udf) { + debug!( + "Failed to re-register updated UDLF '{}': {}", + udlf.name(), + err + ); + } + } + None => match state.register_udlf(Arc::clone(&udlf)) { + Ok(Some(existing)) => { + debug!("Overwrote existing UDLF '{}'", existing.name()); + } + Ok(None) => { + debug!("Registered UDLF '{}'", udlf.name()); + } + Err(err) => { + debug!("Failed to register UDLF '{}': {}", udlf.name(), err); + } + }, + } + } + } if let Some(aggregate_functions) = aggregate_functions { aggregate_functions.into_iter().for_each(|udaf| { @@ -1661,6 +1704,7 @@ impl Debug for SessionStateBuilder { .field("physical_optimizers", &self.physical_optimizers) .field("table_functions", &self.table_functions) .field("scalar_functions", &self.scalar_functions) + .field("lambda_functions", &self.lambda_functions) .field("aggregate_functions", &self.aggregate_functions) .field("window_functions", &self.window_functions) .finish() @@ -1755,6 +1799,10 @@ impl ContextProvider for SessionContextProvider<'_> { self.state.scalar_functions().get(name).cloned() } + fn get_lambda_meta(&self, name: &str) -> Option> { + self.state.lambda_functions().get(name).cloned() + } + fn get_aggregate_meta(&self, name: &str) -> Option> { self.state.aggregate_functions().get(name).cloned() } @@ -1918,6 +1966,37 @@ impl FunctionRegistry for SessionState { Ok(udwf) } + fn udlfs(&self) -> HashSet { + self.lambda_functions.keys().cloned().collect() + } + + fn udlf(&self, name: &str) -> datafusion_common::Result> { + self.lambda_functions + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("Lambda Function {name} not found")) + } + + fn register_udlf( + &mut self, + udlf: Arc, + ) -> datafusion_common::Result>> { + Ok(self.lambda_functions.insert(udlf.name().into(), udlf)) + } + + fn deregister_udlf( + &mut self, + name: &str, + ) -> datafusion_common::Result>> { + let udlf = self.lambda_functions.remove(name); + if let Some(udlf) = &udlf { + for alias in udlf.aliases() { + self.lambda_functions.remove(alias); + } + } + Ok(udlf) + } + fn register_function_rewrite( &mut self, rewrite: Arc, @@ -1974,6 +2053,7 @@ impl From<&SessionState> for TaskContext { state.session_id.clone(), state.config.clone(), state.scalar_functions.clone(), + state.lambda_functions.clone(), state.aggregate_functions.clone(), state.window_functions.clone(), Arc::clone(&state.runtime_env), @@ -2062,6 +2142,7 @@ mod tests { use datafusion_optimizer::optimizer::OptimizerRule; use datafusion_optimizer::Optimizer; use datafusion_physical_plan::display::DisplayableExecutionPlan; + use datafusion_session::Session; use datafusion_sql::planner::{PlannerContext, SqlToRel}; use std::collections::HashMap; use std::sync::Arc; @@ -2338,6 +2419,10 @@ mod tests { self.state.scalar_functions().get(name).cloned() } + fn get_lambda_meta(&self, name: &str) -> Option> { + self.state.lambda_functions().get(name).cloned() + } + fn get_aggregate_meta(&self, name: &str) -> Option> { self.state.aggregate_functions().get(name).cloned() } diff --git a/datafusion/core/src/execution/session_state_defaults.rs b/datafusion/core/src/execution/session_state_defaults.rs index 62a575541a5d8..54037c0a96f9c 100644 --- a/datafusion/core/src/execution/session_state_defaults.rs +++ b/datafusion/core/src/execution/session_state_defaults.rs @@ -36,7 +36,8 @@ use datafusion_execution::config::SessionConfig; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::planner::ExprPlanner; -use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; +use datafusion_expr::{AggregateUDF, LambdaUDF, ScalarUDF, WindowUDF}; +use datafusion_functions_nested::array_transform::ArrayTransform; use std::collections::HashMap; use std::sync::Arc; use url::Url; @@ -112,6 +113,11 @@ impl SessionStateDefaults { functions } + /// returns the list of default [`LambdaUDF`]s + pub fn default_lambda_functions() -> Vec> { + vec![Arc::new(ArrayTransform::new())] + } + /// returns the list of default [`AggregateUDF`]s pub fn default_aggregate_functions() -> Vec> { functions_aggregate::all_default_aggregate_functions() diff --git a/datafusion/core/tests/optimizer/mod.rs b/datafusion/core/tests/optimizer/mod.rs index 9b2a5596827d0..44e40143fe171 100644 --- a/datafusion/core/tests/optimizer/mod.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -31,8 +31,7 @@ use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{plan_err, DFSchema, Result, ScalarValue, TableReference}; use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; use datafusion_expr::{ - col, lit, AggregateUDF, BinaryExpr, Expr, ExprSchemable, LogicalPlan, Operator, - ScalarUDF, TableSource, WindowUDF, + AggregateUDF, BinaryExpr, Expr, ExprSchemable, LambdaUDF, LogicalPlan, Operator, ScalarUDF, TableSource, WindowUDF, col, lit }; use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_optimizer::analyzer::Analyzer; @@ -217,6 +216,10 @@ impl ContextProvider for MyContextProvider { self.udfs.get(name).cloned() } + fn get_lambda_meta(&self, _name: &str) -> Option> { + None + } + fn get_aggregate_meta(&self, _name: &str) -> Option> { None } diff --git a/datafusion/datasource-arrow/src/file_format.rs b/datafusion/datasource-arrow/src/file_format.rs index 3b85640804219..31a880e688aeb 100644 --- a/datafusion/datasource-arrow/src/file_format.rs +++ b/datafusion/datasource-arrow/src/file_format.rs @@ -442,7 +442,7 @@ mod tests { use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::execution_props::ExecutionProps; - use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF}; + use datafusion_expr::{AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use object_store::{chunked::ChunkedStore, memory::InMemory, path::Path}; @@ -488,6 +488,10 @@ mod tests { fn scalar_functions(&self) -> &HashMap> { unimplemented!() } + + fn lambda_functions(&self) -> &HashMap> { + unimplemented!() + } fn aggregate_functions(&self) -> &HashMap> { unimplemented!() diff --git a/datafusion/datasource/src/url.rs b/datafusion/datasource/src/url.rs index 08e5b6a5df83a..b0b84bd3cc943 100644 --- a/datafusion/datasource/src/url.rs +++ b/datafusion/datasource/src/url.rs @@ -415,7 +415,7 @@ mod tests { use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_expr::execution_props::ExecutionProps; - use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF}; + use datafusion_expr::{AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::ExecutionPlan; use object_store::{ @@ -874,6 +874,10 @@ mod tests { unimplemented!() } + fn lambda_functions(&self) -> &HashMap> { + unimplemented!() + } + fn aggregate_functions(&self) -> &HashMap> { unimplemented!() } diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index c2a6cfe2c833f..70c59b6375943 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -21,7 +21,7 @@ use crate::{ }; use datafusion_common::{internal_datafusion_err, plan_datafusion_err, Result}; use datafusion_expr::planner::ExprPlanner; -use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; +use datafusion_expr::{AggregateUDF, ScalarUDF, LambdaUDF, WindowUDF}; use std::collections::HashSet; use std::{collections::HashMap, sync::Arc}; @@ -42,6 +42,8 @@ pub struct TaskContext { session_config: SessionConfig, /// Scalar functions associated with this task context scalar_functions: HashMap>, + /// Lambda functions associated with this task context + lambda_functions: HashMap>, /// Aggregate functions associated with this task context aggregate_functions: HashMap>, /// Window functions associated with this task context @@ -60,6 +62,7 @@ impl Default for TaskContext { task_id: None, session_config: SessionConfig::new(), scalar_functions: HashMap::new(), + lambda_functions: HashMap::new(), aggregate_functions: HashMap::new(), window_functions: HashMap::new(), runtime, @@ -73,11 +76,13 @@ impl TaskContext { /// Most users will use [`SessionContext::task_ctx`] to create [`TaskContext`]s /// /// [`SessionContext::task_ctx`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.task_ctx + #[allow(clippy::too_many_arguments)] pub fn new( task_id: Option, session_id: String, session_config: SessionConfig, scalar_functions: HashMap>, + lambda_functions: HashMap>, aggregate_functions: HashMap>, window_functions: HashMap>, runtime: Arc, @@ -87,6 +92,7 @@ impl TaskContext { session_id, session_config, scalar_functions, + lambda_functions, aggregate_functions, window_functions, runtime, @@ -198,6 +204,37 @@ impl FunctionRegistry for TaskContext { Ok(self.scalar_functions.insert(udf.name().into(), udf)) } + fn udlfs(&self) -> HashSet { + self.lambda_functions.keys().cloned().collect() + } + + fn udlf(&self, name: &str) -> Result> { + self.lambda_functions + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("Lambda Function {name} not found")) + } + + fn register_udlf( + &mut self, + udlf: Arc, + ) -> Result>> { + Ok(self.lambda_functions.insert(udlf.name().into(), udlf)) + } + + fn deregister_udlf( + &mut self, + name: &str, + ) -> Result>> { + let udlf = self.lambda_functions.remove(name); + if let Some(udlf) = &udlf { + for alias in udlf.aliases() { + self.lambda_functions.remove(alias); + } + } + Ok(udlf) + } + fn expr_planners(&self) -> Vec> { vec![] } @@ -248,6 +285,7 @@ mod tests { HashMap::default(), HashMap::default(), HashMap::default(), + HashMap::default(), runtime, ); @@ -280,6 +318,7 @@ mod tests { HashMap::default(), HashMap::default(), HashMap::default(), + HashMap::default(), runtime, ); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 13160d573ab4d..ef4ba4ef5cdfd 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -27,6 +27,7 @@ use std::sync::Arc; use crate::expr_fn::binary_expr; use crate::function::WindowFunctionSimplification; use crate::logical_plan::Subquery; +use crate::udlf::LambdaUDF; use crate::{AggregateUDF, Volatility}; use crate::{ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; @@ -398,6 +399,60 @@ pub enum Expr { OuterReferenceColumn(FieldRef, Column), /// Unnest expression Unnest(Unnest), + LambdaFunction(LambdaFunction), + /// Lambda expression + Lambda(Lambda), + LambdaVariable(LambdaVariable), +} + +#[derive(Clone, Eq, PartialOrd, Debug)] +pub struct LambdaFunction { + pub func: Arc, + pub args: Vec, +} + +impl LambdaFunction { + pub fn new(func: Arc, args: Vec) -> Self { + Self { func, args } + } + + pub fn name(&self) -> &str { + self.func.name() + } +} + +impl Hash for LambdaFunction { + fn hash(&self, state: &mut H) { + self.func.hash(state); + self.args.hash(state); + } +} + +impl PartialEq for LambdaFunction { + fn eq(&self, other: &Self) -> bool { + self.func.as_ref() == other.func.as_ref() && self.args == other.args + } +} + +#[derive(Clone, PartialEq, PartialOrd, Eq, Debug, Hash)] +pub struct LambdaVariable { + pub name: String, + pub field: FieldRef, + pub spans: Spans, +} + +impl LambdaVariable { + pub fn new(name: String, field: FieldRef) -> Self { + Self { + name, + field, + spans: Spans::new(), + } + } + + pub fn spans_mut(&mut self) -> &mut Spans { + &mut self.spans + } } impl Default for Expr { @@ -1211,6 +1266,23 @@ impl GroupingSet { } } +/// Lambda expression. +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct Lambda { + pub params: Vec, + pub body: Box, +} + +impl Lambda { + /// Create a new lambda expression + pub fn new(params: Vec, body: Expr) -> Self { + Self { + params, + body: Box::new(body), + } + } +} + #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] #[cfg(not(feature = "sql"))] pub struct IlikeSelectItem { @@ -1525,6 +1597,9 @@ impl Expr { #[expect(deprecated)] Expr::Wildcard { .. } => "Wildcard", Expr::Unnest { .. } => "Unnest", + Expr::LambdaFunction { .. } => "LambdaFunction", + Expr::Lambda { .. } => "Lambda", + Expr::LambdaVariable { .. } => "LambdaVariable", } } @@ -2040,6 +2115,7 @@ impl Expr { pub fn short_circuits(&self) -> bool { match self { Expr::ScalarFunction(ScalarFunction { func, .. }) => func.short_circuits(), + Expr::LambdaFunction(LambdaFunction { func, .. }) => func.short_circuits(), Expr::BinaryExpr(BinaryExpr { op, .. }) => { matches!(op, Operator::And | Operator::Or) } @@ -2078,7 +2154,9 @@ impl Expr { | Expr::Wildcard { .. } | Expr::WindowFunction(..) | Expr::Literal(..) - | Expr::Placeholder(..) => false, + | Expr::Placeholder(..) + | Expr::Lambda(..) + | Expr::LambdaVariable(..) => false, } } @@ -2674,6 +2752,20 @@ impl HashNode for Expr { column.hash(state); } Expr::Unnest(Unnest { expr: _expr }) => {} + Expr::LambdaFunction(LambdaFunction { func, args: _args }) => { + func.hash(state); + } + Expr::Lambda(Lambda { params, body: _ }) => { + params.hash(state); + } + Expr::LambdaVariable(LambdaVariable { + name, + field, + spans: _, + }) => { + name.hash(state); + field.hash(state); + } }; } } @@ -2987,6 +3079,22 @@ impl Display for SchemaDisplay<'_> { } } } + Expr::LambdaFunction(LambdaFunction { func, args }) => { + match func.schema_name(args) { + Ok(name) => { + write!(f, "{name}") + } + Err(e) => { + write!(f, "got error from schema_name {e}") + } + } + } + Expr::Lambda(Lambda { params, body }) => { + write!(f, "({}) -> {body}", display_comma_separated(params)) + } + Expr::LambdaVariable(c) => { + write!(f, "{}", c.name) + } } } } @@ -3167,6 +3275,9 @@ impl Display for SqlDisplay<'_> { } } } + Expr::Lambda(Lambda { params, body }) => { + write!(f, "({}) -> {}", params.join(", "), SchemaDisplay(body)) + } _ => write!(f, "{}", self.0), } } @@ -3474,6 +3585,15 @@ impl Display for Expr { Expr::Unnest(Unnest { expr }) => { write!(f, "{UNNEST_COLUMN_PREFIX}({expr})") } + Expr::LambdaFunction(fun) => { + fmt_function(f, fun.name(), false, &fun.args, true) + } + Expr::Lambda(Lambda { params, body }) => { + write!(f, "({}) -> {body}", params.join(", ")) + } + Expr::LambdaVariable(c) => { + write!(f, "{}", c.name) + } } } } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 9e8d6080b82c8..f3789ca9fd115 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -18,17 +18,23 @@ use super::{Between, Expr, Like}; use crate::expr::{ AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList, - InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, + InSubquery, Lambda, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, }; +use crate::expr::{FieldMetadata, LambdaVariable}; use crate::type_coercion::functions::{ - data_types_with_scalar_udf, fields_with_aggregate_udf, fields_with_window_udf, + fields_with_aggregate_udf, fields_with_window_udf, +}; +use crate::udlf::{LambdaReturnFieldArgs, ValueOrLambdaField}; +use crate::{ + type_coercion::functions::data_types_with_scalar_udf, udf::ReturnFieldArgs, utils, + LogicalPlan, Projection, Subquery, WindowFunctionDefinition, +}; +use arrow::datatypes::FieldRef; +use arrow::{ + compute::can_cast_types, + datatypes::{DataType, Field}, }; -use crate::udf::ReturnFieldArgs; -use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; -use arrow::compute::can_cast_types; -use arrow::datatypes::{DataType, Field, FieldRef}; -use datafusion_common::metadata::FieldMetadata; use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, Result, Spans, TableReference, @@ -229,6 +235,14 @@ impl ExprSchemable for Expr { // Grouping sets do not really have a type and do not appear in projections Ok(DataType::Null) } + Expr::LambdaFunction(_func) => { + let (return_type, _) = self.data_type_and_nullable(schema)?; + Ok(return_type) + } + Expr::Lambda(Lambda { params: _, body }) => body.get_type(schema), + Expr::LambdaVariable(LambdaVariable { name: _, field, .. }) => { + Ok(field.data_type().clone()) + } } } @@ -347,6 +361,12 @@ impl ExprSchemable for Expr { // in projections Ok(true) } + Expr::LambdaFunction(_func) => { + let (_, nullable) = self.data_type_and_nullable(input_schema)?; + Ok(nullable) + } + Expr::Lambda(l) => l.body.nullable(input_schema), + Expr::LambdaVariable(c) => Ok(c.field.is_nullable()), } } @@ -543,6 +563,7 @@ impl ExprSchemable for Expr { .into_iter() .map(|f| (f.data_type().clone(), f)) .unzip(); + // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` let new_data_types = data_types_with_scalar_udf(&arg_types, func) .map_err(|err| { @@ -573,6 +594,7 @@ impl ExprSchemable for Expr { _ => None, }) .collect::>(); + let args = ReturnFieldArgs { arg_fields: &new_fields, scalar_arguments: &arguments, @@ -600,11 +622,43 @@ impl ExprSchemable for Expr { | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) - | Expr::Unnest(_) => Ok(Arc::new(Field::new( + | Expr::Unnest(_) + | Expr::Lambda(_) => Ok(Arc::new(Field::new( &schema_name, self.get_type(schema)?, self.nullable(schema)?, ))), + Expr::LambdaFunction(func) => { + let arg_fields = func + .args + .iter() + .map(|arg| { + let field = arg.to_field(schema)?.1; + match arg { + Expr::Lambda(_lambda) => { + Ok(ValueOrLambdaField::Lambda(field)) + } + _ => Ok(ValueOrLambdaField::Value(field)), + } + }) + .collect::>>()?; + + let arguments = func.args + .iter() + .map(|e| match e { + Expr::Literal(sv, _) => Some(sv), + _ => None, + }) + .collect::>(); + + let args = LambdaReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &arguments, + }; + + func.func.return_field_from_args(args) + } + Expr::LambdaVariable(c) => Ok(Arc::clone(&c.field)), }?; Ok(( diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 2b7cc9d46ad34..b03bab622a357 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -42,6 +42,7 @@ mod partition_evaluator; mod table_source; mod udaf; mod udf; +mod udlf; mod udwf; pub mod arguments; @@ -118,6 +119,10 @@ pub use udaf::{ ReversedUDAF, SetMonotonicity, StatisticsArgs, }; pub use udf::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; +pub use udlf::{ + LambdaFunctionArgs, LambdaFunctionLambdaArg, LambdaReturnFieldArgs, LambdaSignature, + LambdaTypeSignature, LambdaUDF, ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, +}; pub use udwf::{LimitEffect, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 25a0f83947eee..7696faca0922a 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -22,8 +22,7 @@ use std::sync::Arc; use crate::expr::NullTreatment; use crate::{ - AggregateUDF, Expr, GetFieldAccess, ScalarUDF, SortExpr, TableSource, WindowFrame, - WindowFunctionDefinition, WindowUDF, + AggregateUDF, Expr, GetFieldAccess, LambdaUDF, ScalarUDF, SortExpr, TableSource, WindowFrame, WindowFunctionDefinition, WindowUDF }; use arrow::datatypes::{DataType, Field, SchemaRef}; use datafusion_common::{ @@ -91,6 +90,9 @@ pub trait ContextProvider { /// Return the scalar function with a given name, if any fn get_function_meta(&self, name: &str) -> Option>; + + /// Return the lambda function with a given name, if any + fn get_lambda_meta(&self, name: &str) -> Option>; /// Return the aggregate function with a given name, if any fn get_aggregate_meta(&self, name: &str) -> Option>; diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index 9554dd68e1758..92aa39d64c98d 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -19,6 +19,7 @@ use crate::expr_rewriter::FunctionRewrite; use crate::planner::ExprPlanner; +use crate::udlf::LambdaUDF; use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; use datafusion_common::{not_impl_err, plan_datafusion_err, HashMap, Result}; use std::collections::HashSet; @@ -30,6 +31,9 @@ pub trait FunctionRegistry { /// Returns names of all available scalar user defined functions. fn udfs(&self) -> HashSet; + /// Returns names of all available lambda user defined functions. + fn udlfs(&self) -> HashSet; + /// Returns names of all available aggregate user defined functions. fn udafs(&self) -> HashSet; @@ -40,6 +44,10 @@ pub trait FunctionRegistry { /// `name`. fn udf(&self, name: &str) -> Result>; + /// Returns a reference to the user defined lambda function (udf) named + /// `name`. + fn udlf(&self, name: &str) -> Result>; + /// Returns a reference to the user defined aggregate function (udaf) named /// `name`. fn udaf(&self, name: &str) -> Result>; @@ -56,6 +64,17 @@ pub trait FunctionRegistry { fn register_udf(&mut self, _udf: Arc) -> Result>> { not_impl_err!("Registering ScalarUDF") } + /// Registers a new [`LambdaUDF`], returning any previously registered + /// implementation. + /// + /// Returns an error (the default) if the function can not be registered, + /// for example if the registry is read only. + fn register_udlf( + &mut self, + _udlf: Arc, + ) -> Result>> { + not_impl_err!("Registering LambdaUDF") + } /// Registers a new [`AggregateUDF`], returning any previously registered /// implementation. /// @@ -85,6 +104,15 @@ pub trait FunctionRegistry { not_impl_err!("Deregistering ScalarUDF") } + /// Deregisters a [`LambdaUDF`], returning the implementation that was + /// deregistered. + /// + /// Returns an error (the default) if the function can not be deregistered, + /// for example if the registry is read only. + fn deregister_udlf(&mut self, _name: &str) -> Result>> { + not_impl_err!("Deregistering LambdaUDF") + } + /// Deregisters a [`AggregateUDF`], returning the implementation that was /// deregistered. /// @@ -152,6 +180,8 @@ pub trait SerializerRegistry: Debug + Send + Sync { pub struct MemoryFunctionRegistry { /// Scalar Functions udfs: HashMap>, + /// Lambda Functions + udlfs: HashMap>, /// Aggregate Functions udafs: HashMap>, /// Window Functions @@ -214,4 +244,22 @@ impl FunctionRegistry for MemoryFunctionRegistry { fn udwfs(&self) -> HashSet { self.udwfs.keys().cloned().collect() } + + fn udlfs(&self) -> HashSet { + self.udlfs.keys().cloned().collect() + } + + fn udlf(&self, name: &str) -> Result> { + self.udlfs + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("Lambda Function {name} not found")) + } + + fn register_udlf( + &mut self, + udlf: Arc, + ) -> Result>> { + Ok(self.udlfs.insert(udlf.name().into(), udlf)) + } } diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 81846b4f80608..82179f095937b 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -17,17 +17,20 @@ //! Tree node implementation for Logical Expressions -use crate::expr::{ - AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Cast, - GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, - WindowFunction, WindowFunctionParams, +use crate::{ + expr::{ + AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, + Cast, GroupingSet, InList, InSubquery, Lambda, LambdaFunction, Like, Placeholder, + ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, + }, + Expr, }; -use crate::Expr; - -use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer, +use datafusion_common::{ + tree_node::{ + Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer, + }, + Result, }; -use datafusion_common::Result; /// Implementation of the [`TreeNode`] trait /// @@ -77,7 +80,8 @@ impl TreeNode for Expr { | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::Wildcard { .. } - | Expr::Placeholder(_) => Ok(TreeNodeRecursion::Continue), + | Expr::Placeholder(_) + | Expr::LambdaVariable(_) => Ok(TreeNodeRecursion::Continue), Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { (left, right).apply_ref_elements(f) } @@ -106,6 +110,8 @@ impl TreeNode for Expr { Expr::InList(InList { expr, list, .. }) => { (expr, list).apply_ref_elements(f) } + Expr::LambdaFunction(LambdaFunction { func: _, args}) => args.apply_elements(f), + Expr::Lambda (Lambda{ params: _, body}) => body.apply_elements(f) } } @@ -127,7 +133,8 @@ impl TreeNode for Expr { | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) - | Expr::Literal(_, _) => Transformed::no(self), + | Expr::Literal(_, _) + | Expr::LambdaVariable(_) => Transformed::no(self), Expr::Unnest(Unnest { expr, .. }) => expr .map_elements(f)? .update_data(|expr| Expr::Unnest(Unnest { expr })), @@ -311,6 +318,12 @@ impl TreeNode for Expr { .update_data(|(new_expr, new_list)| { Expr::InList(InList::new(new_expr, new_list, negated)) }), + Expr::LambdaFunction(LambdaFunction { func, args }) => args + .map_elements(f)? + .update_data(|args| Expr::LambdaFunction(LambdaFunction { func, args })), + Expr::Lambda(Lambda { params, body }) => body + .map_elements(f)? + .update_data(|body| Expr::Lambda(Lambda { params, body })), }) } } diff --git a/datafusion/expr/src/udf_eq.rs b/datafusion/expr/src/udf_eq.rs index 6664495267129..a003f05c15b5e 100644 --- a/datafusion/expr/src/udf_eq.rs +++ b/datafusion/expr/src/udf_eq.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::{AggregateUDFImpl, ScalarUDFImpl, WindowUDFImpl}; +use crate::{AggregateUDFImpl, LambdaUDF, ScalarUDFImpl, WindowUDFImpl}; use std::fmt::Debug; use std::hash::{DefaultHasher, Hash, Hasher}; use std::ops::Deref; @@ -93,6 +93,18 @@ impl UdfPointer for Arc { } } +impl UdfPointer for Arc { + fn equals(&self, other: &Self::Target) -> bool { + self.as_ref().dyn_eq(other.as_any()) + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.as_ref().dyn_hash(hasher); + hasher.finish() + } +} + impl UdfPointer for Arc { fn equals(&self, other: &(dyn AggregateUDFImpl + '_)) -> bool { self.as_ref().dyn_eq(other.as_any()) diff --git a/datafusion/expr/src/udlf.rs b/datafusion/expr/src/udlf.rs new file mode 100644 index 0000000000000..b2191e0883d30 --- /dev/null +++ b/datafusion/expr/src/udlf.rs @@ -0,0 +1,736 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`LambdaUDF`]: Lambda User Defined Functions + +use crate::expr::schema_name_from_exprs_comma_separated_without_space; +use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; +use crate::sort_properties::{ExprProperties, SortProperties}; +use crate::{ColumnarValue, Documentation, Expr}; +use arrow::array::RecordBatch; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{Result, ScalarValue, not_impl_err}; +use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; +use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_expr_common::signature::{Volatility}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::any::Any; +use std::cmp::Ordering; +use std::fmt::Debug; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +/// The types of arguments for which a function has implementations. +/// +/// [`LambdaTypeSignature`] **DOES NOT** define the types that a user query could call the +/// function with. DataFusion will automatically coerce (cast) argument types to +/// one of the supported function signatures, if possible. +/// +/// # Overview +/// Functions typically provide implementations for a small number of different +/// argument [`DataType`]s, rather than all possible combinations. If a user +/// calls a function with arguments that do not match any of the declared types, +/// DataFusion will attempt to automatically coerce (add casts to) function +/// arguments so they match the [`LambdaTypeSignature`]. See the [`type_coercion`] module +/// for more details +/// +/// [`type_coercion`]: crate::type_coercion +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub enum LambdaTypeSignature { + /// The acceptable signature and coercions rules are special for this + /// function. + /// + /// If this signature is specified, + /// DataFusion will call [`LambdaUDF::coerce_value_types`] to prepare argument types. + /// + /// [`LambdaUDF::coerce_value_types`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/trait.LambdaUDF.html#method.coerce_value_types + UserDefined, + /// One or more lambdas or arguments with arbitrary types + VariadicAny, + /// The specified number of lambdas or arguments with arbitrary types. + Any(usize), +} + +/// Provides information necessary for calling a lambda function. +/// +/// - [`Volatility`] defines how the output of the function changes with the input. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub struct LambdaSignature { + /// The data types that the function accepts. See [LambdaTypeSignature] for more information. + pub type_signature: LambdaTypeSignature, + /// The volatility of the function. See [Volatility] for more information. + pub volatility: Volatility, + /// Optional parameter names for the function arguments. + /// + /// If provided, enables named argument notation for function calls (e.g., `func(a => 1, b => 2)`). + /// + /// Defaults to `None`, meaning only positional arguments are supported. + pub parameter_names: Option>, +} + +impl LambdaSignature { + /// Creates a new LambdaSignature from a given type signature and volatility. + pub fn new(type_signature: LambdaTypeSignature, volatility: Volatility) -> Self { + LambdaSignature { + type_signature, + volatility, + parameter_names: None, + } + } + + /// User-defined coercion rules for the function. + pub fn user_defined(volatility: Volatility) -> Self { + Self { + type_signature: LambdaTypeSignature::UserDefined, + volatility, + parameter_names: None, + } + } + + /// An arbitrary number of lambdas or arguments of any type. + pub fn variadic_any(volatility: Volatility) -> Self { + Self { + type_signature: LambdaTypeSignature::VariadicAny, + volatility, + parameter_names: None, + } + } + + /// A specified number of arguments of any type + pub fn any(arg_count: usize, volatility: Volatility) -> Self { + Self { + type_signature: LambdaTypeSignature::Any(arg_count), + volatility, + parameter_names: None, + } + } +} + +impl PartialEq for dyn LambdaUDF { + fn eq(&self, other: &Self) -> bool { + self.dyn_eq(other.as_any()) + } +} + +impl PartialOrd for dyn LambdaUDF { + fn partial_cmp(&self, other: &Self) -> Option { + let mut cmp = self.name().cmp(other.name()); + if cmp == Ordering::Equal { + cmp = self.signature().partial_cmp(other.signature())?; + } + if cmp == Ordering::Equal { + cmp = self.aliases().partial_cmp(other.aliases())?; + } + // Contract for PartialOrd and PartialEq consistency requires that + // a == b if and only if partial_cmp(a, b) == Some(Equal). + if cmp == Ordering::Equal && self != other { + // Functions may have other properties besides name and signature + // that differentiate two instances (e.g. type, or arbitrary parameters). + // We cannot return Some(Equal) in such case. + return None; + } + debug_assert!( + cmp == Ordering::Equal || self != other, + "Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \ + The functions compare as equal, but they are not equal based on general properties that \ + the PartialOrd implementation observes,", + self.name(), other.name() + ); + Some(cmp) + } +} + +impl Eq for dyn LambdaUDF {} + +impl Hash for dyn LambdaUDF { + fn hash(&self, state: &mut H) { + self.dyn_hash(state) + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ValueOrLambdaParameter { + /// A value with the given associated data + Value(T), + /// A lambda + Lambda, +} + +/// Arguments passed to [`LambdaUDF::invoke_with_args`] when invoking a +/// lambda function. +#[derive(Debug, Clone)] +pub struct LambdaFunctionArgs { + /// The evaluated arguments to the function + pub args: Vec, + /// Field associated with each arg, if it exists + pub arg_fields: Vec, + /// The number of rows in record batch being evaluated + pub number_rows: usize, + /// The return field of the lambda function returned (from `return_type` + /// or `return_field_from_args`) when creating the physical expression + /// from the logical expression + pub return_field: FieldRef, + /// The config options at execution time + pub config_options: Arc, +} + +/// A lambda argument to a LambdaFunction +#[derive(Clone, Debug)] +pub struct LambdaFunctionLambdaArg { + /// The parameters defined in this lambda + /// + /// For example, for `array_transform([2], v -> -v)`, + /// this will be `vec![Field::new("v", DataType::Int32, true)]` + pub params: Vec, + /// The body of the lambda + /// + /// For example, for `array_transform([2], v -> -v)`, + /// this will be the physical expression of `-v` + pub body: Arc, + /// A RecordBatch containing at least the captured columns inside this lambda body, if any + /// Note that it may contain additional, non-specified columns, but that's implementation detail + /// + /// For example, for `array_transform([2], v -> v + a + b)`, + /// this will be a `RecordBatch` with two columns, `a` and `b` + pub captures: Option, +} + +impl LambdaFunctionArgs { + /// The return type of the function. See [`Self::return_field`] for more + /// details. + pub fn return_type(&self) -> &DataType { + self.return_field.data_type() + } +} + +// An argument to a LambdaUDF that supports lambdas +#[derive(Clone, Debug)] +pub enum ValueOrLambda { + Value(ColumnarValue), + Lambda(LambdaFunctionLambdaArg), +} + +/// Information about arguments passed to the function +/// +/// This structure contains metadata about how the function was called +/// such as the type of the arguments, any scalar arguments and if the +/// arguments can (ever) be null +/// +/// See [`LambdaUDF::return_field_from_args`] for more information +#[derive(Clone, Debug)] +pub struct LambdaReturnFieldArgs<'a> { + /// The data types of the arguments to the function + /// + /// If argument `i` to the function is a lambda, it will be the field returned by the + /// lambda when executed with the arguments returned from `LambdaUDF::lambdas_parameters` + /// + /// For example, with `array_transform([1], v -> v == 5)` + /// this field will be `[Field::new("", DataType::List(DataType::Int32), false), Field::new("", DataType::Boolean, false)]` + pub arg_fields: &'a [ValueOrLambdaField], + /// Is argument `i` to the function a scalar (constant)? + /// + /// If the argument `i` is not a scalar, it will be None + /// + /// For example, if a function is called like `my_function(column_a, 5)` + /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]` + pub scalar_arguments: &'a [Option<&'a ScalarValue>], +} + +/// A tagged Field indicating whether it correspond to a value or a lambda argument +#[derive(Clone, Debug)] +pub enum ValueOrLambdaField { + /// The Field of a ColumnarValue argument + Value(FieldRef), + /// The Field of the return of the lambda body when evaluated with the parameters from LambdaUDF::lambda_parameters + Lambda(FieldRef), +} + +/// Trait for implementing user defined lambda functions. +/// +/// This trait exposes the full API for implementing user defined functions and +/// can be used to implement any function. +/// +/// See [`advanced_udf.rs`] for a full example with complete implementation and +/// [`LambdaUDF`] for other available options. +/// +/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs +/// +/// # Basic Example +/// ``` +/// # use std::any::Any; +/// # use std::sync::LazyLock; +/// # use arrow::datatypes::DataType; +/// # use datafusion_common::{DataFusionError, plan_err, Result}; +/// # use datafusion_expr::{col, ColumnarValue, Documentation, LambdaFunctionArgs, LambdaSignature, Volatility}; +/// # use datafusion_expr::LambdaUDF; +/// # use datafusion_expr::lambda_doc_sections::DOC_SECTION_MATH; +/// /// This struct for a simple UDF that adds one to an int32 +/// #[derive(Debug, PartialEq, Eq, Hash)] +/// struct AddOne { +/// signature: LambdaSignature, +/// } +/// +/// impl AddOne { +/// fn new() -> Self { +/// Self { +/// signature: LambdaSignature::new(Volatility::Immutable), +/// } +/// } +/// } +/// +/// static DOCUMENTATION: LazyLock = LazyLock::new(|| { +/// Documentation::builder(DOC_SECTION_MATH, "Add one to an int32", "add_one(2)") +/// .with_argument("arg1", "The int32 number to add one to") +/// .build() +/// }); +/// +/// fn get_doc() -> &'static Documentation { +/// &DOCUMENTATION +/// } +/// +/// /// Implement the LambdaUDF trait for AddOne +/// impl LambdaUDF for AddOne { +/// fn as_any(&self) -> &dyn Any { self } +/// fn name(&self) -> &str { "add_one" } +/// fn signature(&self) -> &LambdaSignature { &self.signature } +/// fn return_type(&self, args: &[DataType]) -> Result { +/// if !matches!(args.get(0), Some(&DataType::Int32)) { +/// return plan_err!("add_one only accepts Int32 arguments"); +/// } +/// Ok(DataType::Int32) +/// } +/// // The actual implementation would add one to the argument +/// fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result { +/// unimplemented!() +/// } +/// fn documentation(&self) -> Option<&Documentation> { +/// Some(get_doc()) +/// } +/// } +/// +/// // Create a new LambdaUDF from the implementation +/// let add_one = LambdaUDF::from(AddOne::new()); +/// +/// // Call the function `add_one(col)` +/// let expr = add_one.call(vec![col("a")]); +/// ``` +pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { + /// Returns this object as an [`Any`] trait object + fn as_any(&self) -> &dyn Any; + + /// Returns this function's name + fn name(&self) -> &str; + + /// Returns any aliases (alternate names) for this function. + /// + /// Aliases can be used to invoke the same function using different names. + /// For example in some databases `now()` and `current_timestamp()` are + /// aliases for the same function. This behavior can be obtained by + /// returning `current_timestamp` as an alias for the `now` function. + /// + /// Note: `aliases` should only include names other than [`Self::name`]. + /// Defaults to `[]` (no aliases) + fn aliases(&self) -> &[String] { + &[] + } + + /// Returns the name of the column this expression would create + /// + /// See [`Expr::schema_name`] for details + fn schema_name(&self, args: &[Expr]) -> Result { + Ok(format!( + "{}({})", + self.name(), + schema_name_from_exprs_comma_separated_without_space(args)? + )) + } + + /// Returns a [`LambdaSignature`] describing the argument types for which this + /// function has an implementation, and the function's [`Volatility`]. + /// + /// See [`LambdaSignature`] for more details on argument type handling + /// and [`Self::return_type`] for computing the return type. + /// + /// [`Volatility`]: datafusion_expr_common::signature::Volatility + fn signature(&self) -> &LambdaSignature; + + /// Create a new instance of this function with updated configuration. + /// + /// This method is called when configuration options change at runtime + /// (e.g., via `SET` statements) to allow functions that depend on + /// configuration to update themselves accordingly. + /// + /// Note the current [`ConfigOptions`] are also passed to [`Self::invoke_with_args`] so + /// this API is not needed for functions where the values may + /// depend on the current options. + /// + /// This API is useful for functions where the return + /// **type** depends on the configuration options, such as the `now()` function + /// which depends on the current timezone. + /// + /// # Arguments + /// + /// * `config` - The updated configuration options + /// + /// # Returns + /// + /// * `Some(LambdaUDF)` - A new instance of this function configured with the new settings + /// * `None` - If this function does not change with new configuration settings (the default) + fn with_updated_config(&self, _config: &ConfigOptions) -> Option> { + None + } + + /// What type will be returned by this function, given the arguments? + /// + /// By default, this function calls [`Self::return_type`] with the + /// types of each argument. + /// + /// # Notes + /// + /// For the majority of UDFs, implementing [`Self::return_type`] is sufficient, + /// as the result type is typically a deterministic function of the input types + /// (e.g., `sqrt(f32)` consistently yields `f32`). Implementing this method directly + /// is generally unnecessary unless the return type depends on runtime values. + /// + /// This function can be used for more advanced cases such as: + /// + /// 1. specifying nullability + /// 2. return types based on the **values** of the arguments (rather than + /// their **types**. + /// + /// # Example creating `Field` + /// + /// Note the name of the [`Field`] is ignored, except for structured types such as + /// `DataType::Struct`. + /// + /// ```rust + /// # use std::sync::Arc; + /// # use arrow::datatypes::{DataType, Field, FieldRef}; + /// # use datafusion_common::Result; + /// # use datafusion_expr::LambdaReturnFieldArgs; + /// # struct Example{} + /// # impl Example { + /// fn return_field_from_args(&self, args: LambdaReturnFieldArgs) -> Result { + /// // report output is only nullable if any one of the arguments are nullable + /// let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + /// let field = Arc::new(Field::new("ignored_name", DataType::Int32, true)); + /// Ok(field) + /// } + /// # } + /// ``` + /// + /// # Output Type based on Values + /// + /// For example, the following two function calls get the same argument + /// types (something and a `Utf8` string) but return different types based + /// on the value of the second argument: + /// + /// * `arrow_cast(x, 'Int16')` --> `Int16` + /// * `arrow_cast(x, 'Float32')` --> `Float32` + /// + /// # Requirements + /// + /// This function **must** consistently return the same type for the same + /// logical input even if the input is simplified (e.g. it must return the same + /// value for `('foo' | 'bar')` as it does for ('foobar'). + fn return_field_from_args(&self, args: LambdaReturnFieldArgs) -> Result; + + /// Invoke the function returning the appropriate result. + /// + /// # Performance + /// + /// For the best performance, the implementations should handle the common case + /// when one or more of their arguments are constant values (aka + /// [`ColumnarValue::Scalar`]). + /// + /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments + /// to arrays, which will likely be simpler code, but be slower. + fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result; + + /// Optionally apply per-UDF simplification / rewrite rules. + /// + /// This can be used to apply function specific simplification rules during + /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default + /// implementation does nothing. + /// + /// Note that DataFusion handles simplifying arguments and "constant + /// folding" (replacing a function call with constant arguments such as + /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such + /// optimizations manually for specific UDFs. + /// + /// # Arguments + /// * `args`: The arguments of the function + /// * `info`: The necessary information for simplification + /// + /// # Returns + /// [`ExprSimplifyResult`] indicating the result of the simplification NOTE + /// if the function cannot be simplified, the arguments *MUST* be returned + /// unmodified + /// + /// # Notes + /// + /// The returned expression must have the same schema as the original + /// expression, including both the data type and nullability. For example, + /// if the original expression is nullable, the returned expression must + /// also be nullable, otherwise it may lead to schema verification errors + /// later in query planning. + fn simplify( + &self, + args: Vec, + _info: &dyn SimplifyInfo, + ) -> Result { + Ok(ExprSimplifyResult::Original(args)) + } + + /// Returns true if some of this `exprs` subexpressions may not be evaluated + /// and thus any side effects (like divide by zero) may not be encountered. + /// + /// Setting this to true prevents certain optimizations such as common + /// subexpression elimination + /// + /// When overriding this function to return `true`, [LambdaUDF::conditional_arguments] can also be + /// overridden to report more accurately which arguments are eagerly evaluated and which ones + /// lazily. + fn short_circuits(&self) -> bool { + false + } + + /// Determines which of the arguments passed to this function are evaluated eagerly + /// and which may be evaluated lazily. + /// + /// If this function returns `None`, all arguments are eagerly evaluated. + /// Returning `None` is a micro optimization that saves a needless `Vec` + /// allocation. + /// + /// If the function returns `Some`, returns (`eager`, `lazy`) where `eager` + /// are the arguments that are always evaluated, and `lazy` are the + /// arguments that may be evaluated lazily (i.e. may not be evaluated at all + /// in some cases). + /// + /// Implementations must ensure that the two returned `Vec`s are disjunct, + /// and that each argument from `args` is present in one the two `Vec`s. + /// + /// When overriding this function, [LambdaUDF::short_circuits] must + /// be overridden to return `true`. + fn conditional_arguments<'a>( + &self, + args: &'a [Expr], + ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> { + if self.short_circuits() { + Some((vec![], args.iter().collect())) + } else { + None + } + } + + /// Computes the output [`Interval`] for a [`LambdaUDF`], given the input + /// intervals. + /// + /// # Parameters + /// + /// * `children` are the intervals for the children (inputs) of this function. + /// + /// # Example + /// + /// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`, + /// then the output interval would be `[0, 3]`. + fn evaluate_bounds(&self, _input: &[&Interval]) -> Result { + // We cannot assume the input datatype is the same of output type. + Interval::make_unbounded(&DataType::Null) + } + + /// Updates bounds for child expressions, given a known [`Interval`]s for this + /// function. + /// + /// This function is used to propagate constraints down through an + /// expression tree. + /// + /// # Parameters + /// + /// * `interval` is the currently known interval for this function. + /// * `inputs` are the current intervals for the inputs (children) of this function. + /// + /// # Returns + /// + /// A `Vec` of new intervals for the children, in order. + /// + /// If constraint propagation reveals an infeasibility for any child, returns + /// [`None`]. If none of the children intervals change as a result of + /// propagation, may return an empty vector instead of cloning `children`. + /// This is the default (and conservative) return value. + /// + /// # Example + /// + /// If the function is `ABS(a)`, the current `interval` is `[4, 5]` and the + /// input `a` is given as `[-7, 3]`, then propagation would return `[-5, 3]`. + fn propagate_constraints( + &self, + _interval: &Interval, + _inputs: &[&Interval], + ) -> Result>> { + Ok(Some(vec![])) + } + + /// Calculates the [`SortProperties`] of this function based on its children's properties. + fn output_ordering(&self, inputs: &[ExprProperties]) -> Result { + if !self.preserves_lex_ordering(inputs)? { + return Ok(SortProperties::Unordered); + } + + let Some(first_order) = inputs.first().map(|p| &p.sort_properties) else { + return Ok(SortProperties::Singleton); + }; + + if inputs + .iter() + .skip(1) + .all(|input| &input.sort_properties == first_order) + { + Ok(*first_order) + } else { + Ok(SortProperties::Unordered) + } + } + + /// Returns true if the function preserves lexicographical ordering based on + /// the input ordering. + /// + /// For example, `concat(a || b)` preserves lexicographical ordering, but `abs(a)` does not. + fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result { + Ok(false) + } + + /// Coerce arguments of a function call to types that the function can evaluate. + /// + /// See the [type coercion module](crate::type_coercion) + /// documentation for more details on type coercion + /// + /// For example, if your function requires a floating point arguments, but the user calls + /// it like `my_func(1::int)` (i.e. with `1` as an integer), coerce_types can return `[DataType::Float64]` + /// to ensure the argument is converted to `1::double` + /// + /// # Parameters + /// * `arg_types`: The argument types of the arguments this function with + /// + /// # Return value + /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call + /// arguments to these specific types. + fn coerce_value_types( + &self, + _arg_types: &[ValueOrLambdaParameter], + ) -> Result>> { + not_impl_err!("Function {} does not implement coerce_types", self.name()) + } + + /// Returns the documentation for this Lambda UDF. + /// + /// Documentation can be accessed programmatically as well as generating + /// publicly facing documentation. + fn documentation(&self) -> Option<&Documentation> { + None + } + + /// Returns the parameters that any lambda supports + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>> { + Ok(vec![None; args.len()]) + } +} + +#[cfg(test)] +mod tests { + use datafusion_expr_common::signature::Volatility; + + use super::*; + use std::hash::DefaultHasher; + + #[derive(Debug, PartialEq, Eq, Hash)] + struct TestLambdaUDF { + name: &'static str, + field: &'static str, + signature: LambdaSignature, + } + impl LambdaUDF for TestLambdaUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.name + } + + fn signature(&self) -> &LambdaSignature { + &self.signature + } + + fn return_field_from_args(&self, _args: LambdaReturnFieldArgs) -> Result { + unimplemented!() + } + + fn invoke_with_args(&self, _args: LambdaFunctionArgs) -> Result { + unimplemented!() + } + } + + // PartialEq and Hash must be consistent, and also PartialEq and PartialOrd + // must be consistent, so they are tested together. + #[test] + fn test_partial_eq_hash_and_partial_ord() { + // A parameterized function + let f = test_func("foo", "a"); + + // Same like `f`, different instance + let f2 = test_func("foo", "a"); + assert_eq!(&f, &f2); + assert_eq!(hash(&f), hash(&f2)); + assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal)); + + // Different parameter + let b = test_func("foo", "b"); + assert_ne!(&f, &b); + assert_ne!(hash(&f), hash(&b)); // hash can collide for different values but does not collide in this test + assert_eq!(f.partial_cmp(&b), None); + + // Different name + let o = test_func("other", "a"); + assert_ne!(&f, &o); + assert_ne!(hash(&f), hash(&o)); // hash can collide for different values but does not collide in this test + assert_eq!(f.partial_cmp(&o), Some(Ordering::Less)); + + // Different name and parameter + assert_ne!(&b, &o); + assert_ne!(hash(&b), hash(&o)); // hash can collide for different values but does not collide in this test + assert_eq!(b.partial_cmp(&o), Some(Ordering::Less)); + } + + fn test_func(name: &'static str, parameter: &'static str) -> Arc { + Arc::new(TestLambdaUDF { + name, + field: parameter, + signature: LambdaSignature::variadic_any(Volatility::Immutable), + }) + } + + fn hash(value: &T) -> u64 { + let hasher = &mut DefaultHasher::new(); + value.hash(hasher); + hasher.finish() + } +} diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index cd733e0a130a9..e7beba8c4b090 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -307,7 +307,10 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::ScalarSubquery(_) | Expr::Wildcard { .. } | Expr::Placeholder(_) - | Expr::OuterReferenceColumn { .. } => {} + | Expr::OuterReferenceColumn { .. } + | Expr::LambdaFunction(_) + | Expr::Lambda(_) + | Expr::LambdaVariable(_) => {} } Ok(TreeNodeRecursion::Continue) }) diff --git a/datafusion/functions-nested/Cargo.toml b/datafusion/functions-nested/Cargo.toml index 6e0d1048f9697..0299aebdcac47 100644 --- a/datafusion/functions-nested/Cargo.toml +++ b/datafusion/functions-nested/Cargo.toml @@ -53,6 +53,7 @@ datafusion-functions = { workspace = true } datafusion-functions-aggregate = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-macros = { workspace = true } +datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs new file mode 100644 index 0000000000000..e0c4ab28c1fef --- /dev/null +++ b/datafusion/functions-nested/src/array_transform.rs @@ -0,0 +1,334 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_transform function. + +use arrow::{ + array::{ + Array, ArrayRef, AsArray, FixedSizeListArray, LargeListArray, ListArray, + RecordBatch, RecordBatchOptions, + }, + compute::take_record_batch, + datatypes::{DataType, Field, FieldRef, Schema}, +}; +use datafusion_common::{ + exec_err, + tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + }, + utils::{elements_indices, list_indices, list_values, take_function_args}, + HashMap, Result, +}; +use datafusion_expr::{ + ColumnarValue, Documentation, LambdaFunctionArgs, LambdaSignature, LambdaUDF, + ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, Volatility, +}; +use datafusion_macros::user_doc; +use datafusion_physical_expr::expressions::{LambdaExpr, LambdaVariable}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::{any::Any, sync::Arc}; + +//make_udf_expr_and_func!( +// ArrayTransform, +// array_transform, +// array lambda, +// "transforms the values of a array", +// array_transform_udf +//); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "transforms the values of a array", + syntax_example = "array_transform(array, x -> x*2)", + sql_example = r#"```sql +> select array_transform([1, 2, 3, 4, 5], x -> x*2); ++-------------------------------------------+ +| array_transform([1, 2, 3, 4, 5], x -> x*2) | ++-------------------------------------------+ +| [2, 4, 6, 8, 10] | ++-------------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument(name = "lambda", description = "Lambda") +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArrayTransform { + signature: LambdaSignature, + aliases: Vec, +} + +impl Default for ArrayTransform { + fn default() -> Self { + Self::new() + } +} + +impl ArrayTransform { + pub fn new() -> Self { + Self { + signature: LambdaSignature::any(2, Volatility::Immutable), + aliases: vec![String::from("list_transform")], + } + } +} + +impl LambdaUDF for ArrayTransform { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_transform" + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn signature(&self) -> &LambdaSignature { + &self.signature + } + + fn return_field_from_args( + &self, + args: datafusion_expr::LambdaReturnFieldArgs, + ) -> Result> { + let [ValueOrLambdaField::Value(list), ValueOrLambdaField::Lambda(lambda)] = + take_function_args(self.name(), args.arg_fields)? + else { + return exec_err!( + "{} expects a value follewed by a lambda, got {:?}", + self.name(), + args + ); + }; + + //TODO: should metadata be copied into the transformed array? + + // lambda is the resulting field of executing the lambda body + // with the parameters returned in lambdas_parameters + let field = Arc::new(Field::new( + Field::LIST_FIELD_DEFAULT_NAME, + lambda.data_type().clone(), + lambda.is_nullable(), + )); + + let return_type = match list.data_type() { + DataType::List(_) => DataType::List(field), + DataType::LargeList(_) => DataType::LargeList(field), + DataType::FixedSizeList(_, size) => DataType::FixedSizeList(field, *size), + _ => unreachable!(), + }; + + Ok(Arc::new(Field::new("", return_type, list.is_nullable()))) + } + + fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result { + let [list_value, lambda] = take_function_args(self.name(), &args.args)?; + + let (ValueOrLambda::Value(list_value), ValueOrLambda::Lambda(lambda)) = + (list_value, lambda) + else { + return exec_err!( + "{} expects a value followed by a lambda, got {:?}", + self.name(), + &args.args + ); + }; + + let list_array = list_value.to_array(args.number_rows)?; + let list_values = list_values(&list_array)?; + + // if any column got captured, we need to adjust it to the values arrays, + // duplicating values of list with mulitple values and removing values of empty lists + // list_indices is not cheap so is important to avoid it when no column is captured + let adjusted_captures = lambda + .captures + .as_ref() + .map(|captures| take_record_batch(captures, &list_indices(&list_array)?)) + .transpose()?; + + // use closures and merge_captures_with_lazy_args so that it calls only the needed ones based on the number of arguments + // avoiding unnecessary computations + let values_param = || Ok(Arc::clone(list_values)); + let indices_param = || elements_indices(&list_array); + + let binded_body = bind_lambda_variables( + Arc::clone(&lambda.body), + &lambda.params, + &[&values_param, &indices_param], + )?; + + // call the transforming expression with the record batch composed of the list values merged with captured columns + let transformed_values = binded_body + .evaluate(&adjusted_captures.unwrap_or_else(|| { + RecordBatch::try_new_with_options( + Arc::new(Schema::empty()), + vec![], + &RecordBatchOptions::new().with_row_count(Some(list_values.len())), + ) + .unwrap() + }))? + .into_array(list_values.len())?; + + let field = match args.return_field.data_type() { + DataType::List(field) + | DataType::LargeList(field) + | DataType::FixedSizeList(field, _) => Arc::clone(field), + _ => { + return exec_err!( + "{} expected ScalarFunctionArgs.return_field to be a list, got {}", + self.name(), + args.return_field + ) + } + }; + + let transformed_list = match list_array.data_type() { + DataType::List(_) => { + let list = list_array.as_list(); + + Arc::new(ListArray::new( + field, + list.offsets().clone(), + transformed_values, + list.nulls().cloned(), + )) as ArrayRef + } + DataType::LargeList(_) => { + let large_list = list_array.as_list(); + + Arc::new(LargeListArray::new( + field, + large_list.offsets().clone(), + transformed_values, + large_list.nulls().cloned(), + )) + } + DataType::FixedSizeList(_, value_length) => { + Arc::new(FixedSizeListArray::new( + field, + *value_length, + transformed_values, + list_array.as_fixed_size_list().nulls().cloned(), + )) + } + other => exec_err!("expected list, got {other}")?, + }; + + Ok(ColumnarValue::Array(transformed_list)) + } + + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>> { + let [ValueOrLambdaParameter::Value(list), ValueOrLambdaParameter::Lambda] = args + else { + return exec_err!( + "{} expects a value follewed by a lambda, got {:?}", + self.name(), + args + ); + }; + + let (field, index_type) = match list.data_type() { + DataType::List(field) => (field, DataType::Int32), + DataType::LargeList(field) => (field, DataType::Int64), + DataType::FixedSizeList(field, _) => (field, DataType::Int32), + _ => return exec_err!("expected list, got {list}"), + }; + + // we don't need to omit the index in the case the lambda don't specify, e.g. array_transform([], v -> v*2), + // nor check whether the lambda contains more than two parameters, e.g. array_transform([], (v, i, j) -> v+i+j), + // as datafusion will do that for us + let value = Field::new("", field.data_type().clone(), field.is_nullable()) + .with_metadata(field.metadata().clone()); + let index = Field::new("", index_type, false); + + Ok(vec![None, Some(vec![value, index])]) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn bind_lambda_variables( + expr: Arc, + params: &[FieldRef], + args: &[&dyn Fn() -> Result], +) -> Result> { + let columns = std::iter::zip(params, args) + .map(|(param, arg)| Ok((param.name().as_str(), (arg()?, 0)))) + .collect::>>()?; + + expr.rewrite(&mut BindLambdaVariable::new(columns)).data() +} + +struct BindLambdaVariable<'a> { + columns: HashMap<&'a str, (ArrayRef, usize)>, +} + +impl<'a> BindLambdaVariable<'a> { + fn new(columns: HashMap<&'a str, (ArrayRef, usize)>) -> Self { + Self { columns } + } +} + +impl TreeNodeRewriter for BindLambdaVariable<'_> { + type Node = Arc; + + fn f_down(&mut self, node: Self::Node) -> Result> { + if let Some(lambda_variable) = node.as_any().downcast_ref::() { + if let Some((value, shadows)) = self.columns.get(lambda_variable.name()) { + if *shadows == 0 { + return Ok(Transformed::yes(Arc::new( + lambda_variable.clone().with_value(Arc::clone(value)), + ))); + } + } + } else if let Some(inner_lambda) = node.as_any().downcast_ref::() { + for param in inner_lambda.params() { + if let Some((_value, shadows)) = self.columns.get_mut(param.as_str()) { + *shadows += 1; + } + } + + if self.columns.values().all(|(_value, shadows)| *shadows > 0) { + return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump)); + } + } + + Ok(Transformed::no(node)) + } + + fn f_up(&mut self, node: Self::Node) -> Result> { + if let Some(inner_lambda) = node.as_any().downcast_ref::() { + for param in inner_lambda.params() { + if let Some((_value, shadows)) = self.columns.get_mut(param.as_str()) { + *shadows -= 1; + } + } + } + + Ok(Transformed::no(node)) + } +} diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index 3a66e65694768..c93a55cce1a4f 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -37,6 +37,7 @@ pub mod macros; pub mod array_has; +pub mod array_transform; pub mod cardinality; pub mod concat; pub mod dimension; @@ -78,6 +79,7 @@ pub mod expr_fn { pub use super::array_has::array_has; pub use super::array_has::array_has_all; pub use super::array_has::array_has_any; + //pub use super::array_transform::array_transform; pub use super::cardinality::cardinality; pub use super::concat::array_append; pub use super::concat::array_concat; @@ -145,6 +147,7 @@ pub fn all_default_nested_functions() -> Vec> { array_has::array_has_udf(), array_has::array_has_all_udf(), array_has::array_has_any_udf(), + //array_transform::array_transform_udf(), empty::array_empty_udf(), length::array_length_udf(), distance::array_distance_udf(), diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index a71e2e87388d5..7f93500b9cfb9 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -248,6 +248,7 @@ mod tests { .iter() .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) .collect::>(); + let result = fun.invoke_with_args(ScalarFunctionArgs { args, arg_fields, diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 4fb0f8553b4ba..8c6c42e7e630c 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -35,7 +35,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::{ self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList, - InSubquery, Like, ScalarFunction, Sort, WindowFunction, + InSubquery, LambdaFunction, Like, ScalarFunction, Sort, WindowFunction, }; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; use datafusion_expr::expr_schema::cast_subquery; @@ -51,8 +51,9 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_utf8view_or_large_u use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, - AggregateUDF, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator, Projection, - ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits, + AggregateUDF, Expr, ExprSchemable, Join, LambdaTypeSignature, Limit, LogicalPlan, + Operator, Projection, ScalarUDF, Union, ValueOrLambdaParameter, WindowFrame, + WindowFrameBound, WindowFrameUnits, }; /// Performs type coercion by determining the schema @@ -582,6 +583,62 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { }); Ok(Transformed::yes(new_expr)) } + Expr::LambdaFunction(LambdaFunction { func, args }) => { + match func.signature().type_signature { + LambdaTypeSignature::UserDefined => { + let args_types = args + .iter() + .map(|arg| match arg { + Expr::Lambda(_) => Ok(ValueOrLambdaParameter::Lambda), + _ => Ok(ValueOrLambdaParameter::Value( + arg.get_type(self.schema)?, + )), + }) + .collect::>>()?; + + let value_types = func.coerce_value_types(&args_types)?; + + if args_types.iter().eq_by(&value_types, |a, b| match (a, b) { + (ValueOrLambdaParameter::Value(_type), None) => false, + (ValueOrLambdaParameter::Value(from), Some(to)) => from == to, + (ValueOrLambdaParameter::Lambda, None) => true, + (ValueOrLambdaParameter::Lambda, Some(_ty)) => false, + }) { + return Ok(Transformed::no(Expr::LambdaFunction( + LambdaFunction::new(func, args), + ))); + } + + let args = std::iter::zip(args, value_types) + .map(|(arg, ty)| match (&arg, ty) { + (Expr::Lambda(_), None) => Ok(arg), + (Expr::Lambda(_), Some(_ty)) => plan_err!("{} coerce_value_types returned Some for a lambda argument", func.name()), + (_, Some(ty)) => arg.cast_to(&ty, self.schema), + (_, None) => plan_err!("{} coerce_value_types returned None for a value argument", func.name()), + }) + .collect::>>()?; + + Ok(Transformed::yes(Expr::LambdaFunction(LambdaFunction::new( + func, args, + )))) + } + LambdaTypeSignature::VariadicAny => Ok(Transformed::no( + Expr::LambdaFunction(LambdaFunction::new(func, args)), + )), + LambdaTypeSignature::Any(number) => { + if args.len() != number { + return plan_err!( + "The function '{}' expected {number} arguments but received {}", + func.name(), args.len() + ); + } + + Ok(Transformed::no(Expr::LambdaFunction(LambdaFunction::new( + func, args, + )))) + } + } + } // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] Expr::Alias(_) @@ -597,7 +654,9 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) - | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)), + | Expr::OuterReferenceColumn(_, _) + | Expr::Lambda(_) + | Expr::LambdaVariable(_) => Ok(Transformed::no(expr)), } } } @@ -793,9 +852,11 @@ fn coerce_arguments_for_signature_with_scalar_udf( return Ok(expressions); } - let current_types = expressions - .iter() - .map(|e| e.get_type(schema)) + let current_types = expressions.iter() + .map(|e| match e { + Expr::Lambda { .. } => Ok(DataType::Null), + _ => e.get_type(schema), + }) .collect::>>()?; let new_types = data_types_with_scalar_udf(¤t_types, func)?; @@ -803,7 +864,10 @@ fn coerce_arguments_for_signature_with_scalar_udf( expressions .into_iter() .enumerate() - .map(|(i, expr)| expr.cast_to(&new_types[i], schema)) + .map(|(i, expr)| match expr { + lambda @ Expr::Lambda { .. } => Ok(lambda), + _ => expr.cast_to(&new_types[i], schema), + }) .collect() } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 2510068494591..74e77011b71ed 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -694,7 +694,7 @@ impl CSEController for ExprCSEController<'_> { } fn is_valid(node: &Expr) -> bool { - !node.is_volatile_node() + !node.is_volatile_node() && !matches!(node, Expr::LambdaVariable(_)) } fn is_ignored(&self, node: &Expr) -> bool { @@ -707,6 +707,7 @@ impl CSEController for ExprCSEController<'_> { | Expr::ScalarVariable(..) | Expr::Alias(..) | Expr::Wildcard { .. } + | Expr::LambdaVariable(_) ); let is_aggr = matches!(node, Expr::AggregateFunction(..)); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 1c0790b3e3acd..63314b48facfd 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -287,7 +287,10 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::Cast(_) | Expr::TryCast(_) | Expr::InList { .. } - | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue), + | Expr::ScalarFunction(_) + | Expr::LambdaFunction(_) + | Expr::Lambda(_) + | Expr::LambdaVariable(_) => Ok(TreeNodeRecursion::Continue), // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] Expr::AggregateFunction(_) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 05b8c28fadd6c..8e09fbbae48d8 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -17,15 +17,15 @@ //! Expression simplification API +use std::collections::HashSet; +use std::ops::Not; +use std::{borrow::Cow, sync::Arc}; + use arrow::{ array::{new_null_array, AsArray}, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; -use std::borrow::Cow; -use std::collections::HashSet; -use std::ops::Not; -use std::sync::Arc; use datafusion_common::{ cast::{as_large_list_array, as_list_array}, @@ -33,8 +33,10 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{ - exec_datafusion_err, internal_err, DFSchema, DataFusionError, Result, ScalarValue, + exec_datafusion_err, internal_err, DFSchema, DataFusionError, Result, + ScalarValue, }; +use datafusion_expr::expr::LambdaFunction; use datafusion_expr::{ and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, @@ -466,9 +468,11 @@ impl TreeNodeRewriter for Canonicalizer { }; match (left.as_ref(), right.as_ref(), op.swap()) { // - (Expr::Column(left_col), Expr::Column(right_col), Some(swapped_op)) - if right_col > left_col => - { + ( + left_col @ (Expr::Column(_) | Expr::LambdaVariable(_)), + right_col @ (Expr::Column(_) | Expr::LambdaVariable(_)), + Some(swapped_op), + ) if right_col > left_col => { Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { left: right, op: swapped_op, @@ -476,13 +480,15 @@ impl TreeNodeRewriter for Canonicalizer { }))) } // - (Expr::Literal(_a, _), Expr::Column(_b), Some(swapped_op)) => { - Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { - left: right, - op: swapped_op, - right: left, - }))) - } + ( + Expr::Literal(_, _), + Expr::Column(_) | Expr::LambdaVariable(_), + Some(swapped_op), + ) => Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { + left: right, + op: swapped_op, + right: left, + }))), _ => Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { left, op, @@ -649,10 +655,14 @@ impl<'a> ConstEvaluator<'a> { | Expr::WindowFunction { .. } | Expr::GroupingSet(_) | Expr::Wildcard { .. } - | Expr::Placeholder(_) => false, + | Expr::Placeholder(_) + | Expr::LambdaVariable(_) => false, Expr::ScalarFunction(ScalarFunction { func, .. }) => { Self::volatility_ok(func.signature().volatility) } + Expr::LambdaFunction(LambdaFunction { func, .. }) => { + Self::volatility_ok(func.signature().volatility) + } Expr::Literal(_, _) | Expr::Alias(..) | Expr::Unnest(_) @@ -673,7 +683,8 @@ impl<'a> ConstEvaluator<'a> { | Expr::Case(_) | Expr::Cast { .. } | Expr::TryCast { .. } - | Expr::InList { .. } => true, + | Expr::InList { .. } + | Expr::Lambda(_) => true, } } @@ -2005,8 +2016,8 @@ fn are_inlist_and_eq(left: &Expr, right: &Expr) -> bool { let left = as_inlist(left); let right = as_inlist(right); if let (Some(lhs), Some(rhs)) = (left, right) { - matches!(lhs.expr.as_ref(), Expr::Column(_)) - && matches!(rhs.expr.as_ref(), Expr::Column(_)) + matches!(lhs.expr.as_ref(), Expr::Column(_) | Expr::LambdaVariable(_)) + && matches!(rhs.expr.as_ref(), Expr::Column(_) | Expr::LambdaVariable(_)) && lhs.expr == rhs.expr && !lhs.negated && !rhs.negated @@ -2021,16 +2032,20 @@ fn as_inlist(expr: &'_ Expr) -> Option> { Expr::InList(inlist) => Some(Cow::Borrowed(inlist)), Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => { match (left.as_ref(), right.as_ref()) { - (Expr::Column(_), Expr::Literal(_, _)) => Some(Cow::Owned(InList { - expr: left.clone(), - list: vec![*right.clone()], - negated: false, - })), - (Expr::Literal(_, _), Expr::Column(_)) => Some(Cow::Owned(InList { - expr: right.clone(), - list: vec![*left.clone()], - negated: false, - })), + (Expr::Column(_) | Expr::LambdaVariable(_), Expr::Literal(_, _)) => { + Some(Cow::Owned(InList { + expr: left.clone(), + list: vec![*right.clone()], + negated: false, + })) + } + (Expr::Literal(_, _), Expr::Column(_) | Expr::LambdaVariable(_)) => { + Some(Cow::Owned(InList { + expr: right.clone(), + list: vec![*left.clone()], + negated: false, + })) + } _ => None, } } @@ -2046,16 +2061,20 @@ fn to_inlist(expr: Expr) -> Option { op: Operator::Eq, right, }) => match (left.as_ref(), right.as_ref()) { - (Expr::Column(_), Expr::Literal(_, _)) => Some(InList { - expr: left, - list: vec![*right], - negated: false, - }), - (Expr::Literal(_, _), Expr::Column(_)) => Some(InList { - expr: right, - list: vec![*left], - negated: false, - }), + (Expr::Column(_) | Expr::LambdaVariable(_), Expr::Literal(_, _)) => { + Some(InList { + expr: left, + list: vec![*right], + negated: false, + }) + } + (Expr::Literal(_, _), Expr::Column(_) | Expr::LambdaVariable(_)) => { + Some(InList { + expr: right, + list: vec![*left], + negated: false, + }) + } _ => None, }, _ => None, diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index c0f48b8ebfc40..11a656f2abb4c 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -728,6 +728,10 @@ impl ContextProvider for MyContextProvider { None } + fn get_lambda_meta(&self, _name: &str) -> Option> { + None + } + fn get_aggregate_meta(&self, name: &str) -> Option> { self.udafs.get(name).cloned() } diff --git a/datafusion/physical-expr-adapter/src/schema_rewriter.rs b/datafusion/physical-expr-adapter/src/schema_rewriter.rs index 61cc97dae300e..d1957ae1892ea 100644 --- a/datafusion/physical-expr-adapter/src/schema_rewriter.rs +++ b/datafusion/physical-expr-adapter/src/schema_rewriter.rs @@ -21,9 +21,10 @@ use std::sync::Arc; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef}; +use datafusion_common::tree_node::TreeNode; use datafusion_common::{ exec_err, - tree_node::{Transformed, TransformedResult, TreeNode}, + tree_node::{Transformed, TransformedResult}, Result, ScalarValue, }; use datafusion_functions::core::getfield::GetFieldFunc; diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 9ca464b304306..7040fa2bfc9b4 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -188,6 +188,7 @@ pub fn with_new_schema( ); }; let new_col = Column::new(field.name(), idx); + Ok(Transformed::yes(Arc::new(new_col) as _)) } else { Ok(Transformed::no(expr)) diff --git a/datafusion/physical-expr/src/expressions/lambda.rs b/datafusion/physical-expr/src/expressions/lambda.rs new file mode 100644 index 0000000000000..38b64e3c7f3e1 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/lambda.rs @@ -0,0 +1,136 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical lambda expression: [`LambdaExpr`] + +use std::hash::Hash; +use std::sync::Arc; +use std::{any::Any, sync::OnceLock}; + +use crate::expressions::Column; +use crate::physical_expr::PhysicalExpr; +use arrow::{ + datatypes::{DataType, Schema}, + record_batch::RecordBatch, +}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_common::{internal_err, HashSet, Result}; +use datafusion_expr::ColumnarValue; + +/// Represents a lambda with the given parameters names and body +#[derive(Debug, Eq, Clone)] +pub struct LambdaExpr { + params: Vec, + body: Arc, + captures: OnceLock>, +} + +// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808 [https://github.com/apache/datafusion/issues/13196] +impl PartialEq for LambdaExpr { + fn eq(&self, other: &Self) -> bool { + self.params.eq(&other.params) && self.body.eq(&other.body) + } +} + +impl Hash for LambdaExpr { + fn hash(&self, state: &mut H) { + self.params.hash(state); + self.body.hash(state); + } +} + +impl LambdaExpr { + /// Create a new lambda expression with the given parameters and body + pub fn new(params: Vec, body: Arc) -> Self { + Self { + params, + body, + captures: OnceLock::new(), + } + } + + /// Get the lambda's params names + pub fn params(&self) -> &[String] { + &self.params + } + + /// Get the lambda's body + pub fn body(&self) -> &Arc { + &self.body + } + + pub fn captures(&self) -> &HashSet { + self.captures.get_or_init(|| { + let mut indices = HashSet::new(); + + self.body + .apply(|expr| { + if let Some(column) = expr.as_any().downcast_ref::() { + indices.insert(column.index()); + } + + Ok(TreeNodeRecursion::Continue) + }) + .expect("closure should be infallibe"); + + indices + }) + } +} + +impl std::fmt::Display for LambdaExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "({}) -> {}", self.params.join(", "), self.body) + } +} + +impl PhysicalExpr for LambdaExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> Result { + self.body.data_type(input_schema) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + self.body.nullable(input_schema) + } + + fn evaluate(&self, _batch: &RecordBatch) -> Result { + internal_err!("Lambda::evaluate() should not be called") + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.body] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(Self { + params: self.params.clone(), + body: Arc::clone(&children[0]), + captures: OnceLock::new(), + })) + } + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "({}) -> {}", self.params.join(", "), self.body) + } +} diff --git a/datafusion/physical-expr/src/expressions/lambda_variable.rs b/datafusion/physical-expr/src/expressions/lambda_variable.rs new file mode 100644 index 0000000000000..305774c3c02da --- /dev/null +++ b/datafusion/physical-expr/src/expressions/lambda_variable.rs @@ -0,0 +1,136 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical lambda column reference: [`LambdaVariable`] + +use std::any::Any; +use std::hash::Hash; +use std::sync::Arc; + +use crate::physical_expr::PhysicalExpr; +use arrow::array::ArrayRef; +use arrow::datatypes::FieldRef; +use arrow::{ + datatypes::{DataType, Schema}, + record_batch::RecordBatch, +}; +use datafusion_common::{Result, exec_datafusion_err}; +use datafusion_expr::ColumnarValue; + +/// Represents the lambda column with a given name and field +#[derive(Debug, Clone)] +pub struct LambdaVariable { + name: String, + field: FieldRef, + value: Option, +} + +impl Eq for LambdaVariable {} + +impl PartialEq for LambdaVariable { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.field == other.field + } +} + +impl Hash for LambdaVariable { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.field.hash(state); + } +} + +impl LambdaVariable { + /// Create a new lambda column expression + pub fn new(name: &str, field: FieldRef) -> Self { + Self { + name: name.to_owned(), + field, + value: None, + } + } + + /// Get the column's name + pub fn name(&self) -> &str { + &self.name + } + + /// Get the column's field + pub fn field(&self) -> &FieldRef { + &self.field + } + + pub fn with_value(self, value: ArrayRef) -> Self { + Self { + name: self.name, + field: self.field, + value: Some(ColumnarValue::Array(value)), + } + } +} + +impl std::fmt::Display for LambdaVariable { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}@-1", self.name) + } +} + +impl PhysicalExpr for LambdaVariable { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + /// Get the data type of this expression, given the schema of the input + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.field.data_type().clone()) + } + + /// Decide whether this expression is nullable, given the schema of the input + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(self.field.is_nullable()) + } + + /// Evaluate the expression + fn evaluate(&self, _batch: &RecordBatch) -> Result { + self.value.clone().ok_or_else(|| exec_datafusion_err!("Physical LambdaVariable {} missing value", self.name)) + } + + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::clone(&self.field)) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.name) + } +} + +/// Create a lambda variable expression +pub fn lambda_variable(name: &str, field: FieldRef) -> Result> { + Ok(Arc::new(LambdaVariable::new(name, field))) +} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 59d675753d985..990e53fa23b2c 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -23,10 +23,12 @@ mod case; mod cast; mod cast_column; mod column; +mod lambda_variable; mod dynamic_filters; mod in_list; mod is_not_null; mod is_null; +mod lambda; mod like; mod literal; mod negative; @@ -44,11 +46,13 @@ pub use case::{case, CaseExpr}; pub use cast::{cast, CastExpr}; pub use cast_column::CastColumnExpr; pub use column::{col, with_new_schema, Column}; +pub use lambda_variable::{lambda_variable, LambdaVariable}; pub use datafusion_expr::utils::format_state_name; pub use dynamic_filters::DynamicFilterPhysicalExpr; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; +pub use lambda::LambdaExpr; pub use like::{like, LikeExpr}; pub use literal::{lit, Literal}; pub use negative::{negative, NegativeExpr}; diff --git a/datafusion/physical-expr/src/lambda_function.rs b/datafusion/physical-expr/src/lambda_function.rs new file mode 100644 index 0000000000000..97af1f9b13891 --- /dev/null +++ b/datafusion/physical-expr/src/lambda_function.rs @@ -0,0 +1,526 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Declaration of built-in (lambda) functions. +//! This module contains built-in functions' enumeration and metadata. +//! +//! Generally, a function has: +//! * a signature +//! * a return type, that is a function of the incoming argument's types +//! * the computation, that must accept each valid signature +//! +//! * Signature: see `Signature` +//! * Return type: a function `(arg_types) -> return_type`. E.g. for sqrt, ([f32]) -> f32, ([f64]) -> f64. +//! +//! This module also has a set of coercion rules to improve user experience: if an argument i32 is passed +//! to a function that supports f64, it is coerced to f64. + +use std::any::Any; +use std::fmt::{self, Debug, Formatter}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use crate::expressions::{LambdaExpr, Literal}; +use crate::PhysicalExpr; + +use arrow::array::{Array, NullArray, RecordBatch}; +use arrow::datatypes::{DataType, Field, FieldRef, Schema}; +use datafusion_common::config::{ConfigEntry, ConfigOptions}; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::sort_properties::ExprProperties; +use datafusion_expr::{ + expr_vec_fmt, ColumnarValue, LambdaFunctionArgs, LambdaFunctionLambdaArg, + LambdaReturnFieldArgs, LambdaUDF, ValueOrLambda, ValueOrLambdaField, + ValueOrLambdaParameter, Volatility, +}; + +/// Physical expression of a lambda function +pub struct LambdaFunctionExpr { + fun: Arc, + name: String, + args: Vec>, + return_field: FieldRef, + config_options: Arc, +} + +impl Debug for LambdaFunctionExpr { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("LambdaFunctionExpr") + .field("fun", &"") + .field("name", &self.name) + .field("args", &self.args) + .field("return_field", &self.return_field) + .finish() + } +} + +impl LambdaFunctionExpr { + /// Create a new Lambda function + pub fn new( + name: &str, + fun: Arc, + args: Vec>, + return_field: FieldRef, + config_options: Arc, + ) -> Self { + Self { + fun, + name: name.to_owned(), + args, + return_field, + config_options, + } + } + + /// Create a new Lambda function + pub fn try_new( + fun: Arc, + args: Vec>, + schema: &Schema, + config_options: Arc, + ) -> Result { + let name = fun.name().to_string(); + let arg_fields = args + .iter() + .map(|e| { + let field = e.return_field(schema)?; + match e.as_any().downcast_ref::() { + Some(_lambda) => Ok(ValueOrLambdaField::Lambda(field)), + None => Ok(ValueOrLambdaField::Value(field)), + } + }) + .collect::>>()?; + + // TODO: verify that input data types is consistent with function's `TypeSignature` + + let arguments = args + .iter() + .map(|e| { + e.as_any() + .downcast_ref::() + .map(|literal| literal.value()) + }) + .collect::>(); + + let ret_args = LambdaReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &arguments, + }; + + let return_field = fun.return_field_from_args(ret_args)?; + + Ok(Self { + fun, + name, + args, + return_field, + config_options, + }) + } + + /// Get the lambda function implementation + pub fn fun(&self) -> &dyn LambdaUDF { + self.fun.as_ref() + } + + /// The name for this expression + pub fn name(&self) -> &str { + &self.name + } + + /// Input arguments + pub fn args(&self) -> &[Arc] { + &self.args + } + + /// Data type produced by this expression + pub fn return_type(&self) -> &DataType { + self.return_field.data_type() + } + + pub fn with_nullable(mut self, nullable: bool) -> Self { + self.return_field = self + .return_field + .as_ref() + .clone() + .with_nullable(nullable) + .into(); + self + } + + pub fn nullable(&self) -> bool { + self.return_field.is_nullable() + } + + pub fn config_options(&self) -> &ConfigOptions { + &self.config_options + } + + /// Given an arbitrary PhysicalExpr attempt to downcast it to a LambdaFunctionExpr + /// and verify that its inner function is of type T. + /// If the downcast fails, or the function is not of type T, returns `None`. + /// Otherwise returns `Some(LambdaFunctionExpr)`. + pub fn try_downcast_func(expr: &dyn PhysicalExpr) -> Option<&LambdaFunctionExpr> + where + T: 'static, + { + match expr.as_any().downcast_ref::() { + Some(lambda_expr) + if lambda_expr.fun().as_any().downcast_ref::().is_some() => + { + Some(lambda_expr) + } + _ => None, + } + } +} + +impl fmt::Display for LambdaFunctionExpr { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}({})", self.name, expr_vec_fmt!(self.args)) + } +} + +impl PartialEq for LambdaFunctionExpr { + fn eq(&self, o: &Self) -> bool { + if std::ptr::eq(self, o) { + // The equality implementation is somewhat expensive, so let's short-circuit when possible. + return true; + } + let Self { + fun, + name, + args, + return_field, + config_options, + } = self; + fun.eq(&o.fun) + && name.eq(&o.name) + && args.eq(&o.args) + && return_field.eq(&o.return_field) + && (Arc::ptr_eq(config_options, &o.config_options) + || sorted_config_entries(config_options) + == sorted_config_entries(&o.config_options)) + } +} +impl Eq for LambdaFunctionExpr {} +impl Hash for LambdaFunctionExpr { + fn hash(&self, state: &mut H) { + let Self { + fun, + name, + args, + return_field, + config_options: _, // expensive to hash, and often equal + } = self; + fun.hash(state); + name.hash(state); + args.hash(state); + return_field.hash(state); + } +} + +fn sorted_config_entries(config_options: &ConfigOptions) -> Vec { + let mut entries = config_options.entries(); + entries.sort_by(|l, r| l.key.cmp(&r.key)); + entries +} + +impl PhysicalExpr for LambdaFunctionExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.return_field.data_type().clone()) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(self.return_field.is_nullable()) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let arg_fields = self + .args + .iter() + .map(|e| { + let field = e.return_field(batch.schema_ref())?; + match e.as_any().downcast_ref::() { + Some(_lambda) => Ok(ValueOrLambdaField::Lambda(field)), + None => Ok(ValueOrLambdaField::Value(field)), + } + }) + .collect::>>()?; + + let args_metadata = arg_fields.iter() + .map(|field| match field { + ValueOrLambdaField::Value(field) => ValueOrLambdaParameter::Value(Arc::clone(field)), + ValueOrLambdaField::Lambda(_field) => ValueOrLambdaParameter::Lambda, + }) + .collect::>(); + + let params = self.fun().lambdas_parameters(&args_metadata)?; + + let args = std::iter::zip(&self.args, params) + .map(|(arg, lambda_params)| { + match (arg.as_any().downcast_ref::(), lambda_params) { + (Some(lambda), Some(lambda_params)) => { + if lambda.params().len() > lambda_params.len() { + return exec_err!( + "lambda defined {} params but UDF support only {}", + lambda.params().len(), + lambda_params.len() + ); + } + + let captures = lambda.captures(); + + let params = std::iter::zip(lambda.params(), lambda_params) + .map(|(name, param)| Arc::new(param.with_name(name))) + .collect(); + + let captures = if !captures.is_empty() { + let (fields, columns): (Vec<_>, _) = std::iter::zip( + batch.schema_ref().fields(), + batch.columns(), + ) + .enumerate() + .map(|(column_index, (field, column))| { + if captures.contains(&column_index) { + (Arc::clone(field), Arc::clone(column)) + } else { + ( + Arc::new(Field::new( + field.name(), + DataType::Null, + false, + )), + Arc::new(NullArray::new(column.len())) as _, + ) + } + }) + .unzip(); + + let schema = Arc::new(Schema::new(fields)); + + Some(RecordBatch::try_new(schema, columns)?) + } else { + None + }; + + Ok(ValueOrLambda::Lambda(LambdaFunctionLambdaArg { + params, + body: Arc::clone(lambda.body()), + captures, + })) + } + (Some(_lambda), None) => exec_err!( + "{} don't reported the parameters of one of it's lambdas", + self.fun.name() + ), + (None, Some(_lambda_params)) => exec_err!( + "{} reported parameters for an argument that is not a lambda", + self.fun.name() + ), + (None, None) => Ok(ValueOrLambda::Value(arg.evaluate(batch)?)), + } + }) + .collect::>>()?; + + let input_empty = args.is_empty(); + let input_all_scalar = args + .iter() + .all(|arg| matches!(arg, ValueOrLambda::Value(ColumnarValue::Scalar(_)))); + + // evaluate the function + let output = self.fun.invoke_with_args(LambdaFunctionArgs { + args, + arg_fields, + number_rows: batch.num_rows(), + return_field: Arc::clone(&self.return_field), + config_options: Arc::clone(&self.config_options), + })?; + + if let ColumnarValue::Array(array) = &output { + if array.len() != batch.num_rows() { + // If the arguments are a non-empty slice of scalar values, we can assume that + // returning a one-element array is equivalent to returning a scalar. + let preserve_scalar = + array.len() == 1 && !input_empty && input_all_scalar; + return if preserve_scalar { + ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar) + } else { + internal_err!("UDF {} returned a different number of rows than expected. Expected: {}, Got: {}", + self.name, batch.num_rows(), array.len()) + }; + } + } + Ok(output) + } + + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::clone(&self.return_field)) + } + + fn children(&self) -> Vec<&Arc> { + self.args.iter().collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(LambdaFunctionExpr::new( + &self.name, + Arc::clone(&self.fun), + children, + Arc::clone(&self.return_field), + Arc::clone(&self.config_options), + ))) + } + + fn evaluate_bounds(&self, children: &[&Interval]) -> Result { + self.fun.evaluate_bounds(children) + } + + fn propagate_constraints( + &self, + interval: &Interval, + children: &[&Interval], + ) -> Result>> { + self.fun.propagate_constraints(interval, children) + } + + fn get_properties(&self, children: &[ExprProperties]) -> Result { + let sort_properties = self.fun.output_ordering(children)?; + let preserves_lex_ordering = self.fun.preserves_lex_ordering(children)?; + let children_range = children + .iter() + .map(|props| &props.range) + .collect::>(); + let range = self.fun().evaluate_bounds(&children_range)?; + + Ok(ExprProperties { + sort_properties, + range, + preserves_lex_ordering, + }) + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}(", self.name)?; + for (i, expr) in self.args.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + expr.fmt_sql(f)?; + } + write!(f, ")") + } + + fn is_volatile_node(&self) -> bool { + self.fun.signature().volatility == Volatility::Volatile + } +} + +#[cfg(test)] +mod tests { + use std::any::Any; + use std::sync::Arc; + + use super::*; + use crate::expressions::Column; + use crate::LambdaFunctionExpr; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::Result; + use datafusion_expr::{LambdaFunctionArgs, LambdaUDF, LambdaSignature}; + use datafusion_expr_common::columnar_value::ColumnarValue; + use datafusion_physical_expr_common::physical_expr::is_volatile; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + + /// Test helper to create a mock UDF with a specific volatility + #[derive(Debug, PartialEq, Eq, Hash)] + struct MockLambdaUDF { + signature: LambdaSignature, + } + + impl LambdaUDF for MockLambdaUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "mock_function" + } + + fn signature(&self) -> &LambdaSignature { + &self.signature + } + + fn return_field_from_args( + &self, + _args: LambdaReturnFieldArgs, + ) -> Result { + Ok(Arc::new(Field::new("", DataType::Int32, false))) + } + + fn invoke_with_args(&self, _args: LambdaFunctionArgs) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(42)))) + } + } + + #[test] + fn test_lambda_function_volatile_node() { + // Create a volatile UDF + let volatile_udf = Arc::new(MockLambdaUDF { + signature: LambdaSignature::variadic_any(Volatility::Volatile), + }); + + // Create a non-volatile UDF + let stable_udf = Arc::new(MockLambdaUDF { + signature: LambdaSignature::variadic_any(Volatility::Stable), + }); + + let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]); + let args = vec![Arc::new(Column::new("a", 0)) as Arc]; + let config_options = Arc::new(ConfigOptions::new()); + + // Test volatile function + let volatile_expr = LambdaFunctionExpr::try_new( + volatile_udf, + args.clone(), + &schema, + Arc::clone(&config_options), + ) + .unwrap(); + + assert!(volatile_expr.is_volatile_node()); + let volatile_arc: Arc = Arc::new(volatile_expr); + assert!(is_volatile(&volatile_arc)); + + // Test non-volatile function + let stable_expr = + LambdaFunctionExpr::try_new(stable_udf, args, &schema, config_options) + .unwrap(); + + assert!(!stable_expr.is_volatile_node()); + let stable_arc: Arc = Arc::new(stable_expr); + assert!(!is_volatile(&stable_arc)); + } +} diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index aa8c9e50fd71e..a05d24d2ba2c5 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -31,6 +31,7 @@ pub mod binary_map { pub use datafusion_physical_expr_common::binary_map::{ArrowBytesSet, OutputType}; } pub mod async_scalar_function; +pub mod lambda_function; pub mod equivalence; pub mod expressions; pub mod intervals; @@ -70,6 +71,7 @@ pub use datafusion_physical_expr_common::sort_expr::{ pub use planner::{create_physical_expr, create_physical_exprs}; pub use scalar_function::ScalarFunctionExpr; +pub use lambda_function::LambdaFunctionExpr; pub use simplifier::PhysicalExprSimplifier; pub use utils::{conjunction, conjunction_opt, split_conjunction}; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 7790380dffd56..f7be4aedf555e 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -17,7 +17,8 @@ use std::sync::Arc; -use crate::ScalarFunctionExpr; +use crate::expressions::{lambda_variable, LambdaExpr}; +use crate::{LambdaFunctionExpr, ScalarFunctionExpr}; use crate::{ expressions::{self, binary, like, similar_to, Column, Literal}, PhysicalExpr, @@ -27,10 +28,13 @@ use arrow::datatypes::Schema; use datafusion_common::config::ConfigOptions; use datafusion_common::metadata::FieldMetadata; use datafusion_common::{ - exec_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, ToDFSchema, + exec_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, + ToDFSchema, }; use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::expr::{Alias, Cast, InList, Placeholder, ScalarFunction}; +use datafusion_expr::expr::{ + Alias, Cast, InList, Lambda, LambdaFunction, LambdaVariable, Placeholder, ScalarFunction +}; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; use datafusion_expr::{ @@ -317,6 +321,7 @@ pub fn create_physical_expr( Expr::ScalarFunction(ScalarFunction { func, args }) => { let physical_args = create_physical_exprs(args, input_dfschema, execution_props)?; + let config_options = match execution_props.config_options.as_ref() { Some(config_options) => Arc::clone(config_options), None => Arc::new(ConfigOptions::default()), @@ -383,6 +388,34 @@ pub fn create_physical_expr( Expr::Placeholder(Placeholder { id, .. }) => { exec_err!("Placeholder '{id}' was not provided a value for execution.") } + Expr::LambdaFunction(LambdaFunction { func, args}) => { + let physical_args = + create_physical_exprs(args, input_dfschema, execution_props)?; + + let config_options = match execution_props.config_options.as_ref() { + Some(config_options) => Arc::clone(config_options), + None => Arc::new(ConfigOptions::default()), + }; + + Ok(Arc::new(LambdaFunctionExpr::try_new( + Arc::clone(func), + physical_args, + input_schema, + config_options, + )?)) + } + Expr::Lambda(Lambda { params, body }) => Ok(Arc::new(LambdaExpr::new( + params.clone(), + create_physical_expr(body, input_dfschema, execution_props)?, + ))), + Expr::LambdaVariable(LambdaVariable { + name, + field, + spans: _, + }) => lambda_variable( + name, + Arc::clone(field), + ), other => { not_impl_err!("Physical plan does not support logical expression {other:?}") } diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 743d5b99cde95..6ad22671ba847 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -105,6 +105,7 @@ impl ScalarFunctionExpr { .iter() .map(|f| f.data_type().clone()) .collect::>(); + data_types_with_scalar_udf(&arg_types, &fun)?; let arguments = args @@ -115,11 +116,14 @@ impl ScalarFunctionExpr { .map(|literal| literal.value()) }) .collect::>(); + let ret_args = ReturnFieldArgs { arg_fields: &arg_fields, scalar_arguments: &arguments, }; + let return_field = fun.return_field_from_args(ret_args)?; + Ok(Self { fun, name, @@ -283,6 +287,7 @@ impl PhysicalExpr for ScalarFunctionExpr { config_options: Arc::clone(&self.config_options), })?; + if let ColumnarValue::Array(array) = &output { if array.len() != batch.num_rows() { // If the arguments are a non-empty slice of scalar values, we can assume that @@ -367,12 +372,19 @@ impl PhysicalExpr for ScalarFunctionExpr { #[cfg(test)] mod tests { + use std::any::Any; + use std::sync::Arc; + use super::*; use crate::expressions::Column; + use crate::ScalarFunctionExpr; use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::Result; + use datafusion_expr::{ScalarFunctionArgs, Volatility}; use datafusion_expr::{ScalarUDF, ScalarUDFImpl, Signature}; + use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_physical_expr_common::physical_expr::is_volatile; - use std::any::Any; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; /// Test helper to create a mock UDF with a specific volatility #[derive(Debug, PartialEq, Eq, Hash)] diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 6eab2239015a7..e9060c0f2c986 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -24,11 +24,11 @@ use crate::physical_plan::{ AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; use crate::protobuf; -use datafusion_common::{plan_datafusion_err, Result}; +use datafusion_common::{not_impl_err, plan_datafusion_err, Result}; use datafusion_execution::TaskContext; use datafusion_expr::{ - create_udaf, create_udf, create_udwf, AggregateUDF, Expr, LogicalPlan, Volatility, - WindowUDF, + create_udaf, create_udf, create_udwf, AggregateUDF, Expr, LambdaUDF, LambdaSignature, LogicalPlan, + Volatility, WindowUDF, }; use prost::{ bytes::{Bytes, BytesMut}, @@ -167,6 +167,15 @@ impl Serializeable for Expr { ) } + fn register_udlf( + &mut self, + _udlf: Arc, + ) -> Result>> { + datafusion_common::internal_err!( + "register_udlf called in Placeholder Registry!" + ) + } + fn expr_planners(&self) -> Vec> { vec![] } @@ -178,6 +187,51 @@ impl Serializeable for Expr { fn udwfs(&self) -> std::collections::HashSet { std::collections::HashSet::default() } + + fn udlfs(&self) -> std::collections::HashSet { + std::collections::HashSet::default() + } + + fn udlf(&self, name: &str) -> Result> { + #[derive(Debug, PartialEq, Eq, Hash)] + struct MockLambdaUDF { + name: String, + signature: LambdaSignature, + } + + impl LambdaUDF for MockLambdaUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &LambdaSignature { + &self.signature + } + + fn return_field_from_args( + &self, + _args: datafusion_expr::LambdaReturnFieldArgs, + ) -> Result { + not_impl_err!("mock LambdaUDF") + } + + fn invoke_with_args( + &self, + _args: datafusion_expr::LambdaFunctionArgs, + ) -> Result { + not_impl_err!("mock LambdaUDF") + } + } + + Ok(Arc::new(MockLambdaUDF { + name: name.to_string(), + signature: LambdaSignature::variadic_any(Volatility::Immutable), + })) + } } Expr::from_bytes_with_registry(&bytes, &PlaceHolderRegistry)?; diff --git a/datafusion/proto/src/bytes/registry.rs b/datafusion/proto/src/bytes/registry.rs index 087e073db21af..98f4928457679 100644 --- a/datafusion/proto/src/bytes/registry.rs +++ b/datafusion/proto/src/bytes/registry.rs @@ -67,4 +67,14 @@ impl FunctionRegistry for NoRegistry { fn udwfs(&self) -> HashSet { HashSet::new() } + + fn udlfs(&self) -> HashSet { + HashSet::new() + } + + fn udlf(&self, name: &str) -> Result> { + plan_err!("No function registry provided to deserialize, so can not deserialize User Defined Lambda Function '{name}'") + } + + } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 9644c9f69feae..1122952771c79 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -63,7 +63,7 @@ use datafusion_expr::{ Statement, WindowUDF, }; use datafusion_expr::{ - AggregateUDF, DmlStatement, FetchType, RecursiveQuery, SkipType, TableSource, Unnest, + AggregateUDF, DmlStatement, FetchType, LambdaUDF, RecursiveQuery, SkipType, TableSource, Unnest }; use self::to_proto::{serialize_expr, serialize_exprs}; @@ -153,6 +153,14 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync { fn try_encode_udf(&self, _node: &ScalarUDF, _buf: &mut Vec) -> Result<()> { Ok(()) } + + fn try_decode_udlf(&self, name: &str, _buf: &[u8]) -> Result> { + not_impl_err!("LogicalExtensionCodec is not provided for lambda function {name}") + } + + fn try_encode_udlf(&self, _node: &dyn LambdaUDF, _buf: &mut Vec) -> Result<()> { + Ok(()) + } fn try_decode_udaf(&self, name: &str, _buf: &[u8]) -> Result> { not_impl_err!( diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 2774b5b6ba7c3..e080411b49e95 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -622,6 +622,22 @@ pub fn serialize_expr( .unwrap_or(HashMap::new()), })), }, + Expr::LambdaFunction(func) => { + let mut buf = Vec::new(); + let _ = codec.try_encode_udlf(func.func.as_ref(), &mut buf); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { + fun_name: func.name().to_string(), + fun_definition: (!buf.is_empty()).then_some(buf), + args: serialize_exprs(&func.args, codec)?, + })), + } + } + Expr::Lambda(_) | Expr::LambdaVariable(_) => { + return Err(Error::General( + "Proto serialization error: Lambda not implemented".to_string(), + )) + } }; Ok(expr_node) diff --git a/datafusion/session/src/session.rs b/datafusion/session/src/session.rs index fd033172f224f..625b3fb77a4d8 100644 --- a/datafusion/session/src/session.rs +++ b/datafusion/session/src/session.rs @@ -22,7 +22,7 @@ use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF}; +use datafusion_expr::{AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF}; use datafusion_physical_plan::{ExecutionPlan, PhysicalExpr}; use parking_lot::{Mutex, RwLock}; use std::any::Any; @@ -109,6 +109,9 @@ pub trait Session: Send + Sync { /// Return reference to scalar_functions fn scalar_functions(&self) -> &HashMap>; + + /// Return reference to lambda_functions + fn lambda_functions(&self) -> &HashMap>; /// Return reference to aggregate_functions fn aggregate_functions(&self) -> &HashMap>; @@ -149,6 +152,7 @@ impl From<&dyn Session> for TaskContext { state.session_id().to_string(), state.config().clone(), state.scalar_functions().clone(), + state.lambda_functions().clone(), state.aggregate_functions().clone(), state.window_functions().clone(), Arc::clone(state.runtime_env()), diff --git a/datafusion/spark/src/function/utils.rs b/datafusion/spark/src/function/utils.rs index e272d91d8a70e..b939dabda388d 100644 --- a/datafusion/spark/src/function/utils.rs +++ b/datafusion/spark/src/function/utils.rs @@ -60,7 +60,7 @@ pub mod test { let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { arg_fields: &arg_fields, - scalar_arguments: &scalar_arguments_refs + scalar_arguments: &scalar_arguments_refs, }); match expected { diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index 2c0bb86cd8087..3d2ff0528081c 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -22,7 +22,7 @@ use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, Result, TableReference}; use datafusion_expr::planner::ExprPlanner; -use datafusion_expr::WindowUDF; +use datafusion_expr::{LambdaUDF, WindowUDF}; use datafusion_expr::{ logical_plan::builder::LogicalTableSource, AggregateUDF, ScalarUDF, TableSource, }; @@ -138,6 +138,10 @@ impl ContextProvider for MyContextProvider { None } + fn get_lambda_meta(&self, _name: &str) -> Option> { + None + } + fn get_aggregate_meta(&self, name: &str) -> Option> { self.udafs.get(name).cloned() } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 50e479af36204..47f132d065980 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow::datatypes::DataType; @@ -22,11 +24,14 @@ use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Diagnostic, Result, Span, }; -use datafusion_expr::expr::{ - NullTreatment, ScalarFunction, Unnest, WildcardOptions, WindowFunction, +use datafusion_expr::expr::{Lambda, LambdaFunction, ScalarFunction, Unnest}; +use datafusion_expr::expr::{NullTreatment, WildcardOptions, WindowFunction}; +use datafusion_expr::planner::PlannerResult; +use datafusion_expr::planner::{RawAggregateExpr, RawWindowExpr}; +use datafusion_expr::{ + expr, Expr, ExprSchemable, ValueOrLambdaParameter, WindowFrame, + WindowFunctionDefinition, }; -use datafusion_expr::planner::{PlannerResult, RawAggregateExpr, RawWindowExpr}; -use datafusion_expr::{expr, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition}; use sqlparser::ast::{ DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, @@ -272,10 +277,145 @@ impl SqlToRel<'_, S> { } } } + + if let Some(fm) = self.context_provider.get_lambda_meta(&name) { + enum ExprOrLambda { + ExprWithName((Expr, Option)), + Lambda(sqlparser::ast::LambdaFunction), + } + + let pairs = args + .into_iter() + .map(|a| match a { + FunctionArg::Unnamed(FunctionArgExpr::Expr(SQLExpr::Lambda( + lambda, + ))) => Ok(ExprOrLambda::Lambda(lambda)), + _ => Ok(ExprOrLambda::ExprWithName( + self.sql_fn_arg_to_logical_expr_with_name( + a, + schema, + planner_context, + )?, + )), + }) + .collect::>>()?; + + let metadata = pairs + .iter() + .map(|e| match e { + ExprOrLambda::ExprWithName((expr, _name)) => { + Ok(ValueOrLambdaParameter::Value(expr.to_field(schema)?.1)) + } + ExprOrLambda::Lambda(_lambda_function) => { + Ok(ValueOrLambdaParameter::Lambda) + } + }) + .collect::>>()?; + + let lambdas_parameters = fm.lambdas_parameters(&metadata)?; + + let pairs = pairs + .into_iter() + .zip(lambdas_parameters) + .map(|(e, lambda_parameters)| match (e, lambda_parameters) { + (ExprOrLambda::ExprWithName(expr_with_name), None) => { + Ok(expr_with_name) + } + (ExprOrLambda::Lambda(lambda), Some(lambda_params)) => { + if lambda.params.len() > lambda_params.len() { + return plan_err!( + "lambda defined {} params but UDF support only {}", + lambda.params.len(), + lambda_params.len() + ); + } + + let params = + lambda.params.iter().map(|p| p.value.clone()).collect(); + + let lambda_parameters = lambda_params + .into_iter() + .zip(¶ms) + .map(|(f, n)| Arc::new(f.with_name(n))); + + let mut planner_context = planner_context + .clone() + .with_lambda_parameters(lambda_parameters); + + Ok(( + Expr::Lambda(Lambda { + params, + body: Box::new(self.sql_expr_to_logical_expr( + *lambda.body, + schema, + &mut planner_context, + )?), + }), + None, + )) + } + (ExprOrLambda::ExprWithName(_), Some(_)) => plan_err!( + "{} reported parameters for an argument that is not a lambda", + fm.name() + ), + (ExprOrLambda::Lambda(_), None) => plan_err!( + "{} don't reported the parameters of one of it's lambdas", + fm.name() + ), + }) + .collect::>>()?; + + let (args, arg_names): (Vec, Vec>) = + pairs.into_iter().unzip(); + + let resolved_args = if arg_names.iter().any(|name| name.is_some()) { + if let Some(param_names) = &fm.signature().parameter_names { + datafusion_expr::arguments::resolve_function_arguments( + param_names, + args, + arg_names, + )? + } else { + return plan_err!( + "Function '{}' does not support named arguments", + fm.name() + ); + } + } else { + args + }; + + // After resolution, all arguments are positional + let inner = LambdaFunction::new(fm, resolved_args); + + if name.eq_ignore_ascii_case(inner.name()) { + return Ok(Expr::LambdaFunction(inner)); + } else { + // If the function is called by an alias, a verbose string representation is created + // (e.g., "my_alias(arg1, arg2)") and the expression is wrapped in an `Alias` + // to ensure the output column name matches the user's query. + let arg_names = inner + .args + .iter() + .map(|arg| arg.to_string()) + .collect::>() + .join(","); + let verbose_alias = format!("{name}({arg_names})"); + + return Ok(Expr::LambdaFunction(inner).alias(verbose_alias)); + } + } + // User-defined function (UDF) should have precedence if let Some(fm) = self.context_provider.get_function_meta(&name) { - let (args, arg_names) = - self.function_args_to_expr_with_names(args, schema, planner_context)?; + let (args, arg_names): (Vec, Vec>) = args + .into_iter() + .map(|a| { + self.sql_fn_arg_to_logical_expr_with_name(a, schema, planner_context) + }) + .collect::>>()? + .into_iter() + .unzip(); let resolved_args = if arg_names.iter().any(|name| name.is_some()) { if let Some(param_names) = &fm.signature().parameter_names { diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 3c57d195ade67..14433e9cf7eba 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -20,6 +20,7 @@ use datafusion_common::{ exec_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, Result, Span, TableReference, }; +use datafusion_expr::expr::LambdaVariable; use datafusion_expr::planner::PlannerResult; use datafusion_expr::{Case, Expr}; use sqlparser::ast::{CaseWhen, Expr as SQLExpr, Ident}; @@ -53,6 +54,19 @@ impl SqlToRel<'_, S> { // identifier. (e.g. it is "foo.bar" not foo.bar) let normalize_ident = self.ident_normalizer.normalize(id); + if let Some(field) = planner_context + .lambdas_parameters() + .get(&normalize_ident) + { + let mut lambda_var = LambdaVariable::new(normalize_ident, Arc::clone(field)); + if self.options.collect_spans { + if let Some(span) = Span::try_from_sqlparser_span(id_span) { + lambda_var.spans_mut().add_span(span); + } + } + return Ok(Expr::LambdaVariable(lambda_var)); + } + // Check for qualified field with unqualified name if let Ok((qualifier, _)) = schema.qualified_field_with_unqualified_name(normalize_ident.as_str()) diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 715a02db8b027..e51f1c04cf157 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -1207,7 +1207,7 @@ mod tests { use datafusion_common::config::ConfigOptions; use datafusion_common::TableReference; use datafusion_expr::logical_plan::builder::LogicalTableSource; - use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; + use datafusion_expr::{AggregateUDF, LambdaUDF, ScalarUDF, TableSource, WindowUDF}; use super::*; @@ -1247,6 +1247,10 @@ mod tests { None } + fn get_lambda_meta(&self, _name: &str) -> Option> { + None + } + fn get_aggregate_meta(&self, name: &str) -> Option> { match name { "sum" => Some(datafusion_functions_aggregate::sum::sum_udaf()), diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 7bac0337672dc..8cc7747ffe16b 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -26,11 +26,10 @@ use arrow::datatypes::*; use datafusion_common::config::SqlParserOptions; use datafusion_common::datatype::{DataTypeExt, FieldExt}; use datafusion_common::error::add_possible_columns_to_diag; -use datafusion_common::TableReference; use datafusion_common::{ - field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, Diagnostic, - SchemaError, + field_not_found, plan_datafusion_err, DFSchemaRef, Diagnostic, SchemaError, }; +use datafusion_common::{internal_err, TableReference}; use datafusion_common::{not_impl_err, plan_err, DFSchema, DataFusionError, Result}; use datafusion_expr::logical_plan::{LogicalPlan, LogicalPlanBuilder}; pub use datafusion_expr::planner::ContextProvider; @@ -267,6 +266,8 @@ pub struct PlannerContext { outer_from_schema: Option, /// The query schema defined by the table create_table_schema: Option, + /// The parameters of all lambdas seen so far + lambdas_parameters: HashMap, } impl Default for PlannerContext { @@ -284,6 +285,7 @@ impl PlannerContext { outer_query_schema: None, outer_from_schema: None, create_table_schema: None, + lambdas_parameters: HashMap::new(), } } @@ -370,6 +372,20 @@ impl PlannerContext { self.ctes.get(cte_name).map(|cte| cte.as_ref()) } + pub fn lambdas_parameters(&self) -> &HashMap { + &self.lambdas_parameters + } + + pub fn with_lambda_parameters( + mut self, + arguments: impl IntoIterator, + ) -> Self { + self.lambdas_parameters + .extend(arguments.into_iter().map(|f| (f.name().clone(), f))); + + self + } + /// Remove the plan of CTE / Subquery for the specified name pub(super) fn remove_cte(&mut self, cte_name: &str) { self.ctes.remove(cte_name); diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 97f2b58bf8402..3bd669dbec071 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -15,13 +15,14 @@ // specific language governing permissions and limitations // under the License. -use datafusion_expr::expr::{AggregateFunctionParams, Unnest, WindowFunctionParams}; +use datafusion_expr::expr::{AggregateFunctionParams, LambdaFunction, WindowFunctionParams}; +use datafusion_expr::expr::{Lambda, Unnest}; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ - self, Array, BinaryOperator, CaseWhen, DuplicateTreatment, Expr as AstExpr, Function, - Ident, Interval, ObjectName, OrderByOptions, Subscript, TimezoneInfo, UnaryOperator, - ValueWithSpan, + self, Array, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, + ObjectName, Subscript, TimezoneInfo, UnaryOperator, }; +use sqlparser::ast::{CaseWhen, DuplicateTreatment, OrderByOptions, ValueWithSpan}; use std::sync::Arc; use std::vec; @@ -527,6 +528,29 @@ impl Unparser<'_> { } Expr::OuterReferenceColumn(_, col) => self.col_to_sql(col), Expr::Unnest(unnest) => self.unnest_to_sql(unnest), + Expr::LambdaFunction(LambdaFunction { func, args }) => { + let func_name = func.name(); + + if let Some(expr) = self + .dialect + .scalar_function_to_sql_overrides(self, func_name, args)? + { + return Ok(expr); + } + + self.scalar_function_to_sql(func_name, args) + } + Expr::Lambda(Lambda { params, body }) => { + Ok(ast::Expr::Lambda(ast::LambdaFunction { + params: ast::OneOrManyWithParens::Many( + params.iter().map(|param| param.as_str().into()).collect(), + ), + body: Box::new(self.expr_to_sql_inner(body)?), + })) + } + Expr::LambdaVariable(l) => Ok(ast::Expr::Identifier( + self.new_ident_quoted_if_needs(l.name.clone()), + )), } } diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index 5d9fd9f2c3740..6c9ac4bf70046 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -26,7 +26,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{plan_err, DFSchema, GetExt, Result, TableReference}; use datafusion_expr::planner::{ExprPlanner, PlannerResult, TypePlanner}; -use datafusion_expr::{AggregateUDF, Expr, ScalarUDF, TableSource, WindowUDF}; +use datafusion_expr::{AggregateUDF, Expr, LambdaUDF, ScalarUDF, TableSource, WindowUDF}; use datafusion_functions_nested::expr_fn::make_array; use datafusion_sql::planner::ContextProvider; @@ -53,6 +53,7 @@ impl Display for MockCsvType { #[derive(Default)] pub(crate) struct MockSessionState { scalar_functions: HashMap>, + lambda_functions: HashMap>, aggregate_functions: HashMap>, expr_planners: Vec>, type_planner: Option>, @@ -240,6 +241,10 @@ impl ContextProvider for MockContextProvider { self.state.scalar_functions.get(name).cloned() } + fn get_lambda_meta(&self, name: &str) -> Option> { + self.state.lambda_functions.get(name).cloned() + } + fn get_aggregate_meta(&self, name: &str) -> Option> { self.state.aggregate_functions.get(name).cloned() } diff --git a/datafusion/sqllogictest/test_files/lambda.slt b/datafusion/sqllogictest/test_files/lambda.slt new file mode 100644 index 0000000000000..af5334a644421 --- /dev/null +++ b/datafusion/sqllogictest/test_files/lambda.slt @@ -0,0 +1,180 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +############# +## Array Expressions Tests +############# + +statement ok +set datafusion.sql_parser.dialect = databricks; + +statement ok +CREATE TABLE tt +AS VALUES +([1, 50], 10), +([4, 50], 40); + +statement ok +CREATE TABLE t AS SELECT 1 as f, [ [ [2, 3], [2] ], [ [1] ], [ [] ] ] as v, 1 as n; + +query I? +SELECT t.n, array_transform([], e1 -> t.n) from t; +---- +1 [] + +query ? +SELECT array_transform([1], e1 -> (select n from t)); +---- +[1] + +query ? +SELECT array_transform(t.v, (v1, i) -> array_transform(v1, (v2, j) -> array_transform(v2, v3 -> j)) ) from t; +---- +[[[0, 1], [0]], [[0]], [[]]] + +query I? +SELECT t.n, array_transform([1, 2], (e) -> n) from t; +---- +1 [1, 1] + +# selection pushdown not working yet +query ? +SELECT array_transform([1, 2], (e) -> n) from t; +---- +[1, 1] + +query ? +SELECT array_transform([1, 2], (e, i) -> i) from t; +---- +[0, 1] + +# type coercion +query ? +SELECT array_transform([1, 2], (e, i) -> e+i) from t; +---- +[1, 3] + +query TT +EXPLAIN SELECT array_transform([1, 2], (e, i) -> e+i); +---- +logical_plan +01)Projection: array_transform(List([1, 2]), (e, i) -> e + CAST(i AS Int64)) AS array_transform(make_array(Int64(1),Int64(2)),(e, i) -> e + i) +02)--EmptyRelation: rows=1 +physical_plan +01)ProjectionExec: expr=[array_transform([1, 2], (e, i) -> e@-1 + CAST(i@-1 AS Int64)) as array_transform(make_array(Int64(1),Int64(2)),(e, i) -> e + i)] +02)--PlaceholderRowExec + +#cse +query TT +explain select n + 1, array_transform([1], v -> v + n + 1) from t; +---- +logical_plan +01)Projection: t.n + Int64(1), array_transform(List([1]), (v) -> v + t.n + Int64(1)) AS array_transform(make_array(Int64(1)),(v) -> v + t.n + Int64(1)) +02)--TableScan: t projection=[n] +physical_plan +01)ProjectionExec: expr=[n@0 + 1 as t.n + Int64(1), array_transform([1], (v) -> v@-1 + n@0 + 1) as array_transform(make_array(Int64(1)),(v) -> v + t.n + Int64(1))] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query ? +SELECT array_transform([1,2,3,4,5], v -> v*2); +---- +[2, 4, 6, 8, 10] + +query ? +SELECT array_transform([1,2,3,4,5], v -> 2); +---- +[2, 2, 2, 2, 2] + +query ? +SELECT array_transform([[1,2],[3,4,5]], v -> array_transform(v, v -> v*2)); +---- +[[2, 4], [6, 8, 10]] + +query ? +SELECT array_transform([1,2,3,4,5], v -> repeat("a", v)); +---- +[a, aa, aaa, aaaa, aaaaa] + +query ? +SELECT array_transform([1,2,3,4,5], v -> list_repeat("a", v)); +---- +[[a], [a, a], [a, a, a], [a, a, a, a], [a, a, a, a, a]] + +query TT +EXPLAIN SELECT array_transform([1,2,3,4,5], v -> v*2); +---- +logical_plan +01)Projection: array_transform(List([1, 2, 3, 4, 5]), (v) -> v * Int64(2)) AS array_transform(make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)),(v) -> v * Int64(2)) +02)--EmptyRelation: rows=1 +physical_plan +01)ProjectionExec: expr=[array_transform([1, 2, 3, 4, 5], (v) -> v@-1 * 2) as array_transform(make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)),(v) -> v * Int64(2))] +02)--PlaceholderRowExec + +query ? +SELECT array_transform( + [[1]], + v -> array_concat( + array_transform(v, v -> v), + array_transform(v, v1 -> v1 + v[0]) + ) +); +---- +[[1, NULL]] + +query I?? +SELECT t.n, t.v, array_transform(t.v, (v, i) -> array_transform(v, (v, j) -> n) ) from t; +---- +1 [[[2, 3], [2]], [[1]], [[]]] [[1, 1], [1], [1]] + +query ? +SELECT array_transform([1,2,3,4,5], v -> v*2); +---- +[2, 4, 6, 8, 10] + + +# expr simplifier +query TT +EXPLAIN SELECT v = v, array_transform([1], v -> v = v) from t; +---- +logical_plan +01)Projection: Boolean(true) AS t.v = t.v, array_transform(List([1]), (v) -> v IS NOT NULL OR Boolean(NULL)) AS array_transform(make_array(Int64(1)),(v) -> v = v) +02)--TableScan: t projection=[] +physical_plan +01)ProjectionExec: expr=[true as t.v = t.v, array_transform([1], (v) -> v@-1 IS NOT NULL OR NULL) as array_transform(make_array(Int64(1)),(v) -> v = v)] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + + +query error +select array_transform(); +---- +DataFusion error: Execution error: array_transform expects a value follewed by a lambda, got [] + + +query error DataFusion error: Execution error: expected list, got Field \{ "Int64\(1\)": Int64 \} +select array_transform(1, v -> v*2); + +query error DataFusion error: Execution error: array_transform expects a value follewed by a lambda, got \[Lambda, Value\(Field \{ name: "make_array\(Int64\(1\),Int64\(2\)\)", data_type: List\(Field \{ data_type: Int64, nullable: true \}\), nullable: true \}\)\] +select array_transform(v -> v*2, [1, 2]); + +query error DataFusion error: Error during planning: lambda defined 3 params but UDF support only 2 +SELECT array_transform([1, 2], (e, i, j) -> i) from t; + +#todo: this should error due to duplicate names +query ? +SELECT array_transform([1], (v, v) -> v*2); +---- +[0] diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs index f80cf43eb81eb..062c1ac03110c 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs @@ -30,6 +30,7 @@ pub async fn from_scalar_function( f: &ScalarFunction, input_schema: &DFSchema, ) -> Result { + //TODO: handle lambda functions, as they are also encoded as scalar functions let Some(fn_signature) = consumer .get_extensions() .functions diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs index f4e43fd586773..d1112b99536d9 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -152,6 +152,9 @@ pub fn to_substrait_rex( not_impl_err!("Cannot convert {expr:?} to Substrait") } Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::LambdaFunction(expr) => producer.handle_lambda_function(expr, schema), + Expr::Lambda(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // not yet implemented in substrait-rs + Expr::LambdaVariable(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // not yet implemented in substrait-rs } } diff --git a/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs b/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs index abb26f6f66822..b2057a9d914f8 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs @@ -26,17 +26,34 @@ pub fn from_scalar_function( producer: &mut impl SubstraitProducer, fun: &expr::ScalarFunction, schema: &DFSchemaRef, +) -> datafusion::common::Result { + from_function(producer, fun.name(), &fun.args, schema) +} + +pub fn from_lambda_function( + producer: &mut impl SubstraitProducer, + fun: &expr::LambdaFunction, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + from_function(producer, fun.name(), &fun.args, schema) +} + +fn from_function( + producer: &mut impl SubstraitProducer, + name: &str, + args: &[Expr], + schema: &DFSchemaRef, ) -> datafusion::common::Result { let mut arguments: Vec = vec![]; - for arg in &fun.args { + for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), }); } - let arguments = custom_argument_handler(fun.name(), arguments); + let arguments = custom_argument_handler(name, arguments); - let function_anchor = producer.register_function(fun.name().to_string()); + let function_anchor = producer.register_function(name.to_string()); #[allow(deprecated)] Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { diff --git a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs index db08e0f7bfd0c..d065bcf41586a 100644 --- a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs +++ b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs @@ -17,12 +17,7 @@ use crate::extensions::Extensions; use crate::logical_plan::producer::{ - from_aggregate, from_aggregate_function, from_alias, from_between, from_binary_expr, - from_case, from_cast, from_column, from_distinct, from_empty_relation, from_filter, - from_in_list, from_in_subquery, from_join, from_like, from_limit, from_literal, - from_projection, from_repartition, from_scalar_function, from_sort, - from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, from_union, - from_values, from_window, from_window_function, to_substrait_rel, to_substrait_rex, + from_aggregate, from_aggregate_function, from_alias, from_between, from_binary_expr, from_case, from_cast, from_column, from_distinct, from_empty_relation, from_filter, from_in_list, from_in_subquery, from_join, from_lambda_function, from_like, from_limit, from_literal, from_projection, from_repartition, from_scalar_function, from_sort, from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, from_union, from_values, from_window, from_window_function, to_substrait_rel, to_substrait_rex }; use datafusion::common::{substrait_err, Column, DFSchemaRef, ScalarValue}; use datafusion::execution::registry::SerializerRegistry; @@ -327,6 +322,14 @@ pub trait SubstraitProducer: Send + Sync + Sized { ) -> datafusion::common::Result { from_scalar_function(self, scalar_fn, schema) } + + fn handle_lambda_function( + &mut self, + scalar_fn: &expr::LambdaFunction, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_lambda_function(self, scalar_fn, schema) + } fn handle_aggregate_function( &mut self,