Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 108 additions & 18 deletions taskiq_redis/list_schedule_source.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import time as _time
from logging import getLogger
from typing import Any

Expand All @@ -23,6 +24,7 @@ def __init__(
serializer: TaskiqSerializer | None = None,
buffer_size: int = 50,
skip_past_schedules: bool = False,
populate_time_index: bool = False,
**connection_kwargs: Any,
) -> None:
"""
Expand All @@ -34,6 +36,11 @@ def __init__(
:param serializer: Serializer to use for the schedules
:param buffer_size: Buffer size for getting schedules
:param skip_past_schedules: Skip schedules that are in the past.
:param populate_time_index: If True, on startup run a one-time SCAN
to populate the time index sorted set from existing time keys.
This is needed for migrating from an older version that did not
maintain the time index. Set this to True once to backfill the
index, then set it back to False for subsequent runs.
:param connection_kwargs: Additional connection kwargs
"""
super().__init__()
Expand All @@ -51,6 +58,8 @@ def __init__(
self._previous_schedule_source: ScheduleSource | None = None
self._delete_schedules_after_migration: bool = True
self._skip_past_schedules = skip_past_schedules
self._populate_time_index = populate_time_index
self._last_cleanup_time: float = 0

async def startup(self) -> None:
"""
Expand All @@ -59,6 +68,9 @@ async def startup(self) -> None:
By default this function does nothing.
But if the previous schedule source is set,
it will try to migrate schedules from it.

If populate_time_index is True, it will scan for existing
time keys and populate the time index sorted set.
"""
if self._previous_schedule_source is not None:
logger.info("Migrating schedules from previous source")
Expand All @@ -74,13 +86,36 @@ async def startup(self) -> None:
await self._previous_schedule_source.shutdown()
logger.info("Migration complete")

if self._populate_time_index:
logger.info("Populating time index from existing keys via scan")
async with Redis(connection_pool=self._connection_pool) as redis:
batch: dict[str, float] = {}
async for key in redis.scan_iter(f"{self._prefix}:time:*"):
key_str = key.decode()
key_time = self._parse_time_key(key_str)
if key_time:
batch[key_str] = key_time.timestamp()
if len(batch) >= self._buffer_size:
await redis.zadd(
self._get_time_index_key(),
batch,
)
batch = {}
if batch:
await redis.zadd(self._get_time_index_key(), batch)
logger.info("Time index population complete")

def _get_time_key(self, time: datetime.datetime) -> str:
"""Get the key for a time-based schedule."""
if time.tzinfo is None:
time = time.replace(tzinfo=datetime.timezone.utc)
iso_time = time.astimezone(datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M")
return f"{self._prefix}:time:{iso_time}"

def _get_time_index_key(self) -> str:
"""Get the key for the time index sorted set."""
return f"{self._prefix}:time_index"

def _get_cron_key(self) -> str:
"""Get the key for a cron-based schedule."""
return f"{self._prefix}:cron"
Expand All @@ -103,35 +138,77 @@ def _parse_time_key(self, key: str) -> datetime.datetime | None:
logger.debug("Failed to parse time key %s", key)
return None

async def _get_previous_time_schedules(self) -> list[bytes]:
async def _maybe_cleanup_time_index(self, redis: Redis) -> None: # type: ignore[type-arg]
"""
Run time index cleanup at most once per minute.

Called from delete_schedule after removing a time-based schedule,
since that's the path where time key lists become empty.
"""
now = _time.monotonic()
if now - self._last_cleanup_time < 60:
return
self._last_cleanup_time = now
await self._cleanup_time_index(redis)

async def _cleanup_time_index(self, redis: Redis) -> None: # type: ignore[type-arg]
"""
Remove stale entries from the time index sorted set.

Only removes entries that are older than 1 hour AND whose
corresponding time key list is empty (or no longer exists).
This avoids a race condition where an eager cleanup in
delete_schedule could remove an index entry right as
add_schedule is creating a new schedule at the same minute.
"""
one_hour_ago = (
datetime.datetime.now(datetime.timezone.utc)
- datetime.timedelta(hours=1)
).timestamp()
stale_keys: list[bytes] = await redis.zrangebyscore(
self._get_time_index_key(),
"-inf",
one_hour_ago,
)
for key in stale_keys:
if await redis.llen(key) == 0:
await redis.zrem(self._get_time_index_key(), key)

async def _get_previous_time_schedules(
self,
current_time: datetime.datetime,
) -> list[bytes]:
"""
Function that gets all timed schedules that are in the past.

Since this source doesn't retrieve all the schedules at once,
we need to get all the schedules that are in the past and haven't
been sent yet.

We do this by getting all the time keys and checking if the time
is less than the current time.
Uses the time index sorted set to look up past time keys
instead of scanning all Redis keys.

This function is called only during the first run to minimize
the number of requests to the Redis server.

:param current_time: The reference time captured by the caller,
used to derive the cutoff so that the "previous" and "current"
windows never overlap.
"""
logger.info("Getting previous time schedules")
minute_before = datetime.datetime.now(
datetime.timezone.utc,
).replace(second=0, microsecond=0) - datetime.timedelta(
minute_before = current_time.replace(
second=0, microsecond=0,
) - datetime.timedelta(
minutes=1,
)
schedules = []
async with Redis(connection_pool=self._connection_pool) as redis:
time_keys: list[str] = []
# We need to get all the time keys and check if the time is less than
# the current time.
async for key in redis.scan_iter(f"{self._prefix}:time:*"):
key_time = self._parse_time_key(key.decode())
if key_time and key_time <= minute_before:
time_keys.append(key.decode())
max_score = minute_before.timestamp()
time_keys: list[bytes] = await redis.zrangebyscore(
self._get_time_index_key(),
"-inf",
max_score,
)
for key in time_keys:
schedules.extend(await redis.lrange(key, 0, -1)) # type: ignore[misc]

Expand All @@ -153,6 +230,7 @@ async def delete_schedule(self, schedule_id: str) -> None:
elif schedule.time is not None:
time_key = self._get_time_key(schedule.time)
await redis.lrem(time_key, 0, schedule_id) # type: ignore[misc]
await self._maybe_cleanup_time_index(redis)
elif schedule.interval:
await redis.lrem(self._get_interval_key(), 0, schedule_id) # type: ignore[misc]

Expand All @@ -170,9 +248,21 @@ async def add_schedule(self, schedule: "ScheduledTask") -> None:
if schedule.cron is not None:
await redis.rpush(self._get_cron_key(), schedule.schedule_id) # type: ignore[misc]
elif schedule.time is not None:
await redis.rpush( # type: ignore[misc]
self._get_time_key(schedule.time),
schedule.schedule_id,
time_key = self._get_time_key(schedule.time)
await redis.rpush(time_key, schedule.schedule_id) # type: ignore[misc]
# Add to the time index sorted set so we can look up
# past time keys without scanning all Redis keys.
time_val = schedule.time
if time_val.tzinfo is None:
time_val = time_val.replace(tzinfo=datetime.timezone.utc)
score = (
time_val.astimezone(datetime.timezone.utc)
.replace(second=0, microsecond=0)
.timestamp()
)
await redis.zadd( # type: ignore[misc]
self._get_time_index_key(),
{time_key: score},
)
elif schedule.interval:
await redis.rpush( # type: ignore[misc]
Expand Down Expand Up @@ -200,8 +290,8 @@ async def get_schedules(self) -> list["ScheduledTask"]:
current_time = datetime.datetime.now(datetime.timezone.utc)
timed: list[bytes] = []
# Only during first run, we need to get previous time schedules
if not self._skip_past_schedules:
timed = await self._get_previous_time_schedules()
if not self._skip_past_schedules and self._is_first_run:
timed = await self._get_previous_time_schedules(current_time)
self._is_first_run = False
async with Redis(connection_pool=self._connection_pool) as redis:
buffer = []
Expand Down
Loading