diff --git a/vecxt/src-js/doublearrays.scala b/vecxt/src-js/doublearrays.scala index 69c83601..1c30a909 100644 --- a/vecxt/src-js/doublearrays.scala +++ b/vecxt/src-js/doublearrays.scala @@ -767,6 +767,35 @@ object doublearrays: def maxElement: Double = vec.max // val t = js.Math.max( vec.toArray: _* ) + + inline def `zeroWhere!`( + other: Array[Double], + threshold: Double, + inline op: ComparisonOp + ): Unit = + assert(vec.length == other.length) + var i = 0 + while i < vec.length do + val hit = inline op match + case ComparisonOp.LE => other(i) <= threshold + case ComparisonOp.LT => other(i) < threshold + case ComparisonOp.GE => other(i) >= threshold + case ComparisonOp.GT => other(i) > threshold + case ComparisonOp.EQ => other(i) == threshold + case ComparisonOp.NE => other(i) != threshold + if hit then vec(i) = 0.0 + end if + i += 1 + end while + end `zeroWhere!` + + inline def zeroWhere( + other: Array[Double], + threshold: Double, + inline op: ComparisonOp + ): Array[Double] = + vec.clone().tap(_.`zeroWhere!`(other, threshold, op)) + end extension end doublearrays diff --git a/vecxt/src-js/floatarrays.scala b/vecxt/src-js/floatarrays.scala index 8aeb8d70..58d05224 100644 --- a/vecxt/src-js/floatarrays.scala +++ b/vecxt/src-js/floatarrays.scala @@ -835,6 +835,34 @@ object floatarrays: (cv / (vec.length - 1)).toFloat end covariance + inline def `zeroWhere!`( + other: Array[Float], + threshold: Float, + inline op: ComparisonOp + ): Unit = + assert(vec.length == other.length) + var i = 0 + while i < vec.length do + val hit = inline op match + case ComparisonOp.LE => other(i) <= threshold + case ComparisonOp.LT => other(i) < threshold + case ComparisonOp.GE => other(i) >= threshold + case ComparisonOp.GT => other(i) > threshold + case ComparisonOp.EQ => other(i) == threshold + case ComparisonOp.NE => other(i) != threshold + if hit then vec(i) = 0.0f + end if + i += 1 + end while + end `zeroWhere!` + + inline def zeroWhere( + other: Array[Float], + threshold: Float, + inline op: ComparisonOp + ): Array[Float] = + vec.clone().tap(_.`zeroWhere!`(other, threshold, op)) + end extension extension (vec: Array[Array[Double]]) diff --git a/vecxt/src-jvm/doublearrays.scala b/vecxt/src-jvm/doublearrays.scala index 2d860672..58de7b72 100644 --- a/vecxt/src-jvm/doublearrays.scala +++ b/vecxt/src-jvm/doublearrays.scala @@ -1239,6 +1239,52 @@ object doublearrays: // def max: Double = // vec(blas.idamax(vec.length, vec, 1)) // No JS version + + inline def `zeroWhere!`( + other: Array[Double], + threshold: Double, + inline op: ComparisonOp + ): Unit = + assert(vec.length == other.length) + val zero = DoubleVector.zero(spd) + val thresh = DoubleVector.broadcast(spd, threshold) + var i = 0 + + while i < spd.loopBound(vec.length) do + val values = DoubleVector.fromArray(spd, vec, i) + val cmp = DoubleVector.fromArray(spd, other, i) + val mask = inline op match + case ComparisonOp.LE => cmp.compare(VectorOperators.LE, thresh) + case ComparisonOp.LT => cmp.compare(VectorOperators.LT, thresh) + case ComparisonOp.GE => cmp.compare(VectorOperators.GE, thresh) + case ComparisonOp.GT => cmp.compare(VectorOperators.GT, thresh) + case ComparisonOp.EQ => cmp.compare(VectorOperators.EQ, thresh) + case ComparisonOp.NE => cmp.compare(VectorOperators.NE, thresh) + values.blend(zero, mask).intoArray(vec, i) + i += spdl + end while + + while i < vec.length do + val hit = inline op match + case ComparisonOp.LE => other(i) <= threshold + case ComparisonOp.LT => other(i) < threshold + case ComparisonOp.GE => other(i) >= threshold + case ComparisonOp.GT => other(i) > threshold + case ComparisonOp.EQ => other(i) == threshold + case ComparisonOp.NE => other(i) != threshold + if hit then vec(i) = 0.0 + end if + i += 1 + end while + end `zeroWhere!` + + inline def zeroWhere( + other: Array[Double], + threshold: Double, + inline op: ComparisonOp + ): Array[Double] = + vec.clone().tap(_.`zeroWhere!`(other, threshold, op)) + end extension end doublearrays diff --git a/vecxt/src-jvm/floatarrays.scala b/vecxt/src-jvm/floatarrays.scala index bfd87404..f49357ff 100644 --- a/vecxt/src-jvm/floatarrays.scala +++ b/vecxt/src-jvm/floatarrays.scala @@ -938,5 +938,50 @@ object floatarrays: Matrix(out, (n, m))(using BoundsCheck.DoBoundsCheck.no) end outer + inline def `zeroWhere!`( + other: Array[Float], + threshold: Float, + inline op: ComparisonOp + ): Unit = + assert(vec.length == other.length) + val zero = FloatVector.zero(spf) + val thresh = FloatVector.broadcast(spf, threshold) + var i = 0 + + while i < spf.loopBound(vec.length) do + val values = FloatVector.fromArray(spf, vec, i) + val cmp = FloatVector.fromArray(spf, other, i) + val mask = inline op match + case ComparisonOp.LE => cmp.compare(VectorOperators.LE, thresh) + case ComparisonOp.LT => cmp.compare(VectorOperators.LT, thresh) + case ComparisonOp.GE => cmp.compare(VectorOperators.GE, thresh) + case ComparisonOp.GT => cmp.compare(VectorOperators.GT, thresh) + case ComparisonOp.EQ => cmp.compare(VectorOperators.EQ, thresh) + case ComparisonOp.NE => cmp.compare(VectorOperators.NE, thresh) + values.blend(zero, mask).intoArray(vec, i) + i += spfl + end while + + while i < vec.length do + val hit = inline op match + case ComparisonOp.LE => other(i) <= threshold + case ComparisonOp.LT => other(i) < threshold + case ComparisonOp.GE => other(i) >= threshold + case ComparisonOp.GT => other(i) > threshold + case ComparisonOp.EQ => other(i) == threshold + case ComparisonOp.NE => other(i) != threshold + if hit then vec(i) = 0.0f + end if + i += 1 + end while + end `zeroWhere!` + + inline def zeroWhere( + other: Array[Float], + threshold: Float, + inline op: ComparisonOp + ): Array[Float] = + vec.clone().tap(_.`zeroWhere!`(other, threshold, op)) + end extension end floatarrays diff --git a/vecxt/src-native/doublearrays.scala b/vecxt/src-native/doublearrays.scala index 3dfc55d1..174c3fd3 100644 --- a/vecxt/src-native/doublearrays.scala +++ b/vecxt/src-native/doublearrays.scala @@ -751,6 +751,35 @@ object doublearrays: end covariance // def max: Double = vec(blas.cblas_idamax(vec.length, vec.at(0), 1)) // No JS version + + inline def `zeroWhere!`( + other: Array[Double], + threshold: Double, + inline op: ComparisonOp + ): Unit = + assert(vec.length == other.length) + var i = 0 + while i < vec.length do + val hit = inline op match + case ComparisonOp.LE => other(i) <= threshold + case ComparisonOp.LT => other(i) < threshold + case ComparisonOp.GE => other(i) >= threshold + case ComparisonOp.GT => other(i) > threshold + case ComparisonOp.EQ => other(i) == threshold + case ComparisonOp.NE => other(i) != threshold + if hit then vec(i) = 0.0 + end if + i += 1 + end while + end `zeroWhere!` + + inline def zeroWhere( + other: Array[Double], + threshold: Double, + inline op: ComparisonOp + ): Array[Double] = + vec.clone().tap(_.`zeroWhere!`(other, threshold, op)) + end extension end doublearrays diff --git a/vecxt/src-native/floatarrays.scala b/vecxt/src-native/floatarrays.scala index f9c8b480..c6788e10 100644 --- a/vecxt/src-native/floatarrays.scala +++ b/vecxt/src-native/floatarrays.scala @@ -818,6 +818,34 @@ object floatarrays: (cv / (vec.length - 1)).toFloat end covariance + inline def `zeroWhere!`( + other: Array[Float], + threshold: Float, + inline op: ComparisonOp + ): Unit = + assert(vec.length == other.length) + var i = 0 + while i < vec.length do + val hit = inline op match + case ComparisonOp.LE => other(i) <= threshold + case ComparisonOp.LT => other(i) < threshold + case ComparisonOp.GE => other(i) >= threshold + case ComparisonOp.GT => other(i) > threshold + case ComparisonOp.EQ => other(i) == threshold + case ComparisonOp.NE => other(i) != threshold + if hit then vec(i) = 0.0f + end if + i += 1 + end while + end `zeroWhere!` + + inline def zeroWhere( + other: Array[Float], + threshold: Float, + inline op: ComparisonOp + ): Array[Float] = + vec.clone().tap(_.`zeroWhere!`(other, threshold, op)) + end extension extension (vec: Array[Array[Double]]) diff --git a/vecxt/src/ComparisonOp.scala b/vecxt/src/ComparisonOp.scala new file mode 100644 index 00000000..49dcd3ef --- /dev/null +++ b/vecxt/src/ComparisonOp.scala @@ -0,0 +1,5 @@ +package vecxt + +enum ComparisonOp: + case LT, LE, GT, GE, EQ, NE +end ComparisonOp diff --git a/vecxt/src/all.scala b/vecxt/src/all.scala index ed04e11f..ef322fb4 100644 --- a/vecxt/src/all.scala +++ b/vecxt/src/all.scala @@ -18,6 +18,7 @@ object all: export vecxt.IntArraysX.* export vecxt.VarianceMode + export vecxt.ComparisonOp // matricies export vecxt.OneAndZero.given_OneAndZero_Boolean diff --git a/vecxt/test/src/zeroWhere.test.scala b/vecxt/test/src/zeroWhere.test.scala new file mode 100644 index 00000000..ed02df67 --- /dev/null +++ b/vecxt/test/src/zeroWhere.test.scala @@ -0,0 +1,231 @@ +package vecxt + +import scala.util.chaining.* +import all.* + +class ZeroWhereSuite extends munit.FunSuite: + + // ===== Float tests ===== + + test("zeroWhere! zeros elements where other <= threshold (Float)") { + val vec = Array[Float](1.0f, 2.0f, 3.0f, 4.0f, 5.0f) + val other = Array[Float](0.0f, -1.0f, 1.0f, 0.0f, 2.0f) + vec.`zeroWhere!`(other, 0.0f, ComparisonOp.LE) + assertEquals(vec(0), 0.0f) + assertEquals(vec(1), 0.0f) + assertEquals(vec(2), 3.0f) + assertEquals(vec(3), 0.0f) + assertEquals(vec(4), 5.0f) + } + + test("zeroWhere returns new array without mutating source (Float)") { + val vec = Array[Float](1.0f, 2.0f, 3.0f) + val other = Array[Float](-1.0f, 1.0f, -1.0f) + val result = vec.zeroWhere(other, 0.0f, ComparisonOp.LE) + assertEquals(result(0), 0.0f) + assertEquals(result(1), 2.0f) + assertEquals(result(2), 0.0f) + // source unchanged + assertEquals(vec(0), 1.0f) + assertEquals(vec(1), 2.0f) + assertEquals(vec(2), 3.0f) + } + + test("zeroWhere! respects all ComparisonOp variants (Float)") { + val other = Array[Float](1.0f, 2.0f, 3.0f) + + val vecLT = Array[Float](10.0f, 20.0f, 30.0f) + vecLT.`zeroWhere!`(other, 2.0f, ComparisonOp.LT) + assertEquals(vecLT(0), 0.0f, "LT index 0") + assertEquals(vecLT(1), 20.0f, "LT index 1") + assertEquals(vecLT(2), 30.0f, "LT index 2") + + val vecLE = Array[Float](10.0f, 20.0f, 30.0f) + vecLE.`zeroWhere!`(other, 2.0f, ComparisonOp.LE) + assertEquals(vecLE(0), 0.0f, "LE index 0") + assertEquals(vecLE(1), 0.0f, "LE index 1") + assertEquals(vecLE(2), 30.0f, "LE index 2") + + val vecGT = Array[Float](10.0f, 20.0f, 30.0f) + vecGT.`zeroWhere!`(other, 2.0f, ComparisonOp.GT) + assertEquals(vecGT(0), 10.0f, "GT index 0") + assertEquals(vecGT(1), 20.0f, "GT index 1") + assertEquals(vecGT(2), 0.0f, "GT index 2") + + val vecGE = Array[Float](10.0f, 20.0f, 30.0f) + vecGE.`zeroWhere!`(other, 2.0f, ComparisonOp.GE) + assertEquals(vecGE(0), 10.0f, "GE index 0") + assertEquals(vecGE(1), 0.0f, "GE index 1") + assertEquals(vecGE(2), 0.0f, "GE index 2") + + val vecEQ = Array[Float](10.0f, 20.0f, 30.0f) + vecEQ.`zeroWhere!`(other, 2.0f, ComparisonOp.EQ) + assertEquals(vecEQ(0), 10.0f, "EQ index 0") + assertEquals(vecEQ(1), 0.0f, "EQ index 1") + assertEquals(vecEQ(2), 30.0f, "EQ index 2") + + val vecNE = Array[Float](10.0f, 20.0f, 30.0f) + vecNE.`zeroWhere!`(other, 2.0f, ComparisonOp.NE) + assertEquals(vecNE(0), 0.0f, "NE index 0") + assertEquals(vecNE(1), 20.0f, "NE index 1") + assertEquals(vecNE(2), 0.0f, "NE index 2") + } + + test("zeroWhere! handles non-SIMD-aligned lengths (Float)") { + val n = 19 + val vec = Array.tabulate[Float](n)(i => (i + 1).toFloat) + val other = Array.tabulate[Float](n)(i => if i % 3 == 0 then -1.0f else 1.0f) + vec.`zeroWhere!`(other, 0.0f, ComparisonOp.LE) + var i = 0 + while i < n do + if i % 3 == 0 then assertEquals(vec(i), 0.0f, s"index $i should be zeroed") + else assertEquals(vec(i), (i + 1).toFloat, s"index $i should be kept") + end if + i += 1 + end while + } + + test("zeroWhere! on empty arrays is a no-op (Float)") { + val vec = Array.empty[Float] + val other = Array.empty[Float] + vec.`zeroWhere!`(other, 0.0f, ComparisonOp.LE) + assertEquals(vec.length, 0) + } + + test("zeroWhere! zeros all elements when all satisfy condition (Float)") { + val vec = Array[Float](1.0f, 2.0f, 3.0f) + val other = Array[Float](-1.0f, -2.0f, -3.0f) + vec.`zeroWhere!`(other, 0.0f, ComparisonOp.LT) + assertEquals(vec(0), 0.0f) + assertEquals(vec(1), 0.0f) + assertEquals(vec(2), 0.0f) + } + + test("zeroWhere! keeps all elements when none satisfy condition (Float)") { + val vec = Array[Float](1.0f, 2.0f, 3.0f) + val other = Array[Float](10.0f, 20.0f, 30.0f) + vec.`zeroWhere!`(other, 0.0f, ComparisonOp.LT) + assertEquals(vec(0), 1.0f) + assertEquals(vec(1), 2.0f) + assertEquals(vec(2), 3.0f) + } + + test("zeroWhere! handles NaN in other array (Float)") { + val vec = Array[Float](1.0f, 2.0f, 3.0f) + val other = Array[Float](Float.NaN, 1.0f, Float.NaN) + // NaN comparisons are always false (IEEE 754) + vec.`zeroWhere!`(other, 0.0f, ComparisonOp.LE) + assertEquals(vec(0), 1.0f) + assertEquals(vec(1), 2.0f) + assertEquals(vec(2), 3.0f) + } + + test("zeroWhere! handles threshold at Float boundaries") { + val vec = Array[Float](1.0f, 2.0f) + val other = Array[Float](Float.NegativeInfinity, Float.PositiveInfinity) + vec.`zeroWhere!`(other, 0.0f, ComparisonOp.LE) + assertEquals(vec(0), 0.0f) + assertEquals(vec(1), 2.0f) + } + + test("zeroWhere! works when vec and other are the same array (Float)") { + val arr = Array[Float](-2.0f, 3.0f, -1.0f, 4.0f) + arr.`zeroWhere!`(arr, 0.0f, ComparisonOp.LT) + assertEquals(arr(0), 0.0f) + assertEquals(arr(1), 3.0f) + assertEquals(arr(2), 0.0f) + assertEquals(arr(3), 4.0f) + } + + // ===== Double tests ===== + + test("zeroWhere! zeros elements where other <= threshold (Double)") { + val vec = Array[Double](1.0, 2.0, 3.0, 4.0, 5.0) + val other = Array[Double](0.0, -1.0, 1.0, 0.0, 2.0) + vec.`zeroWhere!`(other, 0.0, ComparisonOp.LE) + assertEquals(vec(0), 0.0) + assertEquals(vec(1), 0.0) + assertEquals(vec(2), 3.0) + assertEquals(vec(3), 0.0) + assertEquals(vec(4), 5.0) + } + + test("zeroWhere returns new array without mutating source (Double)") { + val vec = Array[Double](1.0, 2.0, 3.0) + val other = Array[Double](-1.0, 1.0, -1.0) + val result = vec.zeroWhere(other, 0.0, ComparisonOp.LE) + assertEquals(result(0), 0.0) + assertEquals(result(1), 2.0) + assertEquals(result(2), 0.0) + // source unchanged + assertEquals(vec(0), 1.0) + assertEquals(vec(1), 2.0) + assertEquals(vec(2), 3.0) + } + + test("zeroWhere! respects all ComparisonOp variants (Double)") { + val other = Array[Double](1.0, 2.0, 3.0) + + val vecLT = Array[Double](10.0, 20.0, 30.0) + vecLT.`zeroWhere!`(other, 2.0, ComparisonOp.LT) + assertEquals(vecLT(0), 0.0, "LT index 0") + assertEquals(vecLT(1), 20.0, "LT index 1") + assertEquals(vecLT(2), 30.0, "LT index 2") + + val vecLE = Array[Double](10.0, 20.0, 30.0) + vecLE.`zeroWhere!`(other, 2.0, ComparisonOp.LE) + assertEquals(vecLE(0), 0.0, "LE index 0") + assertEquals(vecLE(1), 0.0, "LE index 1") + assertEquals(vecLE(2), 30.0, "LE index 2") + + val vecGT = Array[Double](10.0, 20.0, 30.0) + vecGT.`zeroWhere!`(other, 2.0, ComparisonOp.GT) + assertEquals(vecGT(0), 10.0, "GT index 0") + assertEquals(vecGT(1), 20.0, "GT index 1") + assertEquals(vecGT(2), 0.0, "GT index 2") + + val vecGE = Array[Double](10.0, 20.0, 30.0) + vecGE.`zeroWhere!`(other, 2.0, ComparisonOp.GE) + assertEquals(vecGE(0), 10.0, "GE index 0") + assertEquals(vecGE(1), 0.0, "GE index 1") + assertEquals(vecGE(2), 0.0, "GE index 2") + + val vecEQ = Array[Double](10.0, 20.0, 30.0) + vecEQ.`zeroWhere!`(other, 2.0, ComparisonOp.EQ) + assertEquals(vecEQ(0), 10.0, "EQ index 0") + assertEquals(vecEQ(1), 0.0, "EQ index 1") + assertEquals(vecEQ(2), 30.0, "EQ index 2") + + val vecNE = Array[Double](10.0, 20.0, 30.0) + vecNE.`zeroWhere!`(other, 2.0, ComparisonOp.NE) + assertEquals(vecNE(0), 0.0, "NE index 0") + assertEquals(vecNE(1), 20.0, "NE index 1") + assertEquals(vecNE(2), 0.0, "NE index 2") + } + + test("zeroWhere! on empty arrays is a no-op (Double)") { + val vec = Array.empty[Double] + val other = Array.empty[Double] + vec.`zeroWhere!`(other, 0.0, ComparisonOp.LE) + assertEquals(vec.length, 0) + } + + test("zeroWhere! handles NaN in other array (Double)") { + val vec = Array[Double](1.0, 2.0, 3.0) + val other = Array[Double](Double.NaN, 1.0, Double.NaN) + // NaN comparisons are always false (IEEE 754) + vec.`zeroWhere!`(other, 0.0, ComparisonOp.LE) + assertEquals(vec(0), 1.0) + assertEquals(vec(1), 2.0) + assertEquals(vec(2), 3.0) + } + + test("zeroWhere! handles threshold at Double boundaries") { + val vec = Array[Double](1.0, 2.0) + val other = Array[Double](Double.NegativeInfinity, Double.PositiveInfinity) + vec.`zeroWhere!`(other, 0.0, ComparisonOp.LE) + assertEquals(vec(0), 0.0) + assertEquals(vec(1), 2.0) + } + +end ZeroWhereSuite