Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 38 additions & 39 deletions eval_protocol/adapters/fireworks_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
import os

from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message
from eval_protocol.tracing import PayloadType, decode_payloads
from .base import BaseAdapter
from .lp_deserializer import decompress_and_parse_lp
from .r3_deserializer import decompress_and_parse_r3
from .utils import extract_messages_from_data
from ..common_utils import get_user_agent

Expand Down Expand Up @@ -102,45 +101,45 @@
):
break # Break early if we've found all the metadata we need

# Extract router replay payloads when present
# Decode out-of-band gateway payloads (router replay, logprobs, prompt
# token ids) via the standalone tracing decoder registry, then map the
# decoded values onto the row. Format/decoding lives in
# ``eval_protocol.tracing``; this adapter only does EvaluationRow glue.
payloads = trace.get("payloads")
if isinstance(payloads, dict):
router_replay = payloads.get("router_replay")
if isinstance(router_replay, dict) and router_replay.get("data"):
try:
matrices, r3_meta = decompress_and_parse_r3(router_replay["data"])
if execution_metadata.extra is None:
execution_metadata.extra = {}
execution_metadata.extra["routing_matrices"] = matrices
execution_metadata.extra["routing_metadata"] = r3_meta
except Exception as e:
logger.warning("Failed to decompress R3 payload for trace %s: %s", trace.get("id"), e)

logprobs_payload = payloads.get("logprobs")
if isinstance(logprobs_payload, dict) and logprobs_payload.get("data"):
try:
logprobs, token_ids, lp_meta = decompress_and_parse_lp(logprobs_payload["data"])
if execution_metadata.extra is None:
execution_metadata.extra = {}
execution_metadata.extra["completion_logprobs"] = logprobs
if token_ids is not None:
execution_metadata.extra["completion_token_ids"] = token_ids
execution_metadata.extra["logprobs_metadata"] = lp_meta

for i in range(len(messages) - 1, -1, -1):
if messages[i].role == "assistant":
content_entries = [{"logprob": lp} for lp in logprobs]
if token_ids is not None:
for entry, tid in zip(content_entries, token_ids):
entry["token_id"] = tid
messages[i].logprobs = {"content": content_entries}
break
except Exception as e:
logger.warning(
"Failed to decompress logprobs payload for trace %s: %s",
trace.get("id"),
e,
)
decoded = decode_payloads(
payloads,
on_error=lambda pt, e: logger.warning(
"Failed to decode %s payload for trace %s: %s", pt.value, trace.get("id"), e
),
)
if decoded and execution_metadata.extra is None:
execution_metadata.extra = {}

if (dp := decoded.get(PayloadType.ROUTER_REPLAY)) is not None:
execution_metadata.extra["routing_matrices"] = dp.value

Check failure on line 120 in eval_protocol/adapters/fireworks_tracing.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Object of type "None" is not subscriptable (reportOptionalSubscript)
execution_metadata.extra["routing_metadata"] = dp.metadata

Check failure on line 121 in eval_protocol/adapters/fireworks_tracing.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Object of type "None" is not subscriptable (reportOptionalSubscript)

if (dp := decoded.get(PayloadType.LOGPROBS)) is not None:
logprobs = dp.value
token_ids = dp.extras.get("token_ids")
execution_metadata.extra["completion_logprobs"] = logprobs

Check failure on line 126 in eval_protocol/adapters/fireworks_tracing.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Object of type "None" is not subscriptable (reportOptionalSubscript)
if token_ids is not None:
execution_metadata.extra["completion_token_ids"] = token_ids

Check failure on line 128 in eval_protocol/adapters/fireworks_tracing.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Object of type "None" is not subscriptable (reportOptionalSubscript)
execution_metadata.extra["logprobs_metadata"] = dp.metadata

Check failure on line 129 in eval_protocol/adapters/fireworks_tracing.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Object of type "None" is not subscriptable (reportOptionalSubscript)

for i in range(len(messages) - 1, -1, -1):
if messages[i].role == "assistant":
content_entries = [{"logprob": lp} for lp in logprobs]
if token_ids is not None:
for entry, tid in zip(content_entries, token_ids):
entry["token_id"] = tid
messages[i].logprobs = {"content": content_entries}
break

if (dp := decoded.get(PayloadType.PROMPT_TOKEN_IDS)) is not None:
execution_metadata.extra["prompt_token_ids"] = dp.value

Check failure on line 141 in eval_protocol/adapters/fireworks_tracing.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Object of type "None" is not subscriptable (reportOptionalSubscript)
execution_metadata.extra["prompt_token_ids_metadata"] = dp.metadata

Check failure on line 142 in eval_protocol/adapters/fireworks_tracing.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Object of type "None" is not subscriptable (reportOptionalSubscript)

return EvaluationRow(
messages=messages,
Expand Down
2 changes: 2 additions & 0 deletions eval_protocol/pytest/tracing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def _merge_payloads_into_longest_row(longest_row: EvaluationRow, rows: List[Eval
for key in (
"completion_logprobs",
"completion_token_ids",
"prompt_token_ids",
"prompt_token_ids_metadata",
"logprobs_metadata",
"routing_matrices",
"routing_metadata",
Expand Down
109 changes: 109 additions & 0 deletions eval_protocol/tracing/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# `eval_protocol.tracing` — Fireworks tracing-gateway payload decoders

Standalone helpers for decoding the out-of-band **payloads** the Fireworks
tracing gateway stores alongside a trace (prompt token IDs, completion logprobs,
router-replay routing matrices).

This package is intentionally self-contained: it depends only on the stdlib and
`zstandard`. It does **not** import `EvaluationRow`, rollout processors, or any
other Eval Protocol machinery, so you can use it even if you are not using EP for
rollouts — just point at it for extracting gateway payloads.

## What is a "payload"?

When you read a trace with payloads included:

```
GET {gateway}/v1/traces?rollout_id=...&include_payloads=true
```

each trace carries a `payloads` object like:

```json
{
"payloads": {
"prompt_token_ids": {
"manifest": { "PayloadVersion": "pti/v1", "...": "..." },
"data": "<base64 of zstd-compressed bytes>"
},
"logprobs": { "manifest": { "PayloadVersion": "lp/v1" }, "data": "..." },
"router_replay": { "manifest": { "PayloadVersion": "r3/v1" }, "data": "..." }
}
}
```

The `data` field is `base64(zstd(raw_bytes))`. Each payload type has its own
`raw_bytes` encoding (`pti/v1` is a JSON int array; `lp/v1` and `r3/v1` are packed
binary). This package hides all of that.

## Usage

Decode everything at once (the common case):

```python
from eval_protocol.tracing import decode_payloads, PayloadType

decoded = decode_payloads(trace["payloads"])

if PayloadType.PROMPT_TOKEN_IDS in decoded:
token_ids = decoded[PayloadType.PROMPT_TOKEN_IDS].value # List[int]

if PayloadType.LOGPROBS in decoded:
lp = decoded[PayloadType.LOGPROBS]
logprobs = lp.value # List[float]
token_ids = lp.extras.get("token_ids") # Optional[List[int]]

if PayloadType.ROUTER_REPLAY in decoded:
matrices = decoded[PayloadType.ROUTER_REPLAY].value # List[Optional[str]]
```

If you have the whole trace dict, `decode_trace(trace)` reaches into
`trace["payloads"]` for you.

Decode a single payload:

```python
from eval_protocol.tracing import decode_payload, PayloadType

dp = decode_payload(PayloadType.PROMPT_TOKEN_IDS, trace["payloads"]["prompt_token_ids"]["data"])
dp.value # List[int]
```

### Error handling

`decode_payloads` isolates per-payload failures: if one payload fails to decode,
the others are still returned. Pass `on_error=callback(payload_type, exc)` to
control logging (defaults to a warning):

```python
decode_payloads(payloads, on_error=lambda pt, e: print(f"{pt} failed: {e}"))
```

## Return type

`decode_payloads` / `decode_trace` return `Dict[PayloadType, DecodedPayload]`.

`DecodedPayload` fields:

| field | meaning |
|----------------|-------------------------------------------------------------------|
| `payload_type` | `PayloadType` enum member |
| `value` | decoded value (type depends on `payload_type`, see below) |
| `metadata` | decoded header/manifest metadata (token counts, scope, etc.) |
| `extras` | type-specific extras (e.g. logprobs `token_ids`) |

`value` by type:

| `PayloadType` | `value` | notes |
|---------------------|--------------------------|----------------------------------------------|
| `PROMPT_TOKEN_IDS` | `List[int]` | prompt token ids |
| `LOGPROBS` | `List[float]` | per completion token; ids in `extras["token_ids"]` (or `None`) |
| `ROUTER_REPLAY` | `List[Optional[str]]` | per-token base64 routing matrices; `None` where absent |

## Adding a new payload type

1. Add a member to `PayloadType` in `types.py`.
2. Add a `decode_<name>(data_b64) -> DecodedPayload` function in a new module.
3. Register it in `PAYLOAD_DECODERS` in `registry.py`.

`decode_payloads` picks it up automatically.
43 changes: 43 additions & 0 deletions eval_protocol/tracing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Decode Fireworks tracing-gateway payloads.

Standalone, dependency-light helpers (stdlib + ``zstandard`` only) for turning
the binary/JSON ``payloads`` returned by the Fireworks tracing gateway
(``GET /traces?include_payloads=true``) into Python values. No EvaluationRow or
rollout machinery required -- usable on its own.

Typical use::

from eval_protocol.tracing import decode_payloads, PayloadType

decoded = decode_payloads(trace["payloads"])
decoded[PayloadType.PROMPT_TOKEN_IDS].value # List[int]
decoded[PayloadType.LOGPROBS].value # List[float]
decoded[PayloadType.ROUTER_REPLAY].value # List[Optional[str]]

See ``README.md`` in this package for details.
"""

from __future__ import annotations

from .logprobs import decode_logprobs
from .prompt_token_ids import decode_prompt_token_ids
from .registry import (
PAYLOAD_DECODERS,
decode_payload,
decode_payloads,
decode_trace,
)
from .router_replay import decode_router_replay
from .types import DecodedPayload, PayloadType

__all__ = [
"PayloadType",
"DecodedPayload",
"PAYLOAD_DECODERS",
"decode_payload",
"decode_payloads",
"decode_trace",
"decode_prompt_token_ids",
"decode_logprobs",
"decode_router_replay",
]
17 changes: 17 additions & 0 deletions eval_protocol/tracing/_decompress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Shared helper for the Fireworks tracing-gateway payload decoders."""

from __future__ import annotations

import base64

import zstandard as zstd


def decompress_b64(data_b64: str) -> bytes:
"""Base64-decode then zstd-decompress a gateway ``payloads.*.data`` blob.

The gateway stores every payload as ``base64(zstd(raw_bytes))``; this is the
common first step every decoder shares before interpreting ``raw_bytes``.
"""
compressed = base64.b64decode(data_b64)
return zstd.ZstdDecompressor().decompress(compressed)
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import zstandard as zstd

from .types import DecodedPayload, PayloadType

MAGIC = b"LP01"
HEADER_VERSION = 1
MISSING_TOKEN_ID = -1
Expand Down Expand Up @@ -107,3 +109,18 @@ def decompress_and_parse_lp(data_b64: str) -> Tuple[List[float], Optional[List[i
decompressor = zstd.ZstdDecompressor()
raw = decompressor.decompress(compressed)
return parse_logprobs(raw)


def decode_logprobs(data_b64: str) -> DecodedPayload:
"""Decode a gateway ``payloads.logprobs.data`` blob into a ``DecodedPayload``.

``value`` is the per-completion-token logprob list; per-token ids (when all
valid) are available under ``extras["token_ids"]``.
"""
logprobs, token_ids, metadata = decompress_and_parse_lp(data_b64)
return DecodedPayload(
payload_type=PayloadType.LOGPROBS,
value=logprobs,
metadata=metadata,
extras={"token_ids": token_ids},
)
31 changes: 31 additions & 0 deletions eval_protocol/tracing/prompt_token_ids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""``pti/v1`` decoder for prompt token ID payloads.

Inverse of the tracing gateway's ``serialize_prompt_token_ids``: the gateway
stores prompt token IDs as ``base64(zstd(json.dumps(token_ids)))`` -- a compact
JSON int array, no bespoke binary header.
"""

from __future__ import annotations

import json
from typing import Any, Dict, List, Tuple

from ._decompress import decompress_b64
from .types import DecodedPayload, PayloadType


def parse_prompt_token_ids(raw: bytes) -> Tuple[List[int], Dict[str, Any]]:
"""Parse uncompressed ``pti/v1`` bytes (a JSON int array) into ids + metadata."""
token_ids = json.loads(raw)
metadata: Dict[str, Any] = {"scope": "prompt_only", "token_count": len(token_ids)}
return token_ids, metadata


def decode_prompt_token_ids(data_b64: str) -> DecodedPayload:
"""Decode a gateway ``payloads.prompt_token_ids.data`` blob."""
token_ids, metadata = parse_prompt_token_ids(decompress_b64(data_b64))
return DecodedPayload(
payload_type=PayloadType.PROMPT_TOKEN_IDS,
value=token_ids,
metadata=metadata,
)
Loading
Loading