diff --git a/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py b/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py index d646600426..8ef9ec1443 100644 --- a/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py +++ b/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py @@ -648,6 +648,35 @@ async def test_async_bidi_stream_query(self): events.append(event) assert len(events) == 1 + @pytest.mark.asyncio + async def test_async_bidi_stream_query_with_state(self): + app = reasoning_engines.AdkApp( + agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL) + ) + assert app._tmpl_attrs.get("runner") is None + app.set_up() + app._tmpl_attrs["runner"] = _MockRunner() + request_queue = asyncio.Queue() + request_dict = { + "user_id": _TEST_USER_ID, + "state": {"test_key": "test_val"}, + "live_request": { + "input": "What is the exchange rate from USD to SEK?", + }, + } + + await request_queue.put(request_dict) + await request_queue.put(None) # Sentinel to end the stream. + + with mock.patch.object( + app, "async_create_session", wraps=app.async_create_session + ) as mock_create_session: + async for _ in app.bidi_stream_query(request_queue): + pass + mock_create_session.assert_called_once_with( + user_id=_TEST_USER_ID, state={"test_key": "test_val"} + ) + def test_create_session(self): app = reasoning_engines.AdkApp( agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL) diff --git a/vertexai/preview/reasoning_engines/templates/adk.py b/vertexai/preview/reasoning_engines/templates/adk.py index 8d970d66f5..4a3a62e3bb 100644 --- a/vertexai/preview/reasoning_engines/templates/adk.py +++ b/vertexai/preview/reasoning_engines/templates/adk.py @@ -1178,7 +1178,8 @@ async def bidi_stream_query( if not self._tmpl_attrs.get("runner"): self.set_up() if not session_id: - session = await self.async_create_session(user_id=user_id) + state = first_request.get("state") + session = await self.async_create_session(user_id=user_id, state=state) session_id = session.id run_config = _validate_run_config(run_config)