diff --git a/datafusion/spark/src/function/math/ceil.rs b/datafusion/spark/src/function/math/ceil.rs index 5096914a1eba8..a1c3cd58ad5f0 100644 --- a/datafusion/spark/src/function/math/ceil.rs +++ b/datafusion/spark/src/function/math/ceil.rs @@ -19,10 +19,13 @@ use std::sync::Arc; use arrow::array::{ArrowNativeTypeOp, AsArray, Decimal128Array}; use arrow::datatypes::{DataType, Decimal128Type, Float32Type, Float64Type, Int64Type}; -use datafusion_common::utils::take_function_args; +use datafusion_common::types::{ + NativeType, logical_float32, logical_float64, logical_int32, +}; use datafusion_common::{Result, ScalarValue, exec_err}; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, Volatility, }; /// Spark-compatible `ceil` expression @@ -36,8 +39,6 @@ use datafusion_expr::{ /// - Spark only supports Decimal128; DataFusion also supports Decimal32/64/256 /// - Spark does not check for decimal overflow; DataFusion errors on overflow /// -/// 2-argument ceil(value, scale) is not yet implemented -/// #[derive(Debug, PartialEq, Eq, Hash)] pub struct SparkCeil { signature: Signature, @@ -52,8 +53,50 @@ impl Default for SparkCeil { impl SparkCeil { pub fn new() -> Self { + let decimal = Coercion::new_exact(TypeSignatureClass::Decimal); + let integer = Coercion::new_exact(TypeSignatureClass::Integer); + let decimal_places = Coercion::new_implicit( + TypeSignatureClass::Native(logical_int32()), + vec![TypeSignatureClass::Integer], + NativeType::Int32, + ); + let float32 = Coercion::new_exact(TypeSignatureClass::Native(logical_float32())); + let float64 = Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ); Self { - signature: Signature::numeric(1, Volatility::Immutable), + signature: Signature::one_of( + vec![ + // ceil(decimal, scale) + TypeSignature::Coercible(vec![ + decimal.clone(), + decimal_places.clone(), + ]), + // ceil(decimal) + TypeSignature::Coercible(vec![decimal]), + // ceil(integer, scale) + TypeSignature::Coercible(vec![ + integer.clone(), + decimal_places.clone(), + ]), + // ceil(integer) + TypeSignature::Coercible(vec![integer]), + // ceil(float32, scale) + TypeSignature::Coercible(vec![ + float32.clone(), + decimal_places.clone(), + ]), + // ceil(float32) + TypeSignature::Coercible(vec![float32]), + // ceil(float64, scale) + TypeSignature::Coercible(vec![float64.clone(), decimal_places]), + // ceil(float64) + TypeSignature::Coercible(vec![float64]), + ], + Volatility::Immutable, + ), aliases: vec!["ceiling".to_string()], } } @@ -69,7 +112,10 @@ impl ScalarUDFImpl for SparkCeil { } fn return_type(&self, arg_types: &[DataType]) -> Result { + let has_scale = arg_types.len() == 2; + match &arg_types[0] { + DataType::Decimal128(p, s) if has_scale => Ok(DataType::Decimal128(*p, *s)), DataType::Decimal128(p, s) => { if *s > 0 { Ok(DataType::Decimal128(decimal128_ceil_precision(*p, *s), 0)) @@ -79,6 +125,8 @@ impl ScalarUDFImpl for SparkCeil { Ok(DataType::Decimal128(*p, *s)) } } + DataType::Float32 if has_scale => Ok(DataType::Float32), + DataType::Float64 if has_scale => Ok(DataType::Float64), dt if matches!(dt, DataType::Float32 | DataType::Float64) || dt.is_integer() => { @@ -97,12 +145,67 @@ impl ScalarUDFImpl for SparkCeil { } } +/// Extract the scale (decimal places) from the second argument. +/// Returns `Some(0)` if no second argument is provided. +/// Returns `None` if the scale argument is NULL (Spark returns NULL for `round(expr, NULL)`). +fn get_scale(args: &[ColumnarValue]) -> Result> { + if args.len() < 2 { + return Ok(Some(0)); + } + + match &args[1] { + ColumnarValue::Scalar(ScalarValue::Int8(Some(v))) => Ok(Some(i32::from(*v))), + ColumnarValue::Scalar(ScalarValue::Int16(Some(v))) => Ok(Some(i32::from(*v))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(v))) => Ok(Some(*v)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) => { + i32::try_from(*v).map(Some).map_err(|_| { + (exec_err!("round scale {v} is out of supported i32 range") + as Result<(), _>) + .unwrap_err() + }) + } + ColumnarValue::Scalar(ScalarValue::UInt8(Some(v))) => Ok(Some(i32::from(*v))), + ColumnarValue::Scalar(ScalarValue::UInt16(Some(v))) => Ok(Some(i32::from(*v))), + ColumnarValue::Scalar(ScalarValue::UInt32(Some(v))) => { + i32::try_from(*v).map(Some).map_err(|_| { + (exec_err!("round scale {v} is out of supported i32 range") + as Result<(), _>) + .unwrap_err() + }) + } + ColumnarValue::Scalar(ScalarValue::UInt64(Some(v))) => { + i32::try_from(*v).map(Some).map_err(|_| { + (exec_err!("round scale {v} is out of supported i32 range") + as Result<(), _>) + .unwrap_err() + }) + } + ColumnarValue::Scalar(sv) if sv.is_null() => Ok(None), + other => exec_err!("Unsupported type for round scale: {}", other.data_type()), + } +} + fn spark_ceil(args: &[ColumnarValue]) -> Result { - let [input] = take_function_args("ceil", args)?; + if args.is_empty() || args.len() > 2 { + return exec_err!( + "ceil function requires 1 or 2 arguments, got {}", + args.len() + ); + } + + let scale = match get_scale(args)? { + Some(scale) => scale, + None => { + return Ok(ColumnarValue::Scalar(ScalarValue::try_from( + args[0].data_type(), + )?)); + } + }; + let input = &args[0]; match input { - ColumnarValue::Scalar(value) => spark_ceil_scalar(value), - ColumnarValue::Array(input) => spark_ceil_array(input), + ColumnarValue::Scalar(value) => spark_ceil_scalar(value, scale), + ColumnarValue::Array(input) => spark_ceil_array(input, scale), } } @@ -121,11 +224,21 @@ fn decimal128_ceil_precision(precision: u8, scale: i8) -> u8 { ((precision as i64) - (scale as i64) + 1).clamp(1, 38) as u8 } -fn spark_ceil_scalar(value: &ScalarValue) -> Result { +fn spark_ceil_scalar(value: &ScalarValue, scale: i32) -> Result { let result = match value { - ScalarValue::Float32(v) => ScalarValue::Int64(v.map(|x| x.ceil() as i64)), - ScalarValue::Float64(v) => ScalarValue::Int64(v.map(|x| x.ceil() as i64)), + // Floats without scale (scale=0) -> Int64 (original behaivour) + ScalarValue::Float32(v) if scale == 0 => { + ScalarValue::Int64(v.map(|x| x.ceil() as i64)) + } + ScalarValue::Float64(v) if scale == 0 => { + ScalarValue::Int64(v.map(|x| x.ceil() as i64)) + } + // Floats with scale -> preserve type + ScalarValue::Float32(v) => ScalarValue::Float32(v.map(|x| ceil_float(x, scale))), + ScalarValue::Float64(v) => ScalarValue::Float64(v.map(|x| ceil_float(x, scale))), + // Integers: negative scale rounds, positive is no-op v if v.data_type().is_integer() => v.cast_to(&DataType::Int64)?, + // Decimal128 with positive scalar ScalarValue::Decimal128(v, p, s) if *s > 0 => { let new_p = decimal128_ceil_precision(*p, *s); ScalarValue::Decimal128(v.map(|x| decimal128_ceil(x, *s as u32)), new_p, 0) @@ -141,18 +254,31 @@ fn spark_ceil_scalar(value: &ScalarValue) -> Result { Ok(ColumnarValue::Scalar(result)) } -fn spark_ceil_array(input: &Arc) -> Result { +fn spark_ceil_array( + input: &Arc, + scale: i32, +) -> Result { let result = match input.data_type() { - DataType::Float32 => Arc::new( + DataType::Float32 if scale == 0 => Arc::new( input .as_primitive::() .unary::<_, Int64Type>(|x| x.ceil() as i64), ) as _, - DataType::Float64 => Arc::new( + DataType::Float64 if scale == 0 => Arc::new( input .as_primitive::() .unary::<_, Int64Type>(|x| x.ceil() as i64), ) as _, + DataType::Float32 => Arc::new( + input + .as_primitive::() + .unary::<_, Float32Type>(|x| ceil_float(x, scale)), + ) as _, + DataType::Float64 => Arc::new( + input + .as_primitive::() + .unary::<_, Float64Type>(|x| ceil_float(x, scale)), + ) as _, dt if dt.is_integer() => arrow::compute::cast(input, &DataType::Int64)?, DataType::Decimal128(p, s) if *s > 0 => { let new_p = decimal128_ceil_precision(*p, *s); @@ -168,6 +294,22 @@ fn spark_ceil_array(input: &Arc) -> Result(value: T, scale: i32) -> T { + if scale >= 0 { + let factor = T::from(10.0f64.powi(scale)).unwrap_or_else(T::infinity); + if factor.is_infinite() { + return value; + } + (value * factor).ceil() / factor + } else { + let factor = T::from(10.0f64.powi(-scale)).unwrap_or_else(T::infinity); + if factor.is_infinite() { + return T::zero(); + } + (value / factor).ceil() * factor + } +} + #[cfg(test)] mod tests { use super::*; @@ -301,4 +443,95 @@ mod tests { }; assert_eq!(result, ScalarValue::Int64(Some(48))); } + + #[test] + fn test_ceil_float64_scalar_with_positive_scale() { + // ceil(3.1411, 2) → 3.15 + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(3.1411))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ]; + let result = match spark_ceil(&args).unwrap() { + ColumnarValue::Scalar(v) => v, + _ => panic!("Expected scalar"), + }; + assert_eq!(result, ScalarValue::Float64(Some(3.15))); + } + + #[test] + fn test_ceil_float64_scalar_with_negative_scale() { + // ceil(3345.1, -2) → 3400.0 + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(3345.1))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(-2))), + ]; + let result = match spark_ceil(&args).unwrap() { + ColumnarValue::Scalar(v) => v, + _ => panic!("Expected scalar"), + }; + assert_eq!(result, ScalarValue::Float64(Some(3400.0))); + } + + #[test] + fn test_ceil_float64_scalar_with_zero_scale() { + // ceil(3.5, 0) → 4 as Int64 (same as 1-arg behavior) + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(3.5))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(0))), + ]; + let result = match spark_ceil(&args).unwrap() { + ColumnarValue::Scalar(v) => v, + _ => panic!("Expected scalar"), + }; + assert_eq!(result, ScalarValue::Int64(Some(4))); + } + + #[test] + fn test_ceil_float32_scalar_with_scale() { + // ceil(3.1f32, 1) → 3.1 (already exact) + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Float32(Some(3.1f32))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), + ]; + let result = match spark_ceil(&args).unwrap() { + ColumnarValue::Scalar(v) => v, + _ => panic!("Expected scalar"), + }; + // 3.1f32 ceiling at 1 decimal place stays 3.1 + assert_eq!(result, ScalarValue::Float32(Some(3.1f32))); + } + + #[test] + fn test_ceil_float64_array_with_scale() { + // ceil([3.1411, -1.001, 0.0, NULL], 2) → [3.15, -1.0, 0.0, NULL] + let input = Float64Array::from(vec![Some(3.1411), Some(-1.001), Some(0.0), None]); + let args = vec![ + ColumnarValue::Array(Arc::new(input)), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ]; + let result = spark_ceil(&args).unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + assert_eq!( + result, + &Float64Array::from(vec![Some(3.15), Some(-1.0), Some(0.0), None]) + ); + } + + #[test] + fn test_ceil_null_scale() { + // ceil(3.5, NULL) → NULL + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(3.5))), + ColumnarValue::Scalar(ScalarValue::Int32(None)), + ]; + let result = match spark_ceil(&args).unwrap() { + ColumnarValue::Scalar(v) => v, + _ => panic!("Expected scalar"), + }; + assert!(result.is_null()); + } } diff --git a/datafusion/spark/src/function/math/mod.rs b/datafusion/spark/src/function/math/mod.rs index 3f6a3a686db30..237878508a5d4 100644 --- a/datafusion/spark/src/function/math/mod.rs +++ b/datafusion/spark/src/function/math/mod.rs @@ -53,7 +53,7 @@ pub mod expr_fn { use datafusion_functions::export_functions; export_functions!((abs, "Returns abs(expr)", arg1)); - export_functions!((ceil, "Returns the ceiling of expr.", arg1)); + export_functions!((ceil, "Returns the ceiling of expr.", arg1 arg2)); export_functions!((expm1, "Returns exp(expr) - 1 as a Float64.", arg1)); export_functions!(( factorial,