diff --git a/README.md b/README.md index 00148af..65e60fe 100644 --- a/README.md +++ b/README.md @@ -154,9 +154,26 @@ This is very ineficent and should not be used for high-volume schedules. Because This source holds values in lists. * For cron tasks it uses key `{prefix}:cron`. +* For interval tasks it uses key `{prefix}:interval`. * For timed schedules it uses key `{prefix}:time:{time}` where `{time}` is actually time where schedules should run. +* A sorted set at `{prefix}:time_index` tracks all time keys with their unix timestamps as scores, so that past time schedules can be discovered via `ZRANGEBYSCORE` instead of scanning all Redis keys. Stale entries (older than 5 minutes with empty time key lists) are cleaned up automatically. -The main advantage of this approach is that we only fetch tasks we need to run at a given time and do not perform any excesive calls to redis. +The main advantage of this approach is that we only fetch tasks we need to run at a given time and do not perform any excessive calls to redis. + +#### `populate_time_index` + +If you are upgrading from an older version that did not maintain the `{prefix}:time_index` sorted set, existing time keys will not be present in the index. Set `populate_time_index=True` once on startup to backfill the index via a one-time `SCAN`, then set it back to `False` for subsequent runs: + +```python +# First run after upgrading — backfills the time index +source = ListRedisScheduleSource( + "redis://localhost/1", + populate_time_index=True, +) + +# All subsequent runs — no SCAN, uses the time index +source = ListRedisScheduleSource("redis://localhost/1") +``` ### Migration from one source to another diff --git a/taskiq_redis/list_schedule_source.py b/taskiq_redis/list_schedule_source.py index 977a16d..7ca132e 100644 --- a/taskiq_redis/list_schedule_source.py +++ b/taskiq_redis/list_schedule_source.py @@ -1,4 +1,5 @@ import datetime +import time as _time from logging import getLogger from typing import Any @@ -23,6 +24,7 @@ def __init__( serializer: TaskiqSerializer | None = None, buffer_size: int = 50, skip_past_schedules: bool = False, + populate_time_index: bool = False, **connection_kwargs: Any, ) -> None: """ @@ -34,6 +36,11 @@ def __init__( :param serializer: Serializer to use for the schedules :param buffer_size: Buffer size for getting schedules :param skip_past_schedules: Skip schedules that are in the past. + :param populate_time_index: If True, on startup run a one-time SCAN + to populate the time index sorted set from existing time keys. + This is needed for migrating from an older version that did not + maintain the time index. Set this to True once to backfill the + index, then set it back to False for subsequent runs. :param connection_kwargs: Additional connection kwargs """ super().__init__() @@ -47,10 +54,11 @@ def __init__( if serializer is None: serializer = PickleSerializer() self._serializer = serializer - self._is_first_run = True self._previous_schedule_source: ScheduleSource | None = None self._delete_schedules_after_migration: bool = True self._skip_past_schedules = skip_past_schedules + self._populate_time_index = populate_time_index + self._last_cleanup_time: float = 0 async def startup(self) -> None: """ @@ -59,6 +67,9 @@ async def startup(self) -> None: By default this function does nothing. But if the previous schedule source is set, it will try to migrate schedules from it. + + If populate_time_index is True, it will scan for existing + time keys and populate the time index sorted set. """ if self._previous_schedule_source is not None: logger.info("Migrating schedules from previous source") @@ -74,6 +85,25 @@ async def startup(self) -> None: await self._previous_schedule_source.shutdown() logger.info("Migration complete") + if self._populate_time_index: + logger.info("Populating time index from existing keys via scan") + async with Redis(connection_pool=self._connection_pool) as redis: + batch: dict[str, float] = {} + async for key in redis.scan_iter(f"{self._prefix}:time:*"): + key_str = key.decode() + key_time = self._parse_time_key(key_str) + if key_time: + batch[key_str] = key_time.timestamp() + if len(batch) >= self._buffer_size: + await redis.zadd( + self._get_time_index_key(), + batch, + ) + batch = {} + if batch: + await redis.zadd(self._get_time_index_key(), batch) + logger.info("Time index population complete") + def _get_time_key(self, time: datetime.datetime) -> str: """Get the key for a time-based schedule.""" if time.tzinfo is None: @@ -81,6 +111,10 @@ def _get_time_key(self, time: datetime.datetime) -> str: iso_time = time.astimezone(datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M") return f"{self._prefix}:time:{iso_time}" + def _get_time_index_key(self) -> str: + """Get the key for the time index sorted set.""" + return f"{self._prefix}:time_index" + def _get_cron_key(self) -> str: """Get the key for a cron-based schedule.""" return f"{self._prefix}:cron" @@ -103,7 +137,46 @@ def _parse_time_key(self, key: str) -> datetime.datetime | None: logger.debug("Failed to parse time key %s", key) return None - async def _get_previous_time_schedules(self) -> list[bytes]: + async def _maybe_cleanup_time_index(self, redis: Redis) -> None: # type: ignore[type-arg] + """ + Run time index cleanup at most once per minute. + + Called from delete_schedule after removing a time-based schedule, + since that's the path where time key lists become empty. + """ + now = _time.monotonic() + if now - self._last_cleanup_time < 60: + return + self._last_cleanup_time = now + await self._cleanup_time_index(redis) + + async def _cleanup_time_index(self, redis: Redis) -> None: # type: ignore[type-arg] + """ + Remove stale entries from the time index sorted set. + + Only removes entries that are older than 5 minutes AND whose + corresponding time key list is empty (or no longer exists). + This avoids a race condition where an eager cleanup in + delete_schedule could remove an index entry right as + add_schedule is creating a new schedule at the same minute. + """ + five_minutes_ago = ( + datetime.datetime.now(datetime.timezone.utc) + - datetime.timedelta(minutes=5) + ).timestamp() + stale_keys: list[bytes] = await redis.zrangebyscore( + self._get_time_index_key(), + "-inf", + five_minutes_ago, + ) + for key in stale_keys: + if await redis.llen(key) == 0: + await redis.zrem(self._get_time_index_key(), key) + + async def _get_previous_time_schedules( + self, + current_time: datetime.datetime, + ) -> list[bytes]: """ Function that gets all timed schedules that are in the past. @@ -111,27 +184,31 @@ async def _get_previous_time_schedules(self) -> list[bytes]: we need to get all the schedules that are in the past and haven't been sent yet. - We do this by getting all the time keys and checking if the time - is less than the current time. + Uses the time index sorted set to look up past time keys + instead of scanning all Redis keys. + + Called on every get_schedules invocation so that schedules + added in a past minute (after the previous get_schedules call + but before the minute rolled over) are never missed. - This function is called only during the first run to minimize - the number of requests to the Redis server. + :param current_time: The reference time captured by the caller, + used to derive the cutoff so that the "previous" and "current" + windows never overlap. """ logger.info("Getting previous time schedules") - minute_before = datetime.datetime.now( - datetime.timezone.utc, - ).replace(second=0, microsecond=0) - datetime.timedelta( + minute_before = current_time.replace( + second=0, microsecond=0, + ) - datetime.timedelta( minutes=1, ) schedules = [] async with Redis(connection_pool=self._connection_pool) as redis: - time_keys: list[str] = [] - # We need to get all the time keys and check if the time is less than - # the current time. - async for key in redis.scan_iter(f"{self._prefix}:time:*"): - key_time = self._parse_time_key(key.decode()) - if key_time and key_time <= minute_before: - time_keys.append(key.decode()) + max_score = minute_before.timestamp() + time_keys: list[bytes] = await redis.zrangebyscore( + self._get_time_index_key(), + "-inf", + max_score, + ) for key in time_keys: schedules.extend(await redis.lrange(key, 0, -1)) # type: ignore[misc] @@ -153,6 +230,7 @@ async def delete_schedule(self, schedule_id: str) -> None: elif schedule.time is not None: time_key = self._get_time_key(schedule.time) await redis.lrem(time_key, 0, schedule_id) # type: ignore[misc] + await self._maybe_cleanup_time_index(redis) elif schedule.interval: await redis.lrem(self._get_interval_key(), 0, schedule_id) # type: ignore[misc] @@ -170,9 +248,21 @@ async def add_schedule(self, schedule: "ScheduledTask") -> None: if schedule.cron is not None: await redis.rpush(self._get_cron_key(), schedule.schedule_id) # type: ignore[misc] elif schedule.time is not None: - await redis.rpush( # type: ignore[misc] - self._get_time_key(schedule.time), - schedule.schedule_id, + time_key = self._get_time_key(schedule.time) + await redis.rpush(time_key, schedule.schedule_id) # type: ignore[misc] + # Add to the time index sorted set so we can look up + # past time keys without scanning all Redis keys. + time_val = schedule.time + if time_val.tzinfo is None: + time_val = time_val.replace(tzinfo=datetime.timezone.utc) + score = ( + time_val.astimezone(datetime.timezone.utc) + .replace(second=0, microsecond=0) + .timestamp() + ) + await redis.zadd( # type: ignore[misc] + self._get_time_index_key(), + {time_key: score}, ) elif schedule.interval: await redis.rpush( # type: ignore[misc] @@ -190,19 +280,19 @@ async def get_schedules(self) -> list["ScheduledTask"]: Get all schedules. This function gets all the schedules from the schedule source. - What it does is get all the cron schedules and time schedules - for the current time and return them. + What it does is get all the cron schedules, interval schedules, + past time schedules, and current-minute time schedules and + return them. - If it's the first run, it also gets all the time schedules - that are in the past and haven't been sent yet. + Past time schedules are fetched on every call so that + schedules added after the previous call but before the + minute rolled over are never missed. """ schedules = [] current_time = datetime.datetime.now(datetime.timezone.utc) timed: list[bytes] = [] - # Only during first run, we need to get previous time schedules if not self._skip_past_schedules: - timed = await self._get_previous_time_schedules() - self._is_first_run = False + timed = await self._get_previous_time_schedules(current_time) async with Redis(connection_pool=self._connection_pool) as redis: buffer = [] crons = await redis.lrange(self._get_cron_key(), 0, -1) # type: ignore[misc] diff --git a/tests/test_list_schedule_source.py b/tests/test_list_schedule_source.py index c21486b..76a0005 100644 --- a/tests/test_list_schedule_source.py +++ b/tests/test_list_schedule_source.py @@ -3,6 +3,7 @@ import pytest from freezegun import freeze_time +from redis.asyncio import BlockingConnectionPool, Redis from taskiq import ScheduledTask from taskiq_redis.list_schedule_source import ListRedisScheduleSource @@ -179,3 +180,328 @@ async def test_migration(redis_url: str) -> None: for old_schedule in old_schedules: with freeze_time(old_schedule.time): assert await source.get_schedules() == [old_schedule] + + +@pytest.mark.anyio +@freeze_time("2025-01-01 00:00:00") +async def test_time_index_populated_on_add(redis_url: str) -> None: + """Test that adding a time schedule populates the time index sorted set.""" + prefix = uuid.uuid4().hex + source = ListRedisScheduleSource(redis_url, prefix=prefix) + schedule = ScheduledTask( + task_name="test_task", + labels={}, + args=[], + kwargs={}, + time=datetime.datetime.now(datetime.timezone.utc) + + datetime.timedelta(minutes=5), + ) + await source.add_schedule(schedule) + + # Verify the time index sorted set has an entry. + async with Redis(connection_pool=source._connection_pool) as redis: + members = await redis.zrange(source._get_time_index_key(), 0, -1) + assert len(members) == 1 + assert members[0].decode() == source._get_time_key(schedule.time) + + +@pytest.mark.anyio +@freeze_time("2025-01-01 00:00:00") +async def test_time_index_not_eagerly_cleaned_on_delete(redis_url: str) -> None: + """Test that delete_schedule does NOT eagerly remove the index entry. + This avoids a race condition where a concurrent add_schedule at the + same minute could lose its index entry.""" + prefix = uuid.uuid4().hex + source = ListRedisScheduleSource(redis_url, prefix=prefix) + schedule = ScheduledTask( + task_name="test_task", + labels={}, + args=[], + kwargs={}, + time=datetime.datetime.now(datetime.timezone.utc) + + datetime.timedelta(minutes=5), + ) + await source.add_schedule(schedule) + + # Index has 1 entry. + async with Redis(connection_pool=source._connection_pool) as redis: + assert await redis.zcard(source._get_time_index_key()) == 1 + + await source.delete_schedule(schedule.schedule_id) + + # Index entry is still present (lazy cleanup handles it later). + async with Redis(connection_pool=source._connection_pool) as redis: + assert await redis.zcard(source._get_time_index_key()) == 1 + + +@pytest.mark.anyio +async def test_cleanup_removes_old_empty_entries(redis_url: str) -> None: + """Test that _cleanup_time_index removes index entries that are + older than 5 minutes and whose time key lists are empty.""" + prefix = uuid.uuid4().hex + with freeze_time("2025-01-01 00:10:00"): + source = ListRedisScheduleSource(redis_url, prefix=prefix) + # 10 minutes before "now" — well past the 5-minute threshold. + old_time = datetime.datetime( + 2025, 1, 1, 0, 0, tzinfo=datetime.timezone.utc, + ) + schedule = ScheduledTask( + task_name="test_task", + labels={}, + args=[], + kwargs={}, + time=old_time, + ) + await source.add_schedule(schedule) + # Prevent delete_schedule from triggering cleanup by pretending + # cleanup just ran (rate limiter blocks it). + import time + + source._last_cleanup_time = time.monotonic() + await source.delete_schedule(schedule.schedule_id) + + # Index still has the stale entry (cleanup was rate-limited). + async with Redis(connection_pool=source._connection_pool) as redis: + assert await redis.zcard(source._get_time_index_key()) == 1 + + # Run cleanup directly — entry is > 5 minutes old and empty. + with freeze_time("2025-01-01 00:10:00"): + async with Redis(connection_pool=source._connection_pool) as redis: + await source._cleanup_time_index(redis) + + # Now it should be cleaned up. + async with Redis(connection_pool=source._connection_pool) as redis: + assert await redis.zcard(source._get_time_index_key()) == 0 + + +@pytest.mark.anyio +async def test_cleanup_keeps_non_empty_entries(redis_url: str) -> None: + """Test that _cleanup_time_index does NOT remove index entries whose + time key lists still have schedules, even if older than 5 minutes.""" + prefix = uuid.uuid4().hex + with freeze_time("2025-01-01 00:10:00"): + source = ListRedisScheduleSource(redis_url, prefix=prefix) + old_time = datetime.datetime( + 2025, 1, 1, 0, 0, tzinfo=datetime.timezone.utc, + ) + schedule = ScheduledTask( + task_name="test_task", + labels={}, + args=[], + kwargs={}, + time=old_time, + ) + await source.add_schedule(schedule) + + # Run cleanup — entry is > 5 minutes old but list is NOT empty. + with freeze_time("2025-01-01 00:10:00"): + async with Redis(connection_pool=source._connection_pool) as redis: + await source._cleanup_time_index(redis) + + # Entry should still be present. + async with Redis(connection_pool=source._connection_pool) as redis: + assert await redis.zcard(source._get_time_index_key()) == 1 + + +@pytest.mark.anyio +async def test_cleanup_keeps_recent_empty_entries(redis_url: str) -> None: + """Test that _cleanup_time_index does NOT remove index entries that + are less than 5 minutes old, even if their time key lists are empty.""" + prefix = uuid.uuid4().hex + with freeze_time("2025-01-01 00:04:00"): + source = ListRedisScheduleSource(redis_url, prefix=prefix) + # 2 minutes ago — within the 5-minute safety window. + recent_time = datetime.datetime( + 2025, 1, 1, 0, 2, tzinfo=datetime.timezone.utc, + ) + schedule = ScheduledTask( + task_name="test_task", + labels={}, + args=[], + kwargs={}, + time=recent_time, + ) + await source.add_schedule(schedule) + await source.delete_schedule(schedule.schedule_id) + + # Run cleanup — entry is empty but only 2 minutes old. + with freeze_time("2025-01-01 00:04:00"): + async with Redis(connection_pool=source._connection_pool) as redis: + await source._cleanup_time_index(redis) + + # Entry should still be present (not old enough). + async with Redis(connection_pool=source._connection_pool) as redis: + assert await redis.zcard(source._get_time_index_key()) == 1 + + +@pytest.mark.anyio +@freeze_time("2025-01-01 00:00:00") +async def test_past_schedules_found_via_time_index(redis_url: str) -> None: + """Test that past schedules are discovered via the time index + instead of a full SCAN.""" + prefix = uuid.uuid4().hex + source = ListRedisScheduleSource(redis_url, prefix=prefix) + past_time = datetime.datetime.now( + datetime.timezone.utc, + ) - datetime.timedelta(minutes=5) + schedule = ScheduledTask( + task_name="test_task", + labels={}, + args=[], + kwargs={}, + time=past_time, + ) + await source.add_schedule(schedule) + + # First call to get_schedules should find the past schedule via time index. + schedules = await source.get_schedules() + assert schedules == [schedule] + + +@pytest.mark.anyio +@freeze_time("2025-01-01 00:00:00") +async def test_populate_time_index_from_existing_keys(redis_url: str) -> None: + """Test that populate_time_index=True backfills the sorted set + from existing time keys created without the index.""" + prefix = uuid.uuid4().hex + + # Simulate old-style data: create time key lists directly in Redis + # without populating the time index sorted set. + pool = BlockingConnectionPool.from_url(url=redis_url) + past_times = [ + datetime.datetime(2024, 12, 31, 23, 55, tzinfo=datetime.timezone.utc), + datetime.datetime(2024, 12, 31, 23, 56, tzinfo=datetime.timezone.utc), + datetime.datetime(2024, 12, 31, 23, 57, tzinfo=datetime.timezone.utc), + ] + + source_for_keys = ListRedisScheduleSource(redis_url, prefix=prefix) + async with Redis(connection_pool=pool) as redis: + for t in past_times: + time_key = source_for_keys._get_time_key(t) + # Push a dummy schedule ID directly (bypassing add_schedule + # to simulate old behavior without time index). + await redis.rpush(time_key, f"sched_{t.minute}") # type: ignore[misc] + + # Verify no time index exists yet. + assert await redis.zcard(source_for_keys._get_time_index_key()) == 0 + await pool.disconnect() + + # Now create a source with populate_time_index=True. + source = ListRedisScheduleSource( + redis_url, + prefix=prefix, + populate_time_index=True, + ) + await source.startup() + + # The time index should now be populated. + async with Redis(connection_pool=source._connection_pool) as redis: + count = await redis.zcard(source._get_time_index_key()) + assert count == len(past_times) + + +@pytest.mark.anyio +async def test_post_send_triggers_cleanup(redis_url: str) -> None: + """Test the full lifecycle: add schedule, get it, post_send it, + then verify cleanup (triggered from delete_schedule) removes + the stale index entry when it's > 5 minutes old.""" + prefix = uuid.uuid4().hex + + with freeze_time("2025-01-01 00:10:00"): + source = ListRedisScheduleSource(redis_url, prefix=prefix) + schedule = ScheduledTask( + task_name="test_task", + labels={}, + args=[], + kwargs={}, + time=datetime.datetime( + 2025, 1, 1, 0, 0, tzinfo=datetime.timezone.utc, + ), + ) + await source.add_schedule(schedule) + + # First run picks up past schedules. + schedules = await source.get_schedules() + assert schedules == [schedule] + + # post_send -> delete_schedule -> _maybe_cleanup_time_index. + # The entry is > 5 minutes old and the list becomes empty, + # so cleanup should remove it. + for s in schedules: + await source.post_send(s) + + async with Redis(connection_pool=source._connection_pool) as redis: + assert await redis.zcard(source._get_time_index_key()) == 0 + + # Second run should return nothing. + with freeze_time("2025-01-01 00:11:00"): + schedules = await source.get_schedules() + assert schedules == [] + + +@pytest.mark.anyio +async def test_cleanup_rate_limited(redis_url: str) -> None: + """Test that _maybe_cleanup_time_index only runs once per minute.""" + prefix = uuid.uuid4().hex + + with freeze_time("2025-01-01 00:10:00"): + source = ListRedisScheduleSource(redis_url, prefix=prefix) + old_time = datetime.datetime( + 2025, 1, 1, 0, 0, tzinfo=datetime.timezone.utc, + ) + sched1 = ScheduledTask( + task_name="task1", + labels={}, + args=[], + kwargs={}, + time=old_time, + ) + sched2 = ScheduledTask( + task_name="task2", + labels={}, + args=[], + kwargs={}, + time=old_time, + ) + await source.add_schedule(sched1) + await source.add_schedule(sched2) + + # First delete triggers cleanup (first call, _last_cleanup_time=0). + # But the time key list still has sched2, so the entry is kept. + await source.delete_schedule(sched1.schedule_id) + async with Redis(connection_pool=source._connection_pool) as redis: + assert await redis.zcard(source._get_time_index_key()) == 1 + + # Second delete happens within the same minute, so cleanup + # is rate-limited and does NOT run — index entry remains + # even though the list is now empty. + await source.delete_schedule(sched2.schedule_id) + async with Redis(connection_pool=source._connection_pool) as redis: + assert await redis.zcard(source._get_time_index_key()) == 1 + + +@pytest.mark.anyio +@freeze_time("2025-01-01 00:00:00") +async def test_cron_and_interval_not_in_time_index(redis_url: str) -> None: + """Test that cron and interval schedules do not affect the time index.""" + prefix = uuid.uuid4().hex + source = ListRedisScheduleSource(redis_url, prefix=prefix) + cron_schedule = ScheduledTask( + task_name="cron_task", + labels={}, + args=[], + kwargs={}, + cron="* * * * *", + ) + interval_schedule = ScheduledTask( + task_name="interval_task", + labels={}, + args=[], + kwargs={}, + interval=datetime.timedelta(seconds=30), + ) + await source.add_schedule(cron_schedule) + await source.add_schedule(interval_schedule) + + async with Redis(connection_pool=source._connection_pool) as redis: + assert await redis.zcard(source._get_time_index_key()) == 0