Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 79 additions & 26 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from configuration import configuration
from constants import (
INTERRUPTED_RESPONSE_MESSAGE,
TOPIC_SUMMARY_INTERRUPT_TIMEOUT_SECONDS,
LLM_TOKEN_EVENT,
LLM_TOOL_CALL_EVENT,
LLM_TOOL_RESULT_EVENT,
Expand Down Expand Up @@ -87,6 +88,7 @@
prepare_input,
store_query_results,
update_azure_token,
update_conversation_topic_summary,
validate_attachments_metadata,
validate_model_provider_override,
)
Expand Down Expand Up @@ -116,6 +118,9 @@
logger = get_logger(__name__)
router = APIRouter(tags=["streaming_query"])

# Tracks background topic summary tasks for graceful shutdown.
_background_topic_summary_tasks: list[asyncio.Task[None]] = []

streaming_query_responses: dict[int | str, dict[str, Any]] = {
200: StreamingQueryResponse.openapi_response(),
401: UnauthorizedResponse.openapi_response(
Expand Down Expand Up @@ -364,6 +369,61 @@ async def retrieve_response_generator(
raise HTTPException(**error_response.model_dump()) from e


async def _background_update_topic_summary(
context: ResponseGeneratorContext,
model: str,
) -> None:
"""Generate topic summary and update DB/cache in the background.

Runs as a fire-and-forget task after an interrupted turn is persisted.
All errors are caught and logged.
"""
try:
topic_summary = await asyncio.wait_for(
get_topic_summary(
context.query_request.query,
context.client,
model,
),
timeout=TOPIC_SUMMARY_INTERRUPT_TIMEOUT_SECONDS,
)
if topic_summary:
update_conversation_topic_summary(
context.conversation_id,
topic_summary,
user_id=context.user_id,
skip_userid_check=context.skip_userid_check,
)
except asyncio.TimeoutError:
logger.warning(
"Topic summary timed out for interrupted turn, request %s",
context.request_id,
)
except Exception: # pylint: disable=broad-except
logger.exception(
"Failed to generate topic summary for interrupted turn, request %s",
context.request_id,
)


async def shutdown_background_topic_summary_tasks() -> None:
"""Cancel and await outstanding background topic summary tasks on shutdown.

Ensures graceful shutdown so in-flight topic summary generation can be
cleaned up. Called from the application lifespan shutdown phase.
"""
tasks = list(_background_topic_summary_tasks)
if not tasks:
return
logger.debug(
"Shutting down %d outstanding background topic summary task(s)",
len(tasks),
)
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)


async def _persist_interrupted_turn(
context: ResponseGeneratorContext,
responses_params: ResponsesApiParams,
Expand All @@ -372,8 +432,9 @@ async def _persist_interrupted_turn(
"""Persist the user query and an interrupted response into the conversation.

Called when a streaming request is cancelled so the exchange is not lost.
All errors are caught and logged to avoid masking the original
cancellation.
Persists immediately with topic_summary=None so the conversation exists
when the client fetches. Topic summary is generated in a background task
and updated when ready.

Parameters:
context: The response generator context.
Expand All @@ -395,27 +456,6 @@ async def _persist_interrupted_turn(
)

try:
topic_summary = None
if not context.query_request.conversation_id:
should_generate = context.query_request.generate_topic_summary
if should_generate:
try:
logger.debug(
"Generating topic summary for interrupted new conversation"
)
topic_summary = await get_topic_summary(
context.query_request.query,
context.client,
responses_params.model,
)
except Exception as e: # pylint: disable=broad-except
logger.warning(
"Failed to generate topic summary for interrupted turn, "
"request %s: %s",
context.request_id,
e,
)

completed_at = datetime.datetime.now(datetime.UTC).strftime(
"%Y-%m-%dT%H:%M:%SZ"
)
Expand All @@ -428,8 +468,21 @@ async def _persist_interrupted_turn(
summary=turn_summary,
query=context.query_request.query,
skip_userid_check=context.skip_userid_check,
topic_summary=topic_summary,
topic_summary=None,
)

if (
not context.query_request.conversation_id
and context.query_request.generate_topic_summary
):
task = asyncio.create_task(
_background_update_topic_summary(
context=context,
model=responses_params.model,
)
)
_background_topic_summary_tasks.append(task)
task.add_done_callback(_background_topic_summary_tasks.remove)
except Exception: # pylint: disable=broad-except
logger.exception(
"Failed to store interrupted query results for request %s",
Expand All @@ -444,8 +497,8 @@ def _register_interrupt_callback(
) -> list[bool]:
"""Build an interrupt callback and register the stream for cancellation.

The callback is scheduled as a **separate** asyncio task by
``cancel_stream`` so it executes regardless of where the
The callback is invoked by ``cancel_stream`` when the client
interrupts, so persistence runs regardless of where the
``CancelledError`` is raised in the ASGI stack.

A mutable one-element list is used as a shared guard so the
Expand Down
2 changes: 2 additions & 0 deletions src/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import version
from a2a_storage import A2AStorageFactory
from app import routers
from app.endpoints.streaming_query import shutdown_background_topic_summary_tasks
from app.database import create_tables, initialize_database
from authorization.azure_token_manager import AzureEntraIDManager
from client import AsyncLlamaStackClientHolder
Expand Down Expand Up @@ -80,6 +81,7 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
yield

# Cleanup resources on shutdown
await shutdown_background_topic_summary_tasks()
await A2AStorageFactory.cleanup()
logger.info("App shutdown complete")

Expand Down
3 changes: 3 additions & 0 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
# Response stored in the conversation when the user interrupts a streaming request
INTERRUPTED_RESPONSE_MESSAGE = "You interrupted this request."

# Max seconds to wait for topic summary in background task after interrupt persist.
TOPIC_SUMMARY_INTERRUPT_TIMEOUT_SECONDS = 30.0

# Supported attachment types
ATTACHMENT_TYPES = frozenset(
{
Expand Down
47 changes: 47 additions & 0 deletions src/utils/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,53 @@ def persist_user_conversation_details(
)


def update_conversation_topic_summary(
conversation_id: str,
topic_summary: str,
user_id: Optional[str] = None,
skip_userid_check: bool = False,
) -> None:
"""Update topic_summary for an existing conversation in DB and optionally cache.

Args:
conversation_id: The conversation ID (normalized or with conv_ prefix).
topic_summary: The topic summary to store.
user_id: Optional user ID for cache update; when provided with cache
configured, also updates the conversation cache.
skip_userid_check: Whether to skip user ID validation for cache operations.
"""
normalized_id = normalize_conversation_id(conversation_id)
with get_session() as session:
existing = session.query(UserConversation).filter_by(id=normalized_id).first()
if existing:
existing.topic_summary = topic_summary
session.commit()
logger.debug("Updated topic summary for conversation %s", normalized_id)
else:
logger.debug(
"No conversation found for topic summary update: id=%s, "
"topic_summary_len=%d",
normalized_id,
len(topic_summary),
)

if (
user_id
and configuration.conversation_cache_configuration.type is not None
and configuration.conversation_cache is not None
):
try:
configuration.conversation_cache.set_topic_summary(
user_id, conversation_id, topic_summary, skip_userid_check
)
except Exception as e: # pylint: disable=broad-except
logger.warning(
"Failed to update topic summary in cache for %s: %s",
normalized_id,
e,
)


def validate_attachments_metadata(attachments: list[Attachment]) -> None:
"""Validate the attachments metadata provided in the request.

Expand Down
17 changes: 15 additions & 2 deletions tests/unit/app/endpoints/test_streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,6 +1497,9 @@ async def mock_generator() -> AsyncIterator[str]:
store_query_results_mock = mocker.patch(
"app.endpoints.streaming_query.store_query_results"
)
update_topic_summary_mock = mocker.patch(
"app.endpoints.streaming_query.update_conversation_topic_summary"
)
mocker.patch(
"app.endpoints.streaming_query.append_turn_to_conversation",
new_callable=mocker.AsyncMock,
Expand All @@ -1511,14 +1514,22 @@ async def mock_generator() -> AsyncIterator[str]:
):
result.append(item)

await asyncio.sleep(0.1)

assert any('"event": "interrupted"' in item for item in result)
call_kwargs = store_query_results_mock.call_args[1]
assert call_kwargs["topic_summary"] is None
get_topic_summary_mock.assert_called_once_with(
"What is Kubernetes?",
mock_context.client,
"provider1/model1",
)
call_kwargs = store_query_results_mock.call_args[1]
assert call_kwargs["topic_summary"] == "Kubernetes container orchestration"
update_topic_summary_mock.assert_called_once_with(
"conv_new_456",
"Kubernetes container orchestration",
user_id="user_123",
skip_userid_check=False,
)
isolate_stream_interrupt_registry.deregister_stream.assert_called_once_with(
test_request_id
)
Expand Down Expand Up @@ -1580,6 +1591,8 @@ async def mock_generator() -> AsyncIterator[str]:
):
result.append(item)

await asyncio.sleep(0.1)

assert any('"event": "interrupted"' in item for item in result)
store_query_results_mock.assert_called_once()
call_kwargs = store_query_results_mock.call_args[1]
Expand Down
Loading