From 68c1c07f65c313036028ef4c4b30837dab77c967 Mon Sep 17 00:00:00 2001 From: Christian McArthur Date: Fri, 10 Apr 2026 16:12:17 -0400 Subject: [PATCH 1/3] feat: add cosine_distance scalar function Add cosine_distance (and list_cosine_distance alias) to compute cosine distance between two numeric arrays. Includes shared vector math primitives in vector_math.rs for reuse by follow-on functions. Part of #21536. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../functions-nested/src/cosine_distance.rs | 322 ++++++++++++++++++ datafusion/functions-nested/src/lib.rs | 4 + .../functions-nested/src/vector_math.rs | 68 ++++ .../test_files/cosine_distance.slt | 109 ++++++ .../source/user-guide/sql/scalar_functions.md | 34 ++ 5 files changed, 537 insertions(+) create mode 100644 datafusion/functions-nested/src/cosine_distance.rs create mode 100644 datafusion/functions-nested/src/vector_math.rs create mode 100644 datafusion/sqllogictest/test_files/cosine_distance.slt diff --git a/datafusion/functions-nested/src/cosine_distance.rs b/datafusion/functions-nested/src/cosine_distance.rs new file mode 100644 index 0000000000000..ab616ec7942d2 --- /dev/null +++ b/datafusion/functions-nested/src/cosine_distance.rs @@ -0,0 +1,322 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for cosine_distance function. + +use crate::utils::make_scalar_function; +use crate::vector_math::{convert_to_f64_array, dot_product_f64, magnitude_f64}; +use arrow::array::{ + Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait, +}; +use arrow::datatypes::{ + DataType, + DataType::{FixedSizeList, LargeList, List, Null}, +}; +use datafusion_common::cast::as_generic_list_array; +use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only}; +use datafusion_common::{Result, exec_err, plan_err, utils::take_function_args}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_functions::downcast_arg; +use datafusion_macros::user_doc; +use itertools::Itertools; +use std::sync::Arc; + +make_udf_expr_and_func!( + CosineDistance, + cosine_distance, + array1 array2, + "returns the cosine distance between two numeric arrays.", + cosine_distance_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the cosine distance between two input arrays of equal length. The cosine distance is defined as 1 - cosine_similarity, i.e. `1 - dot(a,b) / (||a|| * ||b||)`. Returns NULL if either array is NULL or contains only zeros.", + syntax_example = "cosine_distance(array1, array2)", + sql_example = r#"```sql +> select cosine_distance([1.0, 0.0], [0.0, 1.0]); ++-----------------------------------------------+ +| cosine_distance(List([1.0,0.0]),List([0.0,1.0])) | ++-----------------------------------------------+ +| 1.0 | ++-----------------------------------------------+ +```"#, + argument( + name = "array1", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "array2", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct CosineDistance { + signature: Signature, + aliases: Vec, +} + +impl Default for CosineDistance { + fn default() -> Self { + Self::new() + } +} + +impl CosineDistance { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["list_cosine_distance".to_string()], + } + } +} + +impl ScalarUDFImpl for CosineDistance { + fn name(&self) -> &str { + "cosine_distance" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [_, _] = take_function_args(self.name(), arg_types)?; + let coercion = Some(&ListCoercion::FixedSizedListToList); + let arg_types = arg_types.iter().map(|arg_type| { + if matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) { + Ok(coerced_type_with_base_type_only( + arg_type, + &DataType::Float64, + coercion, + )) + } else { + plan_err!("{} does not support type {arg_type}", self.name()) + } + }); + + arg_types.try_collect() + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(cosine_distance_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn cosine_distance_inner(args: &[ArrayRef]) -> Result { + let [array1, array2] = take_function_args("cosine_distance", args)?; + match (array1.data_type(), array2.data_type()) { + (List(_), List(_)) => general_cosine_distance::(args), + (LargeList(_), LargeList(_)) => general_cosine_distance::(args), + (arg_type1, arg_type2) => { + exec_err!( + "cosine_distance does not support types {arg_type1} and {arg_type2}" + ) + } + } +} + +fn general_cosine_distance(arrays: &[ArrayRef]) -> Result { + let list_array1 = as_generic_list_array::(&arrays[0])?; + let list_array2 = as_generic_list_array::(&arrays[1])?; + + let result = list_array1 + .iter() + .zip(list_array2.iter()) + .map(|(arr1, arr2)| compute_cosine_distance(arr1, arr2)) + .collect::>()?; + + Ok(Arc::new(result) as ArrayRef) +} + +/// Computes the cosine distance between two arrays: 1 - dot(a,b) / (||a|| * ||b||) +fn compute_cosine_distance( + arr1: Option, + arr2: Option, +) -> Result> { + let value1 = match arr1 { + Some(arr) => arr, + None => return Ok(None), + }; + let value2 = match arr2 { + Some(arr) => arr, + None => return Ok(None), + }; + + let mut value1 = value1; + let mut value2 = value2; + + loop { + match value1.data_type() { + List(_) => { + if downcast_arg!(value1, ListArray).null_count() > 0 { + return Ok(None); + } + value1 = downcast_arg!(value1, ListArray).value(0); + } + LargeList(_) => { + if downcast_arg!(value1, LargeListArray).null_count() > 0 { + return Ok(None); + } + value1 = downcast_arg!(value1, LargeListArray).value(0); + } + _ => break, + } + + match value2.data_type() { + List(_) => { + if downcast_arg!(value2, ListArray).null_count() > 0 { + return Ok(None); + } + value2 = downcast_arg!(value2, ListArray).value(0); + } + LargeList(_) => { + if downcast_arg!(value2, LargeListArray).null_count() > 0 { + return Ok(None); + } + value2 = downcast_arg!(value2, LargeListArray).value(0); + } + _ => break, + } + } + + if value1.null_count() != 0 || value2.null_count() != 0 { + return Ok(None); + } + + let values1 = convert_to_f64_array(&value1)?; + let values2 = convert_to_f64_array(&value2)?; + + if values1.len() != values2.len() { + return exec_err!("Both arrays must have the same length"); + } + + let dot = dot_product_f64(&values1, &values2); + let mag1 = magnitude_f64(&values1); + let mag2 = magnitude_f64(&values2); + + if mag1 == 0.0 || mag2 == 0.0 { + return Ok(None); + } + + Ok(Some(1.0 - dot / (mag1 * mag2))) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Float64Array, ListArray}; + use arrow::buffer::OffsetBuffer; + use arrow::datatypes::{DataType, Field}; + use std::sync::Arc; + + fn make_f64_list_array(values: Vec>>>) -> ArrayRef { + let mut flat: Vec> = Vec::new(); + let mut offsets: Vec = vec![0]; + for v in &values { + match v { + Some(inner) => { + flat.extend(inner); + offsets.push(flat.len() as i32); + } + None => { + offsets.push(flat.len() as i32); + } + } + } + let values_array = Arc::new(Float64Array::from(flat)) as ArrayRef; + let field = Arc::new(Field::new_list_field(DataType::Float64, true)); + let offset_buffer = OffsetBuffer::new(offsets.into()); + let null_buffer = arrow::buffer::NullBuffer::from( + values.iter().map(|v| v.is_some()).collect::>(), + ); + Arc::new(ListArray::new( + field, + offset_buffer, + values_array, + Some(null_buffer), + )) + } + + #[test] + fn test_cosine_distance_orthogonal() { + let arr1 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(0.0)])]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(0.0), Some(1.0)])]); + let result = cosine_distance_inner(&[arr1, arr2]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert!((result.value(0) - 1.0).abs() < 1e-10); + } + + #[test] + fn test_cosine_distance_identical() { + let arr1 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0), Some(3.0)])]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0), Some(3.0)])]); + let result = cosine_distance_inner(&[arr1, arr2]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert!(result.value(0).abs() < 1e-10); + } + + #[test] + fn test_cosine_distance_opposite() { + let arr1 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(0.0)])]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(-1.0), Some(0.0)])]); + let result = cosine_distance_inner(&[arr1, arr2]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert!((result.value(0) - 2.0).abs() < 1e-10); + } + + #[test] + fn test_cosine_distance_null_array() { + let arr1 = make_f64_list_array(vec![None]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0)])]); + let result = cosine_distance_inner(&[arr1, arr2]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert!(result.is_null(0)); + } + + #[test] + fn test_cosine_distance_mismatched_lengths() { + let arr1 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0)])]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(1.0)])]); + let result = cosine_distance_inner(&[arr1, arr2]); + assert!(result.is_err()); + } + + #[test] + fn test_cosine_distance_zero_magnitude() { + let arr1 = make_f64_list_array(vec![Some(vec![Some(0.0), Some(0.0)])]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(0.0)])]); + let result = cosine_distance_inner(&[arr1, arr2]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert!(result.is_null(0)); + } +} diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index 99b25ec96454b..716e72790b70d 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -41,6 +41,7 @@ pub mod array_has; pub mod arrays_zip; pub mod cardinality; pub mod concat; +pub mod cosine_distance; pub mod dimension; pub mod distance; pub mod empty; @@ -68,6 +69,7 @@ pub mod set_ops; pub mod sort; pub mod string; pub mod utils; +pub mod vector_math; use datafusion_common::Result; use datafusion_execution::FunctionRegistry; @@ -85,6 +87,7 @@ pub mod expr_fn { pub use super::concat::array_append; pub use super::concat::array_concat; pub use super::concat::array_prepend; + pub use super::cosine_distance::cosine_distance; pub use super::dimension::array_dims; pub use super::dimension::array_ndims; pub use super::distance::array_distance; @@ -150,6 +153,7 @@ pub fn all_default_nested_functions() -> Vec> { array_has::array_has_any_udf(), empty::array_empty_udf(), length::array_length_udf(), + cosine_distance::cosine_distance_udf(), distance::array_distance_udf(), flatten::flatten_udf(), min_max::array_max_udf(), diff --git a/datafusion/functions-nested/src/vector_math.rs b/datafusion/functions-nested/src/vector_math.rs new file mode 100644 index 0000000000000..02b8772cab915 --- /dev/null +++ b/datafusion/functions-nested/src/vector_math.rs @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Shared vector math primitives used by cosine_distance, inner_product, +//! array_normalize, and related functions. + +use arrow::array::{ArrayRef, Float64Array}; +use datafusion_common::cast::{ + as_float32_array, as_float64_array, as_int32_array, as_int64_array, +}; +use datafusion_common::{Result, exec_err}; + +/// Converts an array of any numeric type to a Float64Array. +pub fn convert_to_f64_array(array: &ArrayRef) -> Result { + match array.data_type() { + arrow::datatypes::DataType::Float64 => Ok(as_float64_array(array)?.clone()), + arrow::datatypes::DataType::Float32 => { + let array = as_float32_array(array)?; + Ok(array.iter().map(|v| v.map(|v| v as f64)).collect()) + } + arrow::datatypes::DataType::Int64 => { + let array = as_int64_array(array)?; + Ok(array.iter().map(|v| v.map(|v| v as f64)).collect()) + } + arrow::datatypes::DataType::Int32 => { + let array = as_int32_array(array)?; + Ok(array.iter().map(|v| v.map(|v| v as f64)).collect()) + } + _ => exec_err!("Unsupported array type for conversion to Float64Array"), + } +} + +/// Computes dot product: sum(a\[i\] * b\[i\]) +pub fn dot_product_f64(a: &Float64Array, b: &Float64Array) -> f64 { + a.iter() + .zip(b.iter()) + .map(|(v1, v2)| v1.unwrap_or(0.0) * v2.unwrap_or(0.0)) + .sum() +} + +/// Computes sum of squares: sum(a\[i\]^2) +pub fn sum_of_squares_f64(a: &Float64Array) -> f64 { + a.iter() + .map(|v| { + let val = v.unwrap_or(0.0); + val * val + }) + .sum() +} + +/// Computes magnitude (L2 norm): sqrt(sum(a\[i\]^2)) +pub fn magnitude_f64(a: &Float64Array) -> f64 { + sum_of_squares_f64(a).sqrt() +} diff --git a/datafusion/sqllogictest/test_files/cosine_distance.slt b/datafusion/sqllogictest/test_files/cosine_distance.slt new file mode 100644 index 0000000000000..4daba225250be --- /dev/null +++ b/datafusion/sqllogictest/test_files/cosine_distance.slt @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +## cosine_distance + +# Orthogonal vectors: distance = 1.0 +query R +select cosine_distance([1.0, 0.0], [0.0, 1.0]); +---- +1 + +# Identical vectors: distance = 0.0 +query R +select cosine_distance([1.0, 2.0, 3.0], [1.0, 2.0, 3.0]); +---- +0 + +# Opposite vectors: distance = 2.0 +query R +select cosine_distance([1.0, 0.0], [-1.0, 0.0]); +---- +2 + +# 45-degree angle: distance ≈ 0.293 +query R +select round(cosine_distance([1.0, 0.0], [1.0, 1.0]), 3); +---- +0.293 + +# NULL input (bare NULL is not a list type, errors at planning) +query error cosine_distance does not support type +select cosine_distance(NULL, [1.0, 2.0]); + +# NULL in second position +query error cosine_distance does not support type +select cosine_distance([1.0, 2.0], NULL); + +# Zero vector returns NULL (undefined cosine similarity) +query R +select cosine_distance([0.0, 0.0], [1.0, 2.0]); +---- +NULL + +# Mismatched lengths error +query error Both arrays must have the same length +select cosine_distance([1.0, 2.0], [1.0]); + +# LargeList support +query R +select cosine_distance( + arrow_cast([1.0, 0.0], 'LargeList(Float64)'), + arrow_cast([0.0, 1.0], 'LargeList(Float64)') +); +---- +1 + +# Integer arrays (coerced to Float64) +query R +select cosine_distance([1, 0], [0, 1]); +---- +1 + +# Multi-row query +query R +select cosine_distance(column1, column2) from (values + (make_array(1.0, 0.0), make_array(0.0, 1.0)), + (make_array(1.0, 1.0), make_array(1.0, 1.0)), + (make_array(1.0, 0.0), make_array(-1.0, 0.0)) +) as t(column1, column2); +---- +1 +0 +2 + +# list_cosine_distance alias +query R +select list_cosine_distance([1.0, 0.0], [0.0, 1.0]); +---- +1 + +# Empty arrays return NULL (magnitude = 0) +query R +select cosine_distance(arrow_cast(make_array(), 'List(Float64)'), arrow_cast(make_array(), 'List(Float64)')); +---- +NULL + +# No arguments error +query error cosine_distance function requires 2 arguments, got 0 +select cosine_distance(); + +# Return type is Float64 +query RT +select cosine_distance([1.0, 0.0], [0.0, 1.0]), arrow_typeof(cosine_distance([1.0, 0.0], [0.0, 1.0])); +---- +1 Float64 diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index d1b80f1f90b8b..84455402571ee 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -3279,6 +3279,7 @@ _Alias of [current_date](#current_date)._ - [arrays_overlap](#arrays_overlap) - [arrays_zip](#arrays_zip) - [cardinality](#cardinality) +- [cosine_distance](#cosine_distance) - [empty](#empty) - [flatten](#flatten) - [generate_series](#generate_series) @@ -3287,6 +3288,7 @@ _Alias of [current_date](#current_date)._ - [list_cat](#list_cat) - [list_concat](#list_concat) - [list_contains](#list_contains) +- [list_cosine_distance](#list_cosine_distance) - [list_dims](#list_dims) - [list_distance](#list_distance) - [list_distinct](#list_distinct) @@ -4441,6 +4443,34 @@ cardinality(array) +--------------------------------------+ ``` +### `cosine_distance` + +Returns the cosine distance between two input arrays of equal length. The cosine distance is defined as 1 - cosine_similarity, i.e. `1 - dot(a,b) / (||a|| * ||b||)`. Returns NULL if either array is NULL or contains only zeros. + +```sql +cosine_distance(array1, array2) +``` + +#### Arguments + +- **array1**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **array2**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select cosine_distance([1.0, 0.0], [0.0, 1.0]); ++-----------------------------------------------+ +| cosine_distance(List([1.0,0.0]),List([0.0,1.0])) | ++-----------------------------------------------+ +| 1.0 | ++-----------------------------------------------+ +``` + +#### Aliases + +- list_cosine_distance + ### `empty` Returns 1 for an empty array or 0 for a non-empty array. @@ -4543,6 +4573,10 @@ _Alias of [array_concat](#array_concat)._ _Alias of [array_has](#array_has)._ +### `list_cosine_distance` + +_Alias of [cosine_distance](#cosine_distance)._ + ### `list_dims` _Alias of [array_dims](#array_dims)._ From fc3ee90a80c6ebdb93ecda28553e82193a2594be Mon Sep 17 00:00:00 2001 From: Christian McArthur Date: Fri, 17 Apr 2026 19:10:43 -0400 Subject: [PATCH 2/3] feat: rework cosine_distance per review feedback Addresses review comments on #21542: - Iterate list offsets/values directly instead of per-row ArrayRef downcast - Remove nested-list unwrap loop (function does not support nested lists) - Drop convert_to_f64_array wrapper (coerce_types already guarantees Float64) - Remove duplicate Rust unit tests now covered by SLT - More descriptive error message for mismatched list lengths - Delete now-unused vector_math module; inline math into sole caller Adds SLT coverage for NULL-element-in-list behavior previously tested only in Rust unit tests. Part of #21536. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../functions-nested/src/cosine_distance.rs | 212 ++++-------------- datafusion/functions-nested/src/lib.rs | 1 - .../functions-nested/src/vector_math.rs | 68 ------ .../test_files/cosine_distance.slt | 8 +- 4 files changed, 53 insertions(+), 236 deletions(-) delete mode 100644 datafusion/functions-nested/src/vector_math.rs diff --git a/datafusion/functions-nested/src/cosine_distance.rs b/datafusion/functions-nested/src/cosine_distance.rs index ab616ec7942d2..5eccaa4f78fcb 100644 --- a/datafusion/functions-nested/src/cosine_distance.rs +++ b/datafusion/functions-nested/src/cosine_distance.rs @@ -18,22 +18,18 @@ //! [`ScalarUDFImpl`] definitions for cosine_distance function. use crate::utils::make_scalar_function; -use crate::vector_math::{convert_to_f64_array, dot_product_f64, magnitude_f64}; -use arrow::array::{ - Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait, -}; +use arrow::array::{Array, ArrayRef, Float64Array, OffsetSizeTrait}; use arrow::datatypes::{ DataType, DataType::{FixedSizeList, LargeList, List, Null}, }; -use datafusion_common::cast::as_generic_list_array; +use datafusion_common::cast::{as_float64_array, as_generic_list_array}; use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only}; use datafusion_common::{Result, exec_err, plan_err, utils::take_function_args}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; -use datafusion_functions::downcast_arg; use datafusion_macros::user_doc; use itertools::Itertools; use std::sync::Arc; @@ -149,174 +145,58 @@ fn general_cosine_distance(arrays: &[ArrayRef]) -> Result(&arrays[0])?; let list_array2 = as_generic_list_array::(&arrays[1])?; - let result = list_array1 - .iter() - .zip(list_array2.iter()) - .map(|(arr1, arr2)| compute_cosine_distance(arr1, arr2)) - .collect::>()?; - - Ok(Arc::new(result) as ArrayRef) -} - -/// Computes the cosine distance between two arrays: 1 - dot(a,b) / (||a|| * ||b||) -fn compute_cosine_distance( - arr1: Option, - arr2: Option, -) -> Result> { - let value1 = match arr1 { - Some(arr) => arr, - None => return Ok(None), - }; - let value2 = match arr2 { - Some(arr) => arr, - None => return Ok(None), - }; - - let mut value1 = value1; - let mut value2 = value2; + let values1 = as_float64_array(list_array1.values())?; + let values2 = as_float64_array(list_array2.values())?; + let offsets1 = list_array1.value_offsets(); + let offsets2 = list_array2.value_offsets(); - loop { - match value1.data_type() { - List(_) => { - if downcast_arg!(value1, ListArray).null_count() > 0 { - return Ok(None); - } - value1 = downcast_arg!(value1, ListArray).value(0); - } - LargeList(_) => { - if downcast_arg!(value1, LargeListArray).null_count() > 0 { - return Ok(None); - } - value1 = downcast_arg!(value1, LargeListArray).value(0); - } - _ => break, + let mut builder = Float64Array::builder(list_array1.len()); + for row in 0..list_array1.len() { + if list_array1.is_null(row) || list_array2.is_null(row) { + builder.append_null(); + continue; } - match value2.data_type() { - List(_) => { - if downcast_arg!(value2, ListArray).null_count() > 0 { - return Ok(None); - } - value2 = downcast_arg!(value2, ListArray).value(0); - } - LargeList(_) => { - if downcast_arg!(value2, LargeListArray).null_count() > 0 { - return Ok(None); - } - value2 = downcast_arg!(value2, LargeListArray).value(0); - } - _ => break, + let start1 = offsets1[row].as_usize(); + let end1 = offsets1[row + 1].as_usize(); + let start2 = offsets2[row].as_usize(); + let end2 = offsets2[row + 1].as_usize(); + let len1 = end1 - start1; + let len2 = end2 - start2; + + if len1 != len2 { + return exec_err!( + "cosine_distance requires both list inputs to have the same length, got {len1} and {len2}" + ); } - } - - if value1.null_count() != 0 || value2.null_count() != 0 { - return Ok(None); - } - - let values1 = convert_to_f64_array(&value1)?; - let values2 = convert_to_f64_array(&value2)?; - if values1.len() != values2.len() { - return exec_err!("Both arrays must have the same length"); - } - - let dot = dot_product_f64(&values1, &values2); - let mag1 = magnitude_f64(&values1); - let mag2 = magnitude_f64(&values2); - - if mag1 == 0.0 || mag2 == 0.0 { - return Ok(None); - } - - Ok(Some(1.0 - dot / (mag1 * mag2))) -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow::array::{Float64Array, ListArray}; - use arrow::buffer::OffsetBuffer; - use arrow::datatypes::{DataType, Field}; - use std::sync::Arc; - - fn make_f64_list_array(values: Vec>>>) -> ArrayRef { - let mut flat: Vec> = Vec::new(); - let mut offsets: Vec = vec![0]; - for v in &values { - match v { - Some(inner) => { - flat.extend(inner); - offsets.push(flat.len() as i32); - } - None => { - offsets.push(flat.len() as i32); - } - } + let slice1 = values1.slice(start1, len1); + let slice2 = values2.slice(start2, len2); + if slice1.null_count() != 0 || slice2.null_count() != 0 { + builder.append_null(); + continue; } - let values_array = Arc::new(Float64Array::from(flat)) as ArrayRef; - let field = Arc::new(Field::new_list_field(DataType::Float64, true)); - let offset_buffer = OffsetBuffer::new(offsets.into()); - let null_buffer = arrow::buffer::NullBuffer::from( - values.iter().map(|v| v.is_some()).collect::>(), - ); - Arc::new(ListArray::new( - field, - offset_buffer, - values_array, - Some(null_buffer), - )) - } - - #[test] - fn test_cosine_distance_orthogonal() { - let arr1 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(0.0)])]); - let arr2 = make_f64_list_array(vec![Some(vec![Some(0.0), Some(1.0)])]); - let result = cosine_distance_inner(&[arr1, arr2]).unwrap(); - let result = result.as_any().downcast_ref::().unwrap(); - assert!((result.value(0) - 1.0).abs() < 1e-10); - } - #[test] - fn test_cosine_distance_identical() { - let arr1 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0), Some(3.0)])]); - let arr2 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0), Some(3.0)])]); - let result = cosine_distance_inner(&[arr1, arr2]).unwrap(); - let result = result.as_any().downcast_ref::().unwrap(); - assert!(result.value(0).abs() < 1e-10); - } - - #[test] - fn test_cosine_distance_opposite() { - let arr1 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(0.0)])]); - let arr2 = make_f64_list_array(vec![Some(vec![Some(-1.0), Some(0.0)])]); - let result = cosine_distance_inner(&[arr1, arr2]).unwrap(); - let result = result.as_any().downcast_ref::().unwrap(); - assert!((result.value(0) - 2.0).abs() < 1e-10); - } - - #[test] - fn test_cosine_distance_null_array() { - let arr1 = make_f64_list_array(vec![None]); - let arr2 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0)])]); - let result = cosine_distance_inner(&[arr1, arr2]).unwrap(); - let result = result.as_any().downcast_ref::().unwrap(); - assert!(result.is_null(0)); - } + let vals1 = slice1.values(); + let vals2 = slice2.values(); + + let mut dot = 0.0; + let mut sq1 = 0.0; + let mut sq2 = 0.0; + for i in 0..len1 { + let a = vals1[i]; + let b = vals2[i]; + dot += a * b; + sq1 += a * a; + sq2 += b * b; + } - #[test] - fn test_cosine_distance_mismatched_lengths() { - let arr1 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0)])]); - let arr2 = make_f64_list_array(vec![Some(vec![Some(1.0)])]); - let result = cosine_distance_inner(&[arr1, arr2]); - assert!(result.is_err()); + if sq1 == 0.0 || sq2 == 0.0 { + builder.append_null(); + } else { + builder.append_value(1.0 - dot / (sq1.sqrt() * sq2.sqrt())); + } } - #[test] - fn test_cosine_distance_zero_magnitude() { - let arr1 = make_f64_list_array(vec![Some(vec![Some(0.0), Some(0.0)])]); - let arr2 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(0.0)])]); - let result = cosine_distance_inner(&[arr1, arr2]).unwrap(); - let result = result.as_any().downcast_ref::().unwrap(); - assert!(result.is_null(0)); - } + Ok(Arc::new(builder.finish()) as ArrayRef) } diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index 716e72790b70d..e2de215750faa 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -69,7 +69,6 @@ pub mod set_ops; pub mod sort; pub mod string; pub mod utils; -pub mod vector_math; use datafusion_common::Result; use datafusion_execution::FunctionRegistry; diff --git a/datafusion/functions-nested/src/vector_math.rs b/datafusion/functions-nested/src/vector_math.rs deleted file mode 100644 index 02b8772cab915..0000000000000 --- a/datafusion/functions-nested/src/vector_math.rs +++ /dev/null @@ -1,68 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Shared vector math primitives used by cosine_distance, inner_product, -//! array_normalize, and related functions. - -use arrow::array::{ArrayRef, Float64Array}; -use datafusion_common::cast::{ - as_float32_array, as_float64_array, as_int32_array, as_int64_array, -}; -use datafusion_common::{Result, exec_err}; - -/// Converts an array of any numeric type to a Float64Array. -pub fn convert_to_f64_array(array: &ArrayRef) -> Result { - match array.data_type() { - arrow::datatypes::DataType::Float64 => Ok(as_float64_array(array)?.clone()), - arrow::datatypes::DataType::Float32 => { - let array = as_float32_array(array)?; - Ok(array.iter().map(|v| v.map(|v| v as f64)).collect()) - } - arrow::datatypes::DataType::Int64 => { - let array = as_int64_array(array)?; - Ok(array.iter().map(|v| v.map(|v| v as f64)).collect()) - } - arrow::datatypes::DataType::Int32 => { - let array = as_int32_array(array)?; - Ok(array.iter().map(|v| v.map(|v| v as f64)).collect()) - } - _ => exec_err!("Unsupported array type for conversion to Float64Array"), - } -} - -/// Computes dot product: sum(a\[i\] * b\[i\]) -pub fn dot_product_f64(a: &Float64Array, b: &Float64Array) -> f64 { - a.iter() - .zip(b.iter()) - .map(|(v1, v2)| v1.unwrap_or(0.0) * v2.unwrap_or(0.0)) - .sum() -} - -/// Computes sum of squares: sum(a\[i\]^2) -pub fn sum_of_squares_f64(a: &Float64Array) -> f64 { - a.iter() - .map(|v| { - let val = v.unwrap_or(0.0); - val * val - }) - .sum() -} - -/// Computes magnitude (L2 norm): sqrt(sum(a\[i\]^2)) -pub fn magnitude_f64(a: &Float64Array) -> f64 { - sum_of_squares_f64(a).sqrt() -} diff --git a/datafusion/sqllogictest/test_files/cosine_distance.slt b/datafusion/sqllogictest/test_files/cosine_distance.slt index 4daba225250be..ee72c106b3ec0 100644 --- a/datafusion/sqllogictest/test_files/cosine_distance.slt +++ b/datafusion/sqllogictest/test_files/cosine_distance.slt @@ -56,9 +56,15 @@ select cosine_distance([0.0, 0.0], [1.0, 2.0]); NULL # Mismatched lengths error -query error Both arrays must have the same length +query error cosine_distance requires both list inputs to have the same length select cosine_distance([1.0, 2.0], [1.0]); +# NULL element inside a list returns NULL for that row +query R +select cosine_distance([1.0, 2.0, NULL], [1.0, 2.0, 3.0]); +---- +NULL + # LargeList support query R select cosine_distance( From ce312cc00734b520ec689b6965ccf164a92abf60 Mon Sep 17 00:00:00 2001 From: Christian McArthur Date: Mon, 20 Apr 2026 12:27:38 -0400 Subject: [PATCH 3/3] feat: handle mixed-type and NULL inputs for cosine_distance MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses round-2 review comments on #21542: - Widen container variant in coerce_types when inputs mix List and LargeList (or FixedSizeList), so mixed-type calls like `cosine_distance([1.0, 0.0], arrow_cast([0.0, 1.0], 'LargeList(Float64)'))` succeed. Follows the pattern from PR #21704 (ArrayConcat). - Coerce bare NULL inputs to a matching list variant so `cosine_distance(NULL, [1.0, 2.0])` returns NULL instead of erroring. - Drop the `list_cosine_distance` alias — the base name is not `array_cosine_distance`, so the `array_X` -> `list_X` convention does not apply. - Expand SLT coverage: mixed-type variants, FixedSizeList inputs, Float32 and Int64 inner types, bare NULL in each position, NULL row in a multi-row VALUES, and an unsupported-type plan error case. Dispatch fallthrough in cosine_distance_inner is now unreachable after the coerce_types widening, changed from exec_err! to internal_err!. Part of #21536. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../functions-nested/src/cosine_distance.rs | 63 +++++++++------ .../test_files/cosine_distance.slt | 78 +++++++++++++++---- .../source/user-guide/sql/scalar_functions.md | 9 --- 3 files changed, 105 insertions(+), 45 deletions(-) diff --git a/datafusion/functions-nested/src/cosine_distance.rs b/datafusion/functions-nested/src/cosine_distance.rs index 5eccaa4f78fcb..335856075046c 100644 --- a/datafusion/functions-nested/src/cosine_distance.rs +++ b/datafusion/functions-nested/src/cosine_distance.rs @@ -22,16 +22,18 @@ use arrow::array::{Array, ArrayRef, Float64Array, OffsetSizeTrait}; use arrow::datatypes::{ DataType, DataType::{FixedSizeList, LargeList, List, Null}, + Field, }; use datafusion_common::cast::{as_float64_array, as_generic_list_array}; use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only}; -use datafusion_common::{Result, exec_err, plan_err, utils::take_function_args}; +use datafusion_common::{ + Result, exec_err, internal_err, plan_err, utils::take_function_args, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; -use itertools::Itertools; use std::sync::Arc; make_udf_expr_and_func!( @@ -66,7 +68,6 @@ make_udf_expr_and_func!( #[derive(Debug, PartialEq, Eq, Hash)] pub struct CosineDistance { signature: Signature, - aliases: Vec, } impl Default for CosineDistance { @@ -79,7 +80,6 @@ impl CosineDistance { pub fn new() -> Self { Self { signature: Signature::user_defined(Volatility::Immutable), - aliases: vec!["list_cosine_distance".to_string()], } } } @@ -100,29 +100,48 @@ impl ScalarUDFImpl for CosineDistance { fn coerce_types(&self, arg_types: &[DataType]) -> Result> { let [_, _] = take_function_args(self.name(), arg_types)?; let coercion = Some(&ListCoercion::FixedSizedListToList); - let arg_types = arg_types.iter().map(|arg_type| { - if matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) { - Ok(coerced_type_with_base_type_only( + + for arg_type in arg_types { + if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) { + return plan_err!("{} does not support type {arg_type}", self.name()); + } + } + + // If any input is `LargeList`, both sides must be widened to `LargeList` + // so the runtime dispatch in `cosine_distance_inner` sees a homogeneous + // pair. Follows the pattern in `ArrayConcat::coerce_types`. + let any_large_list = arg_types.iter().any(|t| matches!(t, LargeList(_))); + + let coerced = arg_types + .iter() + .map(|arg_type| { + if matches!(arg_type, Null) { + let field = Arc::new(Field::new_list_field(DataType::Float64, true)); + return if any_large_list { + LargeList(field) + } else { + List(field) + }; + } + let coerced = coerced_type_with_base_type_only( arg_type, &DataType::Float64, coercion, - )) - } else { - plan_err!("{} does not support type {arg_type}", self.name()) - } - }); - - arg_types.try_collect() + ); + match coerced { + List(field) if any_large_list => LargeList(field), + other => other, + } + }) + .collect(); + + Ok(coerced) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(cosine_distance_inner)(&args.args) } - fn aliases(&self) -> &[String] { - &self.aliases - } - fn documentation(&self) -> Option<&Documentation> { self.doc() } @@ -133,11 +152,9 @@ fn cosine_distance_inner(args: &[ArrayRef]) -> Result { match (array1.data_type(), array2.data_type()) { (List(_), List(_)) => general_cosine_distance::(args), (LargeList(_), LargeList(_)) => general_cosine_distance::(args), - (arg_type1, arg_type2) => { - exec_err!( - "cosine_distance does not support types {arg_type1} and {arg_type2}" - ) - } + (arg_type1, arg_type2) => internal_err!( + "cosine_distance received unexpected types after coercion: {arg_type1} and {arg_type2}" + ), } } diff --git a/datafusion/sqllogictest/test_files/cosine_distance.slt b/datafusion/sqllogictest/test_files/cosine_distance.slt index ee72c106b3ec0..9142aac8cf684 100644 --- a/datafusion/sqllogictest/test_files/cosine_distance.slt +++ b/datafusion/sqllogictest/test_files/cosine_distance.slt @@ -41,13 +41,23 @@ select round(cosine_distance([1.0, 0.0], [1.0, 1.0]), 3); ---- 0.293 -# NULL input (bare NULL is not a list type, errors at planning) -query error cosine_distance does not support type +# Bare NULL input returns NULL +query R select cosine_distance(NULL, [1.0, 2.0]); +---- +NULL -# NULL in second position -query error cosine_distance does not support type +# NULL in second position returns NULL +query R select cosine_distance([1.0, 2.0], NULL); +---- +NULL + +# Both NULL returns NULL +query R +select cosine_distance(NULL, NULL); +---- +NULL # Zero vector returns NULL (undefined cosine similarity) query R @@ -74,29 +84,71 @@ select cosine_distance( ---- 1 -# Integer arrays (coerced to Float64) +# Mixed List + LargeList: widens to LargeList +query R +select cosine_distance([1.0, 0.0], arrow_cast([0.0, 1.0], 'LargeList(Float64)')); +---- +1 + +# Reverse order: LargeList + List also widens +query R +select cosine_distance(arrow_cast([1.0, 0.0], 'LargeList(Float64)'), [0.0, 1.0]); +---- +1 + +# FixedSizeList inputs (coerced to List) +query R +select cosine_distance( + arrow_cast([1.0, 0.0], 'FixedSizeList(2, Float64)'), + arrow_cast([0.0, 1.0], 'FixedSizeList(2, Float64)') +); +---- +1 + +# FixedSizeList + LargeList: widens to LargeList +query R +select cosine_distance( + arrow_cast([1.0, 0.0], 'FixedSizeList(2, Float64)'), + arrow_cast([0.0, 1.0], 'LargeList(Float64)') +); +---- +1 + +# Float32 inner type (coerced to Float64) +query R +select cosine_distance(arrow_cast([1.0, 0.0], 'List(Float32)'), [0.0, 1.0]); +---- +1 + +# Int64 inner type (coerced to Float64) +query R +select cosine_distance(arrow_cast([1, 0], 'List(Int64)'), arrow_cast([0, 1], 'List(Int64)')); +---- +1 + +# Integer literals (coerced to Float64) query R select cosine_distance([1, 0], [0, 1]); ---- 1 -# Multi-row query +# Unsupported non-list input (plan error) +query error cosine_distance does not support type +select cosine_distance(1, 2); + +# Multi-row query with NULL row propagation query R select cosine_distance(column1, column2) from (values (make_array(1.0, 0.0), make_array(0.0, 1.0)), (make_array(1.0, 1.0), make_array(1.0, 1.0)), - (make_array(1.0, 0.0), make_array(-1.0, 0.0)) + (make_array(1.0, 0.0), make_array(-1.0, 0.0)), + (make_array(1.0, 0.0), NULL) ) as t(column1, column2); ---- 1 0 2 - -# list_cosine_distance alias -query R -select list_cosine_distance([1.0, 0.0], [0.0, 1.0]); ----- -1 +NULL # Empty arrays return NULL (magnitude = 0) query R diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 84455402571ee..d88d419749e78 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -3288,7 +3288,6 @@ _Alias of [current_date](#current_date)._ - [list_cat](#list_cat) - [list_concat](#list_concat) - [list_contains](#list_contains) -- [list_cosine_distance](#list_cosine_distance) - [list_dims](#list_dims) - [list_distance](#list_distance) - [list_distinct](#list_distinct) @@ -4467,10 +4466,6 @@ cosine_distance(array1, array2) +-----------------------------------------------+ ``` -#### Aliases - -- list_cosine_distance - ### `empty` Returns 1 for an empty array or 0 for a non-empty array. @@ -4573,10 +4568,6 @@ _Alias of [array_concat](#array_concat)._ _Alias of [array_has](#array_has)._ -### `list_cosine_distance` - -_Alias of [cosine_distance](#cosine_distance)._ - ### `list_dims` _Alias of [array_dims](#array_dims)._