diff --git a/eval_protocol/pytest/__init__.py b/eval_protocol/pytest/__init__.py index 1990042e..72142e9a 100644 --- a/eval_protocol/pytest/__init__.py +++ b/eval_protocol/pytest/__init__.py @@ -18,6 +18,10 @@ "MCPGymRolloutProcessor": (".default_mcp_gym_rollout_processor", "MCPGymRolloutProcessor"), "NoOpRolloutProcessor": (".default_no_op_rollout_processor", "NoOpRolloutProcessor"), "SingleTurnRolloutProcessor": (".default_single_turn_rollout_process", "SingleTurnRolloutProcessor"), + "FireworksTrainingRolloutProcessor": ( + ".default_fireworks_training_rollout_processor", + "FireworksTrainingRolloutProcessor", + ), "RemoteRolloutProcessor": (".remote_rollout_processor", "RemoteRolloutProcessor"), "GithubActionRolloutProcessor": (".github_action_rollout_processor", "GithubActionRolloutProcessor"), "RolloutProcessor": (".rollout_processor", "RolloutProcessor"), @@ -102,6 +106,7 @@ def __dir__(): "MCPGymRolloutProcessor", "RolloutProcessor", "SingleTurnRolloutProcessor", + "FireworksTrainingRolloutProcessor", "RemoteRolloutProcessor", "GithubActionRolloutProcessor", "NoOpRolloutProcessor", @@ -132,6 +137,9 @@ def __dir__(): from .default_mcp_gym_rollout_processor import MCPGymRolloutProcessor as MCPGymRolloutProcessor from .default_no_op_rollout_processor import NoOpRolloutProcessor as NoOpRolloutProcessor from .default_single_turn_rollout_process import SingleTurnRolloutProcessor as SingleTurnRolloutProcessor + from .default_fireworks_training_rollout_processor import ( + FireworksTrainingRolloutProcessor as FireworksTrainingRolloutProcessor, + ) from .remote_rollout_processor import RemoteRolloutProcessor as RemoteRolloutProcessor from .github_action_rollout_processor import GithubActionRolloutProcessor as GithubActionRolloutProcessor from .evaluation_test import evaluation_test as evaluation_test diff --git a/eval_protocol/pytest/default_fireworks_training_rollout_processor.py b/eval_protocol/pytest/default_fireworks_training_rollout_processor.py new file mode 100644 index 00000000..e4a0f61f --- /dev/null +++ b/eval_protocol/pytest/default_fireworks_training_rollout_processor.py @@ -0,0 +1,259 @@ +"""Default training-aware :class:`RolloutProcessor` for Fireworks RFT. + +Unlike :class:`SingleTurnRolloutProcessor`, which uses LiteLLM chat +completions and discards token-level information, this processor drives +``FireworksV1CompletionsClient`` against a Fireworks ``/v1/completions`` +endpoint and surfaces the per-sample token ids / inference logprobs +needed by reinforcement-fine-tuning training (GRPO, CISPO, DAPO, GSPO). + +Why this exists +--------------- +RFT training loops consume token-level data per completion (prompt ids, +completion ids, inference logprobs). ``SingleTurnRolloutProcessor`` was +never meant to produce that; customers who need it have to write a +bespoke :class:`RolloutProcessor` (see the FrozenLake example in +``fw-ai/cookbook``, which is ~800 lines). This processor promotes that +bespoke pattern into an Eval Protocol default so managed Fireworks RFT +jobs can wire it up for every customer evaluator bundle without +customer code changes. + +What it puts on the row +----------------------- +Besides the usual ``EvaluationRow.messages`` update (append the first +completion as an ``assistant`` turn so existing evaluators keep working), +this processor writes the following keys to +``EvaluationRow.execution_metadata.extra``: + +* ``prompt_ids`` — ``list[int]`` (shared across the N completions) +* ``completion_ids`` — ``list[list[int]]`` (one per completion) +* ``inference_logprobs``— ``list[list[float]]`` aligned to completion tokens +* ``completions_text`` — ``list[str]`` (one per completion) +* ``truncated`` — ``list[bool]`` (True when ``finish_reason == 'length'``) +* ``finish_reasons`` — ``list[str]`` + +Shape choice +------------ +Keys are ``list[list[...]]`` keyed by completion index rather than the +flattened concat convention used by +:class:`OpenEnvRolloutProcessor` (which is multi-turn and has no natural +per-completion structure). Single-turn RFT samples ``n>1`` completions +per prompt for advantage estimation, so a per-completion shape is what +the training adapter actually needs. + +Ergonomics +---------- +The processor reads all sampling knobs (``model``, ``temperature``, +``max_tokens``, ``n``) from ``config.completion_params``, matching +:class:`SingleTurnRolloutProcessor`. Customer evaluator bundles don't +need to reference this class — the managed RFT launcher swaps it in. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import time +from typing import Any + +from openai.types import CompletionUsage + +from eval_protocol.dataset_logger import default_logger +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest.rollout_processor import RolloutProcessor +from eval_protocol.pytest.types import RolloutProcessorConfig + +logger = logging.getLogger(__name__) + + +def _as_list_of_dicts(messages: list[Message]) -> list[dict[str, Any]]: + """Convert ``EvaluationRow.messages`` into the dict shape the client expects.""" + return [m.dump_mdoel_for_chat_completion_request() for m in messages] + + +def _append_extra(row: EvaluationRow, updates: dict[str, Any]) -> None: + """Merge *updates* into ``row.execution_metadata.extra`` without clobbering.""" + current = row.execution_metadata.extra if row.execution_metadata else None + merged: dict[str, Any] = dict(current) if current else {} + merged.update(updates) + row.execution_metadata.extra = merged + + +class FireworksTrainingRolloutProcessor(RolloutProcessor): + """Single-turn rollout with token-level outputs attached for RFT training. + + Args: + drop_trailing_assistant_messages: When True (default), strip trailing + assistant messages from the input conversation before sampling + — matches :class:`SingleTurnRolloutProcessor` behaviour. + tokenizer_name_or_path: Override for HuggingFace tokenizer lookup. + When not set, the model id from ``completion_params["model"]`` is + used (via ``FireworksV1CompletionsClient``'s default behaviour). + api_key: Override for the Fireworks API key. Defaults to the + ``FIREWORKS_API_KEY`` env var at first use. + base_url: Override for the Fireworks API base URL. Defaults to the + ``FIREWORKS_BASE_URL`` env var if set, else the SDK default. + """ + + def __init__( + self, + *, + drop_trailing_assistant_messages: bool = True, + tokenizer_name_or_path: str | None = None, + api_key: str | None = None, + base_url: str | None = None, + ) -> None: + self.drop_trailing_assistant_messages = drop_trailing_assistant_messages + self.tokenizer_name_or_path = tokenizer_name_or_path + self._api_key = api_key + self._base_url = base_url + # One client per model id per processor instance; cached lazily in setup(). + self._clients: dict[str, Any] = {} + + def setup(self) -> None: + """Validate the Fireworks SDK / tokenizer deps up front.""" + # Defer the heavy import to setup so processor construction is cheap. + from eval_protocol.integrations.fireworks_v1_completions_client import ( # noqa: F401 + FireworksV1CompletionsClient, + ) + + async def acleanup(self) -> None: + for client in self._clients.values(): + try: + await client.close() + except Exception: + logger.debug("FireworksV1CompletionsClient.close() failed", exc_info=True) + self._clients.clear() + + def _client_for(self, *, model_id: str, temperature: float, max_tokens: int) -> Any: + """Get-or-create a ``FireworksV1CompletionsClient`` for *model_id*.""" + cached = self._clients.get(model_id) + if cached is not None: + return cached + + from eval_protocol.integrations.fireworks_v1_completions_client import ( + FireworksV1CompletionsClient, + ) + + client = FireworksV1CompletionsClient( + model_id=model_id, + tokenizer_name_or_path=self.tokenizer_name_or_path, + api_key=self._api_key or os.getenv("FIREWORKS_API_KEY"), + base_url=self._base_url or os.getenv("FIREWORKS_BASE_URL"), + temperature=temperature, + max_tokens=max_tokens, + logprobs=True, + ) + self._clients[model_id] = client + return client + + def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]: + """Return one asyncio.Task per input row.""" + + async def process_row(row: EvaluationRow) -> EvaluationRow: + start_time = time.perf_counter() + + if not row.messages: + raise ValueError("EvaluationRow.messages is empty") + + completion_params = dict(config.completion_params or {}) + row.input_metadata.completion_params = completion_params + + model_id = completion_params.get("model") + if not model_id: + raise ValueError("completion_params.model is required") + temperature = float(completion_params.get("temperature", 1.0)) + max_tokens = int(completion_params.get("max_tokens", 256)) + completions_per_prompt = int(completion_params.get("n", 1)) + if completions_per_prompt < 1: + raise ValueError(f"n must be >= 1, got {completions_per_prompt}") + + messages_for_request: list[Message] = list(row.messages) + if self.drop_trailing_assistant_messages: + while messages_for_request and messages_for_request[-1].role == "assistant": + messages_for_request.pop() + + client = self._client_for(model_id=str(model_id), temperature=temperature, max_tokens=max_tokens) + + prompt_messages = _as_list_of_dicts(messages_for_request) + prompt_token_ids = client.build_prompt_token_ids(messages=prompt_messages, tools=row.tools) + + # Fire N parallel calls against the *same* prompt_token_ids. Each + # call produces one completion. We sample in parallel because + # Fireworks /v1/completions handles n=1 most reliably; requesting + # n>1 sometimes collapses to a single choice on partial failures, + # and we'd rather surface per-completion retry behaviour. + async def _one_completion() -> dict[str, Any]: + return await client.create_completion_from_prompt_ids( + prompt_token_ids=prompt_token_ids, tools=row.tools + ) + + results = await asyncio.gather(*[_one_completion() for _ in range(completions_per_prompt)]) + + completion_ids: list[list[int]] = [] + completions_text: list[str] = [] + inference_logprobs: list[list[float]] = [] + truncated: list[bool] = [] + finish_reasons: list[str] = [] + + for result in results: + completion_ids.append(list(result.get("completion_ids") or [])) + inference_logprobs.append(list(result.get("completion_logprobs") or [])) + finish_reason = str(result.get("finish_reason") or "unknown") + finish_reasons.append(finish_reason) + truncated.append(finish_reason == "length") + # Prefer the parsed assistant content if the client produced it; + # fall back to the raw choice text. + choice = (result.get("choices") or [{}])[0] + message = choice.get("message") or {} + text = str(message.get("content") or "") + completions_text.append(text) + + first_result = results[0] + prompt_ids = list(first_result.get("prompt_ids") or prompt_token_ids) + first_message = (first_result.get("choices") or [{}])[0].get("message") or {} + first_tool_calls = first_message.get("tool_calls") + + # Append the first completion as the assistant turn so that + # existing evaluators that inspect ``last_assistant_message`` keep + # working without modification. + row.messages = list(messages_for_request) + [ + Message( + role="assistant", + content=completions_text[0] if completions_text else "", + tool_calls=first_tool_calls, + logprobs=inference_logprobs[0] if inference_logprobs else None, + ) + ] + + row.execution_metadata.finish_reason = finish_reasons[0] if finish_reasons else None + row.execution_metadata.tool_call_count = len(first_tool_calls) if first_tool_calls else 0 + row.execution_metadata.usage = CompletionUsage( + prompt_tokens=len(prompt_ids), + completion_tokens=sum(len(ids) for ids in completion_ids), + total_tokens=len(prompt_ids) + sum(len(ids) for ids in completion_ids), + ) + row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time + + _append_extra( + row, + { + "prompt_ids": prompt_ids, + "completion_ids": completion_ids, + "inference_logprobs": inference_logprobs, + "completions_text": completions_text, + "truncated": truncated, + "finish_reasons": finish_reasons, + }, + ) + + default_logger.log(row) + return row + + semaphore = config.semaphore + + async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: + async with semaphore: + return await process_row(r) + + return [asyncio.create_task(_sem_wrapper(row)) for row in rows] diff --git a/tests/pytest/test_fireworks_training_rollout_processor.py b/tests/pytest/test_fireworks_training_rollout_processor.py new file mode 100644 index 00000000..828da764 --- /dev/null +++ b/tests/pytest/test_fireworks_training_rollout_processor.py @@ -0,0 +1,300 @@ +"""Unit tests for :class:`FireworksTrainingRolloutProcessor`. + +These tests stub out :class:`FireworksV1CompletionsClient` so no +network calls or tokenizers are required. They cover the contract that +managed Fireworks RFT depends on: + +* per-completion ``prompt_ids`` / ``completion_ids`` / ``inference_logprobs`` + land on ``EvaluationRow.execution_metadata.extra`` +* the first completion is appended as an ``assistant`` message so existing + evaluators keep working +* ``n`` completions produce ``n``-length lists +* ``finish_reason == 'length'`` is surfaced as ``truncated[i] = True`` +* trailing ``assistant`` messages are dropped before sampling by default +""" + +from __future__ import annotations + +import asyncio +from typing import Any +from unittest import mock + +import pytest + +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest import FireworksTrainingRolloutProcessor + + +class _StubConfig: + """Minimal stand-in for :class:`RolloutProcessorConfig`.""" + + def __init__(self, *, n: int = 2, max_tokens: int = 32, model: str = "accounts/fireworks/models/qwen3-8b"): + self.completion_params: dict[str, Any] = { + "model": model, + "temperature": 1.0, + "max_tokens": max_tokens, + "n": n, + } + self.semaphore = asyncio.Semaphore(4) + + +class _FakeCompletionsClient: + """Mimics ``FireworksV1CompletionsClient`` just enough for the processor. + + Each call to ``create_completion_from_prompt_ids`` returns a distinct + completion so we can assert per-completion indexing. + """ + + def __init__(self, completions: list[dict[str, Any]], prompt_token_ids: list[int] | None = None) -> None: + self._completions = list(completions) + self._prompt_token_ids = prompt_token_ids or [1, 2, 3] + self._call_count = 0 + self.close_calls = 0 + + def build_prompt_token_ids(self, *, messages: list[dict[str, Any]], tools: Any = None) -> list[int]: + return list(self._prompt_token_ids) + + async def create_completion_from_prompt_ids( + self, *, prompt_token_ids: list[int], tools: Any = None + ) -> dict[str, Any]: + idx = self._call_count + self._call_count += 1 + # Cycle through the provided completions so repeated calls return + # distinct payloads. + return self._completions[idx % len(self._completions)] + + async def close(self) -> None: + self.close_calls += 1 + + +def _make_row(*, trailing_assistant: bool = False) -> EvaluationRow: + messages = [Message(role="user", content="What is 2+2?")] + if trailing_assistant: + messages.append(Message(role="assistant", content="old cached answer")) + return EvaluationRow(messages=messages) + + +def _install_fake_client(monkeypatch: pytest.MonkeyPatch, client: _FakeCompletionsClient) -> None: + """Patch the client constructor so the processor uses our fake.""" + import eval_protocol.integrations.fireworks_v1_completions_client as mod + + monkeypatch.setattr(mod, "FireworksV1CompletionsClient", lambda **_kwargs: client) + + +@pytest.mark.asyncio +async def test_produces_per_completion_token_ids_and_logprobs(monkeypatch: pytest.MonkeyPatch) -> None: + """With n=2, the processor returns a row whose ``extra`` has two-entry lists.""" + completions = [ + { + "choices": [{"message": {"role": "assistant", "content": "4"}, "finish_reason": "stop"}], + "prompt_ids": [10, 11, 12], + "completion_ids": [40, 41], + "completion_logprobs": [-0.1, -0.2], + "finish_reason": "stop", + }, + { + "choices": [{"message": {"role": "assistant", "content": "four"}, "finish_reason": "length"}], + "prompt_ids": [10, 11, 12], + "completion_ids": [42, 43, 44], + "completion_logprobs": [-0.3, -0.4, -0.5], + "finish_reason": "length", + }, + ] + client = _FakeCompletionsClient(completions=completions, prompt_token_ids=[10, 11, 12]) + _install_fake_client(monkeypatch, client) + + processor = FireworksTrainingRolloutProcessor() + processor.setup() + + tasks = processor([_make_row()], _StubConfig(n=2)) + result = await tasks[0] + + extra = result.execution_metadata.extra + assert extra is not None + assert extra["prompt_ids"] == [10, 11, 12] + assert extra["completion_ids"] == [[40, 41], [42, 43, 44]] + assert extra["inference_logprobs"] == [[-0.1, -0.2], [-0.3, -0.4, -0.5]] + assert extra["completions_text"] == ["4", "four"] + assert extra["truncated"] == [False, True] + assert extra["finish_reasons"] == ["stop", "length"] + # The first completion must be exposed as the assistant message so + # evaluators that call ``last_assistant_message()`` still score. + assert result.messages[-1].role == "assistant" + assert result.messages[-1].content == "4" + # Usage should aggregate across completions. + assert result.execution_metadata.usage is not None + assert result.execution_metadata.usage.prompt_tokens == 3 + assert result.execution_metadata.usage.completion_tokens == 5 + + +@pytest.mark.asyncio +async def test_n_equals_one_single_completion(monkeypatch: pytest.MonkeyPatch) -> None: + """n=1 still produces the list-of-lists shape — just with length 1.""" + completions = [ + { + "choices": [{"message": {"role": "assistant", "content": "hi"}, "finish_reason": "stop"}], + "prompt_ids": [1, 2], + "completion_ids": [9], + "completion_logprobs": [-0.01], + "finish_reason": "stop", + } + ] + client = _FakeCompletionsClient(completions=completions, prompt_token_ids=[1, 2]) + _install_fake_client(monkeypatch, client) + + processor = FireworksTrainingRolloutProcessor() + processor.setup() + + tasks = processor([_make_row()], _StubConfig(n=1)) + result = await tasks[0] + + extra = result.execution_metadata.extra + assert extra["completion_ids"] == [[9]] + assert extra["inference_logprobs"] == [[-0.01]] + assert extra["truncated"] == [False] + + +@pytest.mark.asyncio +async def test_drops_trailing_assistant_by_default(monkeypatch: pytest.MonkeyPatch) -> None: + """Trailing ``assistant`` messages are dropped before the prompt is built.""" + captured_messages: dict[str, Any] = {} + + class _CaptureClient(_FakeCompletionsClient): + def build_prompt_token_ids(self, *, messages: list[dict[str, Any]], tools: Any = None) -> list[int]: + captured_messages["messages"] = messages + return [1, 2] + + completions = [ + { + "choices": [{"message": {"role": "assistant", "content": "x"}, "finish_reason": "stop"}], + "prompt_ids": [1, 2], + "completion_ids": [3], + "completion_logprobs": [-0.1], + "finish_reason": "stop", + } + ] + client = _CaptureClient(completions=completions) + _install_fake_client(monkeypatch, client) + + processor = FireworksTrainingRolloutProcessor() + processor.setup() + + tasks = processor([_make_row(trailing_assistant=True)], _StubConfig(n=1)) + await tasks[0] + + sent = captured_messages["messages"] + assert [m["role"] for m in sent] == ["user"], "Trailing assistant should have been dropped" + + +@pytest.mark.asyncio +async def test_keeps_trailing_assistant_when_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + """Explicitly disable the drop to support continuations.""" + captured_messages: dict[str, Any] = {} + + class _CaptureClient(_FakeCompletionsClient): + def build_prompt_token_ids(self, *, messages: list[dict[str, Any]], tools: Any = None) -> list[int]: + captured_messages["messages"] = messages + return [1, 2] + + completions = [ + { + "choices": [{"message": {"role": "assistant", "content": "x"}, "finish_reason": "stop"}], + "prompt_ids": [1, 2], + "completion_ids": [3], + "completion_logprobs": [-0.1], + "finish_reason": "stop", + } + ] + client = _CaptureClient(completions=completions) + _install_fake_client(monkeypatch, client) + + processor = FireworksTrainingRolloutProcessor(drop_trailing_assistant_messages=False) + processor.setup() + + tasks = processor([_make_row(trailing_assistant=True)], _StubConfig(n=1)) + await tasks[0] + + sent = captured_messages["messages"] + assert [m["role"] for m in sent] == ["user", "assistant"] + + +@pytest.mark.asyncio +async def test_missing_model_raises(monkeypatch: pytest.MonkeyPatch) -> None: + """``completion_params`` must carry a ``model`` id.""" + _install_fake_client(monkeypatch, _FakeCompletionsClient(completions=[])) + processor = FireworksTrainingRolloutProcessor() + processor.setup() + + config = _StubConfig() + del config.completion_params["model"] + + with pytest.raises(ValueError, match="completion_params.model"): + tasks = processor([_make_row()], config) + await tasks[0] + + +@pytest.mark.asyncio +async def test_invalid_n_raises(monkeypatch: pytest.MonkeyPatch) -> None: + """``n < 1`` is rejected.""" + _install_fake_client(monkeypatch, _FakeCompletionsClient(completions=[])) + processor = FireworksTrainingRolloutProcessor() + processor.setup() + + config = _StubConfig() + config.completion_params["n"] = 0 + + with pytest.raises(ValueError, match="n must be >= 1"): + tasks = processor([_make_row()], config) + await tasks[0] + + +@pytest.mark.asyncio +async def test_acleanup_closes_cached_clients(monkeypatch: pytest.MonkeyPatch) -> None: + """Every cached client should be ``.close()``-ed on cleanup.""" + completions = [ + { + "choices": [{"message": {"role": "assistant", "content": "x"}, "finish_reason": "stop"}], + "prompt_ids": [1], + "completion_ids": [2], + "completion_logprobs": [0.0], + "finish_reason": "stop", + } + ] + client = _FakeCompletionsClient(completions=completions) + _install_fake_client(monkeypatch, client) + + processor = FireworksTrainingRolloutProcessor() + processor.setup() + tasks = processor([_make_row()], _StubConfig(n=1)) + await tasks[0] + assert client.close_calls == 0 + + await processor.acleanup() + assert client.close_calls == 1 + + +@pytest.mark.asyncio +async def test_preserves_existing_extra(monkeypatch: pytest.MonkeyPatch) -> None: + """Pre-existing ``execution_metadata.extra`` keys must not be clobbered.""" + completions = [ + { + "choices": [{"message": {"role": "assistant", "content": "x"}, "finish_reason": "stop"}], + "prompt_ids": [1], + "completion_ids": [2], + "completion_logprobs": [0.0], + "finish_reason": "stop", + } + ] + client = _FakeCompletionsClient(completions=completions) + _install_fake_client(monkeypatch, client) + + processor = FireworksTrainingRolloutProcessor() + processor.setup() + + row = _make_row() + row.execution_metadata.extra = {"my_custom_field": "hello"} + tasks = processor([row], _StubConfig(n=1)) + result = await tasks[0] + + assert result.execution_metadata.extra["my_custom_field"] == "hello" + assert result.execution_metadata.extra["prompt_ids"] == [1]