Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions eval_protocol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
)
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, model_validator

from eval_protocol.get_pep440_version import get_pep440_version
from eval_protocol.human_id import generate_id
Expand Down Expand Up @@ -517,6 +517,13 @@ class Message(BaseModel):
function_call: Optional[FunctionCall] = None
control_plane_step: Optional[Dict[str, Any]] = None
weight: Optional[int] = None
token_ids: Optional[List[int]] = Field(
default=None,
description=(
"Optional token IDs for this message. When set on assistant messages, "
"these should come from the same generation call as logprobs."
),
)
logprobs: Optional[Any] = Field(
default=None,
description=(
Expand All @@ -529,9 +536,21 @@ def dump_mdoel_for_chat_completion_request(self):
"""Only keep chat completion accepted fields"""
return self.model_dump(
exclude_none=True,
exclude={"control_plane_step", "reasoning_content", "weight", "logprobs"},
exclude={"control_plane_step", "reasoning_content", "weight", "token_ids", "logprobs"},
)

@model_validator(mode="after")
def _validate_token_ids_logprobs_alignment(self) -> "Message":
if self.token_ids is None or self.logprobs is None:
return self
if isinstance(self.logprobs, list) and all(isinstance(lp, (int, float)) for lp in self.logprobs):
if len(self.token_ids) != len(self.logprobs):
raise ValueError(
"token_ids and float logprobs must have the same length "
f"(got {len(self.token_ids)} token_ids and {len(self.logprobs)} logprobs)"
)
return self

@classmethod
def model_validate(cls, obj, *args, **kwargs):
if isinstance(obj, dict):
Expand Down
24 changes: 24 additions & 0 deletions eval_protocol/pytest/default_single_turn_rollout_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,28 @@ def _serialize_logprobs(logprobs: Any) -> Any:
return logprobs


def _extract_token_ids_from_logprobs(logprobs: Any) -> List[int] | None:
"""Extract token IDs from a serialized provider logprobs payload when present."""

if not isinstance(logprobs, dict):
return None

content = logprobs.get("content")
if isinstance(content, list) and content:
token_ids: List[int] = []
for item in content:
if not isinstance(item, dict) or item.get("token_id") is None:
return None
token_ids.append(int(item["token_id"]))
return token_ids

raw_token_ids = logprobs.get("token_ids")
if isinstance(raw_token_ids, list) and raw_token_ids:
return [int(token_id) for token_id in raw_token_ids]

return None


class SingleTurnRolloutProcessor(RolloutProcessor):
"""Single turn rollout processor for direct LLM calls."""

Expand Down Expand Up @@ -136,6 +158,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
assistant_message = response.choices[0].message
finish_reason = getattr(response.choices[0], "finish_reason", None)
assistant_logprobs = _serialize_logprobs(getattr(response.choices[0], "logprobs", None))
assistant_token_ids = _extract_token_ids_from_logprobs(assistant_logprobs)

# Extract content
assistant_content = assistant_message.content or ""
Expand Down Expand Up @@ -190,6 +213,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
content=assistant_content,
reasoning_content=reasoning_content,
tool_calls=converted_tool_calls,
token_ids=assistant_token_ids,
logprobs=assistant_logprobs,
)
]
Expand Down
17 changes: 17 additions & 0 deletions tests/test_eval_protocol_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,23 @@ def test_message_creation(self):
assert msg.role == "user"
assert msg.content == "Test message"

def test_message_preserves_token_ids(self):
"""Test token IDs round-trip on messages."""
from eval_protocol import Message

msg = Message(role="assistant", content="Hi", token_ids=[1, 2], logprobs=[-0.1, -0.2])
assert msg.model_dump()["token_ids"] == [1, 2]

def test_message_rejects_misaligned_float_logprobs(self):
"""Test token IDs and flat float logprobs must align."""
import pytest
from pydantic import ValidationError

from eval_protocol import Message

with pytest.raises(ValidationError):
Message(role="assistant", content="Hi", token_ids=[1, 2], logprobs=[-0.1])

def test_utility_functions(self):
"""Test that utility functions work through eval_protocol."""
from eval_protocol import create_llm_resource, load_jsonl
Expand Down
3 changes: 2 additions & 1 deletion tests/test_rollout_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_single_turn_rollout_captures_logprobs(monkeypatch):
async def fake_acompletion(**kwargs):
assert kwargs["logprobs"] is True
assert kwargs["top_logprobs"] == 2
logprobs = {"content": [{"token": "hello", "logprob": -0.1, "top_logprobs": []}]}
logprobs = {"content": [{"token": "hello", "token_id": 15339, "logprob": -0.1, "top_logprobs": []}]}
return ModelResponse(
id="resp-1",
choices=[
Expand All @@ -53,6 +53,7 @@ async def _run() -> None:
assistant_logprobs = completed_rows[0].messages[-1].logprobs
assert isinstance(assistant_logprobs, dict)
assert assistant_logprobs["content"][0]["token"] == "hello"
assert completed_rows[0].messages[-1].token_ids == [15339]
assert assistant_logprobs["content"][0]["logprob"] == -0.1

asyncio.run(_run())
Loading