Skip to content

feat: add cosine_distance scalar function#21542

Open
crm26 wants to merge 3 commits intoapache:mainfrom
crm26:feat/cosine-distance
Open

feat: add cosine_distance scalar function#21542
crm26 wants to merge 3 commits intoapache:mainfrom
crm26:feat/cosine-distance

Conversation

@crm26
Copy link
Copy Markdown

@crm26 crm26 commented Apr 10, 2026

Summary

  • Adds cosine_distance(array1, array2) / list_cosine_distance — computes cosine distance (1 - cosine similarity) between two numeric arrays
  • Introduces shared vector_math.rs primitives (dot_product_f64, magnitude_f64, convert_to_f64_array) for reuse by follow-on vector functions
  • Returns NULL for zero-magnitude vectors; errors on mismatched lengths
  • Supports List, LargeList, and FixedSizeList with any numeric element type

Part of #21536 — first in a series of split PRs (replacing #21371).

Test plan

  • Unit tests: identical, orthogonal, opposite, 45-degree, zero-magnitude, mismatched-length, NULL, multi-row
  • sqllogictest: cosine_distance.slt covering all edge cases including empty arrays, LargeList, integer coercion, alias, return type
  • Full slt suite (426/426 pass)
  • cargo clippy, cargo fmt, taplo, prettier, cargo machete — all clean

🤖 Generated with Claude Code

Add cosine_distance (and list_cosine_distance alias) to compute cosine
distance between two numeric arrays. Includes shared vector math
primitives in vector_math.rs for reuse by follow-on functions.

Part of apache#21536.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@github-actions github-actions bot added documentation Improvements or additions to documentation sqllogictest SQL Logic Tests (.slt) functions Changes to functions implementation labels Apr 10, 2026
Comment thread datafusion/functions-nested/src/cosine_distance.rs Outdated
Comment thread datafusion/functions-nested/src/cosine_distance.rs Outdated
Comment thread datafusion/functions-nested/src/cosine_distance.rs Outdated
Comment thread datafusion/functions-nested/src/cosine_distance.rs Outdated
Comment thread datafusion/functions-nested/src/cosine_distance.rs Outdated
Comment thread datafusion/functions-nested/src/lib.rs Outdated
Comment thread datafusion/functions-nested/src/vector_math.rs Outdated
Addresses review comments on apache#21542:
- Iterate list offsets/values directly instead of per-row ArrayRef downcast
- Remove nested-list unwrap loop (function does not support nested lists)
- Drop convert_to_f64_array wrapper (coerce_types already guarantees Float64)
- Remove duplicate Rust unit tests now covered by SLT
- More descriptive error message for mismatched list lengths
- Delete now-unused vector_math module; inline math into sole caller

Adds SLT coverage for NULL-element-in-list behavior previously tested
only in Rust unit tests.

Part of apache#21536.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@Jefffrey Jefffrey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good

Comment thread datafusion/functions-nested/src/cosine_distance.rs
Comment thread datafusion/functions-nested/src/cosine_distance.rs Outdated
Comment thread datafusion/sqllogictest/test_files/cosine_distance.slt Outdated
Comment thread datafusion/sqllogictest/test_files/cosine_distance.slt Outdated
@crm26
Copy link
Copy Markdown
Author

crm26 commented Apr 18, 2026

Thanks for the detailed review, @Jefffrey. Rework pushed in fc3ee90. Walking through each comment:

1. Iterate via offsets/values, not per-row ArrayRef downcast (cosine_distance.rs:157)
Rewrote general_cosine_distance to downcast list_array.values() once to &Float64Array, then slice by value_offsets() per row. The inner for i in 0..len loop reads directly from the contiguous ScalarBuffer<f64> — no per-row downcast, no Option<f64> unwrapping.

2. Nested-list unwrap loop (cosine_distance.rs:178)
Removed entirely. The function is not intended to support nested lists, and coerce_types rejects anything other than List/LargeList/FixedSizeList of a numeric inner type.

3. Redundant null/Float64 check (cosine_distance.rs:217)
Agreed — coerce_types calls coerced_type_with_base_type_only with Float64 as the base type, which guarantees the inner type is Float64 by the time we hit invoke_with_args. Dropped convert_to_f64_array entirely and replaced with a direct as_float64_array downcast.

4. Ambiguous length-mismatch error wording (cosine_distance.rs:220)
Updated to "cosine_distance requires both list inputs to have the same length, got {len1} and {len2}". Now explicit that the lengths are the list elements' lengths, not the outer array, and includes the observed values.

5. Duplicate Rust unit tests (cosine_distance.rs:235)
Removed the mod tests block. SLT coverage includes orthogonal, identical, opposite, 45-degree, zero-magnitude, mismatched lengths, LargeList, integer coercion, multi-row, alias, empty arrays, no-args, and return-type checks. Added one new SLT case for NULL-element-in-list to preserve that particular behavior the Rust tests were covering.

6. pub mod vector_math (lib.rs:72)
Moot after #7 — the whole file is deleted, so the declaration is gone.

7. Inline the math instead of a separate module (vector_math.rs:67)
Agreed — with only one caller, the indirection wasn't paying for itself. Deleted vector_math.rs and inlined dot/magnitude into the tight per-row loop in general_cosine_distance.

Full validation matrix (fmt --all, workspace clippy -D warnings, full + sqlite-extended SLT, CLI, doctests, feature-flag spot-checks, extended_tests workspace build, rustdoc, license, typos, machete, generated-doc regen) passed locally before push. Let me know if anything else needs tightening.

Addresses round-2 review comments on apache#21542:
- Widen container variant in coerce_types when inputs mix List and
  LargeList (or FixedSizeList), so mixed-type calls like
  `cosine_distance([1.0, 0.0], arrow_cast([0.0, 1.0], 'LargeList(Float64)'))`
  succeed. Follows the pattern from PR apache#21704 (ArrayConcat).
- Coerce bare NULL inputs to a matching list variant so
  `cosine_distance(NULL, [1.0, 2.0])` returns NULL instead of erroring.
- Drop the `list_cosine_distance` alias — the base name is not
  `array_cosine_distance`, so the `array_X` -> `list_X` convention does
  not apply.
- Expand SLT coverage: mixed-type variants, FixedSizeList inputs,
  Float32 and Int64 inner types, bare NULL in each position, NULL row
  in a multi-row VALUES, and an unsupported-type plan error case.

Dispatch fallthrough in cosine_distance_inner is now unreachable after
the coerce_types widening, changed from exec_err! to internal_err!.

Part of apache#21536.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@crm26
Copy link
Copy Markdown
Author

crm26 commented Apr 20, 2026

Thanks @Jefffrey. Round-3 pushed in ce312cc:

1. Mixed-type inputs (cosine_distance.rs:100)
Adopted the pattern from #21704: coerce_types now inspects the inputs, and when any is LargeList, widens both to LargeList. FixedSizeList is already normalized to List via FixedSizedListToList and then widened if needed. Our return type is scalar Float64 so we inspect the inputs (not return_type) to detect promotion. After widening, the dispatch in cosine_distance_inner only ever sees homogeneous pairs — the fallthrough arm is unreachable, so I switched it from exec_err! to internal_err!.

2. list_cosine_distance alias (cosine_distance.rs:82)
Dropped. The base name isn't array_cosine_distance, so the array_Xlist_X swap convention doesn't apply. Removed the alias field, initialization, and aliases() method. scalar_functions.md regenerated via ./dev/update_function_docs.sh shows the clean 9-line removal.

3. Bare NULL input (cosine_distance.slt:46)
Handled in coerce_types — when an input is Null, it's coerced to a matching list variant of Float64. At runtime, the Arrow cast produces an all-null list array; list_array.is_null(row) is true; the builder appends null. select cosine_distance(NULL, [1.0, 2.0]) now returns NULL instead of erroring.

4. Multi-row NULL coverage (cosine_distance.slt:88)
Applied your suggestion block verbatim — added (make_array(1.0, 0.0), NULL) to the multi-row VALUES with expected NULL at the bottom.

Additional SLT coverage added proactively:

  • Mixed (List, LargeList) in both orders
  • (FixedSizeList, FixedSizeList) and (FixedSizeList, LargeList) mixed
  • Float32 and explicit Int64 inner types (coerced to Float64)
  • Bare NULL in each position and both positions
  • Unsupported non-list input (cosine_distance(1, 2)) plan-error case

Full validation run clean (fmt, clippy, full + sqlite-extended SLT, CLI, doctests, feature-flag checks, extended_tests workspace build, rustdoc, license, typos, machete, generated-doc regen). Ready for another look.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation functions Changes to functions implementation sqllogictest SQL Logic Tests (.slt)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants