diff --git a/encryption/cache.go b/encryption/cache.go new file mode 100644 index 0000000..54ed107 --- /dev/null +++ b/encryption/cache.go @@ -0,0 +1,179 @@ +package encryption + +import ( + "container/list" + "sync" + "time" +) + +// CacheConfig configures the optional DEK (data encryption key) cache. +// A zero-value config disables caching. +type CacheConfig struct { + // MaxSize is the maximum number of DEKs to cache. Must be > 0 to enable caching. + MaxSize int + // TTL is the time-to-live for each cache entry. Must be > 0 to enable caching. + TTL time.Duration +} + +// dekCacheEntry holds a cached DEK and its LRU/TTL metadata. +type dekCacheEntry struct { + dek []byte // 32-byte AES-256 key (owned copy) + keyRef string // for reverse lookup from LRU list element + expiresAt time.Time + element *list.Element // back-pointer into LRU list +} + +// dekCache is a thread-safe LRU cache for decrypted data encryption keys. +// It zeroes key material on every eviction path (TTL, LRU, delete, clear). +type dekCache struct { + mu sync.Mutex + entries map[string]*dekCacheEntry + order *list.List // front = most recently used + maxSize int + ttl time.Duration + + inflightMu sync.Mutex + inflight map[string]*inflightEntry +} + +// inflightEntry coordinates singleflight deduplication for concurrent +// cache misses on the same keyRef. +type inflightEntry struct { + done chan struct{} + dek []byte + err error +} + +func newDEKCache(maxSize int, ttl time.Duration) *dekCache { + return &dekCache{ + entries: make(map[string]*dekCacheEntry), + order: list.New(), + maxSize: maxSize, + ttl: ttl, + inflight: make(map[string]*inflightEntry), + } +} + +// get returns a copy of the cached DEK for keyRef, or ok=false on miss/expiry. +func (c *dekCache) get(keyRef string) ([]byte, bool) { + c.mu.Lock() + defer c.mu.Unlock() + + entry, ok := c.entries[keyRef] + if !ok { + return nil, false + } + + if time.Now().After(entry.expiresAt) { + c.evictLocked(entry) + return nil, false + } + + c.order.MoveToFront(entry.element) + return copyBytes(entry.dek), true +} + +// put stores a copy of dek in the cache, evicting the LRU entry if full. +func (c *dekCache) put(keyRef string, dek []byte) { + c.mu.Lock() + defer c.mu.Unlock() + + if entry, ok := c.entries[keyRef]; ok { + // Update existing entry. + clear(entry.dek) + entry.dek = copyBytes(dek) + entry.expiresAt = time.Now().Add(c.ttl) + c.order.MoveToFront(entry.element) + return + } + + entry := &dekCacheEntry{ + dek: copyBytes(dek), + keyRef: keyRef, + expiresAt: time.Now().Add(c.ttl), + } + entry.element = c.order.PushFront(entry) + c.entries[keyRef] = entry + + if len(c.entries) > c.maxSize { + back := c.order.Back() + if back != nil { + c.evictLocked(back.Value.(*dekCacheEntry)) + } + } +} + +// delete removes and zeroes a specific entry. Called by RotateKey. +func (c *dekCache) delete(keyRef string) { + c.mu.Lock() + defer c.mu.Unlock() + + if entry, ok := c.entries[keyRef]; ok { + c.evictLocked(entry) + } +} + +// clear removes and zeroes all entries. +func (c *dekCache) clear() { + c.mu.Lock() + defer c.mu.Unlock() + + for _, entry := range c.entries { + clear(entry.dek) + } + c.entries = make(map[string]*dekCacheEntry) + c.order.Init() +} + +// evictLocked removes an entry, zeroing its DEK. Caller must hold c.mu. +func (c *dekCache) evictLocked(entry *dekCacheEntry) { + clear(entry.dek) + c.order.Remove(entry.element) + delete(c.entries, entry.keyRef) +} + +// waitOrStart implements singleflight deduplication. If another goroutine is +// already fetching the DEK for keyRef, started=false and wait blocks until the +// result is available. Otherwise started=true and the caller must call finish. +func (c *dekCache) waitOrStart(keyRef string) (started bool, wait func() ([]byte, error)) { + c.inflightMu.Lock() + + if entry, ok := c.inflight[keyRef]; ok { + c.inflightMu.Unlock() + return false, func() ([]byte, error) { + <-entry.done + if entry.err != nil { + return nil, entry.err + } + return copyBytes(entry.dek), nil + } + } + + entry := &inflightEntry{done: make(chan struct{})} + c.inflight[keyRef] = entry + c.inflightMu.Unlock() + + return true, nil +} + +// finish signals all waiters for keyRef with the fetch result. +func (c *dekCache) finish(keyRef string, dek []byte, err error) { + c.inflightMu.Lock() + entry, ok := c.inflight[keyRef] + if ok { + entry.dek = copyBytes(dek) + entry.err = err + close(entry.done) + delete(c.inflight, keyRef) + } + c.inflightMu.Unlock() +} + +func copyBytes(b []byte) []byte { + if b == nil { + return nil + } + cp := make([]byte, len(b)) + copy(cp, b) + return cp +} diff --git a/encryption/cache_test.go b/encryption/cache_test.go new file mode 100644 index 0000000..4de3f05 --- /dev/null +++ b/encryption/cache_test.go @@ -0,0 +1,384 @@ +package encryption + +import ( + "sync" + "testing" + "time" +) + +func TestDEKCache_GetPut(t *testing.T) { + c := newDEKCache(10, time.Minute) + dek := []byte("0123456789abcdef0123456789abcdef") + + c.put("ref1", dek) + + got, ok := c.get("ref1") + if !ok { + t.Fatal("expected cache hit") + } + if string(got) != string(dek) { + t.Fatalf("got %x, want %x", got, dek) + } + + // Returned slice must be an independent copy. + got[0] = 0xFF + got2, ok := c.get("ref1") + if !ok { + t.Fatal("expected cache hit") + } + if got2[0] == 0xFF { + t.Fatal("cache returned same underlying slice, expected independent copy") + } + + // Stored slice must be an independent copy of the input. + dek[0] = 0xAA + got3, ok := c.get("ref1") + if !ok { + t.Fatal("expected cache hit") + } + if got3[0] == 0xAA { + t.Fatal("cache stored same underlying slice as input, expected independent copy") + } +} + +func TestDEKCache_Miss(t *testing.T) { + c := newDEKCache(10, time.Minute) + + _, ok := c.get("unknown") + if ok { + t.Fatal("expected cache miss for unknown key") + } +} + +func TestDEKCache_TTLExpiry(t *testing.T) { + c := newDEKCache(10, time.Millisecond) + dek := []byte("0123456789abcdef0123456789abcdef") + + c.put("ref1", dek) + + // Grab internal slice before expiry. + c.mu.Lock() + internalSlice := c.entries["ref1"].dek + c.mu.Unlock() + + time.Sleep(5 * time.Millisecond) + + _, ok := c.get("ref1") + if ok { + t.Fatal("expected cache miss after TTL expiry") + } + + // Expired entry should be removed from map. + c.mu.Lock() + _, exists := c.entries["ref1"] + c.mu.Unlock() + if exists { + t.Fatal("expired entry should have been removed from map") + } + + // Internal DEK should be zeroed. + for _, b := range internalSlice { + if b != 0 { + t.Fatal("expired DEK should have been zeroed") + } + } +} + +func TestDEKCache_LRUEviction(t *testing.T) { + c := newDEKCache(2, time.Minute) + + c.put("ref1", []byte("key1key1key1key1key1key1key1key1")) + c.put("ref2", []byte("key2key2key2key2key2key2key2key2")) + + // Grab internal slice of ref2 before eviction. + c.mu.Lock() + internalRef2 := c.entries["ref2"].dek + c.mu.Unlock() + + // Access ref1 to make it more recent than ref2. + _, _ = c.get("ref1") + + // Adding ref3 should evict ref2 (LRU). + c.put("ref3", []byte("key3key3key3key3key3key3key3key3")) + + if _, ok := c.get("ref2"); ok { + t.Fatal("expected ref2 to be evicted (LRU)") + } + if _, ok := c.get("ref1"); !ok { + t.Fatal("expected ref1 to still be cached") + } + if _, ok := c.get("ref3"); !ok { + t.Fatal("expected ref3 to still be cached") + } + + // Verify evicted internal DEK was zeroed. + for _, b := range internalRef2 { + if b != 0 { + t.Fatal("evicted DEK should have been zeroed") + } + } +} + +func TestDEKCache_PutUpdatesExisting(t *testing.T) { + c := newDEKCache(10, time.Minute) + dek1 := []byte("old_key_old_key_old_key_old_key_") + dek2 := []byte("new_key_new_key_new_key_new_key_") + + c.put("ref1", dek1) + + // Grab internal slice and expiry before update. + c.mu.Lock() + internalSlice := c.entries["ref1"].dek + oldExpiry := c.entries["ref1"].expiresAt + c.mu.Unlock() + + time.Sleep(time.Millisecond) // ensure time advances + c.put("ref1", dek2) + + got, ok := c.get("ref1") + if !ok { + t.Fatal("expected cache hit") + } + if string(got) != string(dek2) { + t.Fatalf("got %x, want %x", got, dek2) + } + + // Old internal slice should be zeroed. + for _, b := range internalSlice { + if b != 0 { + t.Fatal("old DEK slice should have been zeroed") + } + } + + // TTL should be refreshed. + c.mu.Lock() + newExpiry := c.entries["ref1"].expiresAt + c.mu.Unlock() + if !newExpiry.After(oldExpiry) { + t.Fatal("put on existing key should refresh TTL") + } +} + +func TestDEKCache_Delete(t *testing.T) { + c := newDEKCache(10, time.Minute) + dek := []byte("0123456789abcdef0123456789abcdef") + + c.put("ref1", dek) + + // Grab internal slice reference. + c.mu.Lock() + internalSlice := c.entries["ref1"].dek + c.mu.Unlock() + + c.delete("ref1") + + if _, ok := c.get("ref1"); ok { + t.Fatal("expected cache miss after delete") + } + + // Verify zeroed. + for _, b := range internalSlice { + if b != 0 { + t.Fatal("deleted DEK should have been zeroed") + } + } +} + +func TestDEKCache_DeleteMissing(t *testing.T) { + c := newDEKCache(10, time.Minute) + // Should not panic. + c.delete("nonexistent") +} + +func TestDEKCache_Clear(t *testing.T) { + c := newDEKCache(10, time.Minute) + + dek1 := []byte("key1key1key1key1key1key1key1key1") + dek2 := []byte("key2key2key2key2key2key2key2key2") + c.put("ref1", dek1) + c.put("ref2", dek2) + + // Grab internal slice references. + c.mu.Lock() + internal1 := c.entries["ref1"].dek + internal2 := c.entries["ref2"].dek + c.mu.Unlock() + + c.clear() + + if _, ok := c.get("ref1"); ok { + t.Fatal("expected miss after clear") + } + if _, ok := c.get("ref2"); ok { + t.Fatal("expected miss after clear") + } + + for _, b := range internal1 { + if b != 0 { + t.Fatal("cleared DEK 1 should be zeroed") + } + } + for _, b := range internal2 { + if b != 0 { + t.Fatal("cleared DEK 2 should be zeroed") + } + } +} + +func TestDEKCache_ClearThenReuse(t *testing.T) { + c := newDEKCache(10, time.Minute) + dek := []byte("0123456789abcdef0123456789abcdef") + + c.put("ref1", dek) + c.clear() + + // Cache should work normally after clear. + c.put("ref1", dek) + got, ok := c.get("ref1") + if !ok { + t.Fatal("expected cache hit after clear + put") + } + if string(got) != string(dek) { + t.Fatalf("got %x, want %x", got, dek) + } +} + +func TestDEKCache_Singleflight(t *testing.T) { + c := newDEKCache(10, time.Minute) + + dek := []byte("0123456789abcdef0123456789abcdef") + + // First caller starts the fetch. + started, _ := c.waitOrStart("ref1") + if !started { + t.Fatal("first caller should start") + } + + // Second and third callers should wait. + var wg sync.WaitGroup + results := make([][]byte, 2) + errs := make([]error, 2) + + for i := 0; i < 2; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + s, wait := c.waitOrStart("ref1") + if s { + t.Error("subsequent caller should not start") + return + } + results[idx], errs[idx] = wait() + }(i) + } + + // Simulate fetch completing. + time.Sleep(10 * time.Millisecond) // let goroutines reach wait() + c.finish("ref1", dek, nil) + + wg.Wait() + + for i := 0; i < 2; i++ { + if errs[i] != nil { + t.Fatalf("waiter %d got error: %v", i, errs[i]) + } + if string(results[i]) != string(dek) { + t.Fatalf("waiter %d got wrong dek", i) + } + } + + // Each waiter should have received an independent copy. + results[0][0] = 0xFF + if results[1][0] == 0xFF { + t.Fatal("waiters should receive independent copies") + } +} + +func TestDEKCache_SingleflightError(t *testing.T) { + c := newDEKCache(10, time.Minute) + fetchErr := &testError{msg: "kms failed"} + + started, _ := c.waitOrStart("ref1") + if !started { + t.Fatal("first caller should start") + } + + var wg sync.WaitGroup + waiterErrs := make([]error, 2) + + for i := 0; i < 2; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + _, wait := c.waitOrStart("ref1") + _, waiterErrs[idx] = wait() + }(i) + } + + time.Sleep(10 * time.Millisecond) + c.finish("ref1", nil, fetchErr) + + wg.Wait() + + for i := 0; i < 2; i++ { + if waiterErrs[i] == nil { + t.Fatalf("waiter %d should have received error", i) + } + if waiterErrs[i].Error() != "kms failed" { + t.Fatalf("waiter %d got error %q, want %q", i, waiterErrs[i].Error(), "kms failed") + } + } +} + +func TestDEKCache_SingleflightIndependentKeys(t *testing.T) { + c := newDEKCache(10, time.Minute) + + dek1 := []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") + dek2 := []byte("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") + + // Start fetch for ref1. + started1, _ := c.waitOrStart("ref1") + if !started1 { + t.Fatal("first caller for ref1 should start") + } + + // Start fetch for ref2 — should NOT be blocked by ref1. + started2, _ := c.waitOrStart("ref2") + if !started2 { + t.Fatal("first caller for ref2 should start independently") + } + + c.finish("ref1", dek1, nil) + c.finish("ref2", dek2, nil) +} + +func TestDEKCache_ConcurrentAccess(t *testing.T) { + c := newDEKCache(10, time.Minute) + dek := []byte("0123456789abcdef0123456789abcdef") + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(3) + go func(n int) { + defer wg.Done() + c.put("ref1", dek) + }(i) + go func(n int) { + defer wg.Done() + c.get("ref1") + }(i) + go func(n int) { + defer wg.Done() + if n%10 == 0 { + c.delete("ref1") + } + }(i) + } + wg.Wait() +} + +type testError struct { + msg string +} + +func (e *testError) Error() string { return e.msg } diff --git a/encryption/pool.go b/encryption/pool.go index 4a9284d..a750ab5 100644 --- a/encryption/pool.go +++ b/encryption/pool.go @@ -42,19 +42,38 @@ type Pool struct { keysTable KeysTable dataTables []EncryptedDataTable logger *slog.Logger + cache *dekCache // nil when caching disabled } -func NewPool(attester Attester, configs []*Config, keysTable KeysTable, dataTables []EncryptedDataTable, logger *slog.Logger) *Pool { +// PoolOption configures optional Pool behavior. +type PoolOption func(*Pool) + +// WithCache enables an in-memory LRU cache for decrypted data encryption keys, +// eliminating KMS round-trips on cache hits. The cache is local to this process +// and zeroes key material on eviction. +func WithCache(cfg CacheConfig) PoolOption { + return func(p *Pool) { + if cfg.MaxSize > 0 && cfg.TTL > 0 { + p.cache = newDEKCache(cfg.MaxSize, cfg.TTL) + } + } +} + +func NewPool(attester Attester, configs []*Config, keysTable KeysTable, dataTables []EncryptedDataTable, logger *slog.Logger, opts ...PoolOption) *Pool { if logger == nil { logger = slog.Default() } - return &Pool{ + p := &Pool{ attester: attester, configs: configs, keysTable: keysTable, dataTables: dataTables, logger: logger, } + for _, opt := range opts { + opt(p) + } + return p } // Encrypt encrypts the plaintext using a randomly selected cipher key from the Pool. It returns the key reference @@ -95,17 +114,32 @@ func (p *Pool) Encrypt(ctx context.Context, att *enclave.Attestation, plaintext if err != nil { return "", nil, fmt.Errorf("generate key: %w", err) } - } else if err := p.VerifyKey(ctx, att, key); err != nil { - return "", nil, fmt.Errorf("verify key: %w", err) - } - span.SetAnnotation("key_ref", key.KeyRef) - if privateKey == nil { - privateKey, err = p.combineShares(ctx, att, config, key.EncryptedShares) - if err != nil { - return "", nil, fmt.Errorf("combine shares: %w", err) + if p.cache != nil { + p.cache.put(key.KeyRef, privateKey) + } + } else { + // Existing key — try cache before KMS. + if p.cache != nil { + if dek, ok := p.cache.get(key.KeyRef); ok { + privateKey = dek + } + } + if privateKey == nil { + if err := p.VerifyKey(ctx, att, key); err != nil { + return "", nil, fmt.Errorf("verify key: %w", err) + } + privateKey, err = p.combineShares(ctx, att, config, key.EncryptedShares) + if err != nil { + return "", nil, fmt.Errorf("combine shares: %w", err) + } + + if p.cache != nil { + p.cache.put(key.KeyRef, privateKey) + } } } + span.SetAnnotation("key_ref", key.KeyRef) encrypted, err := aesgcm.Encrypt(att, privateKey, plaintext, additionalData) if err != nil { @@ -127,6 +161,7 @@ func (p *Pool) Encrypt(ctx context.Context, att *enclave.Attestation, plaintext // Decrypt decrypts the ciphertext using the latest cipher key from the Pool referenced by the keyRef. // // The key is verified against the attestation and migrated to the current generation if needed. +// If a DEK cache is configured, cached keys bypass DynamoDB and KMS on hit. func (p *Pool) Decrypt(ctx context.Context, att *enclave.Attestation, keyRef string, ciphertext []byte, additionalData []byte) (plaintext []byte, err error) { ctx, span := tracing.Trace(ctx, "encryption.Pool.Decrypt", tracing.WithAnnotation("key_ref", keyRef)) defer func() { @@ -139,6 +174,50 @@ func (p *Pool) Decrypt(ctx context.Context, att *enclave.Attestation, keyRef str return nil, fmt.Errorf("decode ciphertext: %w", err) } + // Try cache. + var privateKey []byte + if p.cache != nil { + if dek, ok := p.cache.get(keyRef); ok { + privateKey = dek + } + } + + // Cache miss — full fetch with singleflight dedup. + if privateKey == nil { + privateKey, err = p.fetchDEK(ctx, att, keyRef) + if err != nil { + return nil, err + } + } + + // Decrypt data. + var decrypted []byte + switch decoded.Version { + case 1: + decrypted, err = aescbc.Decrypt(privateKey, decoded.EncryptedData) + case 2, 3: + decrypted, err = aesgcm.Decrypt(privateKey, decoded.EncryptedData, additionalData) + } + if err != nil { + return nil, fmt.Errorf("decrypt: %w", err) + } + + return decrypted, nil +} + +// fetchDEK retrieves a DEK through the full path: DynamoDB lookup, attestation +// verification, KMS share decryption, and Shamir combine. It uses singleflight +// to deduplicate concurrent fetches for the same keyRef, and populates the cache. +func (p *Pool) fetchDEK(ctx context.Context, att *enclave.Attestation, keyRef string) (privateKey []byte, err error) { + // Singleflight: if another goroutine is already fetching this keyRef, wait. + if p.cache != nil { + started, wait := p.cache.waitOrStart(keyRef) + if !started { + return wait() + } + defer func() { p.cache.finish(keyRef, privateKey, err) }() + } + key, found, err := p.keysTable.GetLatestByKeyRef(ctx, keyRef, false) if err != nil { return nil, fmt.Errorf("get latest key: %w", err) @@ -150,11 +229,6 @@ func (p *Pool) Decrypt(ctx context.Context, att *enclave.Attestation, keyRef str return nil, fmt.Errorf("verify key: %w", err) } - span.SetAnnotation("generation", strconv.Itoa(key.Generation)) - if key.KeyIndex != nil { - span.SetAnnotation("key_index", strconv.Itoa(*key.KeyIndex)) - } - config, err := p.getConfig(key.Generation) if err != nil { return nil, fmt.Errorf("get config: %w", err) @@ -163,31 +237,24 @@ func (p *Pool) Decrypt(ctx context.Context, att *enclave.Attestation, keyRef str return nil, fmt.Errorf("shares are invalid") } - privateKey, err := p.combineShares(ctx, att, config, key.EncryptedShares) + privateKey, err = p.combineShares(ctx, att, config, key.EncryptedShares) if err != nil { return nil, fmt.Errorf("combine shares: %w", err) } - var decrypted []byte - switch decoded.Version { - case 1: - decrypted, err = aescbc.Decrypt(privateKey, decoded.EncryptedData) - case 2, 3: - decrypted, err = aesgcm.Decrypt(privateKey, decoded.EncryptedData, additionalData) - } - if err != nil { - return nil, fmt.Errorf("decrypt: %w", err) + if p.cache != nil { + p.cache.put(keyRef, privateKey) } + // Trigger migration if needed. Migration is synchronous but non-fatal: + // failure is logged and does not affect the returned DEK. if p.keyNeedsMigration(key) { - err := p.migrateKey(ctx, att, key, privateKey) - if err != nil { - // We don't want to fail the decryption if migration fails, log the error and continue + if err := p.migrateKey(ctx, att, key, privateKey); err != nil { p.logger.ErrorContext(ctx, "migrating key failed", "error", err, "key_ref", key.KeyRef, "generation", key.Generation, "key_index", key.KeyIndex) } } - return decrypted, nil + return privateKey, nil } // RotateKey marks a key as inactive by setting its KeyIndex to a negative value. It won't be used for encrypting @@ -232,6 +299,10 @@ func (p *Pool) RotateKey(ctx context.Context, att *enclave.Attestation, keyRef s return fmt.Errorf("deactivate key: %w", err) } + if p.cache != nil { + p.cache.delete(keyRef) + } + return nil } diff --git a/encryption/pool_test.go b/encryption/pool_test.go index 097488d..7be8a89 100644 --- a/encryption/pool_test.go +++ b/encryption/pool_test.go @@ -5,7 +5,9 @@ import ( "crypto/x509" "encoding/pem" "errors" + "sync" "testing" + "time" "github.com/0xsequence/nitrocontrol/enclave" "github.com/0xsequence/nitrocontrol/encryption" @@ -1028,3 +1030,538 @@ func TestPool_CleanupUnusedKeys(t *testing.T) { require.Equal(t, 0, deleted) }) } + +func TestPool_DecryptCacheHit(t *testing.T) { + block, _ := pem.Decode([]byte(dummyPrivKey)) + privKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + require.NoError(t, err) + + kmsClient := &MockKMS{} + remoteKey1 := &MockRemoteKey{} + remoteKey2 := &MockRemoteKey{} + keysTable := &MockKeysTable{} + + random := &constantReader{value: 0x42} + enc, err := enclave.New(context.Background(), enclave.DummyProvider(random), kmsClient, privKey) + require.NoError(t, err) + + configs := []*encryption.Config{ + { + PoolSize: 10, + Threshold: 2, + RemoteKeys: map[string]encryption.RemoteKey{ + "remoteKey1": remoteKey1, + "remoteKey2": remoteKey2, + }, + }, + } + + att, err := enc.GetAttestation(context.Background(), nil, nil) + require.NoError(t, err) + defer func() { _ = att.Close() }() + + cipherKey, privateKey := newCipherKey(t, enc) + shares, err := shamir.Split(privateKey, 2, 2) + require.NoError(t, err) + + // Encrypt to get a ciphertext (this also populates the cache via Encrypt path). + keysTable.On("Get", mock.Anything, 0, 4).Return(cipherKey, true, nil) + remoteKey1.On("Decrypt", mock.Anything, att, "encryptedShare1").Return(shares[0], nil) + remoteKey2.On("Decrypt", mock.Anything, att, "encryptedShare2").Return(shares[1], nil) + keysTable.On("GetLatestByKeyRef", mock.Anything, "cipherKey4", false).Return(cipherKey, true, nil) + + pool := encryption.NewPool(enc, configs, keysTable, nil, nil, + encryption.WithCache(encryption.CacheConfig{MaxSize: 10, TTL: time.Minute})) + + keyRef, ciphertext, err := pool.Encrypt(context.Background(), att, []byte("test"), []byte("aad")) + require.NoError(t, err) + + callsAfterEncrypt := len(remoteKey1.Calls) + + // Decrypt — cache hit from Encrypt's put. No additional KMS calls. + plaintext, err := pool.Decrypt(context.Background(), att, keyRef, ciphertext, []byte("aad")) + require.NoError(t, err) + require.Equal(t, "test", string(plaintext)) + require.Equal(t, callsAfterEncrypt, len(remoteKey1.Calls), "Decrypt should hit cache populated by Encrypt") + + // Second Decrypt — still a cache hit. + plaintext, err = pool.Decrypt(context.Background(), att, keyRef, ciphertext, []byte("aad")) + require.NoError(t, err) + require.Equal(t, "test", string(plaintext)) + require.Equal(t, callsAfterEncrypt, len(remoteKey1.Calls), "repeated Decrypt should hit cache") +} + +// TestPool_DecryptPopulatesCache verifies the Decrypt path itself populates the +// cache (via fetchDEK), independent of Encrypt. +func TestPool_DecryptPopulatesCache(t *testing.T) { + block, _ := pem.Decode([]byte(dummyPrivKey)) + privKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + require.NoError(t, err) + + kmsClient := &MockKMS{} + remoteKey1 := &MockRemoteKey{} + remoteKey2 := &MockRemoteKey{} + keysTable := &MockKeysTable{} + + random := &constantReader{value: 0x42} + enc, err := enclave.New(context.Background(), enclave.DummyProvider(random), kmsClient, privKey) + require.NoError(t, err) + + configs := []*encryption.Config{ + { + PoolSize: 10, + Threshold: 2, + RemoteKeys: map[string]encryption.RemoteKey{ + "remoteKey1": remoteKey1, + "remoteKey2": remoteKey2, + }, + }, + } + + att, err := enc.GetAttestation(context.Background(), nil, nil) + require.NoError(t, err) + defer func() { _ = att.Close() }() + + cipherKey, privateKey := newCipherKey(t, enc) + shares, err := shamir.Split(privateKey, 2, 2) + require.NoError(t, err) + + keysTable.On("Get", mock.Anything, 0, 4).Return(cipherKey, true, nil) + keysTable.On("GetLatestByKeyRef", mock.Anything, "cipherKey4", false).Return(cipherKey, true, nil) + remoteKey1.On("Decrypt", mock.Anything, att, "encryptedShare1").Return(shares[0], nil) + remoteKey2.On("Decrypt", mock.Anything, att, "encryptedShare2").Return(shares[1], nil) + + // Encrypt WITHOUT cache to get a ciphertext. + poolNoCache := encryption.NewPool(enc, configs, keysTable, nil, nil) + keyRef, ciphertext, err := poolNoCache.Encrypt(context.Background(), att, []byte("test"), []byte("aad")) + require.NoError(t, err) + + // Create a NEW pool with cache — cache is cold. + poolWithCache := encryption.NewPool(enc, configs, keysTable, nil, nil, + encryption.WithCache(encryption.CacheConfig{MaxSize: 10, TTL: time.Minute})) + + // First Decrypt — cache miss, calls KMS. + plaintext, err := poolWithCache.Decrypt(context.Background(), att, keyRef, ciphertext, []byte("aad")) + require.NoError(t, err) + require.Equal(t, "test", string(plaintext)) + + callsAfterFirst := len(remoteKey1.Calls) + require.Greater(t, callsAfterFirst, 0, "first Decrypt should have called KMS") + + // Second Decrypt — cache hit from fetchDEK's put. No additional KMS calls. + plaintext, err = poolWithCache.Decrypt(context.Background(), att, keyRef, ciphertext, []byte("aad")) + require.NoError(t, err) + require.Equal(t, "test", string(plaintext)) + require.Equal(t, callsAfterFirst, len(remoteKey1.Calls), "second Decrypt should hit cache populated by fetchDEK") +} + +func TestPool_EncryptCacheHit(t *testing.T) { + block, _ := pem.Decode([]byte(dummyPrivKey)) + privKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + require.NoError(t, err) + + kmsClient := &MockKMS{} + remoteKey1 := &MockRemoteKey{} + remoteKey2 := &MockRemoteKey{} + keysTable := &MockKeysTable{} + + random := &constantReader{value: 0x42} + enc, err := enclave.New(context.Background(), enclave.DummyProvider(random), kmsClient, privKey) + require.NoError(t, err) + + configs := []*encryption.Config{ + { + PoolSize: 10, + Threshold: 2, + RemoteKeys: map[string]encryption.RemoteKey{ + "remoteKey1": remoteKey1, + "remoteKey2": remoteKey2, + }, + }, + } + + att, err := enc.GetAttestation(context.Background(), nil, nil) + require.NoError(t, err) + defer func() { _ = att.Close() }() + + cipherKey, privateKey := newCipherKey(t, enc) + shares, err := shamir.Split(privateKey, 2, 2) + require.NoError(t, err) + + keysTable.On("Get", mock.Anything, 0, 4).Return(cipherKey, true, nil) + remoteKey1.On("Decrypt", mock.Anything, att, "encryptedShare1").Return(shares[0], nil) + remoteKey2.On("Decrypt", mock.Anything, att, "encryptedShare2").Return(shares[1], nil) + + pool := encryption.NewPool(enc, configs, keysTable, nil, nil, + encryption.WithCache(encryption.CacheConfig{MaxSize: 10, TTL: time.Minute})) + + // First encrypt — cache miss, calls KMS Decrypt to combine shares. + _, _, err = pool.Encrypt(context.Background(), att, []byte("test1"), []byte("aad")) + require.NoError(t, err) + + decrypt1Calls := len(remoteKey1.Calls) + + // Second encrypt — same key index (deterministic random), should be a cache hit. + _, _, err = pool.Encrypt(context.Background(), att, []byte("test2"), []byte("aad")) + require.NoError(t, err) + + require.Equal(t, decrypt1Calls, len(remoteKey1.Calls), "expected no additional KMS calls on encrypt cache hit") +} + +func TestPool_RotateInvalidatesCache(t *testing.T) { + block, _ := pem.Decode([]byte(dummyPrivKey)) + privKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + require.NoError(t, err) + + kmsClient := &MockKMS{} + remoteKey1 := &MockRemoteKey{} + remoteKey2 := &MockRemoteKey{} + keysTable := &MockKeysTable{} + + random := &constantReader{value: 0x42} + enc, err := enclave.New(context.Background(), enclave.DummyProvider(random), kmsClient, privKey) + require.NoError(t, err) + + configs := []*encryption.Config{ + { + PoolSize: 10, + Threshold: 2, + RemoteKeys: map[string]encryption.RemoteKey{ + "remoteKey1": remoteKey1, + "remoteKey2": remoteKey2, + }, + }, + } + + att, err := enc.GetAttestation(context.Background(), nil, nil) + require.NoError(t, err) + defer func() { _ = att.Close() }() + + cipherKey, privateKey := newCipherKey(t, enc) + shares, err := shamir.Split(privateKey, 2, 2) + require.NoError(t, err) + + keysTable.On("Get", mock.Anything, 0, 4).Return(cipherKey, true, nil) + keysTable.On("GetLatestByKeyRef", mock.Anything, "cipherKey4", false).Return(cipherKey, true, nil) + keysTable.On("GetLatestByKeyRef", mock.Anything, "cipherKey4", true).Return(cipherKey, true, nil) + keysTable.On("Deactivate", mock.Anything, "cipherKey4", 0, mock.AnythingOfType("time.Time"), mock.Anything).Return(nil) + remoteKey1.On("Decrypt", mock.Anything, att, "encryptedShare1").Return(shares[0], nil) + remoteKey2.On("Decrypt", mock.Anything, att, "encryptedShare2").Return(shares[1], nil) + + pool := encryption.NewPool(enc, configs, keysTable, nil, nil, + encryption.WithCache(encryption.CacheConfig{MaxSize: 10, TTL: time.Minute})) + + // Encrypt to get ciphertext + populate cache. + keyRef, ciphertext, err := pool.Encrypt(context.Background(), att, []byte("test"), []byte("aad")) + require.NoError(t, err) + + // Decrypt — should be cache hit. + plaintext, err := pool.Decrypt(context.Background(), att, keyRef, ciphertext, []byte("aad")) + require.NoError(t, err) + require.Equal(t, "test", string(plaintext)) + + callsBefore := len(remoteKey1.Calls) + + // Rotate — should invalidate cache. + err = pool.RotateKey(context.Background(), att, "cipherKey4") + require.NoError(t, err) + + // Decrypt again — should be cache miss, hitting KMS again. + plaintext, err = pool.Decrypt(context.Background(), att, keyRef, ciphertext, []byte("aad")) + require.NoError(t, err) + require.Equal(t, "test", string(plaintext)) + + require.Greater(t, len(remoteKey1.Calls), callsBefore, "expected additional KMS calls after cache invalidation") +} + +func TestPool_NoCacheByDefault(t *testing.T) { + block, _ := pem.Decode([]byte(dummyPrivKey)) + privKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + require.NoError(t, err) + + kmsClient := &MockKMS{} + remoteKey1 := &MockRemoteKey{} + remoteKey2 := &MockRemoteKey{} + keysTable := &MockKeysTable{} + + random := &constantReader{value: 0x42} + enc, err := enclave.New(context.Background(), enclave.DummyProvider(random), kmsClient, privKey) + require.NoError(t, err) + + configs := []*encryption.Config{ + { + PoolSize: 10, + Threshold: 2, + RemoteKeys: map[string]encryption.RemoteKey{ + "remoteKey1": remoteKey1, + "remoteKey2": remoteKey2, + }, + }, + } + + att, err := enc.GetAttestation(context.Background(), nil, nil) + require.NoError(t, err) + defer func() { _ = att.Close() }() + + cipherKey, privateKey := newCipherKey(t, enc) + shares, err := shamir.Split(privateKey, 2, 2) + require.NoError(t, err) + + keysTable.On("Get", mock.Anything, 0, 4).Return(cipherKey, true, nil) + keysTable.On("GetLatestByKeyRef", mock.Anything, "cipherKey4", false).Return(cipherKey, true, nil) + remoteKey1.On("Decrypt", mock.Anything, att, "encryptedShare1").Return(shares[0], nil) + remoteKey2.On("Decrypt", mock.Anything, att, "encryptedShare2").Return(shares[1], nil) + + // No WithCache option. + pool := encryption.NewPool(enc, configs, keysTable, nil, nil) + + keyRef, ciphertext, err := pool.Encrypt(context.Background(), att, []byte("test"), []byte("aad")) + require.NoError(t, err) + + // First decrypt. + _, err = pool.Decrypt(context.Background(), att, keyRef, ciphertext, []byte("aad")) + require.NoError(t, err) + callsAfterFirst := len(remoteKey1.Calls) + + // Second decrypt — no cache, so KMS is called again. + _, err = pool.Decrypt(context.Background(), att, keyRef, ciphertext, []byte("aad")) + require.NoError(t, err) + + require.Greater(t, len(remoteKey1.Calls), callsAfterFirst, "without cache, every decrypt should call KMS") +} + +func TestPool_DecryptSingleflight(t *testing.T) { + block, _ := pem.Decode([]byte(dummyPrivKey)) + privKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + require.NoError(t, err) + + kmsClient := &MockKMS{} + remoteKey1 := &MockRemoteKey{} + remoteKey2 := &MockRemoteKey{} + keysTable := &MockKeysTable{} + + random := &constantReader{value: 0x42} + enc, err := enclave.New(context.Background(), enclave.DummyProvider(random), kmsClient, privKey) + require.NoError(t, err) + + configs := []*encryption.Config{ + { + PoolSize: 10, + Threshold: 2, + RemoteKeys: map[string]encryption.RemoteKey{ + "remoteKey1": remoteKey1, + "remoteKey2": remoteKey2, + }, + }, + } + + att, err := enc.GetAttestation(context.Background(), nil, nil) + require.NoError(t, err) + defer func() { _ = att.Close() }() + + cipherKey, privateKey := newCipherKey(t, enc) + shares, err := shamir.Split(privateKey, 2, 2) + require.NoError(t, err) + + // First, encrypt to get a ciphertext. + keysTable.On("Get", mock.Anything, 0, 4).Return(cipherKey, true, nil) + remoteKey1.On("Decrypt", mock.Anything, att, "encryptedShare1").Return(shares[0], nil) + remoteKey2.On("Decrypt", mock.Anything, att, "encryptedShare2").Return(shares[1], nil) + keysTable.On("GetLatestByKeyRef", mock.Anything, "cipherKey4", false).Return(cipherKey, true, nil) + + pool := encryption.NewPool(enc, configs, keysTable, nil, nil, + encryption.WithCache(encryption.CacheConfig{MaxSize: 10, TTL: time.Minute})) + + _, ciphertext, err := pool.Encrypt(context.Background(), att, []byte("test"), []byte("aad")) + require.NoError(t, err) + + // Clear mock call history so we count only decrypt-path calls. + remoteKey1.Calls = nil + remoteKey2.Calls = nil + remoteKey1.ExpectedCalls = nil + remoteKey2.ExpectedCalls = nil + + // Re-register expectations. + remoteKey1.On("Decrypt", mock.Anything, att, "encryptedShare1").Return(shares[0], nil) + remoteKey2.On("Decrypt", mock.Anything, att, "encryptedShare2").Return(shares[1], nil) + + // Invalidate cache so all goroutines start with a cold cache. + keysTable.On("GetLatestByKeyRef", mock.Anything, "cipherKey4", true).Return(cipherKey, true, nil) + keysTable.On("Deactivate", mock.Anything, "cipherKey4", 0, mock.AnythingOfType("time.Time"), mock.Anything).Return(nil) + err = pool.RotateKey(context.Background(), att, "cipherKey4") + require.NoError(t, err) + + // Launch N concurrent decrypts. + const N = 10 + var wg sync.WaitGroup + errs := make([]error, N) + results := make([]string, N) + for i := 0; i < N; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + pt, err := pool.Decrypt(context.Background(), att, "cipherKey4", ciphertext, []byte("aad")) + errs[idx] = err + if pt != nil { + results[idx] = string(pt) + } + }(i) + } + wg.Wait() + + for i := 0; i < N; i++ { + require.NoError(t, errs[i], "goroutine %d failed", i) + require.Equal(t, "test", results[i], "goroutine %d got wrong result", i) + } + + // With singleflight, exactly one goroutine fetches. RemoteKey1.Decrypt + // should be called once (one share per remote key, one fetch total). + decryptCalls := len(remoteKey1.Calls) + require.Equal(t, 1, decryptCalls, "singleflight should deduplicate concurrent fetches (got %d calls to remoteKey1)", decryptCalls) +} + +func TestPool_EncryptNewKeyPopulatesCache(t *testing.T) { + block, _ := pem.Decode([]byte(dummyPrivKey)) + privKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + require.NoError(t, err) + + kmsClient := &MockKMS{} + remoteKey1 := &MockRemoteKey{} + remoteKey2 := &MockRemoteKey{} + keysTable := &MockKeysTable{} + + random := &constantReader{value: 0x42} + enc, err := enclave.New(context.Background(), enclave.DummyProvider(random), kmsClient, privKey) + require.NoError(t, err) + + configs := []*encryption.Config{ + { + PoolSize: 10, + Threshold: 2, + RemoteKeys: map[string]encryption.RemoteKey{ + "remoteKey1": remoteKey1, + "remoteKey2": remoteKey2, + }, + }, + } + + att, err := enc.GetAttestation(context.Background(), nil, nil) + require.NoError(t, err) + defer func() { _ = att.Close() }() + + // Key does not exist — GenerateKey will be called. + keysTable.On("Get", mock.Anything, 0, 4).Return(nil, false, nil).Once() + remoteKey1.On("Encrypt", mock.Anything, att, mock.Anything).Return("encryptedShare1", nil) + remoteKey2.On("Encrypt", mock.Anything, att, mock.Anything).Return("encryptedShare2", nil) + keysTable.On("Create", mock.Anything, mock.Anything).Return(false, nil) + + pool := encryption.NewPool(enc, configs, keysTable, nil, nil, + encryption.WithCache(encryption.CacheConfig{MaxSize: 10, TTL: time.Minute})) + + // First Encrypt — GenerateKey creates key and caches DEK. + keyRef, ciphertext, err := pool.Encrypt(context.Background(), att, []byte("test"), []byte("aad")) + require.NoError(t, err) + require.NotEmpty(t, keyRef) + + // Decrypt should hit cache — no KMS Decrypt calls needed. + cipherKey := &data.CipherKey{ + Generation: 0, + KeyIndex: intPtr(4), + KeyRef: keyRef, + EncryptedShares: map[string]string{ + "remoteKey1": "encryptedShare1", + "remoteKey2": "encryptedShare2", + }, + CreatedAt: time.Now(), + } + hash, err := cipherKey.Hash() + require.NoError(t, err) + cipherKeyAtt, err := enc.GetAttestation(context.Background(), nil, hash) + require.NoError(t, err) + cipherKey.Attestation = cipherKeyAtt.Document() + _ = cipherKeyAtt.Close() + + keysTable.On("GetLatestByKeyRef", mock.Anything, keyRef, false).Return(cipherKey, true, nil) + + // RemoteKey.Decrypt should NOT be called — DEK is cached from GenerateKey. + plaintext, err := pool.Decrypt(context.Background(), att, keyRef, ciphertext, []byte("aad")) + require.NoError(t, err) + require.Equal(t, "test", string(plaintext)) + + remoteKey1.AssertNotCalled(t, "Decrypt", mock.Anything, mock.Anything, mock.Anything) + remoteKey2.AssertNotCalled(t, "Decrypt", mock.Anything, mock.Anything, mock.Anything) +} + +func TestPool_MultipleKeyRefs(t *testing.T) { + block, _ := pem.Decode([]byte(dummyPrivKey)) + privKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + require.NoError(t, err) + + kmsClient := &MockKMS{} + remoteKey1 := &MockRemoteKey{} + remoteKey2 := &MockRemoteKey{} + keysTable := &MockKeysTable{} + + random := &constantReader{value: 0x42} + enc, err := enclave.New(context.Background(), enclave.DummyProvider(random), kmsClient, privKey) + require.NoError(t, err) + + configs := []*encryption.Config{ + { + PoolSize: 10, + Threshold: 2, + RemoteKeys: map[string]encryption.RemoteKey{ + "remoteKey1": remoteKey1, + "remoteKey2": remoteKey2, + }, + }, + } + + att, err := enc.GetAttestation(context.Background(), nil, nil) + require.NoError(t, err) + defer func() { _ = att.Close() }() + + // Create two distinct cipher keys. + cipherKeyA, privateKeyA := newCipherKey(t, enc) + cipherKeyB, privateKeyB := newCipherKey(t, enc, func(key *data.CipherKey) { + key.KeyRef = "cipherKeyB" + }) + + sharesA, err := shamir.Split(privateKeyA, 2, 2) + require.NoError(t, err) + sharesB, err := shamir.Split(privateKeyB, 2, 2) + require.NoError(t, err) + + keysTable.On("GetLatestByKeyRef", mock.Anything, "cipherKey4", false).Return(cipherKeyA, true, nil) + keysTable.On("GetLatestByKeyRef", mock.Anything, "cipherKeyB", false).Return(cipherKeyB, true, nil) + remoteKey1.On("Decrypt", mock.Anything, att, "encryptedShare1").Return(sharesA[0], nil).Once() + remoteKey2.On("Decrypt", mock.Anything, att, "encryptedShare2").Return(sharesA[1], nil).Once() + + pool := encryption.NewPool(enc, configs, keysTable, nil, nil, + encryption.WithCache(encryption.CacheConfig{MaxSize: 10, TTL: time.Minute})) + + // Decrypt keyA — cache miss, populates cache for keyA. + _, err = pool.Decrypt(context.Background(), att, "cipherKey4", legacyCiphertext55_v2, []byte("aad")) + require.NoError(t, err) + + // Decrypt keyA again — cache hit, no KMS. + callsAfterA := len(remoteKey1.Calls) + _, err = pool.Decrypt(context.Background(), att, "cipherKey4", legacyCiphertext55_v2, []byte("aad")) + require.NoError(t, err) + require.Equal(t, callsAfterA, len(remoteKey1.Calls), "keyA should be a cache hit") + + // Decrypt keyB — cache miss for keyB (keyA cached, keyB not). + remoteKey1.On("Decrypt", mock.Anything, att, "encryptedShare1").Return(sharesB[0], nil).Once() + remoteKey2.On("Decrypt", mock.Anything, att, "encryptedShare2").Return(sharesB[1], nil).Once() + _, err = pool.Decrypt(context.Background(), att, "cipherKeyB", legacyCiphertext55_v2, []byte("aad")) + require.NoError(t, err) + require.Greater(t, len(remoteKey1.Calls), callsAfterA, "keyB should be a cache miss") + + // Decrypt keyB again — now a cache hit. + callsAfterB := len(remoteKey1.Calls) + _, err = pool.Decrypt(context.Background(), att, "cipherKeyB", legacyCiphertext55_v2, []byte("aad")) + require.NoError(t, err) + require.Equal(t, callsAfterB, len(remoteKey1.Calls), "keyB should now be a cache hit") +} + +func intPtr(v int) *int { return &v }