-
Notifications
You must be signed in to change notification settings - Fork 823
feat(a2a): implement full A2A task lifecycle state support #2245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7ce5a09
ec92858
451cae9
2c8539f
a66e210
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |
| streamed requests to the A2AServer. | ||
| """ | ||
|
|
||
| import asyncio | ||
| import base64 | ||
| import json | ||
| import logging | ||
|
|
@@ -42,7 +43,9 @@ class StrandsA2AExecutor(AgentExecutor): | |
| """Executor that adapts a Strands Agent to the A2A protocol. | ||
|
|
||
| This executor uses streaming mode to handle the execution of agent requests | ||
| and converts Strands Agent responses to A2A protocol events. | ||
| and converts Strands Agent responses to A2A protocol events. It supports the | ||
| full A2A task lifecycle including error handling (failed state), cancellation, | ||
| and interrupt-based input_required flows. | ||
| """ | ||
|
|
||
| # Default formats for each file type when MIME type is unavailable or unrecognized | ||
|
|
@@ -75,14 +78,18 @@ async def execute( | |
| """Execute a request using the Strands Agent and send the response as A2A events. | ||
|
|
||
| This method executes the user's input using the Strands Agent in streaming mode | ||
| and converts the agent's response to A2A events. | ||
| and converts the agent's response to A2A events. If the agent raises an exception, | ||
| the task transitions to the `failed` state. If the agent returns with interrupts, | ||
| the task transitions to the `input_required` state. | ||
|
|
||
| Args: | ||
| context: The A2A request context, containing the user's input and task metadata. | ||
| event_queue: The A2A event queue used to send response events back to the client. | ||
|
|
||
| Raises: | ||
| ServerError: If an error occurs during agent execution | ||
| ServerError: If an unrecoverable error occurs during agent execution setup | ||
| (e.g., missing input). Agent execution errors are handled gracefully | ||
| by transitioning the task to the failed state. | ||
| """ | ||
| task = context.current_task | ||
| if not task: | ||
|
|
@@ -93,8 +100,34 @@ async def execute( | |
|
|
||
| try: | ||
| await self._execute_streaming(context, updater) | ||
| except Exception as e: | ||
| raise ServerError(error=InternalError()) from e | ||
| except ServerError: | ||
| # Re-raise ServerErrors (setup failures like missing input) | ||
| raise | ||
| except asyncio.CancelledError: | ||
| # asyncio.CancelledError is a BaseException (not Exception) — raised when | ||
| # the asyncio task is cancelled (e.g., HTTP client disconnect, server shutdown). | ||
| # We transition to canceled state so the task doesn't remain a zombie in "working". | ||
| logger.warning("task_id=<%s> | asyncio task cancelled, transitioning to canceled state", task.id) | ||
| try: | ||
| await updater.cancel( | ||
| message=updater.new_agent_message( | ||
| parts=[Part(root=TextPart(text="Task cancelled due to connection termination"))] | ||
| ) | ||
| ) | ||
| except RuntimeError: | ||
| # Task already in terminal state | ||
| logger.debug("task_id=<%s> | task already in terminal state, cannot transition to canceled", task.id) | ||
| raise | ||
| except Exception: | ||
| # Agent execution failures transition to failed state | ||
| logger.exception("task_id=<%s> | agent execution failed, transitioning to failed state", task.id) | ||
| try: | ||
| await updater.failed( | ||
| message=updater.new_agent_message(parts=[Part(root=TextPart(text="Agent execution failed"))]) | ||
| ) | ||
| except RuntimeError: | ||
| # Task already in terminal state (e.g., completed before error in cleanup) | ||
| logger.debug("task_id=<%s> | task already in terminal state, cannot transition to failed", task.id) | ||
|
|
||
| async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater) -> None: | ||
| """Execute request in streaming mode. | ||
|
|
@@ -105,14 +138,19 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater | |
| Args: | ||
| context: The A2A request context, containing the user's input and other metadata. | ||
| updater: The task updater for managing task state and sending updates. | ||
|
|
||
| Raises: | ||
| ServerError: If input conversion fails (missing or empty content). | ||
| """ | ||
| # Convert A2A message parts to Strands ContentBlocks | ||
| if context.message and hasattr(context.message, "parts"): | ||
| content_blocks = self._convert_a2a_parts_to_content_blocks(context.message.parts) | ||
| if not content_blocks: | ||
| raise ValueError("No content blocks available") | ||
| raise ServerError( | ||
| error=InternalError(message="No valid content found in request message parts") | ||
| ) from None | ||
| else: | ||
| raise ValueError("No content blocks available") | ||
| raise ServerError(error=InternalError(message="Request message is missing or has no parts")) from None | ||
|
|
||
| if not self.enable_a2a_compliant_streaming: | ||
| warnings.warn( | ||
|
|
@@ -133,8 +171,20 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater | |
| invocation_state: dict[str, Any] = {"a2a_request_context": context} | ||
|
|
||
| try: | ||
| result: SAAgentResult | None = None | ||
| async for event in self.agent.stream_async(content_blocks, invocation_state=invocation_state): | ||
| await self._handle_streaming_event(event, updater) | ||
| if "result" in event: | ||
| result = event["result"] | ||
| else: | ||
| await self._handle_streaming_event(event, updater) | ||
|
|
||
| # Check if agent returned with interrupts (input_required) | ||
| # Note: stop_reason="interrupt" is the authoritative signal. Even if interrupts | ||
| # list is empty (edge case), the agent still indicated it needs input. | ||
| if result is not None and result.stop_reason == "interrupt": | ||
| await self._handle_interrupt_result(result, updater) | ||
| else: | ||
| await self._handle_agent_result(result, updater) | ||
| except Exception: | ||
| logger.exception("Error in streaming execution") | ||
| raise | ||
|
|
@@ -143,6 +193,34 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater | |
| self._current_artifact_id = None | ||
| self._is_first_chunk = True | ||
|
|
||
| async def _handle_interrupt_result(self, result: SAAgentResult, updater: TaskUpdater) -> None: | ||
| """Handle an agent result that contains interrupts. | ||
|
|
||
| When the Strands Agent returns with stop_reason="interrupt", this maps to | ||
| the A2A `input_required` state. The interrupt details are communicated to | ||
| the client via the status message. | ||
|
|
||
| Args: | ||
| result: The agent result containing interrupts. | ||
| updater: The task updater for managing task state. | ||
| """ | ||
| # Build a descriptive message about what input is needed | ||
| interrupt_descriptions = [] | ||
| for interrupt in result.interrupts or []: | ||
| desc = f"- {interrupt.name}" | ||
| if interrupt.reason: | ||
| desc += f": {interrupt.reason}" | ||
| interrupt_descriptions.append(desc) | ||
|
|
||
| if interrupt_descriptions: | ||
| input_message = "Agent requires input:\n" + "\n".join(interrupt_descriptions) | ||
| else: | ||
| # Edge case: stop_reason="interrupt" but no interrupt details provided. | ||
| # Still transition to input_required — the agent signaled it needs input. | ||
| input_message = "Agent requires additional input to continue" | ||
|
|
||
| await updater.requires_input(message=updater.new_agent_message(parts=[Part(root=TextPart(text=input_message))])) | ||
|
|
||
| async def _handle_streaming_event(self, event: dict[str, Any], updater: TaskUpdater) -> None: | ||
| """Handle a single streaming event from the Strands Agent. | ||
|
|
||
|
|
@@ -175,8 +253,6 @@ async def _handle_streaming_event(self, event: dict[str, Any], updater: TaskUpda | |
| updater.task_id, | ||
| ), | ||
| ) | ||
| elif "result" in event: | ||
| await self._handle_agent_result(event["result"], updater) | ||
|
|
||
| async def _handle_agent_result(self, result: SAAgentResult | None, updater: TaskUpdater) -> None: | ||
| """Handle the final result from the Strands Agent. | ||
|
|
@@ -219,20 +295,45 @@ async def _handle_agent_result(self, result: SAAgentResult | None, updater: Task | |
| async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: | ||
| """Cancel an ongoing execution. | ||
|
|
||
| This method is called when a request cancellation is requested. Currently, | ||
| cancellation is not supported by the Strands Agent executor, so this method | ||
| always raises an UnsupportedOperationError. | ||
| Transitions the task to the canceled state and attempts to stop the agent. | ||
| The agent's cancel() method is called if available to signal cooperative | ||
| cancellation of in-flight execution. | ||
|
|
||
| Note: This transitions the A2A task state. The underlying agent execution | ||
| may still complete its current model call before stopping. | ||
|
|
||
| Args: | ||
| context: The A2A request context. | ||
| event_queue: The A2A event queue. | ||
|
|
||
| Raises: | ||
| ServerError: Always raised with an UnsupportedOperationError, as cancellation | ||
| is not currently supported. | ||
| ServerError: If no current task exists or the task is already in a terminal state. | ||
| """ | ||
| logger.warning("Cancellation requested but not supported") | ||
| raise ServerError(error=UnsupportedOperationError()) | ||
| task = context.current_task | ||
| if not task: | ||
| logger.warning("context_id=<%s> | cancel requested but no current task found", context.context_id) | ||
| raise ServerError(error=UnsupportedOperationError()) from None | ||
|
|
||
| # Attempt to cooperatively cancel the agent's execution (best-effort). | ||
| # Agent.cancel() may not exist on all implementations, so we guard with hasattr. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Issue: The comment says "we guard with hasattr" but the implementation actually uses try/except on Suggestion: Update the comment to match the implementation: # Attempt to cooperatively cancel the agent's execution (best-effort).
# Catches AttributeError/NotImplementedError if agent doesn't support cancel(). |
||
| try: | ||
| self.agent.cancel() | ||
| except (AttributeError, NotImplementedError): | ||
| # Agent doesn't support cancel — proceed with state transition only | ||
| pass | ||
| except Exception: | ||
| logger.debug("task_id=<%s> | agent cancel signal failed (non-critical)", task.id) | ||
|
|
||
| updater = TaskUpdater(event_queue, task.id, task.context_id) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agent has
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch — fixed in the latest push. # Attempt to stop the agent if it supports cancellation
if hasattr(self.agent, "cancel") and callable(self.agent.cancel):
try:
self.agent.cancel()
except Exception:
logger.debug("task_id=<%s> | agent cancel signal failed (non-critical)", task.id)This is best-effort — if Docstring updated to reflect this accurately.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: isn't the type of self.agent Agent? why do we need ducktyping checks? |
||
|
|
||
| try: | ||
| await updater.cancel( | ||
| message=updater.new_agent_message(parts=[Part(root=TextPart(text="Task cancelled by client request"))]) | ||
| ) | ||
| except RuntimeError: | ||
| # TaskUpdater raises RuntimeError when task is already in a terminal state | ||
| logger.warning("task_id=<%s> | cannot cancel, already in terminal state", task.id) | ||
| raise ServerError(error=UnsupportedOperationError()) from None | ||
|
|
||
| def _get_file_type_from_mime_type(self, mime_type: str | None) -> Literal["document", "image", "video", "unknown"]: | ||
| """Classify file type based on MIME type. | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.