diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index 85aa6739..2d8316d2 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -375,6 +375,37 @@ async def async_search_logs( ) return results + async def async_get_status(self, session: aiohttp.ClientSession, rollout_id: str) -> Optional[Dict[str, Any]]: + """Fetch rollout status from the lightweight /status endpoint. + + Returns the parsed JSON response or None if the status is not yet available. + Response shape: {"rollout_id": "...", "status": {"code": ...} | null, "extras": {...} | null} + """ + headers = { + "Authorization": f"Bearer {self._get_api_key()}", + "User-Agent": get_user_agent(), + } + params: Dict[str, Any] = {"rollout_id": rollout_id} + timeout = aiohttp.ClientTimeout(total=self.timeout) + + urls_to_try = [f"{self.base_url}/v1/status", f"{self.base_url}/status"] + last_error: Optional[str] = None + for url in urls_to_try: + try: + async with session.get(url, params=params, headers=headers, timeout=timeout) as resp: + if resp.status == 404: + last_error = f"404 for {url}" + continue + resp.raise_for_status() + return (await resp.json(content_type=None)) or {} + except (aiohttp.ClientError, asyncio.TimeoutError, json.JSONDecodeError) as e: + last_error = str(e) + continue + + if last_error: + logger.error("Failed to fetch status from Fireworks (tried %s): %s", urls_to_try, last_error) + return None + def get_evaluation_rows( self, tags: List[str], diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index 6b622218..632d5e00 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -122,45 +122,26 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow: while time.time() < deadline: session = self._get_or_create_session() - completed_logs = await self._tracing_adapter.async_search_logs( - session, tags=[f"rollout_id:{row.execution_metadata.rollout_id}"] + status_result = await self._tracing_adapter.async_get_status( + session, + rollout_id=row.execution_metadata.rollout_id, ) - # Filter for logs that actually have status information - status_logs = [] - for log in completed_logs: - status_dict = log.get("status") - if status_dict and isinstance(status_dict, dict) and "code" in status_dict: - status_logs.append(log) - - if status_logs: - if len(status_logs) > 1: - logger.warning( - "Found %s status logs for rollout %s; expected at most 1. Using the first one: %s", - len(status_logs), - row.execution_metadata.rollout_id, - status_logs[0], - ) - # Use the first log with status information - status_log = status_logs[0] - status_dict = status_log.get("status") - raw_extras = status_log.get("extras") or {} - status_extras = { - k: v for k, v in raw_extras.items() if k not in ("logger_name", "level", "timestamp") - } + status = (status_result or {}).get("status") + if isinstance(status, dict) and "code" in status: + status_code = status["code"] + if status_code == Status.Code.RUNNING: + await asyncio.sleep(poll_interval) + continue logger.info( - f"Found status log for rollout {row.execution_metadata.rollout_id}: {status_log.get('message', '')}" + "Found status for rollout %s with code %s", + row.execution_metadata.rollout_id, + status_code, ) - status_code = status_dict.get("code") - status_message = status_dict.get("message", "") - status_details = status_dict.get("details", []) + status_message = status.get("message", "") or "" + status_details = status.get("details", []) or [] - logger.info( - f"Found Fireworks log for rollout {row.execution_metadata.rollout_id} with status code {status_code}" - ) - - # Create and raise exception if appropriate, preserving original message exception = exception_for_status_code(status_code, status_message) if exception is not None: raise exception @@ -171,10 +152,12 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow: details=status_details, ) - if row.execution_metadata.extra: - row.execution_metadata.extra.update(status_extras) - else: - row.execution_metadata.extra = status_extras + status_extras = (status_result or {}).get("extras") + if isinstance(status_extras, dict): + if row.execution_metadata.extra: + row.execution_metadata.extra.update(status_extras) + else: + row.execution_metadata.extra = status_extras logger.info("Stopping polling for rollout %s", row.execution_metadata.rollout_id) break