diff --git a/src/agents/run_internal/oai_conversation.py b/src/agents/run_internal/oai_conversation.py index 4a0e088353..25d3cc4c14 100644 --- a/src/agents/run_internal/oai_conversation.py +++ b/src/agents/run_internal/oai_conversation.py @@ -95,6 +95,25 @@ def _has_output_payload(item: Any) -> bool: return (isinstance(item, dict) and "output" in item) or hasattr(item, "output") +def _is_tracked_object(items: Sequence[Any], candidate: Any) -> bool: + """Return True when the exact object instance is already tracked.""" + return any(item is candidate for item in items) + + +def _track_object_once(items: list[Any], candidate: Any) -> None: + """Track an object instance once, keeping it alive while identity dedupe is needed.""" + if not _is_tracked_object(items, candidate): + items.append(candidate) + + +def _untrack_object(items: list[Any], candidate: Any) -> None: + """Remove an object instance from an identity-tracking list.""" + for index, item in enumerate(items): + if item is candidate: + items.pop(index) + return + + @dataclass class OpenAIServerConversationTracker: """Track server-side conversation state for conversation-aware runs. @@ -113,9 +132,10 @@ class OpenAIServerConversationTracker: previous_response_id: str | None = None auto_previous_response_id: bool = False - # In-process object identity for items that have already been delivered or acknowledged. - sent_items: set[int] = field(default_factory=set) - server_items: set[int] = field(default_factory=set) + # In-process object identity for delivered or acknowledged items. Keep object references + # instead of id(obj) integers so a later allocation cannot reuse a stale address. + sent_items: list[Any] = field(default_factory=list) + server_items: list[Any] = field(default_factory=list) # Stable provider identifiers returned by the Responses API. server_item_ids: set[str] = field(default_factory=set) @@ -200,7 +220,7 @@ def hydrate_from_state( for output_item in response.output: if output_item is None: continue - self.server_items.add(id(output_item)) + _track_object_once(self.server_items, output_item) item_id = _normalize_server_item_id( output_item.get("id") if isinstance(output_item, dict) @@ -263,8 +283,7 @@ def hydrate_from_state( if not should_mark: continue - raw_item_id = id(raw_item) - self.sent_items.add(raw_item_id) + _track_object_once(self.sent_items, raw_item) fp = _fingerprint_for_tracker(raw_item) if fp: self.sent_item_fingerprints.add(fp) @@ -297,7 +316,7 @@ def hydrate_from_state( if not should_mark: continue - self.sent_items.add(id(raw_item)) + _track_object_once(self.sent_items, raw_item) fp = _fingerprint_for_tracker(raw_item) if fp: self.sent_item_fingerprints.add(fp) @@ -321,7 +340,7 @@ def track_server_items(self, model_response: ModelResponse | None) -> None: for output_item in model_response.output: if output_item is None: continue - self.server_items.add(id(output_item)) + _track_object_once(self.server_items, output_item) item_id = _normalize_server_item_id( output_item.get("id") if isinstance(output_item, dict) @@ -361,17 +380,16 @@ def mark_input_as_sent(self, items: Sequence[TResponseInputItem]) -> None: if not items: return - delivered_source_ids: set[int] = set() + delivered_sources: list[TResponseInputItem] = [] delivered_by_content: set[str] = set() for item in items: if item is None: continue source_item = self._consume_prepared_item_source(item) - source_item_id = id(source_item) - if source_item_id in delivered_source_ids: + if _is_tracked_object(delivered_sources, source_item): continue - delivered_source_ids.add(source_item_id) - self.sent_items.add(source_item_id) + delivered_sources.append(source_item) + _track_object_once(self.sent_items, source_item) fp = _fingerprint_for_tracker(source_item) if fp: delivered_by_content.add(fp) @@ -382,7 +400,7 @@ def mark_input_as_sent(self, items: Sequence[TResponseInputItem]) -> None: remaining: list[TResponseInputItem] = [] for pending in self.remaining_initial_input: - if id(pending) in delivered_source_ids: + if _is_tracked_object(delivered_sources, pending): continue pending_fp = _fingerprint_for_tracker(pending) if pending_fp and pending_fp in delivered_by_content: @@ -402,7 +420,7 @@ def rewind_input(self, items: Sequence[TResponseInputItem]) -> None: continue source_item = self._consume_prepared_item_source(item) rewind_items.append(source_item) - self.sent_items.discard(id(source_item)) + _untrack_object(self.sent_items, source_item) fp = _fingerprint_for_tracker(source_item) if fp: self.sent_item_fingerprints.discard(fp) @@ -469,8 +487,9 @@ def prepare_input( ): continue - raw_item_id = id(raw_item) - if raw_item_id in self.sent_items or raw_item_id in self.server_items: + if _is_tracked_object(self.sent_items, raw_item) or _is_tracked_object( + self.server_items, raw_item + ): continue converted_input_item = run_item_to_input_item(run_item, self.reasoning_item_id_policy) diff --git a/tests/test_server_conversation_tracker.py b/tests/test_server_conversation_tracker.py index 703e2c6824..d7fff5a5ed 100644 --- a/tests/test_server_conversation_tracker.py +++ b/tests/test_server_conversation_tracker.py @@ -220,7 +220,7 @@ def test_hydrate_from_state_does_not_track_string_initial_input_by_object_identi model_responses=[], ) - assert tracker.sent_items == set() + assert tracker.sent_items == [] assert tracker.sent_initial_input is True assert tracker.remaining_initial_input is None assert len(tracker.sent_item_fingerprints) == 1 @@ -238,7 +238,7 @@ def test_hydrate_from_state_does_not_track_list_initial_input_by_object_identity model_responses=[], ) - assert tracker.sent_items == set() + assert tracker.sent_items == [] assert tracker.sent_initial_input is True assert tracker.remaining_initial_input is None assert len(tracker.sent_item_fingerprints) == 1 @@ -278,8 +278,8 @@ def test_mark_input_as_sent_uses_raw_generated_source_for_rebuilt_filtered_item( tracker.mark_input_as_sent([rebuilt_filtered_item]) - assert id(raw_generated_item) in tracker.sent_items - assert id(rebuilt_filtered_item) not in tracker.sent_items + assert any(item is raw_generated_item for item in tracker.sent_items) + assert all(item is not rebuilt_filtered_item for item in tracker.sent_items) prepared_again = tracker.prepare_input( original_input=[], @@ -821,8 +821,8 @@ def _filter_input(payload: Any) -> ModelInputData: ) assert model.last_turn_args["input"] == [item_1] - assert id(item_1) in tracker.sent_items - assert id(item_2) not in tracker.sent_items + assert any(item is item_1 for item in tracker.sent_items) + assert all(item is not item_2 for item in tracker.sent_items) @pytest.mark.asyncio @@ -965,3 +965,64 @@ def _filter_input(payload: Any) -> ModelInputData: assert len(tool_call_events) == 1 assert tool_call_events[0].description == "Search the docs." assert tool_call_events[0].title == "Search Docs" + + +@pytest.mark.parametrize("stale_collection_name", ["sent_items", "server_items"]) +def test_prepare_input_keeps_fresh_tool_output_when_stale_identity_matches( + stale_collection_name: str, +) -> None: + """Tracked object identity must not become a stale address-based dedupe key.""" + tracker = OpenAIServerConversationTracker(previous_response_id="resp-1") + + output_raw_item: dict[str, Any] = { + "type": "function_call_output", + "call_id": "call_FRESH", + "output": "42", + } + tracked_items = getattr(tracker, stale_collection_name) + if isinstance(tracked_items, set): + tracked_items.add(id(output_raw_item)) + else: + old_item = {"type": "message", "content": "already tracked"} + tracked_items.append(old_item) + + generated_items = [DummyRunItem(output_raw_item, type="function_call_output_item")] + + prepared = tracker.prepare_input( + original_input=[], + generated_items=cast(list[Any], generated_items), + ) + + prepared_output_call_ids = [ + item.get("call_id") + for item in prepared + if isinstance(item, dict) and item.get("type") == "function_call_output" + ] + assert "call_FRESH" in prepared_output_call_ids + + +def test_prepare_input_dedupes_same_delivered_tool_output_object() -> None: + """Identity dedupe still skips the exact source object after it is delivered.""" + tracker = OpenAIServerConversationTracker(previous_response_id="resp-1") + + output_raw_item: dict[str, Any] = { + "type": "function_call_output", + "call_id": "call_X", + "output": "42", + } + generated_items = [DummyRunItem(output_raw_item, type="function_call_output_item")] + + first = tracker.prepare_input( + original_input=[], + generated_items=cast(list[Any], generated_items), + ) + assert any(isinstance(item, dict) and item.get("call_id") == "call_X" for item in first) + + tracker.mark_input_as_sent(first) + assert any(item is output_raw_item for item in tracker.sent_items) + + second = tracker.prepare_input( + original_input=[], + generated_items=cast(list[Any], generated_items), + ) + assert all(not (isinstance(item, dict) and item.get("call_id") == "call_X") for item in second)