Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies = [
"gunicorn==23.0.0",
"uvicorn[standard]==0.40.0",
"websockets==15.0.1",
"requests==2.32.5",
"requests==2.33.0",
"itsdangerous==2.2.0",
"Pillow==12.1.1",
"drf-spectacular==0.29.0",
Expand Down
9 changes: 5 additions & 4 deletions backend/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,14 @@ def get_model_string(model: str | None = None) -> str:

@lru_cache(maxsize=1)
def check_lm_ready_or_raise() -> None:
from baserow_enterprise.assistant.retrying_model import _resolve_model

model = get_model_string()
test_agent = Agent(
output_type=str, instructions="Respond with 'ok'.", name="test_agent"
)
try:
test_agent.run_sync("Test", model=model)
test_agent.run_sync("Test", model=_resolve_model(model))
except Exception as e:
raise AssistantModelNotSupportedError(
f"The model '{model}' is not supported or accessible: {e}"
Expand Down
165 changes: 148 additions & 17 deletions enterprise/backend/src/baserow_enterprise/assistant/retrying_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@

import asyncio
import json
import os
import re
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Callable
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from typing import Any
Expand Down Expand Up @@ -172,29 +173,159 @@ def _try_recover_tool_use_failed(exc: Exception) -> ModelResponse | None:
return _recover_failed_generation(failed_gen, model_name)


# ---------------------------------------------------------------------------
# Provider credential resolution
# ---------------------------------------------------------------------------
# Maps provider prefixes to their native env-var names.
#
# Backward-compat: when a provider-specific var is not set we fall back to
# the deprecated UDSPY_LM_* vars so existing deployments keep working.
# This compat layer is intentionally minimal — new providers should NOT be
# added here; operators should use the standard env vars instead.

_PROVIDER_ENV: dict[str, dict[str, str | None]] = {
"openai": {
"api_key": "OPENAI_API_KEY",
"base_url": "OPENAI_BASE_URL",
},
"groq": {
"api_key": "GROQ_API_KEY",
},
"anthropic": {
"api_key": "ANTHROPIC_API_KEY",
},
"ollama": {
"base_url": "OLLAMA_BASE_URL",
},
}


def _resolve_credentials(provider: str) -> dict[str, str | None]:
"""Return ``{"api_key": ..., "base_url": ...}`` for *provider*.

Checks the provider-specific env var first, then falls back to the
deprecated ``UDSPY_LM_API_KEY`` / ``UDSPY_LM_OPENAI_COMPATIBLE_BASE_URL``
for backward-compat. Never touches ``os.environ``.
"""

env = _PROVIDER_ENV.get(provider, {})
api_key_var = env.get("api_key")
base_url_var = env.get("base_url")

api_key = (
(os.getenv(api_key_var) if api_key_var else None)
or os.getenv("UDSPY_LM_API_KEY")
or None
)
base_url = (
(os.getenv(base_url_var) if base_url_var else None)
or os.getenv("UDSPY_LM_OPENAI_COMPATIBLE_BASE_URL")
or None
)
return {"api_key": api_key, "base_url": base_url}


# ---------------------------------------------------------------------------
# Per-provider model factories
# ---------------------------------------------------------------------------


def _make_openai(name: str, creds: dict[str, str | None]) -> Model:
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.providers.openai import OpenAIProvider

kwargs = {k: v for k, v in creds.items() if v is not None}
return OpenAIChatModel(name, provider=OpenAIProvider(**kwargs))


def _make_groq(name: str, creds: dict[str, str | None]) -> Model:
from pydantic_ai.models.groq import GroqModel
from pydantic_ai.providers.groq import GroqProvider

return GroqModel(name, provider=GroqProvider(api_key=creds["api_key"]))


def _make_anthropic(name: str, creds: dict[str, str | None]) -> Model:
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.providers.anthropic import AnthropicProvider

return AnthropicModel(name, provider=AnthropicProvider(api_key=creds["api_key"]))


def _make_ollama(name: str, creds: dict[str, str | None]) -> Model:
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.providers.ollama import OllamaProvider

base_url = creds["base_url"] or "http://localhost:11434/v1"
return OpenAIChatModel(name, provider=OllamaProvider(base_url=base_url))


def _make_google(name: str, creds: dict[str, str | None]) -> Model:
"""Google models need a fresh httpx client per call to avoid event-loop
binding issues in Django async views.
See: https://github.com/pydantic/pydantic-ai/issues/3240
"""

import httpx
from pydantic_ai.models.google import GoogleModel
from pydantic_ai.providers.google import GoogleProvider

return GoogleModel(
name,
provider=GoogleProvider(
api_key=creds["api_key"], http_client=httpx.AsyncClient()
),
)


def _make_google_vertex(name: str, creds: dict[str, str | None]) -> Model:
import httpx
from pydantic_ai.models.google import GoogleModel
from pydantic_ai.providers.google import GoogleProvider

return GoogleModel(
name,
provider=GoogleProvider(
api_key=creds["api_key"],
http_client=httpx.AsyncClient(),
vertexai=True,
),
)


_PROVIDER_FACTORIES: dict[str, Callable[[str, dict[str, str | None]], Model]] = {
"openai": _make_openai,
"groq": _make_groq,
"anthropic": _make_anthropic,
"ollama": _make_ollama,
"google-gla": _make_google,
"google": _make_google,
"google-vertex": _make_google_vertex,
}


# ---------------------------------------------------------------------------
# Model resolution
# ---------------------------------------------------------------------------


def _resolve_model(model_name: str) -> Model:
"""Resolve a model name to a pydantic-ai Model instance.

For Google models, constructs the model with a fresh
``httpx.AsyncClient`` instead of relying on ``infer_model()`` which
uses a process-global cached client. That cached client binds to the
event loop at creation time and breaks when reused on a different loop
(common in Django async views).
See: https://github.com/pydantic/pydantic-ai/issues/3240
Uses explicit provider construction with credential fallback to
``UDSPY_LM_API_KEY`` / ``UDSPY_LM_OPENAI_COMPATIBLE_BASE_URL``
so we never need to set ``os.environ``.
"""

if model_name.startswith(("google-gla:", "google:", "google-vertex:")):
import httpx
from pydantic_ai.models.google import GoogleModel
from pydantic_ai.providers.google import GoogleProvider
provider = model_name.split(":")[0] if ":" in model_name else "openai"
name = model_name.split(":", 1)[1] if ":" in model_name else model_name

vertexai = model_name.startswith("google-vertex:")
google_model_name = model_name.split(":", 1)[1]
return GoogleModel(
google_model_name,
provider=GoogleProvider(http_client=httpx.AsyncClient(), vertexai=vertexai),
)
factory = _PROVIDER_FACTORIES.get(provider)
if factory is not None:
creds = _resolve_credentials(provider)
return factory(name, creds)

# Unknown provider — let pydantic-ai handle it.
return infer_model(model_name)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,11 @@ def _search(question: str) -> list[KnowledgeBaseChunk]:
f"Documentation context (source URL -> content):\n{context}"
)
from baserow_enterprise.assistant.model_profiles import get_model_string
from baserow_enterprise.assistant.retrying_model import _resolve_model

agent_result = await search_docs_agent.run(prompt, model=get_model_string())
agent_result = await search_docs_agent.run(
prompt, model=_resolve_model(get_model_string())
)
prediction = agent_result.output

sources = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,29 +84,13 @@ def setup(settings):
float(_temp_raw) if _temp_raw else None
)

# Backward compatibility: bridge old UDSPY_LM_* env vars so existing
# deployments continue to work without config changes.
# Backward compatibility: bridge old UDSPY_LM_MODEL to the new setting.
# Credential fallback (UDSPY_LM_API_KEY, UDSPY_LM_OPENAI_COMPATIBLE_BASE_URL)
# is handled at model-creation time in retrying_model._resolve_model().
_udspy_model = os.getenv("UDSPY_LM_MODEL", "")
if _udspy_model and not settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL:
settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL = _udspy_model

_udspy_api_key = os.getenv("UDSPY_LM_API_KEY", "")
if _udspy_api_key:
# pydantic-ai reads provider-specific env vars. Set them all as
# fallbacks so the old catch-all key works regardless of provider.
for _key in (
"OPENAI_API_KEY",
"GROQ_API_KEY",
"ANTHROPIC_API_KEY",
"GEMINI_API_KEY",
):
os.environ.setdefault(_key, _udspy_api_key)

_udspy_base_url = os.getenv("UDSPY_LM_OPENAI_COMPATIBLE_BASE_URL", "")
if _udspy_base_url:
# pydantic-ai's OpenAI provider reads OPENAI_BASE_URL.
os.environ.setdefault("OPENAI_BASE_URL", _udspy_base_url)

# Bridge old AWS_REGION_NAME to boto3's standard AWS_DEFAULT_REGION.
_aws_region = os.getenv("AWS_REGION_NAME", "")
if _aws_region:
Expand Down
Loading
Loading