diff --git a/context.go b/context.go index ec7fdd998..510f51855 100644 --- a/context.go +++ b/context.go @@ -467,13 +467,9 @@ func (c *Context) json(code int, i any, indent string) error { // as JSONSerializer.Serialize can fail, and in that case we need to delay sending status code to the client until // (global) error handler decides correct status code for the error to be sent to the client. // For that we need to use writer that can store the proposed status code until the first Write is called. - if r, err := UnwrapResponse(c.response); err == nil { - r.Status = code - } else { - resp := c.Response() - c.SetResponse(&delayedStatusWriter{ResponseWriter: resp, status: code}) - defer c.SetResponse(resp) - } + resp := c.Response() + c.SetResponse(&delayedStatusWriter{ResponseWriter: resp, status: code}) + defer c.SetResponse(resp) return c.echo.JSONSerializer.Serialize(c, i, indent) } diff --git a/echo_test.go b/echo_test.go index b5045e111..6847e56bd 100644 --- a/echo_test.go +++ b/echo_test.go @@ -1233,6 +1233,89 @@ func TestDefaultHTTPErrorHandler_CommitedResponse(t *testing.T) { assert.Equal(t, http.StatusOK, resp.Code) } +func TestRouterAutoHandleHEADFullHTTPHandlerFlow(t *testing.T) { + tests := []struct { + name string + givenAutoHandleHEAD bool + whenMethod string + expectBody string + expectCode int + expectContentLength string + }{ + { + name: "AutoHandleHEAD disabled - HEAD returns 405", + givenAutoHandleHEAD: false, + whenMethod: http.MethodHead, + expectCode: http.StatusMethodNotAllowed, + expectBody: "", + }, + { + name: "AutoHandleHEAD enabled - HEAD returns 200 with Content-Length", + givenAutoHandleHEAD: true, + whenMethod: http.MethodHead, + expectCode: http.StatusOK, + expectBody: "", + expectContentLength: "4", + }, + { + name: "GET request works normally with AutoHandleHEAD enabled", + givenAutoHandleHEAD: true, + whenMethod: http.MethodGet, + expectCode: http.StatusOK, + expectBody: "test", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + e := NewWithConfig(Config{ + Router: NewRouter(RouterConfig{ + AutoHandleHEAD: tc.givenAutoHandleHEAD, + }), + }) + + e.GET("/hello", func(c *Context) error { + return c.String(http.StatusOK, "test") + }) + + req := httptest.NewRequest(tc.whenMethod, "/hello", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + assert.Equal(t, tc.expectContentLength, rec.Header().Get(HeaderContentLength)) + assert.Equal(t, tc.expectBody, rec.Body.String()) + }) + } +} + +func TestAutoHeadExplicitHeadTakesPrecedence(t *testing.T) { + e := NewWithConfig(Config{ + Router: NewRouter(RouterConfig{ + AutoHandleHEAD: true, + }), + }) + + // Register explicit HEAD route FIRST with custom behavior + e.HEAD("/api/users", func(c *Context) error { + c.Response().Header().Set("X-Custom-Header", "explicit-head") + return c.NoContent(http.StatusTeapot) + }) + + e.GET("/api/users", func(c *Context) error { + return c.JSON(http.StatusNotFound, map[string]string{"name": "John"}) + }) + + req := httptest.NewRequest(http.MethodHead, "/api/users", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusTeapot, rec.Code) + assert.Equal(t, "explicit-head", rec.Header().Get("X-Custom-Header")) + assert.Equal(t, "", rec.Body.String()) +} + func benchmarkEchoRoutes(b *testing.B, routes []testRoute) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) diff --git a/response.go b/response.go index 4da729c47..c018af2cb 100644 --- a/response.go +++ b/response.go @@ -10,6 +10,7 @@ import ( "log/slog" "net" "net/http" + "strconv" ) // Response wraps an http.ResponseWriter and implements its interface to be used @@ -170,3 +171,89 @@ func (w *delayedStatusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { func (w *delayedStatusWriter) Unwrap() http.ResponseWriter { return w.ResponseWriter } + +// headResponseWriter captures the response that a GET handler would produce for a +// rewritten HEAD request, suppresses the body, and preserves response metadata. +// +// The writer buffers status until the downstream handler returns, so it +// can compute a Content-Length value from the number of body bytes that would have +// been written by the GET handler. If the handler already sets Content-Length +// explicitly, that value is preserved. +// +// Flush is intentionally a no-op because emitting headers early would prevent +// finalizing Content-Length after the handler completes. +type headResponseWriter struct { + rw http.ResponseWriter + status int + wroteStatus bool + bodyBytes int64 +} + +func (w *headResponseWriter) Header() http.Header { + return w.rw.Header() +} + +func (w *headResponseWriter) WriteHeader(code int) { + if w.wroteStatus { + return + } + w.wroteStatus = true + w.status = code +} + +func (w *headResponseWriter) Write(b []byte) (int, error) { + if !w.wroteStatus { + w.WriteHeader(http.StatusOK) + } + w.bodyBytes += int64(len(b)) + return len(b), nil +} + +func (w *headResponseWriter) Flush() { + // No-op on purpose. A HEAD response has no body, and flushing early would + // commit headers before Content-Length can be finalized. +} + +func (w *headResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return http.NewResponseController(w.rw).Hijack() +} + +func (w *headResponseWriter) Unwrap() http.ResponseWriter { + return w.rw +} + +func (w *headResponseWriter) commit() { + dst := w.rw.Header() + if dst.Get(HeaderContentLength) == "" && + dst.Get("Transfer-Encoding") == "" && + !statusMustNotHaveBody(w.status) { + dst.Set(HeaderContentLength, strconv.FormatInt(w.bodyBytes, 10)) + } + + // "commit" the Response only when the headers were written otherwise the Echo errorhandler cannot properly handle errors + if w.wroteStatus { + w.rw.WriteHeader(w.status) + } +} + +func statusMustNotHaveBody(code int) bool { + return (code >= 100 && code < 200) || + code == http.StatusNoContent || + code == http.StatusNotModified +} + +func wrapHeadHandler(handler HandlerFunc) HandlerFunc { + return func(c *Context) error { + originalWriter := c.Response() + headWriter := &headResponseWriter{rw: originalWriter} + + c.SetResponse(headWriter) + defer func() { + c.SetResponse(originalWriter) + }() + + err := handler(c) + headWriter.commit() + return err + } +} diff --git a/router.go b/router.go index 48341cb1b..86e2dfd26 100644 --- a/router.go +++ b/router.go @@ -69,6 +69,7 @@ type DefaultRouter struct { allowOverwritingRoute bool unescapePathParamValues bool useEscapedPathForRouting bool + autoHandleHEAD bool } // RouterConfig is configuration options for (default) router @@ -79,6 +80,20 @@ type RouterConfig struct { AllowOverwritingRoute bool UnescapePathParamValues bool UseEscapedPathForMatching bool + + // AutoHandleHEAD enables automatic handling of HTTP HEAD requests by + // falling back to the corresponding GET route. + // + // When enabled, a HEAD request will match the same handler as GET for + // the route, but the response body is suppressed in accordance with + // HTTP semantics. Headers (e.g., Content-Length, Content-Type) are + // preserved as if a GET request was made. + // + // Note that the GET handler is still executed, so any side effects + // (such as database queries or logging) will occur. + // + // Disabled by default. + AutoHandleHEAD bool } // NewRouter returns a new Router instance. @@ -98,6 +113,7 @@ func NewRouter(config RouterConfig) *DefaultRouter { notFoundHandler: notFoundHandler, methodNotAllowedHandler: methodNotAllowedHandler, optionsMethodHandler: optionsMethodHandler, + autoHandleHEAD: config.AutoHandleHEAD, } if config.NotFoundHandler != nil { r.notFoundHandler = config.NotFoundHandler @@ -210,7 +226,7 @@ func (m *routeMethods) set(method string, r *routeMethod) { m.updateAllowHeader() } -func (m *routeMethods) find(method string, fallbackToAny bool) *routeMethod { +func (m *routeMethods) find(method string, fallbackToAny bool, autoHandleHEAD bool) *routeMethod { var r *routeMethod switch method { case http.MethodConnect: @@ -221,6 +237,9 @@ func (m *routeMethods) find(method string, fallbackToAny bool) *routeMethod { r = m.get case http.MethodHead: r = m.head + if autoHandleHEAD && r == nil { + r = m.get + } case http.MethodOptions: r = m.options case http.MethodPatch: @@ -374,7 +393,7 @@ func (r *DefaultRouter) Remove(method string, path string) error { return errors.New("could not find route to remove by given path") } - if mh := nodeToRemove.methods.find(method, false); mh == nil { + if mh := nodeToRemove.methods.find(method, false, false); mh == nil { return errors.New("could not find route to remove by given path and method") } nodeToRemove.setHandler(method, nil) @@ -904,7 +923,7 @@ func (r *DefaultRouter) Route(c *Context) HandlerFunc { if previousBestMatchNode == nil { previousBestMatchNode = currentNode } - if h := currentNode.methods.find(req.Method, true); h != nil { + if h := currentNode.methods.find(req.Method, true, r.autoHandleHEAD); h != nil { matchedRouteMethod = h break } @@ -955,7 +974,7 @@ func (r *DefaultRouter) Route(c *Context) HandlerFunc { searchIndex += len(search) search = "" - if rMethod := currentNode.methods.find(req.Method, true); rMethod != nil { + if rMethod := currentNode.methods.find(req.Method, true, r.autoHandleHEAD); rMethod != nil { matchedRouteMethod = rMethod break } @@ -995,6 +1014,11 @@ func (r *DefaultRouter) Route(c *Context) HandlerFunc { var rInfo *RouteInfo if matchedRouteMethod != nil { rHandler = matchedRouteMethod.handler + if r.autoHandleHEAD && req.Method == http.MethodHead { + rHandler = wrapHeadHandler(rHandler) + // we are not touching rInfo.Method and let it be value from GET routeInfo + } + rPath = matchedRouteMethod.Path rInfo = matchedRouteMethod.RouteInfo } else {