Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dist/assets/oldWorker-Braq5_-n.js.map

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions dist/assets/oldWorker-CrOUJ63V.js.map

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions dist/assets/oldWorker-Ga0lx4u7.js.map

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions dist/assets/oldWorker-gPs-Rff0.js.map

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions dist/assets/worker-BUwnwGa6.js.map

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions dist/assets/worker-Or79jBlE.js.map

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions dist/assets/worker-XXgjG2ck.js.map

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions dist/assets/worker-kPeRSrjE.js.map

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions dist/spark.cjs.js

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dist/spark.cjs.js.map

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions dist/spark.module.js

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dist/spark.module.js.map

Large diffs are not rendered by default.

208 changes: 146 additions & 62 deletions rust/spark-worker-rs/src/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ impl SortBuffers {
}

pub fn sort_internal(buffers: &mut SortBuffers, num_splats: usize) -> Result<u32, String> {
let SortBuffers { readback, ordering, buckets } = buffers;
let SortBuffers {
readback,
ordering,
buckets,
} = buffers;
let readback = &readback[..num_splats];

// Set the bucket counts to zero
Expand Down Expand Up @@ -57,28 +61,26 @@ pub fn sort_internal(buffers: &mut SortBuffers, num_splats: usize) -> Result<u32
if buckets[0] != active_splats {
return Err(format!(
"Expected {} active splats but got {}",
active_splats,
buckets[0]
active_splats, buckets[0]
));
}
Ok(active_splats)
}

const DEPTH_INFINITY_F32: u32 = 0x7f800000;
const RADIX_BASE: usize = 1 << 16; // 65536
// 16-bit radix (2 passes)
const RADIX_BITS: u32 = 16;
const RADIX_BASE: usize = 1 << RADIX_BITS; // 65536
const RADIX_MASK: u32 = RADIX_BASE as u32 - 1;

#[derive(Default)]
pub struct Sort32Buffers {
/// raw f32 bit‑patterns (one per splat)
pub readback: Vec<u32>,
/// output indices
pub ordering: Vec<u32>,
/// bucket counts / offsets (length == RADIX_BASE)
pub buckets16lo: Vec<u32>,
/// bucket counts / offsets (length == RADIX_BASE)
pub buckets16hi: Vec<u32>,
/// scratch space for indices
pub scratch: Vec<u32>,
pub scratch: Vec<u64>, // (key, index)
pub buckets: Vec<u32>, // 2 * 65536
}

impl Sort32Buffers {
Expand All @@ -93,84 +95,166 @@ impl Sort32Buffers {
if self.scratch.len() < max_splats {
self.scratch.resize(max_splats, 0);
}
if self.buckets16lo.len() < RADIX_BASE {
self.buckets16lo.resize(RADIX_BASE, 0);
}
if self.buckets16hi.len() < RADIX_BASE {
self.buckets16hi.resize(RADIX_BASE, 0);
if self.buckets.len() < RADIX_BASE * 2 {
self.buckets.resize(RADIX_BASE * 2, 0);
}
}
}

/// Two‑pass radix sort (base 2¹⁶) of 32‑bit float bit‑patterns,
/// descending order (largest keys first). Mirrors the JS `sort32Splats`.
#[inline(always)]
fn prefix_sum_exclusive(buckets: &mut [u32]) -> u32 {
let mut sum = 0u32;
for b in buckets.iter_mut() {
let tmp = *b;
*b = sum;
sum = sum.wrapping_add(tmp);
}
sum
}

pub fn sort32_internal(
buffers: &mut Sort32Buffers,
max_splats: usize,
num_splats: usize,
) -> Result<u32, String> {
// make sure our buffers can hold `max_splats`
buffers.ensure_size(max_splats);

let Sort32Buffers { readback, ordering, buckets16lo, buckets16hi, scratch } = buffers;
let Sort32Buffers {
readback,
ordering,
scratch,
buckets,
} = buffers;
let keys = &readback[..num_splats];

// tally low and high buckets
buckets16lo.fill(0);
buckets16hi.fill(0);
for &key in keys.iter() {
if key < DEPTH_INFINITY_F32 {
let inv = !key;
buckets16lo[(inv & 0xFFFF) as usize] += 1;
buckets16hi[(inv >> 16) as usize] += 1;
// Split buckets
let (b0, b1) = buckets.split_at_mut(RADIX_BASE);
let b1 = &mut b1[..RADIX_BASE];

b0.fill(0);
b1.fill(0);

// pass 1: Histogram (branchless)
let mut chunks = keys.chunks_exact(8);

for chunk in chunks.by_ref() {
macro_rules! tick {
($k:expr) => {{
let valid = ($k < DEPTH_INFINITY_F32) as u32;
let inv = !$k;

let r0 = inv & RADIX_MASK;
let r1 = inv >> RADIX_BITS;

b0[r0 as usize] += valid;
unsafe { *b1.get_unchecked_mut(r1 as usize) += valid };
}};
}

tick!(chunk[0]);
tick!(chunk[1]);
tick!(chunk[2]);
tick!(chunk[3]);
tick!(chunk[4]);
tick!(chunk[5]);
tick!(chunk[6]);
tick!(chunk[7]);
}

// ——— Pass #1: bucket by inv(low 16 bits) ———
// exclusive prefix‑sum → starting offsets
let mut total: u32 = 0;
for slot in buckets16lo.iter_mut() {
let cnt = *slot;
*slot = total;
total = total.wrapping_add(cnt);
for &k in chunks.remainder() {
let valid = (k < DEPTH_INFINITY_F32) as u32;
let inv = !k;
b0[(inv & RADIX_MASK) as usize] += valid;
unsafe { *b1.get_unchecked_mut((inv >> RADIX_BITS) as usize) += valid };
}
let active_splats = total;

// scatter into scratch by low bits of inv
for (i, &key) in keys.iter().enumerate() {
if key < DEPTH_INFINITY_F32 {
let inv = !key;
let lo = (inv & 0xFFFF) as usize;
scratch[buckets16lo[lo] as usize] = i as u32;
buckets16lo[lo] += 1;

let active = prefix_sum_exclusive(b0) as usize;
prefix_sum_exclusive(b1);

// pass 1: scatter into scratch
let mut chunks = keys.chunks_exact(8);
let mut i = 0;

for chunk in chunks.by_ref() {
macro_rules! place {
($k:expr, $idx:expr) => {{
let valid = ($k < DEPTH_INFINITY_F32) as u32;
let inv = !$k;

let r0 = (inv & RADIX_MASK) as usize;
let pos = unsafe { *b0.get_unchecked(r0) } as usize;

// Always write (branchless), but only advance if valid
unsafe { *scratch.get_unchecked_mut(pos) = ((inv as u64) << 32) | ($idx as u64) };
unsafe { *b0.get_unchecked_mut(r0) += valid };
}};
}

place!(chunk[0], i);
place!(chunk[1], i + 1);
place!(chunk[2], i + 2);
place!(chunk[3], i + 3);
place!(chunk[4], i + 4);
place!(chunk[5], i + 5);
place!(chunk[6], i + 6);
place!(chunk[7], i + 7);

i += 8;
}

// ——— Pass #2: bucket by inv(high 16 bits) ———
// exclusive prefix‑sum again
let mut sum: u32 = 0;
for slot in buckets16hi.iter_mut() {
let cnt = *slot;
*slot = sum;
sum = sum.wrapping_add(cnt);
for &k in chunks.remainder() {
let valid = (k < DEPTH_INFINITY_F32) as u32;
let inv = !k;

let r0 = (inv & RADIX_MASK) as usize;
let pos = unsafe { *b0.get_unchecked(r0) } as usize;

unsafe { *scratch.get_unchecked_mut(pos) = ((inv as u64) << 32) | (i as u64) };
unsafe { *b0.get_unchecked_mut(r0) += valid };

i += 1;
}
// scatter into final ordering by high bits of inv
for &idx in scratch.iter().take(active_splats as usize) {
let key = keys[idx as usize];
let inv = !key;
let hi = (inv >> 16) as usize;
ordering[buckets16hi[hi] as usize] = idx;
buckets16hi[hi] += 1;

// pass 2: scatter into final ordering
let mut chunks = scratch[..active].chunks_exact(8);

for chunk in chunks.by_ref() {
macro_rules! place2 {
($kv:expr) => {{
let r1 = (($kv >> 48) & RADIX_MASK as u64) as usize;
let pos = unsafe { *b1.get_unchecked(r1) } as usize;

unsafe { *ordering.get_unchecked_mut(pos) = $kv as u32 };
unsafe { *b1.get_unchecked_mut(r1) += 1 };
}};
}

place2!(chunk[0]);
place2!(chunk[1]);
place2!(chunk[2]);
place2!(chunk[3]);
place2!(chunk[4]);
place2!(chunk[5]);
place2!(chunk[6]);
place2!(chunk[7]);
}

for &kv in chunks.remainder() {
let r1 = ((kv >> 48) & RADIX_MASK as u64) as usize;
let pos = unsafe { *b1.get_unchecked(r1) } as usize;

unsafe { *ordering.get_unchecked_mut(pos) = kv as u32 };
unsafe { *b1.get_unchecked_mut(r1) += 1 };
}

// sanity‑check: last bucket should have consumed all entries
if buckets16hi[RADIX_BASE - 1] != active_splats {
if b1[RADIX_BASE - 1] != active as u32 {
return Err(format!(
"Expected {} active splats but got {}",
active_splats,
buckets16hi[RADIX_BASE - 1]
active,
b1[RADIX_BASE - 1]
));
}

Ok(active_splats)
}
Ok(active as u32)
}