diff --git a/e2e-cli/e2e-config.json b/e2e-cli/e2e-config.json index b0ccf30..e1a02d5 100644 --- a/e2e-cli/e2e-config.json +++ b/e2e-cli/e2e-config.json @@ -1,6 +1,6 @@ { "sdk": "python", - "test_suites": "basic", + "test_suites": "basic,retry", "auto_settings": false, "patch": null, "env": {} diff --git a/segment/analytics/client.py b/segment/analytics/client.py index 0f8015c..5cb0973 100644 --- a/segment/analytics/client.py +++ b/segment/analytics/client.py @@ -30,7 +30,9 @@ class DefaultConfig(object): max_queue_size = 10000 gzip = False timeout = 15 - max_retries = 10 + max_retries = 1000 + max_total_backoff_duration = 43200 + max_rate_limit_duration = 43200 proxies = None thread = 1 upload_interval = 0.5 @@ -65,9 +67,16 @@ def __init__(self, oauth_client_key=DefaultConfig.oauth_client_key, oauth_key_id=DefaultConfig.oauth_key_id, oauth_auth_server=DefaultConfig.oauth_auth_server, - oauth_scope=DefaultConfig.oauth_scope,): + oauth_scope=DefaultConfig.oauth_scope, + max_total_backoff_duration=DefaultConfig.max_total_backoff_duration, + max_rate_limit_duration=DefaultConfig.max_rate_limit_duration,): require('write_key', write_key, str) + if max_total_backoff_duration is not None and max_total_backoff_duration < 0: + raise ValueError('max_total_backoff_duration must be non-negative') + if max_rate_limit_duration is not None and max_rate_limit_duration < 0: + raise ValueError('max_rate_limit_duration must be non-negative') + self.queue = queue.Queue(max_queue_size) self.write_key = write_key self.on_error = on_error @@ -78,6 +87,8 @@ def __init__(self, self.gzip = gzip self.timeout = timeout self.proxies = proxies + self.max_total_backoff_duration = max_total_backoff_duration + self.max_rate_limit_duration = max_rate_limit_duration self.oauth_manager = None if(oauth_client_id and oauth_client_key and oauth_key_id): self.oauth_manager = OauthManager(oauth_client_id, oauth_client_key, oauth_key_id, @@ -110,6 +121,8 @@ def __init__(self, upload_size=upload_size, upload_interval=upload_interval, gzip=gzip, retries=max_retries, timeout=timeout, proxies=proxies, oauth_manager=self.oauth_manager, + max_total_backoff_duration=max_total_backoff_duration, + max_rate_limit_duration=max_rate_limit_duration, ) self.consumers.append(consumer) diff --git a/segment/analytics/consumer.py b/segment/analytics/consumer.py index 157e3c9..b2143f8 100644 --- a/segment/analytics/consumer.py +++ b/segment/analytics/consumer.py @@ -1,10 +1,10 @@ import logging import time +import random from threading import Thread -import backoff import json -from segment.analytics.request import post, APIError, DatetimeSerializer +from segment.analytics.request import post, APIError, DatetimeSerializer, parse_retry_after from queue import Empty @@ -14,6 +14,10 @@ # lower to leave space for extra data that will be added later, eg. "sentAt". BATCH_SIZE_LIMIT = 475000 +# Default duration limits (12 hours in seconds) +DEFAULT_MAX_TOTAL_BACKOFF_DURATION = 43200 +DEFAULT_MAX_RATE_LIMIT_DURATION = 43200 + class FatalError(Exception): def __init__(self, message): @@ -29,8 +33,10 @@ class Consumer(Thread): log = logging.getLogger('segment') def __init__(self, queue, write_key, upload_size=100, host=None, - on_error=None, upload_interval=0.5, gzip=False, retries=10, - timeout=15, proxies=None, oauth_manager=None): + on_error=None, upload_interval=0.5, gzip=False, retries=1000, + timeout=15, proxies=None, oauth_manager=None, + max_total_backoff_duration=DEFAULT_MAX_TOTAL_BACKOFF_DURATION, + max_rate_limit_duration=DEFAULT_MAX_RATE_LIMIT_DURATION): """Create a consumer thread.""" Thread.__init__(self) # Make consumer a daemon thread so that it doesn't block program exit @@ -51,6 +57,12 @@ def __init__(self, queue, write_key, upload_size=100, host=None, self.timeout = timeout self.proxies = proxies self.oauth_manager = oauth_manager + self.max_total_backoff_duration = max_total_backoff_duration + self.max_rate_limit_duration = max_rate_limit_duration + + # Rate-limit state + self.rate_limited_until = None + self.rate_limit_start_time = None def run(self): """Runs the consumer.""" @@ -64,6 +76,19 @@ def pause(self): """Pause the consumer.""" self.running = False + def set_rate_limit_state(self, response): + """Set rate-limit state from a 429 response with a valid Retry-After header.""" + retry_after = parse_retry_after(response) if response else None + if retry_after is not None: + self.rate_limited_until = time.time() + retry_after + if self.rate_limit_start_time is None: + self.rate_limit_start_time = time.time() + + def clear_rate_limit_state(self): + """Clear rate-limit state after successful request or duration exceeded.""" + self.rate_limited_until = None + self.rate_limit_start_time = None + def upload(self): """Upload the next batch of items, return whether successful.""" success = False @@ -71,9 +96,57 @@ def upload(self): if len(batch) == 0: return False + # Check rate-limit state before attempting upload + if self.rate_limited_until is not None: + now = time.time() + + # Check if maxRateLimitDuration has been exceeded + if (self.rate_limit_start_time is not None and + now - self.rate_limit_start_time > self.max_rate_limit_duration): + self.log.error( + 'Rate limit duration exceeded (%ds). Clearing rate-limit state and dropping batch.', + self.max_rate_limit_duration + ) + self.clear_rate_limit_state() + # Drop the batch by marking items as done + if self.on_error: + self.on_error( + Exception('Rate limit duration exceeded, batch dropped'), + batch + ) + for _ in batch: + self.queue.task_done() + return False + + # Still rate-limited; wait until the rate limit expires + wait_time = self.rate_limited_until - now + if wait_time > 0: + self.log.debug( + 'Rate-limited. Waiting %.2fs before next upload attempt.', + wait_time + ) + time.sleep(wait_time) + try: self.request(batch) + # Success — clear rate-limit state + self.clear_rate_limit_state() success = True + except APIError as e: + if e.status == 429 and self.rate_limited_until is not None: + # 429: rate-limit state already set by request(). Re-queue batch. + self.log.debug('429 received. Re-queuing batch and halting upload iteration.') + for item in batch: + try: + self.queue.put(item, block=False) + except Exception: + pass # Queue full, item lost + success = False + else: + self.log.error('error uploading: %s', e) + success = False + if self.on_error: + self.on_error(e, batch) except Exception as e: self.log.error('error uploading: %s', e) success = False @@ -120,40 +193,133 @@ def next(self): return items def request(self, batch): - """Attempt to upload the batch and retry before raising an error """ - - def fatal_exception(exc): - if isinstance(exc, APIError): - # retry on server errors and client errors - # with 429 status code (rate limited), - # don't retry on other client errors - return (400 <= exc.status < 500) and exc.status != 429 - elif isinstance(exc, FatalError): + """Attempt to upload the batch and retry before raising an error""" + + def is_retryable_status(status): + """ + Determine if a status code is retryable. + Retryable 4xx: 408, 410, 429, 460 + Non-retryable 4xx: 400, 401, 403, 404, 413, 422, and all other 4xx + Retryable 5xx: All except 501, 505 + - 511 is only retryable when OauthManager is configured + Non-retryable 5xx: 501, 505 + """ + if 400 <= status < 500: + return status in (408, 410, 429, 460) + elif 500 <= status < 600: + if status in (501, 505): + return False + if status == 511: + return self.oauth_manager is not None return True - else: - # retry on all other errors (eg. network) - return False + return False + + def calculate_backoff_delay(attempt): + """ + Calculate exponential backoff delay with jitter. + First retry is immediate, then 0.5s, 1s, 2s, 4s, etc. + """ + if attempt == 1: + return 0 # First retry is immediate + base_delay = 0.5 * (2 ** (attempt - 2)) + jitter = random.uniform(0, 0.1 * base_delay) + return min(base_delay + jitter, 60) # Cap at 60 seconds + + total_attempts = 0 + backoff_attempts = 0 + first_failure_time = None + + while True: + total_attempts += 1 - attempt_count = 0 - - @backoff.on_exception( - backoff.expo, - Exception, - max_tries=self.retries + 1, - giveup=fatal_exception, - on_backoff=lambda details: self.log.debug( - f"Retry attempt {details['tries']}/{self.retries + 1} after {details['elapsed']:.2f}s" - )) - def send_request(): - nonlocal attempt_count - attempt_count += 1 try: - return post(self.write_key, self.host, gzip=self.gzip, - timeout=self.timeout, batch=batch, proxies=self.proxies, - oauth_manager=self.oauth_manager) - except Exception as e: - if attempt_count >= self.retries + 1: - self.log.error(f"All {self.retries} retries exhausted. Final error: {e}") + # Make the request with current retry count + response = post( + self.write_key, + self.host, + gzip=self.gzip, + timeout=self.timeout, + batch=batch, + proxies=self.proxies, + oauth_manager=self.oauth_manager, + retry_count=total_attempts - 1 + ) + # Success + return response + + except FatalError as e: + # Non-retryable error + self.log.error(f"Fatal error after {total_attempts} attempts: {e}") raise - send_request() + except APIError as e: + # 429 with valid Retry-After: set rate-limit state and raise + # to caller (pipeline blocking). Without Retry-After, fall + # through to counted backoff like any other retryable error. + if e.status == 429: + retry_after = parse_retry_after(e.response) if e.response else None + if retry_after is not None: + self.set_rate_limit_state(e.response) + raise + + # Check if status is retryable + if not is_retryable_status(e.status): + self.log.error( + f"Non-retryable error {e.status} after {total_attempts} attempts: {e}" + ) + raise + + # Transient error -- per-batch backoff + if first_failure_time is None: + first_failure_time = time.time() + if time.time() - first_failure_time > self.max_total_backoff_duration: + self.log.error( + f"Max total backoff duration ({self.max_total_backoff_duration}s) exceeded " + f"after {total_attempts} attempts. Final error: {e}" + ) + raise + + # Count this against backoff attempts + backoff_attempts += 1 + if backoff_attempts >= self.retries + 1: + self.log.error( + f"All {self.retries} retries exhausted after {total_attempts} total attempts. Final error: {e}" + ) + raise + + # Calculate exponential backoff delay with jitter + delay = calculate_backoff_delay(backoff_attempts) + + self.log.debug( + f"Retry attempt {backoff_attempts}/{self.retries} (total attempts: {total_attempts}) " + f"after {delay:.2f}s for status {e.status}" + ) + time.sleep(delay) + + except Exception as e: + # Network errors or other exceptions - retry with backoff + if first_failure_time is None: + first_failure_time = time.time() + if time.time() - first_failure_time > self.max_total_backoff_duration: + self.log.error( + f"Max total backoff duration ({self.max_total_backoff_duration}s) exceeded " + f"after {total_attempts} attempts. Final error: {e}" + ) + raise + + backoff_attempts += 1 + + if backoff_attempts >= self.retries + 1: + self.log.error( + f"All {self.retries} retries exhausted after {total_attempts} total attempts. Final error: {e}" + ) + raise + + # Calculate exponential backoff delay with jitter + delay = calculate_backoff_delay(backoff_attempts) + + self.log.debug( + f"Network error retry {backoff_attempts}/{self.retries} (total attempts: {total_attempts}) " + f"after {delay:.2f}s: {e}" + ) + time.sleep(delay) diff --git a/segment/analytics/request.py b/segment/analytics/request.py index ab92b80..600fda3 100644 --- a/segment/analytics/request.py +++ b/segment/analytics/request.py @@ -3,8 +3,9 @@ from gzip import GzipFile import logging import json +import base64 from dateutil.tz import tzutc -from requests.auth import HTTPBasicAuth + from requests import sessions from segment.analytics.version import VERSION @@ -12,8 +13,32 @@ _session = sessions.Session() +# Maximum Retry-After delay to respect (5 minutes) +MAX_RETRY_AFTER_SECONDS = 300 + + +def parse_retry_after(response): + """ + Parse Retry-After header from response. + Returns the delay in seconds, or None if header is not present or invalid. + Caps the value at MAX_RETRY_AFTER_SECONDS. + """ + retry_after = response.headers.get('Retry-After') + if not retry_after: + return None -def post(write_key, host=None, gzip=False, timeout=15, proxies=None, oauth_manager=None, **kwargs): + try: + # Try parsing as integer (delay in seconds) + delay = int(retry_after) + # Ensure delay is non-negative before applying upper bound + return min(max(delay, 0), MAX_RETRY_AFTER_SECONDS) + except ValueError: + # Could be HTTP-date format, but for simplicity we'll skip that + # Most APIs use integer seconds + return None + + +def post(write_key, host=None, gzip=False, timeout=15, proxies=None, oauth_manager=None, retry_count=0, **kwargs): """Post the `kwargs` to the API""" log = logging.getLogger('segment') body = kwargs @@ -28,10 +53,18 @@ def post(write_key, host=None, gzip=False, timeout=15, proxies=None, oauth_manag log.debug('making request: %s', data) headers = { 'Content-Type': 'application/json', - 'User-Agent': 'analytics-python/' + VERSION + 'User-Agent': 'analytics-python/' + VERSION, + 'X-Retry-Count': str(retry_count) } + + # Add Authorization header - prefer OAuth Bearer token, fallback to Basic auth if auth: headers['Authorization'] = 'Bearer {}'.format(auth) + else: + # Basic auth with write key (format: "writeKey:" encoded in base64) + credentials = '{}:'.format(write_key) + encoded = base64.b64encode(credentials.encode('utf-8')).decode('utf-8') + headers['Authorization'] = 'Basic {}'.format(encoded) if gzip: headers['Content-Encoding'] = 'gzip' @@ -60,24 +93,25 @@ def post(write_key, host=None, gzip=False, timeout=15, proxies=None, oauth_manag log.debug('data uploaded successfully') return res - if oauth_manager and res.status_code in [400, 401, 403]: + if oauth_manager and res.status_code in [400, 401, 403, 511]: oauth_manager.clear_token() try: payload = res.json() log.debug('received response: %s', payload) - raise APIError(res.status_code, payload['code'], payload['message']) - except ValueError: + raise APIError(res.status_code, payload['code'], payload['message'], res) + except (ValueError, KeyError): log.error('Unknown error: [%s] %s', res.status_code, res.reason) - raise APIError(res.status_code, 'unknown', res.text) + raise APIError(res.status_code, 'unknown', res.text, res) class APIError(Exception): - def __init__(self, status, code, message): + def __init__(self, status, code, message, response=None): self.message = message self.status = status self.code = code + self.response = response def __str__(self): msg = "[Segment] {0}: {1} ({2})" diff --git a/segment/analytics/test/test_consumer.py b/segment/analytics/test/test_consumer.py index 8371726..a197843 100644 --- a/segment/analytics/test/test_consumer.py +++ b/segment/analytics/test/test_consumer.py @@ -8,7 +8,7 @@ except ImportError: from Queue import Queue -from segment.analytics.consumer import Consumer, MAX_MSG_SIZE +from segment.analytics.consumer import Consumer, MAX_MSG_SIZE, FatalError from segment.analytics.request import APIError @@ -144,7 +144,7 @@ def test_request_retry(self): self._test_request_retry(consumer, APIError( 500, 'code', 'Internal Server Error'), 2) - # we should retry on HTTP 429 errors + # 429 without Retry-After uses counted backoff (like other retryable errors) consumer = Consumer(None, 'testsecret') self._test_request_retry(consumer, APIError( 429, 'code', 'Too Many Requests'), 2) @@ -155,7 +155,7 @@ def test_request_retry(self): try: self._test_request_retry(consumer, api_error, 1) except APIError: - pass + pass # Expected: 400 is non-retryable, so the error propagates here else: self.fail('request() should not retry on client errors') @@ -220,3 +220,663 @@ def mock_post_fn(*args, **kwargs): args, kwargs = mock_post.call_args cls().assertIn('proxies', kwargs) cls().assertEqual(kwargs['proxies'], proxies) + + def test_retry_count_header_increments(self): + """Test that X-Retry-Count header increments on each retry""" + consumer = Consumer(None, 'testsecret', retries=3) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + retry_counts = [] + + def mock_post_fn(*args, **kwargs): + retry_counts.append(kwargs.get('retry_count', 0)) + if len(retry_counts) < 3: + raise APIError(500, 'error', 'Server Error') + # Success on third attempt + return mock.Mock(status_code=200) + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + consumer.request([track]) + + # Should have been called 3 times with retry counts 0, 1, 2 + self.assertEqual(retry_counts, [0, 1, 2]) + + def test_non_retryable_4xx_status_codes(self): + """Test that non-retryable 4xx errors are not retried""" + consumer = Consumer(None, 'testsecret', retries=3) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + non_retryable_codes = [400, 401, 403, 404, 413, 422] + + for status_code in non_retryable_codes: + call_count = 0 + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + raise APIError(status_code, 'error', f'Client Error {status_code}') + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + try: + consumer.request([track]) + except APIError as e: + self.assertEqual(e.status, status_code) + + # Should only be called once (no retries) + self.assertEqual(call_count, 1, f'Status {status_code} should not be retried') + + def test_retryable_4xx_status_codes(self): + """Test that retryable 4xx errors are retried (429 without Retry-After uses backoff too)""" + consumer = Consumer(None, 'testsecret', retries=3) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + retryable_codes = [408, 410, 429, 460] + + for status_code in retryable_codes: + call_count = 0 + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise APIError(status_code, 'error', f'Retryable Error {status_code}') + # Success on third attempt + return mock.Mock(status_code=200) + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep'): # Mock sleep to speed up test + consumer.request([track]) + + # Should have been called 3 times + self.assertEqual(call_count, 3, f'Status {status_code} should be retried') + + def test_non_retryable_5xx_status_codes(self): + """Test that non-retryable 5xx errors are not retried""" + consumer = Consumer(None, 'testsecret', retries=3) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + non_retryable_codes = [501, 505] + + for status_code in non_retryable_codes: + call_count = 0 + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + raise APIError(status_code, 'error', f'Server Error {status_code}') + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + try: + consumer.request([track]) + except APIError as e: + self.assertEqual(e.status, status_code) + + # Should only be called once (no retries) + self.assertEqual(call_count, 1, f'Status {status_code} should not be retried') + + def test_retryable_5xx_status_codes(self): + """Test that retryable 5xx errors are retried""" + consumer = Consumer(None, 'testsecret', retries=3) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + retryable_codes = [500, 502, 503, 504] + + for status_code in retryable_codes: + call_count = 0 + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise APIError(status_code, 'error', f'Server Error {status_code}') + # Success on third attempt + return mock.Mock(status_code=200) + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep'): # Mock sleep to speed up test + consumer.request([track]) + + # Should have been called 3 times + self.assertEqual(call_count, 3, f'Status {status_code} should be retried') + + def test_429_sets_rate_limit_state_with_retry_after(self): + """Test that 429 with Retry-After sets rate_limited_until on consumer""" + consumer = Consumer(None, 'testsecret', retries=2) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + def mock_post_fn(*args, **kwargs): + response = mock.Mock() + response.headers = {'Retry-After': '10'} + error = APIError(429, 'rate_limit', 'Too Many Requests') + error.response = response + raise error + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with self.assertRaises(APIError) as ctx: + consumer.request([track]) + self.assertEqual(ctx.exception.status, 429) + + # Rate-limit state should be set + self.assertIsNotNone(consumer.rate_limited_until) + self.assertIsNotNone(consumer.rate_limit_start_time) + # rate_limited_until should be ~10 seconds in the future + self.assertGreater(consumer.rate_limited_until, time.time() + 5) + + def test_retry_after_capped_at_300_seconds(self): + """Test that Retry-After delay is capped at 300 seconds when setting rate-limit state""" + consumer = Consumer(None, 'testsecret', retries=2) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + def mock_post_fn(*args, **kwargs): + response = mock.Mock() + response.headers = {'Retry-After': '600'} # 10 minutes + error = APIError(429, 'rate_limit', 'Too Many Requests') + error.response = response + raise error + + now = time.time() + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with self.assertRaises(APIError): + consumer.request([track]) + + # rate_limited_until should be capped at ~300s from now (not 600s) + self.assertIsNotNone(consumer.rate_limited_until) + self.assertLessEqual(consumer.rate_limited_until, now + 310) + self.assertGreater(consumer.rate_limited_until, now + 290) + + def test_408_and_503_use_backoff(self): + """Test that 408 and 503 use exponential backoff""" + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + for status_code in [408, 503]: + consumer = Consumer(None, 'testsecret', retries=2) + call_count = 0 + sleep_durations = [] + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + response = mock.Mock() + response.headers = {'Retry-After': '5'} + error = APIError(status_code, 'error', 'Error') + error.response = response + raise error + return mock.Mock(status_code=200) + + def mock_sleep(duration): + sleep_durations.append(duration) + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep', side_effect=mock_sleep): + consumer.request([track]) + + # Should use backoff delay (0 for first retry), NOT the Retry-After value of 5 + self.assertEqual(call_count, 2) + self.assertEqual(len(sleep_durations), 1) + self.assertEqual(sleep_durations[0], 0, f'{status_code} should use backoff, not Retry-After') + + def test_exponential_backoff_with_jitter(self): + """Test that exponential backoff is used for retries without Retry-After""" + consumer = Consumer(None, 'testsecret', retries=4) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + call_count = 0 + sleep_durations = [] + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + + if call_count <= 3: + raise APIError(500, 'error', 'Server Error') + + return mock.Mock(status_code=200) + + def mock_sleep(duration): + sleep_durations.append(duration) + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep', side_effect=mock_sleep): + consumer.request([track]) + + # Should have 3 backoff delays + self.assertEqual(len(sleep_durations), 3) + + # Delays should be increasing (exponential) + # First: 0s (immediate), Second: ~0.5s, Third: ~1s (with jitter) + self.assertEqual(sleep_durations[0], 0) # First retry is immediate + self.assertGreater(sleep_durations[1], 0.4) + self.assertLess(sleep_durations[1], 0.6) + self.assertGreater(sleep_durations[2], 0.9) + self.assertLess(sleep_durations[2], 1.2) + + def test_fatal_error_not_retried(self): + """Test that FatalError is not retried""" + consumer = Consumer(None, 'testsecret', retries=3) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + call_count = 0 + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + raise FatalError('Fatal error occurred') + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with self.assertRaises(FatalError): + consumer.request([track]) + + # Should only be called once (no retries) + self.assertEqual(call_count, 1) + + def test_max_retries_exhausted(self): + """Test that request fails after max retries exhausted""" + consumer = Consumer(None, 'testsecret', retries=2) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + call_count = 0 + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + # Always fail with retryable error + raise APIError(500, 'error', 'Server Error') + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep'): # Mock sleep to speed up test + try: + consumer.request([track]) + except APIError as e: + self.assertEqual(e.status, 500) + + # Should be called 3 times (initial + 2 retries) + self.assertEqual(call_count, 3) + + def test_first_request_has_retry_count_zero(self): + """T01: First successful request includes X-Retry-Count=0""" + consumer = Consumer(None, 'testsecret') + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + retry_count = None + + def mock_post_fn(*args, **kwargs): + nonlocal retry_count + retry_count = kwargs.get('retry_count') + return mock.Mock(status_code=200) + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + consumer.request([track]) + + # First request should have retry_count=0 + self.assertEqual(retry_count, 0) + + def test_429_without_retry_after_uses_counted_backoff(self): + """429 without Retry-After uses counted backoff (not pipeline blocking)""" + consumer = Consumer(None, 'testsecret', retries=2) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + call_count = 0 + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count < 3: + error = APIError(429, 'rate_limit', 'Too Many Requests') + error.response = mock.Mock() + error.response.headers = {} # No Retry-After + raise error + return mock.Mock(status_code=200) + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep'): + consumer.request([track]) + + # Should retry with backoff (3 calls: initial + 2 retries) + self.assertEqual(call_count, 3) + # Rate-limit state should NOT be set (no pipeline blocking) + self.assertIsNone(consumer.rate_limited_until) + + def test_408_without_retry_after_uses_backoff(self): + """T10: 408 without Retry-After header uses backoff retry""" + consumer = Consumer(None, 'testsecret', retries=3) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + call_count = 0 + retry_counts = [] + sleep_duration = None + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + retry_counts.append(kwargs.get('retry_count', 0)) + + if call_count == 1: + # 408 without Retry-After header + error = APIError(408, 'timeout', 'Request Timeout') + error.response = mock.Mock() + error.response.headers = {} # No Retry-After + raise error + + return mock.Mock(status_code=200) + + def mock_sleep(duration): + nonlocal sleep_duration + sleep_duration = duration + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep', side_effect=mock_sleep): + consumer.request([track]) + + # Should have two attempts + self.assertEqual(call_count, 2) + self.assertEqual(retry_counts, [0, 1]) + + # First retry should be immediate (0s delay) + self.assertIsNotNone(sleep_duration) + if sleep_duration is not None: + self.assertEqual(sleep_duration, 0) + + def test_network_error_retried_with_backoff(self): + """T15: Network/IO error is retried with backoff""" + consumer = Consumer(None, 'testsecret', retries=3) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + call_count = 0 + retry_counts = [] + sleep_duration = None + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + retry_counts.append(kwargs.get('retry_count', 0)) + + if call_count == 1: + # Network error + raise ConnectionError('Network connection failed') + + return mock.Mock(status_code=200) + + def mock_sleep(duration): + nonlocal sleep_duration + sleep_duration = duration + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep', side_effect=mock_sleep): + consumer.request([track]) + + # Should have two attempts + self.assertEqual(call_count, 2) + self.assertEqual(retry_counts, [0, 1]) + + # First retry should be immediate (0s delay) + self.assertIsNotNone(sleep_duration) + if sleep_duration is not None: + self.assertEqual(sleep_duration, 0) + + def test_511_not_retryable_without_oauth(self): + """T17: 511 is NOT retried when OauthManager is not configured""" + consumer = Consumer(None, 'testsecret', retries=3) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + call_count = 0 + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + raise APIError(511, 'auth_required', 'Network Authentication Required') + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with self.assertRaises(APIError) as ctx: + consumer.request([track]) + self.assertEqual(ctx.exception.status, 511) + + # Should only be called once (not retried without OAuth) + self.assertEqual(call_count, 1) + + def test_511_retryable_with_oauth(self): + """T17: 511 IS retried when OauthManager is configured""" + oauth_manager = mock.Mock() + consumer = Consumer(None, 'testsecret', retries=3, oauth_manager=oauth_manager) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + call_count = 0 + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise APIError(511, 'auth_required', 'Network Authentication Required') + return mock.Mock(status_code=200) + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep'): + consumer.request([track]) + + # Should have been called 3 times (511 is retryable with OAuth) + self.assertEqual(call_count, 3) + + def test_429_with_retry_after_does_not_count_against_backoff_budget(self): + """429 with Retry-After raises immediately (pipeline blocking) without consuming backoff budget""" + consumer = Consumer(None, 'testsecret', retries=1) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + call_count = 0 + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + error = APIError(429, 'rate_limit', 'Too Many Requests') + error.response = mock.Mock() + error.response.headers = {'Retry-After': '1'} + raise error + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with self.assertRaises(APIError) as ctx: + consumer.request([track]) + self.assertEqual(ctx.exception.status, 429) + + # 429 with Retry-After raises on first attempt (pipeline blocking) + self.assertEqual(call_count, 1) + + def test_413_payload_too_large_not_retried(self): + """T12: 413 Payload Too Large is non-retryable (won't succeed on retry)""" + consumer = Consumer(None, 'testsecret', retries=3) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + call_count = 0 + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + raise APIError(413, 'payload_too_large', 'Payload Too Large') + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + try: + consumer.request([track]) + except APIError as e: + self.assertEqual(e.status, 413) + + # Should only be called once (no retries) + self.assertEqual(call_count, 1) + + def test_t04_429_halts_upload_iteration(self): + """T04: 429 halts current upload iteration — batch is re-queued, not dropped""" + q = Queue() + consumer = Consumer(q, 'testsecret', retries=3) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + # Put a message in the queue + q.put(track) + + call_count = 0 + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + response = mock.Mock() + response.headers = {'Retry-After': '10'} + error = APIError(429, 'rate_limit', 'Too Many Requests') + error.response = response + raise error + + on_error_called = [] + + def on_error(e, batch): + on_error_called.append((e, batch)) + + consumer.on_error = on_error + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep'): + result = consumer.upload() + + # upload() should return False (not successful) + self.assertFalse(result) + # request() should have been called exactly once + self.assertEqual(call_count, 1) + # on_error should NOT have been called (batch was re-queued, not dropped) + self.assertEqual(len(on_error_called), 0) + # Rate-limit state should be set + self.assertIsNotNone(consumer.rate_limited_until) + self.assertIsNotNone(consumer.rate_limit_start_time) + + def test_429_without_retry_after_does_not_requeue_batch(self): + """429 without Retry-After is treated as normal failure in upload() and is not re-queued""" + q = Queue() + consumer = Consumer(q, 'testsecret', retries=0) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + q.put(track) + + def mock_post_fn(*args, **kwargs): + error = APIError(429, 'rate_limit', 'Too Many Requests') + error.response = mock.Mock() + error.response.headers = {} + raise error + + on_error_called = [] + + def on_error(e, batch): + on_error_called.append((e, batch)) + + consumer.on_error = on_error + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep'): + result = consumer.upload() + + self.assertFalse(result) + self.assertEqual(len(on_error_called), 1) + self.assertIsNone(consumer.rate_limited_until) + self.assertEqual(q.qsize(), 0) + + def test_retry_after_zero_sets_rate_limit_state(self): + """429 with Retry-After: 0 still sets rate-limit state for consistent pipeline handling""" + consumer = Consumer(None, 'testsecret', retries=1) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + def mock_post_fn(*args, **kwargs): + response = mock.Mock() + response.headers = {'Retry-After': '0'} + error = APIError(429, 'rate_limit', 'Too Many Requests') + error.response = response + raise error + + before = time.time() + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with self.assertRaises(APIError) as ctx: + consumer.request([track]) + self.assertEqual(ctx.exception.status, 429) + after = time.time() + + self.assertIsNotNone(consumer.rate_limited_until) + self.assertIsNotNone(consumer.rate_limit_start_time) + self.assertGreaterEqual(consumer.rate_limited_until, before) + self.assertLessEqual(consumer.rate_limited_until, after + 0.1) + + def test_t19_max_total_backoff_duration(self): + """T19: Gives up after maxTotalBackoffDuration elapsed""" + consumer = Consumer(None, 'testsecret', retries=1000, + max_total_backoff_duration=5) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + call_count = 0 + fake_time = [100.0] # Start time + + def mock_post_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + raise APIError(500, 'error', 'Server Error') + + original_time = time.time + + def mock_time(): + # Advance time by 3 seconds on each call after the first + result = fake_time[0] + fake_time[0] += 3.0 + return result + + with mock.patch('segment.analytics.consumer.post', side_effect=mock_post_fn): + with mock.patch('time.sleep'): + with mock.patch('time.time', side_effect=mock_time): + with self.assertRaises(APIError) as ctx: + consumer.request([track]) + self.assertEqual(ctx.exception.status, 500) + + # With max_total_backoff_duration=5 and time advancing 3s per call: + # Attempt 1: fails, first_failure_time set at 100, time now 103 + # Attempt 2: fails, time is 106, 106-100=6 > 5, exceeds duration + # So should be called exactly 2 times + self.assertEqual(call_count, 2) + + def test_t20_max_rate_limit_duration(self): + """T20: Rate-limited state clears and batch is dropped after maxRateLimitDuration""" + q = Queue() + consumer = Consumer(q, 'testsecret', retries=3, + max_rate_limit_duration=10) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + # Pre-set rate-limit state as if we entered it 15 seconds ago + now = time.time() + consumer.rate_limit_start_time = now - 15 # 15s ago, exceeds 10s limit + consumer.rate_limited_until = now + 5 # Would still be rate-limited + + # Put a message in the queue + q.put(track) + + on_error_called = [] + + def on_error(e, batch): + on_error_called.append((e, batch)) + + consumer.on_error = on_error + + # upload() should detect duration exceeded, clear state, drop batch + result = consumer.upload() + + self.assertFalse(result) + # Rate-limit state should be cleared + self.assertIsNone(consumer.rate_limited_until) + self.assertIsNone(consumer.rate_limit_start_time) + # on_error should have been called (batch was dropped) + self.assertEqual(len(on_error_called), 1) + + def test_rate_limit_state_cleared_on_success(self): + """Rate-limit state is cleared after a successful request""" + q = Queue() + consumer = Consumer(q, 'testsecret', retries=3) + track = {'type': 'track', 'event': 'python event', 'userId': 'userId'} + + # Set rate-limit state + consumer.rate_limited_until = time.time() - 1 # Already expired + consumer.rate_limit_start_time = time.time() - 10 + + q.put(track) + + with mock.patch('segment.analytics.consumer.post', return_value=mock.Mock(status_code=200)): + result = consumer.upload() + + self.assertTrue(result) + # Rate-limit state should be cleared on success + self.assertIsNone(consumer.rate_limited_until) + self.assertIsNone(consumer.rate_limit_start_time) diff --git a/segment/analytics/test/test_request.py b/segment/analytics/test/test_request.py index 5ffca00..e0206dd 100644 --- a/segment/analytics/test/test_request.py +++ b/segment/analytics/test/test_request.py @@ -2,9 +2,10 @@ import unittest import json import requests +import base64 from unittest import mock -from segment.analytics.request import post, DatetimeSerializer +from segment.analytics.request import post, DatetimeSerializer, parse_retry_after, APIError class TestRequests(unittest.TestCase): @@ -72,3 +73,151 @@ def mock_post_fn(*args, **kwargs): args, kwargs = mock_post.call_args self.assertIn('proxies', kwargs) self.assertEqual(kwargs['proxies'], proxies) + + def test_authorization_header_basic_auth(self): + """Test that Basic Authorization header is added when no OAuth manager""" + def mock_post_fn(*args, **kwargs): + res = mock.Mock() + res.status_code = 200 + return res + + with mock.patch('segment.analytics.request._session.post', side_effect=mock_post_fn) as mock_post: + post('testsecret', batch=[{ + 'userId': 'userId', + 'event': 'python event', + 'type': 'track' + }]) + + args, kwargs = mock_post.call_args + headers = kwargs['headers'] + self.assertIn('Authorization', headers) + + # Verify it's Basic auth with correct encoding + expected_credentials = base64.b64encode(b'testsecret:').decode('utf-8') + expected_auth = f'Basic {expected_credentials}' + self.assertEqual(headers['Authorization'], expected_auth) + + def test_authorization_header_oauth(self): + """Test that Bearer Authorization header is used with OAuth manager""" + oauth_manager = mock.Mock() + oauth_manager.get_token.return_value = 'test_token_123' + + def mock_post_fn(*args, **kwargs): + res = mock.Mock() + res.status_code = 200 + return res + + with mock.patch('segment.analytics.request._session.post', side_effect=mock_post_fn) as mock_post: + post('testsecret', oauth_manager=oauth_manager, batch=[{ + 'userId': 'userId', + 'event': 'python event', + 'type': 'track' + }]) + + args, kwargs = mock_post.call_args + headers = kwargs['headers'] + self.assertIn('Authorization', headers) + self.assertEqual(headers['Authorization'], 'Bearer test_token_123') + + def test_x_retry_count_header(self): + """Test that X-Retry-Count header is included""" + def mock_post_fn(*args, **kwargs): + res = mock.Mock() + res.status_code = 200 + return res + + with mock.patch('segment.analytics.request._session.post', side_effect=mock_post_fn) as mock_post: + # Test with retry_count=0 (first attempt) + post('testsecret', retry_count=0, batch=[{ + 'userId': 'userId', + 'event': 'python event', + 'type': 'track' + }]) + + args, kwargs = mock_post.call_args + headers = kwargs['headers'] + self.assertIn('X-Retry-Count', headers) + self.assertEqual(headers['X-Retry-Count'], '0') + + with mock.patch('segment.analytics.request._session.post', side_effect=mock_post_fn) as mock_post: + # Test with retry_count=5 + post('testsecret', retry_count=5, batch=[{ + 'userId': 'userId', + 'event': 'python event', + 'type': 'track' + }]) + + args, kwargs = mock_post.call_args + headers = kwargs['headers'] + self.assertEqual(headers['X-Retry-Count'], '5') + + def test_parse_retry_after_integer(self): + """Test parsing Retry-After header with integer seconds""" + response = mock.Mock() + response.headers = {'Retry-After': '30'} + result = parse_retry_after(response) + self.assertEqual(result, 30) + + def test_parse_retry_after_capped(self): + """Test that Retry-After is capped at 300 seconds""" + response = mock.Mock() + response.headers = {'Retry-After': '600'} + result = parse_retry_after(response) + self.assertEqual(result, 300) + + def test_parse_retry_after_missing(self): + """Test parsing when Retry-After header is missing""" + response = mock.Mock() + response.headers = {} + result = parse_retry_after(response) + self.assertIsNone(result) + + def test_parse_retry_after_invalid(self): + """Test parsing with invalid Retry-After header""" + response = mock.Mock() + response.headers = {'Retry-After': 'invalid'} + result = parse_retry_after(response) + self.assertIsNone(result) + + def test_oauth_token_cleared_on_511(self): + """Test that OAuth token is cleared on 511 status""" + oauth_manager = mock.Mock() + oauth_manager.get_token.return_value = 'test_token' + + def mock_post_fn(*args, **kwargs): + res = mock.Mock() + res.status_code = 511 + res.json.return_value = {'code': 'error', 'message': 'Network Authentication Required'} + return res + + with mock.patch('segment.analytics.request._session.post', side_effect=mock_post_fn): + with self.assertRaises(APIError): + post('testsecret', oauth_manager=oauth_manager, batch=[{ + 'userId': 'userId', + 'event': 'python event', + 'type': 'track' + }]) + + # Verify clear_token was called + oauth_manager.clear_token.assert_called_once() + + def test_api_error_includes_response(self): + """Test that APIError includes the response object""" + def mock_post_fn(*args, **kwargs): + res = mock.Mock() + res.status_code = 429 + res.json.return_value = {'code': 'rate_limit', 'message': 'Too Many Requests'} + return res + + with mock.patch('segment.analytics.request._session.post', side_effect=mock_post_fn): + try: + post('testsecret', batch=[{ + 'userId': 'userId', + 'event': 'python event', + 'type': 'track' + }]) + except APIError as e: + self.assertEqual(e.status, 429) + self.assertIsNotNone(e.response) + else: + self.fail('Expected APIError to be raised')