diff --git a/sentry_sdk/integrations/stdlib.py b/sentry_sdk/integrations/stdlib.py index 5d8df43eb2..7573f8da7c 100644 --- a/sentry_sdk/integrations/stdlib.py +++ b/sentry_sdk/integrations/stdlib.py @@ -2,7 +2,7 @@ import subprocess import sys import platform -from http.client import HTTPConnection +from http.client import HTTPConnection, HTTPResponse import sentry_sdk from sentry_sdk.consts import OP, SPANDATA @@ -66,9 +66,22 @@ def add_python_runtime_context( return event +def _complete_span(span: "Union[Span, StreamedSpan]") -> None: + if isinstance(span, StreamedSpan): + with capture_internal_exceptions(): + add_http_request_source(span) + span.end() + else: + span.finish() + with capture_internal_exceptions(): + add_http_request_source(span) + + def _install_httplib() -> None: real_putrequest = HTTPConnection.putrequest real_getresponse = HTTPConnection.getresponse + real_read = HTTPResponse.read + real_close = HTTPResponse.close def putrequest( self: "HTTPConnection", method: str, url: str, *args: "Any", **kwargs: "Any" @@ -172,29 +185,57 @@ def getresponse(self: "HTTPConnection", *args: "Any", **kwargs: "Any") -> "Any": try: rv = real_getresponse(self, *args, **kwargs) + except BaseException: + _complete_span(span) + raise + + if isinstance(span, StreamedSpan): + status_code = int(rv.status) + span.status = "error" if status_code >= 400 else "ok" + span.set_attribute("http.response.status_code", status_code) + else: + span.set_http_status(int(rv.status)) + span.set_data("reason", rv.reason) + + # getresponse doesn't include actually reading the response body. This + # is done in read(). So if the metadata/headers suggest there's a body to + # read, don't finish the span just yet, but save it for ending it later. + has_body = rv.chunked or (rv.length is not None and rv.length > 0) + if has_body: + rv._sentrysdk_span = span # type: ignore[attr-defined] + else: + _complete_span(span) - if isinstance(span, StreamedSpan): - status_code = int(rv.status) - span.status = "error" if status_code >= 400 else "ok" - span.set_attribute("http.response.status_code", status_code) - else: - span.set_http_status(int(rv.status)) - span.set_data("reason", rv.reason) - finally: - if isinstance(span, StreamedSpan): - with capture_internal_exceptions(): - add_http_request_source(span) - span.end() - else: - span.finish() + return rv - with capture_internal_exceptions(): - add_http_request_source(span) + def read(self: "HTTPResponse", *args: "Any", **kwargs: "Any") -> "Any": + try: + return real_read(self, *args, **kwargs) + finally: + span = getattr(self, "_sentrysdk_span", None) + # read() might be called multiple times to consume a single body, + # so we can't just end the span when read() is done. Instead, + # try to figure out whether the response body has been fully read. + if span and (self.fp is None or self.closed): + self._sentrysdk_span = None # type: ignore[attr-defined] + _complete_span(span) + + def close(self: "HTTPResponse") -> None: + # We patch close() as a best effort fallback in case the span is not + # ended yet in getresponse() or read(). - return rv + try: + real_close(self) + finally: + span = getattr(self, "_sentrysdk_span", None) + if span is not None: + self._sentrysdk_span = None # type: ignore[attr-defined] + _complete_span(span) HTTPConnection.putrequest = putrequest # type: ignore[method-assign] HTTPConnection.getresponse = getresponse # type: ignore[method-assign] + HTTPResponse.read = read # type: ignore[method-assign] + HTTPResponse.close = close # type: ignore[assignment,method-assign] def _init_argument( diff --git a/tests/integrations/stdlib/test_httplib.py b/tests/integrations/stdlib/test_httplib.py index 33aa95825d..589a8e8e97 100644 --- a/tests/integrations/stdlib/test_httplib.py +++ b/tests/integrations/stdlib/test_httplib.py @@ -1,6 +1,7 @@ import os import socket import datetime +import time from http.client import HTTPConnection, HTTPSConnection from http.server import BaseHTTPRequestHandler, HTTPServer from socket import SocketIO @@ -44,6 +45,37 @@ def create_mock_proxy_server(): PROXY_PORT = create_mock_proxy_server() +CHUNK_DELAY = 0.1 +NUM_CHUNKS = 3 + + +class ChunkedResponseHandler(BaseHTTPRequestHandler): + def do_GET(self): + self.send_response(200) + self.send_header("Transfer-Encoding", "chunked") + self.end_headers() + for _ in range(NUM_CHUNKS): + chunk = b"x" * 100 + self.wfile.write(f"{len(chunk):x}\r\n".encode() + chunk + b"\r\n") + self.wfile.flush() + time.sleep(CHUNK_DELAY) + self.wfile.write(b"0\r\n\r\n") + + def log_message(self, *args): + pass + + +def create_chunked_server(): + port = get_free_port() + server = HTTPServer(("localhost", port), ChunkedResponseHandler) + thread = Thread(target=server.serve_forever) + thread.daemon = True + thread.start() + return port + + +CHUNKED_PORT = create_chunked_server() + def test_crumb_capture(sentry_init, capture_events): sentry_init(integrations=[StdlibIntegration()]) @@ -1161,3 +1193,50 @@ def test_proxy_http_tunnel( assert span["data"][SPANDATA.HTTP_METHOD] == "GET" assert span["data"][SPANDATA.NETWORK_PEER_ADDRESS] == "localhost" assert span["data"][SPANDATA.NETWORK_PEER_PORT] == PROXY_PORT + + +@pytest.mark.parametrize("span_streaming", [True, False]) +def test_chunked_response_span_covers_body_read( + sentry_init, + capture_events, + capture_items, + span_streaming, +): + sentry_init( + traces_sample_rate=1.0, + _experiments={"trace_lifecycle": "stream" if span_streaming else "static"}, + ) + + min_expected_duration = CHUNK_DELAY * NUM_CHUNKS + + if span_streaming: + items = capture_items("span") + + with sentry_sdk.traces.start_span(name="custom parent"): + conn = HTTPConnection("localhost", CHUNKED_PORT) + conn.request("GET", "/chunked") + response = conn.getresponse() + response.read() + + sentry_sdk.flush() + http_span, parent_span = [item.payload for item in items] + + duration = http_span["end_timestamp"] - http_span["start_timestamp"] + assert duration >= min_expected_duration + else: + events = capture_events() + + with start_transaction(name="test_chunked"): + conn = HTTPConnection("localhost", CHUNKED_PORT) + conn.request("GET", "/chunked") + response = conn.getresponse() + response.read() + + (event,) = events + (span,) = event["spans"] + + fmt = "%Y-%m-%dT%H:%M:%S.%fZ" + start = datetime.datetime.strptime(span["start_timestamp"], fmt) + end = datetime.datetime.strptime(span["timestamp"], fmt) + duration = (end - start).total_seconds() + assert duration >= min_expected_duration