Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions eval_protocol/adapters/fireworks_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,37 @@ async def async_search_logs(
)
return results

async def async_get_status(self, session: aiohttp.ClientSession, rollout_id: str) -> Optional[Dict[str, Any]]:
"""Fetch rollout status from the lightweight /status endpoint.

Returns the parsed JSON response or None if the status is not yet available.
Response shape: {"rollout_id": "...", "status": {"code": ...} | null, "extras": {...} | null}
"""
headers = {
"Authorization": f"Bearer {self._get_api_key()}",
"User-Agent": get_user_agent(),
}
params: Dict[str, Any] = {"rollout_id": rollout_id}
timeout = aiohttp.ClientTimeout(total=self.timeout)

urls_to_try = [f"{self.base_url}/v1/status", f"{self.base_url}/status"]
last_error: Optional[str] = None
for url in urls_to_try:
try:
async with session.get(url, params=params, headers=headers, timeout=timeout) as resp:
if resp.status == 404:
last_error = f"404 for {url}"
continue
resp.raise_for_status()
return (await resp.json(content_type=None)) or {}
except (aiohttp.ClientError, asyncio.TimeoutError, json.JSONDecodeError) as e:
last_error = str(e)
continue

if last_error:
logger.error("Failed to fetch status from Fireworks (tried %s): %s", urls_to_try, last_error)
return None

def get_evaluation_rows(
self,
tags: List[str],
Expand Down
57 changes: 20 additions & 37 deletions eval_protocol/pytest/remote_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,45 +122,26 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:

while time.time() < deadline:
session = self._get_or_create_session()
completed_logs = await self._tracing_adapter.async_search_logs(
session, tags=[f"rollout_id:{row.execution_metadata.rollout_id}"]
status_result = await self._tracing_adapter.async_get_status(
session,
rollout_id=row.execution_metadata.rollout_id,
)
# Filter for logs that actually have status information
status_logs = []
for log in completed_logs:
status_dict = log.get("status")
if status_dict and isinstance(status_dict, dict) and "code" in status_dict:
status_logs.append(log)

if status_logs:
if len(status_logs) > 1:
logger.warning(
"Found %s status logs for rollout %s; expected at most 1. Using the first one: %s",
len(status_logs),
row.execution_metadata.rollout_id,
status_logs[0],
)
# Use the first log with status information
status_log = status_logs[0]
status_dict = status_log.get("status")
raw_extras = status_log.get("extras") or {}
status_extras = {
k: v for k, v in raw_extras.items() if k not in ("logger_name", "level", "timestamp")
}
status = (status_result or {}).get("status")
if isinstance(status, dict) and "code" in status:
status_code = status["code"]
if status_code == Status.Code.RUNNING:
await asyncio.sleep(poll_interval)
continue

logger.info(
f"Found status log for rollout {row.execution_metadata.rollout_id}: {status_log.get('message', '')}"
"Found status for rollout %s with code %s",
row.execution_metadata.rollout_id,
status_code,
)

status_code = status_dict.get("code")
status_message = status_dict.get("message", "")
status_details = status_dict.get("details", [])
status_message = status.get("message", "") or ""
status_details = status.get("details", []) or []

logger.info(
f"Found Fireworks log for rollout {row.execution_metadata.rollout_id} with status code {status_code}"
)

# Create and raise exception if appropriate, preserving original message
exception = exception_for_status_code(status_code, status_message)
if exception is not None:
raise exception
Expand All @@ -171,10 +152,12 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
details=status_details,
)

if row.execution_metadata.extra:
row.execution_metadata.extra.update(status_extras)
else:
row.execution_metadata.extra = status_extras
status_extras = (status_result or {}).get("extras")
if isinstance(status_extras, dict):
if row.execution_metadata.extra:
row.execution_metadata.extra.update(status_extras)
else:
row.execution_metadata.extra = status_extras
Comment thread
xzrderek marked this conversation as resolved.

logger.info("Stopping polling for rollout %s", row.execution_metadata.rollout_id)
break
Expand Down
Loading