diff --git a/client/files.go b/client/files.go index 501589eb..efeb06f5 100644 --- a/client/files.go +++ b/client/files.go @@ -95,11 +95,15 @@ func (c *Client) Save(ctx context.Context, key Key, baseDir string, paths []stri } } - wc, err := c.Create(ctx, key, headers, cfg.ttl) + createCtx, cancelCreate := context.WithCancelCause(ctx) + defer cancelCreate(nil) + + wc, err := c.Create(createCtx, key, headers, cfg.ttl) if err != nil { return errors.Wrap(err, "failed to create object") } if err := Archive(ctx, wc, baseDir, paths, cfg.exclude, cfg.zstdThreads); err != nil { + cancelCreate(err) return errors.Join(err, wc.Close()) } return errors.Wrap(wc.Close(), "failed to close writer") diff --git a/client/files_test.go b/client/files_test.go index 5774cd53..ace44607 100644 --- a/client/files_test.go +++ b/client/files_test.go @@ -120,6 +120,24 @@ func TestHashFilesSkipsDirectories(t *testing.T) { assert.Equal(t, h1, h2, "directories should be skipped, not cause errors") } +func TestSaveAbortOnArchiveFailure(t *testing.T) { + srv := newFakeServer(nil) + defer srv.Close() + + c := client.New(srv.URL, nil).Namespace("test") + defer c.Close() + + key := client.NewKey("should-not-exist") + + // Save from a nonexistent directory — Archive will fail. + err := c.Save(t.Context(), key, "/nonexistent/path", []string{"."}) + assert.Error(t, err) + + // The object must not have been persisted. + _, _, err = c.Open(t.Context(), key) + assert.IsError(t, err, os.ErrNotExist) +} + func TestHashKeySaveRestore(t *testing.T) { srv := newFakeServer(nil) defer srv.Close() diff --git a/cmd/cachew/main.go b/cmd/cachew/main.go index 3deee38e..ce1273c1 100644 --- a/cmd/cachew/main.go +++ b/cmd/cachew/main.go @@ -136,12 +136,16 @@ func (c *PutCmd) Run(ctx context.Context, api *client.Client) error { headers.Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filepath.Base(filename))) //nolint:perfsprint } - wc, err := api.Namespace(c.Namespace).Create(ctx, c.Key.Key(), headers, c.TTL) + createCtx, cancelCreate := context.WithCancelCause(ctx) + defer cancelCreate(nil) + + wc, err := api.Namespace(c.Namespace).Create(createCtx, c.Key.Key(), headers, c.TTL) if err != nil { return errors.Wrap(err, "failed to create object") } if _, err := io.Copy(wc, c.Input); err != nil { + cancelCreate(err) return errors.Join(errors.Wrap(err, "failed to copy data"), wc.Close()) } diff --git a/cmd/cachew/save_test.go b/cmd/cachew/save_test.go index 77c5edc0..e01b6447 100644 --- a/cmd/cachew/save_test.go +++ b/cmd/cachew/save_test.go @@ -2,6 +2,7 @@ package main import ( "context" + "io" "net/http" "net/http/httptest" "os" @@ -75,3 +76,36 @@ func TestRestoreCacheMiss(t *testing.T) { err := cmd.Run(context.Background(), api, &CLI{}) assert.True(t, errors.Is(err, errCacheMiss), "expected errCacheMiss sentinel, got %v", err) } + +func TestPutCmdCancelOnCopyFailure(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + _, _ = io.Copy(io.Discard, r.Body) + w.WriteHeader(http.StatusOK) + return + } + http.NotFound(w, nil) + })) + defer srv.Close() + + api := client.New(srv.URL, nil) + defer api.Close() + + // This mirrors what PutCmd.Run does internally: + // create with a cancellable context, write partial data, then cancel + // before closing — simulating an io.Copy error. + key := client.NewKey("put-cancel-test") + createCtx, cancel := context.WithCancelCause(t.Context()) + defer cancel(nil) + + wc, err := api.Namespace("test").Create(createCtx, key, nil, 0) + assert.NoError(t, err) + + _, _ = wc.Write([]byte("partial")) + + // Simulate the cancel-before-close pattern from PutCmd.Run + cancel(errors.New("copy failed")) + err = wc.Close() + assert.Error(t, err) + assert.Contains(t, err.Error(), "cancelled") +} diff --git a/internal/cache/http.go b/internal/cache/http.go index ec2ba016..ff0bde65 100644 --- a/internal/cache/http.go +++ b/internal/cache/http.go @@ -1,6 +1,7 @@ package cache import ( + "context" "io" "maps" "net/http" @@ -53,8 +54,10 @@ func FetchDirect(client *http.Client, r *http.Request, c Cache, key Key) (*http. } responseHeaders := maps.Clone(resp.Header) - cw, err := c.Create(r.Context(), key, responseHeaders, 0) + createCtx, cancelCreate := context.WithCancelCause(r.Context()) + cw, err := c.Create(createCtx, key, responseHeaders, 0) if err != nil { + cancelCreate(nil) _ = resp.Body.Close() return nil, httputil.Errorf(http.StatusInternalServerError, "failed to create cache entry: %w", err) } @@ -62,8 +65,12 @@ func FetchDirect(client *http.Client, r *http.Request, c Cache, key Key) (*http. originalBody := resp.Body pr, pw := io.Pipe() go func() { + defer cancelCreate(nil) mw := io.MultiWriter(pw, cw) _, copyErr := io.Copy(mw, originalBody) + if copyErr != nil { + cancelCreate(copyErr) + } closeErr := errors.Join(cw.Close(), originalBody.Close()) pw.CloseWithError(errors.Join(copyErr, closeErr)) }() diff --git a/internal/cache/http_test.go b/internal/cache/http_test.go index 407381ff..8dadccc2 100644 --- a/internal/cache/http_test.go +++ b/internal/cache/http_test.go @@ -6,6 +6,7 @@ import ( "log/slog" "net/http" "net/http/httptest" + "os" "testing" "time" @@ -83,3 +84,39 @@ func TestCachedFetchNonOKStatus(t *testing.T) { assert.NoError(t, resp.Body.Close()) assert.Equal(t, "not found", string(body)) } + +func TestFetchDirectAbortsOnPartialResponse(t *testing.T) { + _, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelError}) + memCache, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour}) + assert.NoError(t, err) + defer memCache.Close() + + // Backend that sends partial data then closes the connection. + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("partial")) + // Hijack the connection to simulate a mid-stream failure. + if hj, ok := w.(http.Hijacker); ok { + conn, _, _ := hj.Hijack() + _ = conn.Close() + } + })) + defer backend.Close() + + client := &http.Client{} + key := cache.NewKey(backend.URL + "/fail") + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, backend.URL+"/fail", nil) + assert.NoError(t, err) + + resp, err := cache.FetchDirect(client, req, memCache, key) + assert.NoError(t, err) + // Read the body — the copy goroutine will encounter the connection reset. + _, _ = io.ReadAll(resp.Body) + _ = resp.Body.Close() + + // The partial response must not be cached. + _, _, err = memCache.Open(ctx, key) + assert.IsError(t, err, os.ErrNotExist) +} diff --git a/internal/cache/s3_test.go b/internal/cache/s3_test.go index 8a604670..457c32bd 100644 --- a/internal/cache/s3_test.go +++ b/internal/cache/s3_test.go @@ -1,12 +1,14 @@ package cache_test import ( + "context" "log/slog" "os" "testing" "time" "github.com/alecthomas/assert/v2" + "github.com/alecthomas/errors" "github.com/block/cachew/internal/cache" "github.com/block/cachew/internal/cache/cachetest" @@ -40,6 +42,35 @@ func TestS3Cache(t *testing.T) { cachetest.Suite(t, func(t *testing.T) cache.Cache { return newS3Cache(t, bucket) }) } +// TestS3ContextCancellationAbortsUpload verifies that cancelling the context before +// closing the writer aborts the S3 upload and does not leave any object behind. +// This is the mechanism snapshot.CreatePaths uses to prevent partial/corrupt uploads. +func TestS3ContextCancellationAbortsUpload(t *testing.T) { + bucket := s3clienttest.Start(t) + c := newS3Cache(t, bucket) + defer c.Close() + + key := cache.NewKey("aborted-upload") + + ctx, cancel := context.WithCancelCause(t.Context()) + + w, err := c.Create(ctx, key, nil, time.Hour) + assert.NoError(t, err) + + // Write some data so this isn't just a 0-byte edge case. + _, err = w.Write([]byte("partial data that should not be persisted")) + assert.NoError(t, err) + + // Cancel the context before closing, simulating an archive failure. + cancel(errors.New("archive failed")) + err = w.Close() + assert.Error(t, err) + + // The object must not be retrievable. + _, _, err = c.Open(t.Context(), key) + assert.IsError(t, err, os.ErrNotExist) +} + func TestS3CacheSoak(t *testing.T) { if os.Getenv("SOAK_TEST") == "" { t.Skip("Skipping soak test; set SOAK_TEST=1 to run") diff --git a/internal/snapshot/snapshot.go b/internal/snapshot/snapshot.go index 6d1e4db5..49119dee 100644 --- a/internal/snapshot/snapshot.go +++ b/internal/snapshot/snapshot.go @@ -46,12 +46,18 @@ func CreatePaths(ctx context.Context, remote cache.Cache, key cache.Key, baseDir } } - wc, err := remote.Create(ctx, key, headers, ttl) + // Wrap the context so we can cancel the upload on archive failure, + // preventing partial data from being persisted. + createCtx, cancelCreate := context.WithCancelCause(ctx) + defer cancelCreate(nil) + + wc, err := remote.Create(createCtx, key, headers, ttl) if err != nil { return errors.Wrap(err, "failed to create object") } if err := client.Archive(ctx, wc, baseDir, includePaths, excludePatterns, threads); err != nil { + cancelCreate(err) return errors.Join(err, wc.Close()) } return errors.Wrap(wc.Close(), "failed to close writer") diff --git a/internal/strategy/apiv1.go b/internal/strategy/apiv1.go index 85e77d26..45a2b64a 100644 --- a/internal/strategy/apiv1.go +++ b/internal/strategy/apiv1.go @@ -131,14 +131,19 @@ func (d *APIV1) putObject(w http.ResponseWriter, r *http.Request) { // Extract and filter headers from request headers := httputil.FilterHeaders(r.Header, httputil.TransportHeaders...) + createCtx, cancelCreate := context.WithCancelCause(r.Context()) + defer cancelCreate(nil) + namespacedCache := d.cache.Namespace(namespace) - cw, err := namespacedCache.Create(r.Context(), key, headers, ttl) + cw, err := namespacedCache.Create(createCtx, key, headers, ttl) if err != nil { d.httpError(w, http.StatusInternalServerError, err, "Failed to create cache writer", "key", key) return } if _, err := io.Copy(cw, r.Body); err != nil { + cancelCreate(err) + _ = cw.Close() d.httpError(w, http.StatusInternalServerError, err, "Failed to copy request body to cache writer") return } diff --git a/internal/strategy/apiv1_test.go b/internal/strategy/apiv1_test.go new file mode 100644 index 00000000..eac16f08 --- /dev/null +++ b/internal/strategy/apiv1_test.go @@ -0,0 +1,65 @@ +package strategy_test + +import ( + "context" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/alecthomas/assert/v2" + + "github.com/block/cachew/internal/cache" + "github.com/block/cachew/internal/logging" + "github.com/block/cachew/internal/strategy" +) + +func TestPutObjectAbortsOnReadError(t *testing.T) { + _, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelError}) + memCache, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour}) + assert.NoError(t, err) + defer memCache.Close() + + mux := http.NewServeMux() + _, err = strategy.NewAPIV1(ctx, struct{}{}, memCache, mux) + assert.NoError(t, err) + + key := cache.NewKey("abort-test") + + // Create a reader that returns an error after some data. + body := &failingReader{data: []byte("partial data"), failAfter: 5} + req := httptest.NewRequest(http.MethodPost, "/api/v1/object/test/"+key.String(), body) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + + mux.ServeHTTP(w, req) + assert.Equal(t, http.StatusInternalServerError, w.Code) + + // The partial data must not be cached. + nsCache := memCache.Namespace("test") + _, _, err = nsCache.Open(ctx, key) + assert.IsError(t, err, os.ErrNotExist) +} + +// failingReader returns data up to failAfter bytes, then returns an error. +type failingReader struct { + data []byte + failAfter int + read int +} + +func (r *failingReader) Read(p []byte) (int, error) { + if r.read >= r.failAfter { + return 0, io.ErrUnexpectedEOF + } + n := min(len(p), r.failAfter-r.read, len(r.data)-r.read) + copy(p[:n], r.data[r.read:r.read+n]) + r.read += n + if r.read >= r.failAfter { + return n, io.ErrUnexpectedEOF + } + return n, nil +} diff --git a/internal/strategy/git/export_test.go b/internal/strategy/git/export_test.go index 34bc88f3..461106e8 100644 --- a/internal/strategy/git/export_test.go +++ b/internal/strategy/git/export_test.go @@ -3,6 +3,7 @@ package git import ( "context" + "github.com/block/cachew/internal/cache" "github.com/block/cachew/internal/gitclone" ) @@ -17,3 +18,8 @@ func (s *Strategy) GenerateAndUploadSnapshot(ctx context.Context, repo *gitclone func (s *Strategy) GenerateAndUploadMirrorSnapshot(ctx context.Context, repo *gitclone.Repository) error { return s.generateAndUploadMirrorSnapshot(ctx, repo) } + +// CacheBundleSync exports cacheBundleSync for testing. +func (s *Strategy) CacheBundleSync(ctx context.Context, key cache.Key, data []byte) error { + return s.cacheBundleSync(ctx, key, data) +} diff --git a/internal/strategy/git/snapshot.go b/internal/strategy/git/snapshot.go index 85352197..6548989f 100644 --- a/internal/strategy/git/snapshot.go +++ b/internal/strategy/git/snapshot.go @@ -401,12 +401,16 @@ func (s *Strategy) cacheBundleAsync(ctx context.Context, key cache.Key, data []b } func (s *Strategy) cacheBundleSync(ctx context.Context, key cache.Key, data []byte) error { + createCtx, cancelCreate := context.WithCancelCause(ctx) + defer cancelCreate(nil) + headers := http.Header{"Content-Type": {"application/x-git-bundle"}} - wc, err := s.cache.Create(ctx, key, headers, bundleCacheTTL) + wc, err := s.cache.Create(createCtx, key, headers, bundleCacheTTL) if err != nil { return errors.Wrap(err, "create cache entry") } if _, err := wc.Write(data); err != nil { + cancelCreate(err) _ = wc.Close() return errors.Wrap(err, "write bundle to cache") } diff --git a/internal/strategy/git/snapshot_test.go b/internal/strategy/git/snapshot_test.go index be6626cb..fa889fea 100644 --- a/internal/strategy/git/snapshot_test.go +++ b/internal/strategy/git/snapshot_test.go @@ -2,6 +2,7 @@ package git_test import ( "context" + "io" "net/http" "net/http/httptest" "os" @@ -688,6 +689,68 @@ func TestDeferredRestoreOnlyScheduledOnce(t *testing.T) { // The key assertion is that it doesn't panic from double-scheduling. } +func TestCacheBundleSyncAbortsOnWriteFailure(t *testing.T) { + if _, err := exec.LookPath("git"); err != nil { + t.Skip("git not found in PATH") + } + + _, ctx := logging.Configure(context.Background(), logging.Config{}) + tmpDir := t.TempDir() + mirrorRoot := filepath.Join(tmpDir, "mirrors") + + memCache, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour}) + assert.NoError(t, err) + + // Wrap cache so Create returns a writer that fails on Write. + failCache := &failWriteCache{Cache: memCache} + + mux := newTestMux() + cm := gitclone.NewManagerProvider(ctx, gitclone.Config{MirrorRoot: mirrorRoot}, nil) + s, err := git.New(ctx, git.Config{}, newTestScheduler(ctx, t), failCache, mux, cm, func() (*githubapp.TokenManager, error) { return nil, nil }) //nolint:nilnil + assert.NoError(t, err) + waitForReady(t, s) + + key := cache.NewKey("test-bundle-abort") + data := []byte("bundle data that should not persist") + + err = s.CacheBundleSync(ctx, key, data) + assert.Error(t, err) + + // Verify nothing was persisted — check underlying memCache, not failCache. + _, _, err = memCache.Open(ctx, key) + assert.IsError(t, err, os.ErrNotExist) +} + +// failWriteCache wraps a cache.Cache and makes Create return a writer that +// always fails on Write. +type failWriteCache struct { + cache.Cache +} + +func (f *failWriteCache) Create(ctx context.Context, key cache.Key, headers http.Header, ttl time.Duration) (io.WriteCloser, error) { + wc, err := f.Cache.Create(ctx, key, headers, ttl) + if err != nil { + return nil, err + } + return &failWriter{inner: wc}, nil +} + +func (f *failWriteCache) Namespace(ns cache.Namespace) cache.Cache { + return &failWriteCache{Cache: f.Cache.Namespace(ns)} +} + +type failWriter struct { + inner io.WriteCloser +} + +func (w *failWriter) Write(_ []byte) (int, error) { + return 0, io.ErrShortWrite +} + +func (w *failWriter) Close() error { + return w.inner.Close() +} + func TestSnapshotRemoteURLUsesUpstreamURL(t *testing.T) { if _, err := exec.LookPath("git"); err != nil { t.Skip("git not found in PATH") diff --git a/internal/strategy/gomod/cacher.go b/internal/strategy/gomod/cacher.go index 2d791962..348f4802 100644 --- a/internal/strategy/gomod/cacher.go +++ b/internal/strategy/gomod/cacher.go @@ -33,17 +33,23 @@ func (g *goproxyCacher) Put(ctx context.Context, name string, content io.ReadSee key := cache.NewKey(name) - wc, err := g.cache.Create(ctx, key, nil, 0) + createCtx, cancelCreate := context.WithCancelCause(ctx) + defer cancelCreate(nil) + + wc, err := g.cache.Create(createCtx, key, nil, 0) if err != nil { return errors.Errorf("create cache entry: %w", err) } - defer wc.Close() if _, err := content.Seek(0, io.SeekStart); err != nil { + cancelCreate(err) + _ = wc.Close() return errors.Errorf("seek to start: %w", err) } if _, err := io.Copy(wc, content); err != nil { + cancelCreate(err) + _ = wc.Close() return errors.Errorf("write to cache: %w", err) } diff --git a/internal/strategy/gomod/cacher_test.go b/internal/strategy/gomod/cacher_test.go new file mode 100644 index 00000000..5482adf3 --- /dev/null +++ b/internal/strategy/gomod/cacher_test.go @@ -0,0 +1,69 @@ +package gomod_test + +import ( + "context" + "io" + "log/slog" + "os" + "testing" + "time" + + "github.com/alecthomas/assert/v2" + + "github.com/block/cachew/internal/cache" + "github.com/block/cachew/internal/logging" + "github.com/block/cachew/internal/strategy/gomod" +) + +func TestGoproxyCacherPutAbortsOnReadError(t *testing.T) { + _, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelError}) + memCache, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour}) + assert.NoError(t, err) + defer memCache.Close() + + cacher := gomod.NewTestableGoproxyCacher(memCache) + + // Create a reader that fails after some bytes + content := &failingReadSeeker{data: []byte("partial module data"), failAfter: 7} + + err = cacher.Put(ctx, "example.com/mod/@v/v1.0.0.zip", content) + assert.Error(t, err) + + // The partial data must not be cached + key := cache.NewKey("example.com/mod/@v/v1.0.0.zip") + _, _, err = memCache.Open(ctx, key) + assert.IsError(t, err, os.ErrNotExist) +} + +type failingReadSeeker struct { + data []byte + failAfter int + pos int +} + +func (r *failingReadSeeker) Read(p []byte) (int, error) { + if r.pos >= r.failAfter { + return 0, io.ErrUnexpectedEOF + } + n := copy(p, r.data[r.pos:]) + if r.pos+n > r.failAfter { + n = r.failAfter - r.pos + } + r.pos += n + if r.pos >= r.failAfter { + return n, io.ErrUnexpectedEOF + } + return n, nil +} + +func (r *failingReadSeeker) Seek(offset int64, whence int) (int64, error) { + switch whence { + case io.SeekStart: + r.pos = int(offset) + case io.SeekCurrent: + r.pos += int(offset) + case io.SeekEnd: + r.pos = len(r.data) + int(offset) + } + return int64(r.pos), nil +} diff --git a/internal/strategy/gomod/export_test.go b/internal/strategy/gomod/export_test.go new file mode 100644 index 00000000..15a35845 --- /dev/null +++ b/internal/strategy/gomod/export_test.go @@ -0,0 +1,23 @@ +package gomod + +import ( + "context" + "io" + + "github.com/block/cachew/internal/cache" +) + +// TestableGoproxyCacher exposes goproxyCacher.Put for external tests. +type TestableGoproxyCacher struct { + inner goproxyCacher +} + +// NewTestableGoproxyCacher creates a testable goproxy cacher for the given cache. +func NewTestableGoproxyCacher(c cache.Cache) *TestableGoproxyCacher { + return &TestableGoproxyCacher{inner: goproxyCacher{cache: c}} +} + +// Put delegates to goproxyCacher.Put. +func (t *TestableGoproxyCacher) Put(ctx context.Context, name string, content io.ReadSeeker) error { + return t.inner.Put(ctx, name, content) +} diff --git a/internal/strategy/handler/handler.go b/internal/strategy/handler/handler.go index f99201bc..733274c5 100644 --- a/internal/strategy/handler/handler.go +++ b/internal/strategy/handler/handler.go @@ -1,6 +1,7 @@ package handler import ( + "context" "io" "maps" "net/http" @@ -179,16 +180,22 @@ func (h *Handler) streamNonOKResponse(w http.ResponseWriter, resp *http.Response func (h *Handler) streamAndCache(w http.ResponseWriter, r *http.Request, key cache.Key, resp *http.Response) error { ttl := h.ttlFunc(r) responseHeaders := maps.Clone(resp.Header) - cw, err := h.cache.Create(r.Context(), key, responseHeaders, ttl) + createCtx, cancelCreate := context.WithCancelCause(r.Context()) + cw, err := h.cache.Create(createCtx, key, responseHeaders, ttl) if err != nil { + cancelCreate(nil) h.errorHandler(httputil.Errorf(http.StatusInternalServerError, "failed to create cache entry: %w", err), w, r) return nil } pr, pw := io.Pipe() go func() { + defer cancelCreate(nil) mw := io.MultiWriter(pw, cw) _, copyErr := io.Copy(mw, resp.Body) + if copyErr != nil { + cancelCreate(copyErr) + } closeErr := cw.Close() pw.CloseWithError(errors.Join(copyErr, closeErr)) }() diff --git a/internal/strategy/handler/handler_test.go b/internal/strategy/handler/handler_test.go index cc8f4622..4c430262 100644 --- a/internal/strategy/handler/handler_test.go +++ b/internal/strategy/handler/handler_test.go @@ -6,6 +6,7 @@ import ( "log/slog" "net/http" "net/http/httptest" + "os" "strings" "testing" "time" @@ -368,6 +369,38 @@ func TestHandlerMethodChaining(t *testing.T) { assert.Equal(t, h, result, "methods should return the same handler instance") } +func TestStreamAndCacheAbortsOnUpstreamError(t *testing.T) { + // Backend that sends partial data then abruptly closes. + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("partial")) + if hj, ok := w.(http.Hijacker); ok { + conn, _, _ := hj.Hijack() + _ = conn.Close() + } + })) + defer backend.Close() + + c := mustNewMemoryCache() + ctx := logging.ContextWithLogger(context.Background(), slog.Default()) + + h := handler.New(http.DefaultClient, c). + Transform(func(r *http.Request) (*http.Request, error) { + return http.NewRequestWithContext(r.Context(), http.MethodGet, backend.URL+"/fail", nil) + }) + + r := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) + r = r.WithContext(ctx) + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + + // The partial response must not be cached. + key := cache.NewKey("http://example.com/test") + _, _, err := c.Open(ctx, key) + assert.IsError(t, err, os.ErrNotExist) +} + func mustNewMemoryCache() cache.Cache { _, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelError}) c, err := cache.NewMemory(ctx, cache.MemoryConfig{