diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index c9345f852..0df166171 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -49,6 +49,7 @@ from configuration import configuration from constants import ( INTERRUPTED_RESPONSE_MESSAGE, + TOPIC_SUMMARY_INTERRUPT_TIMEOUT_SECONDS, LLM_TOKEN_EVENT, LLM_TOOL_CALL_EVENT, LLM_TOOL_RESULT_EVENT, @@ -87,6 +88,7 @@ prepare_input, store_query_results, update_azure_token, + update_conversation_topic_summary, validate_attachments_metadata, validate_model_provider_override, ) @@ -116,6 +118,9 @@ logger = get_logger(__name__) router = APIRouter(tags=["streaming_query"]) +# Tracks background topic summary tasks for graceful shutdown. +_background_topic_summary_tasks: list[asyncio.Task[None]] = [] + streaming_query_responses: dict[int | str, dict[str, Any]] = { 200: StreamingQueryResponse.openapi_response(), 401: UnauthorizedResponse.openapi_response( @@ -364,6 +369,61 @@ async def retrieve_response_generator( raise HTTPException(**error_response.model_dump()) from e +async def _background_update_topic_summary( + context: ResponseGeneratorContext, + model: str, +) -> None: + """Generate topic summary and update DB/cache in the background. + + Runs as a fire-and-forget task after an interrupted turn is persisted. + All errors are caught and logged. + """ + try: + topic_summary = await asyncio.wait_for( + get_topic_summary( + context.query_request.query, + context.client, + model, + ), + timeout=TOPIC_SUMMARY_INTERRUPT_TIMEOUT_SECONDS, + ) + if topic_summary: + update_conversation_topic_summary( + context.conversation_id, + topic_summary, + user_id=context.user_id, + skip_userid_check=context.skip_userid_check, + ) + except asyncio.TimeoutError: + logger.warning( + "Topic summary timed out for interrupted turn, request %s", + context.request_id, + ) + except Exception: # pylint: disable=broad-except + logger.exception( + "Failed to generate topic summary for interrupted turn, request %s", + context.request_id, + ) + + +async def shutdown_background_topic_summary_tasks() -> None: + """Cancel and await outstanding background topic summary tasks on shutdown. + + Ensures graceful shutdown so in-flight topic summary generation can be + cleaned up. Called from the application lifespan shutdown phase. + """ + tasks = list(_background_topic_summary_tasks) + if not tasks: + return + logger.debug( + "Shutting down %d outstanding background topic summary task(s)", + len(tasks), + ) + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + + async def _persist_interrupted_turn( context: ResponseGeneratorContext, responses_params: ResponsesApiParams, @@ -372,8 +432,9 @@ async def _persist_interrupted_turn( """Persist the user query and an interrupted response into the conversation. Called when a streaming request is cancelled so the exchange is not lost. - All errors are caught and logged to avoid masking the original - cancellation. + Persists immediately with topic_summary=None so the conversation exists + when the client fetches. Topic summary is generated in a background task + and updated when ready. Parameters: context: The response generator context. @@ -395,27 +456,6 @@ async def _persist_interrupted_turn( ) try: - topic_summary = None - if not context.query_request.conversation_id: - should_generate = context.query_request.generate_topic_summary - if should_generate: - try: - logger.debug( - "Generating topic summary for interrupted new conversation" - ) - topic_summary = await get_topic_summary( - context.query_request.query, - context.client, - responses_params.model, - ) - except Exception as e: # pylint: disable=broad-except - logger.warning( - "Failed to generate topic summary for interrupted turn, " - "request %s: %s", - context.request_id, - e, - ) - completed_at = datetime.datetime.now(datetime.UTC).strftime( "%Y-%m-%dT%H:%M:%SZ" ) @@ -428,8 +468,21 @@ async def _persist_interrupted_turn( summary=turn_summary, query=context.query_request.query, skip_userid_check=context.skip_userid_check, - topic_summary=topic_summary, + topic_summary=None, ) + + if ( + not context.query_request.conversation_id + and context.query_request.generate_topic_summary + ): + task = asyncio.create_task( + _background_update_topic_summary( + context=context, + model=responses_params.model, + ) + ) + _background_topic_summary_tasks.append(task) + task.add_done_callback(_background_topic_summary_tasks.remove) except Exception: # pylint: disable=broad-except logger.exception( "Failed to store interrupted query results for request %s", @@ -444,8 +497,8 @@ def _register_interrupt_callback( ) -> list[bool]: """Build an interrupt callback and register the stream for cancellation. - The callback is scheduled as a **separate** asyncio task by - ``cancel_stream`` so it executes regardless of where the + The callback is invoked by ``cancel_stream`` when the client + interrupts, so persistence runs regardless of where the ``CancelledError`` is raised in the ASGI stack. A mutable one-element list is used as a shared guard so the diff --git a/src/app/main.py b/src/app/main.py index d3e0b0c18..235b2694c 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -15,6 +15,7 @@ import version from a2a_storage import A2AStorageFactory from app import routers +from app.endpoints.streaming_query import shutdown_background_topic_summary_tasks from app.database import create_tables, initialize_database from authorization.azure_token_manager import AzureEntraIDManager from client import AsyncLlamaStackClientHolder @@ -80,6 +81,7 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]: yield # Cleanup resources on shutdown + await shutdown_background_topic_summary_tasks() await A2AStorageFactory.cleanup() logger.info("App shutdown complete") diff --git a/src/constants.py b/src/constants.py index e68ddea4c..5f05431c6 100644 --- a/src/constants.py +++ b/src/constants.py @@ -9,6 +9,9 @@ # Response stored in the conversation when the user interrupts a streaming request INTERRUPTED_RESPONSE_MESSAGE = "You interrupted this request." +# Max seconds to wait for topic summary in background task after interrupt persist. +TOPIC_SUMMARY_INTERRUPT_TIMEOUT_SECONDS = 30.0 + # Supported attachment types ATTACHMENT_TYPES = frozenset( { diff --git a/src/utils/query.py b/src/utils/query.py index 2fb5ff96d..a4957e542 100644 --- a/src/utils/query.py +++ b/src/utils/query.py @@ -467,6 +467,53 @@ def persist_user_conversation_details( ) +def update_conversation_topic_summary( + conversation_id: str, + topic_summary: str, + user_id: Optional[str] = None, + skip_userid_check: bool = False, +) -> None: + """Update topic_summary for an existing conversation in DB and optionally cache. + + Args: + conversation_id: The conversation ID (normalized or with conv_ prefix). + topic_summary: The topic summary to store. + user_id: Optional user ID for cache update; when provided with cache + configured, also updates the conversation cache. + skip_userid_check: Whether to skip user ID validation for cache operations. + """ + normalized_id = normalize_conversation_id(conversation_id) + with get_session() as session: + existing = session.query(UserConversation).filter_by(id=normalized_id).first() + if existing: + existing.topic_summary = topic_summary + session.commit() + logger.debug("Updated topic summary for conversation %s", normalized_id) + else: + logger.debug( + "No conversation found for topic summary update: id=%s, " + "topic_summary_len=%d", + normalized_id, + len(topic_summary), + ) + + if ( + user_id + and configuration.conversation_cache_configuration.type is not None + and configuration.conversation_cache is not None + ): + try: + configuration.conversation_cache.set_topic_summary( + user_id, conversation_id, topic_summary, skip_userid_check + ) + except Exception as e: # pylint: disable=broad-except + logger.warning( + "Failed to update topic summary in cache for %s: %s", + normalized_id, + e, + ) + + def validate_attachments_metadata(attachments: list[Attachment]) -> None: """Validate the attachments metadata provided in the request. diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 99dee264e..643a34477 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -1497,6 +1497,9 @@ async def mock_generator() -> AsyncIterator[str]: store_query_results_mock = mocker.patch( "app.endpoints.streaming_query.store_query_results" ) + update_topic_summary_mock = mocker.patch( + "app.endpoints.streaming_query.update_conversation_topic_summary" + ) mocker.patch( "app.endpoints.streaming_query.append_turn_to_conversation", new_callable=mocker.AsyncMock, @@ -1511,14 +1514,22 @@ async def mock_generator() -> AsyncIterator[str]: ): result.append(item) + await asyncio.sleep(0.1) + assert any('"event": "interrupted"' in item for item in result) + call_kwargs = store_query_results_mock.call_args[1] + assert call_kwargs["topic_summary"] is None get_topic_summary_mock.assert_called_once_with( "What is Kubernetes?", mock_context.client, "provider1/model1", ) - call_kwargs = store_query_results_mock.call_args[1] - assert call_kwargs["topic_summary"] == "Kubernetes container orchestration" + update_topic_summary_mock.assert_called_once_with( + "conv_new_456", + "Kubernetes container orchestration", + user_id="user_123", + skip_userid_check=False, + ) isolate_stream_interrupt_registry.deregister_stream.assert_called_once_with( test_request_id ) @@ -1580,6 +1591,8 @@ async def mock_generator() -> AsyncIterator[str]: ): result.append(item) + await asyncio.sleep(0.1) + assert any('"event": "interrupted"' in item for item in result) store_query_results_mock.assert_called_once() call_kwargs = store_query_results_mock.call_args[1]