Skip to content
6 changes: 5 additions & 1 deletion client/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,15 @@ func (c *Client) Save(ctx context.Context, key Key, baseDir string, paths []stri
}
}

wc, err := c.Create(ctx, key, headers, cfg.ttl)
createCtx, cancelCreate := context.WithCancelCause(ctx)
defer cancelCreate(nil)

wc, err := c.Create(createCtx, key, headers, cfg.ttl)
if err != nil {
return errors.Wrap(err, "failed to create object")
}
if err := Archive(ctx, wc, baseDir, paths, cfg.exclude, cfg.zstdThreads); err != nil {
cancelCreate(err)
return errors.Join(err, wc.Close())
}
return errors.Wrap(wc.Close(), "failed to close writer")
Expand Down
18 changes: 18 additions & 0 deletions client/files_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,24 @@ func TestHashFilesSkipsDirectories(t *testing.T) {
assert.Equal(t, h1, h2, "directories should be skipped, not cause errors")
}

func TestSaveAbortOnArchiveFailure(t *testing.T) {
srv := newFakeServer(nil)
defer srv.Close()

c := client.New(srv.URL, nil).Namespace("test")
defer c.Close()

key := client.NewKey("should-not-exist")

// Save from a nonexistent directory — Archive will fail.
err := c.Save(t.Context(), key, "/nonexistent/path", []string{"."})
assert.Error(t, err)

// The object must not have been persisted.
_, _, err = c.Open(t.Context(), key)
assert.IsError(t, err, os.ErrNotExist)
}

func TestHashKeySaveRestore(t *testing.T) {
srv := newFakeServer(nil)
defer srv.Close()
Expand Down
6 changes: 5 additions & 1 deletion cmd/cachew/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,16 @@ func (c *PutCmd) Run(ctx context.Context, api *client.Client) error {
headers.Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filepath.Base(filename))) //nolint:perfsprint
}

wc, err := api.Namespace(c.Namespace).Create(ctx, c.Key.Key(), headers, c.TTL)
createCtx, cancelCreate := context.WithCancelCause(ctx)
defer cancelCreate(nil)

wc, err := api.Namespace(c.Namespace).Create(createCtx, c.Key.Key(), headers, c.TTL)
if err != nil {
return errors.Wrap(err, "failed to create object")
}

if _, err := io.Copy(wc, c.Input); err != nil {
cancelCreate(err)
return errors.Join(errors.Wrap(err, "failed to copy data"), wc.Close())
}

Expand Down
34 changes: 34 additions & 0 deletions cmd/cachew/save_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"context"
"io"
"net/http"
"net/http/httptest"
"os"
Expand Down Expand Up @@ -75,3 +76,36 @@ func TestRestoreCacheMiss(t *testing.T) {
err := cmd.Run(context.Background(), api, &CLI{})
assert.True(t, errors.Is(err, errCacheMiss), "expected errCacheMiss sentinel, got %v", err)
}

func TestPutCmdCancelOnCopyFailure(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost {
_, _ = io.Copy(io.Discard, r.Body)
w.WriteHeader(http.StatusOK)
return
}
http.NotFound(w, nil)
}))
defer srv.Close()

api := client.New(srv.URL, nil)
defer api.Close()

// This mirrors what PutCmd.Run does internally:
// create with a cancellable context, write partial data, then cancel
// before closing — simulating an io.Copy error.
key := client.NewKey("put-cancel-test")
createCtx, cancel := context.WithCancelCause(t.Context())
defer cancel(nil)

wc, err := api.Namespace("test").Create(createCtx, key, nil, 0)
assert.NoError(t, err)

_, _ = wc.Write([]byte("partial"))

// Simulate the cancel-before-close pattern from PutCmd.Run
cancel(errors.New("copy failed"))
err = wc.Close()
assert.Error(t, err)
assert.Contains(t, err.Error(), "cancelled")
}
9 changes: 8 additions & 1 deletion internal/cache/http.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cache

import (
"context"
"io"
"maps"
"net/http"
Expand Down Expand Up @@ -53,17 +54,23 @@ func FetchDirect(client *http.Client, r *http.Request, c Cache, key Key) (*http.
}

responseHeaders := maps.Clone(resp.Header)
cw, err := c.Create(r.Context(), key, responseHeaders, 0)
createCtx, cancelCreate := context.WithCancelCause(r.Context())
cw, err := c.Create(createCtx, key, responseHeaders, 0)
if err != nil {
cancelCreate(nil)
_ = resp.Body.Close()
return nil, httputil.Errorf(http.StatusInternalServerError, "failed to create cache entry: %w", err)
}

originalBody := resp.Body
pr, pw := io.Pipe()
go func() {
defer cancelCreate(nil)
mw := io.MultiWriter(pw, cw)
_, copyErr := io.Copy(mw, originalBody)
if copyErr != nil {
cancelCreate(copyErr)
}
closeErr := errors.Join(cw.Close(), originalBody.Close())
pw.CloseWithError(errors.Join(copyErr, closeErr))
}()
Expand Down
37 changes: 37 additions & 0 deletions internal/cache/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"log/slog"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"

Expand Down Expand Up @@ -83,3 +84,39 @@ func TestCachedFetchNonOKStatus(t *testing.T) {
assert.NoError(t, resp.Body.Close())
assert.Equal(t, "not found", string(body))
}

func TestFetchDirectAbortsOnPartialResponse(t *testing.T) {
_, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelError})
memCache, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour})
assert.NoError(t, err)
defer memCache.Close()

// Backend that sends partial data then closes the connection.
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("partial"))
// Hijack the connection to simulate a mid-stream failure.
if hj, ok := w.(http.Hijacker); ok {
conn, _, _ := hj.Hijack()
_ = conn.Close()
}
}))
defer backend.Close()

client := &http.Client{}
key := cache.NewKey(backend.URL + "/fail")

req, err := http.NewRequestWithContext(ctx, http.MethodGet, backend.URL+"/fail", nil)
assert.NoError(t, err)

resp, err := cache.FetchDirect(client, req, memCache, key)
assert.NoError(t, err)
// Read the body — the copy goroutine will encounter the connection reset.
_, _ = io.ReadAll(resp.Body)
_ = resp.Body.Close()

// The partial response must not be cached.
_, _, err = memCache.Open(ctx, key)
assert.IsError(t, err, os.ErrNotExist)
}
31 changes: 31 additions & 0 deletions internal/cache/s3_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package cache_test

import (
"context"
"log/slog"
"os"
"testing"
"time"

"github.com/alecthomas/assert/v2"
"github.com/alecthomas/errors"

"github.com/block/cachew/internal/cache"
"github.com/block/cachew/internal/cache/cachetest"
Expand Down Expand Up @@ -40,6 +42,35 @@ func TestS3Cache(t *testing.T) {
cachetest.Suite(t, func(t *testing.T) cache.Cache { return newS3Cache(t, bucket) })
}

// TestS3ContextCancellationAbortsUpload verifies that cancelling the context before
// closing the writer aborts the S3 upload and does not leave any object behind.
// This is the mechanism snapshot.CreatePaths uses to prevent partial/corrupt uploads.
func TestS3ContextCancellationAbortsUpload(t *testing.T) {
bucket := s3clienttest.Start(t)
c := newS3Cache(t, bucket)
defer c.Close()

key := cache.NewKey("aborted-upload")

ctx, cancel := context.WithCancelCause(t.Context())

w, err := c.Create(ctx, key, nil, time.Hour)
assert.NoError(t, err)

// Write some data so this isn't just a 0-byte edge case.
_, err = w.Write([]byte("partial data that should not be persisted"))
assert.NoError(t, err)

// Cancel the context before closing, simulating an archive failure.
cancel(errors.New("archive failed"))
err = w.Close()
assert.Error(t, err)

// The object must not be retrievable.
_, _, err = c.Open(t.Context(), key)
assert.IsError(t, err, os.ErrNotExist)
}

func TestS3CacheSoak(t *testing.T) {
if os.Getenv("SOAK_TEST") == "" {
t.Skip("Skipping soak test; set SOAK_TEST=1 to run")
Expand Down
8 changes: 7 additions & 1 deletion internal/snapshot/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,18 @@ func CreatePaths(ctx context.Context, remote cache.Cache, key cache.Key, baseDir
}
}

wc, err := remote.Create(ctx, key, headers, ttl)
// Wrap the context so we can cancel the upload on archive failure,
// preventing partial data from being persisted.
createCtx, cancelCreate := context.WithCancelCause(ctx)
defer cancelCreate(nil)

wc, err := remote.Create(createCtx, key, headers, ttl)
if err != nil {
return errors.Wrap(err, "failed to create object")
}

if err := client.Archive(ctx, wc, baseDir, includePaths, excludePatterns, threads); err != nil {
cancelCreate(err)
return errors.Join(err, wc.Close())
}
return errors.Wrap(wc.Close(), "failed to close writer")
Expand Down
7 changes: 6 additions & 1 deletion internal/strategy/apiv1.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,19 @@ func (d *APIV1) putObject(w http.ResponseWriter, r *http.Request) {
// Extract and filter headers from request
headers := httputil.FilterHeaders(r.Header, httputil.TransportHeaders...)

createCtx, cancelCreate := context.WithCancelCause(r.Context())
defer cancelCreate(nil)

namespacedCache := d.cache.Namespace(namespace)
cw, err := namespacedCache.Create(r.Context(), key, headers, ttl)
cw, err := namespacedCache.Create(createCtx, key, headers, ttl)
if err != nil {
d.httpError(w, http.StatusInternalServerError, err, "Failed to create cache writer", "key", key)
return
}

if _, err := io.Copy(cw, r.Body); err != nil {
cancelCreate(err)
_ = cw.Close()
d.httpError(w, http.StatusInternalServerError, err, "Failed to copy request body to cache writer")
return
}
Expand Down
65 changes: 65 additions & 0 deletions internal/strategy/apiv1_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package strategy_test

import (
"context"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"

"github.com/alecthomas/assert/v2"

"github.com/block/cachew/internal/cache"
"github.com/block/cachew/internal/logging"
"github.com/block/cachew/internal/strategy"
)

func TestPutObjectAbortsOnReadError(t *testing.T) {
_, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelError})
memCache, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour})
assert.NoError(t, err)
defer memCache.Close()

mux := http.NewServeMux()
_, err = strategy.NewAPIV1(ctx, struct{}{}, memCache, mux)
assert.NoError(t, err)

key := cache.NewKey("abort-test")

// Create a reader that returns an error after some data.
body := &failingReader{data: []byte("partial data"), failAfter: 5}
req := httptest.NewRequest(http.MethodPost, "/api/v1/object/test/"+key.String(), body)
req = req.WithContext(ctx)
w := httptest.NewRecorder()

mux.ServeHTTP(w, req)
assert.Equal(t, http.StatusInternalServerError, w.Code)

// The partial data must not be cached.
nsCache := memCache.Namespace("test")
_, _, err = nsCache.Open(ctx, key)
assert.IsError(t, err, os.ErrNotExist)
}

// failingReader returns data up to failAfter bytes, then returns an error.
type failingReader struct {
data []byte
failAfter int
read int
}

func (r *failingReader) Read(p []byte) (int, error) {
if r.read >= r.failAfter {
return 0, io.ErrUnexpectedEOF
}
n := min(len(p), r.failAfter-r.read, len(r.data)-r.read)
copy(p[:n], r.data[r.read:r.read+n])
r.read += n
if r.read >= r.failAfter {
return n, io.ErrUnexpectedEOF
}
return n, nil
}
6 changes: 6 additions & 0 deletions internal/strategy/git/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package git
import (
"context"

"github.com/block/cachew/internal/cache"
"github.com/block/cachew/internal/gitclone"
)

Expand All @@ -17,3 +18,8 @@ func (s *Strategy) GenerateAndUploadSnapshot(ctx context.Context, repo *gitclone
func (s *Strategy) GenerateAndUploadMirrorSnapshot(ctx context.Context, repo *gitclone.Repository) error {
return s.generateAndUploadMirrorSnapshot(ctx, repo)
}

// CacheBundleSync exports cacheBundleSync for testing.
func (s *Strategy) CacheBundleSync(ctx context.Context, key cache.Key, data []byte) error {
return s.cacheBundleSync(ctx, key, data)
}
6 changes: 5 additions & 1 deletion internal/strategy/git/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,12 +401,16 @@ func (s *Strategy) cacheBundleAsync(ctx context.Context, key cache.Key, data []b
}

func (s *Strategy) cacheBundleSync(ctx context.Context, key cache.Key, data []byte) error {
createCtx, cancelCreate := context.WithCancelCause(ctx)
defer cancelCreate(nil)

headers := http.Header{"Content-Type": {"application/x-git-bundle"}}
wc, err := s.cache.Create(ctx, key, headers, bundleCacheTTL)
wc, err := s.cache.Create(createCtx, key, headers, bundleCacheTTL)
if err != nil {
return errors.Wrap(err, "create cache entry")
}
if _, err := wc.Write(data); err != nil {
cancelCreate(err)
_ = wc.Close()
return errors.Wrap(err, "write bundle to cache")
}
Expand Down
Loading