diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 36ec5d7f..2b2f310b 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -55,6 +55,7 @@ #include "ops/matmul_static.h" // includes highway.h #include "ops/sum-inl.h" #include "hwy/contrib/algo/transform-inl.h" +#include "hwy/contrib/math/fast_math-inl.h" #include "hwy/contrib/math/math-inl.h" HWY_BEFORE_NAMESPACE(); @@ -1442,10 +1443,11 @@ static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx, hn::Transform(d, logits.data(), logits.size(), [pmax](const auto d, const V value) HWY_ATTR { if constexpr (HWY_TARGET & HWY_ALL_SVE) { - // Workaround for buggy SVE codegen: avoid inlined Exp(). - return hn::CallExp(d, hn::Sub(value, *pmax)); + // Workaround for buggy SVE codegen: avoid inlined + // FastExpMinusOrZero(). + return hn::CallFastExpMinusOrZero(d, hn::Sub(value, *pmax)); } else { - return hn::Exp(d, hn::Sub(value, *pmax)); + return hn::FastExpMinusOrZero(d, hn::Sub(value, *pmax)); } }); diff --git a/ops/ops_test.cc b/ops/ops_test.cc index 8d177264..3e291ee8 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -321,10 +321,10 @@ class TestSoftmax { for (size_t i = 0; i < count; ++i) { sum += x[i]; double rel = std::abs(x[i] - e[i]) / e[i]; - ASSERT_LT(rel, 1e-6) << "Mismatch on coordinate " << i << " out of " + ASSERT_LT(rel, 2e-5) << "Mismatch on coordinate " << i << " out of " << count; } - ASSERT_NEAR(sum, 1.0, 1e-6); + ASSERT_NEAR(sum, 1.0, 2e-5); } private: @@ -384,7 +384,7 @@ class TestSoftmaxState { } ASSERT_NEAR(softmax_max, maxval, 1e-6); - ASSERT_NEAR(softmax_d, sum_exp, 1e-6); + ASSERT_NEAR(softmax_d, sum_exp, 2e-5); } };