Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions src/strands/agent/a2a_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@

import httpx
from a2a.client import A2ACardResolver, ClientConfig, ClientFactory
from a2a.types import AgentCard, Message, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent
from a2a.types import AgentCard, Message, TaskArtifactUpdateEvent, TaskStatusUpdateEvent

from .._async import run_async
from ..multiagent.a2a._converters import convert_input_to_message, convert_response_to_agent_result
from ..multiagent.a2a._converters import (
_STATE_TO_STOP_REASON,
convert_input_to_message,
convert_response_to_agent_result,
)
from ..types._events import AgentResultEvent
from ..types.a2a import A2AResponse, A2AStreamEvent
from ..types.agent import AgentInput
Expand All @@ -29,6 +33,13 @@

_DEFAULT_TIMEOUT = 300

# A2A task states that indicate the response stream is complete.
# Derived from the canonical _STATE_TO_STOP_REASON mapping in _converters.
# Terminal states (end_turn) mean no more events; input states (interrupt) mean execution is paused.
_TERMINAL_STATES = {state for state, reason in _STATE_TO_STOP_REASON.items() if reason == "end_turn"}
_INPUT_STATES = {state for state, reason in _STATE_TO_STOP_REASON.items() if reason == "interrupt"}
_COMPLETE_STATES = _TERMINAL_STATES | _INPUT_STATES


class A2AAgent(AgentBase):
"""Client wrapper for remote A2A agents."""
Expand Down Expand Up @@ -265,6 +276,9 @@ async def _send_message(self, prompt: AgentInput) -> AsyncIterator[A2AResponse]:
def _is_complete_event(self, event: A2AResponse) -> bool:
"""Check if an A2A event represents a complete response.

Recognizes all terminal states (completed, failed, canceled, rejected)
and pausing states (input_required, auth_required) as complete events.

Args:
event: A2A event.

Expand All @@ -289,9 +303,10 @@ def _is_complete_event(self, event: A2AResponse) -> bool:
return update_event.last_chunk
return False

# Status update with completed state
# Status update - check for terminal or pausing states
if isinstance(update_event, TaskStatusUpdateEvent):
if update_event.status and hasattr(update_event.status, "state"):
return update_event.status.state == TaskState.completed
state = update_event.status.state
return state in _COMPLETE_STATES

return False
60 changes: 54 additions & 6 deletions src/strands/multiagent/a2a/_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,24 @@
from uuid import uuid4

from a2a.types import Message as A2AMessage
from a2a.types import Part, Role, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, TextPart
from a2a.types import Part, Role, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent, TextPart

from ...agent.agent_result import AgentResult
from ...telemetry.metrics import EventLoopMetrics
from ...types.a2a import A2AResponse
from ...types.agent import AgentInput
from ...types.content import ContentBlock, Message
from ...types.event_loop import StopReason

# Mapping from A2A TaskState to Strands stop_reason
_STATE_TO_STOP_REASON: dict[TaskState, StopReason] = {
TaskState.completed: "end_turn",
TaskState.failed: "end_turn",
TaskState.canceled: "end_turn",
TaskState.rejected: "end_turn",
TaskState.input_required: "interrupt",
TaskState.auth_required: "interrupt",
}


def convert_input_to_message(prompt: AgentInput) -> A2AMessage:
Expand Down Expand Up @@ -79,37 +90,69 @@ def convert_content_blocks_to_parts(content_blocks: list[ContentBlock]) -> list[
return parts


def _extract_task_state(response: A2AResponse) -> TaskState | None:
"""Extract the task state from an A2A response.

Args:
response: A2A response (either A2AMessage or tuple of task and update event).

Returns:
The TaskState if available, None otherwise.
"""
if isinstance(response, tuple) and len(response) == 2:
_task, update_event = response
if isinstance(update_event, TaskStatusUpdateEvent):
if update_event.status and hasattr(update_event.status, "state"):
return update_event.status.state
return None


def convert_response_to_agent_result(response: A2AResponse) -> AgentResult:
"""Convert A2A response to AgentResult.

Maps A2A task lifecycle states to appropriate Strands stop_reasons:
- completed → end_turn
- failed → end_turn (with error content)
- canceled → end_turn (with cancellation info)
- rejected → end_turn (with rejection info)
- input_required → interrupt (agent needs user input)
- auth_required → interrupt (agent needs authentication)

Args:
response: A2A response (either A2AMessage or tuple of task and update event).

Returns:
AgentResult with extracted content and metadata.
"""
content: list[ContentBlock] = []
task_state = _extract_task_state(response)
stop_reason: StopReason = _STATE_TO_STOP_REASON.get(task_state, "end_turn") if task_state else "end_turn"

if isinstance(response, tuple) and len(response) == 2:
task, update_event = response

# Handle artifact updates
if isinstance(update_event, TaskArtifactUpdateEvent):
if update_event.artifact and hasattr(update_event.artifact, "parts"):
if update_event.artifact and hasattr(update_event.artifact, "parts") and update_event.artifact.parts:
for part in update_event.artifact.parts:
if hasattr(part, "root") and hasattr(part.root, "text"):
content.append({"text": part.root.text})
# Handle status updates with messages
elif isinstance(update_event, TaskStatusUpdateEvent):
if update_event.status and hasattr(update_event.status, "message") and update_event.status.message:
if (
update_event.status
and hasattr(update_event.status, "message")
and update_event.status.message
and update_event.status.message.parts
):
for part in update_event.status.message.parts:
if hasattr(part, "root") and hasattr(part.root, "text"):
content.append({"text": part.root.text})

# Use task.artifacts when no content was extracted from the event
if not content and task and hasattr(task, "artifacts") and task.artifacts is not None:
for artifact in task.artifacts:
if hasattr(artifact, "parts"):
if hasattr(artifact, "parts") and artifact.parts:
for part in artifact.parts:
if hasattr(part, "root") and hasattr(part.root, "text"):
content.append({"text": part.root.text})
Expand All @@ -123,9 +166,14 @@ def convert_response_to_agent_result(response: A2AResponse) -> AgentResult:
"content": content,
}

# Build state dict with A2A metadata
state: dict[str, str] = {}
if task_state is not None:
state["a2a_task_state"] = task_state.value

return AgentResult(
stop_reason="end_turn",
stop_reason=stop_reason,
message=message,
metrics=EventLoopMetrics(),
state={},
state=state,
)
135 changes: 118 additions & 17 deletions src/strands/multiagent/a2a/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
streamed requests to the A2AServer.
"""

import asyncio
import base64
import json
import logging
Expand Down Expand Up @@ -42,7 +43,9 @@ class StrandsA2AExecutor(AgentExecutor):
"""Executor that adapts a Strands Agent to the A2A protocol.

This executor uses streaming mode to handle the execution of agent requests
and converts Strands Agent responses to A2A protocol events.
and converts Strands Agent responses to A2A protocol events. It supports the
full A2A task lifecycle including error handling (failed state), cancellation,
and interrupt-based input_required flows.
"""

# Default formats for each file type when MIME type is unavailable or unrecognized
Expand Down Expand Up @@ -75,14 +78,18 @@ async def execute(
"""Execute a request using the Strands Agent and send the response as A2A events.

This method executes the user's input using the Strands Agent in streaming mode
and converts the agent's response to A2A events.
and converts the agent's response to A2A events. If the agent raises an exception,
the task transitions to the `failed` state. If the agent returns with interrupts,
the task transitions to the `input_required` state.

Args:
context: The A2A request context, containing the user's input and task metadata.
event_queue: The A2A event queue used to send response events back to the client.

Raises:
ServerError: If an error occurs during agent execution
ServerError: If an unrecoverable error occurs during agent execution setup
(e.g., missing input). Agent execution errors are handled gracefully
by transitioning the task to the failed state.
"""
task = context.current_task
if not task:
Expand All @@ -93,8 +100,34 @@ async def execute(

try:
await self._execute_streaming(context, updater)
except Exception as e:
raise ServerError(error=InternalError()) from e
except ServerError:
# Re-raise ServerErrors (setup failures like missing input)
raise
except asyncio.CancelledError:
# asyncio.CancelledError is a BaseException (not Exception) — raised when
# the asyncio task is cancelled (e.g., HTTP client disconnect, server shutdown).
# We transition to canceled state so the task doesn't remain a zombie in "working".
logger.warning("task_id=<%s> | asyncio task cancelled, transitioning to canceled state", task.id)
try:
await updater.cancel(
message=updater.new_agent_message(
parts=[Part(root=TextPart(text="Task cancelled due to connection termination"))]
)
)
except RuntimeError:
# Task already in terminal state
logger.debug("task_id=<%s> | task already in terminal state, cannot transition to canceled", task.id)
raise
except Exception:
# Agent execution failures transition to failed state
logger.exception("task_id=<%s> | agent execution failed, transitioning to failed state", task.id)
try:
await updater.failed(
message=updater.new_agent_message(parts=[Part(root=TextPart(text="Agent execution failed"))])
)
except RuntimeError:
Comment thread
mkmeral marked this conversation as resolved.
# Task already in terminal state (e.g., completed before error in cleanup)
logger.debug("task_id=<%s> | task already in terminal state, cannot transition to failed", task.id)

async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater) -> None:
"""Execute request in streaming mode.
Expand All @@ -105,14 +138,19 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater
Args:
context: The A2A request context, containing the user's input and other metadata.
updater: The task updater for managing task state and sending updates.

Raises:
ServerError: If input conversion fails (missing or empty content).
"""
# Convert A2A message parts to Strands ContentBlocks
if context.message and hasattr(context.message, "parts"):
content_blocks = self._convert_a2a_parts_to_content_blocks(context.message.parts)
if not content_blocks:
raise ValueError("No content blocks available")
raise ServerError(
error=InternalError(message="No valid content found in request message parts")
) from None
else:
raise ValueError("No content blocks available")
raise ServerError(error=InternalError(message="Request message is missing or has no parts")) from None

if not self.enable_a2a_compliant_streaming:
warnings.warn(
Expand All @@ -133,8 +171,20 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater
invocation_state: dict[str, Any] = {"a2a_request_context": context}

try:
result: SAAgentResult | None = None
async for event in self.agent.stream_async(content_blocks, invocation_state=invocation_state):
await self._handle_streaming_event(event, updater)
if "result" in event:
result = event["result"]
else:
await self._handle_streaming_event(event, updater)

# Check if agent returned with interrupts (input_required)
# Note: stop_reason="interrupt" is the authoritative signal. Even if interrupts
# list is empty (edge case), the agent still indicated it needs input.
if result is not None and result.stop_reason == "interrupt":
await self._handle_interrupt_result(result, updater)
else:
await self._handle_agent_result(result, updater)
except Exception:
logger.exception("Error in streaming execution")
raise
Expand All @@ -143,6 +193,34 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater
self._current_artifact_id = None
self._is_first_chunk = True

async def _handle_interrupt_result(self, result: SAAgentResult, updater: TaskUpdater) -> None:
"""Handle an agent result that contains interrupts.

When the Strands Agent returns with stop_reason="interrupt", this maps to
the A2A `input_required` state. The interrupt details are communicated to
the client via the status message.

Args:
result: The agent result containing interrupts.
updater: The task updater for managing task state.
"""
# Build a descriptive message about what input is needed
interrupt_descriptions = []
for interrupt in result.interrupts or []:
desc = f"- {interrupt.name}"
if interrupt.reason:
desc += f": {interrupt.reason}"
interrupt_descriptions.append(desc)

if interrupt_descriptions:
input_message = "Agent requires input:\n" + "\n".join(interrupt_descriptions)
else:
# Edge case: stop_reason="interrupt" but no interrupt details provided.
# Still transition to input_required — the agent signaled it needs input.
input_message = "Agent requires additional input to continue"

await updater.requires_input(message=updater.new_agent_message(parts=[Part(root=TextPart(text=input_message))]))

async def _handle_streaming_event(self, event: dict[str, Any], updater: TaskUpdater) -> None:
"""Handle a single streaming event from the Strands Agent.

Expand Down Expand Up @@ -175,8 +253,6 @@ async def _handle_streaming_event(self, event: dict[str, Any], updater: TaskUpda
updater.task_id,
),
)
elif "result" in event:
await self._handle_agent_result(event["result"], updater)

async def _handle_agent_result(self, result: SAAgentResult | None, updater: TaskUpdater) -> None:
"""Handle the final result from the Strands Agent.
Expand Down Expand Up @@ -219,20 +295,45 @@ async def _handle_agent_result(self, result: SAAgentResult | None, updater: Task
async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None:
"""Cancel an ongoing execution.

This method is called when a request cancellation is requested. Currently,
cancellation is not supported by the Strands Agent executor, so this method
always raises an UnsupportedOperationError.
Transitions the task to the canceled state and attempts to stop the agent.
The agent's cancel() method is called if available to signal cooperative
cancellation of in-flight execution.

Note: This transitions the A2A task state. The underlying agent execution
may still complete its current model call before stopping.

Args:
context: The A2A request context.
event_queue: The A2A event queue.

Raises:
ServerError: Always raised with an UnsupportedOperationError, as cancellation
is not currently supported.
ServerError: If no current task exists or the task is already in a terminal state.
"""
logger.warning("Cancellation requested but not supported")
raise ServerError(error=UnsupportedOperationError())
task = context.current_task
if not task:
logger.warning("context_id=<%s> | cancel requested but no current task found", context.context_id)
raise ServerError(error=UnsupportedOperationError()) from None

# Attempt to cooperatively cancel the agent's execution (best-effort).
# Agent.cancel() may not exist on all implementations, so we guard with hasattr.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: The comment says "we guard with hasattr" but the implementation actually uses try/except on AttributeError (which is the right approach per mkmeral's feedback). The comment is a leftover from the previous iteration.

Suggestion: Update the comment to match the implementation:

# Attempt to cooperatively cancel the agent's execution (best-effort).
# Catches AttributeError/NotImplementedError if agent doesn't support cancel().

try:
self.agent.cancel()
except (AttributeError, NotImplementedError):
# Agent doesn't support cancel — proceed with state transition only
pass
except Exception:
logger.debug("task_id=<%s> | agent cancel signal failed (non-critical)", task.id)

updater = TaskUpdater(event_queue, task.id, task.context_id)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agent has cancel method. where do we actually use it? I don't see it. where do we cancel the current agent run?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — fixed in the latest push. cancel() now calls self.agent.cancel() if the method exists (cooperative cancellation):

# Attempt to stop the agent if it supports cancellation
if hasattr(self.agent, "cancel") and callable(self.agent.cancel):
    try:
        self.agent.cancel()
    except Exception:
        logger.debug("task_id=<%s> | agent cancel signal failed (non-critical)", task.id)

This is best-effort — if agent.cancel() fails, we still transition the A2A task state to canceled. The agent's current model call may complete, but no new iterations will start.

Docstring updated to reflect this accurately.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: isn't the type of self.agent Agent? why do we need ducktyping checks?


try:
await updater.cancel(
message=updater.new_agent_message(parts=[Part(root=TextPart(text="Task cancelled by client request"))])
)
except RuntimeError:
# TaskUpdater raises RuntimeError when task is already in a terminal state
logger.warning("task_id=<%s> | cannot cancel, already in terminal state", task.id)
raise ServerError(error=UnsupportedOperationError()) from None

def _get_file_type_from_mime_type(self, mime_type: str | None) -> Literal["document", "image", "video", "unknown"]:
"""Classify file type based on MIME type.
Expand Down
Loading
Loading