Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 247 additions & 14 deletions datafusion/spark/src/function/math/ceil.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ use std::sync::Arc;

use arrow::array::{ArrowNativeTypeOp, AsArray, Decimal128Array};
use arrow::datatypes::{DataType, Decimal128Type, Float32Type, Float64Type, Int64Type};
use datafusion_common::utils::take_function_args;
use datafusion_common::types::{
NativeType, logical_float32, logical_float64, logical_int32,
};
use datafusion_common::{Result, ScalarValue, exec_err};
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
TypeSignatureClass, Volatility,
};

/// Spark-compatible `ceil` expression
Expand All @@ -36,8 +39,6 @@ use datafusion_expr::{
/// - Spark only supports Decimal128; DataFusion also supports Decimal32/64/256
/// - Spark does not check for decimal overflow; DataFusion errors on overflow
///
/// 2-argument ceil(value, scale) is not yet implemented
/// <https://github.com/apache/datafusion/issues/21560>
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkCeil {
signature: Signature,
Expand All @@ -52,8 +53,50 @@ impl Default for SparkCeil {

impl SparkCeil {
pub fn new() -> Self {
let decimal = Coercion::new_exact(TypeSignatureClass::Decimal);
let integer = Coercion::new_exact(TypeSignatureClass::Integer);
let decimal_places = Coercion::new_implicit(
TypeSignatureClass::Native(logical_int32()),
vec![TypeSignatureClass::Integer],
NativeType::Int32,
);
let float32 = Coercion::new_exact(TypeSignatureClass::Native(logical_float32()));
let float64 = Coercion::new_implicit(
TypeSignatureClass::Native(logical_float64()),
vec![TypeSignatureClass::Numeric],
NativeType::Float64,
);
Self {
signature: Signature::numeric(1, Volatility::Immutable),
signature: Signature::one_of(
vec![
// ceil(decimal, scale)
TypeSignature::Coercible(vec![
decimal.clone(),
decimal_places.clone(),
]),
// ceil(decimal)
TypeSignature::Coercible(vec![decimal]),
// ceil(integer, scale)
TypeSignature::Coercible(vec![
integer.clone(),
decimal_places.clone(),
]),
// ceil(integer)
TypeSignature::Coercible(vec![integer]),
// ceil(float32, scale)
TypeSignature::Coercible(vec![
float32.clone(),
decimal_places.clone(),
]),
// ceil(float32)
TypeSignature::Coercible(vec![float32]),
// ceil(float64, scale)
TypeSignature::Coercible(vec![float64.clone(), decimal_places]),
// ceil(float64)
TypeSignature::Coercible(vec![float64]),
],
Volatility::Immutable,
),
aliases: vec!["ceiling".to_string()],
}
}
Expand All @@ -69,7 +112,10 @@ impl ScalarUDFImpl for SparkCeil {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
let has_scale = arg_types.len() == 2;

match &arg_types[0] {
DataType::Decimal128(p, s) if has_scale => Ok(DataType::Decimal128(*p, *s)),
DataType::Decimal128(p, s) => {
if *s > 0 {
Ok(DataType::Decimal128(decimal128_ceil_precision(*p, *s), 0))
Expand All @@ -79,6 +125,8 @@ impl ScalarUDFImpl for SparkCeil {
Ok(DataType::Decimal128(*p, *s))
}
}
DataType::Float32 if has_scale => Ok(DataType::Float32),
DataType::Float64 if has_scale => Ok(DataType::Float64),
dt if matches!(dt, DataType::Float32 | DataType::Float64)
|| dt.is_integer() =>
{
Expand All @@ -97,12 +145,67 @@ impl ScalarUDFImpl for SparkCeil {
}
}

/// Extract the scale (decimal places) from the second argument.
/// Returns `Some(0)` if no second argument is provided.
/// Returns `None` if the scale argument is NULL (Spark returns NULL for `round(expr, NULL)`).
fn get_scale(args: &[ColumnarValue]) -> Result<Option<i32>> {
if args.len() < 2 {
return Ok(Some(0));
}

match &args[1] {
ColumnarValue::Scalar(ScalarValue::Int8(Some(v))) => Ok(Some(i32::from(*v))),
ColumnarValue::Scalar(ScalarValue::Int16(Some(v))) => Ok(Some(i32::from(*v))),
ColumnarValue::Scalar(ScalarValue::Int32(Some(v))) => Ok(Some(*v)),
ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) => {
i32::try_from(*v).map(Some).map_err(|_| {
(exec_err!("round scale {v} is out of supported i32 range")
as Result<(), _>)
.unwrap_err()
})
}
ColumnarValue::Scalar(ScalarValue::UInt8(Some(v))) => Ok(Some(i32::from(*v))),
ColumnarValue::Scalar(ScalarValue::UInt16(Some(v))) => Ok(Some(i32::from(*v))),
ColumnarValue::Scalar(ScalarValue::UInt32(Some(v))) => {
i32::try_from(*v).map(Some).map_err(|_| {
(exec_err!("round scale {v} is out of supported i32 range")
as Result<(), _>)
.unwrap_err()
})
}
ColumnarValue::Scalar(ScalarValue::UInt64(Some(v))) => {
i32::try_from(*v).map(Some).map_err(|_| {
(exec_err!("round scale {v} is out of supported i32 range")
as Result<(), _>)
.unwrap_err()
})
}
ColumnarValue::Scalar(sv) if sv.is_null() => Ok(None),
other => exec_err!("Unsupported type for round scale: {}", other.data_type()),
}
}

fn spark_ceil(args: &[ColumnarValue]) -> Result<ColumnarValue> {
let [input] = take_function_args("ceil", args)?;
if args.is_empty() || args.len() > 2 {
return exec_err!(
"ceil function requires 1 or 2 arguments, got {}",
args.len()
);
}

let scale = match get_scale(args)? {
Some(scale) => scale,
None => {
return Ok(ColumnarValue::Scalar(ScalarValue::try_from(
args[0].data_type(),
)?));
}
};
let input = &args[0];

match input {
ColumnarValue::Scalar(value) => spark_ceil_scalar(value),
ColumnarValue::Array(input) => spark_ceil_array(input),
ColumnarValue::Scalar(value) => spark_ceil_scalar(value, scale),
ColumnarValue::Array(input) => spark_ceil_array(input, scale),
}
}

Expand All @@ -121,11 +224,21 @@ fn decimal128_ceil_precision(precision: u8, scale: i8) -> u8 {
((precision as i64) - (scale as i64) + 1).clamp(1, 38) as u8
}

fn spark_ceil_scalar(value: &ScalarValue) -> Result<ColumnarValue> {
fn spark_ceil_scalar(value: &ScalarValue, scale: i32) -> Result<ColumnarValue> {
let result = match value {
ScalarValue::Float32(v) => ScalarValue::Int64(v.map(|x| x.ceil() as i64)),
ScalarValue::Float64(v) => ScalarValue::Int64(v.map(|x| x.ceil() as i64)),
// Floats without scale (scale=0) -> Int64 (original behaivour)
ScalarValue::Float32(v) if scale == 0 => {
ScalarValue::Int64(v.map(|x| x.ceil() as i64))
}
ScalarValue::Float64(v) if scale == 0 => {
ScalarValue::Int64(v.map(|x| x.ceil() as i64))
}
// Floats with scale -> preserve type
ScalarValue::Float32(v) => ScalarValue::Float32(v.map(|x| ceil_float(x, scale))),
ScalarValue::Float64(v) => ScalarValue::Float64(v.map(|x| ceil_float(x, scale))),
// Integers: negative scale rounds, positive is no-op
v if v.data_type().is_integer() => v.cast_to(&DataType::Int64)?,
// Decimal128 with positive scalar
ScalarValue::Decimal128(v, p, s) if *s > 0 => {
let new_p = decimal128_ceil_precision(*p, *s);
ScalarValue::Decimal128(v.map(|x| decimal128_ceil(x, *s as u32)), new_p, 0)
Expand All @@ -141,18 +254,31 @@ fn spark_ceil_scalar(value: &ScalarValue) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(result))
}

fn spark_ceil_array(input: &Arc<dyn arrow::array::Array>) -> Result<ColumnarValue> {
fn spark_ceil_array(
input: &Arc<dyn arrow::array::Array>,
scale: i32,
) -> Result<ColumnarValue> {
let result = match input.data_type() {
DataType::Float32 => Arc::new(
DataType::Float32 if scale == 0 => Arc::new(
input
.as_primitive::<Float32Type>()
.unary::<_, Int64Type>(|x| x.ceil() as i64),
) as _,
DataType::Float64 => Arc::new(
DataType::Float64 if scale == 0 => Arc::new(
input
.as_primitive::<Float64Type>()
.unary::<_, Int64Type>(|x| x.ceil() as i64),
) as _,
DataType::Float32 => Arc::new(
input
.as_primitive::<Float32Type>()
.unary::<_, Float32Type>(|x| ceil_float(x, scale)),
) as _,
DataType::Float64 => Arc::new(
input
.as_primitive::<Float64Type>()
.unary::<_, Float64Type>(|x| ceil_float(x, scale)),
) as _,
dt if dt.is_integer() => arrow::compute::cast(input, &DataType::Int64)?,
DataType::Decimal128(p, s) if *s > 0 => {
let new_p = decimal128_ceil_precision(*p, *s);
Expand All @@ -168,6 +294,22 @@ fn spark_ceil_array(input: &Arc<dyn arrow::array::Array>) -> Result<ColumnarValu
Ok(ColumnarValue::Array(result))
}

fn ceil_float<T: num_traits::Float>(value: T, scale: i32) -> T {
if scale >= 0 {
let factor = T::from(10.0f64.powi(scale)).unwrap_or_else(T::infinity);
if factor.is_infinite() {
return value;
}
(value * factor).ceil() / factor
} else {
let factor = T::from(10.0f64.powi(-scale)).unwrap_or_else(T::infinity);
if factor.is_infinite() {
return T::zero();
}
(value / factor).ceil() * factor
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -301,4 +443,95 @@ mod tests {
};
assert_eq!(result, ScalarValue::Int64(Some(48)));
}

#[test]
fn test_ceil_float64_scalar_with_positive_scale() {
// ceil(3.1411, 2) → 3.15
let args = vec![
ColumnarValue::Scalar(ScalarValue::Float64(Some(3.1411))),
ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
];
let result = match spark_ceil(&args).unwrap() {
ColumnarValue::Scalar(v) => v,
_ => panic!("Expected scalar"),
};
assert_eq!(result, ScalarValue::Float64(Some(3.15)));
}

#[test]
fn test_ceil_float64_scalar_with_negative_scale() {
// ceil(3345.1, -2) → 3400.0
let args = vec![
ColumnarValue::Scalar(ScalarValue::Float64(Some(3345.1))),
ColumnarValue::Scalar(ScalarValue::Int32(Some(-2))),
];
let result = match spark_ceil(&args).unwrap() {
ColumnarValue::Scalar(v) => v,
_ => panic!("Expected scalar"),
};
assert_eq!(result, ScalarValue::Float64(Some(3400.0)));
}

#[test]
fn test_ceil_float64_scalar_with_zero_scale() {
// ceil(3.5, 0) → 4 as Int64 (same as 1-arg behavior)
let args = vec![
ColumnarValue::Scalar(ScalarValue::Float64(Some(3.5))),
ColumnarValue::Scalar(ScalarValue::Int32(Some(0))),
];
let result = match spark_ceil(&args).unwrap() {
ColumnarValue::Scalar(v) => v,
_ => panic!("Expected scalar"),
};
assert_eq!(result, ScalarValue::Int64(Some(4)));
}

#[test]
fn test_ceil_float32_scalar_with_scale() {
// ceil(3.1f32, 1) → 3.1 (already exact)
let args = vec![
ColumnarValue::Scalar(ScalarValue::Float32(Some(3.1f32))),
ColumnarValue::Scalar(ScalarValue::Int32(Some(1))),
];
let result = match spark_ceil(&args).unwrap() {
ColumnarValue::Scalar(v) => v,
_ => panic!("Expected scalar"),
};
// 3.1f32 ceiling at 1 decimal place stays 3.1
assert_eq!(result, ScalarValue::Float32(Some(3.1f32)));
}

#[test]
fn test_ceil_float64_array_with_scale() {
// ceil([3.1411, -1.001, 0.0, NULL], 2) → [3.15, -1.0, 0.0, NULL]
let input = Float64Array::from(vec![Some(3.1411), Some(-1.001), Some(0.0), None]);
let args = vec![
ColumnarValue::Array(Arc::new(input)),
ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
];
let result = spark_ceil(&args).unwrap();
let result = match result {
ColumnarValue::Array(arr) => arr,
_ => panic!("Expected array"),
};
let result = result.as_primitive::<Float64Type>();
assert_eq!(
result,
&Float64Array::from(vec![Some(3.15), Some(-1.0), Some(0.0), None])
);
}

#[test]
fn test_ceil_null_scale() {
// ceil(3.5, NULL) → NULL
let args = vec![
ColumnarValue::Scalar(ScalarValue::Float64(Some(3.5))),
ColumnarValue::Scalar(ScalarValue::Int32(None)),
];
let result = match spark_ceil(&args).unwrap() {
ColumnarValue::Scalar(v) => v,
_ => panic!("Expected scalar"),
};
assert!(result.is_null());
}
}
2 changes: 1 addition & 1 deletion datafusion/spark/src/function/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub mod expr_fn {
use datafusion_functions::export_functions;

export_functions!((abs, "Returns abs(expr)", arg1));
export_functions!((ceil, "Returns the ceiling of expr.", arg1));
export_functions!((ceil, "Returns the ceiling of expr.", arg1 arg2));
export_functions!((expm1, "Returns exp(expr) - 1 as a Float64.", arg1));
export_functions!((
factorial,
Expand Down