From c9673d70f2a0c2131ba86f95ee5f648aefa0dd12 Mon Sep 17 00:00:00 2001 From: Yolanda Robla Mota Date: Fri, 10 Apr 2026 07:58:49 +0200 Subject: [PATCH] =?UTF-8?q?Revert=20"Add=20LRU=20capacity=20to=20Validatin?= =?UTF-8?q?gCache,=20remove=20sentinel=20pattern,=20add=20sto=E2=80=A6"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 1a6287bb75bcc8991aa2526da9a10e81a35c243e. --- go.mod | 2 - pkg/cache/validating_cache.go | 173 --------- pkg/transport/session/session_data_storage.go | 15 +- .../session/session_data_storage_local.go | 97 +++-- .../session/session_data_storage_redis.go | 32 -- .../session/session_data_storage_test.go | 94 +---- pkg/vmcp/server/sessionmanager/cache.go | 162 ++++++++ .../server/sessionmanager/cache_test.go} | 355 ++++++++---------- pkg/vmcp/server/sessionmanager/factory.go | 12 - .../server/sessionmanager/session_manager.go | 279 ++++++++------ .../sessionmanager/session_manager_test.go | 117 +++--- 11 files changed, 570 insertions(+), 768 deletions(-) delete mode 100644 pkg/cache/validating_cache.go create mode 100644 pkg/vmcp/server/sessionmanager/cache.go rename pkg/{cache/validating_cache_test.go => vmcp/server/sessionmanager/cache_test.go} (50%) diff --git a/go.mod b/go.mod index 146c7236eb..3a4af974aa 100644 --- a/go.mod +++ b/go.mod @@ -79,8 +79,6 @@ require ( require github.com/getsentry/sentry-go/otel v0.44.1 -require github.com/hashicorp/golang-lru/v2 v2.0.7 - require ( cel.dev/expr v0.25.1 // indirect github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect diff --git a/pkg/cache/validating_cache.go b/pkg/cache/validating_cache.go deleted file mode 100644 index eb6dbd1aa9..0000000000 --- a/pkg/cache/validating_cache.go +++ /dev/null @@ -1,173 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -// Package cache provides a generic, capacity-bounded cache with singleflight -// deduplication and per-hit liveness validation. -package cache - -import ( - "errors" - "fmt" - - lru "github.com/hashicorp/golang-lru/v2" - "golang.org/x/sync/singleflight" -) - -// ErrExpired is returned by the check function passed to New to signal that a -// cached entry has definitively expired and should be evicted. -var ErrExpired = errors.New("cache entry expired") - -// ValidatingCache is a node-local write-through cache backed by a -// capacity-bounded LRU map, with singleflight-deduplicated restore on cache -// miss and lazy liveness validation on cache hit. -// -// Type parameter K is the key type (must be comparable). -// Type parameter V is the cached value type. -// -// The no-resurrection invariant (preventing a concurrent restore from -// overwriting a deletion) is enforced via ContainsOrAdd: if a concurrent -// writer stored a value between load() returning and the cache being updated, -// the prior writer's value wins and the just-loaded value is discarded via -// onEvict. -type ValidatingCache[K comparable, V any] struct { - lruCache *lru.Cache[K, V] - flight singleflight.Group - load func(key K) (V, error) - check func(key K, val V) error - // onEvict is kept here so we can call it when discarding a concurrently - // loaded value that lost the race to a prior writer. - onEvict func(key K, val V) -} - -// New creates a ValidatingCache with the given capacity and callbacks. -// -// capacity is the maximum number of entries; it must be >= 1. When the cache -// is full and a new entry must be stored, the least-recently-used entry is -// evicted first. Values less than 1 panic. -// -// load is called on a cache miss to restore the value; it must not be nil. -// check is called on every cache hit to confirm liveness. It receives both the -// key and the cached value so callers can inspect the value without a separate -// read. Returning ErrExpired evicts the entry; any other error is transient -// (cached value returned unchanged). It must not be nil. -// onEvict is called after any eviction (LRU or expiry); it may be nil. -func New[K comparable, V any]( - capacity int, - load func(K) (V, error), - check func(K, V) error, - onEvict func(K, V), -) *ValidatingCache[K, V] { - if capacity < 1 { - panic(fmt.Sprintf("cache.New: capacity must be >= 1, got %d", capacity)) - } - if load == nil { - panic("cache.New: load must not be nil") - } - if check == nil { - panic("cache.New: check must not be nil") - } - - c, err := lru.NewWithEvict(capacity, onEvict) - if err != nil { - // Only possible if size < 0, which we have already ruled out above. - panic(fmt.Sprintf("cache.New: lru.NewWithEvict: %v", err)) - } - - return &ValidatingCache[K, V]{ - lruCache: c, - load: load, - check: check, - onEvict: onEvict, - } -} - -// getHit validates a known-present cache entry and returns its value. -// If the entry has definitively expired it is evicted and (zero, false) is -// returned. Transient check errors leave the entry in place and return the -// cached value. -func (c *ValidatingCache[K, V]) getHit(key K, val V) (V, bool) { - if err := c.check(key, val); err != nil { - if errors.Is(err, ErrExpired) { - // Remove fires the eviction callback automatically. - c.lruCache.Remove(key) - var zero V - return zero, false - } - } - return val, true -} - -// Get returns the value for key, loading it on a cache miss. On a cache hit -// the entry's liveness is validated via the check function provided to New: -// ErrExpired evicts the entry and returns (zero, false); transient errors -// return the cached value unchanged. On a cache miss, load is called under a -// singleflight group so at most one restore runs concurrently per key. -func (c *ValidatingCache[K, V]) Get(key K) (V, bool) { - if val, ok := c.lruCache.Get(key); ok { - return c.getHit(key, val) - } - - // Cache miss: use singleflight to prevent concurrent restores for the same key. - type result struct{ v V } - raw, err, _ := c.flight.Do(fmt.Sprint(key), func() (any, error) { - // Re-check the cache: a concurrent singleflight group may have stored - // the value between our miss check above and acquiring this group. - if existing, ok := c.lruCache.Get(key); ok { - return result{v: existing}, nil - } - - v, loadErr := c.load(key) - if loadErr != nil { - return nil, loadErr - } - - // Guard against a concurrent Set or Remove that occurred while load() was - // running. ContainsOrAdd stores only if absent; if another writer got - // in first, their value wins and we discard ours via onEvict. - ok, _ := c.lruCache.ContainsOrAdd(key, v) - if ok { - // Another writer stored a value first; discard our loaded value and - // return the winner's. ContainsOrAdd and Get are separate lock - // acquisitions, so the winner may itself have been evicted by LRU - // pressure between the two calls — fall back to our freshly loaded - // value in that case rather than returning a zero value. - winner, found := c.lruCache.Get(key) - if !found { - // Winner was evicted between ContainsOrAdd and Get; keep our - // freshly loaded value rather than returning a zero value. - return result{v: v}, nil - } - // Discard our loaded value in favour of the winner. - if c.onEvict != nil { - c.onEvict(key, v) - } - return result{v: winner}, nil - } - - return result{v: v}, nil - }) - if err != nil { - var zero V - return zero, false - } - r, ok := raw.(result) - return r.v, ok -} - -// Set stores value under key, moving the entry to the MRU position. If the -// cache is at capacity, the least-recently-used entry is evicted first and -// onEvict is called for it. -func (c *ValidatingCache[K, V]) Set(key K, value V) { - c.lruCache.Add(key, value) -} - -// Remove evicts the entry for key, calling onEvict if the key was present. -// It is a no-op if the key is not in the cache. -func (c *ValidatingCache[K, V]) Remove(key K) { - c.lruCache.Remove(key) -} - -// Len returns the number of entries currently in the cache. -func (c *ValidatingCache[K, V]) Len() int { - return c.lruCache.Len() -} diff --git a/pkg/transport/session/session_data_storage.go b/pkg/transport/session/session_data_storage.go index 9093fcdf5d..40588ea5be 100644 --- a/pkg/transport/session/session_data_storage.go +++ b/pkg/transport/session/session_data_storage.go @@ -25,9 +25,6 @@ import ( // - Create atomically creates metadata for id only if it does not already exist. // Use this in preference to Load+Upsert to avoid TOCTOU races. // - Upsert creates or overwrites the metadata for id, refreshing the TTL. -// - Update overwrites metadata only if the key already exists (SET XX semantics). -// Returns (true, nil) if updated, (false, nil) if the session was not found. -// Use this instead of Load+Upsert to avoid TOCTOU resurrection races. // - Load retrieves metadata and refreshes the TTL (sliding-window expiry). // Returns ErrSessionNotFound if the session does not exist. // - Delete removes the session. It is not an error if the session is absent. @@ -42,13 +39,6 @@ type DataStorage interface { // Upsert creates or updates session metadata with a sliding TTL. Upsert(ctx context.Context, id string, metadata map[string]string) error - // Update overwrites session metadata only if the session ID already exists - // (conditional write, equivalent to Redis SET XX). Returns (true, nil) if - // the entry was updated, (false, nil) if it was not found, or (false, err) - // on storage errors. Use this instead of Load+Upsert to prevent resurrections - // after a concurrent Delete. - Update(ctx context.Context, id string, metadata map[string]string) (bool, error) - // Load retrieves session metadata and refreshes its TTL. // Returns ErrSessionNotFound if the session does not exist. Load(ctx context.Context, id string) (map[string]string, error) @@ -75,9 +65,8 @@ func NewLocalSessionDataStorage(ttl time.Duration) (*LocalSessionDataStorage, er return nil, fmt.Errorf("ttl must be a positive duration") } s := &LocalSessionDataStorage{ - sessions: make(map[string]*localDataEntry), - ttl: ttl, - stopCh: make(chan struct{}), + ttl: ttl, + stopCh: make(chan struct{}), } go s.cleanupRoutine() return s, nil diff --git a/pkg/transport/session/session_data_storage_local.go b/pkg/transport/session/session_data_storage_local.go index bc125a0480..abb02c9f7d 100644 --- a/pkg/transport/session/session_data_storage_local.go +++ b/pkg/transport/session/session_data_storage_local.go @@ -30,13 +30,12 @@ func (e *localDataEntry) lastAccess() time.Time { } // LocalSessionDataStorage implements DataStorage using an in-memory -// map with TTL-based eviction. +// sync.Map with TTL-based eviction. // // Sessions are evicted if they have not been accessed within the configured TTL. // A background goroutine runs until Close is called. type LocalSessionDataStorage struct { - sessions map[string]*localDataEntry // guarded by mu - mu sync.Mutex + sessions sync.Map // map[string]*localDataEntry ttl time.Duration stopCh chan struct{} stopOnce sync.Once @@ -50,9 +49,9 @@ func (s *LocalSessionDataStorage) Upsert(_ context.Context, id string, metadata if metadata == nil { metadata = make(map[string]string) } - s.mu.Lock() - s.sessions[id] = newLocalDataEntry(maps.Clone(metadata)) - s.mu.Unlock() + // Store a defensive copy so callers cannot mutate stored data. + copied := maps.Clone(metadata) + s.sessions.Store(id, newLocalDataEntry(copied)) return nil } @@ -62,20 +61,26 @@ func (s *LocalSessionDataStorage) Load(_ context.Context, id string) (map[string if id == "" { return nil, fmt.Errorf("cannot load session data with empty ID") } - s.mu.Lock() - entry, ok := s.sessions[id] - if ok { - entry.lastAccessNano.Store(time.Now().UnixNano()) - } - s.mu.Unlock() + + val, ok := s.sessions.Load(id) if !ok { return nil, ErrSessionNotFound } + entry, ok := val.(*localDataEntry) + if !ok { + return nil, fmt.Errorf("invalid entry type in local session data storage") + } + + // Refresh last-access in place. deleteExpired re-checks the timestamp + // immediately before calling CompareAndDelete, so this atomic store is + // sufficient to prevent eviction of an actively accessed entry. + entry.lastAccessNano.Store(time.Now().UnixNano()) + return maps.Clone(entry.metadata), nil } -// Create creates session metadata only if the session ID does not already exist. -// Returns (true, nil) if created, (false, nil) if the key already existed. +// Create atomically creates session metadata only if the session ID +// does not already exist. Uses sync.Map.LoadOrStore for atomicity. func (s *LocalSessionDataStorage) Create(_ context.Context, id string, metadata map[string]string) (bool, error) { if id == "" { return false, fmt.Errorf("cannot write session data with empty ID") @@ -83,31 +88,9 @@ func (s *LocalSessionDataStorage) Create(_ context.Context, id string, metadata if metadata == nil { metadata = make(map[string]string) } - s.mu.Lock() - defer s.mu.Unlock() - if _, exists := s.sessions[id]; exists { - return false, nil - } - s.sessions[id] = newLocalDataEntry(maps.Clone(metadata)) - return true, nil -} - -// Update overwrites session metadata only if the session ID already exists. -// Returns (true, nil) if updated, (false, nil) if not found. -func (s *LocalSessionDataStorage) Update(_ context.Context, id string, metadata map[string]string) (bool, error) { - if id == "" { - return false, fmt.Errorf("cannot write session data with empty ID") - } - if metadata == nil { - metadata = make(map[string]string) - } - s.mu.Lock() - defer s.mu.Unlock() - if _, ok := s.sessions[id]; !ok { - return false, nil - } - s.sessions[id] = newLocalDataEntry(maps.Clone(metadata)) - return true, nil + copied := maps.Clone(metadata) + _, loaded := s.sessions.LoadOrStore(id, newLocalDataEntry(copied)) + return !loaded, nil } // Delete removes session metadata. Not an error if absent. @@ -115,18 +98,17 @@ func (s *LocalSessionDataStorage) Delete(_ context.Context, id string) error { if id == "" { return fmt.Errorf("cannot delete session data with empty ID") } - s.mu.Lock() - delete(s.sessions, id) - s.mu.Unlock() + s.sessions.Delete(id) return nil } // Close stops the background cleanup goroutine and clears all stored metadata. func (s *LocalSessionDataStorage) Close() error { s.stopOnce.Do(func() { close(s.stopCh) }) - s.mu.Lock() - s.sessions = make(map[string]*localDataEntry) - s.mu.Unlock() + s.sessions.Range(func(key, _ any) bool { + s.sessions.Delete(key) + return true + }) return nil } @@ -158,11 +140,26 @@ func (s *LocalSessionDataStorage) cleanupRoutine() { func (s *LocalSessionDataStorage) deleteExpired() { cutoff := time.Now().Add(-s.ttl) - s.mu.Lock() - defer s.mu.Unlock() - for id, entry := range s.sessions { - if entry.lastAccess().Before(cutoff) { - delete(s.sessions, id) + var toDelete []struct { + id string + entry *localDataEntry + } + s.sessions.Range(func(key, val any) bool { + entry, ok := val.(*localDataEntry) + if ok && entry.lastAccess().Before(cutoff) { + id, ok := key.(string) + if ok { + toDelete = append(toDelete, struct { + id string + entry *localDataEntry + }{id, entry}) + } + } + return true + }) + for _, item := range toDelete { + if item.entry.lastAccess().Before(cutoff) { + s.sessions.CompareAndDelete(item.id, item.entry) } } } diff --git a/pkg/transport/session/session_data_storage_redis.go b/pkg/transport/session/session_data_storage_redis.go index 02d04c2027..916a82050e 100644 --- a/pkg/transport/session/session_data_storage_redis.go +++ b/pkg/transport/session/session_data_storage_redis.go @@ -87,38 +87,6 @@ func (s *RedisSessionDataStorage) Load(ctx context.Context, id string) (map[stri return metadata, nil } -// Update overwrites session metadata only if the key already exists. -// Uses Redis SET XX (set-if-exists) to prevent resurrecting a session that -// was deleted by a concurrent Delete call (e.g. from another pod). -// Returns (true, nil) if updated, (false, nil) if the key was not found. -func (s *RedisSessionDataStorage) Update(ctx context.Context, id string, metadata map[string]string) (bool, error) { - if id == "" { - return false, fmt.Errorf("cannot write session data with empty ID") - } - if metadata == nil { - metadata = make(map[string]string) - } - data, err := json.Marshal(metadata) - if err != nil { - return false, fmt.Errorf("failed to serialize session metadata: %w", err) - } - // Mode "XX" means "only set if the key already exists". - res, err := s.client.SetArgs(ctx, s.key(id), data, redis.SetArgs{ - Mode: "XX", - TTL: s.ttl, - }).Result() - if err != nil { - // go-redis surfaces the "key does not exist" nil bulk reply as redis.Nil. - if errors.Is(err, redis.Nil) { - return false, nil - } - return false, fmt.Errorf("failed to conditionally update session metadata: %w", err) - } - // SetArgs with Mode "XX" returns "" when the key does not exist and "OK" - // when the write succeeded. - return res == "OK", nil -} - // Create atomically creates session metadata only if the key does not // already exist. Uses Redis SET NX (set-if-not-exists) to eliminate the // TOCTOU race between Load and Upsert in multi-pod deployments. diff --git a/pkg/transport/session/session_data_storage_test.go b/pkg/transport/session/session_data_storage_test.go index 63b41f9107..9b1be8b346 100644 --- a/pkg/transport/session/session_data_storage_test.go +++ b/pkg/transport/session/session_data_storage_test.go @@ -191,74 +191,6 @@ func runDataStorageTests(t *testing.T, newStorage func(t *testing.T) DataStorage err := s.Delete(ctx, "") assert.Error(t, err) }) - - t.Run("Update overwrites existing entry and returns true", func(t *testing.T) { - t.Parallel() - s := newStorage(t) - ctx := context.Background() - - require.NoError(t, s.Upsert(ctx, "sess-update", map[string]string{"v": "original"})) - - updated, err := s.Update(ctx, "sess-update", map[string]string{"v": "updated"}) - require.NoError(t, err) - assert.True(t, updated, "should return true when key exists") - - loaded, err := s.Load(ctx, "sess-update") - require.NoError(t, err) - assert.Equal(t, "updated", loaded["v"]) - }) - - t.Run("Update on missing key returns (false, nil) without creating it", func(t *testing.T) { - t.Parallel() - s := newStorage(t) - ctx := context.Background() - - updated, err := s.Update(ctx, "sess-absent", map[string]string{"v": "new"}) - require.NoError(t, err) - assert.False(t, updated, "should return false when key does not exist") - - // The key must not have been created. - _, err = s.Load(ctx, "sess-absent") - assert.ErrorIs(t, err, ErrSessionNotFound, "Update must not create a missing key") - }) - - t.Run("Update after Delete returns (false, nil)", func(t *testing.T) { - t.Parallel() - s := newStorage(t) - ctx := context.Background() - - require.NoError(t, s.Upsert(ctx, "sess-deleted", map[string]string{"v": "1"})) - require.NoError(t, s.Delete(ctx, "sess-deleted")) - - updated, err := s.Update(ctx, "sess-deleted", map[string]string{"v": "2"}) - require.NoError(t, err) - assert.False(t, updated, "should return false after key was deleted") - }) - - t.Run("Update with empty ID returns error", func(t *testing.T) { - t.Parallel() - s := newStorage(t) - ctx := context.Background() - - _, err := s.Update(ctx, "", map[string]string{}) - assert.Error(t, err) - }) - - t.Run("Update nil metadata is treated as empty map", func(t *testing.T) { - t.Parallel() - s := newStorage(t) - ctx := context.Background() - - require.NoError(t, s.Upsert(ctx, "sess-update-nil", map[string]string{"v": "original"})) - - updated, err := s.Update(ctx, "sess-update-nil", nil) - require.NoError(t, err) - assert.True(t, updated) - - loaded, err := s.Load(ctx, "sess-update-nil") - require.NoError(t, err) - assert.NotNil(t, loaded) - }) } // --------------------------------------------------------------------------- @@ -322,11 +254,9 @@ func TestLocalSessionDataStorage(t *testing.T) { // simulating an entry that has been idle for that duration. func backdateLocalEntry(t *testing.T, s *LocalSessionDataStorage, id string, age time.Duration) { t.Helper() - s.mu.Lock() - entry, ok := s.sessions[id] - s.mu.Unlock() + val, ok := s.sessions.Load(id) require.True(t, ok, "entry %q not found for backdating", id) - entry.lastAccessNano.Store(time.Now().Add(-age).UnixNano()) + val.(*localDataEntry).lastAccessNano.Store(time.Now().Add(-age).UnixNano()) } // --------------------------------------------------------------------------- @@ -400,24 +330,4 @@ func TestRedisSessionDataStorage(t *testing.T) { require.NoError(t, err) assert.NotEmpty(t, val) }) - - t.Run("Update refreshes TTL via SET XX", func(t *testing.T) { - t.Parallel() - s, mr := newTestRedisDataStorage(t) - ctx := context.Background() - - require.NoError(t, s.Upsert(ctx, "ttl-update", map[string]string{"v": "1"})) - mr.FastForward(29 * time.Minute) - - updated, err := s.Update(ctx, "ttl-update", map[string]string{"v": "2"}) - require.NoError(t, err) - assert.True(t, updated) - - // Advance past the original TTL; Update should have reset the clock. - mr.FastForward(2 * time.Minute) - - loaded, err := s.Load(ctx, "ttl-update") - require.NoError(t, err, "session should still be alive after TTL reset by Update") - assert.Equal(t, "2", loaded["v"]) - }) } diff --git a/pkg/vmcp/server/sessionmanager/cache.go b/pkg/vmcp/server/sessionmanager/cache.go new file mode 100644 index 0000000000..52ee95a4d5 --- /dev/null +++ b/pkg/vmcp/server/sessionmanager/cache.go @@ -0,0 +1,162 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package sessionmanager + +import ( + "errors" + "fmt" + "sync" + + "golang.org/x/sync/singleflight" +) + +// ErrExpired is returned by the check function passed to newRestorableCache to +// signal that a cached entry has definitively expired and should be evicted. +var ErrExpired = errors.New("cache entry expired") + +// errSentinelFound is returned inside the singleflight load function when a +// non-V value (e.g. terminatedSentinel) is present in the map. Returning an +// error aborts the load and causes Get to return (zero, false), consistent +// with the behaviour of the initial-hit path that also returns (zero, false) +// for non-V values. +var errSentinelFound = errors.New("sentinel stored in cache") + +// RestorableCache is a node-local write-through cache backed by a sync.Map, +// with singleflight-deduplicated restore on cache miss and lazy liveness +// validation on cache hit. +// +// Type parameter K is the key type (must be comparable). +// Type parameter V is the cached value type. +// +// Values are stored internally as any, which allows callers to place sentinel +// markers alongside V entries (e.g. a tombstone during teardown). Get performs +// a type assertion to V and treats non-V entries as "not found". Peek and +// Store expose raw any access for sentinel use. +type RestorableCache[K comparable, V any] struct { + m sync.Map + flight singleflight.Group + + // load is called on a cache miss. Return (value, nil) on success. + // A successful result is stored in the cache before being returned. + load func(key K) (V, error) + + // check is called on every cache hit to confirm liveness. Returning nil + // means the entry is alive. Returning ErrExpired means it has definitively + // expired (the entry is evicted). Any other error is treated as a transient + // failure and the cached value is returned unchanged. + check func(key K) error + + // onEvict is called after a confirmed-expired entry has been removed. The + // evicted value is passed to allow resource cleanup (e.g. closing + // connections). May be nil. + onEvict func(key K, v V) +} + +// TODO: add an age-based sweep to bound the lifetime of entries that are +// never accessed again after their storage TTL expires. The sweep would range +// over m, compare each entry's insertion time against a caller-supplied maxAge, +// and call onEvict for entries that are too old — all without touching storage. +// Until then, entries for idle sessions leak backend connections until the +// process restarts or the session ID is queried again. + +func newRestorableCache[K comparable, V any]( + load func(K) (V, error), + check func(K) error, + onEvict func(K, V), +) *RestorableCache[K, V] { + return &RestorableCache[K, V]{ + load: load, + check: check, + onEvict: onEvict, + } +} + +// Get returns the cached V value for key. +// +// On a cache hit, check is run first: ErrExpired evicts the entry and returns +// (zero, false); transient errors return the cached value unchanged. Non-V +// values stored via Store (e.g. sentinels) return (zero, false) without +// triggering a restore. +// +// On a cache miss, load is called under a singleflight group so at most one +// restore runs concurrently per key. +func (c *RestorableCache[K, V]) Get(key K) (V, bool) { + if raw, ok := c.m.Load(key); ok { + v, isV := raw.(V) + if !isV { + var zero V + return zero, false + } + if err := c.check(key); err != nil { + if errors.Is(err, ErrExpired) { + c.m.Delete(key) + if c.onEvict != nil { + c.onEvict(key, v) + } + var zero V + return zero, false + } + // Transient error — keep the cached value. + } + return v, true + } + + // Cache miss: use singleflight to prevent concurrent restores for the same key. + type result struct{ v V } + raw, err, _ := c.flight.Do(fmt.Sprint(key), func() (any, error) { + // Re-check the cache: a concurrent singleflight group may have stored + // the value between our miss check above and acquiring this group. + if stored, ok := c.m.Load(key); ok { + if v, isV := stored.(V); isV { + return result{v: v}, nil + } + // Non-V sentinel present (e.g. terminatedSentinel). Treat as a + // hard stop: do not call load() and do not overwrite the sentinel. + return nil, errSentinelFound + } + v, loadErr := c.load(key) + if loadErr != nil { + return nil, loadErr + } + // Guard against a sentinel being stored between load() completing and + // this Store call (Terminate() running concurrently). LoadOrStore is + // atomic: if a sentinel got in, we discard the freshly loaded value + // via onEvict rather than silently overwriting the sentinel. + if _, loaded := c.m.LoadOrStore(key, v); loaded { + if c.onEvict != nil { + c.onEvict(key, v) + } + return nil, errSentinelFound + } + return result{v: v}, nil + }) + if err != nil { + var zero V + return zero, false + } + r, ok := raw.(result) + return r.v, ok +} + +// Store sets key to value. value may be any type, including sentinel markers. +func (c *RestorableCache[K, V]) Store(key K, value any) { + c.m.Store(key, value) +} + +// Delete removes key from the cache. +func (c *RestorableCache[K, V]) Delete(key K) { + c.m.Delete(key) +} + +// Peek returns the raw value stored under key without type assertion, liveness +// check, or restore. Used for sentinel inspection. +func (c *RestorableCache[K, V]) Peek(key K) (any, bool) { + return c.m.Load(key) +} + +// CompareAndSwap atomically replaces the value stored under key from old to +// new. Both old and new may be any type, including sentinels. +func (c *RestorableCache[K, V]) CompareAndSwap(key K, old, replacement any) bool { + return c.m.CompareAndSwap(key, old, replacement) +} diff --git a/pkg/cache/validating_cache_test.go b/pkg/vmcp/server/sessionmanager/cache_test.go similarity index 50% rename from pkg/cache/validating_cache_test.go rename to pkg/vmcp/server/sessionmanager/cache_test.go index 209d8769d3..1123499b96 100644 --- a/pkg/cache/validating_cache_test.go +++ b/pkg/vmcp/server/sessionmanager/cache_test.go @@ -1,11 +1,10 @@ // SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. // SPDX-License-Identifier: Apache-2.0 -package cache +package sessionmanager import ( "errors" - "fmt" "sync" "sync/atomic" "testing" @@ -14,41 +13,27 @@ import ( "github.com/stretchr/testify/require" ) -// newStringCache builds a ValidatingCache[string, string] for tests. +// sentinel type used to test that non-V values stored via Store are +// invisible to Get without triggering a restore. +type testSentinel struct{} + +// newStringCache builds a RestorableCache[string, string] for tests. func newStringCache( load func(string) (string, error), - check func(string, string) error, + check func(string) error, evict func(string, string), -) *ValidatingCache[string, string] { - return New(1000, load, check, evict) +) *RestorableCache[string, string] { + return newRestorableCache(load, check, evict) } // alwaysAliveCheck returns a check function that always reports the entry as alive. -func alwaysAliveCheck(_ string, _ string) error { return nil } - -// --------------------------------------------------------------------------- -// Construction invariants -// --------------------------------------------------------------------------- - -func TestValidatingCache_New_PanicsOnZeroCapacity(t *testing.T) { - t.Parallel() - assert.Panics(t, func() { - New(0, func(_ string) (string, error) { return "", nil }, alwaysAliveCheck, nil) - }) -} - -func TestValidatingCache_New_PanicsOnNegativeCapacity(t *testing.T) { - t.Parallel() - assert.Panics(t, func() { - New(-1, func(_ string) (string, error) { return "", nil }, alwaysAliveCheck, nil) - }) -} +func alwaysAliveCheck(_ string) error { return nil } // --------------------------------------------------------------------------- // Cache miss / restore // --------------------------------------------------------------------------- -func TestValidatingCache_CacheMiss_CallsLoad(t *testing.T) { +func TestRestorableCache_CacheMiss_CallsLoad(t *testing.T) { t.Parallel() loaded := false @@ -67,7 +52,7 @@ func TestValidatingCache_CacheMiss_CallsLoad(t *testing.T) { assert.True(t, loaded) } -func TestValidatingCache_CacheMiss_StoresResult(t *testing.T) { +func TestRestorableCache_CacheMiss_StoresResult(t *testing.T) { t.Parallel() calls := 0 @@ -85,7 +70,7 @@ func TestValidatingCache_CacheMiss_StoresResult(t *testing.T) { assert.Equal(t, 1, calls, "load should be called only once after caching") } -func TestValidatingCache_CacheMiss_LoadError_ReturnsNotFound(t *testing.T) { +func TestRestorableCache_CacheMiss_LoadError_ReturnsNotFound(t *testing.T) { t.Parallel() loadErr := errors.New("not found") @@ -104,7 +89,7 @@ func TestValidatingCache_CacheMiss_LoadError_ReturnsNotFound(t *testing.T) { // Cache hit / liveness // --------------------------------------------------------------------------- -func TestValidatingCache_CacheHit_AliveCheck_ReturnsCached(t *testing.T) { +func TestRestorableCache_CacheHit_AliveCheck_ReturnsCached(t *testing.T) { t.Parallel() c := newStringCache( @@ -120,14 +105,14 @@ func TestValidatingCache_CacheHit_AliveCheck_ReturnsCached(t *testing.T) { assert.Equal(t, "loaded-k", v) } -func TestValidatingCache_CacheHit_Expired_EvictsAndCallsOnEvict(t *testing.T) { +func TestRestorableCache_CacheHit_Expired_EvictsAndCallsOnEvict(t *testing.T) { t.Parallel() evictedKey := "" evictedVal := "" c := newStringCache( func(_ string) (string, error) { return "v", nil }, - func(_ string, _ string) error { return ErrExpired }, + func(_ string) error { return ErrExpired }, func(key, val string) { evictedKey = key evictedVal = val @@ -142,7 +127,7 @@ func TestValidatingCache_CacheHit_Expired_EvictsAndCallsOnEvict(t *testing.T) { assert.Equal(t, "v", evictedVal) } -func TestValidatingCache_CacheHit_Expired_EntryRemovedFromCache(t *testing.T) { +func TestRestorableCache_CacheHit_Expired_EntryRemovedFromCache(t *testing.T) { t.Parallel() calls := 0 @@ -152,7 +137,7 @@ func TestValidatingCache_CacheHit_Expired_EntryRemovedFromCache(t *testing.T) { calls++ return "v", nil }, - func(_ string, _ string) error { + func(_ string) error { if expired { return ErrExpired } @@ -170,12 +155,12 @@ func TestValidatingCache_CacheHit_Expired_EntryRemovedFromCache(t *testing.T) { assert.Equal(t, 2, calls, "load should be called twice: initial + after eviction") } -func TestValidatingCache_CacheHit_TransientCheckError_ReturnsCached(t *testing.T) { +func TestRestorableCache_CacheHit_TransientCheckError_ReturnsCached(t *testing.T) { t.Parallel() c := newStringCache( func(_ string) (string, error) { return "v", nil }, - func(_ string, _ string) error { return errors.New("transient storage error") }, + func(_ string) error { return errors.New("transient storage error") }, nil, ) c.Get("k") //nolint:errcheck // prime the cache @@ -186,54 +171,66 @@ func TestValidatingCache_CacheHit_TransientCheckError_ReturnsCached(t *testing.T } // --------------------------------------------------------------------------- -// Set +// Sentinel / raw access // --------------------------------------------------------------------------- -func TestValidatingCache_Set_StoresValue(t *testing.T) { +func TestRestorableCache_Sentinel_GetReturnsNotFound(t *testing.T) { t.Parallel() - c := newStringCache( - func(_ string) (string, error) { return "", errors.New("should not call load") }, + loadCalled := false + c := newRestorableCache( + func(_ string) (string, error) { + loadCalled = true + return "", errors.New("should not be called") + }, alwaysAliveCheck, nil, ) - c.Set("k", "v") + c.Store("k", testSentinel{}) v, ok := c.Get("k") - require.True(t, ok) - assert.Equal(t, "v", v) + assert.False(t, ok, "sentinel should not satisfy type assertion to V") + assert.Empty(t, v) + assert.False(t, loadCalled, "load should not be called when a sentinel is present") } -func TestValidatingCache_Set_UpdatesExisting(t *testing.T) { +func TestRestorableCache_Peek_ReturnsSentinel(t *testing.T) { t.Parallel() - c := newStringCache( - func(_ string) (string, error) { return "loaded", nil }, + c := newRestorableCache( + func(string) (string, error) { return "", nil }, alwaysAliveCheck, nil, ) - c.Get("k") //nolint:errcheck // prime with "loaded" - c.Set("k", "updated") - v, ok := c.Get("k") + c.Store("k", testSentinel{}) + + raw, ok := c.Peek("k") require.True(t, ok) - assert.Equal(t, "updated", v) + _, isSentinel := raw.(testSentinel) + assert.True(t, isSentinel) } -// --------------------------------------------------------------------------- -// LRU capacity -// --------------------------------------------------------------------------- - -func TestValidatingCache_LRU_EvictsLeastRecentlyUsed(t *testing.T) { +// TestRestorableCache_Sentinel_StoredDuringLoad verifies that a sentinel stored +// concurrently during load() is respected: load() should not overwrite the +// sentinel, and the loaded value should be discarded via onEvict. +func TestRestorableCache_Sentinel_StoredDuringLoad(t *testing.T) { t.Parallel() var evictedKeys []string var mu sync.Mutex - // capacity=2: inserting a third entry evicts the LRU. - c := New(2, - func(key string) (string, error) { return "val-" + key, nil }, + sentinelReady := make(chan struct{}) + loadStarted := make(chan struct{}) + + c := newRestorableCache( + func(_ string) (string, error) { + // Signal that load has started, then wait for the sentinel to be stored. + close(loadStarted) + <-sentinelReady + return "loaded-value", nil + }, alwaysAliveCheck, func(key, _ string) { mu.Lock() @@ -242,152 +239,158 @@ func TestValidatingCache_LRU_EvictsLeastRecentlyUsed(t *testing.T) { }, ) - c.Get("a") //nolint:errcheck // a=MRU - c.Get("b") //nolint:errcheck // b=MRU, a=LRU - c.Get("c") //nolint:errcheck // c=MRU, b, a=LRU → evicts a + done := make(chan struct{}) + go func() { + defer close(done) + v, ok := c.Get("k") + // The sentinel should have blocked the store; Get returns not-found. + assert.False(t, ok) + assert.Empty(t, v) + }() + + // Wait until load() has started, then inject a sentinel before it stores. + <-loadStarted + c.Store("k", testSentinel{}) + close(sentinelReady) + <-done + + // The sentinel must still be in the cache (not overwritten by the loaded value). + raw, ok := c.Peek("k") + require.True(t, ok) + _, isSentinel := raw.(testSentinel) + assert.True(t, isSentinel, "sentinel must not be overwritten by the restore") + // onEvict must have been called for the discarded loaded value. mu.Lock() defer mu.Unlock() - assert.Equal(t, []string{"a"}, evictedKeys, "LRU entry (a) should be evicted") - - // a is evicted; b and c remain. - _, bPresent := c.Get("b") - assert.True(t, bPresent) - _, cPresent := c.Get("c") - assert.True(t, cPresent) + assert.Equal(t, []string{"k"}, evictedKeys, "loaded value must be evicted when sentinel is present") } -func TestValidatingCache_LRU_GetRefreshesMRUPosition(t *testing.T) { +// TestRestorableCache_Sentinel_BlocksRestoreViaInitialHit verifies that a +// sentinel already present in the cache when Get is called causes load() to be +// skipped and Get to return not-found. This exercises the initial-hit branch +// (the outer c.m.Load check), which short-circuits before entering the +// singleflight group. +// +// The singleflight re-check branch (c.m.Load inside flight.Do) has structurally +// identical logic: if the stored value is not a V, errSentinelFound is returned +// and load is not called. That branch cannot be targeted deterministically from +// outside without code instrumentation, because the re-check runs in the same +// goroutine as the initial miss with no synchronisation point between them. +// The sentinel-stored-during-load path (TestRestorableCache_Sentinel_StoredDuringLoad) +// and the LoadOrStore guard cover the concurrent-store window that follows. +func TestRestorableCache_Sentinel_BlocksRestoreViaInitialHit(t *testing.T) { t.Parallel() - var evictedKeys []string - var mu sync.Mutex - - c := New(2, - func(key string) (string, error) { return "val-" + key, nil }, - alwaysAliveCheck, - func(key, _ string) { - mu.Lock() - evictedKeys = append(evictedKeys, key) - mu.Unlock() + loadCalled := false + c := newRestorableCache( + func(_ string) (string, error) { + loadCalled = true + return "loaded", nil }, + alwaysAliveCheck, + nil, ) - c.Get("a") //nolint:errcheck // a loaded (MRU) - c.Get("b") //nolint:errcheck // b loaded (MRU), a=LRU - c.Get("a") //nolint:errcheck // a accessed → a becomes MRU, b=LRU - c.Get("c") //nolint:errcheck // c loaded → evicts b (LRU), not a - - mu.Lock() - defer mu.Unlock() - assert.Equal(t, []string{"b"}, evictedKeys, "b should be evicted (LRU after a was re-accessed)") + // Sentinel is present before Get is called: the initial c.m.Load hit path + // returns (zero, false) without entering the singleflight group. + c.Store("k", testSentinel{}) - _, aPresent := c.Get("a") - assert.True(t, aPresent, "a should still be in cache") + v, ok := c.Get("k") + assert.False(t, ok, "Get must return not-found when sentinel is present") + assert.Empty(t, v) + assert.False(t, loadCalled, "load must not be called when a sentinel is in the cache") } -func TestValidatingCache_LRU_SetRefreshesMRUPosition(t *testing.T) { +func TestRestorableCache_Peek_MissingKey_ReturnsFalse(t *testing.T) { t.Parallel() - var evictedKeys []string - var mu sync.Mutex - - c := New(2, - func(key string) (string, error) { return "val-" + key, nil }, + c := newStringCache( + func(string) (string, error) { return "", nil }, alwaysAliveCheck, - func(key, _ string) { - mu.Lock() - evictedKeys = append(evictedKeys, key) - mu.Unlock() - }, + nil, ) - c.Get("a") //nolint:errcheck // a=MRU - c.Get("b") //nolint:errcheck // b=MRU, a=LRU - c.Set("a", "x") // Set refreshes a to MRU; b becomes LRU - c.Get("c") //nolint:errcheck // c loaded → evicts b - - mu.Lock() - defer mu.Unlock() - assert.Equal(t, []string{"b"}, evictedKeys) + _, ok := c.Peek("absent") + assert.False(t, ok) } -func TestValidatingCache_LRU_CapacityOne(t *testing.T) { - t.Parallel() +// --------------------------------------------------------------------------- +// CompareAndSwap +// --------------------------------------------------------------------------- - var evictedKeys []string - var mu sync.Mutex +func TestRestorableCache_CompareAndSwap_Success(t *testing.T) { + t.Parallel() - c := New(1, - func(key string) (string, error) { return "val-" + key, nil }, + c := newStringCache( + func(_ string) (string, error) { return "v1", nil }, alwaysAliveCheck, - func(key, _ string) { - mu.Lock() - evictedKeys = append(evictedKeys, key) - mu.Unlock() - }, + nil, ) + c.Get("k") //nolint:errcheck // prime with "v1" - c.Get("a") //nolint:errcheck - c.Get("b") //nolint:errcheck // evicts a - c.Get("c") //nolint:errcheck // evicts b + swapped := c.CompareAndSwap("k", "v1", "v2") + require.True(t, swapped) - mu.Lock() - defer mu.Unlock() - assert.Equal(t, []string{"a", "b"}, evictedKeys) + raw, ok := c.Peek("k") + require.True(t, ok) + assert.Equal(t, "v2", raw) } -func TestValidatingCache_LRU_LargeCapacityNoEviction(t *testing.T) { +func TestRestorableCache_CompareAndSwap_WrongOld_Fails(t *testing.T) { t.Parallel() - const n = 100 - c := New(n+1, - func(key string) (string, error) { return "val-" + key, nil }, + c := newStringCache( + func(_ string) (string, error) { return "v1", nil }, alwaysAliveCheck, - func(key, _ string) { - t.Errorf("unexpected eviction for key %s", key) - }, + nil, ) + c.Get("k") //nolint:errcheck - for i := range n { - c.Get(fmt.Sprintf("k%d", i)) //nolint:errcheck - } - assert.Equal(t, n, c.Len(), "no entries should be evicted when under capacity") + swapped := c.CompareAndSwap("k", "wrong", "v2") + assert.False(t, swapped) } -func TestValidatingCache_LRU_Len(t *testing.T) { +// --------------------------------------------------------------------------- +// Delete +// --------------------------------------------------------------------------- + +func TestRestorableCache_Delete_RemovesEntry(t *testing.T) { t.Parallel() - c := New(5, + c := newStringCache( func(_ string) (string, error) { return "v", nil }, alwaysAliveCheck, nil, ) + c.Get("k") //nolint:errcheck + + c.Delete("k") - assert.Equal(t, 0, c.Len()) - c.Get("a") //nolint:errcheck - assert.Equal(t, 1, c.Len()) - c.Get("b") //nolint:errcheck - assert.Equal(t, 2, c.Len()) + _, ok := c.Peek("k") + assert.False(t, ok) } // --------------------------------------------------------------------------- // Re-check inside singleflight (TOCTOU prevention) // --------------------------------------------------------------------------- -func TestValidatingCache_Singleflight_ReCheckReturnsPreStoredValue(t *testing.T) { +func TestRestorableCache_Singleflight_ReCheckReturnsPreStoredValue(t *testing.T) { t.Parallel() + // Simulate the TOCTOU window: a goroutine sees a cache miss, then the + // value is stored externally before it enters the singleflight group. + // The re-check inside the group should find the value and skip load. var loadCount atomic.Int32 // The load function is gated: it waits until we signal that an external - // Set has been applied, mimicking a value written by another goroutine + // Store has been applied, mimicking a value written by another goroutine // between the miss check and the singleflight group. storeApplied := make(chan struct{}) c := newStringCache( func(_ string) (string, error) { - <-storeApplied // wait until external Set is applied + <-storeApplied // wait until external Store is applied loadCount.Add(1) return "from-load", nil }, @@ -400,14 +403,16 @@ func TestValidatingCache_Singleflight_ReCheckReturnsPreStoredValue(t *testing.T) result string ok bool ) - wg.Go(func() { + wg.Add(1) + go func() { + defer wg.Done() result, ok = c.Get("k") - }) + }() - // Set the value externally to simulate a concurrent writer, then release + // Store the value externally to simulate a concurrent writer, then release // the load function. The re-check at the top of the singleflight function // fires first and finds "external-value", so load is never called. - c.Set("k", "external-value") + c.Store("k", "external-value") close(storeApplied) wg.Wait() @@ -416,61 +421,11 @@ func TestValidatingCache_Singleflight_ReCheckReturnsPreStoredValue(t *testing.T) assert.Equal(t, int32(0), loadCount.Load(), "re-check should short-circuit before load is called") } -// TestValidatingCache_Singleflight_EvictsLoserWhenLoadRacesWriter covers the -// path where load() runs to completion but loses the ContainsOrAdd race to a -// concurrent Set. The loaded-but-discarded value must be passed to onEvict so -// any resources it holds (e.g. connections) can be cleaned up. -func TestValidatingCache_Singleflight_EvictsLoserWhenLoadRacesWriter(t *testing.T) { - t.Parallel() - - // loadReached is closed when load() is about to return, giving us a hook to - // race a Set before ContainsOrAdd is called. - loadReached := make(chan struct{}) - // allowReturn lets the test control exactly when load() returns. - allowReturn := make(chan struct{}) - - var evictedKey, evictedVal string - c := newStringCache( - func(_ string) (string, error) { - close(loadReached) // signal: load has run - <-allowReturn // wait until test injects the concurrent Set - return "from-load", nil - }, - alwaysAliveCheck, - func(key, val string) { - evictedKey = key - evictedVal = val - }, - ) - - var wg sync.WaitGroup - var gotVal string - var gotOk bool - wg.Go(func() { - gotVal, gotOk = c.Get("k") - }) - - // Wait until load() is running, then inject a concurrent Set so that - // ContainsOrAdd finds the key already present and discards the loaded value. - <-loadReached - c.Set("k", "from-set") - close(allowReturn) // let load() return "from-load" - wg.Wait() - - // The concurrent Set wins: caller receives the Set value. - require.True(t, gotOk) - assert.Equal(t, "from-set", gotVal, "concurrent Set value should win") - - // The loaded-but-discarded value must be passed to onEvict. - assert.Equal(t, "k", evictedKey, "onEvict must be called for the discarded loaded value") - assert.Equal(t, "from-load", evictedVal, "onEvict must receive the discarded loaded value") -} - // --------------------------------------------------------------------------- // Singleflight deduplication // --------------------------------------------------------------------------- -func TestValidatingCache_Singleflight_DeduplicatesConcurrentMisses(t *testing.T) { +func TestRestorableCache_Singleflight_DeduplicatesConcurrentMisses(t *testing.T) { t.Parallel() const goroutines = 10 diff --git a/pkg/vmcp/server/sessionmanager/factory.go b/pkg/vmcp/server/sessionmanager/factory.go index 8adf30ed44..73cfb83c3e 100644 --- a/pkg/vmcp/server/sessionmanager/factory.go +++ b/pkg/vmcp/server/sessionmanager/factory.go @@ -32,11 +32,6 @@ import ( const instrumentationName = "github.com/stacklok/toolhive/pkg/vmcp" -// defaultCacheCapacity is the fallback used when FactoryConfig.CacheCapacity is -// zero (the Go zero value). This ensures the cache is always bounded; omitting -// CacheCapacity from a config does not silently enable unbounded growth. -const defaultCacheCapacity = 1000 - // FactoryConfig holds the session factory construction parameters that the // session manager needs to build its decorating factory. It is separate from // server.Config to avoid a circular import between the server and sessionmanager @@ -67,13 +62,6 @@ type FactoryConfig struct { // If non-nil, the optimizer factory (whether derived from OptimizerConfig or // supplied via OptimizerFactory) and workflow executors are wrapped with telemetry. TelemetryProvider *telemetry.Provider - - // CacheCapacity is the maximum number of live MultiSession entries held in - // the node-local ValidatingCache. When the cache is full the least-recently-used - // session is evicted (its backend connections are closed via onEvict). A value of - // 0 uses defaultCacheCapacity (1000). Negative values are rejected by - // sessionmanager.New. - CacheCapacity int } // resolveOptimizer wires the optimizer factory from cfg, applying telemetry diff --git a/pkg/vmcp/server/sessionmanager/session_manager.go b/pkg/vmcp/server/sessionmanager/session_manager.go index b7000c9b3f..f2394ad68d 100644 --- a/pkg/vmcp/server/sessionmanager/session_manager.go +++ b/pkg/vmcp/server/sessionmanager/session_manager.go @@ -26,7 +26,6 @@ import ( mcpserver "github.com/mark3labs/mcp-go/server" "github.com/stacklok/toolhive/pkg/auth" - "github.com/stacklok/toolhive/pkg/cache" transportsession "github.com/stacklok/toolhive/pkg/transport/session" "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/conversion" @@ -44,6 +43,13 @@ const ( MetadataValTrue = "true" ) +// terminatedSentinel is stored in sessions when Terminate() begins tearing +// down a MultiSession. sessions.Get returns (nil, false) for sentinel entries +// (non-V values), and DecorateSession's CAS-based re-check will fail, +// preventing concurrent writers from resurrecting a storage record that +// Terminate() has already deleted. +type terminatedSentinel struct{} + // Manager bridges the domain session lifecycle (MultiSession / MultiSessionFactory) // to the mark3labs SDK's SessionIdManager interface. // @@ -68,12 +74,6 @@ const ( // sticky routing when session-affinity is desired. When Redis is used as the // session-storage backend the metadata is durable across pod restarts, and the // live MultiSession can be re-created via factory.RestoreSession() on a cache miss. -// -// TODO: Long-term, the cache and storage should be layered behind a single -// interface so the session manager does not need to coordinate between them. -// Reads would go through the cache (handling misses, singleflight, and liveness -// transparently); writes go to storage; caching is an implementation detail -// hidden from the caller. type Manager struct { storage transportsession.DataStorage factory vmcpsession.MultiSessionFactory @@ -84,7 +84,7 @@ type Manager struct { // (HTTP connections, routing tables). On a cache miss it restores the // session from stored metadata; on a cache hit it confirms liveness via // storage.Load, which also refreshes the Redis TTL. - sessions *cache.ValidatingCache[string, vmcpsession.MultiSession] + sessions *RestorableCache[string, vmcpsession.MultiSession] } // New creates a Manager backed by the given SessionDataStorage and backend @@ -102,13 +102,6 @@ func New( if cfg == nil || cfg.Base == nil { return nil, nil, fmt.Errorf("sessionmanager.New: FactoryConfig.Base (SessionFactory) is required") } - if cfg.CacheCapacity < 0 { - return nil, nil, fmt.Errorf("sessionmanager.New: CacheCapacity must be >= 0 (got %d)", cfg.CacheCapacity) - } - capacity := cfg.CacheCapacity - if capacity == 0 { - capacity = defaultCacheCapacity - } if len(cfg.WorkflowDefs) > 0 && cfg.ComposerFactory == nil { return nil, nil, fmt.Errorf("sessionmanager.New: ComposerFactory is required when WorkflowDefs are provided") } @@ -142,8 +135,7 @@ func New( backendReg: backendRegistry, } - sm.sessions = cache.New( - capacity, + sm.sessions = newRestorableCache( sm.loadSession, sm.checkSession, func(id string, sess vmcpsession.MultiSession) { @@ -151,7 +143,7 @@ func New( slog.Warn("session cache: error closing evicted session", "session_id", id, "error", closeErr) } - slog.Warn("session cache: session evicted from node-local cache", + slog.Warn("session cache: evicted expired session from node-local cache", "session_id", id) }, ) @@ -351,29 +343,16 @@ func (sm *Manager) CreateSession( // Persist the serialisable session metadata to the pluggable backend (e.g. // Redis) so that Validate() and TTL management work correctly. The live // MultiSession itself is cached in the node-local multiSessions map below. - // - // Use Update (SET XX) rather than Upsert to close the TOCTOU window between - // the second placeholder check above and this write. If Terminate deleted the - // key in that window, Update returns (false, nil) and we bail without - // resurrecting the deleted session. storeCtx, storeCancel := context.WithTimeout(ctx, createSessionStorageTimeout) defer storeCancel() - stored, err := sm.storage.Update(storeCtx, sessionID, sess.GetMetadata()) - if err != nil { + if err := sm.storage.Upsert(storeCtx, sessionID, sess.GetMetadata()); err != nil { _ = sess.Close() sm.cleanupFailedPlaceholder(sessionID, placeholder2) return nil, fmt.Errorf("Manager.CreateSession: failed to store session metadata: %w", err) } - if !stored { - _ = sess.Close() - return nil, fmt.Errorf( - "Manager.CreateSession: session %q was terminated between placeholder check and metadata store", - sessionID, - ) - } // Cache the live MultiSession so that GetMultiSession can retrieve it. - sm.sessions.Set(sessionID, sess) + sm.sessions.Store(sessionID, sess) slog.Debug("Manager: created multi-session", "session_id", sessionID, @@ -387,21 +366,13 @@ func (sm *Manager) CreateSession( // as a valid session), and prevents repeated Validate() calls from refreshing // the Redis TTL and keeping the placeholder alive indefinitely. // -// Uses Update (SET XX) so that a Terminate() that already deleted the key is -// not inadvertently resurrected as a terminated entry. -// // Cleanup is best-effort: errors are logged but not returned, since the caller // already has an error to report. func (sm *Manager) cleanupFailedPlaceholder(sessionID string, metadata map[string]string) { - // Copy before mutating so the caller's map is not modified. - terminated := make(map[string]string, len(metadata)+1) - for k, v := range metadata { - terminated[k] = v - } - terminated[MetadataKeyTerminated] = MetadataValTrue + metadata[MetadataKeyTerminated] = MetadataValTrue cleanupCtx, cancel := context.WithTimeout(context.Background(), createSessionStorageTimeout) defer cancel() - if _, err := sm.storage.Update(cleanupCtx, sessionID, terminated); err != nil { + if err := sm.storage.Upsert(cleanupCtx, sessionID, metadata); err != nil { slog.Warn("Manager.CreateSession: failed to mark failed placeholder as terminated; it will linger until TTL expires", "session_id", sessionID, "error", err) } @@ -444,10 +415,11 @@ func (sm *Manager) Validate(sessionID string) (isTerminated bool, err error) { // where client termination during the Phase 1→Phase 2 window could resurrect // sessions with open backend connections: // -// - MultiSession (Phase 2): the storage key is deleted. The node-local cache -// self-heals on the next Get: checkSession detects ErrSessionNotFound, -// evicts the entry, and onEvict closes backend connections. After deletion -// Validate() returns (false, error) — the same response as "never existed". +// - MultiSession (Phase 2): Close() releases backend connections, then the +// session is deleted from storage immediately. After deletion Validate() +// returns (false, error) — the same response as "never existed". This is +// intentional: a terminated MultiSession has no resources to preserve, so +// immediate removal is cleaner than marking and waiting for TTL. // // - Placeholder (Phase 1): the session is marked terminated=true and left // for TTL cleanup. This prevents CreateSession() from opening backend @@ -466,10 +438,46 @@ func (sm *Manager) Terminate(sessionID string) (isNotAllowed bool, err error) { ctx, cancel := context.WithTimeout(context.Background(), terminateTimeout) defer cancel() - // Load current metadata to determine session phase. + // Check the node-local cache first: a fully-formed MultiSession is stored + // here while this pod owns it. + if v, ok := sm.sessions.Peek(sessionID); ok { + // A terminatedSentinel means another goroutine is already tearing down + // this session. Do not fall through to the placeholder path — that would + // race with the concurrent Terminate's storage.Delete and potentially + // recreate the storage record after it was deleted. + if _, isSentinel := v.(terminatedSentinel); isSentinel { + slog.Debug("Manager.Terminate: concurrent termination in progress, skipping", + "session_id", sessionID) + return false, nil + } + if multiSess, ok := v.(vmcpsession.MultiSession); ok { + // Publish the tombstone before deleting from storage. Any concurrent + // GetMultiSession call will see the terminatedSentinel and return + // (nil, false), and DecorateSession's CAS-based re-check will fail, + // preventing both from recreating the storage record after we delete it. + sm.sessions.Store(sessionID, terminatedSentinel{}) + + if deleteErr := sm.storage.Delete(ctx, sessionID); deleteErr != nil { + // Rollback: restore the live session so the caller can retry. + sm.sessions.Store(sessionID, multiSess) + return false, fmt.Errorf("Manager.Terminate: failed to delete session from storage: %w", deleteErr) + } + + // Storage is clean; remove the sentinel and release backend connections. + sm.sessions.Delete(sessionID) + if closeErr := multiSess.Close(); closeErr != nil { + slog.Warn("Manager.Terminate: error closing multi-session backend connections", + "session_id", sessionID, "error", closeErr) + } + slog.Info("Manager.Terminate: session terminated", "session_id", sessionID) + return false, nil + } + } + + // No MultiSession in the local map — treat as a placeholder session. + // Load current metadata, mark as terminated, and store back. metadata, loadErr := sm.storage.Load(ctx, sessionID) if errors.Is(loadErr, transportsession.ErrSessionNotFound) { - // Already gone (concurrent termination or TTL expiry). slog.Debug("Manager.Terminate: session not found (already expired?)", "session_id", sessionID) return false, nil } @@ -477,39 +485,36 @@ func (sm *Manager) Terminate(sessionID string) (isNotAllowed bool, err error) { return false, fmt.Errorf("Manager.Terminate: failed to load session %q: %w", sessionID, loadErr) } - if _, isFullSession := metadata[sessiontypes.MetadataKeyTokenHash]; isFullSession { - // Phase 2 (full MultiSession): delete from storage, then evict from the - // node-local cache so onEvict closes backend connections immediately rather - // than waiting for the next Get or an LRU eviction. - if deleteErr := sm.storage.Delete(ctx, sessionID); deleteErr != nil { - return false, fmt.Errorf("Manager.Terminate: failed to delete session from storage: %w", deleteErr) - } - sm.sessions.Remove(sessionID) - slog.Info("Manager.Terminate: session terminated", "session_id", sessionID) - return false, nil - } - - // Phase 1 (placeholder): mark terminated so CreateSession fast-fails and - // Validate returns isTerminated=true during the TTL window. - // Use Update (SET XX) rather than Upsert so we never resurrect a key that - // was concurrently deleted or expired between the Load above and this write. - // (false, nil) means already gone — treat as success. + // Placeholder session (not yet upgraded to MultiSession). + // + // This handles the race condition where a client sends DELETE between + // Generate() (Phase 1) and CreateSession() (Phase 2). The two-phase + // pattern creates a window where the session exists as a placeholder: + // + // 1. Client sends initialize → Generate() creates placeholder + // 2. Client sends DELETE before OnRegisterSession hook fires + // 3. We mark the placeholder as terminated (don't delete it) + // 4. CreateSession() hook fires → sees terminated flag → fails fast + // + // Without this branch, CreateSession() would open backend HTTP connections + // for a session the client already terminated, silently resurrecting it. + // + // We mark (not delete) so Validate() can return isTerminated=true, which + // lets the SDK distinguish "actively terminated" from "never existed". + // TTL cleanup will remove the placeholder later. metadata[MetadataKeyTerminated] = MetadataValTrue - updated, storeErr := sm.storage.Update(ctx, sessionID, metadata) - if storeErr != nil { + if storeErr := sm.storage.Upsert(ctx, sessionID, metadata); storeErr != nil { slog.Warn("Manager.Terminate: failed to persist terminated flag for placeholder; attempting delete fallback", "session_id", sessionID, "error", storeErr) + // Use a fresh context: if ctx expired (deadline exceeded), the same + // context would cause the fallback delete to fail immediately too. deleteCtx, deleteCancel := context.WithTimeout(context.Background(), terminateTimeout) + defer deleteCancel() if deleteErr := sm.storage.Delete(deleteCtx, sessionID); deleteErr != nil { - deleteCancel() return false, fmt.Errorf( "Manager.Terminate: failed to persist terminated flag and delete placeholder: storeErr=%v, deleteErr=%w", storeErr, deleteErr) } - deleteCancel() - } else if !updated { - // Session expired or was concurrently deleted between Load and Update — already gone. - slog.Debug("Manager.Terminate: placeholder already gone before terminated flag could be set", "session_id", sessionID) } slog.Info("Manager.Terminate: session terminated", "session_id", sessionID) @@ -522,15 +527,13 @@ func (sm *Manager) Terminate(sessionID string) (isNotAllowed bool, err error) { // cross-pod RestoreSession call does not attempt to reconnect to the expired // backend session. // -// After a successful storage update, the cached entry is not immediately evicted. -// On the next GetMultiSession call, checkSession detects that the stored -// MetadataKeyBackendIDs differs from the cached session's value, evicts the stale -// entry via onEvict, and triggers RestoreSession with the updated metadata. -// On storage error, no eviction occurs and the caller retries on the next access. +// After a successful storage update the session is evicted from the node-local +// cache; the next GetMultiSession call triggers RestoreSession with the updated +// metadata, discarding the stale in-memory copy. // // This is a best-effort operation. If the session is absent from storage (not // found or terminated) the call is a silent no-op. Storage errors are logged -// but not returned. +// but not returned; on error the cache is not evicted. func (sm *Manager) NotifyBackendExpired(sessionID, workloadID string) { loadCtx, loadCancel := context.WithTimeout(context.Background(), notifyBackendExpiredTimeout) defer loadCancel() @@ -581,27 +584,50 @@ func (sm *Manager) NotifyBackendExpired(sessionID, workloadID string) { } } -// updateMetadata writes a complete metadata snapshot to storage using a -// conditional Update (SET XX). If the key is absent at update time (concurrent -// Delete), the call is a no-op. The cache self-heals on the next GetMultiSession -// call: checkSession detects metadata drift, evicts the stale entry, and -// RestoreSession reloads with fresh state. +// updateMetadata writes a complete metadata snapshot to storage and evicts the +// session from the node-local cache so the next GetMultiSession call triggers a +// fresh RestoreSession with the updated state. +// +// Cross-pod TOCTOU: a re-check Load is performed immediately before the Upsert +// to detect cross-pod session termination (where another pod calls +// storage.Delete). If the key is absent at re-check time we bail without +// upserting. A residual race remains between the re-check and the Upsert (a +// concurrent pod could delete the key in that window), but the window is now +// microseconds rather than the full NotifyBackendExpired span. Closing the race +// entirely would require a conditional write primitive (e.g. Redis SET XX / +// UpsertIfPresent) added to the DataStorage interface. +// +// NOTE: concurrent calls for the same session are last-write-wins. We assume +// parallel metadata writers within a session do not occur; NotifyBackendExpired +// is the only post-creation writer and backend expiry events are serialised by +// the backend registry. This can be retrofitted with CAS semantics or a version +// counter if that assumption changes. func (sm *Manager) updateMetadata(sessionID string, metadata map[string]string) error { + // Same-pod guard: if Terminate() is already tearing down this session on + // this pod the sentinel is in the cache and storage is already deleted. + if raw, ok := sm.sessions.Peek(sessionID); ok { + if _, isSentinel := raw.(terminatedSentinel); isSentinel { + return nil + } + } + ctx, cancel := context.WithTimeout(context.Background(), notifyBackendExpiredTimeout) defer cancel() - // Update only succeeds if the key still exists. A concurrent Delete (same - // pod or cross-pod) returns (false, nil), and we bail without resurrecting. - updated, err := sm.storage.Update(ctx, sessionID, metadata) - if err != nil { + // Cross-pod guard: re-check that the storage record still exists before + // upserting. If another pod terminated the session (deleting the key) after + // NotifyBackendExpired's initial Load, we must not recreate the record. + if _, err := sm.storage.Load(ctx, sessionID); err != nil { + if errors.Is(err, transportsession.ErrSessionNotFound) { + return nil // session was terminated elsewhere; nothing to update + } return err } - if !updated { - return nil // session was terminated; nothing to update + + if err := sm.storage.Upsert(ctx, sessionID, metadata); err != nil { + return err } - // The cache self-heals lazily: on the next GetMultiSession, checkSession detects - // either the absent storage key or stale MetadataKeyBackendIDs and evicts the - // entry, triggering a fresh RestoreSession. + sm.sessions.Delete(sessionID) return nil } @@ -638,26 +664,32 @@ func (sm *Manager) GetMultiSession(sessionID string) (vmcpsession.MultiSession, // replacing the old session and its backend connections. This ensures that a // backend-expiry update written by pod A propagates to pod B on the next // cache access rather than waiting for natural TTL expiry. -func (sm *Manager) checkSession(sessionID string, sess vmcpsession.MultiSession) error { +func (sm *Manager) checkSession(sessionID string) error { checkCtx, cancel := context.WithTimeout(context.Background(), restoreStorageTimeout) defer cancel() metadata, err := sm.storage.Load(checkCtx, sessionID) if errors.Is(err, transportsession.ErrSessionNotFound) { - return cache.ErrExpired + return ErrExpired } if err != nil { return err // transient storage error — keep cached } if metadata[MetadataKeyTerminated] == MetadataValTrue { - return cache.ErrExpired - } - - // Compare backend IDs to detect cross-pod metadata drift. - // Only compare when the cached session carries MetadataKeyBackendIDs to - // avoid spurious evictions for sessions that don't track backend IDs. - if cachedIDs, present := sess.GetMetadata()[vmcpsession.MetadataKeyBackendIDs]; present { - if cachedIDs != metadata[vmcpsession.MetadataKeyBackendIDs] { - return cache.ErrExpired + return ErrExpired + } + + // If the cached session has backend metadata and it differs from storage, + // evict to pick up the update. Only compare when the cached session + // explicitly carries MetadataKeyBackendIDs to avoid spurious evictions for + // sessions whose in-memory representation does not track backend IDs (e.g. + // test mocks that return an empty metadata map). + if raw, ok := sm.sessions.Peek(sessionID); ok { + if sess, ok := raw.(vmcpsession.MultiSession); ok { + if cachedIDs, present := sess.GetMetadata()[vmcpsession.MetadataKeyBackendIDs]; present { + if cachedIDs != metadata[vmcpsession.MetadataKeyBackendIDs] { + return ErrExpired + } + } } } @@ -715,9 +747,14 @@ func (sm *Manager) loadSession(sessionID string) (vmcpsession.MultiSession, erro // and stores the result back. Returns an error if the session is not found or // has not yet been upgraded from placeholder to MultiSession. // -// storage.Update is the concurrency guard. If it returns (false, nil), the -// session was deleted; the cache entry will be evicted on the next Get when -// checkSession detects ErrSessionNotFound. +// A re-check is performed immediately before storing to guard against a +// race with Terminate(): if the session is deleted between GetMultiSession and +// the store, the store would silently resurrect a terminated session. +// The re-check catches that window. A narrow TOCTOU gap remains between the +// re-check and the store, but its consequence is bounded: Terminate() already +// called Close() on the underlying MultiSession before deleting it, so any +// resurrected decorator wraps an already-closed session and will fail on first +// use rather than leaking backend connections. func (sm *Manager) DecorateSession(sessionID string, fn func(sessiontypes.MultiSession) sessiontypes.MultiSession) error { sess, ok := sm.GetMultiSession(sessionID) if !ok { @@ -730,24 +767,24 @@ func (sm *Manager) DecorateSession(sessionID string, fn func(sessiontypes.MultiS if decorated.ID() != sessionID { return fmt.Errorf("DecorateSession: decorator changed session ID from %q to %q", sessionID, decorated.ID()) } - - // Persist metadata to storage first via conditional Update (SET XX). - // Only update the node-local cache after a successful write so that a - // storage error or a concurrent delete never leaves a decorated (but - // unpersisted) value in the cache where retries could stack decorations. + // Atomically replace the original entry with the decorated one. + // If Terminate() has stored a terminatedSentinel between the first + // GetMultiSession call above and here, CompareAndSwap returns false and + // we bail out before touching storage — preventing resurrection of a + // terminated session's storage record. + if !sm.sessions.CompareAndSwap(sessionID, sess, decorated) { + return fmt.Errorf("DecorateSession: session %q was terminated or concurrently modified during decoration", sessionID) + } + // Persist updated metadata to storage. On failure, attempt to rollback + // the local-map entry so the caller can retry. If Terminate() has since + // replaced the decorated entry with a sentinel, the rollback CAS returns + // false and we leave the sentinel in place. decorateCtx, decorateCancel := context.WithTimeout(context.Background(), decorateTimeout) defer decorateCancel() - updated, err := sm.storage.Update(decorateCtx, sessionID, decorated.GetMetadata()) - if err != nil { + if err := sm.storage.Upsert(decorateCtx, sessionID, decorated.GetMetadata()); err != nil { + _ = sm.sessions.CompareAndSwap(sessionID, decorated, sess) return fmt.Errorf("DecorateSession: failed to store decorated session metadata: %w", err) } - if !updated { - // Session was deleted (by Terminate or TTL) between Get and Update. - // Evict the stale cache entry so onEvict closes backend connections. - sm.sessions.Remove(sessionID) - return fmt.Errorf("DecorateSession: session %q was deleted during decoration", sessionID) - } - sm.sessions.Set(sessionID, decorated) return nil } diff --git a/pkg/vmcp/server/sessionmanager/session_manager_test.go b/pkg/vmcp/server/sessionmanager/session_manager_test.go index 4c5a8a25bd..8728bf9ecc 100644 --- a/pkg/vmcp/server/sessionmanager/session_manager_test.go +++ b/pkg/vmcp/server/sessionmanager/session_manager_test.go @@ -16,7 +16,6 @@ import ( "go.uber.org/mock/gomock" "github.com/stacklok/toolhive/pkg/auth" - "github.com/stacklok/toolhive/pkg/cache" transportsession "github.com/stacklok/toolhive/pkg/transport/session" "github.com/stacklok/toolhive/pkg/vmcp" vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" @@ -93,9 +92,6 @@ func (alwaysFailDataStorage) Load(_ context.Context, _ string) (map[string]strin func (alwaysFailDataStorage) Create(_ context.Context, _ string, _ map[string]string) (bool, error) { return false, errors.New("storage unavailable") } -func (alwaysFailDataStorage) Update(_ context.Context, _ string, _ map[string]string) (bool, error) { - return false, errors.New("storage unavailable") -} func (alwaysFailDataStorage) Delete(_ context.Context, _ string) error { return nil } func (alwaysFailDataStorage) Close() error { return nil } @@ -180,7 +176,7 @@ func newTestSessionManager( ) (*Manager, transportsession.DataStorage) { t.Helper() storage := newTestSessionDataStorage(t) - sm, cleanup, err := New(storage, &FactoryConfig{Base: factory, CacheCapacity: 1000}, registry) + sm, cleanup, err := New(storage, &FactoryConfig{Base: factory}, registry) require.NoError(t, err) t.Cleanup(func() { _ = cleanup(context.Background()) }) return sm, storage @@ -220,7 +216,7 @@ func TestSessionManager_Generate(t *testing.T) { ctrl := gomock.NewController(t) sess := newMockSession(t, ctrl, "placeholder", nil) factory := newMockFactory(t, ctrl, sess) - sm, cleanup, err := New(alwaysFailDataStorage{}, &FactoryConfig{Base: factory, CacheCapacity: 1000}, newFakeRegistry()) + sm, cleanup, err := New(alwaysFailDataStorage{}, &FactoryConfig{Base: factory}, newFakeRegistry()) require.NoError(t, err) t.Cleanup(func() { _ = cleanup(context.Background()) }) @@ -577,8 +573,7 @@ func TestSessionManager_Terminate(t *testing.T) { MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { createdSess = newMockSession(t, ctrl, id, tools) - // Close() is called eagerly by onEvict when Terminate removes - // the entry from the node-local cache after storage.Delete. + // Close() will be called exactly once during Terminate createdSess.EXPECT().Close().Return(nil).Times(1) return createdSess, nil }).Times(1) @@ -594,7 +589,7 @@ func TestSessionManager_Terminate(t *testing.T) { require.NoError(t, err) require.NotNil(t, createdSess) - // Terminate deletes from storage and removes from cache; onEvict fires Close(). + // Terminate should close the backend connections. isNotAllowed, err := sm.Terminate(sessionID) require.NoError(t, err) assert.False(t, isNotAllowed) @@ -610,8 +605,7 @@ func TestSessionManager_Terminate(t *testing.T) { MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, nil) - // Close is called by onEvict when Terminate removes the cache entry. - sess.EXPECT().Close().Return(nil).AnyTimes() + sess.EXPECT().Close().Return(nil).Times(1) return sess, nil }).Times(1) @@ -624,12 +618,6 @@ func TestSessionManager_Terminate(t *testing.T) { _, err := sm.CreateSession(context.Background(), sessionID) require.NoError(t, err) - // Seed MetadataKeyTokenHash into storage so Terminate recognises this - // as a Phase 2 (full MultiSession) and deletes rather than marks terminated. - require.NoError(t, storage.Upsert(context.Background(), sessionID, map[string]string{ - sessiontypes.MetadataKeyTokenHash: "", - })) - // Session must exist before termination. _, loadErr := storage.Load(context.Background(), sessionID) assert.NoError(t, loadErr, "session should exist in storage before Terminate") @@ -685,7 +673,7 @@ func TestSessionManager_Terminate(t *testing.T) { failStoreAfter: 1, // fail after 1 successful call (Generate's Create) failDelete: false, } - sm, cleanup, err := New(failingStorage, &FactoryConfig{Base: factory, CacheCapacity: 1000}, registry) + sm, cleanup, err := New(failingStorage, &FactoryConfig{Base: factory}, registry) require.NoError(t, err) t.Cleanup(func() { _ = cleanup(context.Background()) }) @@ -723,7 +711,7 @@ func TestSessionManager_Terminate(t *testing.T) { failStoreAfter: 1, // fail after 1 successful call (Generate's Create) failDelete: true, } - sm, cleanup, err := New(failingStorage, &FactoryConfig{Base: factory, CacheCapacity: 1000}, registry) + sm, cleanup, err := New(failingStorage, &FactoryConfig{Base: factory}, registry) require.NoError(t, err) t.Cleanup(func() { _ = cleanup(context.Background()) }) @@ -1887,26 +1875,20 @@ func TestSessionManager_DecorateSession(t *testing.T) { return sess, nil }).Times(1) - sm, storage := newTestSessionManager(t, factory, newFakeRegistry()) + sm, _ := newTestSessionManager(t, factory, newFakeRegistry()) sessionID := sm.Generate() require.NotEmpty(t, sessionID) _, err := sm.CreateSession(context.Background(), sessionID) require.NoError(t, err) - // Seed MetadataKeyTokenHash into storage so Terminate recognises this - // as a Phase 2 (full MultiSession) and deletes rather than marks terminated. - require.NoError(t, storage.Upsert(context.Background(), sessionID, map[string]string{ - sessiontypes.MetadataKeyTokenHash: "", - })) - err = sm.DecorateSession(sessionID, func(sess sessiontypes.MultiSession) sessiontypes.MultiSession { // Simulate concurrent Terminate() completing during decoration. _, _ = sm.Terminate(sessionID) return sess }) require.Error(t, err) - assert.Contains(t, err.Error(), "was deleted during decoration") + assert.Contains(t, err.Error(), "was terminated or concurrently modified during decoration") // The session must not be resurrected. _, ok := sm.GetMultiSession(sessionID) @@ -1934,21 +1916,13 @@ func TestSessionManager_CheckSession(t *testing.T) { return f } - makeEmptySess := func(t *testing.T) vmcpsession.MultiSession { - t.Helper() - ctrl := gomock.NewController(t) - m := sessionmocks.NewMockMultiSession(ctrl) - m.EXPECT().GetMetadata().Return(map[string]string{}).AnyTimes() - return m - } - t.Run("alive session returns nil", func(t *testing.T) { t.Parallel() sm, storage := newTestSessionManager(t, makeFactory(t), newFakeRegistry()) sessionID := "alive-session" require.NoError(t, storage.Upsert(context.Background(), sessionID, map[string]string{})) - err := sm.checkSession(sessionID, makeEmptySess(t)) + err := sm.checkSession(sessionID) assert.NoError(t, err, "alive session must return nil") }) @@ -1956,8 +1930,8 @@ func TestSessionManager_CheckSession(t *testing.T) { t.Parallel() sm, _ := newTestSessionManager(t, makeFactory(t), newFakeRegistry()) - err := sm.checkSession("nonexistent-session", makeEmptySess(t)) - assert.ErrorIs(t, err, cache.ErrExpired, "deleted session must return ErrExpired") + err := sm.checkSession("nonexistent-session") + assert.ErrorIs(t, err, ErrExpired, "deleted session must return ErrExpired") }) t.Run("terminated session returns ErrExpired", func(t *testing.T) { @@ -1971,8 +1945,8 @@ func TestSessionManager_CheckSession(t *testing.T) { MetadataKeyTerminated: MetadataValTrue, })) - err := sm.checkSession(sessionID, makeEmptySess(t)) - assert.ErrorIs(t, err, cache.ErrExpired, "terminated session must return ErrExpired") + err := sm.checkSession(sessionID) + assert.ErrorIs(t, err, ErrExpired, "terminated session must return ErrExpired") }) t.Run("stale backend list triggers cross-pod eviction", func(t *testing.T) { @@ -1996,10 +1970,10 @@ func TestSessionManager_CheckSession(t *testing.T) { cached.EXPECT().GetMetadata().Return(map[string]string{ vmcpsession.MetadataKeyBackendIDs: "backend-a,backend-b", }).AnyTimes() - sm.sessions.Set(sessionID, cached) + sm.sessions.Store(sessionID, cached) - err := sm.checkSession(sessionID, cached) - assert.ErrorIs(t, err, cache.ErrExpired, + err := sm.checkSession(sessionID) + assert.ErrorIs(t, err, ErrExpired, "stale backend list must return ErrExpired to trigger cross-pod eviction") }) @@ -2017,9 +1991,9 @@ func TestSessionManager_CheckSession(t *testing.T) { cached.EXPECT().GetMetadata().Return(map[string]string{ vmcpsession.MetadataKeyBackendIDs: "backend-a", }).AnyTimes() - sm.sessions.Set(sessionID, cached) + sm.sessions.Store(sessionID, cached) - err := sm.checkSession(sessionID, cached) + err := sm.checkSession(sessionID) assert.NoError(t, err, "matching backend list must return nil") }) @@ -2037,9 +2011,9 @@ func TestSessionManager_CheckSession(t *testing.T) { ctrl := gomock.NewController(t) cached := sessionmocks.NewMockMultiSession(ctrl) cached.EXPECT().GetMetadata().Return(map[string]string{}).AnyTimes() - sm.sessions.Set(sessionID, cached) + sm.sessions.Store(sessionID, cached) - err := sm.checkSession(sessionID, cached) + err := sm.checkSession(sessionID) assert.NoError(t, err, "absent MetadataKeyBackendIDs in cache must not cause eviction") }) } @@ -2192,12 +2166,6 @@ func TestNotifyBackendExpired(t *testing.T) { _, err := sm.CreateSession(t.Context(), sessionID) require.NoError(t, err) - // Seed MetadataKeyTokenHash into storage so Terminate recognises this - // as a Phase 2 (full MultiSession) and deletes rather than marks terminated. - require.NoError(t, storage.Upsert(context.Background(), sessionID, map[string]string{ - sessiontypes.MetadataKeyTokenHash: "", - })) - _, err = sm.Terminate(sessionID) require.NoError(t, err) @@ -2209,13 +2177,9 @@ func TestNotifyBackendExpired(t *testing.T) { "terminated session must not be resurrected by NotifyBackendExpired") }) - t.Run("same-pod termination: storage.Update returns false, no resurrection", func(t *testing.T) { + t.Run("concurrent termination: sentinel prevents resurrection after Load succeeds", func(t *testing.T) { t.Parallel() - // Verify that updateMetadata's storage.Update (SET XX) prevents - // resurrection even when Terminate runs concurrently on the same pod. - // We model Terminate completing (key deleted) before updateMetadata - // reaches its storage.Update call. ctrl := gomock.NewController(t) registry := newFakeRegistry() sess := newMockSession(t, ctrl, "s", nil) @@ -2232,17 +2196,22 @@ func TestNotifyBackendExpired(t *testing.T) { map[string]string{"workload-a": "sess-a"}, ) - // Simulate Terminate having completed its storage.Delete already. - require.NoError(t, storage.Delete(context.Background(), sessionID)) + // Simulate Terminate-in-progress: inject the terminatedSentinel directly + // into the node-local cache (as Terminate does before calling + // storage.Delete) while leaving storage intact. This models the TOCTOU + // window where NotifyBackendExpired's Load succeeded before Terminate's + // storage.Delete ran but our sentinel check runs while the sentinel is + // still present. + sm.sessions.Store(sessionID, terminatedSentinel{}) - // storage.Update (SET XX) in updateMetadata returns (false, nil) because - // the key no longer exists — NotifyBackendExpired must bail without - // recreating the record. + // NotifyBackendExpired must detect the terminatedSentinel and bail + // before Upsert, leaving the storage record unmodified. sm.NotifyBackendExpired(sessionID, "workload-a") - _, loadErr := storage.Load(context.Background(), sessionID) - assert.ErrorIs(t, loadErr, transportsession.ErrSessionNotFound, - "NotifyBackendExpired must not resurrect a session whose storage key was deleted by Terminate") + got, loadErr := storage.Load(context.Background(), sessionID) + require.NoError(t, loadErr) + assert.Equal(t, "workload-a", got[vmcpsession.MetadataKeyBackendIDs], + "storage must not be modified when terminatedSentinel is present") }) t.Run("cross-pod termination: absent storage key is a no-op (no resurrection)", func(t *testing.T) { @@ -2278,7 +2247,7 @@ func TestNotifyBackendExpired(t *testing.T) { "NotifyBackendExpired must not resurrect a session terminated by another pod") }) - t.Run("lazy eviction: session stays in cache immediately after NotifyBackendExpired", func(t *testing.T) { + t.Run("evicts session from node-local cache on success", func(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) @@ -2292,8 +2261,9 @@ func TestNotifyBackendExpired(t *testing.T) { _, err := sm.CreateSession(t.Context(), sessionID) require.NoError(t, err) - // Session must be in cache after CreateSession. - assert.Equal(t, 1, sm.sessions.Len(), "session must be in node-local cache after CreateSession") + // CreateSession must have populated the node-local cache. + _, cached := sm.sessions.Peek(sessionID) + require.True(t, cached, "session must be in node-local cache after CreateSession") seedBackendMetadata(t, storage, sessionID, []string{"workload-a"}, @@ -2302,10 +2272,11 @@ func TestNotifyBackendExpired(t *testing.T) { sm.NotifyBackendExpired(sessionID, "workload-a") - // With lazy eviction, session is still in cache immediately after NotifyBackendExpired. - // checkSession detects drift on the next GetMultiSession call. - assert.Equal(t, 1, sm.sessions.Len(), - "session must still be in cache immediately after NotifyBackendExpired (eviction is lazy)") + // The session must have been evicted so the next GetMultiSession call + // triggers RestoreSession with the updated (backend-free) metadata. + _, stillCached := sm.sessions.Peek(sessionID) + assert.False(t, stillCached, + "session must be evicted from node-local cache after NotifyBackendExpired") }) }