diff --git a/tests/integrations/stdlib/test_httplib.py b/tests/integrations/stdlib/test_httplib.py index cdbf6cd68c..48fa85ec9e 100644 --- a/tests/integrations/stdlib/test_httplib.py +++ b/tests/integrations/stdlib/test_httplib.py @@ -1,4 +1,5 @@ import os +import socket import datetime from http.client import HTTPConnection, HTTPSConnection from http.server import BaseHTTPRequestHandler, HTTPServer @@ -202,15 +203,31 @@ def test_httplib_misuse(sentry_init, capture_events, request): ) -def test_outgoing_trace_headers(sentry_init, monkeypatch): - # HTTPSConnection.send is passed a string containing (among other things) - # the headers on the request. Mock it so we can check the headers, and also - # so it doesn't try to actually talk to the internet. - mock_send = mock.Mock() - monkeypatch.setattr(HTTPSConnection, "send", mock_send) - +def test_outgoing_trace_headers(sentry_init, capture_events): sentry_init(traces_sample_rate=1.0) + already_patched_getresponse = HTTPSConnection.getresponse + + request_headers = {} + + class HTTPSConnectionRecordingRequestHeaders(HTTPSConnection): + def send(self, *args, **kwargs) -> None: + request_str = args[0] + for line in request_str.decode("utf-8").split("\r\n")[1:]: + if line: + key, val = line.split(": ") + request_headers[key] = val + + server_sock, client_sock = socket.socketpair() + server_sock.sendall(b"HTTP/1.1 200 OK\r\n\r\n") + server_sock.close() + self.sock = client_sock + + def getresponse(self, *args, **kwargs): + return already_patched_getresponse(self, *args, **kwargs) + + events = capture_events() + headers = { "sentry-trace": "771a43a4192642f0b136d5159a501700-1234567890abcdef-1", "baggage": ( @@ -228,73 +245,83 @@ def test_outgoing_trace_headers(sentry_init, monkeypatch): op="greeting.sniff", trace_id="12312012123120121231201212312012", ) as transaction: - HTTPSConnection("www.squirrelchasers.com").request("GET", "/top-chasers") - - (request_str,) = mock_send.call_args[0] - request_headers = {} - for line in request_str.decode("utf-8").split("\r\n")[1:]: - if line: - key, val = line.split(": ") - request_headers[key] = val - - request_span = transaction._span_recorder.spans[-1] - expected_sentry_trace = "{trace_id}-{parent_span_id}-{sampled}".format( - trace_id=transaction.trace_id, - parent_span_id=request_span.span_id, - sampled=1, - ) - assert request_headers["sentry-trace"] == expected_sentry_trace - - expected_outgoing_baggage = ( - "sentry-trace_id=771a43a4192642f0b136d5159a501700," - "sentry-public_key=49d0f7386ad645858ae85020e393bef3," - "sentry-sample_rate=1.0," - "sentry-user_id=Am%C3%A9lie," - "sentry-sample_rand=0.132521102938283" - ) + connection = HTTPSConnectionRecordingRequestHeaders("localhost", port=PORT) + connection.request("GET", "/top-chasers") + connection.getresponse() - assert request_headers["baggage"] == expected_outgoing_baggage + (event,) = events + request_span = event["spans"][-1] + expected_sentry_trace = "{trace_id}-{parent_span_id}-{sampled}".format( + trace_id=event["contexts"]["trace"]["trace_id"], + parent_span_id=request_span["span_id"], + sampled=1, + ) + assert request_headers["sentry-trace"] == expected_sentry_trace + + expected_outgoing_baggage = ( + "sentry-trace_id=771a43a4192642f0b136d5159a501700," + "sentry-public_key=49d0f7386ad645858ae85020e393bef3," + "sentry-sample_rate=1.0," + "sentry-user_id=Am%C3%A9lie," + "sentry-sample_rand=0.132521102938283" + ) + assert request_headers["baggage"] == expected_outgoing_baggage -def test_outgoing_trace_headers_head_sdk(sentry_init, monkeypatch): - # HTTPSConnection.send is passed a string containing (among other things) - # the headers on the request. Mock it so we can check the headers, and also - # so it doesn't try to actually talk to the internet. - mock_send = mock.Mock() - monkeypatch.setattr(HTTPSConnection, "send", mock_send) +def test_outgoing_trace_headers_head_sdk(sentry_init, capture_events): sentry_init(traces_sample_rate=0.5, release="foo") + + already_patched_getresponse = HTTPSConnection.getresponse + + request_headers = {} + + class HTTPSConnectionRecordingRequestHeaders(HTTPSConnection): + def send(self, *args, **kwargs) -> None: + request_str = args[0] + for line in request_str.decode("utf-8").split("\r\n")[1:]: + if line: + key, val = line.split(": ") + request_headers[key] = val + + server_sock, client_sock = socket.socketpair() + server_sock.sendall(b"HTTP/1.1 200 OK\r\n\r\n") + server_sock.close() + self.sock = client_sock + + def getresponse(self, *args, **kwargs): + return already_patched_getresponse(self, *args, **kwargs) + + events = capture_events() + with mock.patch("sentry_sdk.tracing_utils.Random.randrange", return_value=250000): transaction = continue_trace({}) with start_transaction(transaction=transaction, name="Head SDK tx") as transaction: - HTTPSConnection("www.squirrelchasers.com").request("GET", "/top-chasers") - - (request_str,) = mock_send.call_args[0] - request_headers = {} - for line in request_str.decode("utf-8").split("\r\n")[1:]: - if line: - key, val = line.split(": ") - request_headers[key] = val - - request_span = transaction._span_recorder.spans[-1] - expected_sentry_trace = "{trace_id}-{parent_span_id}-{sampled}".format( - trace_id=transaction.trace_id, - parent_span_id=request_span.span_id, - sampled=1, - ) - assert request_headers["sentry-trace"] == expected_sentry_trace + connection = HTTPSConnectionRecordingRequestHeaders("localhost", port=PORT) + connection.request("GET", "/top-chasers") + connection.getresponse() + + (event,) = events + request_span = event["spans"][-1] + expected_sentry_trace = "{trace_id}-{parent_span_id}-{sampled}".format( + trace_id=event["contexts"]["trace"]["trace_id"], + parent_span_id=request_span["span_id"], + sampled=1, + ) + + assert request_headers["sentry-trace"] == expected_sentry_trace - expected_outgoing_baggage = ( - "sentry-trace_id=%s," - "sentry-sample_rand=0.250000," - "sentry-environment=production," - "sentry-release=foo," - "sentry-sample_rate=0.5," - "sentry-sampled=%s" - ) % (transaction.trace_id, "true" if transaction.sampled else "false") + expected_outgoing_baggage = ( + "sentry-trace_id=%s," + "sentry-sample_rand=0.250000," + "sentry-environment=production," + "sentry-release=foo," + "sentry-sample_rate=0.5," + "sentry-sampled=%s" + ) % (transaction.trace_id, "true" if transaction.sampled else "false") - assert request_headers["baggage"] == expected_outgoing_baggage + assert request_headers["baggage"] == expected_outgoing_baggage @pytest.mark.parametrize( @@ -357,19 +384,33 @@ def test_outgoing_trace_headers_head_sdk(sentry_init, monkeypatch): ], ) def test_option_trace_propagation_targets( - sentry_init, monkeypatch, trace_propagation_targets, host, path, trace_propagated + sentry_init, trace_propagation_targets, host, path, trace_propagated ): - # HTTPSConnection.send is passed a string containing (among other things) - # the headers on the request. Mock it so we can check the headers, and also - # so it doesn't try to actually talk to the internet. - mock_send = mock.Mock() - monkeypatch.setattr(HTTPSConnection, "send", mock_send) - sentry_init( trace_propagation_targets=trace_propagation_targets, traces_sample_rate=1.0, ) + already_patched_getresponse = HTTPSConnection.getresponse + + request_headers = {} + + class HTTPSConnectionRecordingRequestHeaders(HTTPSConnection): + def send(self, *args, **kwargs) -> None: + request_str = args[0] + for line in request_str.decode("utf-8").split("\r\n")[1:]: + if line: + key, val = line.split(": ") + request_headers[key] = val + + server_sock, client_sock = socket.socketpair() + server_sock.sendall(b"HTTP/1.1 200 OK\r\n\r\n") + server_sock.close() + self.sock = client_sock + + def getresponse(self, *args, **kwargs): + return already_patched_getresponse(self, *args, **kwargs) + headers = { "baggage": ( "sentry-trace_id=771a43a4192642f0b136d5159a501700, " @@ -385,21 +426,16 @@ def test_option_trace_propagation_targets( op="greeting.sniff", trace_id="12312012123120121231201212312012", ) as transaction: - HTTPSConnection(host).request("GET", path) - - (request_str,) = mock_send.call_args[0] - request_headers = {} - for line in request_str.decode("utf-8").split("\r\n")[1:]: - if line: - key, val = line.split(": ") - request_headers[key] = val - - if trace_propagated: - assert "sentry-trace" in request_headers - assert "baggage" in request_headers - else: - assert "sentry-trace" not in request_headers - assert "baggage" not in request_headers + connection = HTTPSConnectionRecordingRequestHeaders(host) + connection.request("GET", path) + connection.getresponse() + + if trace_propagated: + assert "sentry-trace" in request_headers + assert "baggage" in request_headers + else: + assert "sentry-trace" not in request_headers + assert "baggage" not in request_headers def test_request_source_disabled(sentry_init, capture_events):