Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions ops/ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
#include "ops/matmul_static.h" // includes highway.h
#include "ops/sum-inl.h"
#include "hwy/contrib/algo/transform-inl.h"
#include "hwy/contrib/math/fast_math-inl.h"
#include "hwy/contrib/math/math-inl.h"

HWY_BEFORE_NAMESPACE();
Expand Down Expand Up @@ -1442,10 +1443,11 @@ static HWY_NOINLINE void Softmax(Logits logits, ThreadingContext& ctx,
hn::Transform(d, logits.data(), logits.size(),
[pmax](const auto d, const V value) HWY_ATTR {
if constexpr (HWY_TARGET & HWY_ALL_SVE) {
// Workaround for buggy SVE codegen: avoid inlined Exp().
return hn::CallExp(d, hn::Sub(value, *pmax));
// Workaround for buggy SVE codegen: avoid inlined
// FastExpMinusOrZero().
return hn::CallFastExpMinusOrZero(d, hn::Sub(value, *pmax));
} else {
return hn::Exp(d, hn::Sub(value, *pmax));
return hn::FastExpMinusOrZero(d, hn::Sub(value, *pmax));
}
});

Expand Down
6 changes: 3 additions & 3 deletions ops/ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,10 @@ class TestSoftmax {
for (size_t i = 0; i < count; ++i) {
sum += x[i];
double rel = std::abs(x[i] - e[i]) / e[i];
ASSERT_LT(rel, 1e-6) << "Mismatch on coordinate " << i << " out of "
ASSERT_LT(rel, 2e-5) << "Mismatch on coordinate " << i << " out of "
<< count;
}
ASSERT_NEAR(sum, 1.0, 1e-6);
ASSERT_NEAR(sum, 1.0, 2e-5);
}

private:
Expand Down Expand Up @@ -384,7 +384,7 @@ class TestSoftmaxState {
}

ASSERT_NEAR(softmax_max, maxval, 1e-6);
ASSERT_NEAR(softmax_d, sum_exp, 1e-6);
ASSERT_NEAR(softmax_d, sum_exp, 2e-5);
}
};

Expand Down
Loading