diff --git a/be/src/exprs/function/function_hamming_distance.cpp b/be/src/exprs/function/function_hamming_distance.cpp new file mode 100644 index 00000000000000..cd3c942607c183 --- /dev/null +++ b/be/src/exprs/function/function_hamming_distance.cpp @@ -0,0 +1,278 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include + +#include "common/status.h" +#include "core/column/column_nullable.h" +#include "core/column/column_string.h" +#include "core/data_type/data_type_number.h" +#include "core/string_ref.h" +#include "exprs/function/simple_function_factory.h" +#include "util/simd/vstring_function.h" + +namespace doris { +#include "common/compile_check_begin.h" + +class FunctionHammingDistance : public IFunction { +public: + using ResultDataType = DataTypeInt64; + using ResultPaddedPODArray = PaddedPODArray; + using ResultColumnType = ColumnVector; + + static constexpr auto name = "hamming_distance"; + + static FunctionPtr create() { return std::make_shared(); } + + String get_name() const override { return name; } + size_t get_number_of_arguments() const override { return 2; } + + DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { + const bool has_nullable = std::ranges::any_of( + arguments, [](const DataTypePtr& type) { return type->is_nullable(); }); + if (has_nullable) { + return make_nullable(std::make_shared()); + } + return std::make_shared(); + } + + bool use_default_implementation_for_nulls() const override { return false; } + + Status execute_impl(FunctionContext* /*context*/, Block& block, const ColumnNumbers& arguments, + uint32_t result, size_t input_rows_count) const override { + const auto& [left_col, left_const] = + unpack_if_const(block.get_by_position(arguments[0]).column); + const auto& [right_col, right_const] = + unpack_if_const(block.get_by_position(arguments[1]).column); + + const auto* left_nullable = check_and_get_column(left_col.get()); + const auto* right_nullable = check_and_get_column(right_col.get()); + + const IColumn* left_nested = + left_nullable ? &left_nullable->get_nested_column() : left_col.get(); + const IColumn* right_nested = + right_nullable ? &right_nullable->get_nested_column() : right_col.get(); + + const auto* left_str_col = check_and_get_column(left_nested); + const auto* right_str_col = check_and_get_column(right_nested); + if (!left_str_col || !right_str_col) { + return Status::NotSupported("Illegal columns {}, {} of argument of function {}", + left_col->get_name(), right_col->get_name(), get_name()); + } + + auto res_col = ResultColumnType::create(input_rows_count); + auto& res_data = res_col->get_data(); + + const NullMap* left_null_map = + left_nullable ? &left_nullable->get_null_map_data() : nullptr; + const NullMap* right_null_map = + right_nullable ? &right_nullable->get_null_map_data() : nullptr; + const bool has_nullable = left_null_map != nullptr || right_null_map != nullptr; + + if (!has_nullable) { + if (left_const) { + RETURN_IF_ERROR(scalar_vector(left_str_col->get_data_at(0).trim_tail_padding_zero(), + *right_str_col, res_data)); + } else if (right_const) { + RETURN_IF_ERROR(vector_scalar( + *left_str_col, right_str_col->get_data_at(0).trim_tail_padding_zero(), + res_data)); + } else { + RETURN_IF_ERROR(vector_vector(*left_str_col, *right_str_col, res_data)); + } + block.replace_by_position(result, std::move(res_col)); + return Status::OK(); + } + + auto null_col = ColumnUInt8::create(input_rows_count, 0); + auto& null_map = null_col->get_data(); + for (size_t i = 0; i < input_rows_count; ++i) { + const size_t left_idx = left_const ? 0 : i; + const size_t right_idx = right_const ? 0 : i; + + const bool left_is_null = left_null_map && (*left_null_map)[left_idx]; + const bool right_is_null = right_null_map && (*right_null_map)[right_idx]; + if (left_is_null || right_is_null) { + null_map[i] = 1; + res_data[i] = 0; + continue; + } + + RETURN_IF_ERROR( + hamming_distance(left_str_col->get_data_at(left_idx).trim_tail_padding_zero(), + right_str_col->get_data_at(right_idx).trim_tail_padding_zero(), + res_data[i], i)); + } + + block.replace_by_position(result, + ColumnNullable::create(std::move(res_col), std::move(null_col))); + return Status::OK(); + } + +private: + static Status vector_vector(const ColumnString& lcol, const ColumnString& rcol, + ResultPaddedPODArray& res) { + DCHECK_EQ(lcol.size(), rcol.size()); + + const size_t size = lcol.size(); + res.resize(size); + std::vector left_offsets; + std::vector right_offsets; + for (size_t i = 0; i < size; ++i) { + RETURN_IF_ERROR(hamming_distance(lcol.get_data_at(i).trim_tail_padding_zero(), + rcol.get_data_at(i).trim_tail_padding_zero(), + left_offsets, right_offsets, res[i], i)); + } + return Status::OK(); + } + + static Status vector_scalar(const ColumnString& lcol, const StringRef& rdata, + ResultPaddedPODArray& res) { + const size_t size = lcol.size(); + res.resize(size); + const bool right_ascii = simd::VStringFunctions::is_ascii(rdata); + std::vector right_offsets; + utf8_char_offsets(rdata, right_offsets); + std::vector left_offsets; + for (size_t i = 0; i < size; ++i) { + const auto left = lcol.get_data_at(i).trim_tail_padding_zero(); + RETURN_IF_ERROR(hamming_distance_with_offsets( + left, left_offsets, false, simd::VStringFunctions::is_ascii(left), rdata, + right_offsets, true, right_ascii, res[i], i)); + } + return Status::OK(); + } + + static Status scalar_vector(const StringRef& ldata, const ColumnString& rcol, + ResultPaddedPODArray& res) { + const size_t size = rcol.size(); + res.resize(size); + const bool left_ascii = simd::VStringFunctions::is_ascii(ldata); + std::vector left_offsets; + utf8_char_offsets(ldata, left_offsets); + std::vector right_offsets; + for (size_t i = 0; i < size; ++i) { + const auto right = rcol.get_data_at(i).trim_tail_padding_zero(); + RETURN_IF_ERROR(hamming_distance_with_offsets( + ldata, left_offsets, true, left_ascii, right, right_offsets, false, + simd::VStringFunctions::is_ascii(right), res[i], i)); + } + return Status::OK(); + } + + static void utf8_char_offsets(const StringRef& ref, std::vector& offsets) { + offsets.clear(); + offsets.reserve(ref.size); + simd::VStringFunctions::get_char_len(ref.data, ref.size, offsets); + } + + static bool utf8_char_equal(const StringRef& left, size_t left_off, size_t left_next, + const StringRef& right, size_t right_off, size_t right_next) { + const size_t left_len = left_next - left_off; + const size_t right_len = right_next - right_off; + return left_len == right_len && + std::memcmp(left.data + left_off, right.data + right_off, left_len) == 0; + } + + static Status hamming_distance_ascii(const StringRef& left, const StringRef& right, + Int64& result, size_t row) { + if (left.size != right.size) { + return Status::InvalidArgument( + "hamming_distance requires strings of the same length at row {}", row); + } + + Int64 distance = 0; + for (size_t i = 0; i < left.size; ++i) { + distance += static_cast(left.data[i] != right.data[i]); + } + result = distance; + return Status::OK(); + } + + static Status hamming_distance_utf8(const StringRef& left, + const std::vector& left_offsets, + const StringRef& right, + const std::vector& right_offsets, Int64& result, + size_t row) { + if (left_offsets.size() != right_offsets.size()) { + return Status::InvalidArgument( + "hamming_distance requires strings of the same length at row {}", row); + } + + Int64 distance = 0; + const size_t len = left_offsets.size(); + for (size_t i = 0; i + 1 < len; ++i) { + const size_t left_off = left_offsets[i]; + const size_t left_next = left_offsets[i + 1]; + const size_t right_off = right_offsets[i]; + const size_t right_next = right_offsets[i + 1]; + distance += static_cast( + !utf8_char_equal(left, left_off, left_next, right, right_off, right_next)); + } + if (len > 0) { + const size_t left_off = left_offsets[len - 1]; + const size_t right_off = right_offsets[len - 1]; + distance += static_cast( + !utf8_char_equal(left, left_off, left.size, right, right_off, right.size)); + } + + result = distance; + return Status::OK(); + } + + static Status hamming_distance(const StringRef& left, const StringRef& right, + std::vector& left_offsets, + std::vector& right_offsets, Int64& result, size_t row) { + const bool left_ascii = simd::VStringFunctions::is_ascii(left); + const bool right_ascii = simd::VStringFunctions::is_ascii(right); + return hamming_distance_with_offsets(left, left_offsets, false, left_ascii, right, + right_offsets, false, right_ascii, result, row); + } + + static Status hamming_distance_with_offsets( + const StringRef& left, std::vector& left_offsets, bool left_offsets_ready, + bool left_ascii, const StringRef& right, std::vector& right_offsets, + bool right_offsets_ready, bool right_ascii, Int64& result, size_t row) { + if (left_ascii && right_ascii) { + return hamming_distance_ascii(left, right, result, row); + } + + if (!left_offsets_ready) { + utf8_char_offsets(left, left_offsets); + } + if (!right_offsets_ready) { + utf8_char_offsets(right, right_offsets); + } + return hamming_distance_utf8(left, left_offsets, right, right_offsets, result, row); + } + + static Status hamming_distance(const StringRef& left, const StringRef& right, Int64& result, + size_t row) { + std::vector left_offsets; + std::vector right_offsets; + return hamming_distance(left, right, left_offsets, right_offsets, result, row); + } +}; + +void register_function_hamming_distance(SimpleFunctionFactory& factory) { + factory.register_function(); +} + +#include "common/compile_check_end.h" +} // namespace doris diff --git a/be/src/exprs/function/function_levenshtein.cpp b/be/src/exprs/function/function_levenshtein.cpp new file mode 100644 index 00000000000000..1312273a37bfe3 --- /dev/null +++ b/be/src/exprs/function/function_levenshtein.cpp @@ -0,0 +1,277 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include + +#include "common/status.h" +#include "core/data_type/data_type_number.h" +#include "core/string_ref.h" +#include "exprs/function/function_totype.h" +#include "exprs/function/simple_function_factory.h" +#include "util/simd/vstring_function.h" + +namespace doris { +#include "common/compile_check_begin.h" + +struct NameLevenshtein { + static constexpr auto name = "levenshtein"; +}; + +template +struct LevenshteinImpl { + using ResultDataType = DataTypeInt32; + using ResultPaddedPODArray = PaddedPODArray; + + static Status vector_vector(const ColumnString::Chars& ldata, + const ColumnString::Offsets& loffsets, + const ColumnString::Chars& rdata, + const ColumnString::Offsets& roffsets, ResultPaddedPODArray& res) { + DCHECK_EQ(loffsets.size(), roffsets.size()); + + const size_t size = loffsets.size(); + res.resize(size); + std::vector left_offsets; + std::vector right_offsets; + for (size_t i = 0; i < size; ++i) { + res[i] = levenshtein_distance(string_ref_at(ldata, loffsets, i), + string_ref_at(rdata, roffsets, i), left_offsets, + right_offsets); + } + return Status::OK(); + } + + static Status vector_scalar(const ColumnString::Chars& ldata, + const ColumnString::Offsets& loffsets, const StringRef& rdata, + ResultPaddedPODArray& res) { + const size_t size = loffsets.size(); + res.resize(size); + const bool right_ascii = simd::VStringFunctions::is_ascii(rdata); + std::vector right_offsets; + utf8_char_offsets(rdata, right_offsets); + std::vector left_offsets; + for (size_t i = 0; i < size; ++i) { + res[i] = levenshtein_distance_with_right_offsets(string_ref_at(ldata, loffsets, i), + left_offsets, rdata, right_offsets, + right_ascii); + } + return Status::OK(); + } + + static Status scalar_vector(const StringRef& ldata, const ColumnString::Chars& rdata, + const ColumnString::Offsets& roffsets, ResultPaddedPODArray& res) { + const size_t size = roffsets.size(); + res.resize(size); + const bool left_ascii = simd::VStringFunctions::is_ascii(ldata); + std::vector left_offsets; + utf8_char_offsets(ldata, left_offsets); + std::vector right_offsets; + for (size_t i = 0; i < size; ++i) { + res[i] = levenshtein_distance_with_left_offsets(ldata, left_offsets, left_ascii, + string_ref_at(rdata, roffsets, i), + right_offsets); + } + return Status::OK(); + } + +private: + static StringRef string_ref_at(const ColumnString::Chars& data, + const ColumnString::Offsets& offsets, size_t i) { + DCHECK_LT(i, offsets.size()); + const auto idx = static_cast(i); + return StringRef(data.data() + offsets[idx - 1], offsets[idx] - offsets[idx - 1]) + .trim_tail_padding_zero(); + } + + static void utf8_char_offsets(const StringRef& ref, std::vector& offsets) { + offsets.clear(); + offsets.reserve(ref.size); + simd::VStringFunctions::get_char_len(ref.data, ref.size, offsets); + } + + static bool utf8_char_equal(const StringRef& left, size_t left_off, size_t left_next, + const StringRef& right, size_t right_off, size_t right_next) { + const size_t left_len = left_next - left_off; + const size_t right_len = right_next - right_off; + return left_len == right_len && + std::memcmp(left.data + left_off, right.data + right_off, left_len) == 0; + } + + static Int32 levenshtein_distance_utf8(const StringRef& left, + const std::vector& left_offsets, + const StringRef& right, + const std::vector& right_offsets) { + const StringRef* left_ref = &left; + const StringRef* right_ref = &right; + const std::vector* left_offsets_ref = &left_offsets; + const std::vector* right_offsets_ref = &right_offsets; + if (right_offsets_ref->size() > left_offsets_ref->size()) { + std::swap(left_offsets_ref, right_offsets_ref); + std::swap(left_ref, right_ref); + } + + const size_t m = left_offsets_ref->size(); + const size_t n = right_offsets_ref->size(); + + std::vector prev(n + 1); + std::vector curr(n + 1); + for (size_t j = 0; j <= n; ++j) { + prev[j] = static_cast(j); + } + + for (size_t i = 1; i <= m; ++i) { + curr[0] = static_cast(i); + const size_t left_off = (*left_offsets_ref)[i - 1]; + const size_t left_next = i < m ? (*left_offsets_ref)[i] : left_ref->size; + + for (size_t j = 1; j <= n; ++j) { + const size_t right_off = (*right_offsets_ref)[j - 1]; + const size_t right_next = j < n ? (*right_offsets_ref)[j] : right_ref->size; + + const Int32 cost = utf8_char_equal(*left_ref, left_off, left_next, *right_ref, + right_off, right_next) + ? 0 + : 1; + + const Int32 insert_cost = curr[j - 1] + 1; + const Int32 delete_cost = prev[j] + 1; + const Int32 replace_cost = prev[j - 1] + cost; + curr[j] = std::min(std::min(insert_cost, delete_cost), replace_cost); + } + std::swap(prev, curr); + } + + return prev[n]; + } + + static Int32 levenshtein_distance_ascii(const StringRef& left, const StringRef& right) { + const StringRef* left_ref = &left; + const StringRef* right_ref = &right; + size_t m = left.size; + size_t n = right.size; + + if (n > m) { + std::swap(left_ref, right_ref); + std::swap(m, n); + } + + std::vector prev(n + 1); + std::vector curr(n + 1); + for (size_t j = 0; j <= n; ++j) { + prev[j] = static_cast(j); + } + + for (size_t i = 1; i <= m; ++i) { + curr[0] = static_cast(i); + const char left_char = left_ref->data[i - 1]; + + for (size_t j = 1; j <= n; ++j) { + const Int32 cost = left_char == right_ref->data[j - 1] ? 0 : 1; + const Int32 insert_cost = curr[j - 1] + 1; + const Int32 delete_cost = prev[j] + 1; + const Int32 replace_cost = prev[j - 1] + cost; + curr[j] = std::min(std::min(insert_cost, delete_cost), replace_cost); + } + std::swap(prev, curr); + } + + return prev[n]; + } + + static Int32 levenshtein_distance(const StringRef& left, const StringRef& right, + std::vector& left_offsets, + std::vector& right_offsets) { + const bool left_ascii = simd::VStringFunctions::is_ascii(left); + const bool right_ascii = simd::VStringFunctions::is_ascii(right); + if (left_ascii && right_ascii) { + return levenshtein_distance_ascii(left, right); + } + + if (left.size == 0) { + return static_cast(simd::VStringFunctions::get_char_len(right.data, right.size)); + } + if (right.size == 0) { + return static_cast(simd::VStringFunctions::get_char_len(left.data, left.size)); + } + + utf8_char_offsets(left, left_offsets); + utf8_char_offsets(right, right_offsets); + return levenshtein_distance_utf8(left, left_offsets, right, right_offsets); + } + + static Int32 levenshtein_distance_with_right_offsets(const StringRef& left, + std::vector& left_offsets, + const StringRef& right, + const std::vector& right_offsets, + bool right_ascii) { + const bool left_ascii = simd::VStringFunctions::is_ascii(left); + if (left_ascii && right_ascii) { + return levenshtein_distance_ascii(left, right); + } + + if (left.size == 0) { + return static_cast(right_offsets.size()); + } + if (right.size == 0) { + return left_ascii ? static_cast(left.size) + : static_cast( + simd::VStringFunctions::get_char_len(left.data, left.size)); + } + + utf8_char_offsets(left, left_offsets); + return levenshtein_distance_utf8(left, left_offsets, right, right_offsets); + } + + static Int32 levenshtein_distance_with_left_offsets(const StringRef& left, + const std::vector& left_offsets, + bool left_ascii, const StringRef& right, + std::vector& right_offsets) { + const bool right_ascii = simd::VStringFunctions::is_ascii(right); + if (left_ascii && right_ascii) { + return levenshtein_distance_ascii(left, right); + } + + if (left.size == 0) { + return static_cast( + right_ascii ? right.size + : simd::VStringFunctions::get_char_len(right.data, right.size)); + } + if (right.size == 0) { + return static_cast(left_offsets.size()); + } + + utf8_char_offsets(right, right_offsets); + return levenshtein_distance_utf8(left, left_offsets, right, right_offsets); + } + + static Int32 levenshtein_distance(const StringRef& left, const StringRef& right) { + std::vector left_offsets; + std::vector right_offsets; + return levenshtein_distance(left, right, left_offsets, right_offsets); + } +}; + +using FunctionLevenshtein = + FunctionBinaryToType; + +void register_function_levenshtein(SimpleFunctionFactory& factory) { + factory.register_function(); +} + +#include "common/compile_check_end.h" +} // namespace doris diff --git a/be/src/exprs/function/function_string.cpp b/be/src/exprs/function/function_string.cpp index 053921ce4fa0d6..c131ff23e805b9 100644 --- a/be/src/exprs/function/function_string.cpp +++ b/be/src/exprs/function/function_string.cpp @@ -46,6 +46,7 @@ namespace doris { #include "common/compile_check_begin.h" + struct NameStringASCII { static constexpr auto name = "ascii"; }; @@ -1326,7 +1327,6 @@ using FunctionStringLocate = FunctionBinaryToType; using FunctionStringFindInSet = FunctionBinaryToType; - using FunctionQuote = FunctionStringToString; using FunctionToLower = FunctionStringToString, NameToLower>; diff --git a/be/src/exprs/function/simple_function_factory.h b/be/src/exprs/function/simple_function_factory.h index c1ebcc34535c67..1d7e26fe5593cf 100644 --- a/be/src/exprs/function/simple_function_factory.h +++ b/be/src/exprs/function/simple_function_factory.h @@ -120,6 +120,8 @@ void register_function_ai(SimpleFunctionFactory& factory); void register_function_score(SimpleFunctionFactory& factory); void register_function_variant_type(SimpleFunctionFactory& factory); void register_function_binary(SimpleFunctionFactory& factory); +void register_function_levenshtein(SimpleFunctionFactory& factory); +void register_function_hamming_distance(SimpleFunctionFactory& factory); void register_function_soundex(SimpleFunctionFactory& factory); #if defined(BE_TEST) && !defined(BE_BENCHMARK) @@ -356,6 +358,8 @@ class SimpleFunctionFactory { register_function_ai(instance); register_function_score(instance); register_function_binary(instance); + register_function_levenshtein(instance); + register_function_hamming_distance(instance); register_function_soundex(instance); register_function_json_transform(instance); register_function_json_hash(instance); diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java index f7b21c7dfbf095..bd2c45601d4b4a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java @@ -234,6 +234,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.Greatest; import org.apache.doris.nereids.trees.expressions.functions.scalar.Grouping; import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingId; +import org.apache.doris.nereids.trees.expressions.functions.scalar.HammingDistance; import org.apache.doris.nereids.trees.expressions.functions.scalar.Hex; import org.apache.doris.nereids.trees.expressions.functions.scalar.HllCardinality; import org.apache.doris.nereids.trees.expressions.functions.scalar.HllEmpty; @@ -316,6 +317,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.Least; import org.apache.doris.nereids.trees.expressions.functions.scalar.Left; import org.apache.doris.nereids.trees.expressions.functions.scalar.Length; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Levenshtein; import org.apache.doris.nereids.trees.expressions.functions.scalar.Ln; import org.apache.doris.nereids.trees.expressions.functions.scalar.Locate; import org.apache.doris.nereids.trees.expressions.functions.scalar.Log; @@ -797,6 +799,7 @@ public class BuiltinScalarFunctions implements FunctionHelper { scalar(Greatest.class, "greatest"), scalar(Grouping.class, "grouping"), scalar(GroupingId.class, "grouping_id"), + scalar(HammingDistance.class, "hamming_distance"), scalar(Hex.class, "hex"), scalar(HllCardinality.class, "hll_cardinality"), scalar(HllEmpty.class, "hll_empty"), @@ -882,6 +885,7 @@ public class BuiltinScalarFunctions implements FunctionHelper { scalar(LastQueryId.class, "last_query_id"), scalar(Lcm.class, "lcm"), scalar(Least.class, "least"), + scalar(Levenshtein.class, "levenshtein"), scalar(Left.class, "left", "strleft"), scalar(Length.class, "length", "octet_length"), scalar(Crc32.class, "crc32"), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/StringArithmetic.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/StringArithmetic.java index 0172c3b433940f..570d0bb4b98c95 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/StringArithmetic.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/StringArithmetic.java @@ -1123,6 +1123,74 @@ public static Expression soundex(StringLikeLiteral first) { return castStringLikeLiteral(first, result); } + /** + * Executable arithmetic functions levenshtein + */ + @ExecFunction(name = "levenshtein") + public static Expression levenshtein(StringLikeLiteral first, StringLikeLiteral second) { + int[] left = first.getValue().codePoints().toArray(); + int[] right = second.getValue().codePoints().toArray(); + + if (right.length > left.length) { + int[] tmp = left; + left = right; + right = tmp; + } + + int m = left.length; + int n = right.length; + if (n == 0) { + return new IntegerLiteral(m); + } + if (m == 0) { + return new IntegerLiteral(n); + } + + int[] prev = new int[n + 1]; + int[] curr = new int[n + 1]; + for (int j = 0; j <= n; j++) { + prev[j] = j; + } + + for (int i = 1; i <= m; i++) { + curr[0] = i; + int leftChar = left[i - 1]; + for (int j = 1; j <= n; j++) { + int cost = leftChar == right[j - 1] ? 0 : 1; + int insertCost = curr[j - 1] + 1; + int deleteCost = prev[j] + 1; + int replaceCost = prev[j - 1] + cost; + curr[j] = Math.min(insertCost, Math.min(deleteCost, replaceCost)); + } + int[] tmp = prev; + prev = curr; + curr = tmp; + } + + return new IntegerLiteral(prev[n]); + } + + /** + * Executable arithmetic functions hamming_distance + */ + @ExecFunction(name = "hamming_distance") + public static Expression hammingDistance(StringLikeLiteral first, StringLikeLiteral second) { + int[] left = first.getValue().codePoints().toArray(); + int[] right = second.getValue().codePoints().toArray(); + + if (left.length != right.length) { + throw new AnalysisException("hamming_distance requires strings of the same length"); + } + + long distance = 0; + for (int i = 0; i < left.length; i++) { + if (left[i] != right[i]) { + distance++; + } + } + return new BigIntLiteral(distance); + } + /** * Executable arithmetic functions make_set */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/HammingDistance.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/HammingDistance.java new file mode 100644 index 00000000000000..a874ed7a912f2d --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/HammingDistance.java @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.scalar; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; +import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.StringType; +import org.apache.doris.nereids.types.VarcharType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * ScalarFunction 'hamming_distance'. + */ +public class HammingDistance extends ScalarFunction + implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable { + + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(BigIntType.INSTANCE) + .args(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT), + FunctionSignature.ret(BigIntType.INSTANCE) + .args(StringType.INSTANCE, StringType.INSTANCE) + ); + + /** + * constructor with 2 arguments. + */ + public HammingDistance(Expression arg0, Expression arg1) { + super("hamming_distance", arg0, arg1); + } + + /** constructor for withChildren and reuse signature */ + private HammingDistance(ScalarFunctionParams functionParams) { + super(functionParams); + } + + /** + * withChildren. + */ + @Override + public HammingDistance withChildren(List children) { + Preconditions.checkArgument(children.size() == 2); + return new HammingDistance(getFunctionParams(children)); + } + + @Override + public List getSignatures() { + return SIGNATURES; + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitHammingDistance(this, context); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Levenshtein.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Levenshtein.java new file mode 100644 index 00000000000000..c1095b27a26262 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Levenshtein.java @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.scalar; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; +import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.StringType; +import org.apache.doris.nereids.types.VarcharType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * ScalarFunction 'levenshtein'. + */ +public class Levenshtein extends ScalarFunction + implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable { + + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(IntegerType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT), + FunctionSignature.ret(IntegerType.INSTANCE).args(StringType.INSTANCE, StringType.INSTANCE) + ); + + /** + * constructor with 2 arguments. + */ + public Levenshtein(Expression arg0, Expression arg1) { + super("levenshtein", arg0, arg1); + } + + /** constructor for withChildren and reuse signature */ + private Levenshtein(ScalarFunctionParams functionParams) { + super(functionParams); + } + + /** + * withChildren. + */ + @Override + public Levenshtein withChildren(List children) { + Preconditions.checkArgument(children.size() == 2); + return new Levenshtein(getFunctionParams(children)); + } + + @Override + public List getSignatures() { + return SIGNATURES; + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitLevenshtein(this, context); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java index a20abfeae853c7..26f1ab235dea44 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java @@ -248,6 +248,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.GetFormat; import org.apache.doris.nereids.trees.expressions.functions.scalar.GetVariantType; import org.apache.doris.nereids.trees.expressions.functions.scalar.Greatest; +import org.apache.doris.nereids.trees.expressions.functions.scalar.HammingDistance; import org.apache.doris.nereids.trees.expressions.functions.scalar.Hex; import org.apache.doris.nereids.trees.expressions.functions.scalar.HllCardinality; import org.apache.doris.nereids.trees.expressions.functions.scalar.HllEmpty; @@ -336,6 +337,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.Least; import org.apache.doris.nereids.trees.expressions.functions.scalar.Left; import org.apache.doris.nereids.trees.expressions.functions.scalar.Length; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Levenshtein; import org.apache.doris.nereids.trees.expressions.functions.scalar.Ln; import org.apache.doris.nereids.trees.expressions.functions.scalar.Locate; import org.apache.doris.nereids.trees.expressions.functions.scalar.Log; @@ -1885,6 +1887,14 @@ default R visitLocate(Locate locate, C context) { return visitScalarFunction(locate, context); } + default R visitHammingDistance(HammingDistance hammingDistance, C context) { + return visitScalarFunction(hammingDistance, context); + } + + default R visitLevenshtein(Levenshtein levenshtein, C context) { + return visitScalarFunction(levenshtein, context); + } + default R visitLog(Log log, C context) { return visitScalarFunction(log, context); } diff --git a/regression-test/data/query_p0/sql_functions/string_functions/test_string_all.out b/regression-test/data/query_p0/sql_functions/string_functions/test_string_all.out index d126d2cd8ea602..d76de3fb6195aa 100644 --- a/regression-test/data/query_p0/sql_functions/string_functions/test_string_all.out +++ b/regression-test/data/query_p0/sql_functions/string_functions/test_string_all.out @@ -965,6 +965,42 @@ S530 S530 -- !soundex_330 -- R163 R163 +-- !levenshtein_331 -- +0 3 2 1 1 + +-- !levenshtein_332 -- +0 3 3 \N \N + +-- !levenshtein_333 -- +2 1 1 + +-- !levenshtein_tbl -- +1 3 +2 0 +3 1 +4 \N +5 1 +6 2 +7 3 + +-- !hamming_distance_333 -- +0 0 1 1 + +-- !hamming_distance_334 -- +0 \N \N + +-- !hamming_distance_335 -- +4 1 2 + +-- !hamming_distance_tbl -- +1 0 +2 1 +3 1 +4 \N +5 4 +6 1 +7 2 + -- !space_333 -- @@ -1411,4 +1447,3 @@ Hello Test123 -- !xpath_string_486 -- 123 - diff --git a/regression-test/suites/query_p0/sql_functions/string_functions/test_string_all.groovy b/regression-test/suites/query_p0/sql_functions/string_functions/test_string_all.groovy index 9d7123b03a79d4..5417511317c265 100644 --- a/regression-test/suites/query_p0/sql_functions/string_functions/test_string_all.groovy +++ b/regression-test/suites/query_p0/sql_functions/string_functions/test_string_all.groovy @@ -753,6 +753,64 @@ suite("string_functions_all") { testFoldConst("SELECT soundex('R@b-e123rt'), soundex('Robert');") // SOUNDEX tests with non-ASCII characters - Skipped (not supported) + // LEVENSHTEIN tests + qt_levenshtein_331 "SELECT levenshtein('', ''), levenshtein('kitten', 'sitting'), levenshtein('flaw', 'lawn'), levenshtein('你好', '你们'), levenshtein('数据库', '数据');" + testFoldConst("SELECT levenshtein('', ''), levenshtein('kitten', 'sitting'), levenshtein('flaw', 'lawn'), levenshtein('你好', '你们'), levenshtein('数据库', '数据');") + qt_levenshtein_332 "SELECT levenshtein('abc', 'abc'), levenshtein('abc', ''), levenshtein('', 'abc'), levenshtein(NULL, 'abc'), levenshtein('abc', NULL);" + testFoldConst("SELECT levenshtein('abc', 'abc'), levenshtein('abc', ''), levenshtein('', 'abc'), levenshtein(NULL, 'abc'), levenshtein('abc', NULL);") + qt_levenshtein_333 "SELECT levenshtein('abcd', 'abdc'), levenshtein('你好呀', '你好'), levenshtein('a你b', 'a们b');" + testFoldConst("SELECT levenshtein('abcd', 'abdc'), levenshtein('你好呀', '你好'), levenshtein('a你b', 'a们b');") + sql """DROP TABLE IF EXISTS string_distance_lv_test""" + sql """ + CREATE TABLE IF NOT EXISTS string_distance_lv_test ( + id int, + s1 VARCHAR, + s2 VARCHAR + ) + DISTRIBUTED BY HASH(id) BUCKETS 1 + PROPERTIES ("replication_num"="1") + """ + sql """ + insert into string_distance_lv_test values + (1, 'kitten', 'sitting'), + (2, 'abc', 'abc'), + (3, '数据库', '数据'), + (4, null, 'abc'), + (5, '你好呀', '你好'), + (6, 'abcd', 'abdc'), + (7, '', '数据库') + """ + qt_levenshtein_tbl "SELECT id, levenshtein(s1, s2) FROM string_distance_lv_test ORDER BY id" + + // HAMMING_DISTANCE tests + qt_hamming_distance_333 "SELECT hamming_distance('', ''), hamming_distance('abc', 'abc'), hamming_distance('abc', 'abd'), hamming_distance('你好', '你们');" + testFoldConst("SELECT hamming_distance('', ''), hamming_distance('abc', 'abc'), hamming_distance('abc', 'abd'), hamming_distance('你好', '你们');") + qt_hamming_distance_334 "SELECT hamming_distance('abc', 'abc'), hamming_distance(NULL, 'abc'), hamming_distance('abc', NULL);" + testFoldConst("SELECT hamming_distance('abc', 'abc'), hamming_distance(NULL, 'abc'), hamming_distance('abc', NULL);") + qt_hamming_distance_335 "SELECT hamming_distance('abcd', 'wxyz'), hamming_distance('你好吗', '你们吗'), hamming_distance('数据库', '数库据');" + testFoldConst("SELECT hamming_distance('abcd', 'wxyz'), hamming_distance('你好吗', '你们吗'), hamming_distance('数据库', '数库据');") + sql """DROP TABLE IF EXISTS string_distance_hd_test""" + sql """ + CREATE TABLE IF NOT EXISTS string_distance_hd_test ( + id int, + s1 VARCHAR, + s2 VARCHAR + ) + DISTRIBUTED BY HASH(id) BUCKETS 1 + PROPERTIES ("replication_num"="1") + """ + sql """ + insert into string_distance_hd_test values + (1, 'abc', 'abc'), + (2, 'abc', 'abd'), + (3, '你好', '你们'), + (4, null, 'abc'), + (5, 'abcd', 'wxyz'), + (6, '你好吗', '你们吗'), + (7, '数据库', '数库据') + """ + qt_hamming_distance_tbl "SELECT id, hamming_distance(s1, s2) FROM string_distance_hd_test ORDER BY id" + // SPACE tests qt_space_333 "SELECT space(5);" testFoldConst("SELECT space(5);") @@ -1092,4 +1150,4 @@ suite("string_functions_all") { testFoldConst("SELECT xpath_string(NULL, '/a');") qt_xpath_string_486 "SELECT xpath_string('123', '/a');" testFoldConst("SELECT xpath_string('123', '/a');") -} \ No newline at end of file +}