From b6c1af91b5fd9c3195723c63f91843a9412d11b5 Mon Sep 17 00:00:00 2001 From: hecate0821 Date: Tue, 28 Apr 2026 03:31:54 +0000 Subject: [PATCH] feat: preserve message token ids --- eval_protocol/models.py | 23 ++++++++++++++++-- .../default_single_turn_rollout_process.py | 24 +++++++++++++++++++ tests/test_eval_protocol_import.py | 17 +++++++++++++ tests/test_rollout_logprobs.py | 3 ++- 4 files changed, 64 insertions(+), 3 deletions(-) diff --git a/eval_protocol/models.py b/eval_protocol/models.py index 90ca21c8..cbe42351 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -14,7 +14,7 @@ from openai.types.chat.chat_completion_message_tool_call import ( ChatCompletionMessageToolCall, ) -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, model_validator from eval_protocol.get_pep440_version import get_pep440_version from eval_protocol.human_id import generate_id @@ -517,6 +517,13 @@ class Message(BaseModel): function_call: Optional[FunctionCall] = None control_plane_step: Optional[Dict[str, Any]] = None weight: Optional[int] = None + token_ids: Optional[List[int]] = Field( + default=None, + description=( + "Optional token IDs for this message. When set on assistant messages, " + "these should come from the same generation call as logprobs." + ), + ) logprobs: Optional[Any] = Field( default=None, description=( @@ -529,9 +536,21 @@ def dump_mdoel_for_chat_completion_request(self): """Only keep chat completion accepted fields""" return self.model_dump( exclude_none=True, - exclude={"control_plane_step", "reasoning_content", "weight", "logprobs"}, + exclude={"control_plane_step", "reasoning_content", "weight", "token_ids", "logprobs"}, ) + @model_validator(mode="after") + def _validate_token_ids_logprobs_alignment(self) -> "Message": + if self.token_ids is None or self.logprobs is None: + return self + if isinstance(self.logprobs, list) and all(isinstance(lp, (int, float)) for lp in self.logprobs): + if len(self.token_ids) != len(self.logprobs): + raise ValueError( + "token_ids and float logprobs must have the same length " + f"(got {len(self.token_ids)} token_ids and {len(self.logprobs)} logprobs)" + ) + return self + @classmethod def model_validate(cls, obj, *args, **kwargs): if isinstance(obj, dict): diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index cabab274..c704474f 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -44,6 +44,28 @@ def _serialize_logprobs(logprobs: Any) -> Any: return logprobs +def _extract_token_ids_from_logprobs(logprobs: Any) -> List[int] | None: + """Extract token IDs from a serialized provider logprobs payload when present.""" + + if not isinstance(logprobs, dict): + return None + + content = logprobs.get("content") + if isinstance(content, list) and content: + token_ids: List[int] = [] + for item in content: + if not isinstance(item, dict) or item.get("token_id") is None: + return None + token_ids.append(int(item["token_id"])) + return token_ids + + raw_token_ids = logprobs.get("token_ids") + if isinstance(raw_token_ids, list) and raw_token_ids: + return [int(token_id) for token_id in raw_token_ids] + + return None + + class SingleTurnRolloutProcessor(RolloutProcessor): """Single turn rollout processor for direct LLM calls.""" @@ -136,6 +158,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: assistant_message = response.choices[0].message finish_reason = getattr(response.choices[0], "finish_reason", None) assistant_logprobs = _serialize_logprobs(getattr(response.choices[0], "logprobs", None)) + assistant_token_ids = _extract_token_ids_from_logprobs(assistant_logprobs) # Extract content assistant_content = assistant_message.content or "" @@ -190,6 +213,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: content=assistant_content, reasoning_content=reasoning_content, tool_calls=converted_tool_calls, + token_ids=assistant_token_ids, logprobs=assistant_logprobs, ) ] diff --git a/tests/test_eval_protocol_import.py b/tests/test_eval_protocol_import.py index 4777be1e..b270b773 100644 --- a/tests/test_eval_protocol_import.py +++ b/tests/test_eval_protocol_import.py @@ -262,6 +262,23 @@ def test_message_creation(self): assert msg.role == "user" assert msg.content == "Test message" + def test_message_preserves_token_ids(self): + """Test token IDs round-trip on messages.""" + from eval_protocol import Message + + msg = Message(role="assistant", content="Hi", token_ids=[1, 2], logprobs=[-0.1, -0.2]) + assert msg.model_dump()["token_ids"] == [1, 2] + + def test_message_rejects_misaligned_float_logprobs(self): + """Test token IDs and flat float logprobs must align.""" + import pytest + from pydantic import ValidationError + + from eval_protocol import Message + + with pytest.raises(ValidationError): + Message(role="assistant", content="Hi", token_ids=[1, 2], logprobs=[-0.1]) + def test_utility_functions(self): """Test that utility functions work through eval_protocol.""" from eval_protocol import create_llm_resource, load_jsonl diff --git a/tests/test_rollout_logprobs.py b/tests/test_rollout_logprobs.py index 1aad2322..6e593067 100644 --- a/tests/test_rollout_logprobs.py +++ b/tests/test_rollout_logprobs.py @@ -28,7 +28,7 @@ def test_single_turn_rollout_captures_logprobs(monkeypatch): async def fake_acompletion(**kwargs): assert kwargs["logprobs"] is True assert kwargs["top_logprobs"] == 2 - logprobs = {"content": [{"token": "hello", "logprob": -0.1, "top_logprobs": []}]} + logprobs = {"content": [{"token": "hello", "token_id": 15339, "logprob": -0.1, "top_logprobs": []}]} return ModelResponse( id="resp-1", choices=[ @@ -53,6 +53,7 @@ async def _run() -> None: assistant_logprobs = completed_rows[0].messages[-1].logprobs assert isinstance(assistant_logprobs, dict) assert assistant_logprobs["content"][0]["token"] == "hello" + assert completed_rows[0].messages[-1].token_ids == [15339] assert assistant_logprobs["content"][0]["logprob"] == -0.1 asyncio.run(_run())