Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 36 additions & 17 deletions src/agents/run_internal/oai_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,25 @@ def _has_output_payload(item: Any) -> bool:
return (isinstance(item, dict) and "output" in item) or hasattr(item, "output")


def _is_tracked_object(items: Sequence[Any], candidate: Any) -> bool:
"""Return True when the exact object instance is already tracked."""
return any(item is candidate for item in items)


def _track_object_once(items: list[Any], candidate: Any) -> None:
"""Track an object instance once, keeping it alive while identity dedupe is needed."""
if not _is_tracked_object(items, candidate):
items.append(candidate)


def _untrack_object(items: list[Any], candidate: Any) -> None:
"""Remove an object instance from an identity-tracking list."""
for index, item in enumerate(items):
if item is candidate:
items.pop(index)
return


@dataclass
class OpenAIServerConversationTracker:
"""Track server-side conversation state for conversation-aware runs.
Expand All @@ -113,9 +132,10 @@ class OpenAIServerConversationTracker:
previous_response_id: str | None = None
auto_previous_response_id: bool = False

# In-process object identity for items that have already been delivered or acknowledged.
sent_items: set[int] = field(default_factory=set)
server_items: set[int] = field(default_factory=set)
# In-process object identity for delivered or acknowledged items. Keep object references
# instead of id(obj) integers so a later allocation cannot reuse a stale address.
sent_items: list[Any] = field(default_factory=list)
server_items: list[Any] = field(default_factory=list)

# Stable provider identifiers returned by the Responses API.
server_item_ids: set[str] = field(default_factory=set)
Expand Down Expand Up @@ -200,7 +220,7 @@ def hydrate_from_state(
for output_item in response.output:
if output_item is None:
continue
self.server_items.add(id(output_item))
_track_object_once(self.server_items, output_item)
item_id = _normalize_server_item_id(
output_item.get("id")
if isinstance(output_item, dict)
Expand Down Expand Up @@ -263,8 +283,7 @@ def hydrate_from_state(
if not should_mark:
continue

raw_item_id = id(raw_item)
self.sent_items.add(raw_item_id)
_track_object_once(self.sent_items, raw_item)
fp = _fingerprint_for_tracker(raw_item)
if fp:
self.sent_item_fingerprints.add(fp)
Expand Down Expand Up @@ -297,7 +316,7 @@ def hydrate_from_state(
if not should_mark:
continue

self.sent_items.add(id(raw_item))
_track_object_once(self.sent_items, raw_item)
fp = _fingerprint_for_tracker(raw_item)
if fp:
self.sent_item_fingerprints.add(fp)
Expand All @@ -321,7 +340,7 @@ def track_server_items(self, model_response: ModelResponse | None) -> None:
for output_item in model_response.output:
if output_item is None:
continue
self.server_items.add(id(output_item))
_track_object_once(self.server_items, output_item)
item_id = _normalize_server_item_id(
output_item.get("id")
if isinstance(output_item, dict)
Expand Down Expand Up @@ -361,17 +380,16 @@ def mark_input_as_sent(self, items: Sequence[TResponseInputItem]) -> None:
if not items:
return

delivered_source_ids: set[int] = set()
delivered_sources: list[TResponseInputItem] = []
delivered_by_content: set[str] = set()
for item in items:
if item is None:
continue
source_item = self._consume_prepared_item_source(item)
source_item_id = id(source_item)
if source_item_id in delivered_source_ids:
if _is_tracked_object(delivered_sources, source_item):
continue
delivered_source_ids.add(source_item_id)
self.sent_items.add(source_item_id)
delivered_sources.append(source_item)
_track_object_once(self.sent_items, source_item)
fp = _fingerprint_for_tracker(source_item)
if fp:
delivered_by_content.add(fp)
Expand All @@ -382,7 +400,7 @@ def mark_input_as_sent(self, items: Sequence[TResponseInputItem]) -> None:

remaining: list[TResponseInputItem] = []
for pending in self.remaining_initial_input:
if id(pending) in delivered_source_ids:
if _is_tracked_object(delivered_sources, pending):
continue
pending_fp = _fingerprint_for_tracker(pending)
if pending_fp and pending_fp in delivered_by_content:
Expand All @@ -402,7 +420,7 @@ def rewind_input(self, items: Sequence[TResponseInputItem]) -> None:
continue
source_item = self._consume_prepared_item_source(item)
rewind_items.append(source_item)
self.sent_items.discard(id(source_item))
_untrack_object(self.sent_items, source_item)
fp = _fingerprint_for_tracker(source_item)
if fp:
self.sent_item_fingerprints.discard(fp)
Expand Down Expand Up @@ -469,8 +487,9 @@ def prepare_input(
):
continue

raw_item_id = id(raw_item)
if raw_item_id in self.sent_items or raw_item_id in self.server_items:
if _is_tracked_object(self.sent_items, raw_item) or _is_tracked_object(
self.server_items, raw_item
):
continue

converted_input_item = run_item_to_input_item(run_item, self.reasoning_item_id_policy)
Expand Down
73 changes: 67 additions & 6 deletions tests/test_server_conversation_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def test_hydrate_from_state_does_not_track_string_initial_input_by_object_identi
model_responses=[],
)

assert tracker.sent_items == set()
assert tracker.sent_items == []
assert tracker.sent_initial_input is True
assert tracker.remaining_initial_input is None
assert len(tracker.sent_item_fingerprints) == 1
Expand All @@ -238,7 +238,7 @@ def test_hydrate_from_state_does_not_track_list_initial_input_by_object_identity
model_responses=[],
)

assert tracker.sent_items == set()
assert tracker.sent_items == []
assert tracker.sent_initial_input is True
assert tracker.remaining_initial_input is None
assert len(tracker.sent_item_fingerprints) == 1
Expand Down Expand Up @@ -278,8 +278,8 @@ def test_mark_input_as_sent_uses_raw_generated_source_for_rebuilt_filtered_item(

tracker.mark_input_as_sent([rebuilt_filtered_item])

assert id(raw_generated_item) in tracker.sent_items
assert id(rebuilt_filtered_item) not in tracker.sent_items
assert any(item is raw_generated_item for item in tracker.sent_items)
assert all(item is not rebuilt_filtered_item for item in tracker.sent_items)

prepared_again = tracker.prepare_input(
original_input=[],
Expand Down Expand Up @@ -821,8 +821,8 @@ def _filter_input(payload: Any) -> ModelInputData:
)

assert model.last_turn_args["input"] == [item_1]
assert id(item_1) in tracker.sent_items
assert id(item_2) not in tracker.sent_items
assert any(item is item_1 for item in tracker.sent_items)
assert all(item is not item_2 for item in tracker.sent_items)


@pytest.mark.asyncio
Expand Down Expand Up @@ -965,3 +965,64 @@ def _filter_input(payload: Any) -> ModelInputData:
assert len(tool_call_events) == 1
assert tool_call_events[0].description == "Search the docs."
assert tool_call_events[0].title == "Search Docs"


@pytest.mark.parametrize("stale_collection_name", ["sent_items", "server_items"])
def test_prepare_input_keeps_fresh_tool_output_when_stale_identity_matches(
stale_collection_name: str,
) -> None:
"""Tracked object identity must not become a stale address-based dedupe key."""
tracker = OpenAIServerConversationTracker(previous_response_id="resp-1")

output_raw_item: dict[str, Any] = {
"type": "function_call_output",
"call_id": "call_FRESH",
"output": "42",
}
tracked_items = getattr(tracker, stale_collection_name)
if isinstance(tracked_items, set):
tracked_items.add(id(output_raw_item))
else:
old_item = {"type": "message", "content": "already tracked"}
tracked_items.append(old_item)

generated_items = [DummyRunItem(output_raw_item, type="function_call_output_item")]

prepared = tracker.prepare_input(
original_input=[],
generated_items=cast(list[Any], generated_items),
)

prepared_output_call_ids = [
item.get("call_id")
for item in prepared
if isinstance(item, dict) and item.get("type") == "function_call_output"
]
assert "call_FRESH" in prepared_output_call_ids


def test_prepare_input_dedupes_same_delivered_tool_output_object() -> None:
"""Identity dedupe still skips the exact source object after it is delivered."""
tracker = OpenAIServerConversationTracker(previous_response_id="resp-1")

output_raw_item: dict[str, Any] = {
"type": "function_call_output",
"call_id": "call_X",
"output": "42",
}
generated_items = [DummyRunItem(output_raw_item, type="function_call_output_item")]

first = tracker.prepare_input(
original_input=[],
generated_items=cast(list[Any], generated_items),
)
assert any(isinstance(item, dict) and item.get("call_id") == "call_X" for item in first)

tracker.mark_input_as_sent(first)
assert any(item is output_raw_item for item in tracker.sent_items)

second = tracker.prepare_input(
original_input=[],
generated_items=cast(list[Any], generated_items),
)
assert all(not (isinstance(item, dict) and item.get("call_id") == "call_X") for item in second)
Loading