diff --git a/.changeset/optimize-lazy-snapshot-reads.md b/.changeset/optimize-lazy-snapshot-reads.md new file mode 100644 index 000000000..69cf62b80 --- /dev/null +++ b/.changeset/optimize-lazy-snapshot-reads.md @@ -0,0 +1,6 @@ +--- +"loro-crdt": patch +"loro-crdt-map": patch +--- + +Reduce memory usage for read-only access to snapshot-imported documents by avoiding unnecessary lazy container state initialization. diff --git a/crates/bench-utils/src/json.rs b/crates/bench-utils/src/json.rs index c0057c76f..dc6f860fc 100644 --- a/crates/bench-utils/src/json.rs +++ b/crates/bench-utils/src/json.rs @@ -53,10 +53,8 @@ impl ActionTrait for JsonAction { fn normalize_value(value: &mut LoroValue) { match value { - LoroValue::Double(f) => { - if f.is_nan() { - *f = 0.0; - } + LoroValue::Double(f) if f.is_nan() => { + *f = 0.0; } LoroValue::List(l) => { for v in l.make_mut().iter_mut() { diff --git a/crates/delta/src/delta_rope.rs b/crates/delta/src/delta_rope.rs index 08730aacb..66e9d2683 100644 --- a/crates/delta/src/delta_rope.rs +++ b/crates/delta/src/delta_rope.rs @@ -392,13 +392,11 @@ impl PartialEq for Delta b.next_with(len).unwrap(); } } - (DeltaItem::Retain { attr, .. }, DeltaItem::Retain { attr: b_attr, .. }) => { - if *attr == *b_attr { - a.next_with(len).unwrap(); - b.next_with(len).unwrap(); - } else { - return false; - } + (DeltaItem::Retain { attr, .. }, DeltaItem::Retain { attr: b_attr, .. }) + if *attr == *b_attr => + { + a.next_with(len).unwrap(); + b.next_with(len).unwrap(); } _ => return false, } diff --git a/crates/fuzz/src/container/tree.rs b/crates/fuzz/src/container/tree.rs index 97576b6cb..639168d2b 100644 --- a/crates/fuzz/src/container/tree.rs +++ b/crates/fuzz/src/container/tree.rs @@ -362,7 +362,7 @@ impl Actionable for TreeAction { } TreeActionInner::MetaDelete { key } => { let meta = super::unwrap(tree.get_meta(target))?; - meta.delete(key); + let _ = meta.delete(key); None } TreeActionInner::MetaClear => { diff --git a/crates/fuzz/src/one_doc_fuzzer.rs b/crates/fuzz/src/one_doc_fuzzer.rs index 4f537cc2f..7872ccb7d 100644 --- a/crates/fuzz/src/one_doc_fuzzer.rs +++ b/crates/fuzz/src/one_doc_fuzzer.rs @@ -646,7 +646,7 @@ impl OneDocFuzzer { undo.clear(); } } - Action::ForkAt { site, to } => { + Action::ForkAt { site, to: _ } => { let frontiers = self.branches[*site as usize].frontiers.clone(); let _forked = self.doc.fork_at(&frontiers); } diff --git a/crates/fuzz/tests/test.rs b/crates/fuzz/tests/test.rs index 41964a5d3..39875b144 100644 --- a/crates/fuzz/tests/test.rs +++ b/crates/fuzz/tests/test.rs @@ -31,6 +31,42 @@ fn test_empty() { test_multi_sites(5, vec![FuzzTarget::All], &mut []) } +#[test] +fn all_fuzz_lazy_richtext_append_style_anchor() { + test_multi_sites( + 5, + vec![FuzzTarget::All], + &mut [ + Handle { + site: 0, + target: 0, + container: 0, + action: Generic(GenericAction { + value: I32(0), + bool: false, + key: 0, + pos: 0, + length: 0, + prop: 2962851221704015872, + }), + }, + Handle { + site: 0, + target: 0, + container: 0, + action: Generic(GenericAction { + value: I32(0), + bool: false, + key: 59, + pos: 15950377895847788544, + length: 18386260117272886751, + prop: 4251980913, + }), + }, + ], + ) +} + #[test] fn all_fuzz_text_update_deleted_container() { test_multi_sites( diff --git a/crates/kv-store/src/block.rs b/crates/kv-store/src/block.rs index 082eb7108..beaf970f2 100644 --- a/crates/kv-store/src/block.rs +++ b/crates/kv-store/src/block.rs @@ -6,7 +6,7 @@ use std::{ }; use bytes::{Buf, Bytes}; -use loro_common::LoroResult; +use loro_common::{LoroError, LoroResult}; use once_cell::sync::OnceCell; use crate::{ @@ -17,6 +17,9 @@ use crate::{ use super::sstable::{SIZE_OF_U16, SIZE_OF_U8}; +const MAX_NORMAL_BLOCK_DATA_LEN: usize = u16::MAX as usize; +const MAX_NORMAL_BLOCK_ENTRIES: usize = u16::MAX as usize; + #[derive(Debug, Clone)] pub struct LargeValueBlock { // without checksum @@ -118,16 +121,40 @@ impl NormalBlock { first_key: Bytes, compression_type: CompressionType, ) -> LoroResult { + if raw_block_and_check.len() < SIZE_OF_U32 { + return Err(LoroError::DecodeError("Invalid bytes".into())); + } + let buf = raw_block_and_check.slice(..raw_block_and_check.len() - SIZE_OF_U32); let mut data = vec![]; decompress(&mut data, buf, compression_type)?; + if data.len() < SIZE_OF_U16 { + return Err(LoroError::DecodeError("Invalid bytes".into())); + } + let offsets_len = (&data[data.len() - SIZE_OF_U16..]).get_u16_le() as usize; - let data_end = data.len() - SIZE_OF_U16 * (offsets_len + 1); + if offsets_len == 0 { + return Err(LoroError::DecodeError("Invalid bytes".into())); + } + + let offsets_bytes_len = SIZE_OF_U16 + .checked_mul(offsets_len + 1) + .ok_or_else(|| LoroError::DecodeError("Invalid bytes".into()))?; + if data.len() < offsets_bytes_len { + return Err(LoroError::DecodeError("Invalid bytes".into())); + } + + let data_end = data.len() - offsets_bytes_len; + if data_end > u16::MAX as usize { + return Err(LoroError::DecodeError("Invalid bytes".into())); + } + let offsets = &data[data_end..data.len() - SIZE_OF_U16]; - let offsets = offsets + let offsets: Vec = offsets .chunks(SIZE_OF_U16) .map(|mut chunk| chunk.get_u16_le()) .collect(); + Self::validate_decoded_data(&data[..data_end], &offsets, &first_key)?; Ok(NormalBlock { data: Bytes::copy_from_slice(&data[..data_end]), encoded_data: OnceCell::with_value((raw_block_and_check, compression_type)), @@ -135,6 +162,67 @@ impl NormalBlock { first_key, }) } + + fn validate_decoded_data(data: &[u8], offsets: &[u16], first_key: &[u8]) -> LoroResult<()> { + if offsets.first().copied() != Some(0) { + return Err(LoroError::DecodeError("Invalid bytes".into())); + } + + let mut prev_key: Option> = None; + let mut prev_offset = 0usize; + for (idx, offset) in offsets.iter().map(|x| *x as usize).enumerate() { + let offset_end = offsets + .get(idx + 1) + .map_or(data.len(), |next| *next as usize); + if offset < prev_offset || offset > offset_end || offset_end > data.len() { + return Err(LoroError::DecodeError("Invalid bytes".into())); + } + + let key = if idx == 0 { + first_key.to_vec() + } else { + let header_end = offset + .checked_add(SIZE_OF_U8 + SIZE_OF_U16) + .ok_or_else(|| LoroError::DecodeError("Invalid bytes".into()))?; + if header_end > offset_end { + return Err(LoroError::DecodeError("Invalid bytes".into())); + } + + let common_prefix_len = data[offset] as usize; + if common_prefix_len > first_key.len() { + return Err(LoroError::DecodeError("Invalid bytes".into())); + } + + let key_suffix_len = + u16::from_le_bytes(data[offset + SIZE_OF_U8..header_end].try_into().unwrap()) + as usize; + let key_end = header_end + .checked_add(key_suffix_len) + .ok_or_else(|| LoroError::DecodeError("Invalid bytes".into()))?; + if key_end > offset_end { + return Err(LoroError::DecodeError("Invalid bytes".into())); + } + + let mut key = Vec::with_capacity(common_prefix_len + key_suffix_len); + key.extend_from_slice(&first_key[..common_prefix_len]); + key.extend_from_slice(&data[header_end..key_end]); + key + }; + + if key.is_empty() + || prev_key + .as_ref() + .is_some_and(|prev_key| prev_key.as_slice() >= key.as_slice()) + { + return Err(LoroError::DecodeError("Invalid bytes".into())); + } + + prev_offset = offset; + prev_key = Some(key); + } + + Ok(()) + } } #[derive(Debug, Clone)] @@ -189,20 +277,31 @@ impl Block { } } - pub fn decode( + pub(crate) fn try_decode( raw_block_and_check: Bytes, is_large: bool, key: Bytes, compression_type: CompressionType, - ) -> Self { - // The caller is responsible for validating SSTable integrity before lazy block reads. + ) -> LoroResult { + if key.is_empty() { + return Err(LoroError::DecodeError("Invalid bytes".into())); + } + if is_large { return LargeValueBlock::decode(raw_block_and_check, key, compression_type) - .map(Block::Large) - .expect("validated SSTable block should decode"); + .map(Block::Large); } - NormalBlock::decode(raw_block_and_check, key, compression_type) - .map(Block::Normal) + NormalBlock::decode(raw_block_and_check, key, compression_type).map(Block::Normal) + } + + pub fn decode( + raw_block_and_check: Bytes, + is_large: bool, + key: Bytes, + compression_type: CompressionType, + ) -> Self { + // The caller is responsible for validating SSTable integrity before lazy block reads. + Self::try_decode(raw_block_and_check, is_large, key, compression_type) .expect("validated SSTable block should decode") } @@ -273,9 +372,13 @@ impl BlockBuilder { /// └─────────────────────────────────────────────────────┘ /// pub fn add(&mut self, key: &[u8], value: &[u8]) -> bool { + if key.is_empty() { + return false; + } + debug_assert!(!key.is_empty(), "key cannot be empty"); if self.first_key.is_empty() { - if value.len() > self.block_size { + if value.len() > self.block_size || value.len() > MAX_NORMAL_BLOCK_DATA_LEN { self.data.extend_from_slice(value); self.is_large = true; self.first_key = Bytes::copy_from_slice(key); @@ -288,16 +391,39 @@ impl BlockBuilder { return true; } - // whether the block is full - if self.estimated_size() + key.len() + value.len() + SIZE_OF_U8 + SIZE_OF_U16 - > self.block_size - { + if self.offsets.len() >= MAX_NORMAL_BLOCK_ENTRIES { return false; } - self.offsets.push(self.data.len() as u16); let (common, suffix) = get_common_prefix_len_and_strip(key, &self.first_key); let key_len = suffix.len(); + let Some(next_data_len) = self + .data + .len() + .checked_add(SIZE_OF_U8 + SIZE_OF_U16) + .and_then(|len| len.checked_add(key_len)) + .and_then(|len| len.checked_add(value.len())) + else { + return false; + }; + if next_data_len > MAX_NORMAL_BLOCK_DATA_LEN { + return false; + } + + // whether the block is full + let Some(estimated_size) = self + .estimated_size() + .checked_add(key_len) + .and_then(|len| len.checked_add(value.len())) + .and_then(|len| len.checked_add(SIZE_OF_U8 + SIZE_OF_U16)) + else { + return false; + }; + if estimated_size > self.block_size { + return false; + } + + self.offsets.push(self.data.len() as u16); self.data.push(common); self.data.extend_from_slice(&(key_len as u16).to_le_bytes()); self.data.extend_from_slice(suffix); diff --git a/crates/kv-store/src/mem_store.rs b/crates/kv-store/src/mem_store.rs index 86bd15133..269b68fb2 100644 --- a/crates/kv-store/src/mem_store.rs +++ b/crates/kv-store/src/mem_store.rs @@ -99,10 +99,18 @@ impl MemKvStore { } pub fn set(&mut self, key: &[u8], value: Bytes) { + if key.is_empty() { + return; + } + self.mem_table.insert(Bytes::copy_from_slice(key), value); } pub fn compare_and_swap(&mut self, key: &[u8], old: Option, new: Bytes) -> bool { + if key.is_empty() { + return false; + } + match self.get(key) { Some(v) => { if old == Some(v) { @@ -131,6 +139,10 @@ impl MemKvStore { /// /// If the value is empty, it means the key is deleted pub fn contains_key(&self, key: &[u8]) -> bool { + if key.is_empty() { + return false; + } + if self.mem_table.contains_key(key) { return !self.mem_table.get(key).unwrap().is_empty(); } @@ -246,8 +258,6 @@ impl MemKvStore { return Ok(()); } - // Since all the export format right now has its own checksum on the header, - // it's safe for us to skip the checksum check internally here. let ss_table = SsTable::import_all(bytes, false).map_err(|e| e.to_string())?; self.ss_table.push(ss_table); Ok(()) @@ -593,6 +603,28 @@ mod tests { assert!(store.contains_key(key)); } + #[test] + fn empty_key_mutations_are_ignored() { + let mut store = new_store(); + let value = Bytes::from_static(b"value"); + + store.set(&[], value.clone()); + assert_eq!(store.get(&[]), None); + assert!(!store.contains_key(&[])); + assert_eq!(store.len(), 0); + assert!(store.is_empty()); + assert!(!store.compare_and_swap(&[], None, value.clone())); + store.remove(&[]); + + let bytes = store.export_all(); + assert!(bytes.is_empty()); + + let mut imported = new_store(); + imported.import_all(bytes).unwrap(); + assert_eq!(imported.get(&[]), None); + assert_eq!(imported.len(), 0); + } + #[test] fn same_key() { let mut store = new_store(); @@ -649,6 +681,41 @@ mod tests { assert_eq!(store.get(&e), Some(e.clone())); } + #[test] + fn large_config_splits_normal_blocks_before_u16_offsets_overflow() { + let mut store = MemKvStore::new( + MemKvConfig::default() + .block_size(usize::from(u16::MAX) * 2) + .should_encode_none(true), + ); + let a = Bytes::from(vec![1; 40_000]); + let b = Bytes::from(vec![2; 40_000]); + store.set(b"a", a.clone()); + store.set(b"b", b.clone()); + + let bytes = store.export_all(); + let mut imported = new_store(); + imported.import_all(bytes).unwrap(); + assert_eq!(imported.get(b"a"), Some(a)); + assert_eq!(imported.get(b"b"), Some(b)); + } + + #[test] + fn large_config_uses_large_block_for_values_bigger_than_normal_offsets() { + let mut store = MemKvStore::new( + MemKvConfig::default() + .block_size(usize::from(u16::MAX) * 2) + .should_encode_none(true), + ); + let value = Bytes::from(vec![7; usize::from(u16::MAX) + 1]); + store.set(b"large", value.clone()); + + let bytes = store.export_all(); + let mut imported = new_store(); + imported.import_all(bytes).unwrap(); + assert_eq!(imported.get(b"large"), Some(value)); + } + fn new_store() -> MemKvStore { MemKvStore::new(MemKvConfig::default().should_encode_none(true)) } diff --git a/crates/kv-store/src/sstable.rs b/crates/kv-store/src/sstable.rs index 52e47320d..4ebd18fc3 100644 --- a/crates/kv-store/src/sstable.rs +++ b/crates/kv-store/src/sstable.rs @@ -178,6 +178,10 @@ impl SsTableBuilder { } pub fn add(&mut self, key: Bytes, value: Bytes) { + if key.is_empty() { + return; + } + if !self.include_none && value.is_empty() { return; } @@ -344,7 +348,7 @@ impl SsTable { /// - [LoroError::DecodeError] /// - "Invalid magic number" /// - "Invalid schema version" - pub fn import_all(bytes: Bytes, check_checksum: bool) -> LoroResult { + pub fn import_all(bytes: Bytes, _check_checksum: bool) -> LoroResult { // magic number + schema version + meta offset if bytes.len() < SIZE_OF_U32 + SIZE_OF_U8 + SIZE_OF_U32 { return Err(LoroError::DecodeError("Invalid sstable bytes".into())); @@ -372,9 +376,9 @@ impl SsTable { } let raw_meta = &bytes[meta_offset..data_len - SIZE_OF_U32]; let meta = BlockMeta::decode_meta(raw_meta)?; - if check_checksum { - Self::check_block_checksum(&meta, &bytes, meta_offset)?; - } + Self::validate_block_ranges(&meta, meta_offset)?; + Self::validate_blocks(&meta, &bytes, meta_offset)?; + Self::check_block_checksum(&meta, &bytes, meta_offset)?; let first_key = meta .first() .map(|m| m.first_key.clone()) @@ -398,6 +402,61 @@ impl SsTable { Ok(ans) } + fn validate_block_ranges(meta: &[BlockMeta], meta_offset: usize) -> LoroResult<()> { + if meta.is_empty() { + return Err(LoroError::DecodeError("Invalid bytes".into())); + } + + for i in 0..meta.len() { + let offset = meta[i].offset; + let offset_end = meta.get(i + 1).map_or(meta_offset, |m| m.offset); + if offset < SIZE_OF_U32 + SIZE_OF_U8 + || offset_end > meta_offset + || offset >= offset_end + || offset_end - offset < SIZE_OF_U32 + { + return Err(LoroError::DecodeError("Invalid bytes".into())); + } + } + + Ok(()) + } + + fn validate_blocks(meta: &[BlockMeta], bytes: &Bytes, meta_offset: usize) -> LoroResult<()> { + let mut last_key = None; + for i in 0..meta.len() { + let offset = meta[i].offset; + let offset_end = meta.get(i + 1).map_or(meta_offset, |m| m.offset); + let raw_block_and_check = bytes.slice(offset..offset_end); + let block = Block::try_decode( + raw_block_and_check, + meta[i].is_large, + meta[i].first_key.clone(), + meta[i].compression_type, + )?; + + let block_last_key = block.last_key(); + if meta[i].last_key.as_ref().unwrap_or(&meta[i].first_key) != &block_last_key { + return Err(LoroError::DecodeError("Invalid bytes".into())); + } + + if meta[i].first_key > block_last_key { + return Err(LoroError::DecodeError("Invalid bytes".into())); + } + + if last_key + .as_ref() + .is_some_and(|last_key| last_key >= &meta[i].first_key) + { + return Err(LoroError::DecodeError("Invalid bytes".into())); + } + + last_key = Some(block_last_key); + } + + Ok(()) + } + fn check_block_checksum( meta: &[BlockMeta], bytes: &Bytes, @@ -897,6 +956,45 @@ mod test { use super::*; use std::sync::Arc; + + fn malformed_sstable_bytes(block_bytes: &[u8], meta: &[BlockMeta]) -> Bytes { + let mut bytes = Vec::new(); + bytes.extend_from_slice(&MAGIC_BYTES); + bytes.push(CURRENT_SCHEMA_VERSION); + bytes.extend_from_slice(block_bytes); + let meta_offset = bytes.len(); + BlockMeta::encode_meta(meta, &mut bytes); + bytes.extend_from_slice(&(meta_offset as u32).to_le_bytes()); + bytes.into() + } + + fn normal_block_bytes(key: &[u8], value: &[u8]) -> Vec { + let mut builder = BlockBuilder::new(4096); + builder.add(key, value); + let block = builder.build(); + let mut bytes = Vec::new(); + block.encode(&mut bytes, CompressionType::None); + bytes + } + + fn normal_block_bytes_from_pairs(pairs: &[(&[u8], &[u8])]) -> Vec { + let mut builder = BlockBuilder::new(4096); + for (key, value) in pairs { + assert!(builder.add(key, value)); + } + let block = builder.build(); + let mut bytes = Vec::new(); + block.encode(&mut bytes, CompressionType::None); + bytes + } + + fn large_block_bytes(value: &[u8]) -> Vec { + let mut bytes = value.to_vec(); + let checksum = xxhash_rust::xxh32::xxh32(value, XXH_SEED); + bytes.extend_from_slice(&checksum.to_le_bytes()); + bytes + } + #[test] fn block_double_end_iter() { let mut builder = BlockBuilder::new(4096); @@ -1137,4 +1235,140 @@ mod test { buffer[11] = 123; assert!(SsTable::import_all(buffer.into(), true).is_err()); } + + #[test] + fn sstable_import_rejects_empty_meta() { + assert!(SsTable::import_all(malformed_sstable_bytes(&[], &[]), false).is_err()); + } + + #[test] + fn sstable_import_rejects_invalid_block_ranges() { + let first_key = Bytes::from_static(b"key"); + let meta = [BlockMeta { + offset: SIZE_OF_U32 + SIZE_OF_U8, + is_large: false, + compression_type: CompressionType::None, + first_key: first_key.clone(), + last_key: Some(first_key.clone()), + }]; + assert!(SsTable::import_all(malformed_sstable_bytes(&[0, 1, 2], &meta), false).is_err()); + + let meta = [BlockMeta { + offset: SIZE_OF_U32 + SIZE_OF_U8 + 8, + is_large: false, + compression_type: CompressionType::None, + first_key: first_key.clone(), + last_key: Some(first_key), + }]; + assert!(SsTable::import_all(malformed_sstable_bytes(&[0, 1, 2, 3], &meta), false).is_err()); + } + + #[test] + fn sstable_import_rejects_undecodable_block_payload() { + let first_key = Bytes::from_static(b"key"); + let meta = [BlockMeta { + offset: SIZE_OF_U32 + SIZE_OF_U8, + is_large: false, + compression_type: CompressionType::None, + first_key: first_key.clone(), + last_key: Some(first_key), + }]; + let checksum_only = xxhash_rust::xxh32::xxh32(&[], XXH_SEED).to_le_bytes(); + assert!( + SsTable::import_all(malformed_sstable_bytes(&checksum_only, &meta), false).is_err() + ); + } + + #[test] + fn sstable_import_rejects_meta_last_key_that_mismatches_block() { + let first_key = Bytes::from_static(b"key"); + let meta = [BlockMeta { + offset: SIZE_OF_U32 + SIZE_OF_U8, + is_large: false, + compression_type: CompressionType::None, + first_key: first_key.clone(), + last_key: Some(Bytes::from_static(b"a")), + }]; + + assert!(SsTable::import_all( + malformed_sstable_bytes(&normal_block_bytes(b"key", b"value"), &meta), + false + ) + .is_err()); + } + + #[test] + fn sstable_import_rejects_unsorted_block_key_ranges() { + let first_block = normal_block_bytes(b"b", b"value"); + let second_block = normal_block_bytes(b"a", b"value"); + let second_offset = SIZE_OF_U32 + SIZE_OF_U8 + first_block.len(); + let meta = [ + BlockMeta { + offset: SIZE_OF_U32 + SIZE_OF_U8, + is_large: false, + compression_type: CompressionType::None, + first_key: Bytes::from_static(b"b"), + last_key: Some(Bytes::from_static(b"b")), + }, + BlockMeta { + offset: second_offset, + is_large: false, + compression_type: CompressionType::None, + first_key: Bytes::from_static(b"a"), + last_key: Some(Bytes::from_static(b"a")), + }, + ]; + let mut block_bytes = first_block; + block_bytes.extend_from_slice(&second_block); + + assert!(SsTable::import_all(malformed_sstable_bytes(&block_bytes, &meta), false).is_err()); + } + + #[test] + fn sstable_import_rejects_unsorted_keys_inside_block() { + let block_bytes = + normal_block_bytes_from_pairs(&[(b"a", b"1"), (b"c", b"2"), (b"b", b"3")]); + let meta = [BlockMeta { + offset: SIZE_OF_U32 + SIZE_OF_U8, + is_large: false, + compression_type: CompressionType::None, + first_key: Bytes::from_static(b"a"), + last_key: Some(Bytes::from_static(b"b")), + }]; + + assert!(SsTable::import_all(malformed_sstable_bytes(&block_bytes, &meta), false).is_err()); + } + + #[test] + fn sstable_import_rejects_empty_large_block_key() { + let meta = [BlockMeta { + offset: SIZE_OF_U32 + SIZE_OF_U8, + is_large: true, + compression_type: CompressionType::None, + first_key: Bytes::new(), + last_key: None, + }]; + + assert!(SsTable::import_all( + malformed_sstable_bytes(&large_block_bytes(b"value"), &meta), + false + ) + .is_err()); + } + + #[test] + fn sstable_import_rejects_block_checksum_mismatch_when_outer_checksum_is_skipped() { + let first_key = Bytes::from_static(b"key"); + let mut block_bytes = normal_block_bytes(b"key", b"value"); + *block_bytes.last_mut().unwrap() ^= 0xff; + let meta = [BlockMeta { + offset: SIZE_OF_U32 + SIZE_OF_U8, + is_large: false, + compression_type: CompressionType::None, + first_key: first_key.clone(), + last_key: Some(first_key), + }]; + + assert!(SsTable::import_all(malformed_sstable_bytes(&block_bytes, &meta), false).is_err()); + } } diff --git a/crates/loro-common/src/lib.rs b/crates/loro-common/src/lib.rs index a7b6ba4a3..81f10ed12 100644 --- a/crates/loro-common/src/lib.rs +++ b/crates/loro-common/src/lib.rs @@ -225,30 +225,63 @@ impl ContainerID { } pub fn from_bytes(bytes: &[u8]) -> Self { + Self::try_from_bytes(bytes).unwrap() + } + + pub fn try_from_bytes(bytes: &[u8]) -> LoroResult { + if bytes.is_empty() { + return Err(LoroError::DecodeError( + "Decode container id failed".to_string().into_boxed_str(), + )); + } + let first_byte = bytes[0]; - let container_type = ContainerType::try_from_u8(first_byte & 0b01111111).unwrap(); + let container_type = ContainerType::try_from_u8(first_byte & 0b01111111)?; let is_root = (first_byte & 0b10000000) != 0; let mut reader = &bytes[1..]; match is_root { true => { - let name_len = leb128::read::unsigned(&mut reader).unwrap(); - let name = InternalString::from( - std::str::from_utf8(&reader[..name_len as usize]).unwrap(), - ); - Self::Root { - name, - container_type, + let name_len = leb128::read::unsigned(&mut reader).map_err(|_| { + LoroError::DecodeError( + "Decode container id failed".to_string().into_boxed_str(), + ) + })?; + let name_len = usize::try_from(name_len).map_err(|_| { + LoroError::DecodeError( + "Decode container id failed".to_string().into_boxed_str(), + ) + })?; + if reader.len() != name_len { + return Err(LoroError::DecodeError( + "Decode container id failed".to_string().into_boxed_str(), + )); } + + let name = std::str::from_utf8(&reader[..name_len]).map_err(|_| { + LoroError::DecodeError( + "Decode container id failed".to_string().into_boxed_str(), + ) + })?; + Ok(Self::Root { + name: InternalString::from(name), + container_type, + }) } false => { + if reader.len() != 12 { + return Err(LoroError::DecodeError( + "Decode container id failed".to_string().into_boxed_str(), + )); + } + let peer = PeerID::from_le_bytes(reader[..8].try_into().unwrap()); let counter = i32::from_le_bytes(reader[8..12].try_into().unwrap()); - Self::Normal { + Ok(Self::Normal { peer, counter, container_type, - } + }) } } } @@ -794,4 +827,17 @@ mod test { let bytes = id.to_bytes(); assert_eq!(ContainerID::from_bytes(&bytes), id); } + + #[test] + fn container_id_try_from_bytes_rejects_trailing_bytes() { + let normal = ContainerID::new_normal(ID::new(1, 2), ContainerType::Map); + let mut normal_bytes = normal.to_bytes(); + normal_bytes.push(0); + assert!(ContainerID::try_from_bytes(&normal_bytes).is_err()); + + let root = ContainerID::new_root("root", ContainerType::List); + let mut root_bytes = root.to_bytes(); + root_bytes.push(0); + assert!(ContainerID::try_from_bytes(&root_bytes).is_err()); + } } diff --git a/crates/loro-common/src/value.rs b/crates/loro-common/src/value.rs index f623ff0d9..39c18c3a7 100644 --- a/crates/loro-common/src/value.rs +++ b/crates/loro-common/src/value.rs @@ -10,7 +10,7 @@ use crate::ContainerID; /// [LoroValue] is used to represents the state of CRDT at a given version. /// /// This struct is cheap to clone, the time complexity is O(1). -#[derive(Debug, PartialEq, Clone, EnumAsInner, Default)] +#[derive(Debug, Clone, EnumAsInner, Default)] pub enum LoroValue { #[default] Null, @@ -26,6 +26,23 @@ pub enum LoroValue { Container(ContainerID), } +impl PartialEq for LoroValue { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (LoroValue::Null, LoroValue::Null) => true, + (LoroValue::Bool(a), LoroValue::Bool(b)) => a == b, + (LoroValue::Double(a), LoroValue::Double(b)) => a == b || (a.is_nan() && b.is_nan()), + (LoroValue::I64(a), LoroValue::I64(b)) => a == b, + (LoroValue::Binary(a), LoroValue::Binary(b)) => a == b, + (LoroValue::String(a), LoroValue::String(b)) => a == b, + (LoroValue::List(a), LoroValue::List(b)) => a == b, + (LoroValue::Map(a), LoroValue::Map(b)) => a == b, + (LoroValue::Container(a), LoroValue::Container(b)) => a == b, + _ => false, + } + } +} + #[derive(Default, Debug, PartialEq, Clone, Arbitrary)] pub struct LoroBinaryValue(Arc>); #[derive(Default, Debug, PartialEq, Clone, Arbitrary)] @@ -231,9 +248,10 @@ impl LoroValue { match self { LoroValue::List(list) => { if index < 0 { - list.get(list.len() - (-index) as usize) + let offset = usize::try_from(index.checked_neg()?).ok()?; + list.get(list.len().checked_sub(offset)?) } else { - list.get(index as usize) + list.get(usize::try_from(index).ok()?) } } _ => None, @@ -359,7 +377,7 @@ impl TryFrom for i32 { fn try_from(value: LoroValue) -> Result { match value { - LoroValue::I64(v) => Ok(v as i32), + LoroValue::I64(v) => i32::try_from(v).map_err(|_| "i64 out of i32 range"), _ => Err("not a i32"), } } @@ -429,7 +447,14 @@ impl Hash for LoroValue { state.write_u8(*v as u8); } LoroValue::Double(v) => { - state.write_u64(v.to_bits()); + let normalized = if v.is_nan() { + f64::NAN + } else if *v == 0.0 { + 0.0 + } else { + *v + }; + state.write_u64(normalized.to_bits()); } LoroValue::I64(v) => { state.write_i64(*v); @@ -445,7 +470,9 @@ impl Hash for LoroValue { } LoroValue::Map(v) => { state.write_usize(v.len()); - for (k, v) in v.iter() { + let mut entries: Vec<_> = v.iter().collect(); + entries.sort_unstable_by_key(|(k, _)| *k); + for (k, v) in entries { k.hash(state); v.hash(state); } @@ -802,7 +829,9 @@ impl<'de> serde::de::Visitor<'de> for LoroValueVisitor { where E: serde::de::Error, { - Ok(LoroValue::I64(v as i64)) + Ok(i64::try_from(v) + .map(LoroValue::I64) + .unwrap_or_else(|_| LoroValue::Double(v as f64))) } fn visit_f64(self, v: f64) -> Result @@ -969,7 +998,7 @@ mod serde_json_impl { match value { LoroValue::Null => Value::Null, LoroValue::Bool(b) => Value::Bool(b), - LoroValue::Double(d) => Value::Number(Number::from_f64(d).unwrap()), + LoroValue::Double(d) => Number::from_f64(d).map_or(Value::Null, Value::Number), LoroValue::I64(i) => Value::Number(Number::from(i)), LoroValue::String(s) => Value::String(s.to_string()), LoroValue::List(l) => Value::Array(l.iter().cloned().map(Value::from).collect()), diff --git a/crates/loro-internal/src/awareness.rs b/crates/loro-internal/src/awareness.rs index 148932a32..581b2783b 100644 --- a/crates/loro-internal/src/awareness.rs +++ b/crates/loro-internal/src/awareness.rs @@ -108,10 +108,15 @@ impl Awareness { postcard::to_allocvec(&peers_info).unwrap() } - /// Returns (updated, added) - pub fn apply(&mut self, encoded_peers_info: &[u8]) -> (Vec, Vec) { - let peers_info: Vec = - postcard::from_bytes(encoded_peers_info).expect("Failed to decode awareness data"); + /// Try to apply encoded updates imported from another peer/process. + /// + /// Returns (updated, added). + pub fn try_apply( + &mut self, + encoded_peers_info: &[u8], + ) -> Result<(Vec, Vec), Box> { + let peers_info: Vec = postcard::from_bytes(encoded_peers_info) + .map_err(|err| format!("Failed to decode awareness data: {err}").into_boxed_str())?; let mut changed_peers = Vec::new(); let mut added_peers = Vec::new(); let now = get_sys_timestamp() as Timestamp; @@ -138,7 +143,13 @@ impl Awareness { } } - (changed_peers, added_peers) + Ok((changed_peers, added_peers)) + } + + /// Returns (updated, added) + pub fn apply(&mut self, encoded_peers_info: &[u8]) -> (Vec, Vec) { + self.try_apply(encoded_peers_info) + .expect("Failed to decode awareness data") } pub fn set_local_state(&mut self, value: impl Into) { diff --git a/crates/loro-internal/src/encoding/arena.rs b/crates/loro-internal/src/encoding/arena.rs index b47670d7c..5ede1b717 100644 --- a/crates/loro-internal/src/encoding/arena.rs +++ b/crates/loro-internal/src/encoding/arena.rs @@ -189,22 +189,43 @@ impl<'a> PositionArena<'a> { Self { positions: ans } } - pub fn parse_to_positions(self) -> Vec> { + pub fn try_parse_to_positions(self) -> LoroResult>> { let mut ans: Vec> = Vec::with_capacity(self.positions.len()); for PositionDelta { common_prefix_length, rest, } in self.positions { + if let Some(last_bytes) = ans.last() { + if common_prefix_length > last_bytes.len() { + return Err(LoroError::DecodeError( + "Decode position arena failed: common prefix out of range".into(), + )); + } + } else if common_prefix_length != 0 { + return Err(LoroError::DecodeError( + "Decode position arena failed: first common prefix is nonzero".into(), + )); + } + // +1 for Fractional Index - let mut p = Vec::with_capacity(rest.len() + common_prefix_length + 1); + let capacity = rest + .len() + .checked_add(common_prefix_length) + .and_then(|len| len.checked_add(1)) + .ok_or_else(|| { + LoroError::DecodeError( + "Decode position arena failed: position length overflow".into(), + ) + })?; + let mut p = Vec::with_capacity(capacity); if let Some(last_bytes) = ans.last() { p.extend_from_slice(&last_bytes[0..common_prefix_length]); } p.extend_from_slice(rest.as_ref()); ans.push(p); } - ans + Ok(ans) } pub fn encode(&self) -> Vec { @@ -235,3 +256,22 @@ impl<'a> PositionArena<'a> { fn longest_common_prefix_length(a: &[u8], b: &[u8]) -> usize { a.iter().zip(b.iter()).take_while(|(x, y)| x == y).count() } + +#[cfg(test)] +mod tests { + use std::borrow::Cow; + + use super::{PositionArena, PositionDelta}; + + #[test] + fn position_arena_rejects_invalid_common_prefix() { + let arena = PositionArena { + positions: vec![PositionDelta { + common_prefix_length: 1, + rest: Cow::Borrowed(&[]), + }], + }; + + assert!(arena.try_parse_to_positions().is_err()); + } +} diff --git a/crates/loro-internal/src/encoding/fast_snapshot.rs b/crates/loro-internal/src/encoding/fast_snapshot.rs index 1072ff2f8..2ebf60a71 100644 --- a/crates/loro-internal/src/encoding/fast_snapshot.rs +++ b/crates/loro-internal/src/encoding/fast_snapshot.rs @@ -79,6 +79,14 @@ pub(super) fn _decode_snapshot_bytes(bytes: Bytes) -> LoroResult { )); } let shallow_root_state_bytes = r.get_mut().copy_to_bytes(shallow_bytes_len); + if r.get_ref().has_remaining() { + return Err(LoroError::DecodeError( + "decode_snapshot: trailing bytes after snapshot" + .to_string() + .into_boxed_str(), + )); + } + Ok(Snapshot { oplog_bytes, state_bytes, @@ -100,6 +108,14 @@ pub(super) fn _decode_snapshot_meta_partial(bytes: &[u8]) -> LoroResult<(&[u8], } r = &r[state_bytes_len..]; let shallow_bytes_len = read_u32_le_slice(&mut r)? as usize; + if r.len() < shallow_bytes_len { + return Err(LoroError::DecodeDataCorruptionError); + } + r = &r[shallow_bytes_len..]; + if !r.is_empty() { + return Err(LoroError::DecodeDataCorruptionError); + } + Ok((oplog_bytes, shallow_bytes_len > 0)) } @@ -214,9 +230,6 @@ pub(crate) fn decode_snapshot_inner( } } - // FIXME: we may need to extract the unknown containers here? - // Or we should lazy load it when the time comes? - state.init_with_states_and_version(state_frontiers, &oplog, vec![], false, origin)?; Ok(()) })(); diff --git a/crates/loro-internal/src/encoding/json_schema.rs b/crates/loro-internal/src/encoding/json_schema.rs index cbb293d28..2f846c313 100644 --- a/crates/loro-internal/src/encoding/json_schema.rs +++ b/crates/loro-internal/src/encoding/json_schema.rs @@ -17,8 +17,8 @@ use either::Either; use itertools::Itertools; use json::{JsonChange, JsonOpContent, JsonSchema}; use loro_common::{ - ContainerID, ContainerType, HasCounterSpan, HasId, IdLp, IdSpan, LoroError, LoroResult, - LoroValue, PeerID, TreeID, ID, + ContainerID, ContainerType, Counter, HasCounterSpan, HasId, IdLp, IdSpan, LoroError, + LoroResult, LoroValue, PeerID, TreeID, ID, }; use rle::{HasLength, RleVec, Sliceable}; use std::sync::Arc; @@ -28,7 +28,7 @@ const SCHEMA_VERSION: u8 = 1; fn refine_vv(vv: &VersionVector, oplog: &OpLog) -> VersionVector { let mut refined = VersionVector::new(); for (&peer, &counter) in vv.iter() { - if counter == 0 { + if counter <= 0 { continue; } let end = oplog.vv().get(&peer).copied().unwrap_or(0); @@ -75,6 +75,12 @@ pub(crate) fn export_json<'a, 'c: 'a>( pub(crate) fn export_json_in_id_span(oplog: &OpLog, mut id_span: IdSpan) -> Vec { id_span.normalize_(); + if id_span.counter.end <= 0 { + return vec![]; + } + id_span.counter.start = id_span.counter.start.max(0); + id_span.counter.end = id_span.counter.end.max(0); + let end = oplog.vv().get(&id_span.peer).copied().unwrap_or(0); if id_span.counter.start >= end { return vec![]; @@ -82,7 +88,7 @@ pub(crate) fn export_json_in_id_span(oplog: &OpLog, mut id_span: IdSpan) -> Vec< id_span.counter.end = id_span.counter.end.min(end); let mut diff_changes: Vec> = Vec::new(); - while id_span.counter.end - id_span.counter.start > 0 { + while id_span.counter.start < id_span.counter.end { let Some(change) = oplog.get_change_at(id_span.id_start()) else { break; }; @@ -530,7 +536,22 @@ pub(crate) fn encode_change( } fn decode_changes(json: JsonSchema, arena: &SharedArena) -> LoroResult> { - let JsonSchema { peers, changes, .. } = json; + let JsonSchema { + schema_version, + start_version, + peers, + changes, + } = json; + if schema_version != SCHEMA_VERSION { + return Err(LoroError::DecodeError( + format!( + "unsupported json schema version: expected {SCHEMA_VERSION}, got {schema_version}" + ) + .into_boxed_str(), + )); + } + validate_json_frontiers(&start_version)?; + let mut ans = Vec::with_capacity(changes.len()); for json::JsonChange { id, @@ -542,19 +563,35 @@ fn decode_changes(json: JsonSchema, arena: &SharedArena) -> LoroResult = RleVec::new(); + let mut expected_counter = id.counter; + if json_ops.is_empty() { + return Err(LoroError::DecodeError( + "invalid json change: change must contain at least one op".into(), + )); + } + for op in json_ops { - ops.push(decode_op(op, arena, &peers)?); + let next_counter = validate_json_op_counter(expected_counter, &op)?; + let op = decode_op(op, arena, &peers)?; + validate_json_op_created_container_ids(id.peer, &op, arena)?; + ops.push(op); + expected_counter = next_counter; + } + + let deps = deps + .into_iter() + .map(|id| convert_id(&id, &peers)) + .collect::>>()?; + for dep in &deps { + validate_json_id_counter(*dep, "dependency id")?; } let change = Change { id, timestamp, - deps: Frontiers::from_iter( - deps.into_iter() - .map(|id| convert_id(&id, &peers)) - .collect::>>()?, - ), + deps: Frontiers::from_iter(deps), lamport, ops, commit_msg: msg.map(|x| x.into()), @@ -564,6 +601,187 @@ fn decode_changes(json: JsonSchema, arena: &SharedArena) -> LoroResult LoroResult<()> { + for id in frontiers.iter() { + validate_json_id_counter(id, "start version id")?; + } + + Ok(()) +} + +fn validate_json_id_counter(id: ID, name: &str) -> LoroResult<()> { + if id.counter < 0 { + return Err(LoroError::DecodeError( + format!("invalid json counter: {name} counter must be non-negative").into_boxed_str(), + )); + } + + Ok(()) +} + +fn validate_json_tree_id_counter(id: &TreeID, name: &str) -> LoroResult<()> { + if id.counter < 0 { + return Err(LoroError::DecodeError( + format!("invalid json counter: {name} counter must be non-negative").into_boxed_str(), + )); + } + + Ok(()) +} + +fn validate_json_container_id_counter(id: &ContainerID, name: &str) -> LoroResult<()> { + if let ContainerID::Normal { counter, .. } = id { + if *counter < 0 { + return Err(LoroError::DecodeError( + format!("invalid json counter: {name} counter must be non-negative") + .into_boxed_str(), + )); + } + } + + Ok(()) +} + +fn validate_json_op_counter(expected_counter: Counter, op: &json::JsonOp) -> LoroResult { + let op_len = op.content.op_len(); + if op_len == 0 { + return Err(LoroError::DecodeError( + "invalid json op counter: op length must be greater than zero".into(), + )); + } + + if expected_counter < 0 || op.counter < 0 { + return Err(LoroError::DecodeError( + "invalid json op counter: op counters must be non-negative".into(), + )); + } + + let op_len = Counter::try_from(op_len).map_err(|_| { + LoroError::DecodeError("invalid json op counter: op length is too large".into()) + })?; + if op.counter != expected_counter { + return Err(LoroError::DecodeError( + "invalid json op counter: op counters must be contiguous with the change id".into(), + )); + } + + expected_counter + .checked_add(op_len) + .ok_or_else(|| LoroError::DecodeError("invalid json op counter: counter overflow".into())) +} + +fn validate_json_op_created_container_ids( + peer: PeerID, + op: &Op, + arena: &SharedArena, +) -> LoroResult<()> { + match &op.content { + InnerContent::List(list) => match list { + InnerListOp::Insert { slice, .. } => { + for (offset, value) in arena.iter_value_slice(slice.to_range()).enumerate() { + let LoroValue::Container(id) = value else { + validate_json_value_has_no_container_refs(&value)?; + continue; + }; + let offset = Counter::try_from(offset).map_err(|_| { + LoroError::DecodeError( + "invalid json container id: list offset is too large".into(), + ) + })?; + let counter = op.counter.checked_add(offset).ok_or_else(|| { + LoroError::DecodeError("invalid json container id: counter overflow".into()) + })?; + validate_json_created_container_id(&id, ID::new(peer, counter))?; + } + } + InnerListOp::Set { value, .. } => { + if let LoroValue::Container(id) = value { + validate_json_created_container_id(id, ID::new(peer, op.counter))?; + } else { + validate_json_value_has_no_container_refs(value)?; + } + } + InnerListOp::Move { .. } + | InnerListOp::InsertText { .. } + | InnerListOp::Delete(_) + | InnerListOp::StyleEnd => {} + InnerListOp::StyleStart { value, .. } => { + validate_json_value_has_no_container_refs(value)?; + } + }, + InnerContent::Map(map) => { + if let Some(value) = &map.value { + if let LoroValue::Container(id) = value { + validate_json_created_container_id(id, ID::new(peer, op.counter))?; + } else { + validate_json_value_has_no_container_refs(value)?; + } + } + } + InnerContent::Tree(tree) => { + if let TreeOp::Create { target, .. } = tree.as_ref() { + validate_json_tree_create_target(target, ID::new(peer, op.counter))?; + } + } + InnerContent::Future(_) => {} + } + + Ok(()) +} + +fn validate_json_value_has_no_container_refs(value: &LoroValue) -> LoroResult<()> { + let mut stack = vec![value]; + while let Some(value) = stack.pop() { + match value { + LoroValue::Container(_) => { + return Err(LoroError::DecodeError( + "invalid json container id: container values must not be nested".into(), + )); + } + LoroValue::List(list) => { + stack.extend(list.iter()); + } + LoroValue::Map(map) => { + stack.extend(map.values()); + } + LoroValue::Null + | LoroValue::Bool(_) + | LoroValue::Double(_) + | LoroValue::I64(_) + | LoroValue::Binary(_) + | LoroValue::String(_) => {} + } + } + + Ok(()) +} + +fn validate_json_created_container_id(id: &ContainerID, expected_id: ID) -> LoroResult<()> { + let ContainerID::Normal { peer, counter, .. } = id else { + return Err(LoroError::DecodeError( + "invalid json container id: created child containers must be normal".into(), + )); + }; + + if *peer != expected_id.peer || *counter != expected_id.counter { + return Err(LoroError::DecodeError( + "invalid json container id: created child container id must match the op id".into(), + )); + } + + Ok(()) +} + +fn validate_json_tree_create_target(target: &TreeID, expected_id: ID) -> LoroResult<()> { + if target.peer != expected_id.peer || target.counter != expected_id.counter { + return Err(LoroError::DecodeError( + "invalid json tree target: tree create target must match the op id".into(), + )); + } + + Ok(()) +} + fn decode_op(op: json::JsonOp, arena: &SharedArena, peers: &Option>) -> LoroResult { let json::JsonOp { counter, @@ -571,6 +789,7 @@ fn decode_op(op: json::JsonOp, arena: &SharedArena, peers: &Option>) content, } = op; let container = convert_container_id(container, peers)?; + validate_json_container_id_counter(&container, "op container")?; let idx = arena.register_container(&container); let content = match container.container_type() { ContainerType::Text => match content { @@ -590,6 +809,7 @@ fn decode_op(op: json::JsonOp, arena: &SharedArena, peers: &Option>) start_id: id_start, } => { let id_start = convert_id(&id_start, peers)?; + validate_json_id_counter(id_start, "text delete start id")?; InnerContent::List(InnerListOp::Delete(DeleteSpanWithId { id_start, span: DeleteSpan { @@ -639,8 +859,10 @@ fn decode_op(op: json::JsonOp, arena: &SharedArena, peers: &Option>) }) } json::ListOp::Delete { pos, len, start_id } => { + let start_id = convert_id(&start_id, peers)?; + validate_json_id_counter(start_id, "list delete start id")?; InnerContent::List(InnerListOp::Delete(DeleteSpanWithId { - id_start: convert_id(&start_id, peers)?, + id_start: start_id, span: DeleteSpan { pos: pos as isize, signed_len: len as isize, @@ -674,8 +896,10 @@ fn decode_op(op: json::JsonOp, arena: &SharedArena, peers: &Option>) }) } json::MovableListOp::Delete { pos, len, start_id } => { + let start_id = convert_id(&start_id, peers)?; + validate_json_id_counter(start_id, "movable list delete start id")?; InnerContent::List(InnerListOp::Delete(DeleteSpanWithId { - id_start: convert_id(&start_id, peers)?, + id_start: start_id, span: DeleteSpan { pos: pos as isize, signed_len: len as isize, @@ -736,29 +960,49 @@ fn decode_op(op: json::JsonOp, arena: &SharedArena, peers: &Option>) target, parent, fractional_index, - } => InnerContent::Tree(Arc::new(TreeOp::Create { - target: convert_tree_id(&target, peers)?, - parent: match parent { - Some(p) => Some(convert_tree_id(&p, peers)?), + } => { + let target = convert_tree_id(&target, peers)?; + validate_json_tree_id_counter(&target, "tree create target")?; + let parent = match parent { + Some(p) => { + let parent = convert_tree_id(&p, peers)?; + validate_json_tree_id_counter(&parent, "tree create parent")?; + Some(parent) + } None => None, - }, - position: fractional_index, - })), + }; + InnerContent::Tree(Arc::new(TreeOp::Create { + target, + parent, + position: fractional_index, + })) + } json::TreeOp::Move { target, parent, fractional_index, - } => InnerContent::Tree(Arc::new(TreeOp::Move { - target: convert_tree_id(&target, peers)?, - parent: match parent { - Some(p) => Some(convert_tree_id(&p, peers)?), + } => { + let target = convert_tree_id(&target, peers)?; + validate_json_tree_id_counter(&target, "tree move target")?; + let parent = match parent { + Some(p) => { + let parent = convert_tree_id(&p, peers)?; + validate_json_tree_id_counter(&parent, "tree move parent")?; + Some(parent) + } None => None, - }, - position: fractional_index, - })), - json::TreeOp::Delete { target } => InnerContent::Tree(Arc::new(TreeOp::Delete { - target: convert_tree_id(&target, peers)?, - })), + }; + InnerContent::Tree(Arc::new(TreeOp::Move { + target, + parent, + position: fractional_index, + })) + } + json::TreeOp::Delete { target } => { + let target = convert_tree_id(&target, peers)?; + validate_json_tree_id_counter(&target, "tree delete target")?; + InnerContent::Tree(Arc::new(TreeOp::Delete { target })) + } }, _ => { return Err(LoroError::DecodeError( @@ -871,11 +1115,21 @@ pub mod json { } impl JsonChange { - pub fn op_len(&self) -> usize { + pub fn checked_op_len(&self) -> Option { let Some(last_op) = self.ops.last() else { - return 0; + return Some(0); }; - (last_op.counter - self.id.counter) as usize + last_op.content.op_len() + if last_op.counter < self.id.counter { + return None; + } + + let counter_delta = last_op.counter.checked_sub(self.id.counter)?; + let counter_delta = usize::try_from(counter_delta).ok()?; + counter_delta.checked_add(last_op.content.op_len()) + } + + pub fn op_len(&self) -> usize { + self.checked_op_len().unwrap_or(0) } } @@ -1057,7 +1311,7 @@ pub mod json { use loro_common::{ContainerID, ContainerType}; use serde::{ - de::{MapAccess, Visitor}, + de::{Error, IgnoredAny, MapAccess, Visitor}, ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer, }; @@ -1095,84 +1349,45 @@ pub mod json { where A: MapAccess<'de>, { - let (_key, container) = map - .next_entry::()? - .ok_or_else(|| serde::de::Error::custom("missing container field"))?; - let is_unknown = container.ends_with(')'); - let container = ContainerID::try_from(container.as_str()) - .map_err(|_| serde::de::Error::custom("invalid container id"))?; - let op = if is_unknown { - let (_key, op) = map - .next_entry::()? - .ok_or_else(|| { - serde::de::Error::custom( - "missing op field for unknown container", - ) - })?; - super::JsonOpContent::Future(op) - } else { - match container.container_type() { - ContainerType::List => { - let (_key, op) = map - .next_entry::()? - .ok_or_else(|| { - serde::de::Error::custom("missing op field for list") - })?; - super::JsonOpContent::List(op) - } - ContainerType::MovableList => { - let (_key, op) = map - .next_entry::()? - .ok_or_else(|| { - serde::de::Error::custom( - "missing op field for movable list", - ) - })?; - super::JsonOpContent::MovableList(op) - } - ContainerType::Map => { - let (_key, op) = - map.next_entry::()?.ok_or_else( - || serde::de::Error::custom("missing op field for map"), - )?; - super::JsonOpContent::Map(op) + let mut container = None; + let mut content = None; + let mut counter = None; + + while let Some(key) = map.next_key::()? { + match key.as_str() { + "container" => { + if container.is_some() { + return Err(A::Error::duplicate_field("container")); + } + let value = map.next_value::()?; + container = + Some(ContainerID::try_from(value.as_str()).map_err( + |_| A::Error::custom("invalid container id"), + )?); } - ContainerType::Text => { - let (_key, op) = map - .next_entry::()? - .ok_or_else(|| { - serde::de::Error::custom("missing op field for text") - })?; - super::JsonOpContent::Text(op) + "content" => { + if content.is_some() { + return Err(A::Error::duplicate_field("content")); + } + content = Some(map.next_value::()?); } - ContainerType::Tree => { - let (_key, op) = map - .next_entry::()? - .ok_or_else(|| { - serde::de::Error::custom("missing op field for tree") - })?; - super::JsonOpContent::Tree(op) + "counter" => { + if counter.is_some() { + return Err(A::Error::duplicate_field("counter")); + } + counter = Some(map.next_value::()?); } - #[cfg(feature = "counter")] - ContainerType::Counter => { - let (_key, value) = map - .next_entry::()? - .ok_or_else(|| { - serde::de::Error::custom( - "missing value field for counter", - ) - })?; - super::JsonOpContent::Future(super::FutureOpWrapper { - prop: 0, - value: super::FutureOp::Counter(value), - }) + _ => { + let _ = map.next_value::()?; } - _ => unreachable!(), } - }; - let (_, counter) = map - .next_entry::()? - .ok_or_else(|| serde::de::Error::custom("missing counter field"))?; + } + + let container = + container.ok_or_else(|| A::Error::missing_field("container"))?; + let content = content.ok_or_else(|| A::Error::missing_field("content"))?; + let op = json_op_content_from_value(&container, content)?; + let counter = counter.ok_or_else(|| A::Error::missing_field("counter"))?; Ok(super::JsonOp { container, content: op, @@ -1185,6 +1400,46 @@ pub mod json { } } + fn json_op_content_from_value( + container: &ContainerID, + value: serde_json::Value, + ) -> Result { + match container.container_type() { + ContainerType::List => serde_json::from_value(value) + .map(super::JsonOpContent::List) + .map_err(E::custom), + ContainerType::MovableList => serde_json::from_value(value) + .map(super::JsonOpContent::MovableList) + .map_err(E::custom), + ContainerType::Map => serde_json::from_value(value) + .map(super::JsonOpContent::Map) + .map_err(E::custom), + ContainerType::Text => serde_json::from_value(value) + .map(super::JsonOpContent::Text) + .map_err(E::custom), + ContainerType::Tree => serde_json::from_value(value) + .map(super::JsonOpContent::Tree) + .map_err(E::custom), + ContainerType::Unknown(_) => serde_json::from_value(value) + .map(super::JsonOpContent::Future) + .map_err(E::custom), + #[cfg(feature = "counter")] + ContainerType::Counter => { + match serde_json::from_value::(value.clone()) { + Ok(op) => Ok(super::JsonOpContent::Future(op)), + Err(_) => { + let value = + serde_json::from_value::(value).map_err(E::custom)?; + Ok(super::JsonOpContent::Future(super::FutureOpWrapper { + prop: 0, + value: super::FutureOp::Counter(value), + })) + } + } + } + } + } + pub mod id { use loro_common::ID; use serde::{Deserialize, Deserializer, Serializer}; @@ -1277,13 +1532,12 @@ pub mod json { D: Deserializer<'de>, { let deps: Vec = Deserialize::deserialize(d)?; - Ok(deps - .into_iter() + deps.into_iter() .map(|x| { ID::try_from(x.as_str()) .map_err(|_| serde::de::Error::custom("invalid ID in deps")) }) - .collect::, _>>()?) + .collect::, _>>() } } @@ -1306,7 +1560,7 @@ pub mod json { D: Deserializer<'de>, { let peers: Option> = Deserialize::deserialize(d)?; - Ok(peers + peers .map(|x| { x.into_iter() .map(|x| { @@ -1315,7 +1569,7 @@ pub mod json { }) .collect::, _>>() }) - .transpose()?) + .transpose() } } @@ -1409,7 +1663,7 @@ pub mod json { D: Deserializer<'de>, { let str: String = Deserialize::deserialize(d)?; - if str.len() % 2 != 0 { + if !str.len().is_multiple_of(2) { return Err(serde::de::Error::custom( "invalid fractional index hex length", )); @@ -1436,6 +1690,47 @@ pub mod json { InvalidSchema(String), } + fn invalid_schema(message: impl Into) -> RedactError { + RedactError::InvalidSchema(message.into()) + } + + fn checked_redact_op_len(op: &JsonOp) -> Result { + let len = op.content.op_len(); + if len == 0 { + return Err(invalid_schema("op length must be greater than zero")); + } + + Counter::try_from(len).map_err(|_| invalid_schema("op length is too large")) + } + + fn validate_redactable_change(change: &JsonChange) -> Result { + if change.id.counter < 0 { + return Err(invalid_schema("change id counter must be non-negative")); + } + + let mut expected_counter = change.id.counter; + for op in &change.ops { + if op.counter < 0 { + return Err(invalid_schema("op counter must be non-negative")); + } + if op.counter != expected_counter { + return Err(invalid_schema( + "op counters must be contiguous with the change id", + )); + } + + let len = checked_redact_op_len(op)?; + expected_counter = expected_counter + .checked_add(len) + .ok_or_else(|| invalid_schema("op counter overflow"))?; + } + + let len = expected_counter + .checked_sub(change.id.counter) + .ok_or_else(|| invalid_schema("change counter overflow"))?; + usize::try_from(len).map_err(|_| invalid_schema("change length is too large")) + } + /// Redacts sensitive content within the specified range by replacing it with default values. /// /// This method applies the following redaction rules: @@ -1459,7 +1754,8 @@ pub mod json { let real_peer = get_peer_from_peers(&peers, change.id.peer) .map_err(|_| RedactError::InvalidSchema("peer index out of range".to_string()))?; let real_id = ID::new(real_peer, change.id.counter); - if !range.has_overlap_with(real_id.to_span(change.op_len())) { + let change_len = validate_redactable_change(change)?; + if !range.has_overlap_with(real_id.to_span(change_len)) { continue; } @@ -1469,15 +1765,19 @@ pub mod json { break; } - let len = op.content.op_len() as Counter; - if op.counter + len <= redact_range.0 { + let len = checked_redact_op_len(op)?; + let op_end = op + .counter + .checked_add(len) + .ok_or_else(|| invalid_schema("op counter overflow"))?; + if op_end <= redact_range.0 { continue; } let result = redact_op( &mut op.content, - (redact_range.0 - op.counter).max(0).min(len) - ..(redact_range.1 - op.counter).max(0).min(len), + redact_range.0.saturating_sub(op.counter).max(0).min(len) + ..redact_range.1.saturating_sub(op.counter).max(0).min(len), ); match result { Ok(()) => {} diff --git a/crates/loro-internal/src/encoding/shallow_snapshot.rs b/crates/loro-internal/src/encoding/shallow_snapshot.rs index c4991bcbd..ac2f4fee2 100644 --- a/crates/loro-internal/src/encoding/shallow_snapshot.rs +++ b/crates/loro-internal/src/encoding/shallow_snapshot.rs @@ -369,10 +369,8 @@ pub(crate) fn encode_snapshot_at( } doc.app_state().lock().take_events(); - let final_result = match result { + match result { Err(err) => Err(err), Ok(()) => restore_result, - }; - - final_result + } } diff --git a/crates/loro-internal/src/handler.rs b/crates/loro-internal/src/handler.rs index c30334df7..d48fa785a 100644 --- a/crates/loro-internal/src/handler.rs +++ b/crates/loro-internal/src/handler.rs @@ -33,11 +33,70 @@ pub use tree::TreeHandler; mod movable_list_apply_delta; mod tree; -const INSERT_CONTAINER_VALUE_ARG_ERROR: &str = - "Cannot insert a LoroValue::Container directly. To create child container, use insert_container"; +const REGULAR_CONTAINER_VALUE_ARG_ERROR: &str = + "Cannot use a LoroValue::Container as a regular value. To create child container, use insert_container or set_container"; mod text_update; +fn ensure_no_regular_container_value(value: &LoroValue) -> LoroResult<()> { + let mut stack = vec![value]; + while let Some(value) = stack.pop() { + match value { + LoroValue::Container(_) => { + return Err(LoroError::ArgErr( + REGULAR_CONTAINER_VALUE_ARG_ERROR + .to_string() + .into_boxed_str(), + )); + } + LoroValue::List(list) => { + stack.extend(list.iter()); + } + LoroValue::Map(map) => { + stack.extend(map.values()); + } + LoroValue::Null + | LoroValue::Bool(_) + | LoroValue::Double(_) + | LoroValue::I64(_) + | LoroValue::Binary(_) + | LoroValue::String(_) => {} + } + } + + Ok(()) +} + +fn checked_range_end( + pos: usize, + len: usize, + container_len: usize, + info: Box, +) -> LoroResult { + let end = pos.checked_add(len).ok_or_else(|| LoroError::OutOfBound { + pos: usize::MAX, + len: container_len, + info: info.clone(), + })?; + if end > container_len { + return Err(LoroError::OutOfBound { + pos: end, + len: container_len, + info, + }); + } + + Ok(end) +} + +fn checked_delta_index_end(pos: usize, len: usize, container_len: usize) -> LoroResult { + pos.checked_add(len).ok_or_else(|| LoroError::OutOfBound { + pos: usize::MAX, + len: container_len, + info: format!("Position: {}:{}", file!(), line!()).into_boxed_str(), + }) +} + pub trait HandlerTrait: Clone + Sized { fn is_attached(&self) -> bool; fn attached_handler(&self) -> Option<&BasicHandler>; @@ -91,6 +150,15 @@ fn create_handler(inner: &BasicHandler, id: ContainerID) -> Handler { Handler::new_attached(id, inner.doc.clone()) } +fn value_to_value_or_handler(inner: &BasicHandler, value: LoroValue) -> ValueOrHandler { + match value { + LoroValue::Container(container_id) => { + ValueOrHandler::Handler(create_handler(inner, container_id)) + } + value => ValueOrHandler::Value(value), + } +} + /// Flatten attributes that allow overlap #[derive(Clone, Debug)] pub struct BasicHandler { @@ -241,6 +309,10 @@ impl BasicHandler { fn is_deleted(&self) -> bool { self.doc.state.lock().is_deleted(self.container_idx) } + + fn has_decoded_state(&self) -> bool { + self.with_doc_state(|state| state.has_decoded_container_state(self.container_idx)) + } } /// Flatten attributes that allow overlap @@ -1407,9 +1479,10 @@ impl TextHandler { pub fn is_empty(&self) -> bool { match &self.inner { MaybeDetached::Detached(t) => t.lock().value.is_empty(), - MaybeDetached::Attached(a) => { + MaybeDetached::Attached(a) if a.has_decoded_state() => { a.with_state(|state| state.as_richtext_state_mut().unwrap().is_empty()) } + MaybeDetached::Attached(a) => a.get_value().as_string().unwrap().is_empty(), } } @@ -1419,9 +1492,10 @@ impl TextHandler { let t = t.lock(); t.value.len_utf8() } - MaybeDetached::Attached(a) => { + MaybeDetached::Attached(a) if a.has_decoded_state() => { a.with_state(|state| state.as_richtext_state_mut().unwrap().len_utf8()) } + MaybeDetached::Attached(a) => a.get_value().as_string().unwrap().len(), } } @@ -1432,7 +1506,7 @@ impl TextHandler { t.value.len_utf16() } MaybeDetached::Attached(a) => { - a.with_state(|state| state.as_richtext_state_mut().unwrap().len_utf16()) + a.with_doc_state(|state| state.get_text_utf16_len(a.container_idx)) } } } @@ -1444,7 +1518,7 @@ impl TextHandler { t.value.len_unicode() } MaybeDetached::Attached(a) => { - a.with_state(|state| state.as_richtext_state_mut().unwrap().len_unicode()) + a.with_doc_state(|state| state.get_text_unicode_len(a.container_idx)) } } } @@ -1462,9 +1536,25 @@ impl TextHandler { fn len(&self, pos_type: PosType) -> usize { match &self.inner { MaybeDetached::Detached(t) => t.lock().value.len(pos_type), - MaybeDetached::Attached(a) => { + MaybeDetached::Attached(a) if a.has_decoded_state() || pos_type == PosType::Entity => { a.with_state(|state| state.as_richtext_state_mut().unwrap().len(pos_type)) } + MaybeDetached::Attached(a) => match pos_type { + PosType::Bytes => a.get_value().as_string().unwrap().len(), + PosType::Unicode => { + a.with_doc_state(|state| state.get_text_unicode_len(a.container_idx)) + } + PosType::Utf16 => { + a.with_doc_state(|state| state.get_text_utf16_len(a.container_idx)) + } + PosType::Event if cfg!(feature = "wasm") => { + a.with_doc_state(|state| state.get_text_utf16_len(a.container_idx)) + } + PosType::Event => { + a.with_doc_state(|state| state.get_text_unicode_len(a.container_idx)) + } + PosType::Entity => unreachable!("entity length is handled by the state path"), + }, } } @@ -1472,14 +1562,22 @@ impl TextHandler { let err = match pos_type { PosType::Bytes => Some(LoroError::UTF8InUnicodeCodePoint { pos }), PosType::Utf16 => Some(LoroError::UTF16InUnicodeCodePoint { pos }), + PosType::Event if cfg!(feature = "wasm") => { + Some(LoroError::UTF16InUnicodeCodePoint { pos }) + } _ => None, }; let Some(err) = err else { return Ok(()); }; - let Some(unicode_pos) = self.convert_pos(pos, pos_type, PosType::Unicode) else { + + if pos > self.len(pos_type) { return Ok(()); + } + + let Some(unicode_pos) = self.convert_pos(pos, pos_type, PosType::Unicode) else { + return Err(err); }; if self.convert_pos(unicode_pos, PosType::Unicode, pos_type) != Some(pos) { return Err(err); @@ -1546,14 +1644,18 @@ impl TextHandler { }; t.value.get_char_by_event_index(event_pos) } - MaybeDetached::Attached(a) => a.with_state(|state| { - let state = state.as_richtext_state_mut().unwrap(); - let event_pos = match pos_type { - PosType::Event => pos, - _ => state.index_to_event_index(pos, pos_type), - }; - state.get_char_by_event_index(event_pos) - }), + MaybeDetached::Attached(a) if a.has_decoded_state() || pos_type == PosType::Entity => a + .with_state(|state| { + let state = state.as_richtext_state_mut().unwrap(); + let event_pos = match pos_type { + PosType::Event => pos, + _ => state.index_to_event_index(pos, pos_type), + }; + state.get_char_by_event_index(event_pos) + }), + MaybeDetached::Attached(a) => { + return text_char_at(a.get_value().as_string().unwrap(), pos, pos_type); + } } { Ok(c) } else { @@ -1620,24 +1722,39 @@ impl TextHandler { }; t.value.get_text_slice_by_event_index(start, end - start) } - MaybeDetached::Attached(a) => a.with_state(|state| { - let state = state.as_richtext_state_mut().unwrap(); - let len = state.len(pos_type); - if end_index > len { - return Err(LoroError::OutOfBound { - pos: end_index, - len, - info: info(), - }); - } - let (start, end) = match pos_type { - PosType::Event => (start_index, end_index), - _ => ( - state.index_to_event_index(start_index, pos_type), - state.index_to_event_index(end_index, pos_type), - ), - }; - state.get_text_slice_by_event_index(start, end - start) + MaybeDetached::Attached(a) if a.has_decoded_state() || pos_type == PosType::Entity => a + .with_state(|state| { + let state = state.as_richtext_state_mut().unwrap(); + let len = state.len(pos_type); + if end_index > len { + return Err(LoroError::OutOfBound { + pos: end_index, + len, + info: info(), + }); + } + let (start, end) = match pos_type { + PosType::Event => (start_index, end_index), + _ => ( + state.index_to_event_index(start_index, pos_type), + state.index_to_event_index(end_index, pos_type), + ), + }; + state.get_text_slice_by_event_index(start, end - start) + }), + MaybeDetached::Attached(a) => text_slice( + a.get_value().as_string().unwrap(), + start_index, + end_index, + pos_type, + ) + .map_err(|err| match err { + LoroError::OutOfBound { pos, len, .. } => LoroError::OutOfBound { + pos, + len, + info: info(), + }, + err => err, }), } } @@ -1705,7 +1822,13 @@ impl TextHandler { /// /// This method requires auto_commit to be enabled. pub fn splice(&self, pos: usize, len: usize, s: &str, pos_type: PosType) -> LoroResult { - let x = self.slice(pos, pos + len, pos_type)?; + let end = checked_range_end( + pos, + len, + self.len(pos_type), + format!("Position: {}:{}", file!(), line!()).into_boxed_str(), + )?; + let x = self.slice(pos, end, pos_type)?; self.delete(pos, len, pos_type)?; self.insert(pos, s, pos_type)?; Ok(x) @@ -1800,15 +1923,14 @@ impl TextHandler { } let text_len = self.len(pos_type); - if pos + len > text_len { - return Err(LoroError::OutOfBound { - pos: pos + len, - len: text_len, - info: format!("Position: {}:{}", file!(), line!()).into_boxed_str(), - }); - } + let end = checked_range_end( + pos, + len, + text_len, + format!("Position: {}:{}", file!(), line!()).into_boxed_str(), + )?; self.validate_text_boundary(pos, pos_type)?; - self.validate_text_boundary(pos + len, pos_type)?; + self.validate_text_boundary(end, pos_type)?; match &self.inner { MaybeDetached::Detached(t) => { @@ -2017,16 +2139,16 @@ impl TextHandler { return Ok(()); } - if pos + len > self.len(pos_type) { - error!("pos={} len={} len_event={}", pos, len, self.len_event()); - return Err(LoroError::OutOfBound { - pos: pos + len, - len: self.len_event(), - info: format!("Position: {}:{}", file!(), line!()).into_boxed_str(), - }); - } + let text_len = self.len(pos_type); + let end = checked_range_end( + pos, + len, + text_len, + format!("Position: {}:{}", file!(), line!()).into_boxed_str(), + ) + .inspect_err(|_| error!("pos={} len={} len={}", pos, len, text_len))?; self.validate_text_boundary(pos, pos_type)?; - self.validate_text_boundary(pos + len, pos_type)?; + self.validate_text_boundary(end, pos_type)?; let inner = self.inner.try_attached_state()?; let s = tracing::span!(tracing::Level::INFO, "delete", "pos={} len={}", pos, len); @@ -2036,7 +2158,7 @@ impl TextHandler { let ranges = inner.with_state(|state| { let richtext_state = state.as_richtext_state_mut().unwrap(); event_pos = richtext_state.index_to_event_index(pos, pos_type); - let event_end = richtext_state.index_to_event_index(pos + len, pos_type); + let event_end = richtext_state.index_to_event_index(end, pos_type); event_len = event_end - event_pos; richtext_state.get_text_entity_ranges_in_event_index_range(event_pos, event_len) @@ -2108,6 +2230,7 @@ impl TextHandler { "Start must be less than end".to_string().into_boxed_str(), )); } + ensure_no_regular_container_value(value)?; let len = state.len(pos_type); if end > len { @@ -2188,6 +2311,7 @@ impl TextHandler { "Start must be less than end".to_string().into_boxed_str(), )); } + ensure_no_regular_container_value(&value)?; let inner = self.inner.try_attached_state()?; let key: InternalString = key.into(); @@ -2333,7 +2457,7 @@ impl TextHandler { empty_attr.as_ref().unwrap() }); - let end = index + insert_len; + let end = checked_delta_index_end(index, insert_len, self.len_event())?; let override_styles = self.insert_with_txn_and_attr( txn, index, @@ -2357,7 +2481,7 @@ impl TextHandler { self.delete_with_txn(txn, index, *delete, PosType::Event)?; } TextDelta::Retain { attributes, retain } => { - let end = index + *retain; + let end = checked_delta_index_end(index, *retain, self.len_event())?; match attributes { Some(attr) if !attr.is_empty() => { let mut pending_mark = PendingMark { @@ -2379,7 +2503,12 @@ impl TextHandler { } } - let mut len = self.len_event(); + let mut len = match &self.inner { + MaybeDetached::Detached(_) => self.len_event(), + MaybeDetached::Attached(a) => { + a.with_state(|state| state.as_richtext_state_mut().unwrap().len(PosType::Event)) + } + }; for pending_mark in marks { if pending_mark.start >= len { self.insert_with_txn( @@ -2619,7 +2748,7 @@ impl TextHandler { }; (event_index, unicode_index) } - MaybeDetached::Attached(a) => { + MaybeDetached::Attached(a) if a.has_decoded_state() => { let res: Option<(usize, usize)> = a.with_state(|state| { let state = state.as_richtext_state_mut().unwrap(); if index > state.len(from) { @@ -2639,10 +2768,14 @@ impl TextHandler { Some((event_index, unicode_index)) }); - match res { - Some(v) => v, - None => return None, - } + res? + } + MaybeDetached::Attached(a) => { + let value = a.get_value(); + let s = value.as_string().unwrap(); + let unicode_index = text_pos_to_unicode(s, index, from)?; + let event_index = unicode_to_text_pos(s, unicode_index, PosType::Event)?; + (event_index, unicode_index) } }; @@ -2659,7 +2792,7 @@ impl TextHandler { } t.value.get_text_slice_by_event_index(0, event_index).ok()? } - MaybeDetached::Attached(a) => { + MaybeDetached::Attached(a) if a.has_decoded_state() => { let res: Result = a.with_state(|state| { let state = state.as_richtext_state_mut().unwrap(); if event_index > state.len_event() { @@ -2675,6 +2808,11 @@ impl TextHandler { Err(_) => return None, } } + MaybeDetached::Attached(a) => { + let value = a.get_value(); + let s = value.as_string().unwrap(); + return unicode_to_text_pos(s, unicode_index, to); + } }; Some(match to { @@ -2697,6 +2835,162 @@ fn event_len(s: &str) -> usize { } } +fn text_len(s: &str, pos_type: PosType) -> Option { + Some(match pos_type { + PosType::Bytes => s.len(), + PosType::Unicode => s.chars().count(), + PosType::Utf16 => count_utf16_len(s.as_bytes()), + PosType::Event => event_len(s), + PosType::Entity => return None, + }) +} + +fn text_pos_to_unicode(s: &str, index: usize, pos_type: PosType) -> Option { + match pos_type { + PosType::Unicode => (index <= s.chars().count()).then_some(index), + PosType::Bytes => { + if index > s.len() { + None + } else { + Some( + s.char_indices() + .take_while(|(pos, c)| *pos + c.len_utf8() <= index) + .count(), + ) + } + } + PosType::Utf16 => utf16_to_unicode_pos(s, index), + PosType::Event if cfg!(feature = "wasm") => utf16_to_unicode_pos(s, index), + PosType::Event => (index <= s.chars().count()).then_some(index), + PosType::Entity => None, + } +} + +fn unicode_to_text_pos(s: &str, index: usize, pos_type: PosType) -> Option { + match pos_type { + PosType::Unicode => (index <= s.chars().count()).then_some(index), + PosType::Bytes => unicode_to_byte_pos(s, index), + PosType::Utf16 => unicode_to_utf16_pos(s, index), + PosType::Event if cfg!(feature = "wasm") => unicode_to_utf16_pos(s, index), + PosType::Event => (index <= s.chars().count()).then_some(index), + PosType::Entity => None, + } +} + +fn unicode_to_byte_pos(s: &str, index: usize) -> Option { + if index == 0 { + return Some(0); + } + + let mut unicode_pos = 0; + for (byte_pos, _) in s.char_indices() { + if unicode_pos == index { + return Some(byte_pos); + } + unicode_pos += 1; + } + + (unicode_pos == index).then_some(s.len()) +} + +fn unicode_to_utf16_pos(s: &str, index: usize) -> Option { + let mut unicode_pos = 0; + let mut utf16_pos = 0; + if index == 0 { + return Some(0); + } + + for c in s.chars() { + unicode_pos += 1; + utf16_pos += c.len_utf16(); + if unicode_pos == index { + return Some(utf16_pos); + } + } + + (unicode_pos == index).then_some(utf16_pos) +} + +fn utf16_to_unicode_pos(s: &str, index: usize) -> Option { + let mut unicode_pos = 0; + let mut utf16_pos = 0; + if index == 0 { + return Some(0); + } + + for c in s.chars() { + let next_utf16_pos = utf16_pos + c.len_utf16(); + if index < next_utf16_pos { + return Some(unicode_pos); + } + if index == next_utf16_pos { + return Some(unicode_pos + 1); + } + utf16_pos = next_utf16_pos; + unicode_pos += 1; + } + + (index == utf16_pos).then_some(unicode_pos) +} + +fn text_boundary_error(pos: usize, pos_type: PosType) -> LoroError { + match pos_type { + PosType::Bytes => LoroError::UTF8InUnicodeCodePoint { pos }, + PosType::Utf16 => LoroError::UTF16InUnicodeCodePoint { pos }, + PosType::Event if cfg!(feature = "wasm") => LoroError::UTF16InUnicodeCodePoint { pos }, + _ => LoroError::OutOfBound { + pos, + len: 0, + info: format!("Position: {}:{}", file!(), line!()).into_boxed_str(), + }, + } +} + +fn text_char_at(s: &str, pos: usize, pos_type: PosType) -> LoroResult { + let len = text_len(s, pos_type).unwrap_or(0); + if pos >= len { + return Err(LoroError::OutOfBound { + pos, + len, + info: format!("Position: {}:{}", file!(), line!()).into_boxed_str(), + }); + } + + let unicode_pos = + text_pos_to_unicode(s, pos, pos_type).ok_or_else(|| text_boundary_error(pos, pos_type))?; + s.chars().nth(unicode_pos).ok_or(LoroError::OutOfBound { + pos, + len, + info: format!("Position: {}:{}", file!(), line!()).into_boxed_str(), + }) +} + +fn text_slice(s: &str, start: usize, end: usize, pos_type: PosType) -> LoroResult { + if end < start { + return Err(LoroError::EndIndexLessThanStartIndex { start, end }); + } + if start == end { + return Ok(String::new()); + } + + let len = text_len(s, pos_type).unwrap_or(0); + if end > len { + return Err(LoroError::OutOfBound { + pos: end, + len, + info: format!("Position: {}:{}", file!(), line!()).into_boxed_str(), + }); + } + + let start = text_pos_to_unicode(s, start, pos_type) + .ok_or_else(|| text_boundary_error(start, pos_type))?; + let end = + text_pos_to_unicode(s, end, pos_type).ok_or_else(|| text_boundary_error(end, pos_type))?; + let start = unicode_to_byte_pos(s, start).expect("unicode index must map to a byte boundary"); + let end = unicode_to_byte_pos(s, end).expect("unicode index must map to a byte boundary"); + Ok(s[start..end].to_string()) +} + impl ListHandler { /// Create a new container that is detached from the document. /// The edits on a detached container will not be persisted. @@ -2719,7 +3013,9 @@ impl ListHandler { len, }); } - list.value.insert(pos, ValueOrHandler::Value(v.into())); + let value = v.into(); + ensure_no_regular_container_value(&value)?; + list.value.insert(pos, ValueOrHandler::Value(value)); Ok(()) } MaybeDetached::Attached(a) => { @@ -2743,13 +3039,7 @@ impl ListHandler { } let inner = self.inner.try_attached_state()?; - if let Some(_container) = v.as_container() { - return Err(LoroError::ArgErr( - INSERT_CONTAINER_VALUE_ARG_ERROR - .to_string() - .into_boxed_str(), - )); - } + ensure_no_regular_container_value(&v)?; txn.apply_local_op( inner.container_idx, @@ -2766,7 +3056,9 @@ impl ListHandler { match &self.inner { MaybeDetached::Detached(l) => { let mut list = l.lock(); - list.value.push(ValueOrHandler::Value(v.into())); + let value = v.into(); + ensure_no_regular_container_value(&value)?; + list.value.push(ValueOrHandler::Value(value)); Ok(()) } MaybeDetached::Attached(a) => a.with_txn(|txn| self.push_with_txn(txn, v.into())), @@ -2859,14 +3151,13 @@ impl ListHandler { match &self.inner { MaybeDetached::Detached(l) => { let mut list = l.lock(); - if pos + len > list.value.len() { - return Err(LoroError::OutOfBound { - pos: pos + len, - info: format!("Position: {}:{}", file!(), line!()).into_boxed_str(), - len: list.value.len(), - }); - } - list.value.drain(pos..pos + len); + let end = checked_range_end( + pos, + len, + list.value.len(), + format!("Position: {}:{}", file!(), line!()).into_boxed_str(), + )?; + list.value.drain(pos..end); Ok(()) } MaybeDetached::Attached(a) => a.with_txn(|txn| self.delete_with_txn(txn, pos, len)), @@ -2878,20 +3169,18 @@ impl ListHandler { return Ok(()); } - if pos + len > self.len() { - return Err(LoroError::OutOfBound { - pos: pos + len, - info: format!("Position: {}:{}", file!(), line!()).into_boxed_str(), - len: self.len(), - }); - } + let list_len = self.len(); + let end = checked_range_end( + pos, + len, + list_len, + format!("Position: {}:{}", file!(), line!()).into_boxed_str(), + )?; let inner = self.inner.try_attached_state()?; let ids: Vec<_> = inner.with_state(|state| { let list = state.as_list_state().unwrap(); - (pos..pos + len) - .map(|i| list.get_id_at(i).unwrap()) - .collect() + (pos..end).map(|i| list.get_id_at(i).unwrap()).collect() }); for id in ids.into_iter() { @@ -2930,19 +3219,17 @@ impl ListHandler { )), } } - MaybeDetached::Attached(a) => { - let Some(value) = a.with_state(|state| { - state.as_list_state().as_ref().unwrap().get(index).cloned() - }) else { + MaybeDetached::Attached(_) => { + let Some(value) = self.get_(index) else { return Err(LoroError::OutOfBound { pos: index, info: format!("Position: {}:{}", file!(), line!()).into_boxed_str(), - len: a.with_state(|state| state.as_list_state().unwrap().len()), + len: self.len(), }); }; match value { - LoroValue::Container(id) => Ok(create_handler(a, id)), - _ => Err(LoroError::ArgErr( + ValueOrHandler::Handler(handler) => Ok(handler), + ValueOrHandler::Value(value) => Err(LoroError::ArgErr( format!( "Expected container at index {}, but found {:?}", index, value @@ -2958,7 +3245,7 @@ impl ListHandler { match &self.inner { MaybeDetached::Detached(l) => l.lock().value.len(), MaybeDetached::Attached(a) => { - a.with_state(|state| state.as_list_state().unwrap().len()) + a.with_doc_state(|state| state.get_list_len(a.container_idx)) } } } @@ -2977,10 +3264,9 @@ impl ListHandler { pub fn get(&self, index: usize) -> Option { match &self.inner { MaybeDetached::Detached(l) => l.lock().value.get(index).map(|x| x.to_value()), - MaybeDetached::Attached(a) => a.with_state(|state| { - let a = state.as_list_state().unwrap(); - a.get(index).cloned() - }), + MaybeDetached::Attached(a) => { + a.with_doc_state(|state| state.get_list_value_at(a.container_idx, index)) + } } } @@ -2992,15 +3278,9 @@ impl ListHandler { l.value.get(index).cloned() } MaybeDetached::Attached(inner) => { - let value = - inner.with_state(|state| state.as_list_state().unwrap().get(index).cloned()); - match value { - Some(LoroValue::Container(container_id)) => Some(ValueOrHandler::Handler( - create_handler(inner, container_id.clone()), - )), - Some(value) => Some(ValueOrHandler::Value(value.clone())), - None => None, - } + let value = inner + .with_doc_state(|state| state.get_list_value_at(inner.container_idx, index)); + value.map(|value| value_to_value_or_handler(inner, value)) } } } @@ -3017,22 +3297,12 @@ impl ListHandler { } } MaybeDetached::Attached(inner) => { - let mut temp = vec![]; - inner.with_state(|state| { - let a = state.as_list_state().unwrap(); - for v in a.iter() { - match v { - LoroValue::Container(c) => { - temp.push(ValueOrHandler::Handler(create_handler( - inner, - c.clone(), - ))); - } - value => { - temp.push(ValueOrHandler::Value(value.clone())); - } - } - } + let temp = inner.with_doc_state(|state| { + state + .get_list_values(inner.container_idx) + .into_iter() + .map(|value| value_to_value_or_handler(inner, value)) + .collect::>() }); for v in temp.into_iter() { f(v); @@ -3188,7 +3458,9 @@ impl MovableListHandler { len: d.value.len(), }); } - d.value.insert(pos, ValueOrHandler::Value(v.into())); + let value = v.into(); + ensure_no_regular_container_value(&value)?; + d.value.insert(pos, ValueOrHandler::Value(value)); Ok(()) } MaybeDetached::Attached(a) => { @@ -3212,13 +3484,7 @@ impl MovableListHandler { }); } - if v.is_container() { - return Err(LoroError::ArgErr( - INSERT_CONTAINER_VALUE_ARG_ERROR - .to_string() - .into_boxed_str(), - )); - } + ensure_no_regular_container_value(&v)?; let op_index = self.with_state(|state| { let list = state.as_movable_list_state().unwrap(); @@ -3344,7 +3610,7 @@ impl MovableListHandler { Ok(d.value.pop()) } MaybeDetached::Attached(a) => { - if self.len() == 0 { + if self.is_empty() { return Ok(None); } let last = self.len() - 1; @@ -3449,7 +3715,9 @@ impl MovableListHandler { len: d.value.len(), }); } - d.value[index] = ValueOrHandler::Value(value.into()); + let value = value.into(); + ensure_no_regular_container_value(&value)?; + d.value[index] = ValueOrHandler::Value(value); Ok(()) } MaybeDetached::Attached(a) => { @@ -3480,6 +3748,7 @@ impl MovableListHandler { else { unreachable!() }; + ensure_no_regular_container_value(&value)?; let op = crate::op::RawOpContent::List(crate::container::list::list_op::ListOp::Set { elem_id: elem_id.to_id(), @@ -3556,14 +3825,13 @@ impl MovableListHandler { match &self.inner { MaybeDetached::Detached(d) => { let mut d = d.lock(); - if pos + len > d.value.len() { - return Err(LoroError::OutOfBound { - pos: pos + len, - info: format!("Position: {}:{}", file!(), line!()).into_boxed_str(), - len: d.value.len(), - }); - } - d.value.drain(pos..pos + len); + let end = checked_range_end( + pos, + len, + d.value.len(), + format!("Position: {}:{}", file!(), line!()).into_boxed_str(), + )?; + d.value.drain(pos..end); Ok(()) } MaybeDetached::Attached(a) => a.with_txn(|txn| self.delete_with_txn(txn, pos, len)), @@ -3576,20 +3844,20 @@ impl MovableListHandler { return Ok(()); } - if pos + len > self.len() { - return Err(LoroError::OutOfBound { - pos: pos + len, - info: format!("Position: {}:{}", file!(), line!()).into_boxed_str(), - len: self.len(), - }); - } + let list_len = self.len(); + let end = checked_range_end( + pos, + len, + list_len, + format!("Position: {}:{}", file!(), line!()).into_boxed_str(), + )?; let (ids, new_poses) = self.with_state(|state| { let list = state.as_movable_list_state().unwrap(); - let ids: Vec<_> = (pos..pos + len) + let ids: Vec<_> = (pos..end) .map(|i| list.get_list_id_at(i, IndexType::ForUser).unwrap()) .collect(); - let poses: Vec<_> = (pos..pos + len) + let poses: Vec<_> = (pos..end) // need to -i because we delete the previous ones .map(|user_index| { let op_index = list @@ -3641,24 +3909,17 @@ impl MovableListHandler { )), } } - MaybeDetached::Attached(a) => { - let Some(value) = a.with_state(|state| { - state - .as_movable_list_state() - .as_ref() - .unwrap() - .get(index, IndexType::ForUser) - .cloned() - }) else { + MaybeDetached::Attached(_) => { + let Some(value) = self.get_(index) else { return Err(LoroError::OutOfBound { pos: index, info: format!("Position: {}:{}", file!(), line!()).into_boxed_str(), - len: a.with_state(|state| state.as_list_state().unwrap().len()), + len: self.len(), }); }; match value { - LoroValue::Container(id) => Ok(create_handler(a, id)), - _ => Err(LoroError::ArgErr( + ValueOrHandler::Handler(handler) => Ok(handler), + ValueOrHandler::Value(value) => Err(LoroError::ArgErr( format!( "Expected container at index {}, but found {:?}", index, value @@ -3677,7 +3938,7 @@ impl MovableListHandler { d.value.len() } MaybeDetached::Attached(a) => { - a.with_state(|state| state.as_movable_list_state().unwrap().len()) + a.with_doc_state(|state| state.get_list_len(a.container_idx)) } } } @@ -3701,10 +3962,9 @@ impl MovableListHandler { let d = d.lock(); d.value.get(index).map(|v| v.to_value()) } - MaybeDetached::Attached(a) => a.with_state(|state| { - let a = state.as_movable_list_state().unwrap(); - a.get(index, IndexType::ForUser).cloned() - }), + MaybeDetached::Attached(a) => { + a.with_doc_state(|state| state.get_list_value_at(a.container_idx, index)) + } } } @@ -3715,19 +3975,11 @@ impl MovableListHandler { let d = d.lock(); d.value.get(index).cloned() } - MaybeDetached::Attached(m) => m.with_state(|state| { - let a = state.as_movable_list_state().unwrap(); - match a.get(index, IndexType::ForUser) { - Some(v) => { - if let LoroValue::Container(c) = v { - Some(ValueOrHandler::Handler(create_handler(m, c.clone()))) - } else { - Some(ValueOrHandler::Value(v.clone())) - } - } - None => None, - } - }), + MaybeDetached::Attached(m) => { + let value = + m.with_doc_state(|state| state.get_list_value_at(m.container_idx, index)); + value.map(|value| value_to_value_or_handler(m, value)) + } } } @@ -3743,19 +3995,12 @@ impl MovableListHandler { } } MaybeDetached::Attached(m) => { - let mut temp = vec![]; - m.with_state(|state| { - let a = state.as_movable_list_state().unwrap(); - for v in a.iter() { - match v { - LoroValue::Container(c) => { - temp.push(ValueOrHandler::Handler(create_handler(m, c.clone()))); - } - value => { - temp.push(ValueOrHandler::Value(value.clone())); - } - } - } + let temp = m.with_doc_state(|state| { + state + .get_list_values(m.container_idx) + .into_iter() + .map(|value| value_to_value_or_handler(m, value)) + .collect::>() }); for v in temp.into_iter() { @@ -3912,8 +4157,9 @@ impl MapHandler { match &self.inner { MaybeDetached::Detached(m) => { let mut m = m.lock(); - m.value - .insert(key.into(), ValueOrHandler::Value(value.into())); + let value = value.into(); + ensure_no_regular_container_value(&value)?; + m.value.insert(key.into(), ValueOrHandler::Value(value)); Ok(()) } MaybeDetached::Attached(a) => { @@ -3927,20 +4173,15 @@ impl MapHandler { match &self.inner { MaybeDetached::Detached(m) => { let mut m = m.lock(); - m.value - .insert(key.into(), ValueOrHandler::Value(value.into())); + let value = value.into(); + ensure_no_regular_container_value(&value)?; + m.value.insert(key.into(), ValueOrHandler::Value(value)); Ok(()) } MaybeDetached::Attached(a) => a.with_txn(|txn| { let this = &self; let value = value.into(); - if let Some(_value) = value.as_container() { - return Err(LoroError::ArgErr( - INSERT_CONTAINER_VALUE_ARG_ERROR - .to_string() - .into_boxed_str(), - )); - } + ensure_no_regular_container_value(&value)?; let inner = this.inner.try_attached_state()?; txn.apply_local_op( @@ -3965,13 +4206,7 @@ impl MapHandler { key: &str, value: LoroValue, ) -> LoroResult<()> { - if let Some(_value) = value.as_container() { - return Err(LoroError::ArgErr( - INSERT_CONTAINER_VALUE_ARG_ERROR - .to_string() - .into_boxed_str(), - )); - } + ensure_no_regular_container_value(&value)?; if self.get(key).map(|x| x == value).unwrap_or(false) { // skip if the value is already set @@ -4072,24 +4307,14 @@ impl MapHandler { } } MaybeDetached::Attached(inner) => { - let mut temp = vec![]; - inner.with_state(|state| { - let a = state.as_map_state().unwrap(); - for (k, v) in a.iter() { - if let Some(v) = &v.value { - match v { - LoroValue::Container(c) => { - temp.push(( - k.to_string(), - ValueOrHandler::Handler(create_handler(inner, c.clone())), - )); - } - value => { - temp.push((k.to_string(), ValueOrHandler::Value(value.clone()))) - } - } - } - } + let temp = inner.with_doc_state(|state| { + state + .get_map_entries(inner.container_idx) + .into_iter() + .map(|(key, value)| { + (key.to_string(), value_to_value_or_handler(inner, value)) + }) + .collect::>() }); for (k, v) in temp.into_iter() { @@ -4111,19 +4336,18 @@ impl MapHandler { ValueOrHandler::Handler(h) => Ok(h.clone()), } } - MaybeDetached::Attached(inner) => { - let container_id = inner.with_state(|state| { - state - .as_map_state() - .as_ref() - .unwrap() - .get(key) - .unwrap() - .as_container() - .unwrap() - .clone() - }); - Ok(create_handler(inner, container_id)) + MaybeDetached::Attached(_) => { + let Some(value) = self.get_(key) else { + return Err(LoroError::ArgErr( + format!("Key {key} does not exist").into_boxed_str(), + )); + }; + match value { + ValueOrHandler::Handler(handler) => Ok(handler), + ValueOrHandler::Value(value) => Err(LoroError::ArgErr( + format!("Expected Handler but found {:?}", value).into_boxed_str(), + )), + } } } } @@ -4146,7 +4370,7 @@ impl MapHandler { m.value.get(key).map(|v| v.to_value()) } MaybeDetached::Attached(inner) => { - inner.with_state(|state| state.as_map_state().unwrap().get(key).cloned()) + inner.with_doc_state(|state| state.get_map_value_by_key(inner.container_idx, key)) } } } @@ -4159,15 +4383,9 @@ impl MapHandler { m.value.get(key).cloned() } MaybeDetached::Attached(inner) => { - let value = - inner.with_state(|state| state.as_map_state().unwrap().get(key).cloned()); - match value { - Some(LoroValue::Container(container_id)) => Some(ValueOrHandler::Handler( - create_handler(inner, container_id.clone()), - )), - Some(value) => Some(ValueOrHandler::Value(value.clone())), - None => None, - } + let value = inner + .with_doc_state(|state| state.get_map_value_by_key(inner.container_idx, key)); + value.map(|value| value_to_value_or_handler(inner, value)) } } } @@ -4202,7 +4420,9 @@ impl MapHandler { pub fn len(&self) -> usize { match &self.inner { MaybeDetached::Detached(m) => m.lock().value.len(), - MaybeDetached::Attached(a) => a.with_state(|state| state.as_map_state().unwrap().len()), + MaybeDetached::Attached(a) => { + a.with_doc_state(|state| state.get_map_len(a.container_idx)) + } } } @@ -4246,48 +4466,33 @@ impl MapHandler { } pub fn keys(&self) -> impl Iterator + '_ { - let mut keys: Vec = Vec::with_capacity(self.len()); - match &self.inner { + let keys: Vec = match &self.inner { MaybeDetached::Detached(m) => { let m = m.lock(); - keys = m.value.keys().map(|x| x.as_str().into()).collect(); + m.value.keys().map(|x| x.as_str().into()).collect() } MaybeDetached::Attached(a) => { - a.with_state(|state| { - for (k, v) in state.as_map_state().unwrap().iter() { - if v.value.is_some() { - keys.push(k.clone()); - } - } - }); + a.with_doc_state(|state| state.get_map_keys(a.container_idx)) } - } + }; keys.into_iter() } pub fn values(&self) -> impl Iterator + '_ { - let mut values: Vec = Vec::with_capacity(self.len()); - match &self.inner { + let values: Vec = match &self.inner { MaybeDetached::Detached(m) => { let m = m.lock(); - values = m.value.values().cloned().collect(); + m.value.values().cloned().collect() } - MaybeDetached::Attached(a) => { - a.with_state(|state| { - for (_, v) in state.as_map_state().unwrap().iter() { - let value = match &v.value { - Some(LoroValue::Container(container_id)) => { - ValueOrHandler::Handler(create_handler(a, container_id.clone())) - } - Some(value) => ValueOrHandler::Value(value.clone()), - None => continue, - }; - values.push(value); - } - }); - } - } + MaybeDetached::Attached(a) => a.with_doc_state(|state| { + state + .get_map_values(a.container_idx) + .into_iter() + .map(|value| value_to_value_or_handler(a, value)) + .collect() + }), + }; values.into_iter() } @@ -4489,6 +4694,7 @@ mod test { use super::{ Handler, HandlerTrait, ListHandler, MapHandler, MovableListHandler, TextDelta, TextHandler, + ValueOrHandler, }; use crate::container::list::list_op::ListOp; use crate::cursor::PosType; @@ -4499,9 +4705,29 @@ mod test { use crate::version::Frontiers; use crate::LoroDoc; use crate::{fx_map, ToJson}; - use loro_common::{ContainerID, ContainerType, LoroValue, ID}; + use loro_common::{ContainerID, ContainerType, LoroError, LoroValue, ID}; use serde_json::json; + fn recheck_fast_blob(mut bytes: Vec) -> Vec { + let checksum = xxhash_rust::xxh32::xxh32(&bytes[20..], u32::from_le_bytes(*b"LORO")); + bytes[16..20].copy_from_slice(&checksum.to_le_bytes()); + bytes + } + + fn replace_fast_snapshot_state_bytes(mut snapshot: Vec, state_bytes: &[u8]) -> Vec { + let mut body = &snapshot[22..]; + let oplog_len = u32::from_le_bytes(body[..4].try_into().unwrap()) as usize; + body = &body[4 + oplog_len..]; + let old_state_len = u32::from_le_bytes(body[..4].try_into().unwrap()) as usize; + let state_len_pos = 22 + 4 + oplog_len; + let state_start = state_len_pos + 4; + let state_end = state_start + old_state_len; + snapshot[state_len_pos..state_start] + .copy_from_slice(&(state_bytes.len() as u32).to_le_bytes()); + snapshot.splice(state_start..state_end, state_bytes.iter().copied()); + recheck_fast_blob(snapshot) + } + fn insert_many_with_single_list_op( txn: &mut crate::txn::Transaction, list: &crate::handler::ListHandler, @@ -4696,6 +4922,239 @@ mod test { ); } + #[test] + fn text_snapshot_string_queries_do_not_decode_state() { + let loro = LoroDoc::new_auto_commit(); + let text = loro.get_text("text"); + text.insert(0, "a😀文", PosType::Unicode).unwrap(); + text.mark(1, 3, "bold", true.into(), PosType::Unicode) + .unwrap(); + + let restored = LoroDoc::new(); + restored + .import(&loro.export(ExportMode::snapshot()).unwrap()) + .unwrap(); + let text = restored.get_text("text"); + assert!(!text.attached_handler().unwrap().has_decoded_state()); + + assert_eq!(text.len_unicode(), 3); + assert_eq!(text.len_utf16(), 4); + assert_eq!(text.len_utf8(), "a😀文".len()); + assert_eq!(text.char_at(1, PosType::Unicode).unwrap(), '😀'); + assert_eq!(text.slice(1, 3, PosType::Unicode).unwrap(), "😀文"); + assert_eq!( + text.convert_pos(2, PosType::Unicode, PosType::Utf16), + Some(3) + ); + assert!(matches!( + text.delete_utf16(2, 1), + Err(LoroError::UTF16InUnicodeCodePoint { pos: 2 }) + )); + assert!(matches!( + text.delete_utf8(2, 1), + Err(LoroError::UTF8InUnicodeCodePoint { pos: 2 }) + )); + assert!(matches!( + text.slice_delta(2, 3, PosType::Utf16), + Err(LoroError::UTF16InUnicodeCodePoint { pos: 2 }) + )); + assert!(matches!( + text.slice_delta(2, 3, PosType::Bytes), + Err(LoroError::UTF8InUnicodeCodePoint { pos: 2 }) + )); + assert!(!text.attached_handler().unwrap().has_decoded_state()); + + assert_eq!(text.get_delta().len(), 2); + assert!(text.attached_handler().unwrap().has_decoded_state()); + } + + #[test] + fn text_lazy_event_queries_match_decoded_state() { + let loro = LoroDoc::new_auto_commit(); + let text = loro.get_text("text"); + text.insert(0, "ab😀cd", PosType::Unicode).unwrap(); + text.mark(1, 4, "bold", true.into(), PosType::Unicode) + .unwrap(); + text.mark(2, 3, "link", "x".into(), PosType::Unicode) + .unwrap(); + + let lazy_doc = LoroDoc::new(); + lazy_doc + .import(&loro.export(ExportMode::snapshot()).unwrap()) + .unwrap(); + let lazy_text = lazy_doc.get_text("text"); + + let decoded_doc = LoroDoc::new(); + decoded_doc + .import(&loro.export(ExportMode::snapshot()).unwrap()) + .unwrap(); + let decoded_text = decoded_doc.get_text("text"); + decoded_text.get_delta(); + + assert!(!lazy_text.attached_handler().unwrap().has_decoded_state()); + assert!(decoded_text.attached_handler().unwrap().has_decoded_state()); + + for pos_type in [ + PosType::Event, + PosType::Unicode, + PosType::Utf16, + PosType::Bytes, + ] { + assert_eq!(lazy_text.len(pos_type), decoded_text.len(pos_type)); + for pos in 0..=decoded_text.len(pos_type) { + assert_eq!( + lazy_text.convert_pos(pos, pos_type, PosType::Unicode), + decoded_text.convert_pos(pos, pos_type, PosType::Unicode), + "convert {pos_type:?} pos {pos} to unicode" + ); + assert_eq!( + lazy_text.convert_pos(pos, pos_type, PosType::Event), + decoded_text.convert_pos(pos, pos_type, PosType::Event), + "convert {pos_type:?} pos {pos} to event" + ); + if pos < decoded_text.len(pos_type) { + assert_eq!( + lazy_text.char_at(pos, pos_type), + decoded_text.char_at(pos, pos_type), + "char_at {pos_type:?} pos {pos}" + ); + } + for end in pos..=decoded_text.len(pos_type) { + assert_eq!( + lazy_text.slice(pos, end, pos_type), + decoded_text.slice(pos, end, pos_type), + "slice {pos_type:?} {pos}..{end}" + ); + } + } + } + } + + #[test] + fn deep_value_with_id_uses_lazy_values_for_snapshot_roots() { + let loro = LoroDoc::new_auto_commit(); + let text = loro.get_text("text"); + text.insert(0, "hello", PosType::Unicode).unwrap(); + let map = loro.get_map("map"); + map.insert("key", "value").unwrap(); + let list = loro.get_list("list"); + list.push("item").unwrap(); + + let restored = LoroDoc::new(); + restored + .import(&loro.export(ExportMode::snapshot()).unwrap()) + .unwrap(); + let text = restored.get_text("text"); + let map = restored.get_map("map"); + let list = restored.get_list("list"); + + let value = restored.get_deep_value_with_id(); + assert_eq!(value["text"]["value"], "hello".into()); + assert_eq!(value["map"]["value"]["key"], "value".into()); + assert_eq!(value["list"]["value"][0], "item".into()); + assert!(!text.attached_handler().unwrap().has_decoded_state()); + assert!(!map.attached_handler().unwrap().has_decoded_state()); + assert!(!list.attached_handler().unwrap().has_decoded_state()); + } + + #[test] + fn lazy_value_reads_do_not_write_stale_snapshot_after_mutation() { + let loro = LoroDoc::new_auto_commit(); + let map = loro.get_map("map"); + map.insert("key", "old").unwrap(); + let child = map + .insert_container("child", MapHandler::new_detached()) + .unwrap(); + child.insert("nested", "old").unwrap(); + let list = loro.get_list("list"); + list.push("old").unwrap(); + let child_list = list.push_container(ListHandler::new_detached()).unwrap(); + child_list.push("nested-old").unwrap(); + + let restored = LoroDoc::new(); + restored + .import(&loro.export(ExportMode::snapshot()).unwrap()) + .unwrap(); + let map = restored.get_map("map"); + let list = restored.get_list("list"); + + assert_eq!(map.get("key").unwrap(), "old".into()); + assert_eq!(list.get(0).unwrap(), "old".into()); + let child = match map.get_("child").unwrap() { + ValueOrHandler::Handler(handler) => handler.into_map().unwrap(), + ValueOrHandler::Value(value) => panic!("expected child map, got {value:?}"), + }; + let child_list = match list.get_(1).unwrap() { + ValueOrHandler::Handler(handler) => handler.into_list().unwrap(), + ValueOrHandler::Value(value) => panic!("expected child list, got {value:?}"), + }; + + map.insert("key", "new").unwrap(); + child.insert("nested", "new").unwrap(); + list.delete(0, 1).unwrap(); + list.insert(0, "new").unwrap(); + child_list.delete(0, 1).unwrap(); + child_list.insert(0, "nested-new").unwrap(); + restored.commit_then_renew(); + + let roundtrip = LoroDoc::new(); + roundtrip + .import(&restored.export(ExportMode::snapshot()).unwrap()) + .unwrap(); + assert_eq!( + roundtrip.get_deep_value().to_json_value(), + serde_json::json!({ + "map": { "key": "new", "child": { "nested": "new" } }, + "list": ["new", ["nested-new"]] + }) + ); + } + + #[test] + fn fast_snapshot_with_trailing_bytes_is_rejected_on_import() { + let loro = LoroDoc::new_auto_commit(); + let map = loro.get_map("map"); + map.insert("key", "value").unwrap(); + let mut snapshot = loro.export(ExportMode::snapshot()).unwrap(); + snapshot.push(0xff); + let corrupted = recheck_fast_blob(snapshot); + + let doc = LoroDoc::new(); + assert!(doc.import(&corrupted).is_err()); + } + + #[test] + fn fast_snapshot_with_trailing_bytes_is_rejected_by_meta_decoder() { + let loro = LoroDoc::new_auto_commit(); + let map = loro.get_map("map"); + map.insert("key", "value").unwrap(); + let mut snapshot = loro.export(ExportMode::snapshot()).unwrap(); + snapshot.push(0xff); + let corrupted = recheck_fast_blob(snapshot); + + assert!(LoroDoc::decode_import_blob_meta(&corrupted, true).is_err()); + } + + #[test] + fn fast_snapshot_empty_sstable_meta_is_rejected_on_import() { + let loro = LoroDoc::new_auto_commit(); + let map = loro.get_map("map"); + map.insert("key", "value").unwrap(); + let snapshot = loro.export(ExportMode::snapshot()).unwrap(); + + let mut malformed_state = Vec::new(); + malformed_state.extend_from_slice(b"LORO"); + malformed_state.push(0); + malformed_state.extend_from_slice(&0u32.to_le_bytes()); + let checksum = xxhash_rust::xxh32::xxh32(&[], u32::from_le_bytes(*b"LORO")); + malformed_state.extend_from_slice(&checksum.to_le_bytes()); + malformed_state.extend_from_slice(&5u32.to_le_bytes()); + let corrupted = replace_fast_snapshot_state_bytes(snapshot, &malformed_state); + + let doc = LoroDoc::new(); + assert!(doc.import(&corrupted).is_err()); + } + #[test] fn tree_meta() { let loro = LoroDoc::new_auto_commit(); diff --git a/crates/loro-internal/src/history_cache.rs b/crates/loro-internal/src/history_cache.rs index 2d1aa9873..c75a24fb8 100644 --- a/crates/loro-internal/src/history_cache.rs +++ b/crates/loro-internal/src/history_cache.rs @@ -214,14 +214,14 @@ impl ContainerHistoryCache { ContainerType::Tree => {} } - let state = c.get_state_mut(*idx, default_ctx); + let state = c.get_state_mut(idx, default_ctx); match state { crate::state::State::MapState(m) => { if for_checkout { let c = self.for_checkout.as_mut().unwrap(); for (k, v) in m.iter() { - c.map.record_shallow_root_state_entry(*idx, k, v); + c.map.record_shallow_root_state_entry(idx, k, v); } } } @@ -242,7 +242,7 @@ impl ContainerHistoryCache { crate::state::State::TreeState(t) => { if for_importing { let c = self.for_importing.as_mut().unwrap(); - let tree = c.entry(*idx).or_insert_with(|| { + let tree = c.entry(idx).or_insert_with(|| { HistoryCacheForImporting::Tree(Default::default()) }); tree.as_tree_mut().unwrap().record_shallow_root_state( diff --git a/crates/loro-internal/src/loro.rs b/crates/loro-internal/src/loro.rs index 3c19937ea..d52254f21 100644 --- a/crates/loro-internal/src/loro.rs +++ b/crates/loro-internal/src/loro.rs @@ -44,8 +44,8 @@ use crate::{ }; use either::Either; use loro_common::{ - ContainerID, ContainerType, HasIdSpan, HasLamportSpan, IdSpan, LoroEncodeError, LoroResult, - LoroValue, ID, + ContainerID, ContainerType, HasCounterSpan, HasIdSpan, HasLamportSpan, IdSpan, LoroEncodeError, + LoroResult, LoroValue, ID, }; use rle::HasLength; use rustc_hash::{FxHashMap, FxHashSet}; @@ -345,14 +345,12 @@ impl LoroDoc { options = None; } } - if config.immediate_renew { - if self.can_edit() { - let mut t = self.txn().unwrap(); - if let Some(options) = options.as_ref() { - t.set_options(options.clone()); - } - *txn_guard = Some(t); + if config.immediate_renew && self.can_edit() { + let mut t = self.txn().unwrap(); + if let Some(options) = options.as_ref() { + t.set_options(options.clone()); } + *txn_guard = Some(t); } if let Some(on_commit) = on_commit { @@ -750,12 +748,20 @@ impl LoroDoc { } if !preflight.applies_to_dag { + let pending_root_containers = pending_root_containers_to_materialize(&oplog, &changes); let result = encoding::apply_decoded_changes_to_oplog(&mut oplog, changes); if result.has_deps_before_shallow_root { oplog.arena.rollback(arena_checkpoint); return Err(LoroError::ImportUpdatesThatDependsOnOutdatedVersion); } + if !pending_root_containers.is_empty() { + let mut state = self.state.lock(); + for id in pending_root_containers { + state.ensure_container(&id); + } + } + return Ok(result.status); } @@ -949,12 +955,20 @@ impl LoroDoc { #[inline] pub fn get_handler(&self, id: ContainerID) -> Option { if self.has_container(&id) { + self.ensure_root_container(&id); Some(Handler::new_attached(id, self.clone())) } else { None } } + #[inline] + fn ensure_root_container(&self, id: &ContainerID) { + if id.is_root() { + self.state.lock().ensure_container(id); + } + } + /// id can be a str, ContainerID, or ContainerIdRaw. /// if it's str it will use Root container, which will not be None #[inline] @@ -963,6 +977,7 @@ impl LoroDoc { if !self.has_container(&id) { return None; } + self.ensure_root_container(&id); Handler::new_attached(id, self.clone()).into_text().ok() } @@ -982,6 +997,7 @@ impl LoroDoc { if !self.has_container(&id) { return None; } + self.ensure_root_container(&id); Handler::new_attached(id, self.clone()).into_list().ok() } @@ -1001,6 +1017,7 @@ impl LoroDoc { if !self.has_container(&id) { return None; } + self.ensure_root_container(&id); Handler::new_attached(id, self.clone()) .into_movable_list() .ok() @@ -1022,6 +1039,7 @@ impl LoroDoc { if !self.has_container(&id) { return None; } + self.ensure_root_container(&id); Handler::new_attached(id, self.clone()).into_map().ok() } @@ -1041,6 +1059,7 @@ impl LoroDoc { if !self.has_container(&id) { return None; } + self.ensure_root_container(&id); Handler::new_attached(id, self.clone()).into_tree().ok() } @@ -1061,6 +1080,7 @@ impl LoroDoc { if !self.has_container(&id) { return None; } + self.ensure_root_container(&id); Handler::new_attached(id, self.clone()).into_counter().ok() } @@ -2062,6 +2082,36 @@ impl LoroDoc { } } +fn pending_root_containers_to_materialize(oplog: &OpLog, changes: &[Change]) -> Vec { + let mut roots = FxHashSet::default(); + for change in changes { + if change.ctr_end() <= oplog.vv().get(&change.id.peer).copied().unwrap_or(0) { + continue; + } + + if oplog.dag.is_before_shallow_root(&change.deps) + || oplog + .dag + .get_change_lamport_from_deps(&change.deps) + .is_some() + { + continue; + } + + for op in change.ops.iter() { + let id = oplog + .arena + .get_container_id(op.container) + .expect("decoded op container should be registered"); + if id.is_root() { + roots.insert(id); + } + } + } + + roots.into_iter().collect() +} + #[derive(Debug, thiserror::Error)] pub enum ChangeTravelError { #[error("Target id not found {0:?}")] @@ -2145,9 +2195,23 @@ impl LoroDoc { pub fn get_changed_containers_in(&self, id: ID, len: usize) -> FxHashSet { self.with_barrier(|| { let mut set = FxHashSet::default(); + let len = i64::try_from(len).unwrap_or(i64::MAX); + let start = i64::from(id.counter); + let end = start.saturating_add(len); + if end <= 0 { + return set; + } + + let start = start.max(0).min(i64::from(i32::MAX)); + let end = end.max(0).min(i64::from(i32::MAX)); + if start >= end { + return set; + } + { let oplog = self.oplog().lock(); - for op in oplog.iter_ops(id.to_span(len)) { + let span = IdSpan::new(id.peer, start as i32, end as i32); + for op in oplog.iter_ops(span) { let id = oplog.arena.get_container_id(op.container()).unwrap(); set.insert(id); } @@ -2170,11 +2234,14 @@ impl LoroDoc { return; }; + self.config + .deleted_root_containers + .lock() + .insert(cid.clone()); if let Err(e) = h.clear() { + self.config.deleted_root_containers.lock().remove(&cid); eprintln!("Failed to clear handler: {:?}", e); - return; } - self.config.deleted_root_containers.lock().insert(cid); } pub fn set_hide_empty_root_containers(&self, hide: bool) { @@ -2214,10 +2281,10 @@ fn find_last_delete_op(oplog: &OpLog, id: ID, idx: ContainerIdx) -> Option { if let InnerContent::List(InnerListOp::Delete(d)) = &op.content { if d.id_start.to_span(d.atom_len()).contains(id) { debug_assert!(op.counter >= change.id().counter); - let op_lamport = change.lamport - + (op.counter - change.id().counter) as loro_common::Lamport; + let op_lamport = + change.lamport + (op.counter - change.id().counter) as loro_common::Lamport; let key = (op_lamport, peer); - if best.map_or(true, |(bk, _)| key > bk) { + if best.is_none_or(|(bk, _)| key > bk) { best = Some((key, ID::new(peer, op.counter))); } } diff --git a/crates/loro-internal/src/op/content.rs b/crates/loro-internal/src/op/content.rs index ac64dd3d9..919b65343 100644 --- a/crates/loro-internal/src/op/content.rs +++ b/crates/loro-internal/src/op/content.rs @@ -54,8 +54,10 @@ impl InnerContent { } } crate::op::InnerContent::Tree(t) => { - let id = t.target().associated_meta_container(); - f(&id); + if let TreeOp::Create { target, .. } = t.as_ref() { + let id = target.associated_meta_container(); + f(&id); + } } crate::op::InnerContent::Future(f) => match &f { #[cfg(feature = "counter")] @@ -447,6 +449,22 @@ mod tests { tree.visit_created_children(&arena, &mut |id| children.push(id.clone())); assert_eq!(children, vec![tree_target.associated_meta_container()]); + for tree in [ + TreeOp::Move { + target: tree_target, + parent: None, + position: FractionalIndex::default(), + }, + TreeOp::Delete { + target: tree_target, + }, + ] { + let tree = InnerContent::Tree(Arc::new(tree)); + let mut children = Vec::new(); + tree.visit_created_children(&arena, &mut |id| children.push(id.clone())); + assert!(children.is_empty()); + } + let future = InnerContent::Future(FutureInnerContent::Unknown { prop: 1, value: Box::new(OwnedValue::False), diff --git a/crates/loro-internal/src/oplog.rs b/crates/loro-internal/src/oplog.rs index 3282e8de5..a1e415b81 100644 --- a/crates/loro-internal/src/oplog.rs +++ b/crates/loro-internal/src/oplog.rs @@ -679,6 +679,7 @@ impl OpLog { .flat_map(move |span| self.change_store.iter_changes(span)) } + #[allow(dead_code)] pub(crate) fn iter_changes_causally_rev<'a>( &'a self, from: &VersionVector, diff --git a/crates/loro-internal/src/oplog/change_store.rs b/crates/loro-internal/src/oplog/change_store.rs index 1b5aa8a98..750a8d7ce 100644 --- a/crates/loro-internal/src/oplog/change_store.rs +++ b/crates/loro-internal/src/oplog/change_store.rs @@ -205,6 +205,16 @@ impl ChangeStore { for span in spans { let mut span = *span; span.normalize_(); + if span.counter.end <= 0 { + continue; + } + + span.counter.start = span.counter.start.max(0); + span.counter.end = span.counter.end.max(0); + if span.counter.start >= span.counter.end { + continue; + } + // PERF: this can be optimized by reusing the current encoded blocks // In the current method, it needs to parse and re-encode the blocks for c in self.iter_changes(span) { @@ -427,6 +437,7 @@ impl ChangeStore { }) } + #[allow(dead_code)] pub(crate) fn get_blocks_in_range(&self, id_span: IdSpan) -> VecDeque> { let mut inner = self.inner.lock(); let start_counter = inner @@ -824,7 +835,7 @@ mod mut_inner_kv { panic!("counter should be continuous") } - if let Some(rollback) = rollback.as_deref_mut() { + if let Some(rollback) = &mut rollback { rollback.record_block_before_mutation(*_id, block.clone()); } @@ -1071,7 +1082,7 @@ mod mut_inner_kv { if !new_change.ops.is_empty() { total_len += new_change.atom_len(); - self.insert_change_inner(new_change, false, false, rollback.as_deref_mut()); + self.insert_change_inner(new_change, false, false, rollback); } assert_eq!(total_len, original_len); @@ -1200,8 +1211,7 @@ mod mut_inner_kv { let mut inner = self.inner.lock(); let Some((next_back_id, next_back_bytes)) = kv .scan(Bound::Unbounded, Bound::Included(&id.to_bytes())) - .filter(|(id, _)| id.len() == 12) - .next_back() + .rfind(|(id, _)| id.len() == 12) else { return; }; @@ -1314,6 +1324,7 @@ impl ChangesBlock { }) } + #[allow(dead_code)] pub(crate) fn content(&self) -> &ChangesBlockContent { &self.content } @@ -1586,6 +1597,7 @@ impl ChangesBlockContent { } } + #[allow(dead_code)] pub(crate) fn len_changes(&self) -> usize { match self { ChangesBlockContent::Changes(changes) => changes.len(), diff --git a/crates/loro-internal/src/oplog/change_store/block_encode.rs b/crates/loro-internal/src/oplog/change_store/block_encode.rs index 6c779ede1..1d733f29c 100644 --- a/crates/loro-internal/src/oplog/change_store/block_encode.rs +++ b/crates/loro-internal/src/oplog/change_store/block_encode.rs @@ -568,7 +568,7 @@ pub fn decode_block( keys, }; let positions = PositionArena::decode_v2(&positions)?; - let positions = positions.parse_to_positions(); + let positions = positions.try_parse_to_positions()?; let cids: &Vec = header.cids.get_or_try_init(|| { ContainerArena::decode(&cids)? .iter() diff --git a/crates/loro-internal/src/oplog/change_store/iter.rs b/crates/loro-internal/src/oplog/change_store/iter.rs index 27bba4113..153eaf0a2 100644 --- a/crates/loro-internal/src/oplog/change_store/iter.rs +++ b/crates/loro-internal/src/oplog/change_store/iter.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use std::{ collections::{BinaryHeap, VecDeque}, sync::Arc, diff --git a/crates/loro-internal/src/oplog/loro_dag.rs b/crates/loro-internal/src/oplog/loro_dag.rs index aeeead3b8..d71a08e5d 100644 --- a/crates/loro-internal/src/oplog/loro_dag.rs +++ b/crates/loro-internal/src/oplog/loro_dag.rs @@ -1203,7 +1203,7 @@ impl AppDag { let last_ids: Frontiers = this .iter() .filter_map(|(client_id, cnt)| { - if *cnt == 0 { + if *cnt <= 0 { return None; } @@ -1234,7 +1234,7 @@ impl AppDag { let last_ids: Frontiers = this .iter() .filter_map(|(client_id, cnt)| { - if *cnt == 0 { + if *cnt <= 0 { return None; } diff --git a/crates/loro-internal/src/state.rs b/crates/loro-internal/src/state.rs index 545d414ac..7b590c4ab 100644 --- a/crates/loro-internal/src/state.rs +++ b/crates/loro-internal/src/state.rs @@ -9,7 +9,7 @@ use dead_containers_cache::DeadContainersCache; use enum_as_inner::EnumAsInner; use enum_dispatch::enum_dispatch; use itertools::Itertools; -use loro_common::{ContainerID, LoroError, LoroResult, TreeID}; +use loro_common::{ContainerID, Lamport, LoroError, LoroResult, TreeID}; use loro_delta::DeltaItem; use rustc_hash::{FxHashMap, FxHashSet}; use tracing::{info_span, instrument, warn}; @@ -68,6 +68,90 @@ pub(crate) fn fail_next_import_state_apply_for_test() { FAIL_NEXT_IMPORT_STATE_APPLY.with(|fail| fail.set(true)); } +fn visible_container_value_is_empty(kind: ContainerType, value: &LoroValue) -> bool { + match kind { + ContainerType::Text => value.as_string().is_some_and(|value| value.is_empty()), + ContainerType::Map | ContainerType::List | ContainerType::MovableList => { + value.is_empty_collection() + } + ContainerType::Tree => value.as_list().is_some_and(|value| value.is_empty()), + #[cfg(feature = "counter")] + ContainerType::Counter => false, + ContainerType::Unknown(_) => false, + } +} + +fn deleted_root_container_value_is_cleared(kind: ContainerType, value: &LoroValue) -> bool { + match kind { + #[cfg(feature = "counter")] + ContainerType::Counter => value.as_double().is_some_and(|value| *value == 0.0), + _ => visible_container_value_is_empty(kind, value), + } +} + +fn state_decode_error(message: impl Into>) -> LoroError { + LoroError::DecodeError(message.into()) +} + +fn decode_peer_table(bytes: &mut &[u8], context: &str) -> LoroResult> { + let peer_num = leb128::read::unsigned(bytes) + .map_err(|_| state_decode_error(format!("{context}: invalid peer table length")))?; + let peer_num = usize::try_from(peer_num) + .map_err(|_| state_decode_error(format!("{context}: peer table length overflow")))?; + let peer_bytes_len = peer_num + .checked_mul(std::mem::size_of::()) + .ok_or_else(|| state_decode_error(format!("{context}: peer table byte length overflow")))?; + if bytes.len() < peer_bytes_len { + return Err(state_decode_error(format!( + "{context}: truncated peer table" + ))); + } + + let peer_bytes = &bytes[..peer_bytes_len]; + let peers = peer_bytes + .chunks_exact(std::mem::size_of::()) + .map(|chunk| { + let mut buf = [0u8; std::mem::size_of::()]; + buf.copy_from_slice(chunk); + PeerID::from_le_bytes(buf) + }) + .collect(); + *bytes = &bytes[peer_bytes_len..]; + Ok(peers) +} + +fn decode_peer_from_table(peers: &[PeerID], peer_idx: usize, context: &str) -> LoroResult { + peers + .get(peer_idx) + .copied() + .ok_or_else(|| state_decode_error(format!("{context}: peer index out of range"))) +} + +fn read_state_leb_u64(bytes: &mut &[u8], context: &str) -> LoroResult { + leb128::read::unsigned(bytes) + .map_err(|_| state_decode_error(format!("{context}: invalid integer"))) +} + +fn decode_counter(counter: i32, context: &str) -> LoroResult { + if counter < 0 { + return Err(state_decode_error(format!("{context}: negative counter"))); + } + + Ok(counter) +} + +fn decode_lamport_from_delta( + counter: i32, + lamport_sub_counter: i32, + context: &str, +) -> LoroResult { + decode_counter(counter, context)?; + let lamport = counter + .checked_add(lamport_sub_counter) + .ok_or_else(|| state_decode_error(format!("{context}: lamport overflow")))?; + u32::try_from(lamport).map_err(|_| state_decode_error(format!("{context}: negative lamport"))) +} + pub struct DocState { pub(super) peer: Arc, @@ -780,7 +864,7 @@ impl DocState { pub(crate) fn iter_all_containers_mut( &mut self, - ) -> impl Iterator { + ) -> impl Iterator { self.store.iter_all_containers() } @@ -834,6 +918,61 @@ impl DocState { .unwrap_or_else(|| container_idx.get_type().default_value()) } + pub(crate) fn get_map_value_by_key( + &mut self, + container_idx: ContainerIdx, + key: &str, + ) -> Option { + self.store.map_get(container_idx, key) + } + + pub(crate) fn get_map_len(&mut self, container_idx: ContainerIdx) -> usize { + self.store.map_len(container_idx) + } + + pub(crate) fn get_map_keys(&mut self, container_idx: ContainerIdx) -> Vec { + self.store.map_keys(container_idx) + } + + pub(crate) fn get_map_values(&mut self, container_idx: ContainerIdx) -> Vec { + self.store.map_values(container_idx) + } + + pub(crate) fn get_map_entries( + &mut self, + container_idx: ContainerIdx, + ) -> Vec<(InternalString, LoroValue)> { + self.store.map_entries(container_idx) + } + + pub(crate) fn get_list_value_at( + &mut self, + container_idx: ContainerIdx, + index: usize, + ) -> Option { + self.store.list_get(container_idx, index) + } + + pub(crate) fn get_list_len(&mut self, container_idx: ContainerIdx) -> usize { + self.store.list_len(container_idx) + } + + pub(crate) fn get_list_values(&mut self, container_idx: ContainerIdx) -> Vec { + self.store.list_values(container_idx) + } + + pub(crate) fn get_text_unicode_len(&mut self, container_idx: ContainerIdx) -> usize { + self.store.text_unicode_len(container_idx).unwrap_or(0) + } + + pub(crate) fn get_text_utf16_len(&mut self, container_idx: ContainerIdx) -> usize { + self.store.text_utf16_len(container_idx).unwrap_or(0) + } + + pub(crate) fn has_decoded_container_state(&mut self, container_idx: ContainerIdx) -> bool { + self.store.has_decoded_state(container_idx) + } + /// Set the state of the container with the given container idx. /// This is only used for decode. /// @@ -893,7 +1032,7 @@ impl DocState { let diff: Vec<_> = self .store .iter_all_containers() - .map(|(&idx, state)| InternalContainerDiff { + .map(|(idx, state)| InternalContainerDiff { idx, bring_back: false, diff: state @@ -993,8 +1132,14 @@ impl DocState { match &id { loro_common::ContainerID::Root { name, .. } => { let v = self.get_container_deep_value(root_idx); - if (should_hide_empty_root_container || deleted_root_container.contains(&id)) - && v.is_empty_collection() + if should_hide_empty_root_container + && visible_container_value_is_empty(root_idx.get_type(), &v) + { + continue; + } + + if deleted_root_container.contains(&id) + && deleted_root_container_value_is_cleared(root_idx.get_type(), &v) { continue; } @@ -1032,12 +1177,18 @@ impl DocState { } pub(crate) fn preferred_root_containers(&mut self) -> Vec { - let flag = self.store.load_all(); + let flag = self.store.load_root_containers(); let roots = self.arena.root_containers(flag); let mut selected = FxHashMap::default(); let mut names = Vec::new(); for idx in roots { + let Some(id) = self.arena.idx_to_id(idx) else { + continue; + }; + if !self.store.contains_id(&id) { + continue; + } let Some(name) = self.root_container_name(idx) else { continue; }; @@ -1068,11 +1219,17 @@ impl DocState { &mut self, root_index: &InternalString, ) -> Option { - let flag = self.store.load_all(); + let flag = self.store.load_root_containers(); let roots = self.arena.root_containers(flag); let mut selected = None; for idx in roots { + let Some(id) = self.arena.idx_to_id(idx) else { + continue; + }; + if !self.store.contains_id(&id) { + continue; + } let Some(name) = self.root_container_name(idx) else { continue; }; @@ -1102,10 +1259,11 @@ impl DocState { } fn root_container_is_empty(&mut self, idx: ContainerIdx) -> bool { - self.store + let value = self + .store .get_value(idx) - .unwrap_or_else(|| idx.get_type().default_value()) - .is_empty_collection() + .unwrap_or_else(|| idx.get_type().default_value()); + visible_container_value_is_empty(idx.get_type(), &value) } pub fn get_all_container_value_flat(&mut self) -> LoroValue { @@ -1125,10 +1283,9 @@ impl DocState { id: Option, ) -> LoroValue { let id = id.unwrap_or_else(|| self.arena.idx_to_id(container).unwrap()); - let Some(state) = self.store.get_container_mut(container) else { + let Some(value) = self.store.get_value(container) else { return container.get_type().default_value(); }; - let value = state.get_value(); let cid_str = LoroValue::String(format!("idx:{}, id:{}", container.to_index(), id).into()); match value { LoroValue::Container(_) => unreachable!(), @@ -1257,6 +1414,7 @@ impl DocState { .root_containers(flag) .iter() .map(|x| self.arena.get_container_id(*x).unwrap()) + .filter(|id| self.store.contains_id(id)) .collect_vec(); while let Some(id) = to_visit.pop() { diff --git a/crates/loro-internal/src/state/analyzer.rs b/crates/loro-internal/src/state/analyzer.rs index 2eb3b7a61..1ecef1cee 100644 --- a/crates/loro-internal/src/state/analyzer.rs +++ b/crates/loro-internal/src/state/analyzer.rs @@ -40,7 +40,7 @@ impl DocAnalysis { let mut containers = FxHashMap::default(); let mut state = doc.app_state().lock(); let alive_containers = state.get_all_alive_containers(); - for (&idx, c) in state.iter_all_containers_mut() { + for (idx, c) in state.iter_all_containers_mut() { let ops_num = ops_nums.get(&idx).unwrap_or(&0); let id = doc.arena().get_container_id(idx).unwrap(); let dropped = !alive_containers.contains(&id); diff --git a/crates/loro-internal/src/state/container_store.rs b/crates/loro-internal/src/state/container_store.rs index ec12f728a..3dc4255ce 100644 --- a/crates/loro-internal/src/state/container_store.rs +++ b/crates/loro-internal/src/state/container_store.rs @@ -7,7 +7,7 @@ use crate::{ }; use bytes::Bytes; use inner_store::InnerStore; -use loro_common::{ContainerID, LoroResult, LoroValue}; +use loro_common::{ContainerID, InternalString, LoroResult, LoroValue}; use std::sync::Arc; pub(crate) use container_wrapper::ContainerWrapper; @@ -107,8 +107,67 @@ impl ContainerStore { pub fn get_value(&mut self, idx: ContainerIdx) -> Option { self.store - .get_mut(idx) - .map(|c| c.get_value(idx, ctx!(self))) + .with_container_for_read(idx, |c| c.get_value(idx, ctx!(self))) + } + + pub fn map_get(&mut self, idx: ContainerIdx, key: &str) -> Option { + self.store + .with_container_for_read(idx, |c| c.map_get(idx, ctx!(self), key))? + } + + pub fn map_len(&mut self, idx: ContainerIdx) -> usize { + self.store + .with_container_for_read(idx, |c| c.map_len(idx, ctx!(self))) + .unwrap_or(0) + } + + pub fn map_keys(&mut self, idx: ContainerIdx) -> Vec { + self.store + .with_container_for_read(idx, |c| c.map_keys(idx, ctx!(self))) + .unwrap_or_default() + } + + pub fn map_values(&mut self, idx: ContainerIdx) -> Vec { + self.store + .with_container_for_read(idx, |c| c.map_values(idx, ctx!(self))) + .unwrap_or_default() + } + + pub fn map_entries(&mut self, idx: ContainerIdx) -> Vec<(InternalString, LoroValue)> { + self.store + .with_container_for_read(idx, |c| c.map_entries(idx, ctx!(self))) + .unwrap_or_default() + } + + pub fn list_get(&mut self, idx: ContainerIdx, index: usize) -> Option { + self.store + .with_container_for_read(idx, |c| c.list_get(idx, ctx!(self), index))? + } + + pub fn list_len(&mut self, idx: ContainerIdx) -> usize { + self.store + .with_container_for_read(idx, |c| c.list_len(idx, ctx!(self))) + .unwrap_or(0) + } + + pub fn list_values(&mut self, idx: ContainerIdx) -> Vec { + self.store + .with_container_for_read(idx, |c| c.list_values(idx, ctx!(self))) + .unwrap_or_default() + } + + pub fn text_unicode_len(&mut self, idx: ContainerIdx) -> Option { + self.store + .with_container_for_read(idx, |c| c.text_unicode_len(idx, ctx!(self)))? + } + + pub fn text_utf16_len(&mut self, idx: ContainerIdx) -> Option { + self.store + .with_container_for_read(idx, |c| c.text_utf16_len(idx, ctx!(self)))? + } + + pub fn has_decoded_state(&mut self, idx: ContainerIdx) -> bool { + self.store.has_decoded_state(idx) } pub fn encode(&mut self) -> Bytes { @@ -168,7 +227,7 @@ impl ContainerStore { pub fn iter_and_decode_all(&mut self) -> impl Iterator { self.store.iter_all_containers_mut().map(|(idx, v)| { v.get_state_mut( - *idx, + idx, ContainerCreationContext { configure: &self.conf, peer: self.peer.load(std::sync::atomic::Ordering::Relaxed), @@ -187,7 +246,7 @@ impl ContainerStore { pub fn iter_all_containers( &mut self, - ) -> impl Iterator { + ) -> impl Iterator { self.store.iter_all_containers_mut() } @@ -200,6 +259,11 @@ impl ContainerStore { LoadAllFlag } + pub fn load_root_containers(&mut self) -> LoadAllFlag { + self.store.load_roots(); + LoadAllFlag + } + pub(super) fn get_or_create_mut(&mut self, idx: ContainerIdx) -> &mut State { self.store .get_or_insert_with(idx, || { @@ -256,7 +320,7 @@ impl ContainerStore { #[allow(unused)] fn check_eq_after_parsing(&mut self, other: &mut ContainerStore) { for (idx, container) in self.store.iter_all_containers_mut() { - let id = self.arena.get_container_id(*idx).unwrap(); + let id = self.arena.get_container_id(idx).unwrap(); let other_idx = other.arena.register_container(&id); let other_container = other .store @@ -269,7 +333,7 @@ impl ContainerStore { id, other_id ); assert_eq!( - container.get_value(*idx, ctx!(self)), + container.get_value(idx, ctx!(self)), other_container.get_value(other_idx, ctx!(other)), "value mismatch" ); @@ -299,7 +363,8 @@ impl ContainerStore { mod test { use super::*; use crate::{ - cursor::PosType, state::TreeParentId, ListHandler, LoroDoc, MapHandler, MovableListHandler, + cursor::PosType, state::TreeParentId, ContainerType, ListHandler, LoroDoc, MapHandler, + MovableListHandler, }; fn decode_container_store(bytes: Bytes) -> ContainerStore { @@ -348,12 +413,31 @@ mod test { doc } + fn export_container_store(doc: &LoroDoc) -> Bytes { + let mut state = doc.app_state().lock(); + state.ensure_all_alive_containers(); + state.store.encode() + } + #[test] fn test_container_store_exports_imports() { let doc = init_doc(); - let mut s = doc.app_state().lock(); - let bytes = s.store.encode(); + let bytes = export_container_store(&doc); let mut new_store = decode_container_store(bytes); + let mut s = doc.app_state().lock(); s.store.check_eq_after_parsing(&mut new_store); } + + #[test] + fn first_lazy_read_caches_value() { + let doc = init_doc(); + let bytes = export_container_store(&doc); + let mut store = decode_container_store(bytes); + let map_id = ContainerID::new_root("map", ContainerType::Map); + let map_idx = store.arena.register_container(&map_id); + + assert!(!store.store.has_cached_value_for_test(map_idx)); + assert_eq!(store.map_len(map_idx), 2); + assert!(store.store.has_cached_value_for_test(map_idx)); + } } diff --git a/crates/loro-internal/src/state/container_store/container_wrapper.rs b/crates/loro-internal/src/state/container_store/container_wrapper.rs index f486195e4..696b0e623 100644 --- a/crates/loro-internal/src/state/container_store/container_wrapper.rs +++ b/crates/loro-internal/src/state/container_store/container_wrapper.rs @@ -1,5 +1,8 @@ +use std::collections::BTreeMap; + use bytes::Bytes; -use loro_common::{ContainerID, ContainerType, LoroResult, LoroValue}; +use loro_common::{ContainerID, ContainerType, InternalString, LoroError, LoroResult, LoroValue}; +use once_cell::sync::OnceCell; use tracing::trace; #[cfg(feature = "counter")] @@ -9,8 +12,9 @@ use crate::{ container::idx::ContainerIdx, state::{ unknown_state::UnknownState, ContainerCreationContext, ContainerState, FastStateSnapshot, - ListState, MapState, MovableListState, RichtextState, State, TreeState, + IndexType, ListState, MapState, MovableListState, RichtextState, State, TreeState, }, + utils::utf16::{count_unicode_chars, count_utf16_len}, }; #[derive(Debug)] @@ -18,18 +22,87 @@ pub(crate) struct ContainerWrapper { depth: usize, kind: ContainerType, parent: Option, - /// The possible combinations of is_some() are: - /// - /// 1. bytes: new container decoded from bytes - /// 2. bytes + value: new container decoded from bytes, with value decoded - /// 3. state + bytes + value: new container decoded from bytes, with value and state decoded - /// 4. state + data: ContainerData, + flushed: bool, +} + +#[derive(Debug)] +enum ContainerData { + State(State), + Lazy(Box), +} + +#[derive(Debug)] +struct LazyContainerData { + /// Lazily decoded snapshot bytes and optional decoded value. bytes: Option, - value: Option, + value: Option, bytes_offset_for_value: Option, bytes_offset_for_state: Option, - state: Option, - flushed: bool, +} + +#[derive(Debug)] +enum LazyDecodedValue { + Value(LoroValue), + Text { + value: LoroValue, + unicode_len: usize, + utf16_len: usize, + }, + Map { + ordered: BTreeMap, + value: OnceCell, + }, +} + +impl LazyDecodedValue { + fn text(value: LoroValue) -> Self { + let text = value + .as_string() + .expect("decoded text value should be a string"); + Self::Text { + unicode_len: count_unicode_chars(text.as_bytes()), + utf16_len: count_utf16_len(text.as_bytes()), + value, + } + } + + fn to_loro_value(&self) -> LoroValue { + match self { + Self::Value(value) | Self::Text { value, .. } => value.clone(), + Self::Map { ordered, value } => value + .get_or_init(|| { + LoroValue::Map( + ordered + .iter() + .map(|(key, value)| (key.clone(), value.clone())) + .collect::(), + ) + }) + .clone(), + } + } + + fn as_map(&self) -> Option<&BTreeMap> { + match self { + Self::Map { ordered, .. } => Some(ordered), + _ => None, + } + } + + fn unicode_len(&self) -> Option { + match self { + Self::Text { unicode_len, .. } => Some(*unicode_len), + _ => None, + } + } + + fn utf16_len(&self) -> Option { + match self { + Self::Text { utf16_len, .. } => Some(*utf16_len), + _ => None, + } + } } impl ContainerWrapper { @@ -43,11 +116,7 @@ impl ContainerWrapper { depth, parent, kind: idx.get_type(), - state: Some(state), - bytes: None, - value: None, - bytes_offset_for_state: None, - bytes_offset_for_value: None, + data: ContainerData::State(state), flushed: false, } } @@ -56,16 +125,26 @@ impl ContainerWrapper { self.depth } + pub(crate) fn kind(&self) -> ContainerType { + self.kind + } + /// It will not decode the state if it is not decoded #[allow(unused)] pub fn try_get_state(&self) -> Option<&State> { - self.state.as_ref() + match &self.data { + ContainerData::State(state) => Some(state), + ContainerData::Lazy(_) => None, + } } /// It will decode the state if it is not decoded pub fn get_state(&mut self, idx: ContainerIdx, ctx: ContainerCreationContext) -> &State { self.decode_state(idx, ctx).unwrap(); - self.state.as_ref().expect("ContainerWrapper is empty") + match &self.data { + ContainerData::State(state) => state, + ContainerData::Lazy(_) => unreachable!("ContainerWrapper state should be decoded"), + } } /// It will decode the state if it is not decoded @@ -75,32 +154,290 @@ impl ContainerWrapper { ctx: ContainerCreationContext, ) -> &mut State { self.decode_state(idx, ctx).unwrap(); - self.bytes = None; - self.value = None; self.flushed = false; - self.state.as_mut().unwrap() + match &mut self.data { + ContainerData::State(state) => state, + ContainerData::Lazy(_) => unreachable!("ContainerWrapper state should be decoded"), + } + } + + pub(crate) fn try_get_value( + &mut self, + idx: ContainerIdx, + ctx: ContainerCreationContext, + ) -> LoroResult { + match &mut self.data { + ContainerData::State(state) => { + trace!("state"); + Ok(state.get_value()) + } + ContainerData::Lazy(lazy) if lazy.value.is_some() => { + trace!("value"); + Ok(lazy.value.as_ref().unwrap().to_loro_value()) + } + ContainerData::Lazy(_) => { + trace!("decode value"); + self.decode_value(idx, ctx)?; + match &mut self.data { + ContainerData::State(state) => Ok(state.get_value()), + ContainerData::Lazy(lazy) => Ok(lazy.value.as_ref().unwrap().to_loro_value()), + } + } + } } pub fn get_value(&mut self, idx: ContainerIdx, ctx: ContainerCreationContext) -> LoroValue { - if let Some(v) = self.value.as_ref() { - trace!("value"); - return v.clone(); + self.try_get_value(idx, ctx).unwrap() + } + + pub fn map_get( + &mut self, + idx: ContainerIdx, + ctx: ContainerCreationContext, + key: &str, + ) -> Option { + match &mut self.data { + ContainerData::State(state) => state.as_map_state().unwrap().get(key).cloned(), + ContainerData::Lazy(_) => { + self.decode_value(idx, ctx).unwrap(); + match &self.data { + ContainerData::Lazy(lazy) => lazy.value.as_ref()?.as_map()?.get(key).cloned(), + ContainerData::State(state) => state.as_map_state().unwrap().get(key).cloned(), + } + } + } + } + + pub fn map_len(&mut self, idx: ContainerIdx, ctx: ContainerCreationContext) -> usize { + match &mut self.data { + ContainerData::State(state) => state.as_map_state().unwrap().len(), + ContainerData::Lazy(_) => { + self.decode_value(idx, ctx).unwrap(); + match &self.data { + ContainerData::Lazy(lazy) => lazy + .value + .as_ref() + .and_then(|value| value.as_map()) + .map_or(0, BTreeMap::len), + ContainerData::State(state) => state.as_map_state().unwrap().len(), + } + } + } + } + + pub fn map_keys( + &mut self, + idx: ContainerIdx, + ctx: ContainerCreationContext, + ) -> Vec { + match &mut self.data { + ContainerData::State(state) => state + .as_map_state() + .unwrap() + .iter() + .filter(|(_, value)| value.value.is_some()) + .map(|(key, _)| key.clone()) + .collect(), + ContainerData::Lazy(_) => { + self.decode_value(idx, ctx).unwrap(); + match &self.data { + ContainerData::Lazy(lazy) => lazy + .value + .as_ref() + .and_then(|value| value.as_map()) + .map(|map| map.keys().map(|key| key.as_str().into()).collect()) + .unwrap_or_default(), + ContainerData::State(state) => state + .as_map_state() + .unwrap() + .iter() + .filter(|(_, value)| value.value.is_some()) + .map(|(key, _)| key.clone()) + .collect(), + } + } } + } - self.decode_value(idx, ctx).unwrap(); - if self.value.is_none() { - trace!("state"); - return self.state.as_mut().unwrap().get_value(); + pub fn map_values( + &mut self, + idx: ContainerIdx, + ctx: ContainerCreationContext, + ) -> Vec { + match &mut self.data { + ContainerData::State(state) => state + .as_map_state() + .unwrap() + .iter() + .filter_map(|(_, value)| value.value.clone()) + .collect(), + ContainerData::Lazy(_) => { + self.decode_value(idx, ctx).unwrap(); + match &self.data { + ContainerData::Lazy(lazy) => lazy + .value + .as_ref() + .and_then(|value| value.as_map()) + .map(|map| map.values().cloned().collect()) + .unwrap_or_default(), + ContainerData::State(state) => state + .as_map_state() + .unwrap() + .iter() + .filter_map(|(_, value)| value.value.clone()) + .collect(), + } + } } + } - trace!("devalue"); - self.value.as_ref().unwrap().clone() + pub fn map_entries( + &mut self, + idx: ContainerIdx, + ctx: ContainerCreationContext, + ) -> Vec<(InternalString, LoroValue)> { + match &mut self.data { + ContainerData::State(state) => state + .as_map_state() + .unwrap() + .iter() + .filter_map(|(key, value)| value.value.clone().map(|value| (key.clone(), value))) + .collect(), + ContainerData::Lazy(_) => { + self.decode_value(idx, ctx).unwrap(); + match &self.data { + ContainerData::Lazy(lazy) => lazy + .value + .as_ref() + .and_then(|value| value.as_map()) + .map(|map| { + map.iter() + .map(|(key, value)| (key.as_str().into(), value.clone())) + .collect() + }) + .unwrap_or_default(), + ContainerData::State(state) => state + .as_map_state() + .unwrap() + .iter() + .filter_map(|(key, value)| { + value.value.clone().map(|value| (key.clone(), value)) + }) + .collect(), + } + } + } } - pub fn encode(&mut self) -> Bytes { - if let Some(bytes) = self.bytes.as_ref() { - return bytes.clone(); + pub fn text_unicode_len( + &mut self, + idx: ContainerIdx, + ctx: ContainerCreationContext, + ) -> Option { + match &mut self.data { + ContainerData::State(state) => { + Some(state.as_richtext_state_mut().unwrap().len_unicode()) + } + ContainerData::Lazy(_) => { + self.decode_value(idx, ctx).unwrap(); + match &mut self.data { + ContainerData::State(state) => { + Some(state.as_richtext_state_mut().unwrap().len_unicode()) + } + ContainerData::Lazy(lazy) => lazy.value.as_ref()?.unicode_len(), + } + } + } + } + + pub fn text_utf16_len( + &mut self, + idx: ContainerIdx, + ctx: ContainerCreationContext, + ) -> Option { + match &mut self.data { + ContainerData::State(state) => Some(state.as_richtext_state_mut().unwrap().len_utf16()), + ContainerData::Lazy(_) => { + self.decode_value(idx, ctx).unwrap(); + match &mut self.data { + ContainerData::State(state) => { + Some(state.as_richtext_state_mut().unwrap().len_utf16()) + } + ContainerData::Lazy(lazy) => lazy.value.as_ref()?.utf16_len(), + } + } + } + } + + pub fn list_get( + &mut self, + idx: ContainerIdx, + ctx: ContainerCreationContext, + index: usize, + ) -> Option { + match &mut self.data { + ContainerData::State(state) => match self.kind { + ContainerType::List => state.as_list_state().unwrap().get(index).cloned(), + ContainerType::MovableList => state + .as_movable_list_state() + .unwrap() + .get(index, IndexType::ForUser) + .cloned(), + _ => None, + }, + ContainerData::Lazy(_) => match self.get_value(idx, ctx) { + LoroValue::List(list) => list.get(index).cloned(), + _ => None, + }, + } + } + + pub fn list_len(&mut self, idx: ContainerIdx, ctx: ContainerCreationContext) -> usize { + match &mut self.data { + ContainerData::State(state) => match self.kind { + ContainerType::List => state.as_list_state().unwrap().len(), + ContainerType::MovableList => state.as_movable_list_state().unwrap().len(), + _ => 0, + }, + ContainerData::Lazy(_) => match self.get_value(idx, ctx) { + LoroValue::List(list) => list.len(), + _ => 0, + }, + } + } + + pub fn list_values( + &mut self, + idx: ContainerIdx, + ctx: ContainerCreationContext, + ) -> Vec { + match &mut self.data { + ContainerData::State(state) => match self.kind { + ContainerType::List => state.as_list_state().unwrap().iter().cloned().collect(), + ContainerType::MovableList => state + .as_movable_list_state() + .unwrap() + .iter() + .cloned() + .collect(), + _ => Vec::new(), + }, + ContainerData::Lazy(_) => match self.get_value(idx, ctx) { + LoroValue::List(list) => list.iter().cloned().collect(), + _ => Vec::new(), + }, } + } + + pub fn encode(&mut self) -> Bytes { + let ContainerData::State(state) = &mut self.data else { + let lazy = match &self.data { + ContainerData::Lazy(lazy) => lazy, + ContainerData::State(_) => unreachable!(), + }; + assert!(self.flushed, "lazy container should be flushed"); + return lazy.bytes.as_ref().unwrap().clone(); + }; // ContainerType // Depth @@ -110,13 +447,8 @@ impl ContainerWrapper { output.push(self.kind.to_u8()); leb128::write::unsigned(&mut output, self.depth as u64).unwrap(); postcard::to_io(&self.parent, &mut output).unwrap(); - self.state - .as_mut() - .unwrap() - .encode_snapshot_fast(&mut output); - let ans: Bytes = output.into(); - self.bytes = Some(ans.clone()); - ans + state.encode_snapshot_fast(&mut output); + output.into() } #[allow(unused)] @@ -128,86 +460,161 @@ impl ContainerWrapper { } pub fn new_from_bytes(bytes: Bytes) -> Self { - let kind = ContainerType::try_from_u8(bytes[0]).unwrap(); + Self::try_new_from_bytes(bytes).unwrap() + } + + pub fn try_new_from_bytes(bytes: Bytes) -> LoroResult { + if bytes.is_empty() { + return Err(LoroError::DecodeError( + "Decode container state failed".to_string().into_boxed_str(), + )); + } + + let kind = ContainerType::try_from_u8(bytes[0])?; let mut reader = &bytes[1..]; - let depth = leb128::read::unsigned(&mut reader).unwrap(); - let (parent, reader) = postcard::take_from_bytes(reader).unwrap(); + let depth = leb128::read::unsigned(&mut reader).map_err(|_| { + LoroError::DecodeError("Decode container state failed".to_string().into_boxed_str()) + })?; + let (parent, reader) = postcard::take_from_bytes(reader).map_err(|_| { + LoroError::DecodeError("Decode container state failed".to_string().into_boxed_str()) + })?; let size = bytes.len() - reader.len(); - Self { + Ok(Self { depth: depth as usize, kind, parent, - state: None, - value: None, - bytes: Some(bytes.clone()), - bytes_offset_for_value: Some(size), - bytes_offset_for_state: None, + data: ContainerData::Lazy(Box::new(LazyContainerData { + value: None, + bytes: Some(bytes.clone()), + bytes_offset_for_value: Some(size), + bytes_offset_for_state: None, + })), flushed: true, - } + }) } - #[allow(unused)] - pub fn ensure_value(&mut self, idx: ContainerIdx, ctx: ContainerCreationContext) -> &LoroValue { - if self.value.is_some() { - } else if self.state.is_some() { - let value = self.state.as_mut().unwrap().get_value(); - self.value = Some(value); - } else { - self.decode_value(idx, ctx).unwrap(); + fn decode_value(&mut self, idx: ContainerIdx, ctx: ContainerCreationContext) -> LoroResult<()> { + if matches!(self.data, ContainerData::State(_)) { + return Ok(()); } - self.value.as_ref().unwrap() - } - - fn decode_value(&mut self, idx: ContainerIdx, ctx: ContainerCreationContext) -> LoroResult<()> { - if self.value.is_some() || self.state.is_some() { + if matches!(&self.data, ContainerData::Lazy(lazy) if lazy.value.is_some()) { return Ok(()); } - let Some(bytes) = self.bytes.as_ref() else { + let (v, state_offset, decoded_state) = self.decode_value_from_bytes(idx, ctx)?; + if let Some(state) = decoded_state { + self.data = ContainerData::State(state); return Ok(()); + } + + let ContainerData::Lazy(lazy) = &mut self.data else { + unreachable!(); }; + lazy.value = Some(v); + lazy.bytes_offset_for_state = Some(state_offset); + Ok(()) + } - if self.bytes_offset_for_value.is_none() { - let mut reader: &[u8] = bytes; + pub(super) fn has_cached_value(&self) -> bool { + match &self.data { + ContainerData::State(_) => true, + ContainerData::Lazy(lazy) => lazy.value.is_some(), + } + } + + #[cfg(test)] + pub(super) fn has_cached_value_for_test(&self) -> bool { + self.has_cached_value() + } + + fn value_bytes_and_offset(&mut self) -> Option<(Bytes, usize)> { + let ContainerData::Lazy(lazy) = &mut self.data else { + return None; + }; + + let bytes = lazy.bytes.as_ref()?.clone(); + if lazy.bytes_offset_for_value.is_none() { + let mut reader: &[u8] = &bytes; reader = &reader[1..]; let _depth = leb128::read::unsigned(&mut reader).unwrap(); let (_parent, reader) = postcard::take_from_bytes::>(reader).unwrap(); // SAFETY: bytes is a slice of b let size = bytes.len() - reader.len(); - self.bytes_offset_for_value = Some(size); + lazy.bytes_offset_for_value = Some(size); } - let value_offset = self.bytes_offset_for_value.unwrap(); + Some((bytes, lazy.bytes_offset_for_value.unwrap())) + } + + fn decode_value_from_bytes( + &mut self, + idx: ContainerIdx, + ctx: ContainerCreationContext, + ) -> LoroResult<(LazyDecodedValue, usize, Option)> { + let Some((bytes, value_offset)) = self.value_bytes_and_offset() else { + return Ok((LazyDecodedValue::Value(self.kind.default_value()), 0, None)); + }; let b = &bytes[value_offset..]; - let (v, rest) = match self.kind { - ContainerType::Text => RichtextState::decode_value(b)?, - ContainerType::Map => MapState::decode_value(b)?, - ContainerType::List => ListState::decode_value(b)?, - ContainerType::MovableList => MovableListState::decode_value(b)?, + let mut decoded_state = None; + let (v, state_offset) = match self.kind { + ContainerType::Text => { + let (v, rest) = RichtextState::decode_value(b)?; + ( + LazyDecodedValue::text(v), + b.len() - rest.len() + value_offset, + ) + } + ContainerType::Map => { + let (v, rest) = MapState::decode_value_as_btree_map(b)?; + ( + LazyDecodedValue::Map { + ordered: v, + value: OnceCell::new(), + }, + b.len() - rest.len() + value_offset, + ) + } + ContainerType::List => { + let (v, rest) = ListState::decode_value(b)?; + ( + LazyDecodedValue::Value(v), + b.len() - rest.len() + value_offset, + ) + } + ContainerType::MovableList => { + let (v, rest) = MovableListState::decode_value(b)?; + ( + LazyDecodedValue::Value(v), + b.len() - rest.len() + value_offset, + ) + } ContainerType::Tree => { let mut state = TreeState::decode_snapshot_fast(idx, (LoroValue::Null, b), ctx)?; - self.value = Some(state.get_value()); - self.state = Some(State::TreeState(Box::new(state))); - self.bytes_offset_for_state = Some(value_offset); - return Ok(()); + let value = state.get_value(); + decoded_state = Some(State::TreeState(Box::new(state))); + (LazyDecodedValue::Value(value), value_offset) } #[cfg(feature = "counter")] ContainerType::Counter => { - let (v, _rest) = CounterState::decode_value(b)?; - self.value = Some(v); - self.bytes_offset_for_state = Some(0); - return Ok(()); + let (v, rest) = CounterState::decode_value(b)?; + ( + LazyDecodedValue::Value(v), + b.len() - rest.len() + value_offset, + ) + } + ContainerType::Unknown(_) => { + let (v, rest) = UnknownState::decode_value(b)?; + ( + LazyDecodedValue::Value(v), + b.len() - rest.len() + value_offset, + ) } - ContainerType::Unknown(_) => UnknownState::decode_value(b)?, }; - self.value = Some(v); - let offset = b.len() - rest.len(); - self.bytes_offset_for_state = Some(offset + value_offset); - Ok(()) + Ok((v, state_offset, decoded_state)) } pub(super) fn decode_state( @@ -215,18 +622,29 @@ impl ContainerWrapper { idx: ContainerIdx, ctx: ContainerCreationContext, ) -> LoroResult<()> { - if self.state.is_some() { + if matches!(self.data, ContainerData::State(_)) { return Ok(()); } - if self.value.is_none() { + let need_value = match &self.data { + ContainerData::Lazy(lazy) => lazy.value.is_none(), + ContainerData::State(_) => false, + }; + if need_value { self.decode_value(idx, ctx)?; } - let b = self.bytes.as_ref().unwrap(); - let offset = self.bytes_offset_for_state.unwrap(); - let b = &b[offset..]; - let v = self.value.as_ref().unwrap().clone(); + if matches!(self.data, ContainerData::State(_)) { + return Ok(()); + } + + let ContainerData::Lazy(lazy) = &self.data else { + unreachable!(); + }; + let bytes = lazy.bytes.as_ref().unwrap(); + let offset = lazy.bytes_offset_for_state.unwrap(); + let b = &bytes[offset..]; + let v = lazy.value.as_ref().unwrap().to_loro_value(); let state: State = match self.kind { ContainerType::Text => RichtextState::decode_snapshot_fast(idx, (v, b), ctx)?.into(), ContainerType::Map => MapState::decode_snapshot_fast(idx, (v, b), ctx)?.into(), @@ -241,25 +659,46 @@ impl ContainerWrapper { UnknownState::decode_snapshot_fast(idx, (v, b), ctx)?.into() } }; - self.state = Some(state); + self.data = ContainerData::State(state); Ok(()) } #[allow(unused)] pub(crate) fn is_state_empty(&self) -> bool { - if let Some(state) = self.state.as_ref() { - return state.is_state_empty(); + match &self.data { + ContainerData::State(state) => state.is_state_empty(), + ContainerData::Lazy(lazy) => { + // FIXME: it's not very accurate... + lazy.bytes.as_ref().unwrap().len() > 10 + } } + } - // FIXME: it's not very accurate... - self.bytes.as_ref().unwrap().len() > 10 + pub(crate) fn is_deleted_root_value_cleared(&mut self) -> bool { + fn value_is_cleared(kind: ContainerType, value: &LoroValue) -> bool { + match kind { + ContainerType::Text => value.as_string().is_some_and(|value| value.is_empty()), + ContainerType::Map | ContainerType::List | ContainerType::MovableList => { + value.is_empty_collection() + } + ContainerType::Tree => value.as_list().is_some_and(|value| value.is_empty()), + #[cfg(feature = "counter")] + ContainerType::Counter => value.as_double().is_some_and(|value| *value == 0.0), + ContainerType::Unknown(_) => false, + } + } + + match &mut self.data { + ContainerData::State(state) => value_is_cleared(self.kind, &state.get_value()), + ContainerData::Lazy(lazy) => lazy + .value + .as_ref() + .is_some_and(|value| value_is_cleared(self.kind, &value.to_loro_value())), + } } pub(crate) fn clear_bytes(&mut self) { - assert!(self.state.is_some()); - self.bytes = None; - self.bytes_offset_for_state = None; - self.bytes_offset_for_value = None; + assert!(matches!(self.data, ContainerData::State(_))); } pub(crate) fn is_flushed(&self) -> bool { diff --git a/crates/loro-internal/src/state/container_store/inner_store.rs b/crates/loro-internal/src/state/container_store/inner_store.rs index 9ca5cb484..a4bcea935 100644 --- a/crates/loro-internal/src/state/container_store/inner_store.rs +++ b/crates/loro-internal/src/state/container_store/inner_store.rs @@ -4,7 +4,6 @@ use crate::{ }; use bytes::Bytes; use loro_common::ContainerID; -use rustc_hash::FxHashMap; use std::ops::Bound; use super::ContainerWrapper; @@ -12,17 +11,24 @@ use super::ContainerWrapper; /// The invariants about this struct: /// /// - `kv` is either the same or older than `store`. -/// - if `all_loaded` is true, then `store` contains all the entries from `kv` +/// - if `load_state` is `AllLoaded`, then `store` contains all the entries from `kv` /// /// Invariants: it should be agnostic to the users of this struct whether a container is stored in `kv` or `store` pub(crate) struct InnerStore { arena: SharedArena, - store: FxHashMap, + store: Vec>, kv: KvWrapper, - all_loaded: bool, + load_state: LoadState, config: Configure, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum LoadState { + Lazy, + RootsLoaded, + AllLoaded, +} + impl std::fmt::Debug for InnerStore { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("InnerStore").finish() @@ -31,27 +37,72 @@ impl std::fmt::Debug for InnerStore { /// This impl block contains all the mutation code that may break the invariants of this struct impl InnerStore { + #[inline] + fn slot(idx: ContainerIdx) -> usize { + idx.to_index() as usize + } + + #[inline] + fn get_entry_mut_in( + store: &mut [Option], + idx: ContainerIdx, + ) -> Option<&mut ContainerWrapper> { + let entry = store.get_mut(Self::slot(idx))?.as_mut()?; + debug_assert_eq!(entry.kind(), idx.get_type()); + Some(entry) + } + + #[inline] + fn get_entry_mut(&mut self, idx: ContainerIdx) -> Option<&mut ContainerWrapper> { + Self::get_entry_mut_in(&mut self.store, idx) + } + + #[inline] + fn contains_idx_in(store: &[Option], idx: ContainerIdx) -> bool { + store + .get(Self::slot(idx)) + .and_then(|entry| entry.as_ref()) + .is_some_and(|entry| entry.kind() == idx.get_type()) + } + + #[inline] + fn contains_idx(&self, idx: ContainerIdx) -> bool { + Self::contains_idx_in(&self.store, idx) + } + + fn insert_entry( + store: &mut Vec>, + idx: ContainerIdx, + container: ContainerWrapper, + ) -> Option { + let slot = Self::slot(idx); + if store.len() <= slot { + store.resize_with(slot + 1, || None); + } + + store[slot].replace(container) + } + pub(super) fn get_or_insert_with( &mut self, idx: ContainerIdx, f: impl FnOnce() -> ContainerWrapper, ) -> &mut ContainerWrapper { - match self.store.entry(idx) { - std::collections::hash_map::Entry::Vacant(e) => { - let id = self.arena.get_container_id(idx).unwrap(); - let key = id.to_bytes(); - if !self.all_loaded { - if let Some(v) = self.kv.get(&key) { - let c = ContainerWrapper::new_from_bytes(v); - return e.insert(c); - } - } - - let c = f(); - e.insert(c) - } - std::collections::hash_map::Entry::Occupied(e) => e.into_mut(), + if self.get_entry_mut(idx).is_none() { + let id = self.arena.get_container_id(idx).unwrap(); + let key = id.to_bytes(); + let container = if self.load_state != LoadState::AllLoaded { + self.kv + .get(&key) + .map(ContainerWrapper::new_from_bytes) + .unwrap_or_else(f) + } else { + f() + }; + Self::insert_entry(&mut self.store, idx, container); } + + self.get_entry_mut(idx).unwrap() } pub(super) fn ensure_container( @@ -59,54 +110,77 @@ impl InnerStore { idx: ContainerIdx, f: impl FnOnce() -> ContainerWrapper, ) { - if self.store.contains_key(&idx) { + if self.contains_idx(idx) { return; } - if !self.all_loaded { + if self.load_state != LoadState::AllLoaded { let id = self.arena.get_container_id(idx).unwrap(); let key = id.to_bytes(); if let Some(v) = self.kv.get(&key) { let c = ContainerWrapper::new_from_bytes(v); - self.store.insert(idx, c); + Self::insert_entry(&mut self.store, idx, c); return; } } let c = f(); - self.store.insert(idx, c); + Self::insert_entry(&mut self.store, idx, c); } pub(crate) fn get_mut(&mut self, idx: ContainerIdx) -> Option<&mut ContainerWrapper> { - if let std::collections::hash_map::Entry::Vacant(e) = self.store.entry(idx) { - if !self.all_loaded { - let id = self.arena.get_container_id(idx).unwrap(); - let key = id.to_bytes(); - if let Some(v) = self.kv.get(&key) { - let c = ContainerWrapper::new_from_bytes(v); - e.insert(c); + if self.get_entry_mut(idx).is_none() && self.load_state != LoadState::AllLoaded { + let id = self.arena.get_container_id(idx).unwrap(); + let key = id.to_bytes(); + if let Some(v) = self.kv.get(&key) { + let c = ContainerWrapper::new_from_bytes(v); + Self::insert_entry(&mut self.store, idx, c); + } + } + + self.get_entry_mut(idx) + } + + pub(crate) fn with_container_for_read( + &mut self, + idx: ContainerIdx, + f: impl FnOnce(&mut ContainerWrapper) -> R, + ) -> Option { + if let Some(entry) = self.get_entry_mut(idx) { + return Some(f(entry)); + } + + if self.load_state != LoadState::AllLoaded { + let id = self.arena.get_container_id(idx).unwrap(); + let key = id.to_bytes(); + if let Some(v) = self.kv.get(&key) { + let mut container = ContainerWrapper::new_from_bytes(v); + let ans = f(&mut container); + if container.has_cached_value() { + Self::insert_entry(&mut self.store, idx, container); } + return Some(ans); } } - self.store.get_mut(&idx) + None + } + + pub(crate) fn has_decoded_state(&mut self, idx: ContainerIdx) -> bool { + self.get_entry_mut(idx) + .is_some_and(|entry| entry.try_get_state().is_some()) } pub(crate) fn contains_id(&mut self, id: &ContainerID) -> bool { if let Some(idx) = self.arena.id_to_idx(id) { - if self.store.contains_key(&idx) { + if self.contains_idx(idx) { return true; } } - if !self.all_loaded { + if self.load_state != LoadState::AllLoaded { let key = id.to_bytes(); - if let Some(v) = self.kv.get(&key) { - let idx = self.arena.register_container(id); - let c = ContainerWrapper::new_from_bytes(v); - self.store.insert(idx, c); - return true; - } + return self.kv.contains_key(&key); } false @@ -114,17 +188,30 @@ impl InnerStore { pub(crate) fn iter_all_containers_mut( &mut self, - ) -> impl Iterator { + ) -> impl Iterator { self.load_all(); - self.store.iter_mut() + self.store + .iter_mut() + .enumerate() + .filter_map(|(slot, entry)| { + entry.as_mut().map(|container| { + ( + ContainerIdx::from_index_and_type(slot as u32, container.kind()), + container, + ) + }) + }) } pub(crate) fn iter_all_container_ids(&mut self) -> impl Iterator + '_ { // PERF: we don't need to load all the containers here self.load_all(); - self.store - .keys() - .map(|idx| self.arena.get_container_id(*idx).unwrap()) + self.store.iter().enumerate().filter_map(|(slot, entry)| { + entry.as_ref().map(|container| { + let idx = ContainerIdx::from_index_and_type(slot as u32, container.kind()); + self.arena.get_container_id(idx).unwrap() + }) + }) } pub(crate) fn encode(&mut self) -> Bytes { @@ -134,22 +221,36 @@ impl InnerStore { pub(crate) fn flush(&mut self) { let deleted = self.config.deleted_root_containers.lock(); - self.kv - .set_all(self.store.iter_mut().filter_map(|(idx, c)| { - if c.is_flushed() { - return None; - } + let mut updates = Vec::new(); + let mut deleted_roots = Vec::new(); + + for (slot, entry) in self.store.iter_mut().enumerate() { + let Some(c) = entry.as_mut() else { + continue; + }; + let idx = ContainerIdx::from_index_and_type(slot as u32, c.kind()); + let cid = self.arena.get_container_id(idx).unwrap(); + if cid.is_root() && deleted.contains(&cid) && c.is_deleted_root_value_cleared() { + deleted_roots.push(cid.to_bytes()); + c.set_flushed(true); + continue; + } - let cid = self.arena.get_container_id(*idx).unwrap(); - if c.is_state_empty() && cid.is_root() && deleted.contains(&cid) { - return None; - } + if c.is_flushed() { + continue; + } - let cid: Bytes = cid.to_bytes().into(); - let value = c.encode(); - c.set_flushed(true); - Some((cid, value)) - })); + let cid: Bytes = cid.to_bytes().into(); + let value = c.encode(); + c.set_flushed(true); + updates.push((cid, value)); + } + + drop(deleted); + for cid in deleted_roots { + self.kv.remove(&cid); + } + self.kv.set_all(updates.into_iter()); } pub(crate) fn get_kv_clone(&self) -> KvWrapper { @@ -179,7 +280,7 @@ impl InnerStore { })); self.store.clear(); - self.all_loaded = false; + self.load_state = LoadState::Lazy; Ok(fr) } @@ -197,8 +298,10 @@ impl InnerStore { .import(bytes_b) .map_err(|e| loro_common::LoroError::DecodeError(e.into_boxed_str()))?; self.kv.remove(FRONTIERS_KEY); + let store = &mut self.store; + let arena = &self.arena; self.kv.with_kv(|kv| { - self.arena.with_guards(|guards| { + arena.with_guards(|guards| { let iter = kv.scan(Bound::Unbounded, Bound::Unbounded); for (k, v) in iter { let cid = ContainerID::from_bytes(&k); @@ -207,39 +310,61 @@ impl InnerStore { let idx = guards.register_container(&cid); let p = parent.as_ref().map(|p| guards.register_container(p)); guards.set_parent(idx, p); - if self.store.insert(idx, c).is_some() {} + if Self::insert_entry(store, idx, c).is_some() {} } }); }); - self.all_loaded = true; + self.load_state = LoadState::AllLoaded; Ok(()) } pub fn load_all(&mut self) { - if self.all_loaded { + if self.load_state == LoadState::AllLoaded { return; } + let store = &mut self.store; + let arena = &self.arena; self.kv.with_kv(|kv| { let iter = kv.scan(Bound::Unbounded, Bound::Unbounded); - self.arena.with_guards(|guards| { + arena.with_guards(|guards| { for (k, v) in iter { let cid = ContainerID::from_bytes(&k); let idx = guards.register_container(&cid); - if self.store.contains_key(&idx) { + if Self::contains_idx_in(store, idx) { // the container is already loaded // the content in `store` is guaranteed to be newer than the content in `kv` continue; } let container = ContainerWrapper::new_from_bytes(v); - self.store.insert(idx, container); + Self::insert_entry(store, idx, container); } }); }); - self.all_loaded = true; + self.load_state = LoadState::AllLoaded; + } + + pub fn load_roots(&mut self) { + if self.load_state != LoadState::Lazy { + return; + } + + let arena = &self.arena; + self.kv.with_kv(|kv| { + let iter = kv.scan(Bound::Unbounded, Bound::Unbounded); + arena.with_guards(|guards| { + for (k, _) in iter { + let cid = ContainerID::from_bytes(&k); + if cid.is_root() { + guards.register_container(&cid); + } + } + }); + }); + self.load_state = LoadState::RootsLoaded; } pub(crate) fn can_import_snapshot(&self) -> bool { @@ -247,7 +372,16 @@ impl InnerStore { return false; } - self.store.iter().all(|(_, c)| c.is_state_empty()) + self.store + .iter() + .filter_map(|entry| entry.as_ref()) + .all(|c| c.is_state_empty()) + } + + #[cfg(test)] + pub(super) fn has_cached_value_for_test(&mut self, idx: ContainerIdx) -> bool { + self.get_entry_mut(idx) + .is_some_and(|entry| entry.has_cached_value_for_test()) } } @@ -255,9 +389,9 @@ impl InnerStore { pub(crate) fn new(arena: SharedArena, config: Configure) -> Self { Self { arena, - store: FxHashMap::default(), + store: Vec::new(), kv: KvWrapper::new_mem(), - all_loaded: true, + load_state: LoadState::AllLoaded, config, } } diff --git a/crates/loro-internal/src/state/counter_state.rs b/crates/loro-internal/src/state/counter_state.rs index d47d24077..332b2f52d 100644 --- a/crates/loro-internal/src/state/counter_state.rs +++ b/crates/loro-internal/src/state/counter_state.rs @@ -97,6 +97,17 @@ mod snapshot { } fn decode_value(bytes: &[u8]) -> LoroResult<(LoroValue, &[u8])> { + // Builds without the counter feature can re-export an untouched counter as header only. + if bytes.is_empty() { + return Ok((LoroValue::Double(0.0), bytes)); + } + + if bytes.len() != 8 { + return Err(loro_common::LoroError::DecodeError( + "Decode counter value failed".to_string().into_boxed_str(), + )); + } + let value = f64::from_le_bytes(bytes[..8].try_into().unwrap()); Ok((LoroValue::Double(value), &bytes[8..])) } @@ -109,6 +120,12 @@ mod snapshot { where Self: Sized, { + if !v.1.is_empty() { + return Err(loro_common::LoroError::DecodeError( + "Decode counter state failed".to_string().into_boxed_str(), + )); + } + let mut counter = CounterState::new(idx); counter.value = *v.0.as_double().unwrap(); Ok(counter) diff --git a/crates/loro-internal/src/state/list_state.rs b/crates/loro-internal/src/state/list_state.rs index 626aebf4c..474d4b11d 100644 --- a/crates/loro-internal/src/state/list_state.rs +++ b/crates/loro-internal/src/state/list_state.rs @@ -18,12 +18,14 @@ use generic_btree::{ use loro_common::{IdFull, LoroError, LoroResult, ID}; use loro_delta::array_vec::ArrayVec; use rustc_hash::FxHashMap; +use smallvec::SmallVec; + +const TINY_LIST_MAX: usize = 4; #[derive(Debug)] pub struct ListState { idx: ContainerIdx, - list: BTree, - child_container_to_leaf: FxHashMap, + list: ListEntries, } impl Clone for ListState { @@ -31,7 +33,6 @@ impl Clone for ListState { Self { idx: self.idx, list: self.list.clone(), - child_container_to_leaf: self.child_container_to_leaf.clone(), } } } @@ -139,38 +140,159 @@ impl UseLengthFinder for ListImpl { } } +#[derive(Debug, Clone)] +enum ListEntries { + Tiny { + entries: SmallVec<[Elem; 1]>, + child_container_to_index: FxHashMap, + }, + Tree { + tree: BTree, + child_container_to_index: FxHashMap, + }, +} + +impl ListEntries { + fn new_tiny() -> Self { + Self::Tiny { + entries: SmallVec::new(), + child_container_to_index: Default::default(), + } + } + + fn is_empty(&self) -> bool { + match self { + Self::Tiny { entries, .. } => entries.is_empty(), + Self::Tree { tree, .. } => tree.is_empty(), + } + } + + fn len(&self) -> usize { + match self { + Self::Tiny { entries, .. } => entries.len(), + Self::Tree { tree, .. } => *tree.root_cache() as usize, + } + } + + fn iter(&self) -> Box + '_> { + match self { + Self::Tiny { entries, .. } => Box::new(entries.iter()), + Self::Tree { tree, .. } => Box::new(tree.iter()), + } + } +} + impl ListState { pub fn new(idx: ContainerIdx) -> Self { - let tree = BTree::new(); Self { idx, - list: tree, - child_container_to_leaf: Default::default(), + list: ListEntries::new_tiny(), } } - pub fn contains_child_container(&self, id: &ContainerID) -> bool { - let Some(&leaf) = self.child_container_to_leaf.get(id) else { - return false; + fn rebuild_tiny_child_index(&mut self) { + let ListEntries::Tiny { + entries, + child_container_to_index, + } = &mut self.list + else { + return; + }; + + child_container_to_index.clear(); + for (index, elem) in entries.iter().enumerate() { + if let LoroValue::Container(id) = &elem.v { + child_container_to_index.insert(id.clone(), index); + } + } + } + + fn upgrade_to_tree(&mut self) { + let ListEntries::Tiny { entries, .. } = + std::mem::replace(&mut self.list, ListEntries::new_tiny()) + else { + return; }; - self.list.get_elem(leaf).is_some() + let mut tree = BTree::new(); + let mut child_container_to_index = FxHashMap::default(); + for elem in entries { + let container = elem.v.as_container().cloned(); + let leaf = tree.push(elem); + if let Some(container) = container { + child_container_to_index.insert(container, leaf.leaf); + } + } + + self.list = ListEntries::Tree { + tree, + child_container_to_index, + }; + } + + fn downgrade_tree_if_tiny(&mut self) { + let ListEntries::Tree { tree, .. } = &self.list else { + return; + }; + + if *tree.root_cache() as usize > TINY_LIST_MAX { + return; + } + + let entries = tree.iter().cloned().collect(); + self.list = ListEntries::Tiny { + entries, + child_container_to_index: Default::default(), + }; + self.rebuild_tiny_child_index(); + } + + pub fn contains_child_container(&self, id: &ContainerID) -> bool { + match &self.list { + ListEntries::Tiny { + entries, + child_container_to_index, + } => child_container_to_index.get(id).is_some_and(|index| { + entries + .get(*index) + .is_some_and(|elem| elem.v.as_container() == Some(id)) + }), + ListEntries::Tree { + tree, + child_container_to_index, + } => child_container_to_index + .get(id) + .is_some_and(|leaf| tree.get_elem(*leaf).is_some()), + } } pub fn get_child_container_index(&self, id: &ContainerID) -> Option { - let leaf = *self.child_container_to_leaf.get(id)?; - self.list.get_elem(leaf)?; + let (leaf, list) = match &self.list { + ListEntries::Tiny { + entries, + child_container_to_index, + } => { + let index = *child_container_to_index.get(id)?; + entries.get(index)?; + return Some(index); + } + ListEntries::Tree { + tree, + child_container_to_index, + } => (*child_container_to_index.get(id)?, tree), + }; + + list.get_elem(leaf)?; let mut index = 0; - self.list - .visit_previous_caches(Cursor { leaf, offset: 0 }, |cache| match cache { - generic_btree::PreviousCache::NodeCache(cache) => { - index += *cache; - } - generic_btree::PreviousCache::PrevSiblingElem(..) => { - index += 1; - } - generic_btree::PreviousCache::ThisElemAndOffset { .. } => {} - }); + list.visit_previous_caches(Cursor { leaf, offset: 0 }, |cache| match cache { + generic_btree::PreviousCache::NodeCache(cache) => { + index += *cache; + } + generic_btree::PreviousCache::PrevSiblingElem(..) => { + index += 1; + } + generic_btree::PreviousCache::ThisElemAndOffset { .. } => {} + }); Some(index as usize) } @@ -180,20 +302,30 @@ impl ListState { panic!("Index {index} out of range. The length is {}", self.len()); } - if self.list.is_empty() { - let idx = self.list.push(Elem { - v: value.clone(), - id, - }); - - if value.is_container() { - self.child_container_to_leaf - .insert(value.into_container().unwrap(), idx.leaf); + if let ListEntries::Tiny { entries, .. } = &mut self.list { + if entries.len() < TINY_LIST_MAX { + entries.insert( + index, + Elem { + v: value.clone(), + id, + }, + ); + self.rebuild_tiny_child_index(); + return; } - return; + + self.upgrade_to_tree(); } - let (leaf, data) = self.list.insert::( + let ListEntries::Tree { + tree, + child_container_to_index, + } = &mut self.list + else { + unreachable!() + }; + let (leaf, data) = tree.insert::( &index, Elem { v: value.clone(), @@ -202,45 +334,63 @@ impl ListState { ); if value.is_container() { - self.child_container_to_leaf - .insert(value.into_container().unwrap(), leaf.leaf); + child_container_to_index.insert(value.into_container().unwrap(), leaf.leaf); } assert!(data.arr.is_empty()); } pub fn push(&mut self, value: LoroValue, id: IdFull) { - if self.list.is_empty() { - let idx = self.list.push(Elem { - v: value.clone(), - id, - }); - - if value.is_container() { - self.child_container_to_leaf - .insert(value.into_container().unwrap(), idx.leaf); + if let ListEntries::Tiny { entries, .. } = &mut self.list { + if entries.len() < TINY_LIST_MAX { + entries.push(Elem { + v: value.clone(), + id, + }); + self.rebuild_tiny_child_index(); + return; } - return; + + self.upgrade_to_tree(); } - let leaf = self.list.push(Elem { + let ListEntries::Tree { + tree, + child_container_to_index, + } = &mut self.list + else { + unreachable!() + }; + let leaf = tree.push(Elem { v: value.clone(), id, }); if value.is_container() { - self.child_container_to_leaf - .insert(value.into_container().unwrap(), leaf.leaf); + child_container_to_index.insert(value.into_container().unwrap(), leaf.leaf); } } pub fn delete(&mut self, index: usize) -> LoroValue { - let leaf = self.list.query::(&index); - let leaf = self.list.remove_leaf(leaf.unwrap().cursor).unwrap(); + if let ListEntries::Tiny { entries, .. } = &mut self.list { + let elem = entries.remove(index); + self.rebuild_tiny_child_index(); + return elem.v; + } + + let ListEntries::Tree { + tree, + child_container_to_index, + } = &mut self.list + else { + unreachable!() + }; + let leaf = tree.query::(&index); + let leaf = tree.remove_leaf(leaf.unwrap().cursor).unwrap(); if leaf.v.is_container() { - self.child_container_to_leaf - .remove(leaf.v.as_container().unwrap()); + child_container_to_index.remove(leaf.v.as_container().unwrap()); } + self.downgrade_tree_if_tiny(); leaf.v } @@ -259,6 +409,19 @@ impl ListState { std::ops::Bound::Excluded(x) => *x, std::ops::Bound::Unbounded => self.len(), }; + + if let ListEntries::Tiny { entries, .. } = &mut self.list { + for elem in entries.drain(start..end) { + if let LoroValue::Container(c) = elem.v { + if let Some(notify_deletion) = &mut notify_deletion { + notify_deletion.push(c); + } + } + } + self.rebuild_tiny_child_index(); + return; + } + if end - start == 1 { if let LoroValue::Container(c) = self.delete(start) { if let Some(notify_deletion) = &mut notify_deletion { @@ -268,19 +431,25 @@ impl ListState { return; } - let list = &mut self.list; + let ListEntries::Tree { + tree, + child_container_to_index, + } = &mut self.list + else { + unreachable!() + }; let q = start..end; - let start1 = list.query::(&q.start); - let end1 = list.query::(&q.end); - for v in iter::Drain::new(list, start1, end1) { + let start1 = tree.query::(&q.start); + let end1 = tree.query::(&q.end); + for v in iter::Drain::new(tree, start1, end1) { if v.v.is_container() { - self.child_container_to_leaf - .remove(v.v.as_container().unwrap()); + child_container_to_index.remove(v.v.as_container().unwrap()); if let Some(notify_deletion) = &mut notify_deletion { notify_deletion.push(v.v.into_container().unwrap()); } } } + self.downgrade_tree_if_tiny(); } // PERF: use &[LoroValue] @@ -294,16 +463,16 @@ impl ListState { } pub fn iter(&self) -> impl Iterator { - self.list.iter().map(|x| &x.v) + self.iter_with_id().map(|x| &x.v) } #[allow(unused)] - pub(crate) fn iter_with_id(&self) -> impl Iterator { + pub(crate) fn iter_with_id(&self) -> Box + '_> { self.list.iter() } pub fn len(&self) -> usize { - *self.list.root_cache() as usize + self.list.len() } fn to_vec(&self) -> Vec { @@ -315,20 +484,30 @@ impl ListState { } pub fn get(&self, index: usize) -> Option<&LoroValue> { - let result = self.list.query::(&index)?; - if result.found { - Some(&result.elem(&self.list).unwrap().v) - } else { - None + match &self.list { + ListEntries::Tiny { entries, .. } => entries.get(index).map(|elem| &elem.v), + ListEntries::Tree { tree, .. } => { + let result = tree.query::(&index)?; + if result.found { + Some(&result.elem(tree).unwrap().v) + } else { + None + } + } } } pub fn get_id_at(&self, index: usize) -> Option { - let result = self.list.query::(&index)?; - if result.found { - Some(result.elem(&self.list).unwrap().id) - } else { - None + match &self.list { + ListEntries::Tiny { entries, .. } => entries.get(index).map(|elem| elem.id), + ListEntries::Tree { tree, .. } => { + let result = tree.query::(&index)?; + if result.found { + Some(result.elem(tree).unwrap().id) + } else { + None + } + } } } @@ -551,7 +730,7 @@ impl ContainerState for ListState { fn get_child_containers(&self) -> Vec { let mut ans = Vec::new(); - for elem in self.list.iter() { + for elem in self.iter_with_id() { if elem.v.is_container() { ans.push(elem.v.as_container().unwrap().clone()); } @@ -565,12 +744,16 @@ impl ContainerState for ListState { } mod snapshot { - use std::io::Read; - - use loro_common::{Counter, Lamport, PeerID}; + use loro_common::PeerID; use serde_columnar::columnar; - use crate::{encoding::value_register::ValueRegister, state::ContainerCreationContext}; + use crate::{ + encoding::value_register::ValueRegister, + state::{ + decode_lamport_from_delta, decode_peer_from_table, decode_peer_table, + state_decode_error, ContainerCreationContext, + }, + }; use super::*; #[columnar(vec, ser, de, iterable)] @@ -640,26 +823,33 @@ mod snapshot { where Self: Sized, { - let peer_num = leb128::read::unsigned(&mut bytes).unwrap() as usize; - let mut peers = Vec::with_capacity(peer_num); - for _ in 0..peer_num { - let mut buf = [0u8; 8]; - bytes.read_exact(&mut buf).unwrap(); - peers.push(PeerID::from_le_bytes(buf)); - } + let peers = decode_peer_table(&mut bytes, "Decode list state failed")?; - let EncodedListIds { ids } = serde_columnar::from_bytes(bytes).unwrap(); + let EncodedListIds { ids } = serde_columnar::from_bytes(bytes).map_err(|err| { + state_decode_error(format!("Decode list state failed: invalid id table: {err}")) + })?; let list = v.as_list().unwrap(); + if ids.len() != list.len() { + return Err(state_decode_error( + "Decode list state failed: id/value length mismatch", + )); + } + let mut ans = Self::new(idx); for (i, id) in ids.into_iter().enumerate() { + let peer = decode_peer_from_table(&peers, id.peer_idx, "Decode list state failed")?; ans.insert( i, list[i].clone(), IdFull::new( - peers[id.peer_idx], - id.counter as Counter, - (id.lamport_sub_counter + id.counter) as Lamport, + peer, + id.counter, + decode_lamport_from_delta( + id.counter, + id.lamport_sub_counter, + "Decode list state failed", + )?, ), ); } @@ -667,6 +857,45 @@ mod snapshot { Ok(ans) } } + + #[cfg(test)] + mod tests { + use loro_common::LoroValue; + + use crate::{container::idx::ContainerIdx, state::ContainerCreationContext}; + + use super::*; + + #[test] + fn list_fast_snapshot_rejects_negative_counter() { + let mut bytes = Vec::new(); + postcard::to_io(&vec![LoroValue::I64(1)], &mut bytes).unwrap(); + leb128::write::unsigned(&mut bytes, 1).unwrap(); + bytes.extend_from_slice(&1_u64.to_le_bytes()); + bytes.extend_from_slice( + &serde_columnar::to_vec(&EncodedListIds { + ids: vec![EncodedListId { + peer_idx: 0, + counter: -1, + lamport_sub_counter: 1, + }], + }) + .unwrap(), + ); + + let idx = ContainerIdx::from_index_and_type(0, loro_common::ContainerType::List); + let (value, state_bytes) = ListState::decode_value(&bytes).unwrap(); + assert!(ListState::decode_snapshot_fast( + idx, + (value, state_bytes), + ContainerCreationContext { + configure: &Default::default(), + peer: 0, + }, + ) + .is_err()); + } + } } #[cfg(test)] @@ -741,4 +970,22 @@ mod test { assert_eq!(v[2].id.counter, 2 as Counter); assert_eq!(v[2].id.lamport, 2 as Lamport); } + + #[test] + fn list_fast_snapshot_rejects_corrupt_state_metadata() { + let idx = ContainerIdx::from_index_and_type(0, loro_common::ContainerType::List); + let ctx = ContainerCreationContext { + configure: &Default::default(), + peer: 0, + }; + + let value = LoroValue::from(vec![LoroValue::I64(1)]); + assert!(ListState::decode_snapshot_fast(idx, (value.clone(), &[1]), ctx).is_err()); + + let mut empty = ListState::new(idx); + let mut bytes = Vec::new(); + empty.encode_snapshot_fast(&mut bytes); + let (_, state_bytes) = ListState::decode_value(&bytes).unwrap(); + assert!(ListState::decode_snapshot_fast(idx, (value, state_bytes), ctx).is_err()); + } } diff --git a/crates/loro-internal/src/state/map_state.rs b/crates/loro-internal/src/state/map_state.rs index a8863eb27..0ff7f0b5d 100644 --- a/crates/loro-internal/src/state/map_state.rs +++ b/crates/loro-internal/src/state/map_state.rs @@ -2,6 +2,7 @@ use std::{collections::BTreeMap, sync::Weak}; use loro_common::{ContainerID, IdLp, LoroResult, PeerID}; use rustc_hash::FxHashMap; +use smallvec::SmallVec; use crate::{ configure::Configure, @@ -16,11 +17,196 @@ use crate::{ use super::{ApplyLocalOpReturn, ContainerState, DiffApplyContext}; +const TINY_MAP_MAX: usize = 4; + +#[derive(Debug, Clone)] +enum MapEntries { + SortedTiny(SmallVec<[(InternalString, MapValue); 1]>), + Tree(BTreeMap), +} + +impl Default for MapEntries { + fn default() -> Self { + Self::SortedTiny(SmallVec::new()) + } +} + +enum MapEntriesIter<'a> { + SortedTiny(std::slice::Iter<'a, (InternalString, MapValue)>), + Tree(std::collections::btree_map::Iter<'a, InternalString, MapValue>), +} + +impl<'a> Iterator for MapEntriesIter<'a> { + type Item = (&'a InternalString, &'a MapValue); + + fn next(&mut self) -> Option { + match self { + Self::SortedTiny(iter) => iter.next().map(|(key, value)| (key, value)), + Self::Tree(iter) => iter.next(), + } + } +} + +impl MapEntries { + fn debug_assert_sorted_tiny(entries: &[(InternalString, MapValue)]) { + debug_assert!(entries.windows(2).all(|pair| { + let left = &pair[0].0; + let right = &pair[1].0; + left < right + })); + } + + fn is_empty(&self) -> bool { + match self { + Self::SortedTiny(entries) => entries.is_empty(), + Self::Tree(map) => map.is_empty(), + } + } + + fn get(&self, key: &InternalString) -> Option<&MapValue> { + match self { + Self::SortedTiny(entries) => { + Self::debug_assert_sorted_tiny(entries); + entries + .binary_search_by(|(entry_key, _)| entry_key.cmp(key)) + .ok() + .map(|index| &entries[index].1) + } + Self::Tree(map) => map.get(key), + } + } + + fn insert(&mut self, key: InternalString, value: MapValue) -> Option { + match self { + Self::SortedTiny(entries) => { + Self::debug_assert_sorted_tiny(entries); + match entries.binary_search_by(|(entry_key, _)| entry_key.cmp(&key)) { + Ok(index) => return Some(std::mem::replace(&mut entries[index].1, value)), + Err(index) if entries.len() < TINY_MAP_MAX => { + entries.insert(index, (key, value)); + Self::debug_assert_sorted_tiny(entries); + return None; + } + Err(_) => {} + } + + let mut map = BTreeMap::new(); + for (key, value) in entries.drain(..) { + map.insert(key, value); + } + let result = map.insert(key, value); + *self = Self::Tree(map); + result + } + Self::Tree(map) => map.insert(key, value), + } + } + + fn remove(&mut self, key: &InternalString) -> Option { + match self { + Self::SortedTiny(entries) => { + Self::debug_assert_sorted_tiny(entries); + entries + .binary_search_by(|(entry_key, _)| entry_key.cmp(key)) + .ok() + .map(|index| entries.remove(index).1) + } + Self::Tree(map) => { + let result = map.remove(key); + if result.is_some() && map.len() <= TINY_MAP_MAX { + let entries = std::mem::take(map).into_iter().collect(); + *self = Self::SortedTiny(entries); + } + result + } + } + } + + fn iter(&self) -> MapEntriesIter<'_> { + match self { + Self::SortedTiny(entries) => { + Self::debug_assert_sorted_tiny(entries); + MapEntriesIter::SortedTiny(entries.iter()) + } + Self::Tree(map) => MapEntriesIter::Tree(map.iter()), + } + } +} + +#[derive(Debug, Clone)] +enum ChildContainers { + Tiny(SmallVec<[(ContainerID, InternalString); 1]>), + Map(FxHashMap), +} + +impl Default for ChildContainers { + fn default() -> Self { + Self::Tiny(SmallVec::new()) + } +} + +impl ChildContainers { + fn get(&self, id: &ContainerID) -> Option<&InternalString> { + match self { + Self::Tiny(entries) => entries + .iter() + .find_map(|(entry_id, key)| (entry_id == id).then_some(key)), + Self::Map(map) => map.get(id), + } + } + + fn contains_key(&self, id: &ContainerID) -> bool { + self.get(id).is_some() + } + + fn insert(&mut self, id: ContainerID, key: InternalString) -> Option { + match self { + Self::Tiny(entries) => { + if let Some((_, old_key)) = entries.iter_mut().find(|(entry_id, _)| entry_id == &id) + { + return Some(std::mem::replace(old_key, key)); + } + + if entries.len() < TINY_MAP_MAX { + entries.push((id, key)); + return None; + } + + let mut map = FxHashMap::default(); + for (id, key) in entries.drain(..) { + map.insert(id, key); + } + let result = map.insert(id, key); + *self = Self::Map(map); + result + } + Self::Map(map) => map.insert(id, key), + } + } + + fn remove(&mut self, id: &ContainerID) -> Option { + match self { + Self::Tiny(entries) => entries + .iter() + .position(|(entry_id, _)| entry_id == id) + .map(|index| entries.swap_remove(index).1), + Self::Map(map) => { + let result = map.remove(id); + if result.is_some() && map.len() <= TINY_MAP_MAX { + let entries = std::mem::take(map).into_iter().collect(); + *self = Self::Tiny(entries); + } + result + } + } + } +} + #[derive(Debug, Clone)] pub struct MapState { idx: ContainerIdx, - map: BTreeMap, - child_containers: FxHashMap, + map: MapEntries, + child_containers: ChildContainers, size: usize, } @@ -119,9 +305,8 @@ impl ContainerState for MapState { Diff::Map(ResolvedMapDelta { updated: self .map - .clone() - .into_iter() - .map(|(k, v)| (k, ResolvedMapValue::from_map_value(v, doc))) + .iter() + .map(|(k, v)| (k.clone(), ResolvedMapValue::from_map_value(v.clone(), doc))) .collect::>(), }) } @@ -176,18 +361,14 @@ impl MapState { } match (&result, value_yes) { - (Some(x), true) => { - if x.value.is_none() { - self.size += 1; - } + (Some(x), true) if x.value.is_none() => { + self.size += 1; } (None, true) => { self.size += 1; } - (Some(x), false) => { - if x.value.is_some() { - self.size -= 1; - } + (Some(x), false) if x.value.is_some() => { + self.size -= 1; } _ => {} }; @@ -207,7 +388,7 @@ impl MapState { }; } - pub fn iter(&self) -> std::collections::btree_map::Iter<'_, InternalString, MapValue> { + pub fn iter(&self) -> impl Iterator { self.map.iter() } @@ -245,6 +426,8 @@ impl MapState { mod snapshot { + use std::collections::BTreeMap; + use loro_common::{InternalString, LoroValue}; use rustc_hash::{FxHashMap, FxHashSet}; use serde_columnar::Itertools; @@ -252,7 +435,10 @@ mod snapshot { use crate::{ delta::MapValue, encoding::value_register::ValueRegister, - state::{ContainerCreationContext, ContainerState, FastStateSnapshot}, + state::{ + decode_peer_from_table, decode_peer_table, read_state_leb_u64, state_decode_error, + ContainerCreationContext, ContainerState, FastStateSnapshot, + }, }; use super::MapState; @@ -274,7 +460,7 @@ mod snapshot { .collect_vec(); postcard::to_io(&keys_with_none_value, &mut w).unwrap(); let mut peer_register = ValueRegister::new(); - for v in self.map.values() { + for (_, v) in self.map.iter() { peer_register.register(&v.peer); } @@ -282,7 +468,7 @@ mod snapshot { for p in peer_register.vec() { w.write_all(&p.to_le_bytes()).unwrap(); } - let mut keys: Vec<&InternalString> = self.map.keys().collect(); + let mut keys: Vec<&InternalString> = self.map.iter().map(|(key, _)| key).collect(); keys.sort_unstable(); for key in keys.into_iter() { let value = self.map.get(key).unwrap(); @@ -323,13 +509,7 @@ mod snapshot { let keys_with_none_value: FxHashSet<_> = keys_with_none_value.into_iter().collect(); // peers - let peer_count = leb128::read::unsigned(&mut bytes).unwrap() as usize; - let mut peers = Vec::with_capacity(peer_count); - for _ in 0..peer_count { - let peer = u64::from_le_bytes(bytes[..8].try_into().unwrap()); - bytes = &bytes[8..]; - peers.push(peer); - } + let peers = decode_peer_table(&mut bytes, "Decode map state failed")?; // let mut ans = MapState::new(idx); @@ -338,9 +518,17 @@ mod snapshot { keys.sort_unstable(); for key in keys { - let peer_idx = leb128::read::unsigned(&mut bytes).unwrap() as usize; - let lamp = leb128::read::unsigned(&mut bytes).unwrap() as u32; - let peer = peers[peer_idx]; + let peer_idx = + usize::try_from(read_state_leb_u64(&mut bytes, "Decode map state failed")?) + .map_err(|_| { + state_decode_error("Decode map state failed: peer index overflow") + })?; + let lamp = + u32::try_from(read_state_leb_u64(&mut bytes, "Decode map state failed")?) + .map_err(|_| { + state_decode_error("Decode map state failed: lamport overflow") + })?; + let peer = decode_peer_from_table(&peers, peer_idx, "Decode map state failed")?; if keys_with_none_value.contains(&key) { ans.insert( @@ -364,10 +552,30 @@ mod snapshot { } } + if !bytes.is_empty() { + return Err(loro_common::LoroError::DecodeError( + "Decode map state failed".to_string().into_boxed_str(), + )); + } + Ok(ans) } } + impl MapState { + pub(crate) fn decode_value_as_btree_map( + bytes: &[u8], + ) -> loro_common::LoroResult<(BTreeMap, &[u8])> { + let (value, bytes) = postcard::take_from_bytes::>(bytes) + .map_err(|_| { + loro_common::LoroError::DecodeError( + "Decode map value failed".to_string().into_boxed_str(), + ) + })?; + Ok((value, bytes)) + } + } + #[cfg(test)] mod map_snapshot_test { use loro_common::LoroValue; @@ -447,5 +655,24 @@ mod snapshot { } ); } + + #[test] + fn map_fast_snapshot_rejects_corrupt_state_metadata() { + let idx = ContainerIdx::from_index_and_type(0, loro_common::ContainerType::Map); + let ctx = ContainerCreationContext { + configure: &Default::default(), + peer: 0, + }; + let value = LoroValue::from(std::collections::HashMap::from([( + "key".to_string(), + LoroValue::I64(1), + )])); + + let mut empty = MapState::new(idx); + let mut bytes = Vec::new(); + empty.encode_snapshot_fast(&mut bytes); + let (_, state_bytes) = MapState::decode_value(&bytes).unwrap(); + assert!(MapState::decode_snapshot_fast(idx, (value, state_bytes), ctx).is_err()); + } } } diff --git a/crates/loro-internal/src/state/movable_list_state.rs b/crates/loro-internal/src/state/movable_list_state.rs index 93df44855..ea8f89ba1 100644 --- a/crates/loro-internal/src/state/movable_list_state.rs +++ b/crates/loro-internal/src/state/movable_list_state.rs @@ -1133,73 +1133,71 @@ impl ContainerState for MovableListState { match self.inner.elements().get(&elem_id).cloned() { Some(elem) => { // Update value if needed - if value_id.is_some() - && elem.value != value - && (!need_compare || elem.value_id < value_id.unwrap()) - { - maybe_moved.remove(&elem_id); - self.inner - .update_value(elem_id, value.clone(), value_id.unwrap()); - let index = self.get_index_of_elem(elem_id); - if let Some(index) = index { - event.compose( - &DeltaRopeBuilder::new() - .retain(index, Default::default()) - .delete(1) - .insert( - ArrayVec::from([ValueOrHandler::from_value( - value, doc, - )]), - ListDeltaMeta { from_move: false }, - ) - .build(), - ) + if let Some(value_id) = value_id { + if elem.value != value && (!need_compare || elem.value_id < value_id) { + maybe_moved.remove(&elem_id); + self.inner.update_value(elem_id, value.clone(), value_id); + let index = self.get_index_of_elem(elem_id); + if let Some(index) = index { + event.compose( + &DeltaRopeBuilder::new() + .retain(index, Default::default()) + .delete(1) + .insert( + ArrayVec::from([ValueOrHandler::from_value( + value, doc, + )]), + ListDeltaMeta { from_move: false }, + ) + .build(), + ) + } } } // Update pos if needed - if pos.is_some() - && elem.pos != pos.unwrap() - && (!need_compare || elem.pos < pos.unwrap()) - { - // don't need to update old list item, because it's handled by list diff already - let result = self.inner.update_pos(elem_id, pos.unwrap(), false); - let result = self.inner.convert_update_to_event_pos(result); - if let Some(new_index) = result.insert { - let new_value = - self.elements().get(&elem_id).unwrap().value.clone(); - let from_delete = if let Some((_elem_index, elem_old_value)) = - maybe_moved.remove(&elem_id) - { - elem_old_value == new_value - } else { - false - }; - let new_delta: ListDiff = DeltaRopeBuilder::new() - .retain(new_index, Default::default()) - .insert( - ArrayVec::from([ValueOrHandler::from_value( - new_value, doc, - )]), - ListDeltaMeta { - from_move: (result.delete.is_some() && !value_updated) - || from_delete, - }, - ) - .build(); - event.compose(&new_delta); - } - if let Some(del_index) = result.delete { - event.compose( - &DeltaRopeBuilder::new() - .retain(del_index, Default::default()) - .delete(1) - .build(), - ); - } - if !result.activate_new_list_item { - // not matched list item found, remove directly - self.inner.remove_elem_by_id(&elem_id); + if let Some(pos) = pos { + if elem.pos != pos && (!need_compare || elem.pos < pos) { + // don't need to update old list item, because it's handled by list diff already + let result = self.inner.update_pos(elem_id, pos, false); + let result = self.inner.convert_update_to_event_pos(result); + if let Some(new_index) = result.insert { + let new_value = + self.elements().get(&elem_id).unwrap().value.clone(); + let from_delete = if let Some((_elem_index, elem_old_value)) = + maybe_moved.remove(&elem_id) + { + elem_old_value == new_value + } else { + false + }; + let new_delta: ListDiff = DeltaRopeBuilder::new() + .retain(new_index, Default::default()) + .insert( + ArrayVec::from([ValueOrHandler::from_value( + new_value, doc, + )]), + ListDeltaMeta { + from_move: (result.delete.is_some() + && !value_updated) + || from_delete, + }, + ) + .build(); + event.compose(&new_delta); + } + if let Some(del_index) = result.delete { + event.compose( + &DeltaRopeBuilder::new() + .retain(del_index, Default::default()) + .delete(1) + .build(), + ); + } + if !result.activate_new_list_item { + // not matched list item found, remove directly + self.inner.remove_elem_by_id(&elem_id); + } } } } @@ -1432,13 +1430,14 @@ struct EncodedFastSnapshot { } mod snapshot { - use std::io::Read; - use loro_common::{IdFull, IdLp, LoroValue, PeerID}; use crate::{ encoding::value_register::ValueRegister, - state::{ContainerCreationContext, ContainerState, FastStateSnapshot}, + state::{ + decode_lamport_from_delta, decode_peer_from_table, decode_peer_table, + state_decode_error, ContainerCreationContext, ContainerState, FastStateSnapshot, + }, }; use super::{ @@ -1550,59 +1549,121 @@ mod snapshot { where Self: Sized, { - let peer_num = leb128::read::unsigned(&mut bytes).unwrap() as usize; - let mut peers = Vec::with_capacity(peer_num); - for _ in 0..peer_num { - let mut buf = [0u8; 8]; - bytes.read_exact(&mut buf).unwrap(); - peers.push(PeerID::from_le_bytes(buf)); - } + let peers = decode_peer_table(&mut bytes, "Decode movable list state failed")?; let mut ans = MovableListState::new(idx); - let iters = serde_columnar::iter_from_bytes::(bytes).unwrap(); + let iters = + serde_columnar::iter_from_bytes::(bytes).map_err(|err| { + state_decode_error(format!( + "Decode movable list state failed: invalid metadata: {err}" + )) + })?; let mut elem_iter = iters.elem_ids; let item_iter = iters.items; let mut list_item_id_iter = iters.list_item_ids; let mut last_set_id_iter = iters.last_set_ids; let mut is_first = true; - let list_value = list_value.into_list().unwrap(); + let list_value = list_value.into_list().map_err(|_| { + state_decode_error("Decode movable list state failed: value is not a list") + })?; let mut list_value_iter = list_value.iter(); for item in item_iter { let EncodedItemForFastSnapshot { invisible_list_item, pos_id_eq_elem_id, elem_id_eq_last_set_id, - } = item.unwrap(); + } = item.map_err(|err| { + state_decode_error(format!( + "Decode movable list state failed: invalid item: {err}" + )) + })?; if !is_first { let EncodedIdFull { peer_idx, counter, lamport_sub_counter, - } = list_item_id_iter.next().unwrap().unwrap(); + } = list_item_id_iter + .next() + .ok_or_else(|| { + state_decode_error( + "Decode movable list state failed: missing list item id", + ) + })? + .map_err(|err| { + state_decode_error(format!( + "Decode movable list state failed: invalid list item id: {err}" + )) + })?; + let peer = decode_peer_from_table( + &peers, + peer_idx, + "Decode movable list state failed", + )?; let id_full = IdFull::new( - peers[peer_idx], + peer, counter, - (lamport_sub_counter + counter) as u32, + decode_lamport_from_delta( + counter, + lamport_sub_counter, + "Decode movable list state failed", + )?, ); let elem_id = if pos_id_eq_elem_id { id_full.idlp() } else { - let EncodedId { peer_idx, lamport } = elem_iter.next().unwrap().unwrap(); - IdLp::new(peers[peer_idx], lamport) + let EncodedId { peer_idx, lamport } = elem_iter + .next() + .ok_or_else(|| { + state_decode_error( + "Decode movable list state failed: missing element id", + ) + })? + .map_err(|err| { + state_decode_error(format!( + "Decode movable list state failed: invalid element id: {err}" + )) + })?; + IdLp::new( + decode_peer_from_table( + &peers, + peer_idx, + "Decode movable list state failed", + )?, + lamport, + ) }; let last_set_id = if elem_id_eq_last_set_id { elem_id } else { - let EncodedId { peer_idx, lamport } = - last_set_id_iter.next().unwrap().unwrap(); - IdLp::new(peers[peer_idx], lamport) + let EncodedId { peer_idx, lamport } = last_set_id_iter + .next() + .ok_or_else(|| { + state_decode_error( + "Decode movable list state failed: missing last set id", + ) + })? + .map_err(|err| { + state_decode_error(format!( + "Decode movable list state failed: invalid last set id: {err}" + )) + })?; + IdLp::new( + decode_peer_from_table( + &peers, + peer_idx, + "Decode movable list state failed", + )?, + lamport, + ) }; - let value = list_value_iter.next().unwrap(); + let value = list_value_iter.next().ok_or_else(|| { + state_decode_error("Decode movable list state failed: missing list value") + })?; ans.inner.push_inner( id_full, Some(PushElemInfo { @@ -1619,20 +1680,88 @@ mod snapshot { peer_idx, counter, lamport_sub_counter, - } = list_item_id_iter.next().unwrap().unwrap(); + } = list_item_id_iter + .next() + .ok_or_else(|| { + state_decode_error( + "Decode movable list state failed: missing invisible list item id", + ) + })? + .map_err(|err| { + state_decode_error(format!( + "Decode movable list state failed: invalid invisible list item id: {err}" + )) + })?; + let peer = decode_peer_from_table( + &peers, + peer_idx, + "Decode movable list state failed", + )?; let id_full = IdFull::new( - peers[peer_idx], + peer, counter, - (counter + lamport_sub_counter) as u32, + decode_lamport_from_delta( + counter, + lamport_sub_counter, + "Decode movable list state failed", + )?, ); ans.inner.push_inner(id_full, None); } } - debug_assert!(elem_iter.next().is_none()); - debug_assert!(list_item_id_iter.next().is_none()); - debug_assert!(last_set_id_iter.next().is_none()); - debug_assert!(list_value_iter.next().is_none()); + if is_first { + return Err(state_decode_error( + "Decode movable list state failed: missing sentinel item", + )); + } + if elem_iter + .next() + .transpose() + .map_err(|err| { + state_decode_error(format!( + "Decode movable list state failed: invalid extra element id: {err}" + )) + })? + .is_some() + { + return Err(state_decode_error( + "Decode movable list state failed: unused element id", + )); + } + if list_item_id_iter + .next() + .transpose() + .map_err(|err| { + state_decode_error(format!( + "Decode movable list state failed: invalid extra list item id: {err}" + )) + })? + .is_some() + { + return Err(state_decode_error( + "Decode movable list state failed: unused list item id", + )); + } + if last_set_id_iter + .next() + .transpose() + .map_err(|err| { + state_decode_error(format!( + "Decode movable list state failed: invalid extra last set id: {err}" + )) + })? + .is_some() + { + return Err(state_decode_error( + "Decode movable list state failed: unused last set id", + )); + } + if list_value_iter.next().is_some() { + return Err(state_decode_error( + "Decode movable list state failed: unused list value", + )); + } Ok(ans) } @@ -1772,6 +1901,34 @@ mod snapshot { list.encode_snapshot_fast(&mut bytes); assert!(bytes.len() <= 47, "{}", bytes.len()); } + + #[test] + fn movable_list_fast_snapshot_rejects_corrupt_state_metadata() { + let idx = ContainerIdx::from_index_and_type(0, loro_common::ContainerType::MovableList); + let configure = Default::default(); + let ctx = ContainerCreationContext { + configure: &configure, + peer: 0, + }; + + assert!(MovableListState::decode_snapshot_fast( + idx, + (Vec::::new().into(), &[1]), + ctx, + ) + .is_err()); + + let mut list = MovableListState::new(idx); + let mut bytes = Vec::new(); + list.encode_snapshot_fast(&mut bytes); + let (_, state_bytes) = MovableListState::decode_value(&bytes).unwrap(); + assert!(MovableListState::decode_snapshot_fast( + idx, + (vec![LoroValue::I64(1)].into(), state_bytes), + ctx, + ) + .is_err()); + } } } diff --git a/crates/loro-internal/src/state/richtext_state.rs b/crates/loro-internal/src/state/richtext_state.rs index 32f969376..a5e40dd87 100644 --- a/crates/loro-internal/src/state/richtext_state.rs +++ b/crates/loro-internal/src/state/richtext_state.rs @@ -1,7 +1,8 @@ use generic_btree::{rle::HasLength, rle::Sliceable as _, Cursor}; use loro_common::{ContainerID, InternalString, LoroError, LoroResult, LoroValue, ID}; -use loro_delta::DeltaRopeBuilder; +use loro_delta::{DeltaRope, DeltaRopeBuilder}; use rustc_hash::{FxHashMap, FxHashSet}; +use smallvec::SmallVec; use std::ops::Range; use std::sync::{Arc, Weak}; @@ -91,6 +92,7 @@ impl RichtextState { } } + #[allow(unused)] pub(crate) fn get_text_slice_by_event_index( &mut self, pos: usize, @@ -110,6 +112,7 @@ impl RichtextState { .slice_delta(start_index, end_index, pos_type) } + #[allow(unused)] pub(crate) fn get_char_by_event_index(&mut self, pos: usize) -> Result { self.state.get_mut().get_char_by_event_index(pos) } @@ -577,6 +580,12 @@ impl ContainerState for RichtextState { unreachable!() }; + if let LazyLoad::Src(loader) = &mut self.state { + if loader.try_apply_append_delta(&richtext) { + return Ok(()); + } + } + // Fast path for plain-text deltas (no style anchors / style ranges). // // Rebuilding avoids repeated BTree queries and mutations when the delta is very "choppy" @@ -877,7 +886,10 @@ impl ContainerState for RichtextState { // value is a list fn get_value(&mut self) -> LoroValue { - self.state.get_mut().to_string().into() + match &self.state { + LazyLoad::Src(loader) => loader.to_plain_string().into(), + LazyLoad::Dst(_) => self.state.get_mut().to_string().into(), + } } #[doc = r" Get the index of the child container"] @@ -938,6 +950,7 @@ impl RichtextState { } #[inline] + #[allow(dead_code)] pub(crate) fn has_styles(&mut self) -> bool { self.state.get_mut().has_styles() } @@ -1057,7 +1070,7 @@ impl RichtextState { #[derive(Debug, Default, Clone)] pub(crate) struct RichtextStateLoader { start_anchor_pos: FxHashMap, - elements: Vec, + elements: SmallVec<[RichtextStateChunk; 1]>, style_ranges: Vec<(Arc, Range)>, entity_index: usize, } @@ -1106,13 +1119,84 @@ impl RichtextStateLoader { fn is_empty(&self) -> bool { self.elements.is_empty() } + + fn to_plain_string(&self) -> String { + let len = self + .elements + .iter() + .map(|elem| match elem { + RichtextStateChunk::Text(text) => text.bytes().len(), + RichtextStateChunk::Style { .. } => 0, + }) + .sum(); + let mut text = String::with_capacity(len); + for elem in &self.elements { + if let RichtextStateChunk::Text(chunk) = elem { + text.push_str(chunk.as_str()); + } + } + text + } + + fn try_apply_append_delta(&mut self, delta: &DeltaRope) -> bool { + // Style anchors/ranges need InnerState's range maintenance. + if !self.start_anchor_pos.is_empty() || !self.style_ranges.is_empty() { + return false; + } + + let old_len = self.entity_index; + let mut current_len = self.entity_index; + let mut entity_index = 0; + let mut appended = Vec::new(); + for span in delta.iter() { + match span { + loro_delta::DeltaItem::Retain { len, .. } => { + entity_index += len; + if entity_index > old_len { + return false; + } + } + loro_delta::DeltaItem::Replace { value, delete, .. } => { + if *delete > 0 { + return false; + } + + let insert_len = value.rle_len(); + if insert_len == 0 { + continue; + } + + if !matches!(value, RichtextStateChunk::Text(_)) { + return false; + } + + if entity_index != current_len { + return false; + } + + appended.push(value.clone()); + entity_index += insert_len; + current_len += insert_len; + } + } + } + + if appended.is_empty() { + return false; + } + + for elem in appended { + self.push(elem); + } + true + } } mod snapshot { use loro_common::{IdFull, InternalString, LoroValue, PeerID}; use rustc_hash::FxHashMap; use serde_columnar::columnar; - use std::{io::Read, sync::Arc}; + use std::sync::Arc; use crate::{ container::richtext::{ @@ -1120,7 +1204,10 @@ mod snapshot { TextStyleInfoFlag, }, encoding::value_register::ValueRegister, - state::{ContainerCreationContext, ContainerState, FastStateSnapshot}, + state::{ + decode_lamport_from_delta, decode_peer_from_table, decode_peer_table, + state_decode_error, ContainerCreationContext, ContainerState, FastStateSnapshot, + }, utils::lazy::LazyLoad, }; @@ -1271,17 +1358,17 @@ mod snapshot { { let mut text = RichtextState::new(idx, ctx.configure.text_style_config.clone()); let mut loader = RichtextStateLoader::default(); - let peer_num = leb128::read::unsigned(&mut bytes).unwrap() as usize; - let mut peers = Vec::with_capacity(peer_num); - for _ in 0..peer_num { - let mut buf = [0u8; 8]; - bytes.read_exact(&mut buf).unwrap(); - peers.push(PeerID::from_le_bytes(buf)); - } + let peers = decode_peer_table(&mut bytes, "Decode richtext state failed")?; - let string = string.into_string().unwrap(); + let string = string.into_string().map_err(|_| { + state_decode_error("Decode richtext state failed: value is not a string") + })?; let mut s = StrSlice::new_from_str(&string); - let iters = serde_columnar::from_bytes::(bytes).unwrap(); + let iters = serde_columnar::from_bytes::(bytes).map_err(|err| { + state_decode_error(format!( + "Decode richtext state failed: invalid spans: {err}" + )) + })?; let keys = iters.keys; let span_iter = iters.spans.into_iter(); let mut mark_iter = iters.marks.into_iter(); @@ -1293,11 +1380,14 @@ mod snapshot { lamport_sub_counter, len, } = span; - let id_full = IdFull::new( - peers[peer_idx], + let peer = + decode_peer_from_table(&peers, peer_idx, "Decode richtext state failed")?; + let lamport = decode_lamport_from_delta( counter, - (lamport_sub_counter + counter) as u32, - ); + lamport_sub_counter, + "Decode richtext state failed", + )?; + let id_full = IdFull::new(peer, counter, lamport); let chunk = match len { 0 => { // Style Start @@ -1305,12 +1395,19 @@ mod snapshot { key_idx, value, info, - } = mark_iter.next().unwrap(); + } = mark_iter.next().ok_or_else(|| { + state_decode_error("Decode richtext state failed: missing style mark") + })?; + let key = keys.get(key_idx).ok_or_else(|| { + state_decode_error( + "Decode richtext state failed: style key index out of range", + ) + })?; let style_op = Arc::new(StyleOp { - lamport: (lamport_sub_counter + counter) as u32, + lamport, peer: id_full.peer, cnt: id_full.counter, - key: keys[key_idx].clone(), + key: key.clone(), value, info: TextStyleInfoFlag::from_byte(info), }); @@ -1319,11 +1416,25 @@ mod snapshot { } -1 => { // Style End - let style = id_to_style.remove(&id_full.id().inc(-1)).unwrap(); + let style = id_to_style.remove(&id_full.id().inc(-1)).ok_or_else(|| { + state_decode_error("Decode richtext state failed: unmatched style end") + })?; RichtextStateChunk::new_style(style, richtext::AnchorType::End) } len => { + if len < -1 { + return Err(state_decode_error( + "Decode richtext state failed: invalid text span length", + )); + } + // Text + if s.as_str().chars().count() < len as usize { + return Err(state_decode_error( + "Decode richtext state failed: text span exceeds value length", + )); + } + let (new, rest) = s.split_at_unicode_pos(len as usize); s = rest; RichtextStateChunk::new_text(new.bytes().clone(), id_full) @@ -1332,6 +1443,21 @@ mod snapshot { loader.push(chunk); } + if !s.as_str().is_empty() { + return Err(state_decode_error( + "Decode richtext state failed: text value not fully covered", + )); + } + if mark_iter.next().is_some() { + return Err(state_decode_error( + "Decode richtext state failed: unused style mark", + )); + } + if !id_to_style.is_empty() { + return Err(state_decode_error( + "Decode richtext state failed: unclosed style mark", + )); + } text.state = LazyLoad::Src(loader); // NOTE: We need to ensure the invariance that the version id is always increased when the richtext state is changed // This is used to avoid the version_id to be the same as the previous zero version @@ -1339,4 +1465,41 @@ mod snapshot { Ok(text) } } + + #[cfg(test)] + mod tests { + use loro_common::{ContainerType, LoroValue}; + + use crate::container::idx::ContainerIdx; + + use super::*; + + #[test] + fn richtext_fast_snapshot_rejects_corrupt_state_metadata() { + let idx = ContainerIdx::from_index_and_type(0, ContainerType::Text); + let configure = Default::default(); + let ctx = ContainerCreationContext { + configure: &configure, + peer: 0, + }; + + assert!(RichtextState::decode_snapshot_fast( + idx, + (LoroValue::String("a".into()), &[1]), + ctx, + ) + .is_err()); + + let mut text = RichtextState::new(idx, configure.text_style_config.clone()); + let mut bytes = Vec::new(); + text.encode_snapshot_fast(&mut bytes); + let (_, state_bytes) = RichtextState::decode_value(&bytes).unwrap(); + assert!(RichtextState::decode_snapshot_fast( + idx, + (LoroValue::String("a".into()), state_bytes), + ctx, + ) + .is_err()); + } + } } diff --git a/crates/loro-internal/src/state/tree_state.rs b/crates/loro-internal/src/state/tree_state.rs index d97c646e8..e89964a2c 100644 --- a/crates/loro-internal/src/state/tree_state.rs +++ b/crates/loro-internal/src/state/tree_state.rs @@ -1534,18 +1534,20 @@ mod jitter { } mod snapshot { - use std::{borrow::Cow, collections::BTreeSet, io::Read}; + use std::{borrow::Cow, collections::BTreeSet}; use fractional_index::FractionalIndex; - use itertools::Itertools; - use loro_common::{IdFull, Lamport, PeerID, TreeID}; + use loro_common::{IdFull, PeerID, TreeID}; use rustc_hash::FxHashMap; use serde_columnar::columnar; use crate::{ encoding::{arena::PositionArena, value_register::ValueRegister}, - state::FastStateSnapshot, + state::{ + decode_counter, decode_lamport_from_delta, decode_peer_from_table, decode_peer_table, + state_decode_error, FastStateSnapshot, + }, }; use super::{TreeNode, TreeParentId, TreeState}; @@ -1698,24 +1700,30 @@ mod snapshot { where Self: Sized, { - let peer_num = leb128::read::unsigned(&mut bytes).unwrap() as usize; - let mut peers = Vec::with_capacity(peer_num); - for _ in 0..peer_num { - let mut buf = [0u8; 8]; - bytes.read_exact(&mut buf).unwrap(); - peers.push(PeerID::from_le_bytes(buf)); - } + let peers = decode_peer_table(&mut bytes, "Decode tree state failed")?; let mut tree = TreeState::new(idx, ctx.peer); - let encoded: EncodedTree = serde_columnar::from_bytes(bytes)?; - let fractional_indexes = PositionArena::decode(&encoded.fractional_indexes).unwrap(); - let fractional_indexes = fractional_indexes.parse_to_positions(); + let encoded: EncodedTree = serde_columnar::from_bytes(bytes).map_err(|err| { + state_decode_error(format!("Decode tree state failed: invalid metadata: {err}")) + })?; + let fractional_indexes = + PositionArena::decode(&encoded.fractional_indexes)?.try_parse_to_positions()?; + if encoded.node_ids.len() != encoded.nodes.len() { + return Err(state_decode_error( + "Decode tree state failed: node id/node length mismatch", + )); + } let node_ids = encoded .node_ids .iter() - .map(|x| TreeID::new(peers[x.peer_idx], x.counter)) - .collect_vec(); - for (node_id, node) in node_ids.iter().zip(encoded.nodes.into_iter()) { + .map(|x| { + Ok(TreeID::new( + decode_peer_from_table(&peers, x.peer_idx, "Decode tree state failed")?, + decode_counter(x.counter, "Decode tree state failed")?, + )) + }) + .collect::>>()?; + for (node_id, node) in node_ids.iter().zip(encoded.nodes) { // PERF: we don't need to mov the deleted node, instead we can cache them // If the parent is TreeParentId::Deleted, then all the nodes afterwards are deleted tree._init_push_tree_node_in_order( @@ -1724,23 +1732,100 @@ mod snapshot { 0 => TreeParentId::Root, 1 => TreeParentId::Deleted, n => { - let id = node_ids[n - 2]; + let id = *node_ids.get(n - 2).ok_or_else(|| { + state_decode_error( + "Decode tree state failed: parent index out of range", + ) + })?; TreeParentId::from(Some(id)) } }, IdFull::new( - peers[node.last_set_peer_idx], + decode_peer_from_table( + &peers, + node.last_set_peer_idx, + "Decode tree state failed", + )?, node.last_set_counter, - (node.last_set_lamport_sub_counter + node.last_set_counter) as Lamport, + decode_lamport_from_delta( + node.last_set_counter, + node.last_set_lamport_sub_counter, + "Decode tree state failed", + )?, ), Some(FractionalIndex::from_bytes( - fractional_indexes[node.fractional_index_idx].clone(), + fractional_indexes + .get(node.fractional_index_idx) + .ok_or_else(|| { + state_decode_error( + "Decode tree state failed: fractional index out of range", + ) + })? + .clone(), )), ) - .unwrap(); + .map_err(|err| { + state_decode_error(format!("Decode tree state failed: invalid node: {err}")) + })?; } Ok(tree) } } + + #[cfg(test)] + mod tests { + use loro_common::{ContainerType, LoroValue}; + + use crate::{container::idx::ContainerIdx, state::ContainerCreationContext}; + + use super::*; + + #[test] + fn tree_fast_snapshot_rejects_corrupt_state_metadata() { + let idx = ContainerIdx::from_index_and_type(0, ContainerType::Tree); + let configure = Default::default(); + let ctx = ContainerCreationContext { + configure: &configure, + peer: 0, + }; + + assert!(TreeState::decode_snapshot_fast(idx, (LoroValue::Null, &[1]), ctx).is_err()); + } + + #[test] + fn tree_fast_snapshot_rejects_negative_node_counter() { + let idx = ContainerIdx::from_index_and_type(0, ContainerType::Tree); + let configure = Default::default(); + let ctx = ContainerCreationContext { + configure: &configure, + peer: 0, + }; + + let position = fractional_index::FractionalIndex::default(); + let positions = PositionArena::from_positions([position.as_bytes()]); + let encoded = EncodedTree { + node_ids: vec![EncodedTreeNodeId { + peer_idx: 0, + counter: -1, + }], + nodes: vec![EncodedTreeNode { + parent_idx_plus_two: 0, + last_set_peer_idx: 0, + last_set_counter: 0, + last_set_lamport_sub_counter: 0, + fractional_index_idx: 0, + }], + fractional_indexes: positions.encode().into(), + reserved_has_effect_bool_rle: vec![].into(), + }; + + let mut bytes = Vec::new(); + leb128::write::unsigned(&mut bytes, 1).unwrap(); + bytes.extend_from_slice(&1_u64.to_le_bytes()); + bytes.extend_from_slice(&serde_columnar::to_vec(&encoded).unwrap()); + + assert!(TreeState::decode_snapshot_fast(idx, (LoroValue::Null, &bytes), ctx).is_err()); + } + } } diff --git a/crates/loro-internal/src/txn.rs b/crates/loro-internal/src/txn.rs index b4185e013..26d4606c1 100644 --- a/crates/loro-internal/src/txn.rs +++ b/crates/loro-internal/src/txn.rs @@ -906,7 +906,7 @@ fn change_to_diff( }), EventHint::Tree(tree_diff) => { let mut diff = TreeDiff::default(); - diff.diff.extend(tree_diff.into_iter()); + diff.diff.extend(tree_diff); ans.push(TxnContainerDiff { idx: container_idx, diff: Diff::Tree(diff), diff --git a/crates/loro-internal/src/version.rs b/crates/loro-internal/src/version.rs index fa3a42f45..1c6ea17a4 100644 --- a/crates/loro-internal/src/version.rs +++ b/crates/loro-internal/src/version.rs @@ -75,7 +75,10 @@ impl VersionRange { pub fn from_vv(vv: &VersionVector) -> Self { let mut ans = Self::new(); for (peer, counter) in vv.iter() { - ans.insert(*peer, 0, *counter); + let counter = normalize_vv_counter(*counter); + if counter > 0 { + ans.insert(*peer, 0, counter); + } } ans } @@ -159,6 +162,16 @@ impl VersionRange { #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct ImVersionVector(im::HashMap); +#[inline] +fn normalize_vv_counter(counter: Counter) -> Counter { + counter.max(0) +} + +#[inline] +fn last_id_to_vv_end(id: ID) -> Counter { + id.counter.saturating_add(1).max(0) +} + impl ImVersionVector { pub fn new() -> Self { Self(Default::default()) @@ -222,7 +235,13 @@ impl ImVersionVector { vv: impl Iterator, ) { for (&client_id, &counter) in vv { + let counter = normalize_vv_counter(counter); + if counter == 0 { + continue; + } + if let Some(my_counter) = self.0.get_mut(&client_id) { + *my_counter = normalize_vv_counter(*my_counter); if *my_counter < counter { *my_counter = counter; } @@ -244,35 +263,46 @@ impl ImVersionVector { #[inline] pub fn set_last(&mut self, id: ID) { - self.0.insert(id.peer, id.counter + 1); + let end = last_id_to_vv_end(id); + if end == 0 { + self.0.remove(&id.peer); + } else { + self.0.insert(id.peer, end); + } } pub fn extend_to_include_last_id(&mut self, id: ID) { + let end = last_id_to_vv_end(id); + if end == 0 { + return; + } + if let Some(counter) = self.0.get_mut(&id.peer) { - if *counter <= id.counter { - *counter = id.counter + 1; + *counter = normalize_vv_counter(*counter); + if *counter < end { + *counter = end; } } else { - self.set_last(id) + self.0.insert(id.peer, end); } } pub(crate) fn includes_id(&self, x: ID) -> bool { - if self.is_empty() { + if self.is_empty() || x.counter < 0 { return false; } - self.get(&x.peer).copied().unwrap_or(0) > x.counter + normalize_vv_counter(self.get(&x.peer).copied().unwrap_or(0)) > x.counter } } impl PartialEq for VersionVector { fn eq(&self, other: &Self) -> bool { - self.iter() - .all(|(client, counter)| other.get(client).unwrap_or(&0) == counter) - && other - .iter() - .all(|(client, counter)| self.get(client).unwrap_or(&0) == counter) + self.iter().all(|(client, counter)| { + normalize_vv_counter(*other.get(client).unwrap_or(&0)) == normalize_vv_counter(*counter) + }) && other.iter().all(|(client, counter)| { + normalize_vv_counter(*self.get(client).unwrap_or(&0)) == normalize_vv_counter(*counter) + }) } } @@ -280,13 +310,13 @@ impl Eq for VersionVector {} impl PartialEq for ImVersionVector { fn eq(&self, other: &Self) -> bool { - self.0 - .iter() - .all(|(client, counter)| other.0.get(client).unwrap_or(&0) == counter) - && other - .0 - .iter() - .all(|(client, counter)| self.0.get(client).unwrap_or(&0) == counter) + self.0.iter().all(|(client, counter)| { + normalize_vv_counter(*other.0.get(client).unwrap_or(&0)) + == normalize_vv_counter(*counter) + }) && other.0.iter().all(|(client, counter)| { + normalize_vv_counter(*self.0.get(client).unwrap_or(&0)) + == normalize_vv_counter(*counter) + }) } } @@ -372,7 +402,9 @@ impl PartialOrd for VersionVector { let mut other_greater = true; let mut eq = true; for (client_id, other_end) in other.iter() { + let other_end = normalize_vv_counter(*other_end); if let Some(self_end) = self.get(client_id) { + let self_end = normalize_vv_counter(*self_end); if self_end < other_end { self_greater = false; eq = false; @@ -381,7 +413,7 @@ impl PartialOrd for VersionVector { other_greater = false; eq = false; } - } else if *other_end > 0 { + } else if other_end > 0 { self_greater = false; eq = false; } @@ -390,7 +422,7 @@ impl PartialOrd for VersionVector { for (client_id, self_end) in self.iter() { if other.contains_key(client_id) { continue; - } else if *self_end > 0 { + } else if normalize_vv_counter(*self_end) > 0 { other_greater = false; eq = false; } @@ -414,7 +446,9 @@ impl PartialOrd for ImVersionVector { let mut other_greater = true; let mut eq = true; for (client_id, other_end) in other.iter() { + let other_end = normalize_vv_counter(*other_end); if let Some(self_end) = self.get(client_id) { + let self_end = normalize_vv_counter(*self_end); if self_end < other_end { self_greater = false; eq = false; @@ -423,7 +457,7 @@ impl PartialOrd for ImVersionVector { other_greater = false; eq = false; } - } else if *other_end > 0 { + } else if other_end > 0 { self_greater = false; eq = false; } @@ -432,7 +466,7 @@ impl PartialOrd for ImVersionVector { for (client_id, self_end) in self.iter() { if other.contains_key(client_id) { continue; - } else if *self_end > 0 { + } else if normalize_vv_counter(*self_end) > 0 { other_greater = false; eq = false; } @@ -460,7 +494,9 @@ impl VersionVector { pub fn diff(&self, rhs: &Self) -> VersionVectorDiff { let mut ans: VersionVectorDiff = Default::default(); for (client_id, &counter) in self.iter() { + let counter = normalize_vv_counter(counter); if let Some(&rhs_counter) = rhs.get(client_id) { + let rhs_counter = normalize_vv_counter(rhs_counter); match counter.cmp(&rhs_counter) { Ordering::Less => { ans.forward.insert( @@ -482,7 +518,7 @@ impl VersionVector { } Ordering::Equal => {} } - } else { + } else if counter > 0 { ans.retreat.insert( *client_id, CounterSpan { @@ -493,7 +529,8 @@ impl VersionVector { } } for (client_id, &rhs_counter) in rhs.iter() { - if !self.contains_key(client_id) { + let rhs_counter = normalize_vv_counter(rhs_counter); + if rhs_counter > 0 && !self.contains_key(client_id) { ans.forward.insert( *client_id, CounterSpan { @@ -524,7 +561,9 @@ impl VersionVector { /// Returns the spans that are in `self` but not in `rhs` pub fn sub_iter<'a>(&'a self, rhs: &'a Self) -> impl Iterator + 'a { self.iter().filter_map(move |(peer, &counter)| { + let counter = normalize_vv_counter(counter); if let Some(&rhs_counter) = rhs.get(peer) { + let rhs_counter = normalize_vv_counter(rhs_counter); if counter > rhs_counter { Some(IdSpan { peer: *peer, @@ -556,7 +595,9 @@ impl VersionVector { rhs: &'a ImVersionVector, ) -> impl Iterator + 'a { self.iter().filter_map(move |(peer, &counter)| { + let counter = normalize_vv_counter(counter); if let Some(&rhs_counter) = rhs.get(peer) { + let rhs_counter = normalize_vv_counter(rhs_counter); if counter > rhs_counter { Some(IdSpan { peer: *peer, @@ -593,34 +634,37 @@ impl VersionVector { } pub fn distance_between(&self, other: &Self) -> usize { - let mut ans = 0; + let mut ans = 0usize; for (client_id, &counter) in self.iter() { + let counter = counter.max(0) as i64; if let Some(&other_counter) = other.get(client_id) { - ans += (counter - other_counter).abs(); - } else if counter > 0 { - ans += counter; + let other_counter = other_counter.max(0) as i64; + ans += counter.abs_diff(other_counter) as usize; + } else { + ans += counter as usize; } } for (client_id, &counter) in other.iter() { if !self.contains_key(client_id) { - ans += counter; + ans += counter.max(0) as usize; } } - ans as usize + ans } pub fn to_spans(&self) -> IdSpanVector { self.iter() - .map(|(client_id, &counter)| { - ( + .filter_map(|(client_id, &counter)| { + let counter = normalize_vv_counter(counter); + (counter > 0).then_some(( *client_id, CounterSpan { start: 0, end: counter, }, - ) + )) }) .collect() } @@ -649,14 +693,24 @@ impl VersionVector { /// set the inclusive ending point. target id will be included by self #[inline] pub fn set_last(&mut self, id: ID) { - self.0.insert(id.peer, id.counter + 1); + let end = last_id_to_vv_end(id); + if end == 0 { + self.0.remove(&id.peer); + } else { + self.0.insert(id.peer, end); + } } #[inline] pub fn get_last(&self, client_id: PeerID) -> Option { - self.0 - .get(&client_id) - .and_then(|&x| if x == 0 { None } else { Some(x - 1) }) + self.0.get(&client_id).and_then(|&x| { + let x = normalize_vv_counter(x); + if x == 0 { + None + } else { + Some(x - 1) + } + }) } /// set the exclusive ending point. target id will NOT be included by self @@ -673,15 +727,24 @@ impl VersionVector { /// Return whether updated #[inline] pub fn try_update_last(&mut self, id: ID) -> bool { + let new_end = last_id_to_vv_end(id); + if new_end == 0 { + if self.0.get(&id.peer).is_some_and(|counter| *counter < 0) { + self.0.remove(&id.peer); + } + return false; + } + if let Some(end) = self.0.get_mut(&id.peer) { - if *end < id.counter + 1 { - *end = id.counter + 1; + *end = normalize_vv_counter(*end); + if *end < new_end { + *end = new_end; true } else { false } } else { - self.0.insert(id.peer, id.counter + 1); + self.0.insert(id.peer, new_end); true } } @@ -689,12 +752,17 @@ impl VersionVector { pub fn get_missing_span(&self, target: &Self) -> Vec { let mut ans = vec![]; for (client_id, other_end) in target.iter() { - if let Some(my_end) = self.get(client_id) { - if my_end < other_end { - ans.push(IdSpan::new(*client_id, *my_end, *other_end)); - } - } else { - ans.push(IdSpan::new(*client_id, 0, *other_end)); + let other_end = normalize_vv_counter(*other_end); + if other_end == 0 { + continue; + } + + let my_end = self + .get(client_id) + .map(|counter| normalize_vv_counter(*counter)) + .unwrap_or(0); + if my_end < other_end { + ans.push(IdSpan::new(*client_id, my_end, other_end)); } } @@ -703,7 +771,13 @@ impl VersionVector { pub fn merge(&mut self, other: &Self) { for (&client_id, &other_end) in other.iter() { + let other_end = normalize_vv_counter(other_end); + if other_end == 0 { + continue; + } + if let Some(my_end) = self.get_mut(&client_id) { + *my_end = normalize_vv_counter(*my_end); if *my_end < other_end { *my_end = other_end; } @@ -725,8 +799,12 @@ impl VersionVector { } pub fn includes_id(&self, id: ID) -> bool { + if id.counter < 0 { + return false; + } + if let Some(end) = self.get(&id.peer) { - if *end > id.counter { + if normalize_vv_counter(*end) > id.counter { return true; } } @@ -735,10 +813,12 @@ impl VersionVector { pub fn intersect_span(&self, target: IdSpan) -> Option { if let Some(&end) = self.get(&target.peer) { - if end > target.ctr_start() { - let count_end = target.ctr_end(); + let end = normalize_vv_counter(end); + let count_start = target.ctr_start().max(0); + let count_end = target.ctr_end().max(0); + if end > count_start && count_end > count_start { return Some(CounterSpan { - start: target.ctr_start(), + start: count_start, end: end.min(count_end), }); } @@ -752,7 +832,13 @@ impl VersionVector { vv: impl Iterator, ) { for (&client_id, &counter) in vv { + let counter = normalize_vv_counter(counter); + if counter == 0 { + continue; + } + if let Some(my_counter) = self.get_mut(&client_id) { + *my_counter = normalize_vv_counter(*my_counter); if *my_counter < counter { *my_counter = counter; } @@ -763,44 +849,75 @@ impl VersionVector { } pub fn extend_to_include_last_id(&mut self, id: ID) { + let end = last_id_to_vv_end(id); + if end == 0 { + return; + } + if let Some(counter) = self.get_mut(&id.peer) { - if *counter <= id.counter { - *counter = id.counter + 1; + *counter = normalize_vv_counter(*counter); + if *counter < end { + *counter = end; } } else { - self.set_last(id) + self.0.insert(id.peer, end); } } pub fn extend_to_include_end_id(&mut self, id: ID) { + let end = normalize_vv_counter(id.counter); + if end == 0 { + return; + } + if let Some(counter) = self.get_mut(&id.peer) { - if *counter < id.counter { - *counter = id.counter; + *counter = normalize_vv_counter(*counter); + if *counter < end { + *counter = end; } } else { - self.set_end(id) + self.0.insert(id.peer, end); } } pub fn extend_to_include(&mut self, span: IdSpan) { + let end = normalize_vv_counter(span.counter.norm_end()); + if end == 0 { + if self.0.get(&span.peer).is_some_and(|counter| *counter < 0) { + self.0.remove(&span.peer); + } + return; + } + if let Some(counter) = self.get_mut(&span.peer) { - if *counter < span.counter.norm_end() { - *counter = span.counter.norm_end(); + *counter = normalize_vv_counter(*counter); + if *counter < end { + *counter = end; } } else { - self.insert(span.peer, span.counter.norm_end()); + self.insert(span.peer, end); } } pub fn shrink_to_exclude(&mut self, span: IdSpan) { - if span.counter.min() == 0 { + let start = normalize_vv_counter(span.counter.min()); + let end = normalize_vv_counter(span.counter.norm_end()); + if end <= start { + return; + } + + if start == 0 { self.remove(&span.peer); return; } if let Some(counter) = self.get_mut(&span.peer) { - if *counter > span.counter.min() { - *counter = span.counter.min(); + *counter = normalize_vv_counter(*counter); + if *counter > start { + *counter = start; + } + if *counter == 0 { + self.remove(&span.peer); } } } @@ -826,7 +943,9 @@ impl VersionVector { pub fn intersection(&self, other: &VersionVector) -> VersionVector { let mut ans = VersionVector::new(); for (client_id, &counter) in self.iter() { + let counter = normalize_vv_counter(counter); if let Some(&other_counter) = other.get(client_id) { + let other_counter = normalize_vv_counter(other_counter); if counter < other_counter { if counter != 0 { ans.insert(*client_id, counter); diff --git a/crates/loro-internal/tests/undo.rs b/crates/loro-internal/tests/undo.rs index 23e2103b9..c1ad818f2 100644 --- a/crates/loro-internal/tests/undo.rs +++ b/crates/loro-internal/tests/undo.rs @@ -175,7 +175,8 @@ fn test_clear_redo() { // Make some edits text.update("hello", UpdateOptions::default()).unwrap(); doc.commit_then_renew(); - text.update("hello world", UpdateOptions::default()).unwrap(); + text.update("hello world", UpdateOptions::default()) + .unwrap(); doc.commit_then_renew(); // Undo to create redo stack @@ -187,7 +188,10 @@ fn test_clear_redo() { // Clear only redo stack undo_manager.clear_redo(); assert!(!undo_manager.can_redo(), "redo stack should be empty"); - assert!(undo_manager.can_undo(), "undo stack should still have items"); + assert!( + undo_manager.can_undo(), + "undo stack should still have items" + ); // Verify undo still works undo_manager.undo().unwrap(); @@ -203,7 +207,8 @@ fn test_clear_undo() { // Make some edits text.update("hello", UpdateOptions::default()).unwrap(); doc.commit_then_renew(); - text.update("hello world", UpdateOptions::default()).unwrap(); + text.update("hello world", UpdateOptions::default()) + .unwrap(); doc.commit_then_renew(); // Undo to create redo stack @@ -214,7 +219,10 @@ fn test_clear_undo() { // Clear only undo stack undo_manager.clear_undo(); - assert!(undo_manager.can_redo(), "redo stack should still have items"); + assert!( + undo_manager.can_redo(), + "redo stack should still have items" + ); assert!(!undo_manager.can_undo(), "undo stack should be empty"); // Verify redo still works diff --git a/crates/loro-wasm/.gitignore b/crates/loro-wasm/.gitignore index 9fa05e381..03b46fbc4 100644 --- a/crates/loro-wasm/.gitignore +++ b/crates/loro-wasm/.gitignore @@ -2,6 +2,7 @@ node_modules/ npm/ nodejs/ bundler/ +browser/ web/ docs/ base64/ diff --git a/crates/loro-wasm/src/awareness.rs b/crates/loro-wasm/src/awareness.rs index 8e08a968e..6e105319d 100644 --- a/crates/loro-wasm/src/awareness.rs +++ b/crates/loro-wasm/src/awareness.rs @@ -73,7 +73,10 @@ impl AwarenessWasm { /// Each peer's deletion countdown will be reset upon update, requiring them to pass through the `timeout` /// interval again before being eligible for deletion. pub fn apply(&mut self, encoded_peers_info: Vec) -> JsResult { - let (updated, added) = self.inner.apply(&encoded_peers_info); + let (updated, added) = self + .inner + .try_apply(&encoded_peers_info) + .map_err(|e| JsValue::from_str(&e))?; let ans = Object::new(); let updated = Array::from_iter(updated.into_iter().map(peer_to_str_js)); let added = Array::from_iter(added.into_iter().map(peer_to_str_js)); diff --git a/crates/loro-wasm/src/convert.rs b/crates/loro-wasm/src/convert.rs index 154c86743..6fc8855ad 100644 --- a/crates/loro-wasm/src/convert.rs +++ b/crates/loro-wasm/src/convert.rs @@ -36,7 +36,7 @@ fn js_wbg_ptr(js: &JsValue) -> JsResult { } fn validate_wbg_ptr_alignment(ptr: u32) -> Result<(), &'static str> { - if (ptr as usize) % std::mem::align_of::>() != 0 { + if !(ptr as usize).is_multiple_of(std::mem::align_of::>()) { return Err("Invalid wasm-bindgen pointer alignment"); } diff --git a/crates/loro-wasm/src/lib.rs b/crates/loro-wasm/src/lib.rs index 7234e3a36..8797f20eb 100644 --- a/crates/loro-wasm/src/lib.rs +++ b/crates/loro-wasm/src/lib.rs @@ -90,7 +90,7 @@ type JsResult = Result; type EventCallback = Box bool + Send + Sync + 'static>; thread_local! { - static IN_PRE_COMMIT_CALLBACK: Cell = Cell::new(false); + static IN_PRE_COMMIT_CALLBACK: Cell = const { Cell::new(false) }; } /// The CRDTs document. Loro supports different CRDTs include [**List**](LoroList), diff --git a/crates/loro-wasm/tests/awareness.test.ts b/crates/loro-wasm/tests/awareness.test.ts index 41baf08a2..429c94e51 100644 --- a/crates/loro-wasm/tests/awareness.test.ts +++ b/crates/loro-wasm/tests/awareness.test.ts @@ -28,6 +28,15 @@ describe("Awareness", () => { expect(awarenessB.getAllStates()).toEqual({ "123": { foo: "bar" } }); }); + it("rejects invalid payloads", () => { + const awareness = new AwarenessWasm("123", 30_000); + + expect(() => awareness.apply(Uint8Array.from([0xff, 0xff, 0xff, 0xff]))).toThrow( + /Failed to decode awareness data/, + ); + expect(awareness.getAllStates()).toEqual({}); + }); + it("not sync if peer is not in sync list", () => { const awareness = new AwarenessWasm("123", 30_000); awareness.setLocalState({ foo: "bar" }); diff --git a/crates/loro/src/event.rs b/crates/loro/src/event.rs index eca51308c..4273b2770 100644 --- a/crates/loro/src/event.rs +++ b/crates/loro/src/event.rs @@ -19,6 +19,7 @@ use std::ops::Deref; use std::sync::Arc; use crate::ValueOrContainer; +use crate::{LoroError, LoroResult}; /// A subscriber to the event. #[allow(clippy::unused_unit)] @@ -287,6 +288,59 @@ impl DiffBatch { } } } + + pub(crate) fn validate_for_apply(&self) -> LoroResult<()> { + for (_, diff) in self.iter() { + validate_diff_for_apply(diff)?; + } + + Ok(()) + } +} + +fn validate_diff_for_apply(diff: &Diff<'_>) -> LoroResult<()> { + let mut index = 0usize; + match diff { + Diff::Text(delta) => { + for item in delta { + match item { + TextDelta::Retain { retain, .. } => { + index = checked_apply_diff_index_end(index, *retain)?; + } + TextDelta::Delete { delete } => { + index = checked_apply_diff_index_end(index, *delete)?; + } + TextDelta::Insert { .. } => {} + } + } + } + Diff::List(delta) => { + for item in delta { + match item { + ListDiffItem::Retain { retain } => { + index = checked_apply_diff_index_end(index, *retain)?; + } + ListDiffItem::Delete { delete } => { + index = checked_apply_diff_index_end(index, *delete)?; + } + ListDiffItem::Insert { .. } => {} + } + } + } + Diff::Map(_) | Diff::Tree(_) | Diff::Unknown => {} + #[cfg(feature = "counter")] + Diff::Counter(_) => {} + } + + Ok(()) +} + +fn checked_apply_diff_index_end(pos: usize, len: usize) -> LoroResult { + pos.checked_add(len).ok_or_else(|| LoroError::OutOfBound { + pos: usize::MAX, + len: usize::MAX, + info: "apply_diff".into(), + }) } impl From for DiffBatch { diff --git a/crates/loro/src/lib.rs b/crates/loro/src/lib.rs index f8eeff7e2..81fc79260 100644 --- a/crates/loro/src/lib.rs +++ b/crates/loro/src/lib.rs @@ -1467,6 +1467,7 @@ impl LoroDoc { /// Internally, it will apply the diff to the current state. #[inline] pub fn apply_diff(&self, diff: DiffBatch) -> LoroResult<()> { + diff.validate_for_apply()?; self.doc.apply_diff(diff.into()) } diff --git a/crates/loro/tests/contracts/awareness.rs b/crates/loro/tests/contracts/awareness.rs index cbfc0ae37..966d0cbf9 100644 --- a/crates/loro/tests/contracts/awareness.rs +++ b/crates/loro/tests/contracts/awareness.rs @@ -163,6 +163,15 @@ fn legacy_awareness_selected_sync_stale_updates_and_timeout_follow_contract() { assert!(added.is_empty()); } +#[allow(deprecated)] +#[test] +fn legacy_awareness_try_apply_rejects_invalid_payloads() { + let mut awareness = Awareness::new(1, 30_000); + + assert!(awareness.try_apply(&[0xff, 0xff, 0xff, 0xff]).is_err()); + assert!(awareness.get_all_states().is_empty()); +} + #[test] fn ephemeral_store_rejects_invalid_payloads_and_unsubscribes_false_callbacks() { let store = EphemeralStore::new(30_000); diff --git a/crates/loro/tests/contracts/container_handlers.rs b/crates/loro/tests/contracts/container_handlers.rs index fa5133a85..2cd3a0258 100644 --- a/crates/loro/tests/contracts/container_handlers.rs +++ b/crates/loro/tests/contracts/container_handlers.rs @@ -169,6 +169,58 @@ fn map_contracts_cover_iteration_lookup_and_root_hiding() -> LoroResult<()> { Ok(()) } +#[test] +fn imported_lazy_maps_iterate_in_key_order() -> LoroResult<()> { + fn insert_entries(map: &LoroMap, entries: &[&str]) -> LoroResult<()> { + for key in entries { + map.insert(*key, format!("{key}-value"))?; + } + Ok(()) + } + + fn assert_map_order(map: &LoroMap, expected: &[&str]) { + let keys = map.keys().map(|key| key.to_string()).collect::>(); + assert_eq!(keys, expected); + + let values = map + .values() + .map(|value| value.get_deep_value().to_json_value()) + .collect::>(); + assert_eq!( + values, + expected + .iter() + .map(|key| json!(format!("{key}-value"))) + .collect::>() + ); + + let mut entries = Vec::new(); + map.for_each(|key, value| { + entries.push((key.to_string(), value.get_deep_value().to_json_value())); + }); + assert_eq!( + entries, + expected + .iter() + .map(|key| (key.to_string(), json!(format!("{key}-value")))) + .collect::>() + ); + } + + let doc = LoroDoc::new(); + insert_entries(&doc.get_map("small"), &["z", "a", "m", "b"])?; + insert_entries(&doc.get_map("large"), &["z", "a", "m", "b", "q"])?; + doc.commit(); + + let snapshot = doc.export(ExportMode::Snapshot)?; + let restored = LoroDoc::from_snapshot(&snapshot)?; + + assert_map_order(&restored.get_map("small"), &["a", "b", "m", "z"]); + assert_map_order(&restored.get_map("large"), &["a", "b", "m", "q", "z"]); + + Ok(()) +} + #[test] fn list_contracts_cover_insert_delete_pop_clear_cursor_and_nested_containers() -> LoroResult<()> { let doc = LoroDoc::new(); diff --git a/crates/loro/tests/contracts/doc_analysis.rs b/crates/loro/tests/contracts/doc_analysis.rs index a9d1cce37..c5ac0e6fa 100644 --- a/crates/loro/tests/contracts/doc_analysis.rs +++ b/crates/loro/tests/contracts/doc_analysis.rs @@ -165,6 +165,17 @@ fn change_traversal_exposes_changed_containers_and_id_spans() -> anyhow::Result< let changed_second = doc.get_changed_containers_in(second_change.id, second_change.len); assert!(changed_second.contains(&text.id())); + let first_counter = first_change.id.counter; + let changed_from_negative = + doc.get_changed_containers_in(ID::new(first_change.id.peer, first_counter - 1), 2); + assert_eq!( + changed_from_negative, + doc.get_changed_containers_in(first_change.id, 1) + ); + assert!(doc + .get_changed_containers_in(ID::new(first_change.id.peer, i32::MAX), 1) + .is_empty()); + let between = doc.find_id_spans_between(&Frontiers::from_id(first), &Frontiers::from_id(third)); assert_eq!( between.forward.get(&103).map(|span| (span.start, span.end)), diff --git a/crates/loro/tests/contracts/doc_export.rs b/crates/loro/tests/contracts/doc_export.rs index 8dfc5ddd6..ec8657c2e 100644 --- a/crates/loro/tests/contracts/doc_export.rs +++ b/crates/loro/tests/contracts/doc_export.rs @@ -1,8 +1,8 @@ use std::sync::{Arc, Mutex}; use loro::{ - CommitOptions, ExportMode, IdSpan, Index, LoroDoc, LoroList, LoroMap, LoroText, Timestamp, - ToJson, TreeParentId, VersionVector, ID, + CommitOptions, ContainerID, ContainerType, ExportMode, IdSpan, Index, LoroDoc, LoroList, + LoroMap, LoroText, Timestamp, ToJson, TreeParentId, VersionVector, ID, }; use pretty_assertions::assert_eq; use serde_json::{json, Value}; @@ -38,6 +38,62 @@ fn deep_value_prefers_non_empty_root_when_same_name_has_empty_container() -> any Ok(()) } +#[test] +fn deleted_imported_root_containers_are_removed_from_snapshots() -> anyhow::Result<()> { + let doc = LoroDoc::new(); + doc.get_list("list"); + doc.get_map("map"); + doc.get_movable_list("movable"); + doc.get_text("text"); + doc.get_tree("tree"); + #[cfg(feature = "counter")] + doc.get_counter("counter"); + let restored = LoroDoc::from_snapshot(&doc.export(ExportMode::Snapshot)?)?; + restored.delete_root_container(ContainerID::new_root("list", ContainerType::List)); + restored.delete_root_container(ContainerID::new_root("map", ContainerType::Map)); + restored.delete_root_container(ContainerID::new_root("movable", ContainerType::MovableList)); + restored.delete_root_container(ContainerID::new_root("text", ContainerType::Text)); + restored.delete_root_container(ContainerID::new_root("tree", ContainerType::Tree)); + #[cfg(feature = "counter")] + restored.delete_root_container(ContainerID::new_root("counter", ContainerType::Counter)); + assert_eq!(deep_json(&restored), json!({})); + + let restored_again = LoroDoc::from_snapshot(&restored.export(ExportMode::Snapshot)?)?; + assert_eq!(deep_json(&restored_again), json!({})); + + Ok(()) +} + +#[test] +fn deleted_imported_non_empty_root_containers_are_removed_from_snapshots() -> anyhow::Result<()> { + let doc = LoroDoc::new(); + doc.get_list("list").push("item")?; + doc.get_map("map").insert("key", "value")?; + doc.get_movable_list("movable").push("item")?; + doc.get_text("text").insert(0, "hello")?; + let tree = doc.get_tree("tree"); + let node = tree.create(TreeParentId::Root)?; + tree.get_meta(node)?.insert("key", "value")?; + #[cfg(feature = "counter")] + doc.get_counter("counter").increment(1.5)?; + doc.commit(); + + let restored = LoroDoc::from_snapshot(&doc.export(ExportMode::Snapshot)?)?; + restored.delete_root_container(ContainerID::new_root("list", ContainerType::List)); + restored.delete_root_container(ContainerID::new_root("map", ContainerType::Map)); + restored.delete_root_container(ContainerID::new_root("movable", ContainerType::MovableList)); + restored.delete_root_container(ContainerID::new_root("text", ContainerType::Text)); + restored.delete_root_container(ContainerID::new_root("tree", ContainerType::Tree)); + #[cfg(feature = "counter")] + restored.delete_root_container(ContainerID::new_root("counter", ContainerType::Counter)); + assert_eq!(deep_json(&restored), json!({})); + + let restored_again = LoroDoc::from_snapshot(&restored.export(ExportMode::Snapshot)?)?; + assert_eq!(deep_json(&restored_again), json!({})); + + Ok(()) +} + #[test] fn commit_metadata_empty_commit_and_json_updates_roundtrip_follow_contract() -> anyhow::Result<()> { let doc = LoroDoc::new(); @@ -371,6 +427,51 @@ fn state_only_shallow_since_and_updates_till_cover_export_boundaries() -> anyhow Ok(()) } +#[test] +fn updates_in_range_clamps_negative_counter_ranges() -> anyhow::Result<()> { + let source = LoroDoc::new(); + source.set_peer_id(92)?; + source.set_change_merge_interval(0); + + let text = source.get_text("text"); + text.insert(0, "a")?; + source.commit(); + text.insert(1, "b")?; + source.commit(); + text.insert(2, "c")?; + source.commit(); + + let peer = source.peer_id(); + let updates = source.export(ExportMode::updates_in_range(vec![IdSpan::new( + peer, + i32::MIN, + 1, + )]))?; + let restored = LoroDoc::new(); + restored.import(&updates)?; + assert_eq!(restored.get_text("text").to_string(), "a"); + + let updates = source.export(ExportMode::updates_in_range(vec![IdSpan::new( + peer, + i32::MIN, + -1, + )]))?; + let restored = LoroDoc::new(); + restored.import(&updates)?; + assert_eq!(restored.get_text("text").to_string(), ""); + + let updates = source.export(ExportMode::updates_in_range(vec![IdSpan::new( + peer, + 1, + i32::MIN, + )]))?; + let restored = LoroDoc::new(); + restored.import(&updates)?; + assert_eq!(restored.get_text("text").to_string(), "ab"); + + Ok(()) +} + #[test] fn import_batch_reports_pending_until_dependencies_arrive() -> anyhow::Result<()> { let source = LoroDoc::new(); diff --git a/crates/loro/tests/contracts/handler_edges.rs b/crates/loro/tests/contracts/handler_edges.rs index 007d0f5e0..7cda85f39 100644 --- a/crates/loro/tests/contracts/handler_edges.rs +++ b/crates/loro/tests/contracts/handler_edges.rs @@ -1,10 +1,11 @@ -use std::collections::BTreeSet; +use std::collections::{BTreeSet, HashMap}; use loro::{ cursor::{PosType, Side}, - Container, ContainerTrait, ExpandType, Index, LoroDoc, LoroError, LoroList, LoroMap, - LoroMovableList, LoroResult, LoroText, LoroTree, LoroValue, StyleConfig, StyleConfigMap, - TextDelta, ToJson, TreeParentId, ValueOrContainer, + event::{Diff, DiffBatch, ListDiffItem}, + Container, ContainerID, ContainerTrait, ExpandType, Index, LoroDoc, LoroError, LoroList, + LoroMap, LoroMovableList, LoroResult, LoroText, LoroTree, LoroValue, StyleConfig, + StyleConfigMap, TextDelta, ToJson, TreeParentId, ValueOrContainer, }; use pretty_assertions::assert_eq; use serde_json::json; @@ -63,6 +64,180 @@ fn assert_container_deleted(result: LoroResult) { } } +fn value_with_nested_container(id: ContainerID) -> LoroValue { + LoroValue::from(HashMap::from([( + "nested".to_string(), + LoroValue::Container(id), + )])) +} + +fn assert_arg_error(result: LoroResult<()>) { + match result { + Err(LoroError::ArgErr(_)) => {} + other => panic!("expected ArgErr, got {other:?}"), + } +} + +fn assert_out_of_bound(result: LoroResult) { + match result { + Err(LoroError::OutOfBound { .. }) => {} + other => panic!("expected OutOfBound, got {other:?}"), + } +} + +#[test] +fn regular_value_writes_reject_nested_container_refs() -> LoroResult<()> { + let doc = LoroDoc::new(); + let root = doc.get_map("root"); + let child = root.insert_container("child", LoroMap::new())?; + let value = value_with_nested_container(child.id()); + + assert_arg_error(root.insert("bad", value.clone())); + + let list = doc.get_list("list"); + assert_arg_error(list.insert(0, value.clone())); + + let movable = doc.get_movable_list("movable"); + movable.insert(0, "seed")?; + assert_arg_error(movable.insert(1, value.clone())); + assert_arg_error(movable.set(0, value.clone())); + assert_arg_error(movable.set(0, LoroValue::Container(child.id()))); + + let text = doc.get_text("text"); + text.insert(0, "a")?; + assert_arg_error(text.mark(0..1, "bold", value.clone())); + + let detached_map = LoroMap::new(); + assert_arg_error(detached_map.insert("bad", value.clone())); + + let detached_list = LoroList::new(); + assert_arg_error(detached_list.insert(0, value.clone())); + + let detached_movable = LoroMovableList::new(); + detached_movable.insert(0, "seed")?; + assert_arg_error(detached_movable.insert(1, value.clone())); + assert_arg_error(detached_movable.set(0, value.clone())); + + let detached_text = LoroText::new(); + detached_text.insert(0, "a")?; + assert_arg_error(detached_text.mark(0..1, "bold", value)); + + Ok(()) +} + +#[test] +fn range_mutations_reject_overflowing_delete_lengths() -> LoroResult<()> { + let doc = LoroDoc::new(); + + let text = doc.get_text("text"); + text.insert(0, "abc")?; + assert_out_of_bound(text.delete(usize::MAX, 1)); + assert_out_of_bound(text.delete(1, usize::MAX)); + assert_out_of_bound(text.splice(usize::MAX, 1, "x")); + assert_out_of_bound(text.splice(1, usize::MAX, "x")); + + let list = doc.get_list("list"); + list.push(1)?; + assert_out_of_bound(list.delete(usize::MAX, 1)); + assert_out_of_bound(list.delete(1, usize::MAX)); + + let movable = doc.get_movable_list("movable"); + movable.push(1)?; + assert_out_of_bound(movable.delete(usize::MAX, 1)); + assert_out_of_bound(movable.delete(1, usize::MAX)); + + let detached_text = LoroText::new(); + detached_text.insert(0, "abc")?; + assert_out_of_bound(detached_text.delete(usize::MAX, 1)); + assert_out_of_bound(detached_text.delete(1, usize::MAX)); + assert_out_of_bound(detached_text.splice(usize::MAX, 1, "x")); + assert_out_of_bound(detached_text.splice(1, usize::MAX, "x")); + + let detached_list = LoroList::new(); + detached_list.push(1)?; + assert_out_of_bound(detached_list.delete(usize::MAX, 1)); + assert_out_of_bound(detached_list.delete(1, usize::MAX)); + + let detached_movable = LoroMovableList::new(); + detached_movable.push(1)?; + assert_out_of_bound(detached_movable.delete(usize::MAX, 1)); + assert_out_of_bound(detached_movable.delete(1, usize::MAX)); + + Ok(()) +} + +#[test] +fn text_apply_delta_rejects_overflowing_retain_positions() -> LoroResult<()> { + let doc = LoroDoc::new(); + let text = doc.get_text("text"); + text.insert(0, "a")?; + + assert_out_of_bound(text.apply_delta(&[ + TextDelta::Retain { + retain: usize::MAX, + attributes: None, + }, + TextDelta::Insert { + insert: "x".to_string(), + attributes: None, + }, + ])); + + assert_out_of_bound(text.apply_delta(&[ + TextDelta::Retain { + retain: usize::MAX, + attributes: None, + }, + TextDelta::Retain { + retain: 1, + attributes: Some([("bold".to_string(), true.into())].into_iter().collect()), + }, + ])); + + Ok(()) +} + +#[test] +fn doc_apply_diff_rejects_overflowing_text_and_list_positions() -> LoroResult<()> { + let doc = LoroDoc::new(); + let text = doc.get_text("text"); + text.insert(0, "a")?; + + let mut text_batch = DiffBatch::default(); + text_batch + .push( + text.id(), + Diff::Text(vec![ + TextDelta::Retain { + retain: usize::MAX, + attributes: None, + }, + TextDelta::Retain { + retain: 1, + attributes: None, + }, + ]), + ) + .unwrap(); + assert_out_of_bound(doc.apply_diff(text_batch)); + + let list = doc.get_list("list"); + list.push(1)?; + let mut list_batch = DiffBatch::default(); + list_batch + .push( + list.id(), + Diff::List(vec![ + ListDiffItem::Retain { retain: usize::MAX }, + ListDiffItem::Retain { retain: 1 }, + ]), + ) + .unwrap(); + assert_out_of_bound(doc.apply_diff(list_batch)); + + Ok(()) +} + #[test] fn detached_bundle_contracts_cover_child_handler_lookup_attachment_and_deletion() -> LoroResult<()> { diff --git a/crates/loro/tests/contracts/value_encoding.rs b/crates/loro/tests/contracts/value_encoding.rs index e2de1fd6a..cd6a99f59 100644 --- a/crates/loro/tests/contracts/value_encoding.rs +++ b/crates/loro/tests/contracts/value_encoding.rs @@ -1,11 +1,17 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::{hash_map::DefaultHasher, HashMap}, + hash::{Hash, Hasher}, + sync::Arc, +}; use loro::{ - ContainerID, ContainerTrait, IdSpan, JsonListOp, JsonMapOp, JsonMovableListOp, JsonOpContent, - JsonTextOp, JsonTreeOp, LoroDoc, LoroList, LoroMap, LoroMovableList, LoroText, LoroTree, - LoroValue, ToJson, ValueOrContainer, VersionVector, + ContainerID, ContainerTrait, ContainerType, IdSpan, JsonListOp, JsonMapOp, JsonMovableListOp, + JsonOpContent, JsonTextOp, JsonTreeOp, LoroDoc, LoroList, LoroMap, LoroMapValue, + LoroMovableList, LoroText, LoroTree, LoroValue, ToJson, TreeID, ValueOrContainer, + VersionVector, ID, }; use pretty_assertions::assert_eq; +use rustc_hash::FxHashMap; use serde_json::{json, Value}; fn nested_value() -> LoroValue { @@ -29,6 +35,12 @@ fn value_json(value: &LoroValue) -> Value { value.to_json_value() } +fn value_hash(value: &LoroValue) -> u64 { + let mut hasher = DefaultHasher::new(); + value.hash(&mut hasher); + hasher.finish() +} + fn nested_container(map: &LoroMap, key: &str) -> T { let value = map .get(key) @@ -151,11 +163,45 @@ fn loro_value_contracts_roundtrip_for_scalars_collections_and_containers() -> an serde_json::from_value::(json!(i64::MAX - 3))?, i64_value ); + let large_u64_value = LoroValue::Double(u64::MAX as f64); + assert_eq!( + serde_json::from_value::(json!(u64::MAX))?, + large_u64_value + ); + assert_eq!(LoroValue::from(json!(u64::MAX)), large_u64_value); assert_eq!(serde_json::to_value(&float_value)?, json!(-12.25)); assert_eq!( serde_json::from_value::(json!(-12.25))?, float_value ); + for non_finite in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] { + let non_finite_value = LoroValue::Double(non_finite); + assert_eq!(serde_json::to_value(&non_finite_value)?, Value::Null); + assert_eq!(value_json(&non_finite_value), Value::Null); + assert_eq!(Value::from(non_finite_value), Value::Null); + } + assert_eq!(LoroValue::Double(f64::NAN), LoroValue::Double(f64::NAN)); + assert_eq!(LoroValue::Double(0.0), LoroValue::Double(-0.0)); + assert_eq!( + value_hash(&LoroValue::Double(0.0)), + value_hash(&LoroValue::Double(-0.0)) + ); + let mut compact_map = FxHashMap::default(); + compact_map.insert("a".to_string(), 1_i64.into()); + compact_map.insert("b".to_string(), 2_i64.into()); + let mut sparse_map = FxHashMap::default(); + for i in 0..64 { + sparse_map.insert(format!("padding-{i}"), LoroValue::from(i)); + } + for i in 0..64 { + sparse_map.remove(&format!("padding-{i}")); + } + sparse_map.insert("b".to_string(), 2_i64.into()); + sparse_map.insert("a".to_string(), 1_i64.into()); + let compact_value = LoroValue::Map(LoroMapValue::from(compact_map)); + let sparse_value = LoroValue::Map(LoroMapValue::from(sparse_map)); + assert_eq!(compact_value, sparse_value); + assert_eq!(value_hash(&compact_value), value_hash(&sparse_value)); assert_eq!(serde_json::to_value(&string_value)?, json!("hello")); assert_eq!( serde_json::from_value::(json!("hello"))?, @@ -198,6 +244,8 @@ fn loro_value_contracts_roundtrip_for_scalars_collections_and_containers() -> an assert_eq!(bool::try_from(LoroValue::from(false)).unwrap(), false); assert_eq!(f64::try_from(LoroValue::from(1.5_f64)).unwrap(), 1.5); assert_eq!(i32::try_from(LoroValue::from(123_i64)).unwrap(), 123); + assert!(i32::try_from(LoroValue::from(i64::from(i32::MAX) + 1)).is_err()); + assert!(i32::try_from(LoroValue::from(i64::from(i32::MIN) - 1)).is_err()); assert_eq!( ContainerID::try_from(container_value.clone()).unwrap(), text.id() @@ -219,6 +267,9 @@ fn loro_value_contracts_roundtrip_for_scalars_collections_and_containers() -> an Some(&LoroValue::from(vec![4_i64, 5_i64])) ); assert_eq!(list_value.get_by_index(0), Some(&LoroValue::Null)); + assert_eq!(list_value.get_by_index(-7), None); + assert_eq!(list_value.get_by_index(isize::MIN), None); + assert_eq!(LoroValue::Null.get_by_index(-1), None); assert_eq!(list_value[5], LoroValue::from(vec![4_i64, 5_i64])); assert_eq!(map_value["missing"], LoroValue::Null); assert_eq!(list_value[99], LoroValue::Null); @@ -331,6 +382,149 @@ fn json_updates_roundtrip_nested_values_and_peer_compression() -> anyhow::Result Ok(()) } +#[test] +fn import_json_updates_accepts_reordered_op_fields() -> anyhow::Result<()> { + let doc = LoroDoc::new(); + doc.set_peer_id(76)?; + let root = doc.get_map("root"); + root.insert("key", "value")?; + doc.commit(); + + let json = doc + .export_json_updates_without_peer_compression(&VersionVector::default(), &doc.oplog_vv()); + let value = serde_json::to_value(json)?; + let change = &value["changes"][0]; + let op = &change["ops"][0]; + let raw_json = format!( + r#"{{"schema_version":{},"start_version":{},"peers":{},"changes":[{{"id":{},"timestamp":{},"deps":{},"lamport":{},"msg":{},"ops":[{{"counter":{},"content":{},"container":{}}}]}}]}}"#, + serde_json::to_string(&value["schema_version"])?, + serde_json::to_string(&value["start_version"])?, + serde_json::to_string(&value["peers"])?, + serde_json::to_string(&change["id"])?, + serde_json::to_string(&change["timestamp"])?, + serde_json::to_string(&change["deps"])?, + serde_json::to_string(&change["lamport"])?, + serde_json::to_string(&change["msg"])?, + serde_json::to_string(&op["counter"])?, + serde_json::to_string(&op["content"])?, + serde_json::to_string(&op["container"])?, + ); + + let imported = LoroDoc::new(); + imported.import_json_updates(raw_json)?; + + assert_eq!( + imported.get_map("root").get_deep_value().to_json_value(), + root.get_deep_value().to_json_value() + ); + + Ok(()) +} + +#[test] +fn import_json_updates_rejects_unsupported_schema_version() -> anyhow::Result<()> { + let doc = LoroDoc::new(); + doc.get_map("root").insert("key", "value")?; + doc.commit(); + + let mut json = doc + .export_json_updates_without_peer_compression(&VersionVector::default(), &doc.oplog_vv()); + json.schema_version = 2; + + let err = LoroDoc::new().import_json_updates(json).unwrap_err(); + assert!( + err.to_string().contains("schema version"), + "expected schema version validation error, got {err:?}" + ); + + Ok(()) +} + +#[test] +fn export_json_in_id_span_clamps_negative_counter_ranges() -> anyhow::Result<()> { + let doc = LoroDoc::new(); + doc.set_peer_id(92)?; + doc.get_text("text").insert(0, "a")?; + doc.commit(); + + let changes = doc.export_json_in_id_span(IdSpan::new(92, i32::MIN, 1)); + assert_eq!(changes.len(), 1); + assert_eq!(changes[0].id.counter, 0); + + let changes = doc.export_json_in_id_span(IdSpan::new(92, i32::MIN, -1)); + assert!(changes.is_empty()); + + let changes = doc.export_json_in_id_span(IdSpan::new(92, 1, i32::MIN)); + assert_eq!(changes.len(), 1); + assert_eq!(changes[0].id.counter, 0); + + Ok(()) +} + +#[test] +fn export_json_updates_clamps_negative_version_ranges() -> anyhow::Result<()> { + let doc = LoroDoc::new(); + doc.set_peer_id(93)?; + doc.get_text("text").insert(0, "a")?; + doc.commit(); + + let mut negative_start = VersionVector::new(); + negative_start.insert(doc.peer_id(), i32::MIN); + let json = doc.export_json_updates(&negative_start, &doc.oplog_vv()); + let restored = LoroDoc::new(); + restored.import_json_updates(json)?; + assert_eq!(restored.get_text("text").to_string(), "a"); + + let mut negative_end = VersionVector::new(); + negative_end.insert(doc.peer_id(), -1); + let json = doc.export_json_updates(&VersionVector::default(), &negative_end); + let restored = LoroDoc::new(); + restored.import_json_updates(json)?; + assert_eq!(restored.get_text("text").to_string(), ""); + + Ok(()) +} + +#[test] +fn import_json_updates_rejects_negative_op_counters() -> anyhow::Result<()> { + let doc = LoroDoc::new(); + doc.get_map("root").insert("key", "value")?; + doc.commit(); + + let mut json = doc + .export_json_updates_without_peer_compression(&VersionVector::default(), &doc.oplog_vv()); + json.changes[0].id.counter = -1; + json.changes[0].ops[0].counter = -1; + + let err = LoroDoc::new().import_json_updates(json).unwrap_err(); + assert!( + err.to_string().contains("counter"), + "expected counter validation error, got {err:?}" + ); + + Ok(()) +} + +#[test] +fn import_json_updates_rejects_negative_dependency_counters() -> anyhow::Result<()> { + let doc = LoroDoc::new(); + doc.set_peer_id(77)?; + doc.get_map("root").insert("key", "value")?; + doc.commit(); + + let mut json = doc + .export_json_updates_without_peer_compression(&VersionVector::default(), &doc.oplog_vv()); + json.changes[0].deps.push(ID::new(77, -1)); + + let err = LoroDoc::new().import_json_updates(json).unwrap_err(); + assert!( + err.to_string().contains("counter"), + "expected counter validation error, got {err:?}" + ); + + Ok(()) +} + #[test] fn json_update_schema_covers_list_map_text_tree_and_movable_list_ops() -> anyhow::Result<()> { let doc = LoroDoc::new(); @@ -497,3 +691,241 @@ fn json_update_schema_covers_list_map_text_tree_and_movable_list_ops() -> anyhow Ok(()) } + +#[test] +fn import_json_updates_rejects_non_contiguous_op_counters() -> anyhow::Result<()> { + let doc = LoroDoc::new(); + doc.set_peer_id(71)?; + doc.get_map("root").insert("key", "value")?; + doc.commit(); + + let mut json = doc + .export_json_updates_without_peer_compression(&VersionVector::default(), &doc.oplog_vv()); + json.changes[0].ops[0].counter += 1; + + let err = LoroDoc::new().import_json_updates(json).unwrap_err(); + assert!( + err.to_string().contains("op counter"), + "expected op counter validation error, got {err:?}" + ); + + Ok(()) +} + +#[test] +fn import_json_updates_rejects_mismatched_created_container_id() -> anyhow::Result<()> { + let doc = LoroDoc::new(); + doc.set_peer_id(72)?; + doc.get_map("root") + .insert_container("child", LoroList::new())?; + doc.commit(); + + let mut json = doc + .export_json_updates_without_peer_compression(&VersionVector::default(), &doc.oplog_vv()); + for op in &mut json.changes[0].ops { + if let JsonOpContent::Map(JsonMapOp::Insert { + value: LoroValue::Container(id), + .. + }) = &mut op.content + { + *id = ContainerID::new_normal(ID::new(72, op.counter + 10), ContainerType::List); + break; + } + } + + let err = LoroDoc::new().import_json_updates(json).unwrap_err(); + assert!( + err.to_string().contains("container id"), + "expected container id validation error, got {err:?}" + ); + + Ok(()) +} + +#[test] +fn import_json_updates_rejects_mismatched_list_created_container_ids() -> anyhow::Result<()> { + let list_doc = LoroDoc::new(); + list_doc.set_peer_id(74)?; + list_doc + .get_list("list") + .insert_container(0, LoroMap::new())?; + list_doc.commit(); + + let mut list_json = list_doc.export_json_updates_without_peer_compression( + &VersionVector::default(), + &list_doc.oplog_vv(), + ); + for op in &mut list_json.changes[0].ops { + if let JsonOpContent::List(JsonListOp::Insert { value, .. }) = &mut op.content { + if let Some(LoroValue::Container(id)) = value.first_mut() { + *id = ContainerID::new_normal(ID::new(74, op.counter + 10), ContainerType::Map); + break; + } + } + } + + let err = LoroDoc::new().import_json_updates(list_json).unwrap_err(); + assert!( + err.to_string().contains("container id"), + "expected list container id validation error, got {err:?}" + ); + + let movable_doc = LoroDoc::new(); + movable_doc.set_peer_id(75)?; + let movable = movable_doc.get_movable_list("movable"); + movable.insert(0, "seed")?; + movable.set_container(0, LoroText::new())?; + movable_doc.commit(); + + let mut movable_json = movable_doc.export_json_updates_without_peer_compression( + &VersionVector::default(), + &movable_doc.oplog_vv(), + ); + for op in &mut movable_json.changes[0].ops { + if let JsonOpContent::MovableList(JsonMovableListOp::Set { + value: LoroValue::Container(id), + .. + }) = &mut op.content + { + *id = ContainerID::new_normal(ID::new(75, op.counter + 10), ContainerType::Text); + break; + } + } + + let err = LoroDoc::new() + .import_json_updates(movable_json) + .unwrap_err(); + assert!( + err.to_string().contains("container id"), + "expected movable list container id validation error, got {err:?}" + ); + + Ok(()) +} + +#[test] +fn import_json_updates_rejects_tree_create_target_not_matching_op_id() -> anyhow::Result<()> { + let doc = LoroDoc::new(); + doc.set_peer_id(73)?; + doc.get_tree("tree").create(None)?; + doc.commit(); + + let mut json = doc + .export_json_updates_without_peer_compression(&VersionVector::default(), &doc.oplog_vv()); + for op in &mut json.changes[0].ops { + if let JsonOpContent::Tree(JsonTreeOp::Create { target, .. }) = &mut op.content { + *target = TreeID { + peer: 73, + counter: op.counter + 10, + }; + break; + } + } + + let err = LoroDoc::new().import_json_updates(json).unwrap_err(); + assert!( + err.to_string().contains("tree target"), + "expected tree target validation error, got {err:?}" + ); + + Ok(()) +} + +#[test] +fn import_json_updates_rejects_nested_container_values() -> anyhow::Result<()> { + fn nested_container_value(peer: u64, counter: i32) -> LoroValue { + LoroValue::from(HashMap::from([( + "nested".to_string(), + LoroValue::Container(ContainerID::new_normal( + ID::new(peer, counter), + ContainerType::Map, + )), + )])) + } + + let doc = LoroDoc::new(); + doc.set_peer_id(76)?; + doc.get_map("root").insert("key", "value")?; + doc.commit(); + + let mut json = doc + .export_json_updates_without_peer_compression(&VersionVector::default(), &doc.oplog_vv()); + for op in &mut json.changes[0].ops { + if let JsonOpContent::Map(JsonMapOp::Insert { value, .. }) = &mut op.content { + *value = nested_container_value(76, op.counter); + break; + } + } + + let err = LoroDoc::new().import_json_updates(json).unwrap_err(); + assert!( + err.to_string().contains("container"), + "expected nested container validation error, got {err:?}" + ); + + let doc = LoroDoc::new(); + doc.set_peer_id(77)?; + doc.get_list("list").insert(0, "value")?; + doc.commit(); + + let mut json = doc + .export_json_updates_without_peer_compression(&VersionVector::default(), &doc.oplog_vv()); + for op in &mut json.changes[0].ops { + if let JsonOpContent::List(JsonListOp::Insert { value, .. }) = &mut op.content { + value[0] = nested_container_value(77, op.counter); + break; + } + } + + let err = LoroDoc::new().import_json_updates(json).unwrap_err(); + assert!( + err.to_string().contains("container"), + "expected nested list container validation error, got {err:?}" + ); + + let doc = LoroDoc::new(); + doc.set_peer_id(78)?; + let text = doc.get_text("text"); + text.insert(0, "a")?; + text.mark(0..1, "bold", true)?; + doc.commit(); + + let mut json = doc + .export_json_updates_without_peer_compression(&VersionVector::default(), &doc.oplog_vv()); + for op in &mut json.changes[0].ops { + if let JsonOpContent::Text(JsonTextOp::Mark { style_value, .. }) = &mut op.content { + *style_value = nested_container_value(78, op.counter); + break; + } + } + + let err = LoroDoc::new().import_json_updates(json).unwrap_err(); + assert!( + err.to_string().contains("container"), + "expected nested text style container validation error, got {err:?}" + ); + + let doc = LoroDoc::new(); + doc.set_peer_id(79)?; + let movable = doc.get_movable_list("movable"); + movable.insert(0, "seed")?; + movable.set(0, "value")?; + doc.commit(); + + let mut json = doc + .export_json_updates_without_peer_compression(&VersionVector::default(), &doc.oplog_vv()); + for op in &mut json.changes[0].ops { + if let JsonOpContent::MovableList(JsonMovableListOp::Set { value, .. }) = &mut op.content { + *value = nested_container_value(79, op.counter); + break; + } + } + + let err = LoroDoc::new().import_json_updates(json).unwrap_err(); + assert!( + err.to_string().contains("container"), + "expected nested movable list container validation error, got {err:?}" + ); + + Ok(()) +} diff --git a/crates/loro/tests/contracts/version_frontiers.rs b/crates/loro/tests/contracts/version_frontiers.rs index ddf99ee5d..baaa2496a 100644 --- a/crates/loro/tests/contracts/version_frontiers.rs +++ b/crates/loro/tests/contracts/version_frontiers.rs @@ -110,6 +110,32 @@ fn version_vector_contracts_follow_semantics() -> anyhow::Result<()> { assert_eq!(diff_left.sub_vec(&diff_right), span_map(&[(1, (1, 3))])); assert_eq!(diff_left.distance_between(&diff_right), 3); assert_eq!(zero_entry.distance_between(&diff_left), 4); + let negative = vv_pairs(&[(11, i32::MIN)]); + let negative_small = vv_pairs(&[(11, -1)]); + let max_counter = vv_pairs(&[(11, i32::MAX)]); + assert_eq!(negative, VersionVector::new()); + assert_eq!(negative.partial_cmp(&negative_small), Some(Ordering::Equal)); + assert_eq!( + negative.diff(&VersionVector::new()), + VersionVectorDiff::default() + ); + assert_eq!(negative.to_spans(), IdSpanVector::default()); + assert_eq!( + sorted_spans(max_counter.sub_iter(&negative)), + vec![(11, 0, i32::MAX)] + ); + assert_eq!( + sorted_spans(negative.sub_iter(&VersionVector::new())), + Vec::<(u64, i32, i32)>::new() + ); + assert_eq!( + VersionVector::new().get_missing_span(&negative), + Vec::::new() + ); + assert_eq!(negative.distance_between(&VersionVector::new()), 0); + assert_eq!(VersionVector::new().distance_between(&negative), 0); + assert_eq!(negative.distance_between(&max_counter), i32::MAX as usize); + assert_eq!(VersionRange::from_vv(&negative), VersionRange::new()); assert_eq!(diff_left.to_spans(), span_map(&[(1, (0, 3)), (2, (0, 1))])); assert_eq!( diff_left.get_frontiers(), @@ -141,6 +167,10 @@ fn version_vector_contracts_follow_semantics() -> anyhow::Result<()> { assert_eq!(adjust.get(&5), Some(&3)); adjust.set_end(ID::new(5, 0)); assert!(!adjust.contains_key(&5)); + assert!(!adjust.try_update_last(ID::new(6, -1))); + assert!(!adjust.contains_key(&6)); + adjust.set_last(ID::new(6, i32::MAX)); + assert_eq!(adjust.get(&6), Some(&i32::MAX)); let mut span_ops = vv_pairs(&[(10, 2)]); span_ops.extend_to_include_last_id(ID::new(10, 3)); @@ -152,6 +182,15 @@ fn version_vector_contracts_follow_semantics() -> anyhow::Result<()> { span_ops.shrink_to_exclude(span(10, 0, 2)); assert!(!span_ops.contains_key(&10)); + let mut invalid_spans = VersionVector::new(); + invalid_spans.extend_to_include(span(12, -3, -1)); + assert!(!invalid_spans.contains_key(&12)); + invalid_spans.set_end(ID::new(12, 5)); + invalid_spans.shrink_to_exclude(span(12, -3, -1)); + assert_eq!(invalid_spans.get(&12), Some(&5)); + invalid_spans.shrink_to_exclude(span(12, -3, 2)); + assert!(!invalid_spans.contains_key(&12)); + let mut span_ops = VersionVector::new(); span_ops.forward(&{ let mut spans = IdSpanVector::default(); @@ -195,6 +234,10 @@ fn version_vector_contracts_follow_semantics() -> anyhow::Result<()> { assert!(im2.contains_key(&3)); im2.set_last(ID::new(4, 1)); assert_eq!(im2.to_vv().get(&4), Some(&2)); + im2.set_last(ID::new(4, -1)); + assert!(!im2.contains_key(&4)); + im2.set_last(ID::new(4, i32::MAX)); + assert_eq!(im2.get(&4), Some(&i32::MAX)); let im_encoded = im2.encode(); assert_eq!(ImVersionVector::decode(&im_encoded)?, im2); im2.clear(); @@ -347,6 +390,7 @@ fn frontiers_contracts_follow_semantics() -> anyhow::Result<()> { assert!(!doc_frontiers.is_empty()); let vv = doc.frontiers_to_vv(&doc_frontiers).unwrap(); assert_eq!(doc.vv_to_frontiers(&vv), doc_frontiers); + assert_eq!(doc.vv_to_frontiers(&vv_pairs(&[(77, -1)])), Frontiers::None); assert_eq!( doc.frontiers_to_vv(&Frontiers::None), Some(VersionVector::new()) diff --git a/crates/loro/tests/integration_test/redact_test.rs b/crates/loro/tests/integration_test/redact_test.rs index 95239af8c..1d5601576 100644 --- a/crates/loro/tests/integration_test/redact_test.rs +++ b/crates/loro/tests/integration_test/redact_test.rs @@ -1,6 +1,7 @@ -use loro::json::redact; +use loro::json::{redact, RedactError}; use loro::{LoroDoc, LoroList, LoroMovableList, LoroTree, LoroValue}; use loro_internal::version::VersionRange; +use std::panic::{catch_unwind, AssertUnwindSafe}; #[test] fn redact_text_doc() { @@ -25,6 +26,29 @@ fn redact_text_doc() { assert_ne!(text.to_string(), redacted_text.to_string()); } +#[test] +fn redact_rejects_overflowing_json_counters_without_panicking() { + let doc = LoroDoc::new(); + doc.set_peer_id(1).unwrap(); + let text = doc.get_text("text"); + text.insert(0, "secret").unwrap(); + + let mut json = doc.export_json_updates(&Default::default(), &doc.oplog_vv()); + let change = &mut json.changes[0]; + change.id.counter = i32::MAX; + change.ops[0].counter = i32::MAX; + + let mut range = VersionRange::new(); + range.insert(1, 0, i32::MAX); + let result = catch_unwind(AssertUnwindSafe(|| redact(&mut json, range))); + + assert!(result.is_ok()); + assert!(matches!( + result.unwrap(), + Err(RedactError::InvalidSchema(_)) + )); +} + #[test] fn redact_map_list_insertions() { let doc = LoroDoc::new();