feat: add FireworksTrainingRolloutProcessor for RFT (FIR2-1351)#445
Closed
feat: add FireworksTrainingRolloutProcessor for RFT (FIR2-1351)#445
Conversation
A new default RolloutProcessor that drives Fireworks /v1/completions via `FireworksV1CompletionsClient` and surfaces the per-sample token-level data required by reinforcement fine-tuning training (GRPO, CISPO, DAPO, GSPO). Problem ------- The existing `SingleTurnRolloutProcessor` uses LiteLLM chat completions and discards token ids + inference logprobs, so scored `EvaluationRow`s are fine for evaluation but cannot feed a training loop. Today, teams that need training-ready rollouts write a bespoke `RolloutProcessor` (the FrozenLake example in fw-ai/cookbook is ~800 lines). This puts token ids / logprobs out of reach of every customer evaluator bundle unless they rewrite their own processor. What it does ------------ For each `EvaluationRow`, `FireworksTrainingRolloutProcessor`: * Reads model / temperature / max_tokens / n from `completion_params`. * Builds prompt token ids locally via `FireworksV1CompletionsClient". * Fires `n` parallel `/v1/completions` calls from the same `prompt_token_ids", so each completion gets independent retry behaviour rather than collapsing on partial server failures. * Appends the first completion as the assistant message so existing evaluators that inspect `last_assistant_message()" keep scoring. * Populates `EvaluationRow.execution_metadata.extra" with: - `prompt_ids: list[int]" (shared across completions) - `completion_ids: list[list[int]]" (per-completion) - `inference_logprobs: list[list[float]]" (aligned to completion tokens) - `completions_text: list[str]" - `truncated: list[bool]" (`finish_reason == 'length'") - `finish_reasons: list[str]" * Merges into pre-existing `extra" rather than clobbering it. * Caches one client per model id; closes them all via `acleanup()". Shape rationale --------------- OpenEnvRolloutProcessor already writes flat `prompt_ids" / `completion_ids" concatenated across turns (multi-turn, per-episode agent rollouts). Single-turn RFT samples n>1 completions per prompt for advantage estimation and needs per-completion indexing, hence the `list[list[...]]" shape here. The training adapter on the consumer side can key into either convention without loss of generality. Tests ----- 8 new unit tests stub `FireworksV1CompletionsClient" so no network calls or tokenizers are needed; existing `SingleTurnRolloutProcessor" suite still passes. Fixes FIR2-1351
Contributor
Author
|
Closing: not needed. Upstream review revealed that the Fireworks managed RFT cutover does not require a training-aware
That pattern works for CISPO (and by extension GRPO/DAPO/GSPO) today, and it keeps eval-protocol's rollout contract single-purpose. No reason to push a training-specific default into EP. No harm done — the module is self-contained and not wired into anything else. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Adds a new default
RolloutProcessorsubclass that drives Fireworks/v1/completionsviaFireworksV1CompletionsClientand surfaces the per-sample token ids, completion ids, and inference logprobs required by reinforcement-fine-tuning training loops (GRPO, CISPO, DAPO, GSPO).Problem
SingleTurnRolloutProcessoruses LiteLLM chat completions and discards token-level data, so scoredEvaluationRows are fine for evaluation but cannot feed a training loop. Today, teams that need training-ready rollouts write a bespokeRolloutProcessor— the FrozenLake example in fw-ai/cookbook is ~800 lines. This puts training-compatible rollouts out of reach of every customer evaluator bundle unless they reimplement the rollout path themselves.Fireworks' managed RFT flow needs this for every customer job, so promoting the pattern into an Eval Protocol default removes the per-customer 800-line tax.
What it does
For each
EvaluationRow,FireworksTrainingRolloutProcessor:model/temperature/max_tokens/nfromcompletion_params.FireworksV1CompletionsClient.build_prompt_token_ids(...).nparallel/v1/completionscalls from the sameprompt_token_ids, so each completion gets independent retry behaviour rather than collapsing on partial server failures.last_assistant_message()keep scoring without modification.EvaluationRow.execution_metadata.extrawith the per-completion payload:prompt_ids: list[int](shared across completions)completion_ids: list[list[int]](one per completion)inference_logprobs: list[list[float]](aligned to completion tokens)completions_text: list[str]truncated: list[bool](True whenfinish_reason == 'length')finish_reasons: list[str]extrarather than clobbering it (coexists withOpenEnvRolloutProcessor, tracing_utils, etc.).FireworksV1CompletionsClientper model id; closes them all viaacleanup().Shape rationale
OpenEnvRolloutProcessoralready writes flatprompt_ids/completion_idsconcatenated across turns (multi-turn, per-episode agent rollouts). Single-turn RFT samplesn>1completions per prompt for advantage estimation and needs per-completion indexing, hence thelist[list[...]]shape here. A training adapter on the consumer side can key into either convention without loss of generality.Architecture
flowchart LR Row[EvaluationRow<br/>messages, tools] -->|messages → dicts| P[FireworksTrainingRolloutProcessor] P -->|build_prompt_token_ids| Client[FireworksV1CompletionsClient] Client -->|/v1/completions × n| API[Fireworks API] API -->|n completions<br/>with prompt_ids,<br/>completion_ids, logprobs| P P -->|first completion<br/>→ assistant message| OutMsgs[row.messages] P -->|execution_metadata.extra| Extra[prompt_ids, completion_ids,<br/>inference_logprobs, completions_text,<br/>truncated, finish_reasons]Type of Change
Testing
8 new unit tests in
tests/pytest/test_fireworks_training_rollout_processor.py, using a stubFireworksV1CompletionsClientso no network calls or HF tokenizers are required:extrapayload hasn-length lists with correct shapes and values (n=2 case).n=1still produces the list-of-lists shape with length 1 (not a scalar).assistantmessages are dropped by default before sampling.assistantmessages are preserved when the flag is disabled.modelincompletion_paramsraisesValueError.n < 1raisesValueError.acleanup()closes every cached client.execution_metadata.extraare preserved across rollout.All existing
SingleTurnRolloutProcessortests still pass.Surface
eval_protocol.pytest.FireworksTrainingRolloutProcessor.RolloutProcessorbase class, orEvaluationRowschema — all new data is carried through the existingexecution_metadata.extrabag.Follow-ups (not in this PR)
FIR2-1352.execution_metadata.extrashape introduced here — tracked asFIR2-1353.FIR2-1366.