diff --git a/backend/justfile b/backend/justfile index a87d620520..dcfc6dd74d 100644 --- a/backend/justfile +++ b/backend/justfile @@ -81,9 +81,10 @@ uv_run := "uv run --active" # Repo root (parent of backend/) - clean() normalizes path (removes ..) repo_root := clean(justfile_directory() / "..") +_set_pythonpath := 'export PYTHONPATH="' + repo_root / 'backend/src:' + repo_root / 'premium/backend/src:' + repo_root / 'enterprise/backend/src:' + repo_root / 'backend/tests:' + repo_root / 'premium/backend/tests:' + repo_root / 'enterprise/backend/tests${PYTHONPATH:+:$PYTHONPATH}"' # Helper to load .env.local if present and set PYTHONPATH with absolute paths # Include this at the start of bash recipes that need env vars -_load_env := 'if [ -f "../.env.local" ]; then set -a; source "../.env.local"; set +a; fi; export PYTHONPATH="' + repo_root / 'backend/src:' + repo_root / 'premium/backend/src:' + repo_root / 'enterprise/backend/src:' + repo_root / 'backend/tests:' + repo_root / 'premium/backend/tests:' + repo_root / 'enterprise/backend/tests${PYTHONPATH:+:$PYTHONPATH}"' +_load_env := 'if [ -f "../.env.local" ]; then set -a; source "../.env.local"; set +a; fi; ' + _set_pythonpath # Source directories backend_source_dirs := "src/ ../premium/backend/src/ ../enterprise/backend/src/" @@ -228,14 +229,14 @@ alias f := fix # PYTHONPATH for test fixtures across all test directories test_pythonpath := "tests:../premium/backend/tests:../enterprise/backend/tests" -_pytest := 'PYTHONPATH="' + test_pythonpath + ':${PYTHONPATH:-}" ' + uv_run + ' pytest' +_pytest := 'PYTHONPATH="' + test_pythonpath + ':${PYTHONPATH:-}" ' + uv_run + ' pytest -c pytest.ini' # Run tests. Pass -n=auto to run in parallel with pytest-xdist [group('3 - testing')] test *ARGS: _check-dev #!/usr/bin/env bash set -euo pipefail - {{ _load_env }} + {{ _set_pythonpath }} {{ _pytest }} {{ ARGS }} # Run tests with coverage report diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 2b4f0d62da..2f50270b68 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -89,8 +89,8 @@ dependencies = [ "langchain==0.3.28", "langchain-openai==0.3.35", "openai==2.14.0", - "anthropic==0.77.0", - "mistralai==1.1.0", + "anthropic==0.84.0", + "mistralai==2.0.0", "icalendar==6.3.2", "jira2markdown==0.5", "openpyxl==3.1.5", @@ -100,7 +100,8 @@ dependencies = [ "genson==1.3.0", "pyotp==2.9.0", "qrcode==8.2", - "udspy==0.1.8", + "pydantic-ai-slim[anthropic,bedrock,google,groq,openai]==1.66.0", + "opentelemetry-sdk>=1.20.0", "netifaces==0.11.0", "requests-futures>=1.0.2", ] diff --git a/backend/pytest.ini b/backend/pytest.ini index d15be43dcd..7fc761fe45 100644 --- a/backend/pytest.ini +++ b/backend/pytest.ini @@ -56,3 +56,4 @@ markers = workspace_search: All tests related to workspace search functionality enable_all_signals: Disables signal deferral for this test (all signals enabled) enable_signals: Enables specific signals for this test (accepts dotted callable paths) + eval: mark test as an eval test (requires LLM API key) diff --git a/backend/src/baserow/config/settings/base.py b/backend/src/baserow/config/settings/base.py index 86213e7f05..6796069e99 100644 --- a/backend/src/baserow/config/settings/base.py +++ b/backend/src/baserow/config/settings/base.py @@ -1332,15 +1332,15 @@ def __setitem__(self, key, value): from sentry_sdk.integrations.django import DjangoIntegration from sentry_sdk.scrubber import DEFAULT_DENYLIST, EventScrubber - # Exclude the langchain integration from auto-discovery: its module-level - # imports are incompatible with Python 3.14 (langchain/pydantic type - # evaluation crash), and the import happens before disabled_integrations - # can take effect. + # Exclude integrations whose module-level imports are incompatible: + # - langchain: Python 3.14 type evaluation crash + # - pydantic_ai: sentry-sdk patches ToolManager._call_tool which was + # removed in pydantic-ai >= 1.x (now execute_tool_call) _sentry_integrations._AUTO_ENABLING_INTEGRATIONS[:] = [ entry for entry in _sentry_integrations._AUTO_ENABLING_INTEGRATIONS - if "langchain" not in entry + if "langchain" not in entry and "pydantic_ai" not in entry ] SENTRY_DENYLIST = DEFAULT_DENYLIST + ["username", "email", "name"] diff --git a/backend/src/baserow/config/settings/dev.py b/backend/src/baserow/config/settings/dev.py index db3b3c4434..532eff15c0 100755 --- a/backend/src/baserow/config/settings/dev.py +++ b/backend/src/baserow/config/settings/dev.py @@ -66,6 +66,24 @@ post_migrate.connect(setup_dev_e2e, dispatch_uid="setup_dev_e2e") +# Mirror logs to a file when BASEROW_LOG_FILE is set (e.g. for AI access when +# running locally). Truncated on each restart. +BASEROW_LOG_FILE = os.getenv("BASEROW_LOG_FILE", "") +if BASEROW_LOG_FILE: + LOGGING["handlers"]["file"] = { # noqa: F405 + "class": "logging.FileHandler", + "filename": BASEROW_LOG_FILE, + "formatter": "console", + "mode": "w", + } + LOGGING["root"]["handlers"].append("file") # noqa: F405 + + # Also route loguru to the same file so modules using loguru (e.g. + # the assistant telemetry) appear alongside stdlib log output. + from loguru import logger as _loguru_logger + + _loguru_logger.add(BASEROW_LOG_FILE, mode="a") + try: from .local import * # noqa: F403, F401 except ImportError: diff --git a/backend/src/baserow/config/settings/test.py b/backend/src/baserow/config/settings/test.py index 8245a298ce..d6449937e8 100644 --- a/backend/src/baserow/config/settings/test.py +++ b/backend/src/baserow/config/settings/test.py @@ -26,13 +26,13 @@ TEST_ENV_VARS = {} # Prefixes for vars that can be overridden via env vars (for DB/Redis configuration) -ALLOWED_ENV_PREFIXES = ("DATABASE_",) +ALLOWED_ENV_PREFIXES = ("DATABASE_", "BASEROW_EMBEDDINGS_API_URL") def getenv_for_tests(key: str, default: str = "") -> str: """ Get env var for tests: - - DATABASE_* vars: check real env first, then TEST_ENV_FILE, then default + - ALLOWED_ENV_PREFIXES vars: use real env var if set, else TEST_ENV_FILE, else default - Other vars: only use TEST_ENV_FILE or default (never real env) """ @@ -65,9 +65,9 @@ def getenv_for_tests(key: str, default: str = "") -> str: BASEROW_TESTS_SETUP_DB_FIXTURE = str_to_bool( os.getenv("BASEROW_TESTS_SETUP_DB_FIXTURE", "on") ) -DATABASES["default"]["TEST"] = { - "MIGRATE": not BASEROW_TESTS_SETUP_DB_FIXTURE, -} +DATABASES["default"].setdefault("TEST", {})[ + "MIGRATE" +] = not BASEROW_TESTS_SETUP_DB_FIXTURE # Open a second database connection that can be used to test transactions. DATABASES["default-copy"] = deepcopy(DATABASES["default"]) diff --git a/backend/src/baserow/contrib/database/migrations/0206_rowhistory_database_ro_action__6ea699_idx.py b/backend/src/baserow/contrib/database/migrations/0206_rowhistory_database_ro_action__6ea699_idx.py new file mode 100644 index 0000000000..f0a4ba6a3d --- /dev/null +++ b/backend/src/baserow/contrib/database/migrations/0206_rowhistory_database_ro_action__6ea699_idx.py @@ -0,0 +1,19 @@ +# Generated by Django 5.2.12 on 2026-03-17 09:16 + +from django.db import migrations, models +from django.contrib.postgres.operations import AddIndexConcurrently + + +class Migration(migrations.Migration): + atomic = False + + dependencies = [ + ('database', '0205_formvieweditrowfield'), + ] + + operations = [ + AddIndexConcurrently( + model_name='rowhistory', + index=models.Index(fields=['action_timestamp'], name='database_ro_action__6ea699_idx'), + ), + ] diff --git a/backend/src/baserow/contrib/database/rows/history.py b/backend/src/baserow/contrib/database/rows/history.py index d9a55e2839..aa13498020 100644 --- a/backend/src/baserow/contrib/database/rows/history.py +++ b/backend/src/baserow/contrib/database/rows/history.py @@ -2,7 +2,7 @@ from itertools import groupby from django.conf import settings -from django.db import router +from django.db import connection from django.db.models import QuerySet from django.dispatch import receiver @@ -18,6 +18,7 @@ from baserow.contrib.database.rows.types import ActionData from baserow.core.action.signals import action_done from baserow.core.models import Workspace +from baserow.core.psycopg import sql from baserow.core.telemetry.utils import baserow_trace from baserow.core.types import AnyUser @@ -68,15 +69,33 @@ def list_row_history( return queryset @classmethod - def delete_entries_older_than(cls, cutoff: datetime): + def delete_entries_older_than(cls, cutoff: datetime, batch_size: int = 20_000): """ - Deletes all row history entries that are older than the given cutoff date. + Deletes all row history entries that are older than the given cutoff date + in batches to avoid long-running transactions. :param cutoff: The date and time before which all entries will be deleted. + :param batch_size: The number of rows to delete per batch. """ - delete_qs = RowHistory.objects.filter(action_timestamp__lt=cutoff) - delete_qs._raw_delete(using=router.db_for_write(delete_qs.model)) + table = sql.Identifier(RowHistory._meta.db_table) + query = sql.SQL( + """ + WITH to_delete AS ( + SELECT id FROM {table} + WHERE action_timestamp < %s + LIMIT %s + ) + DELETE FROM {table} + USING to_delete + WHERE {table}.id = to_delete.id + """ + ).format(table=table) + while True: + with connection.cursor() as cursor: + cursor.execute(query, [cutoff, batch_size]) + if cursor.rowcount == 0: + break @receiver(action_done) diff --git a/backend/src/baserow/contrib/database/rows/models.py b/backend/src/baserow/contrib/database/rows/models.py index efdc954a73..075e08cfab 100644 --- a/backend/src/baserow/contrib/database/rows/models.py +++ b/backend/src/baserow/contrib/database/rows/models.py @@ -59,4 +59,9 @@ class RowHistory(models.Model): class Meta: ordering = ("-action_timestamp", "-id") - indexes = [models.Index(fields=["table", "row_id", "-action_timestamp", "-id"])] + indexes = [ + # For deleting history entries by action timestamp. + models.Index(fields=["action_timestamp"]), + # For listing the history of a row. + models.Index(fields=["table", "row_id", "-action_timestamp", "-id"]), + ] diff --git a/backend/uv.lock b/backend/uv.lock index edd52787b9..6715e97cc4 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -40,7 +40,7 @@ wheels = [ [[package]] name = "anthropic" -version = "0.77.0" +version = "0.84.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -52,9 +52,9 @@ dependencies = [ { name = "sniffio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/eb/85/6cb5da3cf91de2eeea89726316e8c5c8c31e2d61ee7cb1233d7e95512c31/anthropic-0.77.0.tar.gz", hash = "sha256:ce36efeb80cb1e25430a88440dc0f9aa5c87f10d080ab70a1bdfd5c2c5fbedb4", size = 504575, upload-time = "2026-01-29T18:20:41.507Z" } +sdist = { url = "https://files.pythonhosted.org/packages/04/ea/0869d6df9ef83dcf393aeefc12dd81677d091c6ffc86f783e51cf44062f2/anthropic-0.84.0.tar.gz", hash = "sha256:72f5f90e5aebe62dca316cb013629cfa24996b0f5a4593b8c3d712bc03c43c37", size = 539457, upload-time = "2026-02-25T05:22:38.54Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ac/27/9df785d3f94df9ac72f43ee9e14b8120b37d992b18f4952774ed46145022/anthropic-0.77.0-py3-none-any.whl", hash = "sha256:65cc83a3c82ce622d5c677d0d7706c77d29dc83958c6b10286e12fda6ffb2651", size = 397867, upload-time = "2026-01-29T18:20:39.481Z" }, + { url = "https://files.pythonhosted.org/packages/64/ca/218fa25002a332c0aa149ba18ffc0543175998b1f65de63f6d106689a345/anthropic-0.84.0-py3-none-any.whl", hash = "sha256:861c4c50f91ca45f942e091d83b60530ad6d4f98733bfe648065364da05d29e7", size = 455156, upload-time = "2026-02-25T05:22:40.468Z" }, ] [[package]] @@ -255,6 +255,7 @@ dependencies = [ { name = "prosemirror", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "psutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "psycopg2-binary", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pydantic-ai-slim", extra = ["anthropic", "bedrock", "google", "groq", "openai"], marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "pyotp", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "pysaml2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "qrcode", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -270,7 +271,6 @@ dependencies = [ { name = "twisted", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "tzdata", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "udspy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "unicodecsv", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "uvicorn", extra = ["standard"], marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "validators", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -323,7 +323,7 @@ dev = [ [package.metadata] requires-dist = [ - { name = "anthropic", specifier = "==0.77.0" }, + { name = "anthropic", specifier = "==0.84.0" }, { name = "antlr4-python3-runtime", specifier = "==4.9.3" }, { name = "asgiref", specifier = "==3.11.0" }, { name = "boto3", specifier = "==1.42.57" }, @@ -359,7 +359,7 @@ requires-dist = [ { name = "langchain-openai", specifier = "==0.3.35" }, { name = "loguru", specifier = "==0.7.3" }, { name = "mcp", specifier = "==1.26.0" }, - { name = "mistralai", specifier = "==1.1.0" }, + { name = "mistralai", specifier = "==2.0.0" }, { name = "netifaces", specifier = "==0.11.0" }, { name = "openai", specifier = "==2.14.0" }, { name = "openpyxl", specifier = "==3.1.5" }, @@ -381,6 +381,7 @@ requires-dist = [ { name = "opentelemetry-instrumentation-wsgi", specifier = "==0.60b1" }, { name = "opentelemetry-proto", specifier = "==1.39.1" }, { name = "opentelemetry-sdk", specifier = "==1.39.1" }, + { name = "opentelemetry-sdk", specifier = ">=1.20.0" }, { name = "opentelemetry-semantic-conventions", specifier = "==0.60b1" }, { name = "opentelemetry-util-http", specifier = "==0.60b1" }, { name = "pgvector", specifier = "==0.4.2" }, @@ -389,6 +390,7 @@ requires-dist = [ { name = "prosemirror", specifier = "==0.5.2" }, { name = "psutil", specifier = "==7.2.2" }, { name = "psycopg2-binary", specifier = "==2.9.11" }, + { name = "pydantic-ai-slim", extras = ["anthropic", "bedrock", "google", "groq", "openai"], specifier = "==1.66.0" }, { name = "pyotp", specifier = "==2.9.0" }, { name = "pysaml2", specifier = "==7.5.4" }, { name = "qrcode", specifier = "==8.2" }, @@ -404,7 +406,6 @@ requires-dist = [ { name = "twisted", specifier = "==25.5.0" }, { name = "typing-extensions", specifier = ">=4.14.1" }, { name = "tzdata", specifier = "==2025.3" }, - { name = "udspy", specifier = "==0.1.8" }, { name = "unicodecsv", specifier = "==0.14.1" }, { name = "uvicorn", extras = ["standard"], specifier = "==0.40.0" }, { name = "validators", specifier = "==0.35.0" }, @@ -1249,6 +1250,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5e/2e/b41d8a1a917d6581fc27a35d05561037b048e47df50f27f8ac9c7e27a710/freezegun-1.5.5-py3-none-any.whl", hash = "sha256:cd557f4a75cf074e84bc374249b9dd491eaeacd61376b9eb3c423282211619d2", size = 19266, upload-time = "2025-08-09T10:39:06.636Z" }, ] +[[package]] +name = "genai-prices" +version = "0.0.55" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/77/67/de9d9be180db6d80b298c281dff71502095c0776d7cc9286f486f667f61a/genai_prices-0.0.55.tar.gz", hash = "sha256:8692c65d0deefe2ad0680d71841eb12822a35945a6060d2b6adbcbdf4945e1cb", size = 59987, upload-time = "2026-02-26T17:56:41.467Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c4/98/66a06b82a5c840f896490d5ef9c7691776b147589f2e8d2fa66c67a3db9c/genai_prices-0.0.55-py3-none-any.whl", hash = "sha256:ccd795c90c926b3c71066bf5656f14c67fc11fdba6d71e072c7fb4fa311e1b12", size = 62603, upload-time = "2026-02-26T17:56:40.502Z" }, +] + [[package]] name = "genson" version = "1.3.0" @@ -1287,6 +1301,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/db/18/79e9008530b79527e0d5f79e7eef08d3b179b7f851cfd3a2f27822fbdfa9/google_auth-2.47.0-py3-none-any.whl", hash = "sha256:c516d68336bfde7cf0da26aab674a36fedcf04b37ac4edd59c597178760c3498", size = 234867, upload-time = "2026-01-06T21:55:28.6Z" }, ] +[package.optional-dependencies] +requests = [ + { name = "requests", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] + [[package]] name = "google-cloud-core" version = "2.5.0" @@ -1329,6 +1348,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/42/fa/f50f51260d7b0ef5d4898af122d8a7ec5a84e2984f676f746445f783705f/google_crc32c-1.8.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8b3f68782f3cbd1bce027e48768293072813469af6a61a86f6bb4977a4380f21", size = 33734, upload-time = "2025-12-16T00:40:27.028Z" }, ] +[[package]] +name = "google-genai" +version = "1.66.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "distro", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "google-auth", extra = ["requests"], marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "requests", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "sniffio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "tenacity", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "websockets", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9b/ba/0b343b0770d4710ad2979fd9301d7caa56c940174d5361ed4a7cc4979241/google_genai-1.66.0.tar.gz", hash = "sha256:ffc01647b65046bca6387320057aa51db0ad64bcc72c8e3e914062acfa5f7c49", size = 504386, upload-time = "2026-03-04T22:15:28.156Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/dd/403949d922d4e261b08b64aaa132af4e456c3b15c8e2a2d9e6ef693f66e2/google_genai-1.66.0-py3-none-any.whl", hash = "sha256:7f127a39cf695277104ce4091bb26e417c59bb46e952ff3699c3a982d9c474ee", size = 732174, upload-time = "2026-03-04T22:15:26.63Z" }, +] + [[package]] name = "google-resumable-media" version = "2.8.0" @@ -1391,6 +1431,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4f/dc/041be1dff9f23dac5f48a43323cd0789cb798342011c19a248d9c9335536/greenlet-3.3.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6c10513330af5b8ae16f023e8ddbfb486ab355d04467c4679c5cfe4659975dd9", size = 1676034, upload-time = "2025-12-04T14:27:33.531Z" }, ] +[[package]] +name = "griffelib" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4d/51/c936033e16d12b627ea334aaaaf42229c37620d0f15593456ab69ab48161/griffelib-2.0.0-py3-none-any.whl", hash = "sha256:01284878c966508b6d6f1dbff9b6fa607bc062d8261c5c7253cb285b06422a7f", size = 142004, upload-time = "2026-02-09T19:09:40.561Z" }, +] + +[[package]] +name = "groq" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "distro", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "sniffio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3f/12/f4099a141677fcd2ed79dcc1fcec431e60c52e0e90c9c5d935f0ffaf8c0e/groq-1.0.0.tar.gz", hash = "sha256:66cb7bb729e6eb644daac7ce8efe945e99e4eb33657f733ee6f13059ef0c25a9", size = 146068, upload-time = "2025-12-17T23:34:23.115Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4a/88/3175759d2ef30406ea721f4d837bfa1ba4339fde3b81ba8c5640a96ed231/groq-1.0.0-py3-none-any.whl", hash = "sha256:6e22bf92ffad988f01d2d4df7729add66b8fd5dbfb2154b5bbf3af245b72c731", size = 138292, upload-time = "2025-12-17T23:34:21.957Z" }, +] + [[package]] name = "gunicorn" version = "23.0.0" @@ -1441,18 +1506,17 @@ wheels = [ [[package]] name = "httpx" -version = "0.27.2" +version = "0.28.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "certifi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "httpcore", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "idna", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "sniffio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/78/82/08f8c936781f67d9e6b9eeb8a0c8b4e406136ea4c3d1f89a5db71d42e0e6/httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2", size = 144189, upload-time = "2024-08-27T12:54:01.334Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/56/95/9377bcb415797e44274b51d46e3249eba641711cf3348050f76ee7b15ffc/httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0", size = 76395, upload-time = "2024-08-27T12:53:59.653Z" }, + { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, ] [[package]] @@ -1705,15 +1769,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/73/07/02e16ed01e04a374e644b575638ec7987ae846d25ad97bcc9945a3ee4b0e/jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade", size = 12898, upload-time = "2023-06-16T21:01:28.466Z" }, ] -[[package]] -name = "jsonpath-python" -version = "1.1.4" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b8/bf/626a72f2d093c5eb4f4de55b443714afa7231beeae40d4a1c69b5c5aa4d1/jsonpath_python-1.1.4.tar.gz", hash = "sha256:bb3e13854e4807c078a1503ae2d87c211b8bff4d9b40b6455ed583b3b50a7fdd", size = 84766, upload-time = "2025-11-25T12:08:39.521Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ee/bc/52e5bf0d9839e082b976c19afcab7561d0d719c7627483bf5dc251d27eed/jsonpath_python-1.1.4-py3-none-any.whl", hash = "sha256:8700cb8610c44da6e5e9bff50232779c44bf7dc5bc62662d49319ee746898442", size = 12687, upload-time = "2025-11-25T12:08:38.453Z" }, -] - [[package]] name = "jsonpointer" version = "3.0.0" @@ -1904,6 +1959,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/43/71/0f5d010e92ed9747e14bef35e91b6580533510f1e36a8a09eb79ee70b2f0/librt-0.7.8-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:cf243da9e42d914036fd362ac3fa77d80a41cadcd11ad789b1b5eec4daaf67ca", size = 224731, upload-time = "2026-01-14T12:55:58.175Z" }, ] +[[package]] +name = "logfire-api" +version = "4.25.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/94/5c/026cec30d85394aec8f5f12d70edbe2d706837bc9a411bd71a542cedae50/logfire_api-4.25.0.tar.gz", hash = "sha256:7562d5adfe3987291039dddb21947c86cb9d832d068c87d9aa23db86ef07095b", size = 75853, upload-time = "2026-02-19T15:27:29.518Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e2/39/83414c0fadb4f11f90e6b80b631aa79f62a605664f0c4693e2ebc7ee73f3/logfire_api-4.25.0-py3-none-any.whl", hash = "sha256:0d607eb09ef5426e26f376ff277a8d401bc5b7b4178ea66db404e13c368494cf", size = 120473, upload-time = "2026-02-19T15:27:25.832Z" }, +] + [[package]] name = "loguru" version = "0.7.3" @@ -2060,19 +2124,20 @@ wheels = [ [[package]] name = "mistralai" -version = "1.1.0" +version = "2.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "eval-type-backport", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "jsonpath-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "opentelemetry-api", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "opentelemetry-semantic-conventions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "python-dateutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "typing-inspect", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "typing-inspection", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f8/9c/4ea3ee3c8aac270e3d7fde9eb18c34209348f89815fbb356d04bf949e2aa/mistralai-1.1.0.tar.gz", hash = "sha256:9d1fe778e0e8c6ddab714e6a64c6096bd39cfe119ff38ceb5019d8e089df08ba", size = 117553, upload-time = "2024-09-17T16:25:53.342Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/5c/22fd7d1ec7e333f83dc5e2d0b176952a5d9a1f08519898c55616c92a81d8/mistralai-2.0.0.tar.gz", hash = "sha256:acb7937a53119ece67f4978809d4cf630fbf54b4dfe85c0eeae778ac40850fab", size = 317705, upload-time = "2026-03-10T17:12:48.616Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/64/9b/97d1f2f8fb4648008882284b2235d0b7b64b094ad4a4ee02c9c67c361578/mistralai-1.1.0-py3-none-any.whl", hash = "sha256:eea0938975195f331d0ded12d14e3c982f09f1b68210200ed4ff0c6b9b22d0fb", size = 229749, upload-time = "2024-09-17T16:25:51.963Z" }, + { url = "https://files.pythonhosted.org/packages/b8/95/1587d555837bf635db28e2acee366cc47edc473cd3155515be14acced91b/mistralai-2.0.0-py3-none-any.whl", hash = "sha256:e551fc36d60d4c969140e37f10eab04986480e487f357c900da05d740b9a0baf", size = 709642, upload-time = "2026-03-10T17:12:50.104Z" }, ] [[package]] @@ -2840,6 +2905,42 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5a/87/b70ad306ebb6f9b585f114d0ac2137d792b48be34d732d60e597c2f8465a/pydantic-2.12.5-py3-none-any.whl", hash = "sha256:e561593fccf61e8a20fc46dfc2dfe075b8be7d0188df33f221ad1f0139180f9d", size = 463580, upload-time = "2025-11-26T15:11:44.605Z" }, ] +[[package]] +name = "pydantic-ai-slim" +version = "1.66.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "genai-prices", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "griffelib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "opentelemetry-api", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pydantic-graph", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "typing-inspection", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/23/31/1b291e2c169c684290b458a1333d438e34c542d355c60c0bc92866c192a2/pydantic_ai_slim-1.66.0.tar.gz", hash = "sha256:d675f3cf7171c7ea767084a2228d7a2e8eb88e18bfefba71387ed150fcb64069", size = 435408, upload-time = "2026-03-05T00:54:58.587Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/c9/098d675eb20863c6c92a23e09b6cc0d10df3f96191f04f3daefb31f180bc/pydantic_ai_slim-1.66.0-py3-none-any.whl", hash = "sha256:59dcccbcbf948d356dd4a03457962b4079db42c56edf8a11113d827015027e66", size = 566105, upload-time = "2026-03-05T00:54:51.611Z" }, +] + +[package.optional-dependencies] +anthropic = [ + { name = "anthropic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +bedrock = [ + { name = "boto3", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +google = [ + { name = "google-genai", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +groq = [ + { name = "groq", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +openai = [ + { name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] + [[package]] name = "pydantic-core" version = "2.41.5" @@ -2873,6 +2974,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d9/02/3c562f3a51afd4d88fff8dffb1771b30cfdfd79befd9883ee094f5b6c0d8/pydantic_core-2.41.5-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:2a5e06546e19f24c6a96a129142a75cee553cc018ffee48a460059b1185f4470", size = 2331955, upload-time = "2025-11-04T13:41:54.079Z" }, ] +[[package]] +name = "pydantic-graph" +version = "1.66.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "logfire-api", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "typing-inspection", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/25/5e/4a3ed6c4047fd2676b248cee3666299b6214f691c086fd5f9bdda96ace1d/pydantic_graph-1.66.0.tar.gz", hash = "sha256:834df5137098c2c95d2241b98d4dd61af4a3ff24784751c82cc543db46dd29f5", size = 58522, upload-time = "2026-03-05T00:55:01.019Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8d/95/22c0ad3f3830d7fdd4dbfdc78548705f6c9ac434ada0d790ffc02491b39e/pydantic_graph-1.66.0-py3-none-any.whl", hash = "sha256:8f75d34efbaa4b65767d39faa2b3270fd321fb4104a66d3773754f4854876739", size = 72351, upload-time = "2026-03-05T00:54:54.661Z" }, +] + [[package]] name = "pydantic-settings" version = "2.12.0" @@ -3761,19 +3877,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, ] -[[package]] -name = "typing-inspect" -version = "0.9.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "mypy-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/dc/74/1789779d91f1961fa9438e9a8710cdae6bd138c80d7303996933d117264a/typing_inspect-0.9.0.tar.gz", hash = "sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78", size = 13825, upload-time = "2023-05-24T20:25:47.612Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/65/f3/107a22063bf27bdccf2024833d3445f4eea42b2e598abfbd46f6a63b6cb0/typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f", size = 8827, upload-time = "2023-05-24T20:25:45.287Z" }, -] - [[package]] name = "typing-inspection" version = "0.4.2" @@ -3813,22 +3916,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b1/5e/512aeb40fd819f4660d00f96f5c7371ee36fc8c6b605128c5ee59e0b28c6/u_msgpack_python-2.8.0-py2.py3-none-any.whl", hash = "sha256:1d853d33e78b72c4228a2025b4db28cda81214076e5b0422ed0ae1b1b2bb586a", size = 10590, upload-time = "2023-05-18T09:28:10.323Z" }, ] -[[package]] -name = "udspy" -version = "0.1.8" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "jiter", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "regex", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "tenacity", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/d4/d8/0ab2a0258f4932f40004c759f79336100a590c1aa8296c75d797d47836e5/udspy-0.1.8.tar.gz", hash = "sha256:8da68fcbd118850eeff3750942c053006bafea335bf74f411055ccf27d800b3b", size = 270081, upload-time = "2025-11-28T15:20:55.19Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a7/ec/10076e9cb53685ffb01d2229df6372d4dbca3d4a3a0a93a03ad5126c40b2/udspy-0.1.8-py3-none-any.whl", hash = "sha256:3a66427b60f4cd6360ff95db76fd34a8fa9201fde9244b5a3db6b3fb2e424042", size = 60418, upload-time = "2025-11-28T15:20:53.849Z" }, -] - [[package]] name = "ujson" version = "5.12.0" diff --git a/changelog/entries/unreleased/bug/5014_fix_error_when_changing_filter_type_on_link_row_fields.json b/changelog/entries/unreleased/bug/5014_fix_error_when_changing_filter_type_on_link_row_fields.json new file mode 100644 index 0000000000..6d9e88583f --- /dev/null +++ b/changelog/entries/unreleased/bug/5014_fix_error_when_changing_filter_type_on_link_row_fields.json @@ -0,0 +1,9 @@ +{ + "type": "bug", + "message": "Fix error when changing filter type on link row fields", + "issue_origin": "github", + "issue_number": 5014, + "domain": "database", + "bullet_points": [], + "created_at": "2026-03-20" +} \ No newline at end of file diff --git a/changelog/entries/unreleased/bug/add_index_to_resolve_slow_cleanup_of_row_history_table.json b/changelog/entries/unreleased/bug/add_index_to_resolve_slow_cleanup_of_row_history_table.json new file mode 100644 index 0000000000..eeca089da0 --- /dev/null +++ b/changelog/entries/unreleased/bug/add_index_to_resolve_slow_cleanup_of_row_history_table.json @@ -0,0 +1,9 @@ +{ + "type": "bug", + "message": "Fixed slow cleanup of row history table by adding a database index.", + "issue_origin": "github", + "issue_number": null, + "domain": "database", + "bullet_points": [], + "created_at": "2026-03-17" +} \ No newline at end of file diff --git a/changelog/entries/unreleased/refactor/5011_silence_defined_error_codes_in_sentry.json b/changelog/entries/unreleased/refactor/5011_silence_defined_error_codes_in_sentry.json new file mode 100644 index 0000000000..c1aa56f4a9 --- /dev/null +++ b/changelog/entries/unreleased/refactor/5011_silence_defined_error_codes_in_sentry.json @@ -0,0 +1,9 @@ +{ + "type": "refactor", + "message": "Silence defined error codes in Sentry", + "issue_origin": "github", + "issue_number": 5011, + "domain": "core", + "bullet_points": [], + "created_at": "2026-03-19" +} \ No newline at end of file diff --git a/changelog/entries/unreleased/refactor/Replace udspy with pydantic-ai_replace_udspy_with_pydanticai.json b/changelog/entries/unreleased/refactor/Replace udspy with pydantic-ai_replace_udspy_with_pydanticai.json new file mode 100644 index 0000000000..18f9dd2271 --- /dev/null +++ b/changelog/entries/unreleased/refactor/Replace udspy with pydantic-ai_replace_udspy_with_pydanticai.json @@ -0,0 +1,9 @@ +{ + "type": "refactor", + "message": "Replace udspy with pydantic-ai", + "issue_origin": "github", + "issue_number": null, + "domain": "core", + "bullet_points": [], + "created_at": "2026-03-04" +} \ No newline at end of file diff --git a/docs/installation/ai-assistant.md b/docs/installation/ai-assistant.md index 93e0369b02..ff0f900654 100644 --- a/docs/installation/ai-assistant.md +++ b/docs/installation/ai-assistant.md @@ -6,8 +6,8 @@ server. ## 1) Core concepts -- The assistant runs via **UDSPy** — see https://github.com/baserow/udspy/ -- UDSPy speaks to **any OpenAI-compatible API**. +- The assistant is built on [**pydantic-ai**](https://ai.pydantic.dev/) — a + Python agent framework that supports multiple LLM providers out of the box. - You **must** set `BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL` with the provider and model of your choosing. - The assistant has been mostly tested with the `gpt-oss-120b` family. Other models can @@ -21,61 +21,82 @@ Set the model you want, restart Baserow, and let migrations run. ```dotenv # Required -BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=openai/gpt-4o +BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=openai:gpt-5.2 OPENAI_API_KEY=your_api_key -# Optional - adjust LLM temperature (default: 0) -BASEROW_ENTERPRISE_ASSISTANT_LLM_TEMPERATURE=0 +# Optional - adjust LLM temperature (default: 0.3) +BASEROW_ENTERPRISE_ASSISTANT_LLM_TEMPERATURE=0.3 ``` **About temperature:** -- Controls randomness in LLM responses (0.0 to 2.0) -- **Default: 0** (deterministic, consistent responses - recommended for production) -- Higher values (e.g., 0.7-1.0) = more creative/varied responses -- Lower values (e.g., 0-0.3) = more focused/consistent responses +- Controls randomness in the main assistant's LLM responses. +- **Default: 0.3** (focused, consistent responses) +- Higher values (depending on the model) = more creative/varied responses. +- Lower values (e.g., 0-0.1) = more analytical responses. Note that even with temperature of 0.0, the results will not be fully deterministic. ## 3) Provider presets -Choose **one** provider block and set its variables. +Choose **one** provider block and set its variables. pydantic-ai uses the standard +environment variables for each provider (e.g. `OPENAI_API_KEY`, `GROQ_API_KEY`). ### OpenAI / OpenAI-compatible ```dotenv -BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=openai/gpt-4o +BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=openai:gpt-5.2 OPENAI_API_KEY=your_api_key -# Optional alternative endpoints (OpenAI EU or Azure OpenAI, etc.) -UDSPY_LM_OPENAI_COMPATIBLE_BASE_URL=https://eu.api.openai.com/v1 +# Optional: point to an alternative OpenAI-compatible endpoint +OPENAI_BASE_URL=https://eu.api.openai.com/v1 # or -UDSPY_LM_OPENAI_COMPATIBLE_BASE_URL=https://.openai.azure.com -# or any OpenAI compatible endpoint +OPENAI_BASE_URL=https://.openai.azure.com +``` + +### Anthropic + +```dotenv +BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=anthropic:claude-sonnet-4-20250514 +ANTHROPIC_API_KEY=your_api_key ``` ### AWS Bedrock +pydantic-ai supports two authentication methods for Bedrock. Use whichever matches your setup. + +**Option A — Standard AWS credentials (boto3)** + +```dotenv +BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=bedrock:openai.gpt-oss-120b-1:0 +AWS_ACCESS_KEY_ID=your_access_key +AWS_SECRET_ACCESS_KEY=your_secret_key +AWS_DEFAULT_REGION=eu-central-1 +``` + +Any boto3-compatible credential method works: env vars, IAM roles, instance profiles, `~/.aws/credentials`, etc. + +**Option B — Bedrock bearer token** + ```dotenv -BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=bedrock/openai.gpt-oss-120b-1:0 -AWS_BEARER_TOKEN_BEDROCK=your_bedrock_token -AWS_REGION_NAME=eu-central-1 +BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=bedrock:openai.gpt-oss-120b-1:0 +AWS_BEARER_TOKEN_BEDROCK=your_bearer_token +AWS_DEFAULT_REGION=eu-central-1 ``` ### Groq ```dotenv -BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=groq/openai/gpt-oss-120b +BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=groq:openai/gpt-oss-120b GROQ_API_KEY=your_api_key ``` ### Ollama ```dotenv -BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=ollama/gpt-oss:120b -OLLAMA_API_KEY=your_api_key -# Optionally and alternative endpoint -UDSPY_LM_OPENAI_COMPATIBLE_BASE_URL=http://localhost:11434/v1 +BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=ollama:gpt-oss:120b +# Point to your Ollama instance (defaults to http://localhost:11434/v1) +OLLAMA_BASE_URL=http://localhost:11434/v1 ``` -Under the hood, UDSPy auto-detects provider from the model prefix and builds an -OpenAI-compatible client accordingly. +pydantic-ai auto-detects the provider from the model prefix and routes requests +accordingly. ## 4) Knowledge-base lookup @@ -123,3 +144,42 @@ just dcd run --rm web-frontend bash -c env | grep LLM_MODEL ``` Both commands must return the same value for `BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL`. If either is missing or they differ, update your environment configuration and restart the services. + +## 6) Supported models + +OpenAI, Anthropic, AWS Bedrock, Groq, Gemini/Vertex AI and any OpenAI-compatible +endpoint (Azure, DeepSeek, Fireworks, LiteLLM, Perplexity, Together AI, etc.). + +## 7) Framework change: UDSPy to pydantic-ai + +The assistant previously used [UDSPy](https://github.com/baserow/udspy/) as its agent +framework. It now uses [pydantic-ai](https://ai.pydantic.dev/). Most environment +variables are unchanged or bridged for backward compatibility. + +### What stays the same + +| Variable | Notes | +|----------|-------| +| `BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL` | Works exactly as before. Both `provider/model` and `provider:model` formats are accepted. | +| `BASEROW_ENTERPRISE_ASSISTANT_LLM_TEMPERATURE` | Still supported. Overrides the orchestrator temperature when set. | +| `OPENAI_API_KEY` | Unchanged. | +| `GROQ_API_KEY` | Unchanged. | +| `AWS_BEARER_TOKEN_BEDROCK` | Still works — pydantic-ai supports Bedrock bearer token auth natively. | + +### Bridged for backward compatibility (no action needed) + +| Old variable | Equivalent | Notes | +|--------------|------------|-------| +| `UDSPY_LM_MODEL` | `BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL` | If set and the new var is absent, the old value is used automatically. | +| `UDSPY_LM_API_KEY` | `OPENAI_API_KEY` / `GROQ_API_KEY` / etc. | Propagated to all provider key variables as a fallback. | +| `UDSPY_LM_OPENAI_COMPATIBLE_BASE_URL` | `OPENAI_BASE_URL` | Still works; bridged automatically. | +| `AWS_REGION_NAME` | `AWS_DEFAULT_REGION` | Still works; bridged automatically. | + +### New variables + +| Variable | Notes | +|----------|-------| +| `OPENAI_BASE_URL` | Preferred replacement for `UDSPY_LM_OPENAI_COMPATIBLE_BASE_URL`. | +| `AWS_DEFAULT_REGION` | Preferred replacement for `AWS_REGION_NAME`. | +| `OLLAMA_BASE_URL` | Replaces `UDSPY_LM_OPENAI_COMPATIBLE_BASE_URL` for Ollama. Defaults to `http://localhost:11434/v1`. | +| `ANTHROPIC_API_KEY` | New provider — Anthropic models are now supported. | diff --git a/docs/installation/configuration.md b/docs/installation/configuration.md index c0d66d6153..fc65e449c7 100644 --- a/docs/installation/configuration.md +++ b/docs/installation/configuration.md @@ -188,12 +188,12 @@ The installation methods referred to in the variable descriptions are: | Name | Description | Defaults | |---------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------| | BASEROW\_EMBEDDINGS\_API\_URL | If not empty, the AI-assistant will use this as embedding server for the knowledge base lookup. Must point to a container running this image: https://hub.docker.com/r/baserow/embeddings | "" (empty string) | -| BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL | If not empty, then this model will be used for the AI-assistant. Provide like `groq/openai/gpt-oss-120b` or `bedrock/openai.gpt-oss-120b-1:0`. Note that additional API keys must be provided as environment variable depending on the provider. Instructions can be found at https://baserow.io/docs/installation/ai-assistant | "" (empty string) | -| AWS\_BEARER\_TOKEN\_BEDROCK | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL=bedrock/bedrock/openai.gpt-oss-120b-1:0, then this environment variable must be set. Instructions on how to obtain: https://docs.aws.amazon.com/bedrock/latest/userguide/api-keys-use.html | "" (empty string) | -| AWS\_REGION\_NAME | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL=groq/openai/gpt-oss-120b, then the AWS region for the AI-assistant can be provided here. | us-east-1 | -| GROQ\_API\_KEY | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL=bedrock/bedrock/openai.gpt-oss-120b-1:0, then the Groq API key can be provided here. | "" (empty string) | -| OLLAMA\_API\_KEY | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL=ollama/gpt-oss:120b, then the Ollama API key can be provided here. | "" (empty string) | -| UDSPY\_LM\_OPENAI\_COMPATIBLE\_BASE\_URL | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL=openai/gpt-5-nano, then the base URL can be changed here. This can be used to point to a different OpenAI compatible API like Azure, for example. | "" (empty string) | +| BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL | If not empty, then this model will be used for the AI-assistant. Provide in pydantic-ai format like `groq:openai/gpt-oss-120b` or `bedrock:openai.gpt-oss-120b-1:0`. Note that additional API keys must be provided as environment variable depending on the provider. Instructions can be found at https://baserow.io/docs/installation/ai-assistant | "" (empty string) | +| AWS\_BEARER\_TOKEN\_BEDROCK | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL uses a bedrock provider, then this environment variable must be set. Instructions on how to obtain: https://docs.aws.amazon.com/bedrock/latest/userguide/api-keys-use.html | "" (empty string) | +| AWS\_REGION\_NAME | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL uses a bedrock provider, then the AWS region for the AI-assistant can be provided here. | us-east-1 | +| GROQ\_API\_KEY | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL uses a groq provider (e.g. `groq:openai/gpt-oss-120b`), then the Groq API key must be provided here. | "" (empty string) | +| OLLAMA\_API\_KEY | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL uses an ollama provider (e.g. `ollama:gpt-oss:120b`), then the Ollama API key can be provided here. | "" (empty string) | +| UDSPY\_LM\_OPENAI\_COMPATIBLE\_BASE\_URL | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL uses an openai provider (e.g. `openai:gpt-5-nano`), then the base URL can be changed here. This can be used to point to a different OpenAI compatible API like Azure, for example. | "" (empty string) | ### Data sync configuration diff --git a/docs/installation/install-with-helm.md b/docs/installation/install-with-helm.md index 103c678978..f71e9cad4b 100644 --- a/docs/installation/install-with-helm.md +++ b/docs/installation/install-with-helm.md @@ -183,7 +183,7 @@ Add to your `config.yaml`: ```yaml baserow-embeddings: enabled: true - assistantLLMModel: "groq/openai/gpt-oss-120b" + assistantLLMModel: "groq:openai/gpt-oss-120b" backendSecrets: GROQ_API_KEY: "your-groq-api-key" diff --git a/docs/testing/ai-assistant-evals.md b/docs/testing/ai-assistant-evals.md new file mode 100644 index 0000000000..77950cc27c --- /dev/null +++ b/docs/testing/ai-assistant-evals.md @@ -0,0 +1,255 @@ +# AI Assistant Evals + +The assistant eval suite runs the real agent against a live LLM to verify +end-to-end behaviour: tool selection, schema compatibility, row creation, etc. + +All eval tests live under +`enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/` and are +marked with `@pytest.mark.eval` so they are **skipped by default** in CI and +local test runs. + +## Prerequisites + +1. A running PostgreSQL database (see [running-tests.md](../development/running-tests.md)). +2. An API key for the LLM provider you want to test against. +3. **For `test_eval_search_user_docs` only:** an embeddings server and a + synced knowledge base (see [Search docs evals](#search-docs-evals) below). + +## Quick start + +```bash +# Set your API key (Groq example — works with any pydantic-ai provider) +export GROQ_API_KEY=gsk_... + +# Run all evals with the default model (groq:openai/gpt-oss-120b) +just b test ../enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/ \ + -m eval -v + +# Run a single eval file +just b test ../enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_core_builders.py \ + -m eval -v +``` + +> **Tip:** Do **not** pass `-s`. Without it, pytest captures `print_message_history` output and shows it only in the failure report — passing tests stay silent. Use `-s` only when you want to watch the agent's tool calls in real time for a single test. + + +## Configuration + +All configuration is via environment variables: + +| Variable | Default | Description | +|----------|---------|-------------| +| `EVAL_LLM_MODEL` | `groq:openai/gpt-oss-120b` | Model string in pydantic-ai format (`provider:model`). Accepts a comma-separated list to parametrize every eval across multiple models. | +| `EVAL_RETRIES` | `0` | Retry each failing eval test up to N times. If a test passes on retry it's a flake (LLM non-determinism); if it fails all N retries it's a consistent bug. | +| `GROQ_API_KEY` | — | Required when using a Groq model. | +| `OPENAI_API_KEY` | — | Required when using an OpenAI model. | +| `ANTHROPIC_API_KEY` | — | Required when using an Anthropic model. | + +### API keys from a file + +The eval conftest reads API keys from the same `TEST_ENV_FILE` that +`baserow/config/settings/test.py` already parses, and exposes them via +`os.environ` so that LLM provider SDKs can find them: + +```bash +TEST_ENV_FILE=.env.testing-local just b test \ + ../enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/ -m eval -v -s +``` + +Variables already present in `os.environ` take precedence. + +### Running against multiple models + +```bash +GROQ_API_KEY=... OPENAI_API_KEY=... EVAL_LLM_MODEL="groq:openai/gpt-oss-120b,openai:gpt-4o" \ +just b test ../enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/ \ + -m eval -v -s +``` + +Each test will run once per model, with the model name shown in the test ID. + +## Test files + +File names follow the pattern `test_eval_{module}_{feature}.py`, where module +maps to the tool directory (`core`, `database`, `automation`, `navigation`, +`search_user_docs`). Browse +`enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/` for the +full list. Each file defines its prompts as module-level `PROMPT_*` constants +at the top, making it easy to scan which scenarios are covered without reading +the test bodies. + +## Writing a new eval + +1. Create a new `test_eval_.py` file in the `evals/` directory. +2. Define prompts as `PROMPT_*` constants at the top, so it's easier to have an overview of the existing evals. +3. Mark each test with `@pytest.mark.eval` and + `@pytest.mark.django_db(transaction=True)`. +4. Use the helpers from `eval_utils.py`: + +```python +import pytest +from .eval_utils import ( + EvalChecklist, + build_database_ui_context, + count_tool_errors, + create_eval_assistant, + print_message_history, +) + +PROMPT_DOES_SOMETHING = "Do something useful in database {database_name}" + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_does_something(data_fixture, eval_model): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace, name="Test") + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=15, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database) + deps.tool_helpers.request_context["ui_context"] = ui_context + + result = agent.run_sync( + user_prompt=PROMPT_DOES_SOMETHING.format(database_name=database.name), + deps=deps, + model=model, + usage_limits=usage_limits, + toolsets=[toolset], + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + with EvalChecklist("does something") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + # Add domain-specific checks here + checks.check("created the thing", some_condition, hint="details if failed") +``` + +### Key helpers + +| Helper | Purpose | +|--------|---------| +| `create_eval_assistant(user, workspace, max_iters, model)` | Returns `(agent, deps, tracker, model, usage_limits, toolset)` configured like production. | +| `build_database_ui_context(user, workspace, database, table)` | Builds the UI context JSON the agent receives. | +| `count_tool_errors(result)` | Returns `(error_count, hint)` — count of tool validation errors (pydantic retries) and a formatted hint string. Use with `EvalChecklist`: `checks.check("no tool errors", err_count == 0, hint=err_hint)`. | +| `EvalChecklist(name)` | Context manager for soft assertions: collects checks, prints a score table (`4/6 (66%)`), and only hard-fails at the end. Use for tests with multiple independent checks. | +| `print_message_history(result)` | Prints the full agent conversation to stdout. | +| `format_message_history(result)` | Returns the conversation as a list of dicts for programmatic assertions. | + +## Search docs evals + +`test_eval_search_user_docs.py` tests the `search_user_docs` tool end-to-end: +the agent receives a real user question, decides to call the tool, the tool +performs a vector search against the knowledge base, and a sub-agent produces +an answer with source URLs. The test verifies that: + +1. The agent called `search_user_docs`. +2. The answer mentions expected concepts (e.g. "date_diff" for a date + formula question). +3. Returned source URLs match expected documentation pages (non-fatal + warning if not — URLs can change). + +### Additional prerequisites + +These tests are **automatically skipped** when the knowledge base is not +available. To enable them: + +1. **Embeddings server** — start the embeddings service and set: + ```bash + # Running tests outside Docker (local dev): + export BASEROW_EMBEDDINGS_API_URL=http://localhost:7999 + # Running tests inside Docker: + export BASEROW_EMBEDDINGS_API_URL=http://embeddings + ``` + +2. **pgvector extension** — the PostgreSQL instance must have the `vector` + extension installed. If you use the dev Docker setup this is already + included. + +3. **Sync the knowledge base** — the test suite handles this automatically + (see [Knowledge base caching](#knowledge-base-caching) below), but you + can also trigger a manual sync: + ```bash + # From the backend directory, with the Django env active: + python -m baserow sync_knowledge_base + ``` + This reads `website_export.csv` (user docs) and `docs/` (dev docs), + creates `KnowledgeBaseDocument` / `KnowledgeBaseChunk` rows, and + generates embeddings via the embeddings server. + +### Knowledge base caching + +Syncing the knowledge base is slow (it generates embeddings for every +documentation chunk). To avoid repeating this on every test run, the eval +suite uses two mechanisms together: + +1. **Session-scoped fixture** — the `synced_knowledge_base` fixture in + `conftest.py` runs once per pytest session. It checks whether the KB is + already populated (`handler.can_search()`) and only calls + `sync_knowledge_base()` when it isn't. + +2. **`--reuse-db`** — pytest-django's `--reuse-db` flag keeps the test + database between sessions instead of recreating it. Combined with the + fixture above, the expensive sync only happens on the very first run. + Subsequent runs detect that the data is already there and skip the sync + entirely. + +3. **No `transaction=True`** — search docs tests use + `@pytest.mark.django_db` (savepoint rollback) rather than + `@pytest.mark.django_db(transaction=True)` (full table truncation). This + is important: `transaction=True` would wipe the knowledge base tables + after each test, defeating the caching. + +**Typical workflow:** + +| Run | What happens | Time | +|-----|--------------|------| +| First ever | DB created, KB synced, tests run | Several minutes | +| Subsequent | DB reused, KB already populated, tests run | Seconds | + +To force a fresh sync (e.g. after schema changes or new documentation): + +```bash +# Drop and recreate the test DB, then re-sync +just b test ../enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_search_user_docs.py \ + -m eval -v -s --create-db +``` + +### Running search docs evals + +```bash +# Only search docs evals +just b test ../enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_search_user_docs.py \ + -m eval -v -s + +# A single test case by parametrize ID +just b test ../enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_search_user_docs.py \ + -m eval -v -s -k "vlookup-to-link-row" +``` + +If the embeddings server is not running or the knowledge base has not been +synced, all search docs tests will be skipped with a clear message. + +## Troubleshooting + +### `FAILED — No API key` + +Make sure the correct `*_API_KEY` env var is set for your provider/ + +### Flaky results + +LLM evals are inherently non-deterministic. If a test fails intermittently: + +- Use `EVAL_RETRIES` to automatically distinguish flakes from consistent bugs: + ```bash + EVAL_RETRIES=3 just b test \ + ../enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_database_tables.py \ + -m eval -v -s + ``` + A test that passes on retry is a flake; one that fails all 3 retries is a real problem. +- Check the printed message history (`-s` flag) to see what the agent did. +- If a prompt is ambiguous, tighten the wording in the `PROMPT_*` constant. +- Consider lowering the temperature in the model profile for the eval model. diff --git a/docs/testing/ai-assistant-test-plan.md b/docs/testing/ai-assistant-test-plan.md new file mode 100644 index 0000000000..ff9b12de83 --- /dev/null +++ b/docs/testing/ai-assistant-test-plan.md @@ -0,0 +1,134 @@ +# AI Assistant Test Plan + +## How to test + +### 1. Automated tests (unit) + +Run the unit test suite (no LLM needed): + +```bash +just b test -n auto ../enterprise/backend/tests/baserow_enterprise_tests/assistant/ \ + -v --ignore=enterprise/backend/tests/baserow_enterprise_tests/assistant/evals +``` + +All tests must pass. These cover: assistant orchestrator, all tool modules, +telemetry event emission, history compaction, and streaming. + +### 2. Automated tests (evals, optional) + +Run the eval suite against a live LLM. The default model is +`groq:openai/gpt-oss-120b`, so you need a `GROQ_API_KEY`. Evals that exercise +the `search_user_docs` tool also require a running embedding service — set +`BASEROW_EMBEDDINGS_API_URL` to point to it, or those evals will fail. + +```bash +GROQ_API_KEY=gsk_... BASEROW_EMBEDDINGS_API_URL=http://... \ +just b test ../enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/ \ + -m eval -v -s +``` + +> **Note:** Evals are non-deterministic and are not guaranteed to pass every +> run. When a failure occurs, check whether the model did something +> fundamentally wrong or whether the result is still acceptable. See +> [ai-assistant-evals.md](ai-assistant-evals.md) for details on configuration, +> multi-model runs, and how to interpret results. + +### 3. Manual: Tool smoke tests + +Open the assistant in the UI and verify each tool works end-to-end. Suggested +prompts: + +| Tool | Prompt | +|------|--------| +| `navigate` | "Go to the Customers table" | +| `list_builders` | "What builders do I have?" | +| `create_builders` | "Create a new application called Test App" | +| `list_tables` | "What tables are in my database?" | +| `get_tables_schema` | "Show me the schema of the Customers table" | +| `list_rows` | "Show me the first rows of the Customers table" | +| `list_views` | "What views does the Customers table have?" | +| `create_tables` | "Create a table called Projects with columns: Name (text), Status (single select: Active/Done), Due date (date)" | +| `create_fields` | "Add an email field to the Customers table" | +| `create_views` | "Create a kanban view grouped by Status on the Projects table" | +| `create_view_filters` | "Add a filter on the Projects grid view to only show Active rows" | +| `generate_formula` | "Add a formula field that concatenates first name and last name" | +| `update_fields` | "Rename the email field to Contact Email in the Customers table" | +| `delete_fields` | "Delete the Contact Email field from the Customers table" | +| `load_row_tools` | "Add a row to the Projects table: Name=Launch, Status=Active" (this implicitly triggers load_row_tools first) | +| `update_rows_in_table_X` | "Change the Status of the Launch row in Projects to Done" | +| `delete_rows_in_table_X` | "Delete the Launch row from the Projects table" | +| `list_workflows` | "What automations do I have?" | +| `create_workflows` | "Create an automation that sends a notification when a row is created in Projects" | +| `list_nodes` | "What nodes are in my first workflow?" | +| `add_nodes` | "Add a Slack notification action after the trigger in my workflow" | +| `update_nodes` | "Rename the trigger node to New Project Trigger" | +| `delete_nodes` | "Delete the Slack notification node from my workflow" | +| `search_user_docs` | "How do I create a lookup field?"* | + +* Make sure you synced the knowledge base first, look at [ai-assistant.md](../installation/ai-assistant.md) for more info. + +### 4. Manual: Feedback + +- Send a message, then click the thumbs-up/thumbs-down on the response +- Verify the feedback is recorded (no errors in the console/network tab) +- Refresh the page, the previously selected thumb up/down button must be highlighted + +### 5. Manual: Conversation memory (history) + +Test that the agent retains multi-turn context: + +1. Send: "My name is Mario" +2. Agent responds acknowledging +3. Send: "What's my name?" +4. Agent should respond "Mario" (proves history serialization/deserialization + via `message_history` field works) + +Also test a longer conversation (3+ turns) to verify the compaction doesn't +lose essential context. + +### 6. Manual: Telemetry / PostHog traces + +Requires PostHog configured (`POSTHOG_PROJECT_API_KEY`, `POSTHOG_HOST` etc.): + +1. Send a few messages exercising different tools +2. Go to PostHog > LLM Analytics > Traces +3. Verify: + - Each conversation turn appears as a `$ai_trace` + - Tool calls appear as `$ai_span` children + - LLM generations appear as `$ai_generation` with model name, token counts, + latency + - Input/output content is captured (not empty) + +### 7. Manual: Knowledge base (search_user_docs) + +Requires an embeddings server and synced KB (look at [ai-assistant.md](../installation/ai-assistant.md) for more info). Verify: + +- Ask a Baserow how-to question (e.g. "How do I set up SSO?") -> agent should + call `search_user_docs` and cite sources +- Ask a creative task (e.g. "Create a table for tracking expenses") -> agent + should NOT call search_user_docs, should just act +- Ask a question about the agent's own tools (e.g. "What tools do you have?") + -> agent should NOT search docs, should answer from its own knowledge + +### 8. Manual: Do vs. Describe + +Verify the agent acts rather than describes: + +- "Create a table called Invoices" -> should actually create it (call + `create_tables`), not describe how to do it +- "How would I create a table?" -> should describe the manual UI steps (no + tools available for this meta-question) or search docs +- After creating something, the agent should navigate to it to show the result + +### 9. Manual: Cancellation + +1. Send a long-running request (e.g. "Create a table with 10 fields") +2. Click cancel mid-execution +3. Verify the stream stops cleanly without error toasts + +### 10. Manual: Error handling + +- Misconfigure the LLM API key and try to chat -> should show a clear error, + not a stack trace +- Send a prompt referencing a non-existent table/database/any other resource -> agent should + handle gracefully diff --git a/enterprise/backend/pytest.ini b/enterprise/backend/pytest.ini index 9d3346f478..28c968f590 100644 --- a/enterprise/backend/pytest.ini +++ b/enterprise/backend/pytest.ini @@ -1,5 +1,7 @@ [pytest] DJANGO_SETTINGS_MODULE = baserow.config.settings.test python_files = test_*.py +markers = + eval: mark test as an eval test (requires LLM API key) env = DJANGO_SETTINGS_MODULE = baserow.config.settings.test diff --git a/enterprise/backend/src/baserow_enterprise/api/assistant/views.py b/enterprise/backend/src/baserow_enterprise/api/assistant/views.py index 1987e6f224..0c117df6cc 100644 --- a/enterprise/backend/src/baserow_enterprise/api/assistant/views.py +++ b/enterprise/backend/src/baserow_enterprise/api/assistant/views.py @@ -23,10 +23,7 @@ from baserow.api.sessions import set_client_undo_redo_action_group_id from baserow.core.exceptions import UserNotInWorkspace, WorkspaceDoesNotExist from baserow.core.handler import CoreHandler -from baserow_enterprise.assistant.assistant import ( - check_lm_ready_or_raise, - set_assistant_cancellation_key, -) +from baserow_enterprise.assistant.assistant import set_assistant_cancellation_key from baserow_enterprise.assistant.exceptions import ( AssistantChatDoesNotExist, AssistantChatMessagePredictionDoesNotExist, @@ -34,6 +31,7 @@ AssistantModelNotSupportedError, ) from baserow_enterprise.assistant.handler import AssistantHandler +from baserow_enterprise.assistant.model_profiles import check_lm_ready_or_raise from baserow_enterprise.assistant.models import AssistantChatPrediction from baserow_enterprise.assistant.operations import ChatAssistantChatOperationType from baserow_enterprise.assistant.types import ( diff --git a/enterprise/backend/src/baserow_enterprise/apps.py b/enterprise/backend/src/baserow_enterprise/apps.py index 7be1301b82..9a482941bb 100755 --- a/enterprise/backend/src/baserow_enterprise/apps.py +++ b/enterprise/backend/src/baserow_enterprise/apps.py @@ -1,4 +1,5 @@ from django.apps import AppConfig +from django.conf import settings from django.db.models.signals import post_migrate from tqdm import tqdm @@ -282,11 +283,12 @@ def ready(self): # Make sure that the assistant knowledge base is up to date after running the # migrations. - post_migrate.connect( - sync_assistant_knowledge_base, - sender=self, - dispatch_uid="sync_assistant_knowledge_base", - ) + if not settings.TESTS: + post_migrate.connect( + sync_assistant_knowledge_base, + sender=self, + dispatch_uid="sync_assistant_knowledge_base", + ) from baserow_enterprise.teams.receivers import ( connect_to_post_delete_signals_to_cascade_deletion_to_team_subjects, @@ -313,43 +315,6 @@ def ready(self): notification_type_registry.register(TwoWaySyncUpdateFailedNotificationType()) notification_type_registry.register(TwoWaySyncDeactivatedNotificationType()) - from baserow_enterprise.assistant.tools import ( - CreateBuildersToolType, - GenerateDatabaseFormulaToolType, - GetTablesSchemaToolType, - ListBuildersToolType, - ListRowsToolType, - ListTablesToolType, - ListViewsToolType, - ListWorkflowsToolType, - NavigationToolType, - RowsToolFactoryToolType, - SearchDocsToolType, - TableAndFieldsToolFactoryToolType, - ViewsToolFactoryToolType, - WorkflowToolFactoryToolType, - ) - from baserow_enterprise.assistant.tools.registries import ( - assistant_tool_registry, - ) - - assistant_tool_registry.register(SearchDocsToolType()) - assistant_tool_registry.register(NavigationToolType()) - - assistant_tool_registry.register(ListBuildersToolType()) - assistant_tool_registry.register(CreateBuildersToolType()) - assistant_tool_registry.register(ListTablesToolType()) - assistant_tool_registry.register(GetTablesSchemaToolType()) - assistant_tool_registry.register(TableAndFieldsToolFactoryToolType()) - assistant_tool_registry.register(GenerateDatabaseFormulaToolType()) - assistant_tool_registry.register(ListRowsToolType()) - assistant_tool_registry.register(RowsToolFactoryToolType()) - assistant_tool_registry.register(ListViewsToolType()) - assistant_tool_registry.register(ViewsToolFactoryToolType()) - - assistant_tool_registry.register(ListWorkflowsToolType()) - assistant_tool_registry.register(WorkflowToolFactoryToolType()) - from baserow_enterprise.views.operations import ( ListenToAllRestrictedViewEventsOperationType, ) @@ -376,6 +341,29 @@ def ready(self): page_registry.register(RestrictedViewPageType()) view_realtime_rows_registry.register(RestrictedViewRealtimeRowsType()) + from baserow_enterprise.assistant.tools.automation.tool_types import ( + AutomationToolType, + ) + from baserow_enterprise.assistant.tools.core.tool_types import CoreToolType + from baserow_enterprise.assistant.tools.database.tool_types import ( + DatabaseToolType, + ) + from baserow_enterprise.assistant.tools.navigation.tool_types import ( + NavigationToolType, + ) + from baserow_enterprise.assistant.tools.registries import ( + assistant_tool_registry, + ) + from baserow_enterprise.assistant.tools.search_user_docs.tool_types import ( + SearchDocsToolType, + ) + + assistant_tool_registry.register(NavigationToolType()) + assistant_tool_registry.register(CoreToolType()) + assistant_tool_registry.register(DatabaseToolType()) + assistant_tool_registry.register(AutomationToolType()) + assistant_tool_registry.register(SearchDocsToolType()) + # The signals must always be imported last because they use the registries # which need to be filled first. import baserow_enterprise.assistant.tasks # noqa: F401 diff --git a/enterprise/backend/src/baserow_enterprise/assistant/agents.py b/enterprise/backend/src/baserow_enterprise/assistant/agents.py new file mode 100644 index 0000000000..1af2110615 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/agents.py @@ -0,0 +1,81 @@ +from pydantic_ai import Agent, RunContext +from pydantic_ai.toolsets import FunctionToolset + +from baserow_enterprise.assistant.deps import AssistantDeps +from baserow_enterprise.assistant.prompts import AGENT_SYSTEM_PROMPT +from baserow_enterprise.assistant.tools.toolset import tool_manifest_line_compact + +main_agent: Agent[AssistantDeps, str] = Agent( + deps_type=AssistantDeps, + output_type=str, + instructions=AGENT_SYSTEM_PROMPT, + retries=3, + name="main_agent", +) + + +@main_agent.instructions +def dynamic_ui_context(ctx) -> str: + """Inject the UI context into the system prompt dynamically.""" + + ui_context = ctx.deps.tool_helpers.request_context.get("ui_context") + if ui_context: + return f"\n\n{ui_context}\n" + return "" + + +@main_agent.instructions +def dynamic_mode(ctx) -> str: + """Inject the current agent mode into the system prompt.""" + + return f"\n{ctx.deps.mode.value}" + + +@main_agent.instructions +def dynamic_current_task(ctx) -> str: + """Pin the original user request as immutable context.""" + + if ctx.deps.original_request: + return f"\n\n{ctx.deps.original_request}\n" + return "" + + +@main_agent.instructions +def dynamic_tool_manifest(ctx) -> str: + """ + Inject the available tools manifest into the system prompt, including both + static and dynamically loaded tools name and description. + """ + + manifest = ctx.deps.active_manifest + if not manifest: + return "" + + # Append dynamically loaded tools (e.g. row tools from load_row_tools) + if ctx.deps.dynamic_tools: + extra = "\n".join( + tool_manifest_line_compact(tool.name, tool.description or "") + for tool in ctx.deps.dynamic_tools + ) + manifest = manifest + "\n" + extra + + return f"\n\n{manifest}\n" + + +@main_agent.toolset +def dynamic_toolset(ctx: RunContext[AssistantDeps]): + """Make dynamically loaded tools available to the agent.""" + + if ctx.deps.dynamic_tools: + ts = FunctionToolset() + for tool in ctx.deps.dynamic_tools: + ts.add_tool(tool) + return ts + return None + + +title_agent: Agent[None, str] = Agent( + output_type=str, + instructions="Create a short title (max 50 chars) for the following user request.", + name="title_agent", +) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/assistant.py b/enterprise/backend/src/baserow_enterprise/assistant/assistant.py index a7679b4805..f5fbb5a423 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/assistant.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/assistant.py @@ -1,30 +1,56 @@ -from dataclasses import dataclass -from functools import lru_cache -from typing import Any, AsyncGenerator, Callable, Tuple, TypedDict +import asyncio +from typing import Any, AsyncGenerator -from django.conf import settings from django.core.cache import cache from django.utils import translation -import udspy -from udspy.callback import BaseCallback +from loguru import logger +from pydantic_ai._thinking_part import split_content_into_text_and_thinking +from pydantic_ai.messages import ( + FunctionToolCallEvent, + FunctionToolResultEvent, + ModelMessage, + ModelMessagesTypeAdapter, + PartDeltaEvent, + PartStartEvent, + TextPart, + TextPartDelta, + ThinkingPart, + ThinkingPartDelta, +) +from pydantic_ai.run import AgentRunResultEvent +from pydantic_ai.usage import UsageLimits from baserow.api.sessions import get_client_undo_redo_action_group_id -from baserow_enterprise.assistant.exceptions import ( - AssistantMessageCancelled, - AssistantModelNotSupportedError, +from baserow_enterprise.assistant.agents import main_agent, title_agent +from baserow_enterprise.assistant.deps import ( + AgentMode, + AssistantDeps, + EventBus, + QueueEvent, + QueueEventKind, + ToolHelpers, +) +from baserow_enterprise.assistant.exceptions import AssistantMessageCancelled +from baserow_enterprise.assistant.history import compact_message_history +from baserow_enterprise.assistant.model_profiles import ( + ORCHESTRATOR, + TITLE, + get_model_settings, + get_model_string, +) +from baserow_enterprise.assistant.retrying_model import RetryingModel +from baserow_enterprise.assistant.telemetry import ( + PosthogTracingCallback, + setup_instrumentation, ) -from baserow_enterprise.assistant.telemetry import PosthogTracingCallback -from baserow_enterprise.assistant.tools.navigation.types import AnyNavigationRequestType from baserow_enterprise.assistant.tools.navigation.utils import unsafe_navigate_to from baserow_enterprise.assistant.tools.registries import assistant_tool_registry from .models import AssistantChat, AssistantChatMessage, AssistantChatPrediction -from .signatures import ChatSignature from .types import ( AiMessage, AiMessageChunk, - AiNavigationMessage, AiReasoningChunk, AiStartedMessage, AiThinkingMessage, @@ -33,176 +59,119 @@ HumanMessage, ) +_CANCELLATION_KEY_TTL = 300 # seconds +_THINKING_TAGS = ("", "") -@dataclass -class ToolHelpers: - update_status: Callable[[str], None] - navigate_to: Callable[["AnyNavigationRequestType"], str] - - -class AssistantMessagePair(TypedDict): - question: str - answer: str +def _strip_think_tags(text: str) -> str: + """Remove ``...`` blocks from *text*, returning only the + non-thinking content. Uses pydantic-ai's own tag parser. -class AssistantCallbacks(BaseCallback): - def __init__(self, tool_helpers: ToolHelpers | None = None): - self.tool_helpers = tool_helpers - self.tool_calls = {} - self.sources = [] - - def extend_sources(self, sources: list[str]) -> None: - """ - Extends the current list of sources with new ones, avoiding duplicates. - - :param sources: The list of new source URLs to add. - :return: None - """ - - self.sources.extend([s for s in sources if s not in self.sources]) - - def on_tool_start( - self, - call_id: str, - instance: Any, - inputs: dict[str, Any], - ) -> None: - """ - Called when a tool starts. It records the tool call and invokes the - corresponding tool's on_tool_start method if it exists. - - :param call_id: The unique identifier of the tool call. - :param instance: The instance of the tool being called. - :param inputs: The inputs provided to the tool. - """ - - try: - assistant_tool_registry.get(instance.name).on_tool_start( - call_id, instance, inputs - ) - self.tool_calls[call_id] = (instance, inputs) - except assistant_tool_registry.does_not_exist_exception_class: - pass - - def on_tool_end( - self, - call_id: str, - outputs: dict[str, Any] | None, - exception: Exception | None = None, - ) -> None: - """ - Called when a tool ends. It invokes the corresponding tool's on_tool_end - method if it exists and updates the sources if the tool produced any. - - :param call_id: The unique identifier of the tool call. - :param outputs: The outputs returned by the tool, or None if there was an - exception. - :param exception: The exception raised by the tool, or None if it was - successful. - """ + Also strips any trailing unclosed ```` block that may appear + during streaming (the closing tag hasn't arrived yet). + """ - if call_id not in self.tool_calls: - return + if "" not in text: + return text - instance, inputs = self.tool_calls.pop(call_id) - assistant_tool_registry.get(instance.name).on_tool_end( - call_id, instance, inputs, outputs, exception - ) + # Strip any trailing unclosed block (common during streaming) + last_open = text.rfind("") + last_close = text.rfind("") + if last_open > last_close: + text = text[:last_open] - if exception is not None and self.tool_helpers is not None: - self.tool_helpers.update_status( - f"Calling the {instance.name} tool encountered an error." - ) + if "" not in text: + return text.strip() - # If the tool produced sources, add them to the overall list of sources. - if isinstance(outputs, dict) and "sources" in outputs: - self.extend_sources(outputs["sources"]) + parts = split_content_into_text_and_thinking(text, _THINKING_TAGS) + return "".join(p.content for p in parts if not isinstance(p, ThinkingPart)).strip() def get_assistant_cancellation_key(chat_uuid: str) -> str: - """ - Get the Redis cache key for cancellation tracking. - - :param chat_uuid: The UUID of the assistant chat. - :return: The cache key as a string. - """ + """Return the cache key used to signal cancellation for a chat session.""" return f"assistant:chat:{chat_uuid}:cancelled" -def set_assistant_cancellation_key(chat_uuid: str, timeout: int = 300) -> None: - """ - Set the cancellation flag in the cache for the given chat UUID. +def set_assistant_cancellation_key( + chat_uuid: str, timeout: int = _CANCELLATION_KEY_TTL +) -> None: + """Set the cancellation flag in the cache for a chat session.""" - :param chat_uuid: The UUID of the assistant chat. - :param timeout: The time in seconds after which the cancellation flag expires. - """ + cache.set(get_assistant_cancellation_key(chat_uuid), True, timeout=timeout) - cache_key = get_assistant_cancellation_key(chat_uuid) - cache.set(cache_key, True, timeout=timeout) +def _extract_tool_thought(event: FunctionToolCallEvent) -> str | None: + """Extract the chain-of-thought ``thought`` argument from a tool call + event, if present and non-empty.""" -def get_lm_client( - model: str | None = None, -) -> "Assistant": - """ - Returns a udspy.LM client configured with the specified model or the default model. - - :param model: The language model to use. If None, the default model from settings - will be used. - :return: A udspy.LM instance. - """ + try: + args = event.part.args_as_dict() + except Exception: + return None + thought = args.get("thought") + return thought if isinstance(thought, str) and thought.strip() else None - return udspy.LM(model=model or settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL) +class Assistant: + """Orchestrates a single assistant chat session. -@lru_cache(maxsize=1) -def check_lm_ready_or_raise() -> None: - """ - Checks if the configured LLM is ready by making a test call. Raises - AssistantModelNotSupportedError if the model is not supported or accessible. + Wires together the pydantic-ai agent, toolsets, telemetry, event + streaming, and message persistence for one ``AssistantChat``. """ - lm = get_lm_client() - try: - lm("Respond in JSON: {'response': 'ok'}") - except Exception as e: - raise AssistantModelNotSupportedError( - f"The model '{lm.model}' is not supported or accessible: {e}" - ) - - -class Assistant: def __init__(self, chat: AssistantChat): self._chat = chat self._user = chat.user self._workspace = chat.workspace + self._model_string = get_model_string() + self._model = RetryingModel(self._model_string) + self._event_bus = EventBus() + self._tool_helpers = self._build_tool_helpers() + self._telemetry = PosthogTracingCallback() + + self._deps = AssistantDeps( + user=self._user, + workspace=self._workspace, + tool_helpers=self._tool_helpers, + ) + self._toolset, db_m, app_m, auto_m, explain_m = ( + assistant_tool_registry.build_toolset( + user=self._user, + workspace=self._workspace, + model=self._model_string, + deps=self._deps, + ) + ) + self._deps.database_manifest = db_m + self._deps.application_manifest = app_m + self._deps.automation_manifest = auto_m + self._deps.explain_manifest = explain_m - self._lm_client = get_lm_client() - self._init_assistant() + setup_instrumentation() - def _init_assistant(self): - self.history = None - self.tool_helpers = self.get_tool_helpers() - tools = [ - t if isinstance(t, udspy.Tool) else udspy.Tool(t) - for t in assistant_tool_registry.list_all_usable_tools( - self._user, self._workspace, self.tool_helpers - ) - ] - - self._assistant_callbacks = AssistantCallbacks(self.tool_helpers) - self._telemetry_callbacks = PosthogTracingCallback() - self._callbacks = [self._assistant_callbacks, self._telemetry_callbacks] - - module_kwargs = { - "temperature": settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_TEMPERATURE, - "response_format": {"type": "json_object"}, - } - self._assistant = udspy.ReAct( - ChatSignature, tools=tools, max_iters=20, **module_kwargs + # ------------------------------------------------------------------ + # Setup + # ------------------------------------------------------------------ + + def _build_tool_helpers(self) -> ToolHelpers: + """Create the ``ToolHelpers`` that tools use for status updates, + navigation, and cancellation during the agent run.""" + + def update_status(status: str): + with translation.override(self._user.profile.language): + self._event_bus.emit(AiThinkingMessage(content=status)) + + return ToolHelpers( + update_status=update_status, + navigate_to=lambda loc: unsafe_navigate_to(loc, self._event_bus), + event_bus=self._event_bus, ) + # ------------------------------------------------------------------ + # Message persistence + # ------------------------------------------------------------------ + async def acreate_chat_message( self, role: AssistantChatMessage.Role, @@ -210,37 +179,24 @@ async def acreate_chat_message( artifacts: dict[str, Any] | None = None, **kwargs, ) -> AssistantChatMessage: - """ - Creates and saves a new chat message. - - :param role: The role of the message (human or AI). - :param content: The content of the message. - :param artifacts: Optional artifacts associated with the message. - :return: The created AssistantChatMessage instance. - """ + """Persist a new chat message to the database.""" message = AssistantChatMessage( - chat=self._chat, - role=role, - content=content, - **kwargs, + chat=self._chat, role=role, content=content, **kwargs ) if artifacts: message.artifacts = artifacts - await message.asave() return message def list_chat_messages( self, last_message_id: int | None = None, limit: int = 100 - ) -> list[AssistantChatMessage]: - """ - Lists all chat messages in chronological order. + ) -> list[AssistantMessageUnion]: + """Return recent chat messages, oldest-first. - :param last_message_id: The ID of the last message received. If provided, only - messages before this ID will be returned. - :param limit: The maximum number of messages to return. - :return: A list of AssistantChatMessage instances. + :param last_message_id: If set, only return messages with ``id`` + below this value (cursor-based pagination). + :param limit: Maximum number of messages to return. """ queryset = ( @@ -251,7 +207,7 @@ def list_chat_messages( if last_message_id is not None: queryset = queryset.filter(id__lt=last_message_id) - messages = [] + messages: list[AssistantMessageUnion] = [] for msg in queryset[:limit]: if msg.role == AssistantChatMessage.Role.HUMAN: messages.append( @@ -276,267 +232,393 @@ def list_chat_messages( ) return list(reversed(messages)) - async def afetch_chat_history(self, limit: int = 50) -> udspy.History: - """ - Loads the chat history into a udspy.History object. It only loads complete - message pairs (human + AI). The history will be in chronological order and must - respect the module signature (question, answer). + async def _save_ai_response( + self, human_msg: AssistantChatMessage, answer: str + ) -> AiMessage: + """Persist the AI answer and create a prediction record for + feedback tracking.""" - :param limit: The maximum number of message pairs to load. - :return: A udspy.History instance containing the chat history. - """ + sources = self._deps.sources + ai_msg = await self.acreate_chat_message( + AssistantChatMessage.Role.AI, + answer, + artifacts={"sources": sources}, + action_group_id=get_client_undo_redo_action_group_id(self._user), + ) + await AssistantChatPrediction.objects.acreate( + human_message=human_msg, + ai_response=ai_msg, + prediction={"answer": answer}, + ) + return AiMessage( + id=ai_msg.id, + content=answer, + sources=sources, + can_submit_feedback=True, + ) - history = udspy.History() - last_saved_messages: list[AssistantChatMessage] = [ - msg async for msg in self._chat.messages.order_by("-created_on")[:limit] - ] - - while len(last_saved_messages) >= 2: - # Pop the oldest message pair to respect chronological order. - first_message = last_saved_messages.pop() - next_message = last_saved_messages[-1] - if ( - first_message.role != AssistantChatMessage.Role.HUMAN - or next_message.role != AssistantChatMessage.Role.AI - ): - continue + # ------------------------------------------------------------------ + # Message history (pydantic-ai ModelMessage round-trips) + # ------------------------------------------------------------------ - history.add_user_message(first_message.content) - assistant_answer = last_saved_messages.pop() - history.add_assistant_message(assistant_answer.content) + async def _save_message_history(self, messages_json: bytes) -> None: + """Persist the serialised pydantic-ai message history on the chat.""" - return history + self._chat.message_history = messages_json + await self._chat.asave(update_fields=["message_history", "updated_on"]) - def get_tool_helpers(self) -> ToolHelpers: - def update_status_localized(status: str): - """ - Sends a localized message to the frontend to update the assistant status. + async def _load_message_history(self) -> list[ModelMessage] | None: + """Deserialise and compact the stored message history, returning + ``None`` if absent or corrupt.""" - :param status: The status message to send. - """ + raw = self._chat.message_history + if not raw: + return None + try: + messages = ModelMessagesTypeAdapter.validate_json(bytes(raw)) + return compact_message_history(messages) + except Exception: + logger.opt(exception=True).warning( + "Failed to load message history for chat {}, starting fresh", + self._chat.pk, + ) + return None - with translation.override(self._user.profile.language): - udspy.emit_event(AiThinkingMessage(content=status)) + # ------------------------------------------------------------------ + # Agent execution + # ------------------------------------------------------------------ - return ToolHelpers( - update_status=update_status_localized, - navigate_to=unsafe_navigate_to, + async def _generate_chat_title(self, user_message: str) -> str: + """Ask the title agent to summarise a user message into a short + chat title.""" + + result = await title_agent.run( + user_message, + model=self._model, + model_settings=get_model_settings(self._model_string, TITLE), ) + return result.output - async def _generate_chat_title(self, user_message: str) -> str: - """ - Generates a title for the chat based on the user message and AI response. + _MAX_TOOL_CALL_AS_TEXT_RETRIES = 2 - :param user_message: The latest user message in the chat. - :return: The generated chat title. - """ + _TOOL_CALL_CORRECTION_PROMPT = ( + "Your previous response contained a raw JSON tool call instead of " + "actually invoking the tool. The malformed output was:\n\n" + "{malformed_output}\n\n" + "Please call the tool directly using the proper tool-calling " + "mechanism instead of outputting JSON text. Make sure the " + "arguments conform to the tool's schema." + ) - title_generator = udspy.Predict( - udspy.Signature.from_string( - "user_message -> chat_title", - "Create a short title for the following user request.", + async def _emit_answer( + self, + answer: str, + run_result: Any, + queue: asyncio.Queue[QueueEvent], + ) -> None: + """Push the final answer and result events onto *queue*.""" + + await queue.put( + QueueEvent( + kind=QueueEventKind.STREAM, + message=AiMessageChunk(content=answer, sources=self._deps.sources), ) ) - rsp = await title_generator.aforward( - user_message=user_message, + queue.put_nowait( + QueueEvent( + kind=QueueEventKind.RESULT, + answer=answer, + messages_json=run_result.all_messages_json(), + ) ) - return rsp.chat_title - async def _acreate_ai_message_response( + async def _run_agent( self, - human_msg: HumanMessage, - prediction: udspy.Prediction, - ) -> AiMessage: - """ - Creates and saves an AI chat message response based on the prediction. Stores - the prediction in AssistantChatPrediction, linking it to the human message, so - it can be referenced later to provide feedback. + user_prompt: str, + message_history: list[ModelMessage] | None, + queue: asyncio.Queue[QueueEvent], + ) -> None: + """Execute the main agent, retrying if it outputs tool calls as text. - :param human_msg: The human message instance. - :param prediction: The udspy.Prediction instance containing the AI response. - :return: The created AiMessage instance to return to the user. - """ + Delegates each streaming pass to ``_stream_agent_run``. If the + final output looks like a raw JSON tool call, re-runs the agent + with the conversation history and a corrective prompt (up to + ``_MAX_TOOL_CALL_AS_TEXT_RETRIES`` times) so the model can + self-correct and invoke the tool properly. - sources = self._assistant_callbacks.sources - ai_msg = await self.acreate_chat_message( - AssistantChatMessage.Role.AI, - prediction.answer, - artifacts={"sources": sources}, - action_group_id=get_client_undo_redo_action_group_id(self._user), - ) + Pushes ``STREAM``, ``RESULT``, ``ERROR``, and ``DONE`` events + onto *queue* for the consumer in ``astream_messages``. + """ - await AssistantChatPrediction.objects.acreate( - human_message=human_msg, - ai_response=ai_msg, - prediction={k: v for k, v in prediction.items() if k != "module"}, - ) + try: + with self._telemetry.trace(self._chat, user_prompt) as tracer: + answer, run_result = await self._run_agent_with_retries( + user_prompt, message_history, queue + ) + tracer.set_trace_output(answer) + await self._emit_answer(answer, run_result, queue) + except Exception as exc: + logger.exception("Error running main agent") + queue.put_nowait(QueueEvent(kind=QueueEventKind.ERROR, error=exc)) + finally: + queue.put_nowait(QueueEvent(kind=QueueEventKind.DONE)) + + async def _run_agent_with_retries( + self, + user_prompt: str, + message_history: list[ModelMessage] | None, + queue: asyncio.Queue[QueueEvent], + ) -> tuple[str, Any]: + """Stream the agent, retrying on tool-call-as-text outputs. - # Yield final complete message - return AiMessage( - id=ai_msg.id, - content=prediction.answer, - sources=sources, - can_submit_feedback=True, - ) + Returns ``(answer, run_result)`` — either the model's valid + answer or a fallback message after exhausting retries. - def _get_cancellation_cache_key(self) -> str: + :raises RuntimeError: if the stream ends without a result event. """ - Get the Redis cache key for cancellation tracking. - :return: The cache key as a string. - """ + current_prompt = user_prompt + current_history = message_history - return get_assistant_cancellation_key(self._chat.uuid) + for attempt in range(1 + self._MAX_TOOL_CALL_AS_TEXT_RETRIES): + result = await self._stream_agent_run( + current_prompt, current_history, queue + ) + if result is None: + raise RuntimeError("Agent stream ended without a result event") - def _check_cancellation(self, cache_key: str, message_id: str) -> None: - """ - Check if the message generation has been cancelled. + answer, run_result = result - :param cache_key: The cache key to check for cancellation. - :param message_id: The ID of the message being generated. - :raises AssistantMessageCancelled: If the message generation has been cancelled. - """ + if not self._looks_like_json_tool_call(answer): + return answer, run_result - if cache.get(cache_key): - cache.delete(cache_key) - raise AssistantMessageCancelled(message_id=message_id) + logger.warning( + "[assistant] Model output tool call as text (attempt {}/{}): {}", + attempt + 1, + 1 + self._MAX_TOOL_CALL_AS_TEXT_RETRIES, + answer[:200], + ) - async def _process_agent_stream( - self, - event: Any, - human_msg: AssistantChatMessage, - ) -> Tuple[list[AssistantMessageUnion], udspy.Prediction | None]: - """ - Process a single event from the output stream. + if attempt < self._MAX_TOOL_CALL_AS_TEXT_RETRIES: + # Replace the malformed JSON visible in the UI with a + # reasoning indicator so the user doesn't see garbage. + await queue.put( + QueueEvent( + kind=QueueEventKind.STREAM, + message=AiReasoningChunk(content=""), + ) + ) + current_history = run_result.all_messages() + current_prompt = self._TOOL_CALL_CORRECTION_PROMPT.format( + malformed_output=answer[:500] + ) + + # Exhausted retries — give up gracefully. + logger.error( + "[assistant] Model persisted outputting tool " + "calls as text after {} retries", + self._MAX_TOOL_CALL_AS_TEXT_RETRIES, + ) + fallback = ( + "I ran into a temporary issue processing " + "your request. Could you please try again?" + ) + return fallback, run_result - :param event: The event to process. - :param human_msg: The human message instance. - :return: a tuple of (messages_to_yield, prediction). + async def _stream_agent_run( + self, + user_prompt: str, + message_history: list[ModelMessage] | None, + queue: asyncio.Queue[QueueEvent], + ) -> tuple[str, Any] | None: + """Run a single agent streaming pass. + + Streams reasoning/text chunks to *queue* and returns + ``(answer, run_result)`` when an ``AgentRunResultEvent`` is + received, or ``None`` if the stream ends without one. """ - messages = [] - prediction = None + reasoning_so_far = "" - if isinstance(event, (AiThinkingMessage, AiNavigationMessage)): - messages.append(event) - return messages, prediction + async for event in main_agent.run_stream_events( + user_prompt=user_prompt, + deps=self._deps, + model=self._model, + message_history=message_history, + usage_limits=UsageLimits(request_limit=200), + toolsets=[self._toolset], + model_settings=get_model_settings(self._model_string, ORCHESTRATOR), + ): + if isinstance(event, AgentRunResultEvent): + answer = event.result.output + if isinstance(answer, str): + answer = _strip_think_tags(answer) + return (answer, event.result) + + if isinstance(event, FunctionToolCallEvent): + thought = _extract_tool_thought(event) + if thought: + reasoning_so_far += thought + cleaned = _strip_think_tags(reasoning_so_far) + await self._enqueue_reasoning(queue, cleaned) + continue - # Stream the final answer - if isinstance(event, udspy.OutputStreamChunk): - if ( - event.field_name == "answer" - and event.module is self._assistant.extract_module - ): - messages.append( - AiMessageChunk( - content=event.content, - sources=self._assistant_callbacks.sources, - ) - ) + if isinstance(event, FunctionToolResultEvent): + reasoning_so_far = "" # reset on tool results, to show the reasoning leading up to the next tool call + continue - elif isinstance(event, udspy.Prediction): - # final prediction contains the answer to the user question - if event.module is self._assistant: - prediction = event - ai_msg = await self._acreate_ai_message_response(human_msg, prediction) - messages.append(ai_msg) + # Accumulate text/thinking deltas and send full reasoning. + # The frontend replaces content on each chunk, so we must + # send the complete text every time. + content = self._get_content_delta(event) + if content: + reasoning_so_far += content + cleaned = _strip_think_tags(reasoning_so_far) + if cleaned: + await self._enqueue_reasoning(queue, cleaned) - elif reasoning := getattr(event, "next_thought", None): - messages.append(AiReasoningChunk(content=reasoning)) + return None - return messages, prediction + @staticmethod + def _get_content_delta(event: Any) -> str | None: + """Extract text or thinking content from a stream event delta.""" - def get_agent_stream( - self, message: HumanMessage, conversation_history: udspy.History | None = None - ) -> AsyncGenerator[Any, None]: - """ - Returns an async generator that streams the ReAct agent's response to a user - message. + if isinstance(event, PartStartEvent) and isinstance( + event.part, (TextPart, ThinkingPart) + ): + return event.part.content or None + if isinstance(event, PartDeltaEvent) and isinstance( + event.delta, (TextPartDelta, ThinkingPartDelta) + ): + return event.delta.content_delta or None + return None - :param user_message: The message from the user. - :return: An async generator that yields stream events. - """ + @staticmethod + async def _enqueue_reasoning( + queue: asyncio.Queue[QueueEvent], content: str + ) -> None: + """Push an ``AiReasoningChunk`` onto *queue*.""" - formatted_history = ( - ChatSignature.format_conversation_history(conversation_history) - if conversation_history - else [] - ) - formatted_ui_context = ( - message.ui_context.format() if message.ui_context else None + await queue.put( + QueueEvent( + kind=QueueEventKind.STREAM, + message=AiReasoningChunk(content=content), + ) ) - return self._assistant.astream( - question=message.content, - conversation_history=formatted_history, - ui_context=formatted_ui_context, - ) + @staticmethod + def _looks_like_json_tool_call(text: str) -> bool: + """Return True if *text* looks like a tool call dumped as JSON. - async def _process_stream( - self, - human_msg: HumanMessage, - stream: AsyncGenerator[Any, None], - process_event_func: Callable[ - [Any, AssistantChatMessage], - Tuple[list[AssistantMessageUnion], udspy.Prediction | None], - ], - ) -> AsyncGenerator[Tuple[AssistantMessageUnion, udspy.Prediction | None], None]: - chunk_count = 0 - cancellation_key = self._get_cancellation_cache_key() - message_id = str(human_msg.id) + Checks for ``{"name": ..., "arguments": ...}`` pattern in the first + 200 chars. Does not require valid JSON (the output may be truncated). + """ - async for event in stream: - # Periodically check for cancellation - chunk_count += 1 - if chunk_count % 10 == 0: - self._check_cancellation(cancellation_key, message_id) + stripped = text.strip() + return ( + bool(stripped) + and stripped[0] == "{" + and '"name"' in stripped[:200] + and '"arguments"' in stripped[:200] + ) + + # ------------------------------------------------------------------ + # Cancellation + # ------------------------------------------------------------------ - messages, prediction = await process_event_func(event, human_msg) + async def _monitor_cancellation(self, task: asyncio.Task) -> None: + """Poll the cache for a cancellation flag and cancel *task* if + set. Runs as a concurrent task alongside the agent.""" - if messages: # Don't return responses if cancelled - self._check_cancellation(cancellation_key, message_id) + cache_key = get_assistant_cancellation_key(self._chat.uuid) + while not task.done(): + await asyncio.sleep(0.2) + if cache.get(cache_key): + cache.delete(cache_key) + self._tool_helpers.cancel() + task.cancel() + return - for msg in messages: - yield msg, prediction + # ------------------------------------------------------------------ + # Public streaming API + # ------------------------------------------------------------------ async def astream_messages( self, message: HumanMessage ) -> AsyncGenerator[AssistantMessageUnion, None]: - """ - Streams the response to a user message. + """Stream the full response lifecycle for a user message. - :param human_message: The message from the user. - :return: An async generator that yields the response messages. + Yields events in order: ``AiStartedMessage``, zero or more + streaming chunks (``AiMessageChunk`` / ``AiReasoningChunk`` / + ``AiThinkingMessage``), and finally an ``AiMessage`` with the + persisted answer. A ``ChatTitleMessage`` is appended on the first + message in a chat. """ + # Sticky task: capture on first message of the session + if not self._deps.original_request: + self._deps.original_request = message.content + + # Auto-detect starting mode from UI context (only on first message) + if message.ui_context: + if message.ui_context.application or message.ui_context.page: + self._deps.mode = AgentMode.APPLICATION + elif message.ui_context.automation or message.ui_context.workflow: + self._deps.mode = AgentMode.AUTOMATION + # else stays DATABASE (default) + human_msg = await self.acreate_chat_message( - AssistantChatMessage.Role.HUMAN, - message.content, + AssistantChatMessage.Role.HUMAN, message.content ) - default_callbacks = udspy.settings.callbacks - - with ( - udspy.settings.context( - lm=self._lm_client, - callbacks=[*default_callbacks, *self._callbacks], - ), - self._telemetry_callbacks.trace(self._chat, human_msg.content), - ): - message_id = str(human_msg.id) - yield AiStartedMessage(message_id=message_id) + message_id = str(human_msg.id) + yield AiStartedMessage(message_id=message_id) - history = await self.afetch_chat_history(limit=30) + ui_context = message.ui_context.format() if message.ui_context else None + self._tool_helpers.request_context["ui_context"] = ui_context + message_history = await self._load_message_history() - agent_stream = self.get_agent_stream(message, history) + queue: asyncio.Queue[QueueEvent] = asyncio.Queue() + self._event_bus.set_queue(queue) - async for msg, __ in self._process_stream( - human_msg, agent_stream, self._process_agent_stream - ): - yield msg + agent_task = asyncio.create_task( + self._run_agent(message.content, message_history, queue) + ) + monitor_task = asyncio.create_task(self._monitor_cancellation(agent_task)) - # Generate chat title if needed - if not self._chat.title: - chat_title = await self._generate_chat_title(human_msg.content) - self._chat.title = chat_title + try: + answer = None + messages_json = None + + while True: + event = await queue.get() + if event.kind == QueueEventKind.DONE: + break + elif event.kind == QueueEventKind.RESULT: + answer, messages_json = event.answer, event.messages_json + elif event.kind == QueueEventKind.ERROR: + raise event.error + else: + yield event.message + + if agent_task.cancelled(): + raise AssistantMessageCancelled(message_id=message_id) + + if answer is not None: + yield await self._save_ai_response(human_msg, answer) + if messages_json: + await self._save_message_history(messages_json) + finally: + monitor_task.cancel() + if not agent_task.done(): + agent_task.cancel() + await asyncio.gather(monitor_task, agent_task, return_exceptions=True) + self._event_bus.set_queue(None) + + if not self._chat.title: + try: + title = await self._generate_chat_title(human_msg.content) + self._chat.title = title[: AssistantChat.TITLE_MAX_LENGTH] await self._chat.asave(update_fields=["title", "updated_on"]) - yield ChatTitleMessage(content=chat_title) + yield ChatTitleMessage(content=self._chat.title) + except Exception: + logger.exception("Failed to generate chat title") diff --git a/enterprise/backend/src/baserow_enterprise/assistant/deps.py b/enterprise/backend/src/baserow_enterprise/assistant/deps.py new file mode 100644 index 0000000000..f4bb95db97 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/deps.py @@ -0,0 +1,148 @@ +import asyncio +import threading +from dataclasses import dataclass, field +from enum import Enum, auto +from typing import TYPE_CHECKING, Any, Callable + +from pydantic_ai import Tool + +if TYPE_CHECKING: + from django.contrib.auth.models import AbstractUser + + from baserow.core.models import Workspace + from baserow_enterprise.assistant.tools.navigation.types import ( + AnyNavigationRequestType, + ) + + +class AgentMode(str, Enum): + """Operating mode that controls which tools are available to the agent.""" + + DATABASE = "database" + APPLICATION = "application" + AUTOMATION = "automation" + EXPLAIN = "explain" + + +class QueueEventKind(Enum): + STREAM = auto() + RESULT = auto() + ERROR = auto() + DONE = auto() + + +@dataclass +class QueueEvent: + kind: QueueEventKind + message: Any = None + answer: str = "" + messages_json: bytes = b"" + error: Exception | None = None + + +@dataclass +class EventBus: + """ + Pushes streaming events into the queue consumed by + Assistant.astream_messages(). Events are silently dropped when no + queue is attached. + """ + + _queue: asyncio.Queue[QueueEvent] | None = None + + def set_queue(self, queue: asyncio.Queue[QueueEvent] | None): + self._queue = queue + + def emit(self, event): + if self._queue is not None: + self._queue.put_nowait( + QueueEvent(kind=QueueEventKind.STREAM, message=event) + ) + + +@dataclass +class ToolHelpers: + """ + Contextual helpers available to every tool via ``RunContext[AssistantDeps]``. + + Provides status updates (shown in the UI), navigation actions, + cancellation support, and an event bus for emitting custom streaming + events (thinking messages, navigation messages, etc.). + """ + + update_status: Callable[[str], None] + navigate_to: Callable[["AnyNavigationRequestType"], str] + request_context: dict = field(default_factory=dict) + event_bus: EventBus = field(default_factory=EventBus) + _cancel_event: threading.Event = field(default_factory=threading.Event) + + def raise_if_cancelled(self) -> None: + """Check cancellation and raise if set. Thread-safe. + + Call this in tool loops or between expensive operations. + Raises ``CancelledError`` (``BaseException``) which escapes the + agent's ``except Exception`` handler and propagates through the + async chain. + """ + + if self._cancel_event.is_set(): + raise asyncio.CancelledError() + + @property + def is_cancelled(self) -> bool: + """Check if cancelled without raising. Thread-safe.""" + + return self._cancel_event.is_set() + + def cancel(self) -> None: + """Signal cancellation to running tools. Thread-safe.""" + + self._cancel_event.set() + + +@dataclass +class AssistantDeps: + """ + Typed dependency container for the pydantic-ai agent. + + Every agent run operates on behalf of a user in a given workspace. + This runtime-context also allows tools to share information (e.g. + sources), provide helpers for emitting events or requesting navigation, + switch between domain modes, and dynamically extend the toolset by + adding tools to ``dynamic_tools`` during a run (e.g. row tools loaded + by the database agent). + + Passed via ``deps=`` to every ``agent.run()`` / ``agent.run_stream()`` + call and accessible in tools via ``RunContext[AssistantDeps].deps``. + """ + + user: "AbstractUser" + workspace: "Workspace" + tool_helpers: ToolHelpers + mode: AgentMode = AgentMode.DATABASE + sources: list[str] = field(default_factory=list) + dynamic_tools: list[Tool] = field(default_factory=list) + database_manifest: str = "" + application_manifest: str = "" + automation_manifest: str = "" + explain_manifest: str = "" + original_request: str = "" + + @property + def active_manifest(self) -> str: + return { + AgentMode.DATABASE: self.database_manifest, + AgentMode.APPLICATION: self.application_manifest, + AgentMode.AUTOMATION: self.automation_manifest, + AgentMode.EXPLAIN: self.explain_manifest, + }[self.mode] + + def extend_sources(self, new_sources: list[str]): + """ + Extend the current list of sources with new ones, avoiding + duplicates. + + :param new_sources: The list of new source URLs to add. + """ + + self.sources.extend(s for s in new_sources if s not in self.sources) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/handler.py b/enterprise/backend/src/baserow_enterprise/assistant/handler.py index bfe8b2bf72..6d98e98935 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/handler.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/handler.py @@ -147,12 +147,11 @@ async def astream_assistant_messages( :param chat: The AI assistant chat to get the assistant for. :param human_message: The new message from the user. - :param ui_ontext: The UI context where the message was sent. + :param ui_context: The UI context where the message was sent. :return: An async generator yielding messages from the assistant. """ assistant = self.get_assistant(chat) - async for message in assistant.astream_messages( - human_message, ui_context=ui_context - ): - yield message + message = HumanMessage(content=human_message, ui_context=ui_context) + async for msg in assistant.astream_messages(message): + yield msg diff --git a/enterprise/backend/src/baserow_enterprise/assistant/history.py b/enterprise/backend/src/baserow_enterprise/assistant/history.py new file mode 100644 index 0000000000..c25e699550 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/history.py @@ -0,0 +1,118 @@ +""" +Utilities for compacting and trimming pydantic-ai message histories. + +The assistant persists the full message history (including intermediate tool +calls) across turns. Before feeding it back into the agent we compact each +turn down to (user prompt, final answer) and trim to a fixed window so the +context doesn't grow unboundedly. +""" + +from pydantic_ai.messages import ( + ModelMessage, + ModelRequest, + ModelResponse, + TextPart, + UserPromptPart, +) + +# The number of messages to keep in the compacted history for context. This is +# a simple safeguard to prevent excessively long histories from bloating the +# context. +MAX_HISTORY_MESSAGES = 20 + + +def _has_user_prompt(msg: ModelMessage) -> bool: + """Check if a ModelRequest contains a UserPromptPart.""" + + return isinstance(msg, ModelRequest) and any( + isinstance(p, UserPromptPart) for p in msg.parts + ) + + +def _get_final_text_response(turn: list[ModelMessage]) -> ModelResponse | None: + """ + Return the last ModelResponse in the turn that contains a TextPart, + or None if no such response exists. + """ + + for msg in reversed(turn): + if isinstance(msg, ModelResponse) and any( + isinstance(p, TextPart) for p in msg.parts + ): + return msg + return None + + +def _split_into_turns( + messages: list[ModelMessage], +) -> list[list[ModelMessage]]: + """ + Split a flat message list into turns. Each turn starts at a ModelRequest + that contains a UserPromptPart. + + Messages before the first UserPromptPart (e.g. initial system instructions) + are grouped into a leading "turn 0". + """ + + turns: list[list[ModelMessage]] = [] + current: list[ModelMessage] = [] + + for msg in messages: + if _has_user_prompt(msg) and current: + turns.append(current) + current = [] + current.append(msg) + + if current: + turns.append(current) + + return turns + + +def _compact_turn(turn: list[ModelMessage]) -> list[ModelMessage]: + """ + Compact a single turn. If the turn has more than 2 messages (user prompt + + final answer), strip intermediate tool call/return messages and keep + only the user prompt request and the final text response. + + Returns the turn unchanged if it has no tool calls or no final text + response. + """ + + if len(turn) <= 2: + return turn + + # Find the user prompt request (first message) and final text response + user_request = turn[0] if _has_user_prompt(turn[0]) else None + final_response = _get_final_text_response(turn) + + if user_request and final_response: + return [user_request, final_response] + + # Cannot compact -- return as-is + return turn + + +def compact_message_history( + messages: list[ModelMessage], + max_messages: int = MAX_HISTORY_MESSAGES, +) -> list[ModelMessage]: + """ + Compact and trim a pydantic-ai message history for multi-turn context. + + 1. Splits messages into turns (delimited by UserPromptPart). + 2. For each turn with intermediate tool calls, collapses to just the + user prompt and final text answer. + 3. Trims to the last ``max_messages`` messages if still too long. + """ + + turns = _split_into_turns(messages) + + compacted: list[ModelMessage] = [] + for turn in turns: + compacted.extend(_compact_turn(turn)) + + if len(compacted) > max_messages: + compacted = compacted[-max_messages:] + + return compacted diff --git a/enterprise/backend/src/baserow_enterprise/assistant/model_profiles.py b/enterprise/backend/src/baserow_enterprise/assistant/model_profiles.py new file mode 100644 index 0000000000..25f19a625b --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/model_profiles.py @@ -0,0 +1,175 @@ +""" +Centralized model configuration and per-model settings for all agents. + +Contains: +- ``get_model_string()``: Resolves the active LLM model identifier. +- ``check_lm_ready_or_raise()``: Quick connectivity check. +- ``get_model_settings(model, role)``: Per-model, per-role settings. + +Usage:: + + from baserow_enterprise.assistant.model_profiles import ( + get_model_string, get_model_settings, ORCHESTRATOR, + ) + + model = get_model_string() + settings = get_model_settings(model, ORCHESTRATOR) +""" + +from functools import lru_cache + +from django.conf import settings + +from pydantic_ai import Agent +from pydantic_ai.settings import ModelSettings + +from baserow_enterprise.assistant.exceptions import AssistantModelNotSupportedError +from baserow_enterprise.assistant.models import AssistantChat + +# --------------------------------------------------------------------------- +# Agent roles +# --------------------------------------------------------------------------- + +ORCHESTRATOR = "orchestrator" +SUBAGENT = "subagent" # database, builder, automations +UTILITY = "utility" # formula, fixer (precision-oriented) +SAMPLE = "sample" # sample row generation (creative) +TITLE = "title" # title generation + +# --------------------------------------------------------------------------- +# Per-model profiles +# --------------------------------------------------------------------------- + +# Fallback when the model isn't in _MODEL_PROFILES +_DEFAULT_PROFILE: dict[str, ModelSettings] = { + ORCHESTRATOR: { + "temperature": 0.3, + "timeout": 30, + "parallel_tool_calls": False, + "max_tokens": 16384, + }, + SUBAGENT: { + "temperature": 0.3, + "timeout": 20, + "parallel_tool_calls": False, + "max_tokens": 16384, + }, + UTILITY: { + "temperature": 0.1, + "timeout": 20, + }, + SAMPLE: { + "temperature": 0.5, + "timeout": 20, + }, + TITLE: { + "temperature": 0.7, + "timeout": 10, + "max_tokens": AssistantChat.TITLE_MAX_LENGTH, + }, +} + +_MODEL_PROFILES: dict[str, dict[str, ModelSettings]] = { + "gpt-oss-120b": { + ORCHESTRATOR: { + **_DEFAULT_PROFILE[ORCHESTRATOR], + "groq_reasoning_format": "parsed", + }, + SUBAGENT: { + **_DEFAULT_PROFILE[SUBAGENT], + "groq_reasoning_format": "parsed", + }, + UTILITY: { + # No groq_reasoning_format here: formula generation is a precise + # structured-output task where reasoning tokens pollute the output. + **_DEFAULT_PROFILE[UTILITY], + }, + SAMPLE: { + **_DEFAULT_PROFILE[SAMPLE], + "groq_reasoning_format": "parsed", + }, + TITLE: { + **_DEFAULT_PROFILE[TITLE], + }, + }, +} + + +def get_model_settings(model: str, role: str) -> ModelSettings: + """ + Return the ModelSettings for a given model string and agent role. + + The model string is the pydantic-ai format (e.g. ``"groq:openai/gpt-oss-120b"``). + We match on the last path segment (e.g. ``"gpt-oss-120b"``) to find the profile. + + For the ``ORCHESTRATOR`` role the temperature defaults to the value of + ``BASEROW_ENTERPRISE_ASSISTANT_LLM_TEMPERATURE`` (if set), allowing + operators to override it without changing code. + + :param model: pydantic-ai model string (e.g. ``"groq:openai/gpt-oss-120b"``). + :param role: One of ORCHESTRATOR, SUBAGENT, UTILITY, TITLE. + :return: A ModelSettings dict suitable for ``model_settings=`` parameter. + """ + + # Extract model name after the provider prefix: + # "groq:openai/gpt-oss-120b" -> "gpt-oss-120b" + # "ollama:kimi-2.5:cloud" -> "kimi-2.5:cloud" + _, sep, after_provider = model.partition(":") + model_name = after_provider.rsplit("/", 1)[-1] if sep else model + + profile = _MODEL_PROFILES.get(model_name, _DEFAULT_PROFILE) + result = dict(profile.get(role, _DEFAULT_PROFILE.get(role, {}))) + + # Allow the env-var-driven setting to override the orchestrator temperature. + if role == ORCHESTRATOR: + env_temp = getattr( + settings, "BASEROW_ENTERPRISE_ASSISTANT_LLM_TEMPERATURE", None + ) + if env_temp is not None: + result["temperature"] = env_temp + + return result + + +# --------------------------------------------------------------------------- +# Model resolution +# --------------------------------------------------------------------------- + + +def get_model_string(model: str | None = None) -> str: + """ + Returns the model string for the pydantic-ai agent. + + :param model: The language model to use. If None, the default model from + settings will be used. + :return: A model string compatible with pydantic-ai. + """ + + value = model or settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL + # pydantic-ai expects "provider:model" (e.g. "groq:openai/gpt-oss-120b"). + # Convert "provider/model" to "provider:model" when the first "/" comes + # before the first ":" (or there is no ":"). This handles cases like + # "ollama/kimi-2.5:cloud" where the colon is part of the model tag. + slash_pos = value.find("/") + colon_pos = value.find(":") + if slash_pos != -1 and (colon_pos == -1 or slash_pos < colon_pos): + value = value.replace("/", ":", 1) + elif slash_pos == -1 and colon_pos == -1: + # No provider prefix at all (e.g. "gpt-4o") — default to OpenAI + # for backward compatibility with old UDSPY_LM_MODEL values. + value = f"openai:{value}" + return value + + +@lru_cache(maxsize=1) +def check_lm_ready_or_raise() -> None: + model = get_model_string() + test_agent = Agent( + output_type=str, instructions="Respond with 'ok'.", name="test_agent" + ) + try: + test_agent.run_sync("Test", model=model) + except Exception as e: + raise AssistantModelNotSupportedError( + f"The model '{model}' is not supported or accessible: {e}" + ) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/models.py b/enterprise/backend/src/baserow_enterprise/assistant/models.py index 9c48402a96..250bc5e1c0 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/models.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/models.py @@ -42,6 +42,14 @@ class Status(models.TextChoices): status = models.CharField( max_length=20, choices=Status.choices, default=Status.IDLE ) + message_history = models.BinaryField( + null=True, + blank=True, + help_text=( + "Serialized pydantic-ai message history (JSON bytes) for " + "multi-turn conversation context." + ), + ) class Meta: indexes = [ diff --git a/enterprise/backend/src/baserow_enterprise/assistant/prompts.py b/enterprise/backend/src/baserow_enterprise/assistant/prompts.py index 84d949a61c..bc8ff080f0 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/prompts.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/prompts.py @@ -1,167 +1,57 @@ from django.conf import settings -CORE_CONCEPTS = """ -### BASEROW STRUCTURE - -**Structure**: Workspace → Databases, Applications, Automations, Dashboards, Snapshots - -**Key concepts**: -• **Roles**: Free (admin, member) | Advanced/Enterprise (admin, builder, editor, viewer, no access) -• **Features**: Real-time collaboration, SSO (SAML2/OIDC/OAuth2), MCP integration, API access, Audit logs -• **Plans**: Free, Premium, Advanced, Enterprise (https://baserow.io/pricing) -• **Open Source**: Core is open source (https://github.com/baserow/baserow) -• **Snapshots**: Application-level backups -""" - -DATABASE_BUILDER_CONCEPTS = """ -### DATABASE BUILDER (no-code database) - -**Structure**: Database → Tables → Fields + Views + Webhooks + Rows. Rows → comments. - -**Key concepts**: -• **Fields**: Define schema (30+ types including link_row for relationships); one primary field per table -• **Views**: Present data with filters/sorts/grouping/colors; can be shared, personal, or public -• **Rows**: Data records following the table schema; support for rich content (files, long text, formulas, numbers, dates, etc.). Changes are tracked in history. -• **Comments**: Threaded discussions on rows; mentions. -• **Formulas**: Computed fields using functions/operators; support for cross-table lookups -• **Permissions**: RBAC at workspace/database/table/field levels; database tokens for API -• **Data sync**: Table replication; **Webhooks**: Row/field/view event triggers -""" - -APPLICATION_BUILDER_CONCEPTS = """ -### APPLICATION BUILDER (visual app builder) - -**Structure**: Application → Pages → Elements + Data Sources + Workflows - -**Key concepts**: -• **Pages**: Routes with UI elements (buttons, tables, forms, etc.) -• **Data Sources**: Connect to database tables/views; elements bind to them for dynamic content -• **Formulas**: Reference data from previous nodes and compute values using functions/operators in nodes attributes -• **Workflows**: Event-driven actions (create/update rows, navigate, notifications) -• **Publishing**: Requires domain configuration +AGENT_IDENTITY = """\ + +You are Kuma, an AI expert for Baserow (open-source no-code platform). \ +You are an autonomous tool-calling agent. Whenever possible, you act — you do not describe. + +""" + +RULES = """\ + +1. Use the `thought` parameter on EVERY tool call to state your reasoning. +2. Have tools → call them. No tools in current mode → check other modes before saying something is not possible. If another mode has the tool, switch_mode and use it. Only explain manual UI steps if no mode covers the action. +3. One tool per turn. Wait for the result. Never reply and call a tool in same turn. +4. Verify after create/modify — navigate to show the result. +5. Request priority: action > follow-up (reuse prior IDs, never search docs) > question. When a tool result contains next_steps, act on them immediately — do not ask for permission to continue. +6. You start in the mode matching your UI context (database/application/automation). If the user asks a how-to or feature question, call switch_mode("explain"), then search_user_docs. +7. After finishing the tool calls in a different mode (not just after switching — after the actual work is done and results received), switch back to the original domain mode (check and ). +8. Reply in concise Markdown. Never expose raw JSON or internal IDs unless asked. +9. When a request references resources by name/ID, verify they exist (list_*) before building on them. If not found, ask — don't guess. But when the task *requires* creating resources in another domain (e.g. building an app that needs new tables), switch_mode and create them yourself — don't ask the user to do it manually. +10. Before responding to the user, verify ALL parts of `` are addressed. If anything is missing, continue working. +11. Before adding a table to a database or a page to an application, check that the target is semantically related. If the name/purpose doesn't match, ask the user which target to use or whether to create a new one. Examples of mismatches: adding "Inquiries" table to a "Project Management" DB; adding "Event Registration" pages to a "Portfolio Website" app. This applies to ALL resource creation — tables, pages, and the applications/databases themselves. Remember their answer — only re-ask when a new, different mismatch arises. + +""" + +HANDLING_AMBIGUITY = """\ + +Ambiguous terms — pick by context, confirm only if truly unclear: +- "table" → App Builder: Table element | Database: database table +- "form" → App Builder: Form element | Database: Form view +- "workflow action" → App Builder: element action | Automations: action node + +""" + +BASEROW_KNOWLEDGE = """\ + +Workspace → Databases, Applications, Automations, Dashboards +Database → Tables → Fields (30+ types, link_row for relations) + Views (grid, form, kanban, calendar, gallery, timeline) + Rows +Application → Pages → Elements + Data Sources + Actions +Automation → Workflows → Trigger + Action/Router/Iterator nodes (use {{ node.ref }} for formulas) + +""" + +LIMITATIONS_AND_SOURCES = f"""\ + +Cannot create/modify/delete: user accounts, workspaces, dashboards, widgets, snapshots, webhooks, integrations, roles, permissions. +Docs: search_user_docs | API: {settings.PUBLIC_BACKEND_URL}/api/schema.json | Web: https://baserow.io | Community: https://community.baserow.io + """ -AUTOMATION_BUILDER_CONCEPTS = """ -### AUTOMATIONS (no-code automation builder) - -**Structure**: Automation → Workflows → Trigger + Actions + Routers (Nodes) - -**Key concepts**: -• **Trigger**: The single event that starts the workflow (e.g., row created/updated/deleted) -• **Actions**: Tasks performed (e.g., create/update rows, send emails, call webhooks) -• **Routers**: Conditional logic (if/else, switch) to control flow -• **Iterators**: Loop over lists of items -• **Formulas**: Reference data from previous nodes and compute values using functions/operators in nodes attributes -• **Execution**: Runs in the background; monitor via logs -• **History**: Track runs, successes, failures -• **Publishing**: Requires at least one configured action -""" - -AGENT_LIMITATIONS = """ -## LIMITATIONS - -### CANNOT CREATE: -• User accounts, workspaces -• Applications, pages -• Dashboards, widgets -• Snapshots, webhooks, integrations -• Roles, permissions - -### CANNOT UPDATE/MODIFY: -• User, workspace, or integration settings -• Roles, permissions -• Applications, pages -• Dashboards, widgets - -### CANNOT DELETE: -• Users, workspaces -• Roles, permissions -• Applications, pages -• Dashboards, widgets -""" - -ASSISTANT_SYSTEM_PROMPT_BASE = ( - f""" -You are Kuma, an AI expert for Baserow (open-source no-code platform). - -## YOUR KNOWLEDGE -1. **Core concepts** (below) -2. **Detailed docs** - use search_user_docs tool to search when needed -3. **API specs** - guide users to "{settings.PUBLIC_BACKEND_URL}/api/schema.json" -4. **Official website** - "https://baserow.io" -5. **Community support** - "https://community.baserow.io" -6. **Direct support** - for Advanced/Enterprise plan users - -## ANSWER FORMATTING GUIDELINES -• Use American English spelling and grammar -• Only use Markdown (bold, italics, lists, code blocks) -• Prefer lists in explanations. Numbered lists for steps; bulleted for others. -• Use code blocks for examples, commands, snippets -• Be concise and clear in your response - -## BASEROW CONCEPTS -""" - + CORE_CONCEPTS - + DATABASE_BUILDER_CONCEPTS - + APPLICATION_BUILDER_CONCEPTS - + AUTOMATION_BUILDER_CONCEPTS -) - AGENT_SYSTEM_PROMPT = ( - ASSISTANT_SYSTEM_PROMPT_BASE - + """ -## YOUR TOOLS - -**CRITICAL - Understanding your tools:** -- Learn what each tool does ONLY from its **name** and **description** -- **NEVER use `search_user_docs` to learn about your tools** - it contains end-user documentation, NOT information about your available tools or how to call them -- `search_user_docs` is ONLY for answering user questions about Baserow features and providing manual instructions - -## REQUEST HANDLING - -### ACTION REQUESTS - CHECK FIRST - -**CRITICAL: Before treating a request as a question, determine if it's an action you can perform.** - -Recognize action requests by: -- Imperative verbs: "Show...", "Filter...", "Create...", "Add...", "Delete...", "Update...", "Sort...", "Hide..." -- Desired states: "I want only...", "I need a field that...", "Make it show..." -- Example: "Show only rows where the primary field is empty" → This is an ACTION (create a filter), not a question about filtering - -**DO vs EXPLAIN:** -- If you have tools to do it → **DO IT** -- If you lack tools → **THEN explain** how to do it manually -- **NEVER explain how to do something you can do yourself** - -**Workflow:** -1. Check your tools - can you fulfill this? -2. **YES**: Execute (ask for clarification only if request is ambiguous) -3. **NO** (see LIMITATIONS): Explain you can't, then provide manual instructions from docs - -### QUESTIONS (only after ruling out action requests) - -**FACTUAL QUESTIONS** - asking what Baserow IS or HAS: -- Examples: "Does Baserow have X feature?", "How does Y work?", "What options exist for Z?" -- These have objectively correct/incorrect answers that must come from documentation -- **ALWAYS search documentation first** using `search_user_docs` -- Check the `reliability_note` in the response: - - **HIGH CONFIDENCE**: Present the answer confidently with sources - - **PARTIAL MATCH**: Provide the answer but note some details may be incomplete - - **LOW CONFIDENCE / NOTHING FOUND**: Tell the user you couldn't find this in the documentation. **DO NOT guess or assume features exist** - if docs don't mention it (e.g., a "barcode field"), it likely doesn't exist. Suggest checking the community forum or contacting support. -- **NEVER fabricate Baserow features or capabilities** - -**ADVISORY QUESTIONS** - asking how to USE or APPLY Baserow: -- Examples: "How should I structure X?", "What's a good approach for Y?", "Help me build Z", "Which field type works best for W?" -- These ask for your expertise in applying Baserow to solve problems - there's no single correct answer -- **Use your knowledge** of Baserow's real capabilities (field types, views, formulas, automations, linking, etc.) to provide helpful recommendations -- You may search docs for reference, but can also directly advise based on your understanding of Baserow -- Focus on practical solutions using actual Baserow functionality - -**Key principle**: Never fabricate what Baserow CAN do. Freely advise on HOW to use what Baserow actually offers. -""" - + AGENT_LIMITATIONS - + """ - -## TASK INSTRUCTIONS: -""" + AGENT_IDENTITY + + RULES + + HANDLING_AMBIGUITY + + BASEROW_KNOWLEDGE + + LIMITATIONS_AND_SOURCES ) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/retrying_model.py b/enterprise/backend/src/baserow_enterprise/assistant/retrying_model.py new file mode 100644 index 0000000000..007f55398d --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/retrying_model.py @@ -0,0 +1,448 @@ +""" +A pydantic-ai Model wrapper that retries on transient provider errors. + +Provider SDKs (Groq, Anthropic, OpenAI) sometimes raise exceptions that +are transient — e.g. ``groq.APIError: Failed to parse tool call arguments +as JSON``. pydantic-ai handles *some* of these (e.g. ``tool_use_failed`` +with a structured body), but others slip through. + +``RetryingModel`` wraps any pydantic-ai ``Model`` and adds retry logic +around ``request()`` with configurable back-off. + +Streaming recovery +------------------ +pydantic-ai's ``GroqStreamedResponse`` catches ``APIError`` with +``tool_use_failed`` bodies, but only when the ``failed_generation`` JSON +is valid. When Groq sends **truly malformed** JSON (not just +schema-invalid), pydantic-ai's ``Json[...]`` type fails to parse it and +re-raises the raw ``APIError``. + +Since this error occurs *during* stream consumption (after yield), +``@asynccontextmanager`` cannot yield a replacement. Instead we wrap +the stream in ``_ErrorRecoveringStream`` which intercepts ``APIError`` +in its ``_get_event_iterator`` and emits a ``ToolCallPart`` (or +``TextPart``) so pydantic-ai's validation loop can tell the model +what was wrong. + +For errors that occur *before* the stream is established (during +``request_stream`` setup), we fall back to the retrying ``request()`` +method and wrap the result in ``_PreFetchedResponse``. +""" + +from __future__ import annotations + +import asyncio +import json +import re +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from datetime import datetime, timezone +from typing import Any + +from loguru import logger +from pydantic_ai._run_context import RunContext +from pydantic_ai.exceptions import ModelHTTPError +from pydantic_ai.messages import ModelMessage, ModelResponse, ToolCallPart +from pydantic_ai.models import ( + KnownModelName, + Model, + ModelRequestParameters, + ModelResponseStreamEvent, + StreamedResponse, + infer_model, +) +from pydantic_ai.models.wrapper import WrapperModel +from pydantic_ai.settings import ModelSettings + +# Transient Groq errors that are safe to retry. +_RETRYABLE_MESSAGES = frozenset( + { + "Failed to parse tool call arguments as JSON", + "Tool call validation failed", + } +) + + +def _is_transient_provider_error(exc: Exception) -> bool: + """Return True for provider errors that are transient and safe to retry.""" + + msg = str(exc) + return any(needle in msg for needle in _RETRYABLE_MESSAGES) + + +def _extract_tool_use_failed(body: dict) -> dict | None: + """Extract ``tool_use_failed`` error dict from an error body. + + Handles both wrapped (``{"error": {...}}``) and unwrapped layouts + (the Groq SDK streaming path sets ``body=data["error"]``). + """ + + error = body.get("error", body) + if not isinstance(error, dict): + return None + if error.get("code") != "tool_use_failed": + return None + return error + + +_TOOL_NAME_RE = re.compile(r'"name"\s*:\s*"([^"]+)"') + + +def _extract_tool_name(failed_gen: str) -> str: + """Best-effort tool name extraction from truncated/malformed JSON.""" + + m = _TOOL_NAME_RE.search(failed_gen) + return m.group(1) if m else "unknown" + + +def _recover_failed_generation(failed_gen: str, model_name: str = "") -> ModelResponse: + """Turn a ``failed_generation`` string into a synthetic ``ModelResponse``. + + If the JSON is valid and contains ``name`` + ``arguments``, returns a + ``ToolCallPart`` so pydantic-ai's validation loop can tell the model + what was wrong. For truly malformed JSON, extracts the tool name + (best-effort) and returns a ``ToolCallPart`` with empty args so + pydantic-ai's validation rejects it and sends a retry prompt. + """ + + try: + parsed = json.loads(failed_gen) + if isinstance(parsed, dict) and "name" in parsed and "arguments" in parsed: + return ModelResponse( + parts=[ + ToolCallPart( + tool_name=parsed["name"], + args=json.dumps(parsed["arguments"]), + ) + ], + model_name=model_name, + ) + except (json.JSONDecodeError, TypeError): + pass + + # JSON is truly malformed (e.g. truncated). We must NOT fall back to a + # TextPart here because the stream may have already started emitting + # tool-call events — mixing TextPart into a tool-call stream causes + # pydantic-ai's AgentStream to fail with "unable to find output". + # + # Instead, try to extract the tool name from partial JSON and emit a + # ToolCallPart with empty args. pydantic-ai's validation will reject + # the args and send a retry prompt to the model. + tool_name = _extract_tool_name(failed_gen) + return ModelResponse( + parts=[ + ToolCallPart( + tool_name=tool_name, + args="{}", + ) + ], + model_name=model_name, + ) + + +def _try_recover_tool_use_failed(exc: Exception) -> ModelResponse | None: + """Try to recover a ``tool_use_failed`` error into a ``ModelResponse``. + + Works with both ``ModelHTTPError`` (non-streaming path) and raw + provider ``APIError`` (streaming path where pydantic-ai's handler + couldn't parse the malformed JSON). + """ + + if isinstance(exc, ModelHTTPError): + body = exc.body + model_name = exc.model_name + elif hasattr(exc, "body"): + # Raw provider APIError (e.g. groq.APIError). + body = exc.body # type: ignore[union-attr] + model_name = "" + else: + return None + + if not isinstance(body, dict): + return None + + error = _extract_tool_use_failed(body) + if error is None: + return None + + failed_gen = error.get("failed_generation") + if not failed_gen or not isinstance(failed_gen, str): + return None + + return _recover_failed_generation(failed_gen, model_name) + + +def _resolve_model(model_name: str) -> Model: + """Resolve a model name to a pydantic-ai Model instance. + + For Google models, constructs the model with a fresh + ``httpx.AsyncClient`` instead of relying on ``infer_model()`` which + uses a process-global cached client. That cached client binds to the + event loop at creation time and breaks when reused on a different loop + (common in Django async views). + See: https://github.com/pydantic/pydantic-ai/issues/3240 + """ + + if model_name.startswith(("google-gla:", "google:", "google-vertex:")): + import httpx + from pydantic_ai.models.google import GoogleModel + from pydantic_ai.providers.google import GoogleProvider + + vertexai = model_name.startswith("google-vertex:") + google_model_name = model_name.split(":", 1)[1] + return GoogleModel( + google_model_name, + provider=GoogleProvider(http_client=httpx.AsyncClient(), vertexai=vertexai), + ) + + return infer_model(model_name) + + +class RetryingModel(WrapperModel): + """Model wrapper that retries ``request()`` on transient provider errors. + + Model resolution is deferred until the first actual call so that + constructing a ``RetryingModel`` from a model name string does not + require provider API keys to be available at import/init time. + + Only ``request()`` has a retry loop. ``request_stream()`` falls back + to ``request()`` when the stream raises a retryable error, since + ``@asynccontextmanager`` only allows a single ``yield``. + """ + + def __init__( + self, + wrapped: Model | KnownModelName, + *, + max_attempts: int = 3, + base_delay: float = 1.0, + max_delay: float = 10.0, + ): + # Bypass WrapperModel.__init__ to defer infer_model. + Model.__init__(self) + self._wrapped_or_name = wrapped + self._resolved: Model | None = None + self.max_attempts = max_attempts + self.base_delay = base_delay + self.max_delay = max_delay + + @property + def wrapped(self) -> Model: + if self._resolved is None: + self._resolved = ( + self._wrapped_or_name + if isinstance(self._wrapped_or_name, Model) + else _resolve_model(self._wrapped_or_name) + ) + return self._resolved + + @wrapped.setter + def wrapped(self, value: Model) -> None: + self._resolved = value + + def _delay_for(self, attempt: int) -> float: + """Exponential back-off delay capped at ``max_delay``.""" + return min(self.base_delay * (2 ** (attempt - 1)), self.max_delay) + + async def request( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> ModelResponse: + for attempt in range(1, self.max_attempts + 1): + try: + return await self.wrapped.request( + messages, model_settings, model_request_parameters + ) + except Exception as exc: + # Try to recover tool_use_failed into a response so + # pydantic-ai's validation loop can tell the model what + # was wrong (instead of blindly retrying the same request). + recovered = _try_recover_tool_use_failed(exc) + if recovered is not None: + logger.info( + "[assistant] Recovered tool_use_failed error into ModelResponse" + ) + return recovered + + if ( + not _is_transient_provider_error(exc) + or attempt == self.max_attempts + ): + raise + delay = self._delay_for(attempt) + logger.warning( + "[assistant] Model request failed (attempt {}/{}), " + "retrying in {:.1f}s: {}", + attempt, + self.max_attempts, + delay, + repr(exc), + ) + await asyncio.sleep(delay) + raise RuntimeError("Exhausted retries") # pragma: no cover + + @asynccontextmanager + async def request_stream( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, + ) -> AsyncIterator[StreamedResponse]: + yielded = False + try: + async with self.wrapped.request_stream( + messages, model_settings, model_request_parameters, run_context + ) as stream: + yielded = True + # Wrap the stream so that errors *during* chunk iteration + # (e.g. groq.APIError with malformed failed_generation) + # are caught and converted to recovery events rather than + # crashing the entire agent run. + yield _ErrorRecoveringStream(stream) + except Exception as exc: + if yielded: + # Error during stream consumption that + # _ErrorRecoveringStream couldn't handle. + raise + + # Setup error — try to recover tool_use_failed. + recovered = _try_recover_tool_use_failed(exc) + if recovered is not None: + logger.info( + "[assistant] Recovered tool_use_failed error " + "in stream into ModelResponse" + ) + yield _PreFetchedResponse(recovered, model_request_parameters) + return + + if not _is_transient_provider_error(exc): + raise + # Stream failed with a retryable error. Fall back to a + # non-streaming request which has its own retry loop. + logger.warning( + "[assistant] Stream failed with retryable error, " + "falling back to non-streaming request: {}", + repr(exc), + ) + response = await self.request( + messages, model_settings, model_request_parameters + ) + yield _PreFetchedResponse(response, model_request_parameters) + + +class _ErrorRecoveringStream(StreamedResponse): + """Transparent proxy around a ``StreamedResponse`` that catches provider + errors during chunk iteration and converts ``tool_use_failed`` errors + (even with malformed JSON) into recovery events. + + pydantic-ai's ``GroqStreamedResponse`` already handles ``tool_use_failed`` + when the ``failed_generation`` JSON is *valid*, but fails when it is + truly malformed because ``Json[_GroqToolUseFailedGeneration]`` raises + ``ValidationError``. This wrapper catches the re-raised ``APIError`` + and emits a ``ToolCallPart`` or ``TextPart`` so pydantic-ai's + validation loop can tell the model what was wrong. + """ + + # Dataclass fields on StreamedResponse that have class-level defaults + # (e.g. ``final_result_event = None``). These shadow ``__getattr__`` + # because Python finds the class attribute before calling __getattr__. + # We override them as properties so reads delegate to ``_inner``. + final_result_event = property(lambda self: self._inner.final_result_event) # type: ignore[assignment] + provider_response_id = property(lambda self: self._inner.provider_response_id) # type: ignore[assignment] + provider_details = property(lambda self: self._inner.provider_details) # type: ignore[assignment] + finish_reason = property(lambda self: self._inner.finish_reason) # type: ignore[assignment] + + def __init__(self, inner: StreamedResponse): + # Don't call super().__init__() — delegate everything to *inner*. + # Only store our own _inner and _event_iterator on the instance. + object.__setattr__(self, "_inner", inner) + object.__setattr__(self, "_event_iterator", None) + + def __getattr__(self, name: str) -> Any: + return getattr(self._inner, name) + + def __setattr__(self, name: str, value: Any) -> None: + if name in ("_inner", "_event_iterator"): + object.__setattr__(self, name, value) + else: + setattr(self._inner, name, value) + + async def _get_event_iterator( + self, + ) -> AsyncIterator[ModelResponseStreamEvent]: + try: + async for event in self._inner._get_event_iterator(): + yield event + except Exception as exc: + recovered = _try_recover_tool_use_failed(exc) + if recovered is None: + raise + logger.info( + "[assistant] Recovered tool_use_failed error during stream consumption" + ) + for i, part in enumerate(recovered.parts): + yield self._parts_manager.handle_part( + vendor_part_id=f"recovered-{i}", part=part + ) + + # Abstract properties — delegate to inner stream. + + @property + def model_name(self) -> str: + return self._inner.model_name + + @property + def provider_name(self) -> str | None: + return self._inner.provider_name + + @property + def provider_url(self) -> str | None: + return self._inner.provider_url + + @property + def timestamp(self) -> datetime: + return self._inner.timestamp + + +class _PreFetchedResponse(StreamedResponse): + """A ``StreamedResponse`` backed by an already-complete ``ModelResponse``. + + Used when ``request_stream`` falls back to ``request()`` after a + retryable streaming error. Emits all response parts as immediate + ``PartStartEvent`` s so pydantic-ai can process them normally. + """ + + def __init__( + self, + response: ModelResponse, + model_request_parameters: ModelRequestParameters, + ): + super().__init__(model_request_parameters=model_request_parameters) + self._response = response + self._usage.input_tokens = response.usage.input_tokens + self._usage.output_tokens = response.usage.output_tokens + + async def _get_event_iterator( + self, + ) -> AsyncIterator[ModelResponseStreamEvent]: + for i, part in enumerate(self._response.parts): + yield self._parts_manager.handle_part(vendor_part_id=i, part=part) + + @property + def model_name(self) -> str: + return self._response.model_name or "" + + @property + def provider_name(self) -> str | None: + return self._response.provider_name + + @property + def provider_url(self) -> str | None: + return self._response.provider_url + + @property + def timestamp(self) -> datetime: + return self._response.timestamp or datetime.now(tz=timezone.utc) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/signatures.py b/enterprise/backend/src/baserow_enterprise/assistant/signatures.py deleted file mode 100644 index 60bd981266..0000000000 --- a/enterprise/backend/src/baserow_enterprise/assistant/signatures.py +++ /dev/null @@ -1,33 +0,0 @@ -import udspy - -from .prompts import AGENT_SYSTEM_PROMPT - - -class ChatSignature(udspy.Signature): - __doc__ = AGENT_SYSTEM_PROMPT - - question: str = udspy.InputField() - conversation_history: list[str] = udspy.InputField( - desc="Previous messages formatted as '[index] (role): content', ordered chronologically" - ) - ui_context: str | None = udspy.InputField( - default=None, - description=( - "The JSON serialized context the user is currently in. " - "It contains information about the user, the timezone, the workspace, etc." - "Whenever make sense, use it to ground your answer." - ), - ) - answer: str = udspy.OutputField() - - @classmethod - def format_conversation_history(cls, history: udspy.History) -> list[str]: - """ - Format the conversation history into a list of strings for the signature. - """ - - formatted_history = [] - for i, msg in enumerate(history.messages): - formatted_history.append(f"[{i}] ({msg['role']}): {msg['content']}") - - return formatted_history diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tasks.py b/enterprise/backend/src/baserow_enterprise/assistant/tasks.py index ef72768cfc..babb9f6b30 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tasks.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tasks.py @@ -3,7 +3,7 @@ from baserow.config.celery import app from .handler import AssistantHandler -from .tools import KnowledgeBaseHandler +from .tools.search_user_docs.handler import KnowledgeBaseHandler @app.task(bind=True) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/telemetry.py b/enterprise/backend/src/baserow_enterprise/assistant/telemetry.py index 9b5259092d..f1242c39a9 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/telemetry.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/telemetry.py @@ -1,21 +1,50 @@ """ Posthog telemetry integration for the Baserow Assistant. -This module provides tracing callbacks that capture DSPy execution flows -and send structured events to Posthog for LLM analytics. +Hooks into pydantic-ai's OpenTelemetry instrumentation to capture LLM +generation and tool call events, mapping them to PostHog's AI analytics +event schema (``$ai_trace``, ``$ai_generation``, ``$ai_span``). + +Architecture: + + PosthogTracingCallback -- per-request context manager that emits the + top-level ``$ai_trace`` event and publishes + trace metadata via a ``ContextVar`` for the + span exporter. + + PosthogSpanProcessor -- OpenTelemetry ``SpanProcessor`` that maps + pydantic-ai spans to PostHog events: + ``chat ...`` -> ``$ai_generation`` + ``running tool`` -> ``$ai_span`` + ``agent run`` -> ``$ai_span`` + The ``running tools`` grouping span is + transparently skipped; child tool spans have + their parent remapped to the grandparent + (typically the ``agent run`` span). + + setup_instrumentation() -- one-time wiring of the span processor into a + ``TracerProvider`` + ``Agent.instrument_all()``. """ +from __future__ import annotations + +import json from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass from datetime import datetime, timezone -from typing import Any from uuid import uuid4 -import udspy -from udspy.callback import BaseCallback +from opentelemetry.sdk.trace import ReadableSpan, SpanProcessor, TracerProvider +from opentelemetry.trace import SpanKind from baserow.core.posthog import get_posthog_client from baserow_enterprise.assistant.models import AssistantChat +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + def _utc_now() -> datetime: return datetime.now(tz=timezone.utc) @@ -25,23 +54,436 @@ def _uuid() -> str: return str(uuid4()) -class PosthogTracingCallback(BaseCallback): +def _posthog_capture(distinct_id: str, event: str, properties: dict, **kwargs): + """Send a single event to PostHog with standardised error handling.""" + + posthog_client = get_posthog_client() + try: + posthog_client.capture( + distinct_id=distinct_id, event=event, properties=properties, **kwargs + ) + except Exception: + pass + + +# --------------------------------------------------------------------------- +# Trace context (ContextVars shared between callback and span exporter) +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class _TraceContext: + """Immutable snapshot of per-request trace metadata.""" + + trace_id: str + user_id: str + workspace_id: str + chat_uuid: str + + +_trace_ctx: ContextVar[_TraceContext | None] = ContextVar("_trace_ctx", default=None) + +# Tool names collected during a trace for the $ai_trace summary. +_tool_calls: ContextVar[list[str]] = ContextVar("_tool_calls") + + +# --------------------------------------------------------------------------- +# Message format conversion (pydantic-ai -> PostHog) +# --------------------------------------------------------------------------- + + +def _parse_arguments(value): + """Ensure tool call arguments are a dict, parsing JSON strings if needed.""" + if isinstance(value, str): + try: + return json.loads(value) + except (json.JSONDecodeError, TypeError): + return value + return value + + +# pydantic-ai key names -> PostHog key names +_PART_TRANSFORMS = { + "text": lambda p: { + "type": "text", + "text": p.get("content", ""), + }, + "tool_call": lambda p: { + "type": "tool_call", + "tool_call_id": p.get("id", ""), + "name": p.get("name", ""), + "arguments": _parse_arguments(p.get("arguments", {})), + }, + "tool_return": lambda p: { + "type": "tool_result", + "tool_call_id": p.get("tool_call_id", ""), + "content": p.get("content", ""), + }, + "thinking": lambda p: { + "type": "thinking", + "thinking": p.get("content", ""), + }, +} + + +def _safe_json_attr(attrs: dict, key: str) -> list | dict | None: + """Extract a JSON-serialised span attribute, returning None if missing or + unparseable.""" + + val = attrs.get(key) + if val is None: + return None + if isinstance(val, str): + try: + return json.loads(val) + except (json.JSONDecodeError, TypeError): + return None + return val + + +def _pydantic_messages_to_posthog(messages: list[dict]) -> list[dict]: + """Convert pydantic-ai message dicts to PostHog's expected format. + + pydantic-ai: ``{"role": ..., "parts": [{"type": "text", "content": ...}]}`` + PostHog: ``{"role": ..., "content": [{"type": "text", "text": ...}]}`` """ - Captures uDSPy execution traces and sends events to Posthog. - This callback tracks: - - uDSPy module execution (ChainOfThought, ReAct, Predict) - - LLM API calls (OpenAI, Groq, etc.) - - Tool invocations - - Performance metrics and token usage + result = [] + for msg in messages: + content_parts = [] + for part in msg.get("parts", []): + ptype = part.get("type", "text") + transform = _PART_TRANSFORMS.get(ptype) + content_parts.append(transform(part) if transform else part) + result.append({"role": msg.get("role", "unknown"), "content": content_parts}) + return result + + +# --------------------------------------------------------------------------- +# Span helpers (shared by _emit_generation and _emit_tool_span) +# --------------------------------------------------------------------------- + + +def _span_latency(span: ReadableSpan) -> float | None: + """Compute span duration in seconds from OTel nanosecond timestamps.""" + + if span.start_time and span.end_time: + return (span.end_time - span.start_time) / 1e9 + return None + + +def _base_properties(ctx: _TraceContext) -> dict: + """Properties common to every PostHog event within a trace.""" + + return { + "$ai_trace_id": ctx.trace_id, + "$ai_session_id": ctx.chat_uuid, + "workspace_id": ctx.workspace_id, + } + + +def _extract_reasoning(output_messages: list[dict]) -> str | None: + """Join all ``thinking`` parts and tool-call ``thought`` fields from output + messages into a single string.""" + + parts: list[str] = [] + for msg in output_messages: + for part in msg.get("parts", []): + ptype = part.get("type") + if ptype == "thinking": + if content := part.get("content"): + parts.append(content) + elif ptype == "tool_call": + args = _parse_arguments(part.get("arguments", {})) + if isinstance(args, dict) and (thought := args.get("thought")): + parts.append(thought) + return "\n".join(parts) if parts else None + - Each instance is created per Assistant call with trace context, so - multiple concurrent traces can be captured independently. +# --------------------------------------------------------------------------- +# PosthogSpanExporter +# --------------------------------------------------------------------------- + +# Model setting keys emitted by pydantic-ai as ``gen_ai.request.*`` attrs. +_MODEL_PARAM_KEYS = ( + "temperature", + "max_tokens", + "top_p", + "seed", + "presence_penalty", + "frequency_penalty", +) + + +class PosthogSpanProcessor(SpanProcessor): + """Maps pydantic-ai OTel spans to PostHog LLM analytics events. + + ``chat {model}`` -> ``$ai_generation`` + ``running tool`` -> ``$ai_span`` (parent remapped past ``running tools``) + ``agent run`` -> ``$ai_span`` + ``running tools`` -> skipped (children re-parented to grandparent) """ def __init__(self): - super().__init__() + # "running tools" span_id -> its parent span_id. + # Populated on_start so child tool spans (which end first) can + # look up the grandparent during on_end. + self._tools_group_parents: dict[int, int | None] = {} + + # -- SpanProcessor interface ------------------------------------------- + + def on_start(self, span, parent_context=None): + if span.name == "running tools": + parent_id = span.parent.span_id if span.parent else None + self._tools_group_parents[span.context.span_id] = parent_id + + def on_end(self, span: ReadableSpan): + ctx = _trace_ctx.get() + if ctx is None: + return + + try: + self._process_span(span, ctx) + except Exception: + pass + + # Clean up mapping once the grouping span itself ends. + if span.name == "running tools": + self._tools_group_parents.pop(span.context.span_id, None) + + def shutdown(self): + pass + + def force_flush(self, timeout_millis: int = 0) -> bool: + return True + + # -- internal ---------------------------------------------------------- + + def _resolve_parent_id(self, span: ReadableSpan) -> str | None: + """Return the hex ``$ai_parent_id``, skipping ``running tools``.""" + + if not span.parent: + return None + parent_id = span.parent.span_id + # If the direct parent is a "running tools" span, jump to its parent. + grandparent = self._tools_group_parents.get(parent_id) + if grandparent is not None: + parent_id = grandparent + return f"{parent_id:016x}" + + def _span_id_props(self, span: ReadableSpan) -> dict: + props: dict = {"$ai_span_id": f"{span.context.span_id:016x}"} + parent_hex = self._resolve_parent_id(span) + if parent_hex: + props["$ai_parent_id"] = parent_hex + return props + + def _process_span(self, span: ReadableSpan, ctx: _TraceContext): + attrs = dict(span.attributes or {}) + + if span.kind == SpanKind.CLIENT and span.name.startswith("chat "): + self._emit_generation(span, attrs, ctx) + elif span.name == "running tool": + self._emit_tool_span(span, attrs, ctx) + elif span.name == "agent run": + self._emit_agent_span(span, attrs, ctx) + # "running tools" is intentionally not emitted. + + def _emit_generation(self, span: ReadableSpan, attrs: dict, ctx: _TraceContext): + """Map a ``chat {model}`` span to ``$ai_generation``.""" + + input_messages = _safe_json_attr(attrs, "gen_ai.input.messages") + output_messages = _safe_json_attr(attrs, "gen_ai.output.messages") + + properties = { + **_base_properties(ctx), + "$ai_model": ( + attrs.get("gen_ai.response.model") or attrs.get("gen_ai.request.model") + ), + "$ai_provider": ( + attrs.get("gen_ai.provider.name") or attrs.get("gen_ai.system") + ), + "$ai_input_tokens": attrs.get("gen_ai.usage.input_tokens"), + "$ai_output_tokens": attrs.get("gen_ai.usage.output_tokens"), + } + + # Model parameters + model_params = { + key: val + for key in _MODEL_PARAM_KEYS + if (val := attrs.get(f"gen_ai.request.{key}")) is not None + } + if model_params: + properties["$ai_model_parameters"] = model_params + + # System prompt + system_instructions = _safe_json_attr(attrs, "gen_ai.system_instructions") + if system_instructions and isinstance(system_instructions, list): + system_text = "\n".join( + p.get("content", "") for p in system_instructions if isinstance(p, dict) + ) + if system_text: + properties["$ai_system_prompt"] = system_text + + # Input / output messages + if input_messages: + properties["$ai_input"] = _pydantic_messages_to_posthog(input_messages) + if output_messages: + properties["$ai_output_choices"] = _pydantic_messages_to_posthog( + output_messages + ) + + latency = _span_latency(span) + if latency is not None: + properties["$ai_latency"] = latency + + # Tool definitions and names + tool_definitions = _safe_json_attr(attrs, "gen_ai.tool.definitions") + if tool_definitions and isinstance(tool_definitions, list): + tool_names = [ + t.get("name", "?") for t in tool_definitions if isinstance(t, dict) + ] + if tool_names: + properties["$ai_tools"] = tool_names + properties["$ai_tool_definitions"] = tool_definitions + + # Reasoning / thinking + if output_messages and isinstance(output_messages, list): + reasoning = _extract_reasoning(output_messages) + if reasoning: + properties["$ai_reasoning"] = reasoning + + properties.update(self._span_id_props(span)) + _posthog_capture(ctx.user_id, "$ai_generation", properties) + + def _emit_agent_span(self, span: ReadableSpan, attrs: dict, ctx: _TraceContext): + """Map an ``agent run`` span to ``$ai_span`` with the agent name.""" + + agent_name = attrs.get("agent_name", "unknown_agent") + + properties = { + **_base_properties(ctx), + "$ai_span_name": f"Agent: {agent_name}", + } + + # System prompt + system_instructions = _safe_json_attr(attrs, "gen_ai.system_instructions") + if system_instructions and isinstance(system_instructions, list): + system_text = "\n".join( + p.get("content", "") for p in system_instructions if isinstance(p, dict) + ) + if system_text: + properties["$ai_input_state"] = {"system_prompt": system_text} + + # User input (first user message) and final output + all_messages = _safe_json_attr(attrs, "pydantic_ai.all_messages") + if all_messages and isinstance(all_messages, list): + for msg in all_messages: + if msg.get("role") == "user": + parts = msg.get("parts", []) + user_texts = [ + p.get("content", "") + for p in parts + if isinstance(p, dict) and p.get("type") == "text" + ] + if user_texts: + input_state = properties.get("$ai_input_state", {}) + input_state["user_prompt"] = "\n".join(user_texts) + properties["$ai_input_state"] = input_state + break + + final_result = attrs.get("final_result") + if final_result is not None: + properties["$ai_output_state"] = _parse_arguments(final_result) + + latency = _span_latency(span) + if latency is not None: + properties["$ai_latency"] = latency + + properties.update(self._span_id_props(span)) + _posthog_capture(ctx.user_id, "$ai_span", properties) + + def _emit_tool_span(self, span: ReadableSpan, attrs: dict, ctx: _TraceContext): + """Map a ``running tool`` span to ``$ai_span``.""" + + tool_name = attrs.get("gen_ai.tool.name", "unknown_tool") + + # Record for the trace summary. + try: + _tool_calls.get().append(tool_name) + except LookupError: + pass + + tool_args = _safe_json_attr(attrs, "tool_arguments") + + properties = { + **_base_properties(ctx), + "$ai_span_name": f"Tool: {tool_name}", + "$ai_input_state": tool_args or {}, + "$ai_output_state": _parse_arguments(attrs.get("tool_response")), + } + + # Chain-of-thought reasoning from the "thought" argument + if isinstance(tool_args, dict) and tool_args.get("thought"): + properties["$ai_reasoning"] = tool_args["thought"] + + latency = _span_latency(span) + if latency is not None: + properties["$ai_latency"] = latency + + properties.update(self._span_id_props(span)) + _posthog_capture(ctx.user_id, "$ai_span", properties) + + +# --------------------------------------------------------------------------- +# One-time instrumentation setup +# --------------------------------------------------------------------------- + +_instrumentation_ready = False + + +def setup_instrumentation(): + """Activate pydantic-ai's OTel instrumentation with PostHog export. + Safe to call multiple times (subsequent calls are no-ops). + Does nothing when PostHog is disabled. + """ + + global _instrumentation_ready + if _instrumentation_ready: + return + + from django.conf import settings as django_settings + + posthog_enabled = getattr(django_settings, "POSTHOG_ENABLED", False) + if not posthog_enabled: + return + + from pydantic_ai import Agent, InstrumentationSettings + + tracer_provider = TracerProvider() + tracer_provider.add_span_processor(PosthogSpanProcessor()) + + Agent.instrument_all( + InstrumentationSettings( + tracer_provider=tracer_provider, + include_content=True, + ) + ) + + _instrumentation_ready = True + + +# --------------------------------------------------------------------------- +# PosthogTracingCallback — per-request trace lifecycle +# --------------------------------------------------------------------------- + + +class PosthogTracingCallback: + """Per-request trace lifecycle. Creates the ``$ai_trace`` event and + publishes ``_TraceContext`` for the span exporter.""" + + def __init__(self): self.chat: AssistantChat | None = None self.human_msg: str | None = None self.trace_id: str | None = None @@ -49,50 +491,36 @@ def __init__(self): self.user_id: str | None = None self.workspace_id: str | None = None self.chat_uuid: str | None = None - self.spans: dict[str, dict] = {} - self.span_ids: list[str] = [] + self.trace_outputs = None @contextmanager def trace(self, chat: AssistantChat, human_message: str): - """ - Context manager for tracing an assistant execution. - Initializes trace context and captures the overall trace event. - It also patches the OpenAI client to auto-capture generation events. + """Context manager that scopes a single assistant execution. - :param chat: The AssistantChat instance - :param human_message: The initial user message + Publishes ``_trace_ctx`` so ``PosthogSpanExporter`` can attach trace + metadata to child ``$ai_generation`` / ``$ai_span`` events. """ - from posthog.ai.openai import AsyncOpenAI - self.chat = chat self.human_msg = human_message - self.trace_id = _uuid() self.span_id = _uuid() self.user_id = str(chat.user_id) self.workspace_id = str(chat.workspace_id) self.chat_uuid = str(chat.uuid) + self.trace_outputs = None start_time = _utc_now() - self.spans = {} - self.span_ids = [self.span_id] - self.trace_outputs = None - # patch the OpenAI client to automatically send the generation event - lm = udspy.settings._context_lm.get() - openai_client = lm.client - - # Check if client is already a PostHog-wrapped client by checking its - # module. We avoid isinstance() here because it can fail when the class - # is mocked in tests. - is_posthog_client = "posthog" in type(openai_client).__module__ - if not is_posthog_client: - lm.client = AsyncOpenAI( - api_key=openai_client.api_key, - base_url=openai_client.base_url, - posthog_client=get_posthog_client(), + token = _trace_ctx.set( + _TraceContext( + trace_id=self.trace_id, + user_id=self.user_id, + workspace_id=self.workspace_id, + chat_uuid=self.chat_uuid, ) + ) + tools_token = _tool_calls.set([]) exception = None try: @@ -101,7 +529,17 @@ def trace(self, chat: AssistantChat, human_message: str): exception = exc raise finally: - # Stop trace + tool_call_names = _tool_calls.get([]) + _trace_ctx.reset(token) + _tool_calls.reset(tools_token) + + output_state = self.trace_outputs if exception is None else str(exception) + if tool_call_names: + if output_state is None: + output_state = {} + if isinstance(output_state, dict): + output_state["tool_calls"] = tool_call_names + self._capture_event( "$ai_trace", timestamp=start_time, @@ -112,190 +550,27 @@ def trace(self, chat: AssistantChat, human_message: str): "$ai_latency": (_utc_now() - start_time).total_seconds(), "$ai_is_error": exception is not None, "$ai_input_state": {"user_message": human_message}, - "$ai_output_state": self.trace_outputs - if exception is None - else str(exception), + "$ai_output_state": output_state, }, ) - def _capture_event(self, event: str, **kwargs): - """ - Capture a Posthog event if Posthog is enabled. + try: + get_posthog_client().flush() + except Exception: + pass - :param event: Event name (e.g., "$ai_generation") - :param properties: Event properties dictionary - """ + def set_trace_output(self, output: str): + """Record the agent's final answer for the ``$ai_trace`` event.""" - default_props = { - "$ai_trace_id": self.trace_id, - "$ai_session_id": self.chat_uuid, - "workspace_id": self.workspace_id, - } - if "properties" in kwargs: - kwargs["properties"].update(default_props) - else: - kwargs["properties"] = default_props - - posthog_client = get_posthog_client() - posthog_client.capture( - distinct_id=str(self.user_id), - event=event, - **kwargs, - ) # noqa: W505 - - def on_module_start(self, call_id: str, instance: Any, inputs: dict): - """ - Track the start of a DSPy module execution. + self.trace_outputs = {"answer": output} - Captures ChainOfThought, ReAct, Predict, and other module types. - - :param call_id: Unique identifier for this call - :param instance: The DSPy module instance - :param inputs: Input dictionary passed to the module - """ - - module_type = instance.__class__.__name__ - parent_span_id = self.span_ids[-1] if self.span_ids else None - span_id = call_id - self.span_ids.append(span_id) - span = { - "start_time": _utc_now(), - "properties": { - "$ai_span_name": module_type, - "$ai_span_id": span_id, - "$ai_parent_span_id": parent_span_id, - }, - } - self.spans[span_id] = span - - def _update_span_with_signature_data(signature): - adapter = udspy.ChatAdapter() - input_fields = ", ".join(signature.get_input_fields().keys()) - output_fields = ", ".join(signature.get_output_fields()) - span["properties"]["$ai_input_state"] = { - "signature": f"{input_fields} -> {output_fields}", - "instructions": adapter.format_instructions(signature), - **inputs["kwargs"], - } - - if isinstance(instance, (udspy.Predict, udspy.ReAct)): - _update_span_with_signature_data(instance.signature) - elif isinstance(instance, udspy.ChainOfThought): - _update_span_with_signature_data(instance.original_signature) - - def on_module_end(self, call_id: str, outputs: Any, exception: Exception | None): - """ - Remove the span from the stack together with all the started $ai_generation - spans appended in `on_lm_start` - - Args: - call_id: Unique identifier for this call - outputs: Module output (if successful) - exception: Exception raised (if failed) - """ - - while (span_id := self.span_ids.pop()) != call_id: - continue - - span = self.spans.pop(span_id) - start_time = span.pop("start_time") - span["properties"].update( - { - "$ai_latency": (_utc_now() - start_time).total_seconds(), - "$ai_is_error": exception is not None, - "$ai_output_state": outputs if exception is None else str(exception), - } - ) - - if isinstance(outputs, dict) and "answer" in outputs: - self.trace_outputs = { - k: v - for k, v in outputs.items() - if k not in ["module", "native_tool_calls"] - } - - self._capture_event("$ai_span", timestamp=start_time, **span) - - def on_lm_start(self, call_id: str, instance: Any, inputs: dict): - """ - Only enrich posthog properties that will be sent automatically - by the patched openai client. - Add the span_id to the stack so any tool call will be shown - as a child span. - - Args: - call_id: Unique identifier for this call - instance: The LM instance - inputs: API call parameters (model, messages, temperature, etc.) - """ + def _capture_event(self, event: str, **kwargs): + """Capture a PostHog event, merging in default trace properties.""" - parent_span_id = self.span_ids[-1] if self.span_ids else None - kwargs = inputs["kwargs"] - span_id = call_id - self.span_ids.append(span_id) - kwargs["posthog_distinct_id"] = self.user_id - kwargs["posthog_trace_id"] = self.trace_id - kwargs["posthog_properties"] = { + kwargs["properties"] = { + **kwargs.get("properties", {}), + "$ai_trace_id": self.trace_id, "$ai_session_id": self.chat_uuid, - "$ai_parent_span_id": parent_span_id, - "$ai_span_id": span_id, "workspace_id": self.workspace_id, - "$ai_provider": instance.provider, - } - - def on_lm_end(self, call_id: str, outputs: Any, exception: Exception | None): - """ - Automatically tracked by the patched openai client. - - :param call_id: Unique identifier for this call - :param outputs: LLM response object - :param exception: Exception raised (if failed) - """ - - pass - - def on_tool_start(self, call_id: str, instance: Any, inputs: dict): - """ - Track the start of a tool invocation. - - Args: - call_id: Unique identifier for this call - instance: The tool instance - inputs: Tool input parameters - """ - - tool_name = getattr(instance, "name", instance.__class__.__name__) - - span_id = call_id - parent_span_id = self.span_ids[-1] if self.span_ids else None - self.spans[span_id] = { - "start_time": _utc_now(), - "properties": { - "$ai_span_name": f"Tool: {tool_name}", - "$ai_span_id": span_id, - "$ai_parent_span_id": parent_span_id, - "$ai_input_state": inputs, - }, } - - def on_tool_end(self, call_id: str, outputs: Any, exception: Exception | None): - """ - Track the completion of a tool invocation. - - Args: - call_id: Unique identifier for this call - outputs: Tool output - exception: Exception raised (if failed) - """ - - span_id = call_id - span = self.spans.pop(span_id) - start_time = span.pop("start_time") - span["properties"].update( - { - "$ai_latency": (_utc_now() - start_time).total_seconds(), - "$ai_is_error": exception is not None, - "$ai_output_state": outputs if exception is None else str(exception), - } - ) - self._capture_event("$ai_span", timestamp=start_time, **span) + _posthog_capture(str(self.user_id), event, **kwargs) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/__init__.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/__init__.py index 4ac95ed235..8b13789179 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/__init__.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/__init__.py @@ -1,5 +1 @@ -from .automation.tools import * # noqa: F401, F403 -from .core.tools import * # noqa: F401, F403 -from .database.tools import * # noqa: F401, F403 -from .navigation.tools import * # noqa: F401, F403 -from .search_user_docs.tools import * # noqa: F401, F403 + diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/__init__.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/__init__.py index ace1c221c3..8b13789179 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/__init__.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/__init__.py @@ -1,6 +1 @@ -from .tools import ListWorkflowsToolType, WorkflowToolFactoryToolType -__all__ = [ - "ListWorkflowsToolType", - "WorkflowToolFactoryToolType", -] diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/agents.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/agents.py new file mode 100644 index 0000000000..8f56f79152 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/agents.py @@ -0,0 +1,156 @@ +""" +Sub-agents for the automation assistant tools. + +Contains: +- ``AssistantFormulaContext``: Automation-specific formula context. +- ``get_generate_formulas_tool()``: Gets the automation formula generator. +- ``update_workflow_formulas()``: Generates formulas for workflow nodes. +""" + +from typing import TYPE_CHECKING, Any + +from django.db import transaction +from django.utils.translation import gettext as _ + +from loguru import logger + +from baserow.contrib.automation.nodes.models import AutomationNode +from baserow_enterprise.assistant.tools.shared.agents import get_formula_generator +from baserow_enterprise.assistant.tools.shared.formula_utils import ( + BaseFormulaContext, + create_example_from_json_schema, + minimize_json_schema, +) + +from .prompts import GENERATE_FORMULA_PROMPT +from .types import ActionNodeCreate, NodeUpdate, WorkflowCreate + +if TYPE_CHECKING: + from baserow_enterprise.assistant.deps import ToolHelpers + + +class AssistantFormulaContext(BaseFormulaContext): + """ + Automation-specific formula context. + + Wraps node data in the ``{"previous_node": {...}}`` structure expected + by automation formula ``get()`` paths. + """ + + def add_node_context( + self, + node_id: int | str, + node_context: dict[str, Any], + context_metadata: dict[str, dict[str, str]] | None = None, + ): + """Add a node's output values to the formula context.""" + self.add_context(str(node_id), node_context, context_metadata) + + def get_formula_context(self) -> dict[str, Any]: + """Return context wrapped in ``previous_node`` for automation formulas.""" + return {"previous_node": self.context} + + def __getitem__(self, key) -> Any: + """Resolve paths like ``previous_node.1.0.field_name``.""" + return self._resolve_path(key, "previous_node") + + +def get_generate_formulas_tool(): + """Get the automation formula generator using the shared factory.""" + return get_formula_generator(GENERATE_FORMULA_PROMPT) + + +def update_workflow_formulas( + workflow: "WorkflowCreate", + node_mapping: dict[int | str, Any], + tool_helpers: "ToolHelpers", +) -> None: + """ + Generate and apply formulas for all nodes in a newly created workflow. + + Walks nodes in order, building up the available formula context as it goes. + For each node that has ``$formula:`` values, delegates to the formula + generation agent and writes the results back to the ORM service. + """ + + context = AssistantFormulaContext() + generate_formula = get_generate_formulas_tool() + + def _build_node_context(orm_node: AutomationNode, node_create): + """Extract schema/example from a node and add it to the formula context.""" + + schema = orm_node.service.get_type().generate_schema(orm_node.service.specific) + example = create_example_from_json_schema(schema) + metadata = minimize_json_schema(schema) + metadata["node_id"] = orm_node.id + metadata["node_ref"] = node_create.ref + if getattr(node_create, "previous_node_ref", None): + metadata["previous_node_ref"] = node_create.previous_node_ref + context.add_node_context(orm_node.id, example, metadata) + + def _generate_node_formulas(node: ActionNodeCreate, orm_node: AutomationNode): + """Generate formulas for a single node and write them to the service.""" + + formulas_to_create = node.get_formulas_to_create(orm_node) + if formulas_to_create is None: + return + result = generate_formula(formulas_to_create, context) + if result: + node.update_service_with_formulas(orm_node.service, result) + + # Seed context with the trigger + orm_trigger, trigger_create = node_mapping[workflow.trigger.ref] + _build_node_context(orm_trigger, trigger_create) + + # Process action nodes in order + for node in workflow.nodes: + orm_node, _node_create = node_mapping[node.ref] + node.apply_direct_values(orm_node.service) + + if node.get_formulas_to_create(orm_node) is not None: + tool_helpers.update_status( + _("Generating formulas for node '%(label)s'..." % {"label": node.label}) + ) + with transaction.atomic(): + try: + _generate_node_formulas(node, orm_node) + except Exception as exc: + logger.exception( + "Failed to generate formulas for node {}: {}", orm_node.id, exc + ) + + _build_node_context(orm_node, node) + + +def update_single_node_formulas( + node_update: "NodeUpdate", + orm_node: AutomationNode, + tool_helpers: "ToolHelpers", +) -> None: + """ + Generate and apply formulas for a single node being updated. + + Builds formula context from the node's workflow, then generates + formulas for the $formula: fields in the update. + """ + + context = AssistantFormulaContext() + generate_formula = get_generate_formulas_tool() + + # Build context from the workflow's existing nodes + workflow = orm_node.workflow + all_nodes = list(workflow.automation_workflow_nodes.all().order_by("id")) + for wf_node in all_nodes: + schema = wf_node.service.get_type().generate_schema(wf_node.service.specific) + example = create_example_from_json_schema(schema) + metadata = minimize_json_schema(schema) + metadata["node_id"] = wf_node.id + context.add_node_context(wf_node.id, example, metadata) + + formulas_to_create = node_update.get_formulas_to_update(orm_node) + if formulas_to_create is None: + return + + result = generate_formula(formulas_to_create, context) + if result: + node_update.update_service_with_formulas(orm_node.service, result) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/helpers.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/helpers.py new file mode 100644 index 0000000000..c7bfe5fe10 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/helpers.py @@ -0,0 +1,288 @@ +""" +Shared helpers for the automation assistant tools. + +Contains permission-checked accessors and the workflow creation orchestrator +used by ``tools.py`` and ``agents.py``. +""" + +from typing import TYPE_CHECKING, Any + +from django.contrib.auth.models import AbstractUser +from django.utils.translation import gettext as _ + +from baserow.contrib.automation.models import Automation +from baserow.contrib.automation.nodes.registries import automation_node_type_registry +from baserow.contrib.automation.nodes.service import AutomationNodeService +from baserow.contrib.automation.workflows.models import AutomationWorkflow +from baserow.contrib.automation.workflows.service import AutomationWorkflowService +from baserow.core.models import Workspace +from baserow.core.service import CoreService + +from .types import NodeUpdate, WorkflowCreate + +if TYPE_CHECKING: + from baserow_enterprise.assistant.deps import ToolHelpers + + from .types import ActionNodeCreate + + +def get_automation( + automation_id: int, user: AbstractUser, workspace: Workspace +) -> Automation: + """Fetch an automation scoped to the user's workspace.""" + + base_queryset = Automation.objects.filter(workspace=workspace) + return CoreService().get_application( + user, automation_id, base_queryset=base_queryset + ) + + +def get_workflow( + workflow_id: int, user: AbstractUser, workspace: Workspace +) -> AutomationWorkflow: + """Fetch a workflow with a workspace-level permission check.""" + + workflow = AutomationWorkflowService().get_workflow(user, workflow_id) + if workflow.automation.workspace_id != workspace.id: + raise ValueError("Workflow not in workspace") + return workflow + + +def get_nodes_in_order(user: AbstractUser, workflow: AutomationWorkflow) -> list[dict]: + """ + Return the nodes of a workflow in graph traversal order. + + Walks the workflow graph starting from the trigger, following ``next`` + edges (all outputs) and ``children`` to produce a flat, ordered list. + """ + + nodes = AutomationNodeService().get_nodes(user, workflow) + node_map = {n.id: n for n in nodes} + graph = workflow.get_graph().graph + + trigger_id = graph.get("0") + if trigger_id is None: + return [] + + ordered_ids: list[int] = [] + visited: set[int] = set() + + def walk(node_id: int): + if node_id in visited or node_id not in node_map: + return + visited.add(node_id) + ordered_ids.append(node_id) + info = graph.get(str(node_id), {}) + # Follow children first (for container nodes like iterators) + for child_id in info.get("children", []): + walk(child_id) + # Then follow next edges in order + for output_uid, next_ids in info.get("next", {}).items(): + for nid in next_ids: + walk(nid) + + walk(trigger_id) + + result = [] + for nid in ordered_ids: + node = node_map[nid] + node_type = node.get_type() + entry = { + "id": node.id, + "label": node.get_label(), + "type": node_type.type, + } + result.append(entry) + + return result + + +def add_nodes_to_workflow( + user: AbstractUser, + workflow: AutomationWorkflow, + nodes: list["ActionNodeCreate"], + tool_helpers: "ToolHelpers", +) -> tuple[list[Any], dict[int | str, Any]]: + """ + Add action nodes to an existing workflow. + + The ``previous_node_ref`` on each node can reference: + - An existing node ID as a string (e.g. "49") + - A temp ref from an earlier node in the same ``nodes`` list + + Returns a list of created ORM nodes and the node mapping. + """ + + # Seed the mapping with existing nodes in the workflow + existing_nodes = AutomationNodeService().get_nodes(user, workflow) + node_mapping: dict[int | str, Any] = {} + for n in existing_nodes: + # Create a stub for the node_create part that has type and edges info + stub = _ExistingNodeStub(n) + node_mapping[str(n.id)] = (n, stub) + node_mapping[n.id] = (n, stub) + + created = [] + for node in nodes: + tool_helpers.raise_if_cancelled() + reference_node_id, output = node.to_orm_reference_node(node_mapping) + orm_node = _create_node( + user, + workflow, + node, + tool_helpers, + reference_node_id=reference_node_id, + output=output, + ) + node_mapping[node.ref] = node_mapping[orm_node.id] = (orm_node, node) + created.append(orm_node) + + return created, node_mapping + + +class _EdgeStub: + """Bridges ORM edge ``uid`` to the ``_uid`` attribute expected by ``to_orm_reference_node``.""" + + def __init__(self, orm_edge): + self.label = orm_edge.label + self._uid = str(orm_edge.uid) + + +class _ExistingNodeStub: + """ + Lightweight stub exposing ``type`` and ``edges`` from an existing ORM node, + so ``ActionNodeCreate.to_orm_reference_node`` can resolve router edge labels. + """ + + def __init__(self, orm_node): + self.type = orm_node.get_type().type + self.edges = [] + if self.type == "router" and hasattr(orm_node.service, "specific"): + service = orm_node.service.specific + if hasattr(service, "edges"): + self.edges = [_EdgeStub(e) for e in service.edges.all()] + + +def create_workflow( + user: AbstractUser, + automation: Automation, + workflow: "WorkflowCreate", + tool_helpers: "ToolHelpers", +) -> tuple[AutomationWorkflow, dict[int | str, Any]]: + """ + Create a workflow with its trigger and action nodes. + + Returns the ORM workflow and a mapping of ``{ref_or_id: (orm_node, node_create)}`` + for every created node, usable by downstream formula generation. + """ + + tool_helpers.update_status( + _("Creating workflow '%(name)s'..." % {"name": workflow.name}) + ) + + orm_wf = AutomationWorkflowService().create_workflow( + user, automation.id, workflow.name + ) + + node_mapping: dict[int | str, Any] = {} + + # -- Trigger -- + orm_trigger = _create_node(user, orm_wf, workflow.trigger, tool_helpers) + node_mapping[workflow.trigger.ref] = node_mapping[orm_trigger.id] = ( + orm_trigger, + workflow.trigger, + ) + + # -- Action / router / iterator nodes -- + for node in workflow.nodes: + try: + reference_node_id, output = node.to_orm_reference_node(node_mapping) + except ValueError as exc: + from pydantic_ai import ModelRetry + + raise ModelRetry(str(exc)) from exc + orm_node = _create_node( + user, + orm_wf, + node, + tool_helpers, + reference_node_id=reference_node_id, + output=output, + ) + node_mapping[node.ref] = node_mapping[orm_node.id] = (orm_node, node) + + return orm_wf, node_mapping + + +def _create_node(user, workflow, node_create, tool_helpers, **extra_kwargs): + """Create a single automation node (trigger or action).""" + + tool_helpers.update_status( + _("Creating node '%(label)s'..." % {"label": node_create.label}) + ) + node_type = automation_node_type_registry.get(node_create.type) + return AutomationNodeService().create_node( + user, + node_type, + workflow, + label=node_create.label, + service=node_create.to_orm_service_dict(), + **extra_kwargs, + ) + + +def update_node( + user: "AbstractUser", + workspace: "Workspace", + node_update: "NodeUpdate", + tool_helpers: "ToolHelpers", +): + """ + Update an automation node's label and/or service config. + + :param user: The acting user. + :param workspace: Workspace for permission check. + :param node_update: The update definition. + :param tool_helpers: Provides status updates. + :returns: The updated ORM node. + """ + + node = AutomationNodeService().get_node(user, node_update.node_id) + if node.workflow.automation.workspace_id != workspace.id: + raise ValueError("Node not in workspace") + + kwargs = {} + if node_update.label is not None: + kwargs["label"] = node_update.label + + node_type = node.service.get_type().type if node.service else None + service_dict = node_update.to_update_service_dict(node_type) if node_type else None + if service_dict is not None: + kwargs["service"] = service_dict + + if kwargs: + tool_helpers.update_status( + _("Updating node '%(label)s'..." % {"label": node.label}) + ) + AutomationNodeService().update_node(user, node.id, **kwargs) + + return AutomationNodeService().get_node(user, node_update.node_id) + + +def delete_node( + user: "AbstractUser", + workspace: "Workspace", + node_id: int, +): + """ + Delete an automation node. + + :param user: The acting user. + :param workspace: Workspace for permission check. + :param node_id: ID of the node to delete. + """ + + node = AutomationNodeService().get_node(user, node_id) + if node.workflow.automation.workspace_id != workspace.id: + raise ValueError("Node not in workspace") + AutomationNodeService().delete_node(user, node_id) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/prompts.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/prompts.py index a0e54c5ab1..55904a5b7d 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/prompts.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/prompts.py @@ -1,31 +1,13 @@ -GENERATE_FORMULA_PROMPT = """ -You are a formula builder. Generate formulas using these functions: +from baserow_enterprise.assistant.tools.shared.formula_prompt import FORMULA_LANGUAGE -**Comparison operators** (for router conditions only): -equal, not_equal, greater_than, less_than, greater_than_equal, less_than_equal -- Arguments: numbers, 'strings', or get() functions -- Returns: boolean -- Example: greater_than(get('age'), 18) +GENERATE_FORMULA_PROMPT = ( + FORMULA_LANGUAGE + + """ +## Context: Automation Workflows -**concat(...args)** - Joins arguments into a string -- Arguments: 'string literals' or get() functions -- Example: concat('Hello ', get('name'), '!') - -**get(path)** - Retrieves values from context using path notation -- Objects: get('user.name') -- Arrays: get('items.0'), get('orders.2.total') -- Nested: get('users.0.address.city') -- All: get('users.*.email') returns a list of emails from all users - -**if(condition, true_value, false_value)** - Conditional expression -- Arguments: a boolean condition, value if true, value if false -- Example: if(greater_than(get('score'), 50), 'pass', 'fail') - -**today()** - Returns the current date -**now()** - Returns the current date and time - -**constants**: -- A string literal enclosed in single quotes (e.g., 'hello world', '123') +In automation formulas, data is accessed through the previous_node structure: +- Path format: get('previous_node..0.') +- Each node ID maps to an array of rows; use index 0 for the first (and usually only) row. **Example 1 - String Fields:** Input: @@ -84,3 +66,4 @@ 3. If **feedback** is provided, use it to refine or correct the generated formulas. 4. Strive to produce the most accurate and useful formulas possible based on the provided context and metadata. """ +) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/tool_types.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/tool_types.py new file mode 100644 index 0000000000..d8a63f2f61 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/tool_types.py @@ -0,0 +1,20 @@ +from baserow_enterprise.assistant.tools.registries import AssistantToolType + + +class AutomationToolType(AssistantToolType): + type = "automation" + + def get_tool_functions(self): + from .tools import TOOL_FUNCTIONS + + return TOOL_FUNCTIONS + + def get_toolset(self): + from .tools import automation_toolset + + return automation_toolset + + def get_routing_rules(self): + from .tools import ROUTING_RULES + + return ROUTING_RULES diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/tools.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/tools.py index 6046ca5ffa..5338ddf07b 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/tools.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/tools.py @@ -1,150 +1,370 @@ -from typing import TYPE_CHECKING, Any, Callable +from typing import Annotated, Any -from django.contrib.auth.models import AbstractUser from django.db import transaction from django.utils.translation import gettext as _ -import udspy +from pydantic import Field +from pydantic_ai import RunContext +from pydantic_ai.toolsets import FunctionToolset from baserow.contrib.automation.workflows.service import AutomationWorkflowService -from baserow.core.models import Workspace -from baserow_enterprise.assistant.tools.registries import AssistantToolType +from baserow_enterprise.assistant.deps import AssistantDeps from baserow_enterprise.assistant.types import WorkflowNavigationType -from . import utils -from .types import WorkflowCreate - -if TYPE_CHECKING: - from baserow_enterprise.assistant.assistant import ToolHelpers - - -def get_list_workflows_tool( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[int], dict[str, list[dict]]]: - """ - List all workflows in an automation. +from . import agents, helpers +from .types import ActionNodeCreate, NodeUpdate, WorkflowCreate + + +def list_workflows( + ctx: RunContext[AssistantDeps], + automation_id: Annotated[ + int, Field(description="The ID of the automation to list workflows for.") + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + List workflows in an automation. + + WHEN to use: Check existing workflows in an automation, or find workflow IDs before creating new ones. + WHAT it does: Lists all workflows in an automation with their id, name, and state. + RETURNS: Workflows array with id, name, state. + DO NOT USE when: You already have the workflow IDs you need. """ - def list_workflows(automation_id: int) -> dict[str, Any]: - """ - List all workflows in an automation application. + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - :param automation_id: The ID of the automation application - :return: Dictionary with workflows list - """ + tool_helpers.update_status(_("Listing workflows...")) - nonlocal user, workspace, tool_helpers + automation = helpers.get_automation(automation_id, user, workspace) + workflows = AutomationWorkflowService().list_workflows(user, automation.id) - tool_helpers.update_status(_("Listing workflows...")) + return { + "workflows": [{"id": w.id, "name": w.name, "state": w.state} for w in workflows] + } - automation = utils.get_automation(automation_id, user, workspace) - workflows = AutomationWorkflowService().list_workflows(user, automation.id) - return { - "workflows": [ - {"id": w.id, "name": w.name, "state": w.state} for w in workflows - ] - } +def list_nodes( + ctx: RunContext[AssistantDeps], + workflow_id: Annotated[ + int, Field(description="The ID of the workflow to list nodes for.") + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + List nodes in a workflow in execution order. - return list_workflows + WHEN to use: Inspect the nodes in a workflow, find node IDs before updating or deleting. + WHAT it does: Lists all nodes (trigger + actions) in graph traversal order with id, label, and type. + RETURNS: Nodes array with id, label, type. + DO NOT USE when: You already have the node IDs you need. + """ + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers + + tool_helpers.update_status(_("Listing nodes...")) + + workflow = helpers.get_workflow(workflow_id, user, workspace) + nodes = helpers.get_nodes_in_order(user, workflow) + + return {"nodes": nodes} + + +def add_nodes( + ctx: RunContext[AssistantDeps], + workflow_id: Annotated[ + int, Field(description="The ID of the workflow to add nodes to.") + ], + nodes: Annotated[ + list[ActionNodeCreate], + Field( + description="Nodes to add. previous_node_ref can be an existing node ID (as string) or a temp ref from an earlier node in this list." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Add action/router nodes to an existing workflow. + + WHEN to use: User wants to insert or append nodes in an existing workflow — e.g. add a router between trigger and action, or add a new action after an existing one. + WHAT it does: Creates new nodes attached to existing ones. Use previous_node_ref with the string ID of an existing node (e.g. "49") or a temp ref of a node being created in the same call. + RETURNS: Created nodes array with id, label, type. + DO NOT USE when: You want to create an entirely new workflow — use create_workflows instead. + HOW: Use list_nodes first to find the existing node IDs, then specify previous_node_ref to place new nodes. Use router_edge_label when attaching to a router branch. + """ -def get_workflow_tool_factory( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[int, list[WorkflowCreate]], dict[str, list[dict]]]: - def create_workflows( - automation_id: int, workflows: list[WorkflowCreate] - ) -> dict[str, Any]: - """ - Create one or more workflows in an automation. Always use {{ node.ref }} to - reference previous nodes values inside the workflow. + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - :param automation_id: The automation application ID - :param workflows: List of workflows to create - :return: Dictionary with created workflows - """ + if not nodes: + return {"created_nodes": []} - nonlocal user, workspace, tool_helpers + tool_helpers.update_status(_("Adding nodes to workflow...")) - created = [] + workflow = helpers.get_workflow(workflow_id, user, workspace) - automation = utils.get_automation(automation_id, user, workspace) - for wf in workflows: - with transaction.atomic(): - orm_workflow, node_mapping = utils.create_workflow( - user, automation, wf, tool_helpers - ) - created.append( - { - "id": orm_workflow.id, - "name": orm_workflow.name, - "state": orm_workflow.state, - } - ) - - # In separate transactions, try to update the formulas inside the workflow, - # so we don't block the main creation if something goes wrong here. - utils.update_workflow_formulas(wf, node_mapping, tool_helpers) - - # Navigate to the last created workflow - tool_helpers.navigate_to( - WorkflowNavigationType( - type="automation-workflow", - automation_id=automation.id, - workflow_id=orm_workflow.id, - workflow_name=orm_workflow.name, - ) + with transaction.atomic(): + created_nodes, node_mapping = helpers.add_nodes_to_workflow( + user, workflow, nodes, tool_helpers ) - return {"created_workflows": created} - - def load_workflow_automation_tools(): - """ - TOOL LOADER: Loads tools to manage workflows in an automation. - - After calling this loader, you will have access to: - - create_workflows: Create workflows with triggers, actions, and routers + # Generate formulas for nodes that need them + for orm_node, node_create in [(n, nodes[i]) for i, n in enumerate(created_nodes)]: + formulas = node_create.get_formulas_to_create(orm_node) + if formulas: + node_create.apply_direct_values(orm_node.service) + tool_helpers.update_status( + _( + "Generating formulas for node '%(label)s'..." + % {"label": orm_node.label} + ) + ) + with transaction.atomic(): + try: + agents.update_single_node_formulas( + node_create, orm_node, tool_helpers + ) + except Exception: + from loguru import logger + + logger.exception( + "Failed to generate formulas for node {}", orm_node.id + ) + + return { + "created_nodes": [ + {"id": n.id, "label": n.get_label(), "type": n.get_type().type} + for n in created_nodes + ] + } + + +def create_workflows( + ctx: RunContext[AssistantDeps], + automation_id: Annotated[ + int, Field(description="The ID of the automation to create workflows in.") + ], + workflows: Annotated[ + list[WorkflowCreate], + Field( + description="List of workflows to create, each with a trigger and action nodes." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Create workflows with triggers and action nodes. + + WHEN to use: User wants automated workflows with triggers and action nodes. + WHAT it does: Creates workflows with a trigger and action/router/iterator nodes. Use {{ node.ref }} for referencing values from previous nodes. + RETURNS: Created workflows with id, name, state. + DO NOT USE when: Workflows with those names already exist — check with list_workflows first. + HOW: Each workflow needs exactly one trigger and one or more actions/routers. Use {{ node.ref }} syntax to reference previous node values in action formulas. Know the table_id and field_ids for row-based triggers and actions. + + ## Workflow Structure + + Each workflow has a trigger (the starting event) and action nodes (tasks to perform). + Nodes execute in sequence. Use {{ node.ref }} template syntax to reference + values from previous nodes. + + ## Dynamic Values with $formula: + + Any string field marked "Supports $formula:" can use dynamic values. + Prefix with '$formula:' + a natural-language description to auto-generate a formula + from context data. Otherwise the value is used as a literal. + - {"field_id": 123, "value": "$formula: the customer name from the trigger data"} + - {"field_id": 456, "value": "$formula: today's date"} + - {"field_id": 789, "value": "pending"} ← literal, no prefix + """ - Use this when you need to create workflows in an automation but don't have the tool. - """ # noqa: W505 + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - @udspy.module_callback - def _load_workflow_automation_tools(context): - nonlocal user, workspace, tool_helpers + if not workflows: + return {"created_workflows": []} - observation = ["New tools are now available.\n"] + created = [] - create_tool = udspy.Tool(create_workflows) - new_tools = [create_tool] - observation.append( - "- Use `create_workflows` to create workflows in an automation." + automation = helpers.get_automation(automation_id, user, workspace) + for wf in workflows: + tool_helpers.raise_if_cancelled() + with transaction.atomic(): + orm_workflow, node_mapping = helpers.create_workflow( + user, automation, wf, tool_helpers + ) + created.append( + { + "id": orm_workflow.id, + "name": orm_workflow.name, + "state": orm_workflow.state, + } ) - # Re-initialize the module with the new tools for the next iteration - context.module.init_module(tools=context.module._tools + new_tools) - return "\n".join(observation) + # In separate transactions, try to update the formulas inside the workflow, + # so we don't block the main creation if something goes wrong here. + agents.update_workflow_formulas(wf, node_mapping, tool_helpers) + + # Navigate to the last created workflow + tool_helpers.navigate_to( + WorkflowNavigationType( + type="automation-workflow", + automation_id=automation.id, + workflow_id=orm_workflow.id, + workflow_name=orm_workflow.name, + ) + ) + + return {"created_workflows": created} + + +def update_nodes( + ctx: RunContext[AssistantDeps], + workflow_id: Annotated[ + int, Field(description="The ID of the workflow containing the nodes.") + ], + nodes: Annotated[ + list[NodeUpdate], + Field( + description="List of node updates, each with a node_id and properties to change." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Update automation node labels and service configuration. + + WHEN to use: User wants to rename a node, change email subject/body, update slack channel, etc. + WHAT it does: Updates node label and/or service config. Supports $formula: prefix for dynamic values. + RETURNS: Updated node IDs and any errors. + DO NOT USE when: You need to change a node's type — delete and recreate it instead. + HOW: Use list_workflows first to find the workflow and node IDs. + """ - return _load_workflow_automation_tools + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - return load_workflow_automation_tools + if not nodes: + return {"updated_nodes": []} + # Verify workflow belongs to workspace + helpers.get_workflow(workflow_id, user, workspace) -# ============================================================================ -# TOOL TYPE REGISTRY -# ============================================================================ + updated = [] + errors = [] + nodes_needing_formulas = [] + with transaction.atomic(): + for node_update in nodes: + tool_helpers.raise_if_cancelled() + try: + orm_node = helpers.update_node( + user, workspace, node_update, tool_helpers + ) + updated.append({"node_id": orm_node.id, "label": orm_node.label}) + + # Check if any fields need formula generation + formulas = node_update.get_formulas_to_update(orm_node) + if formulas: + nodes_needing_formulas.append((node_update, orm_node, formulas)) + except Exception as e: + errors.append(f"Error updating node {node_update.node_id}: {e}") + + # Apply direct values and generate formulas outside the main transaction + for node_update, orm_node, formulas in nodes_needing_formulas: + node_update.apply_direct_values(orm_node.service) + tool_helpers.update_status( + _("Generating formulas for node '%(label)s'..." % {"label": orm_node.label}) + ) + with transaction.atomic(): + try: + agents.update_single_node_formulas(node_update, orm_node, tool_helpers) + except Exception as exc: + from loguru import logger + + logger.exception( + "Failed to generate formulas for node {}: {}", orm_node.id, exc + ) -class ListWorkflowsToolType(AssistantToolType): - type = "list_workflows" + result: dict[str, Any] = {"updated_nodes": updated} + if errors: + result["errors"] = errors + return result + + +def delete_nodes( + ctx: RunContext[AssistantDeps], + node_ids: Annotated[ + list[int], + Field(description="List of node IDs to delete."), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Delete automation nodes. + + WHEN to use: User wants to remove nodes from a workflow. + WHAT it does: Deletes the specified automation nodes. + RETURNS: Deleted node IDs and any errors. + DO NOT USE when: You want to modify a node — use update_nodes instead. + """ - @classmethod - def get_tool(cls, user, workspace, tool_helpers): - return get_list_workflows_tool(user, workspace, tool_helpers) + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers + if not node_ids: + return {"deleted_node_ids": []} -class WorkflowToolFactoryToolType(AssistantToolType): - type = "workflow_tool_factory" + deleted = [] + errors = [] - @classmethod - def get_tool(cls, user, workspace, tool_helpers): - return get_workflow_tool_factory(user, workspace, tool_helpers) + for node_id in node_ids: + tool_helpers.raise_if_cancelled() + tool_helpers.update_status( + _("Deleting node %(node_id)s...") % {"node_id": node_id} + ) + try: + helpers.delete_node(user, workspace, node_id) + deleted.append(node_id) + except Exception as e: + errors.append(f"Error deleting node {node_id}: {e}") + + result: dict[str, Any] = {"deleted_node_ids": deleted} + if errors: + result["errors"] = errors + return result + + +TOOL_FUNCTIONS = [ + list_workflows, + list_nodes, + create_workflows, + add_nodes, + update_nodes, + delete_nodes, +] +automation_toolset = FunctionToolset(TOOL_FUNCTIONS, max_retries=3) + +ROUTING_RULES = """\ +- Check list_* before create_* to avoid duplicates. +- switch_mode: switch domain if task needs tools not in the current mode. +- create_workflows: use {{ node.ref }} for node refs, $formula: prefix for dynamic field values. +- add_nodes: insert/append nodes. Use list_nodes first to find existing node IDs.""" diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/__init__.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/__init__.py index f2c9159123..5358c6af80 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/__init__.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/__init__.py @@ -1,26 +1,16 @@ from .node import ( - AiAgentNodeCreate, - CreateRowActionCreate, - DeleteRowActionCreate, - HasFormulasToCreateMixin, - NodeBase, - RouterNodeCreate, - SendEmailActionCreate, + ActionNodeCreate, + ActionNodeItem, + NodeUpdate, TriggerNodeCreate, - UpdateRowActionCreate, ) from .workflow import WorkflowCreate, WorkflowItem __all__ = [ "WorkflowCreate", "WorkflowItem", - "NodeBase", - "RouterNodeCreate", - "CreateRowActionCreate", - "UpdateRowActionCreate", - "DeleteRowActionCreate", - "SendEmailActionCreate", - "AiAgentNodeCreate", + "ActionNodeCreate", + "ActionNodeItem", + "NodeUpdate", "TriggerNodeCreate", - "HasFormulasToCreateMixin", ] diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/node.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/node.py index 1c4b916eb2..0c2f425324 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/node.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/node.py @@ -1,10 +1,17 @@ -from abc import ABC, abstractmethod -from typing import Annotated, Any, Literal, Optional +""" +Automation node type models and ORM conversion logic. + +Defines ``TriggerNodeCreate``, ``ActionNodeCreate``, and their read-back +counterparts (``TriggerNodeItem``, ``ActionNodeItem``), plus the dispatch +tables that convert between Pydantic models and Django ORM representations. +""" + +from typing import Any, Callable, Literal, Optional from uuid import uuid4 from django.conf import settings -from pydantic import Field, PrivateAttr +from pydantic import Field, PrivateAttr, model_serializer, model_validator from baserow.contrib.automation.nodes.models import AutomationNode from baserow.core.formula.types import ( @@ -13,81 +20,82 @@ ) from baserow.core.services.handler import ServiceHandler from baserow.core.services.models import Service +from baserow_enterprise.assistant.tools.shared.formula_utils import ( + FORMULA_PREFIX, + formula_desc, + literal_or_placeholder, + needs_formula, +) from baserow_enterprise.assistant.types import BaseModel +# Short marker appended to fields that support $formula: dynamic values. +# The full explanation lives in the create_workflows tool description. +SUPPORTS_FORMULA = f" Supports {FORMULA_PREFIX} prefix." -class NodeBase(BaseModel): - """Base node model.""" - - label: str = Field(..., description="The human readable name of the node") - type: str +# --------------------------------------------------------------------------- +# Field-mapping helpers (shared by apply_direct / update_formulas) +# --------------------------------------------------------------------------- -class RefCreate(BaseModel): - """Base node creation model.""" - - ref: str = Field( - ..., description="A reference ID for the node, only used during creation" - ) - - -class Item(BaseModel): - id: str +def _upsert_field_mappings( + service: Service, + values: dict[int, tuple[str, bool]], +): + """ + Bulk-upsert field mappings on a service. + + ``values`` maps ``field_id → (formula_value, enabled)``. + Existing mappings are updated in place; missing ones are created. + """ + + if not values: + return + + existing = {m.field_id: m for m in service.field_mappings.all()} + FieldMapping = service.field_mappings.model + to_create, to_update = [], [] + + for field_id, (formula, enabled) in values.items(): + if field_id in existing: + mapping = existing[field_id] + mapping.value = formula + mapping.enabled = enabled + to_update.append(mapping) + else: + to_create.append( + FieldMapping( + field_id=field_id, + value=formula, + enabled=enabled, + service_id=service.id, + ) + ) -class HasFormulasToCreateMixin(ABC): - @abstractmethod - def get_formulas_to_create(self, orm_node: AutomationNode) -> dict[str, str]: - """ - Creates and returns a mapping between field names and formulas to be created - for the given ORM node. Every value needs to contain instructions or description - on how to generate the formula for that field. - Prefix optional fields with "[optional]: " in the description to indicate they - are not mandatory. - """ + if to_create: + service.field_mappings.bulk_create(to_create) + if to_update: + FieldMapping.objects.bulk_update(to_update, ["value", "enabled"]) - pass - def update_service_with_formulas(self, service: Service, formulas: dict[str, str]): - save = False - for field_name, formula in formulas.items(): - if hasattr(service, field_name): - setattr( - service, - field_name, - BaserowFormulaObject.create(formula=formula), - ) - save = True - if save: - ServiceHandler().update_service(service.get_type(), service) +# --------------------------------------------------------------------------- +# Sub-models +# --------------------------------------------------------------------------- class PeriodicTriggerSettings(BaseModel): - interval: Literal["MINUTE", "HOUR", "DAY", "WEEK", "MONTH"] = Field( - ..., description="The interval for the periodic trigger" - ) + """All times in UTC — remove timezone offsets.""" + + interval: Literal["MINUTE", "HOUR", "DAY", "WEEK", "MONTH"] minute: int = Field( default=0, - description=( - "If interval=MINUTE, the number of minutes between each trigger. " - f"Minimum is set to {settings.INTEGRATIONS_PERIODIC_MINUTE_MIN} minutes. " - "If interval=HOUR, the UTC minute for the periodic trigger. " - ), - ) - hour: int = Field( - default=0, - description=( - "The UTC hour for the periodic trigger. " - "ALWAYS remove timezone offset from the context." - ), - ) - day_of_week: int = Field( - default=0, - description="The day of the week for the periodic trigger (0=Monday, 6=Sunday)", - ) - day_of_month: int = Field( - default=1, description="The day of the month for the periodic trigger (1-31)" + ge=0, + le=59, + description=f"MINUTE: minutes between triggers (min {settings.INTEGRATIONS_PERIODIC_MINUTE_MIN}). HOUR: minute of the hour.", ) + hour: int = Field(default=0, ge=0, le=23, description="UTC hour (0-23).") + day_of_week: int = Field(default=0, ge=0, le=6, description="0=Monday, 6=Sunday.") + day_of_month: int = Field(default=1, ge=1, le=31, description="1-31.") class RowsTriggersSettings(BaseModel): @@ -96,9 +104,46 @@ class RowsTriggersSettings(BaseModel): table_id: int = Field(..., description="The ID of the table to monitor") -class TriggerNodeCreate(NodeBase, RefCreate): +class RouterEdgeCreate(BaseModel): + """Router branch. Order matters: first matching branch is taken.""" + + label: str = Field(description="Branch label.") + condition: str = Field( + description="Boolean condition using comparison operators and get() functions.", + ) + + _uid: str = PrivateAttr(default_factory=lambda: str(uuid4())) + + def to_orm_service_dict(self) -> dict[str, Any]: + return {"uid": self._uid, "label": self.label} + + +class RouterBranch(RouterEdgeCreate): + """Existing router branch with ID.""" + + id: str + + +class AutomationFieldValue(BaseModel): + """Field ID → value mapping for row actions.""" + + field_id: int = Field(..., description="Database field ID.") + value: str = Field(..., description=f"Field value.{SUPPORTS_FORMULA}") + + +# --------------------------------------------------------------------------- +# Trigger +# --------------------------------------------------------------------------- + + +_PERIODIC_KEYS = {"interval", "minute", "hour", "day_of_week", "day_of_month"} + + +class TriggerNodeCreate(BaseModel): """Create a trigger node in a workflow.""" + ref: str = Field(..., description="Temporary reference ID for creation.") + label: str = Field(..., description="Display name.") type: Literal[ "periodic", "http_trigger", @@ -107,16 +152,38 @@ class TriggerNodeCreate(NodeBase, RefCreate): "rows_deleted", ] - # periodic trigger specific periodic_interval: Optional[PeriodicTriggerSettings] = Field( default=None, - description="UTC configuration for periodic trigger. ALWAYS remove timezone offset from the context.", + description="(periodic) Schedule settings in UTC.", ) rows_triggers_settings: Optional[RowsTriggersSettings] = Field( default=None, - description="Configuration for rows trigger", + description="(rows_*) Table to monitor.", ) + @model_validator(mode="before") + @classmethod + def _fold_flat_periodic(cls, data): + """Accept flat periodic fields (interval, hour, ...) and nest them.""" + + if not isinstance(data, dict): + return data + if data.get("periodic_interval") is not None: + return data + flat = {k: data.pop(k) for k in list(data) if k in _PERIODIC_KEYS} + if flat: + data["periodic_interval"] = flat + return data + + @model_validator(mode="after") + def _validate_trigger_settings(self): + if self.type == "periodic" and self.periodic_interval is None: + raise ValueError("periodic trigger requires periodic_interval") + if self.type in ("rows_created", "rows_updated", "rows_deleted"): + if self.rows_triggers_settings is None: + raise ValueError(f"{self.type} trigger requires rows_triggers_settings") + return self + def to_orm_service_dict(self) -> dict[str, Any]: """Convert to ORM dict for node creation service.""" @@ -138,28 +205,128 @@ def to_orm_service_dict(self) -> dict[str, Any]: return {} -class TriggerNodeItem(TriggerNodeCreate, Item): +class TriggerNodeItem(TriggerNodeCreate): """Existing trigger node with ID.""" + id: str http_trigger_url: str | None = Field( default=None, description="The URL to trigger the HTTP request" ) -class EdgeCreate(BaseModel): - previous_node_ref: str = Field( - ..., - description="The reference ID of the previous node to link from. Every node can have only one previous node.", - ) +# --------------------------------------------------------------------------- +# Action node +# --------------------------------------------------------------------------- + +ActionNodeType = Literal[ + "router", + "smtp_email", + "slack_write_message", + "create_row", + "update_row", + "delete_row", + "ai_agent", +] + + +class ActionNodeCreate(BaseModel): + """Flat model for creating an action node: type + type-specific fields.""" + + ref: str = Field(..., description="Temporary reference ID for creation.") + label: str = Field(..., description="Display name.") + type: ActionNodeType + previous_node_ref: str = Field(..., description="Ref of the preceding node.") router_edge_label: str = Field( default="", - description="If the previous node is a router, the edge label to link from if different from default", + description="Branch label if previous node is a router.", + ) + + # -- router -- + edges: list[RouterEdgeCreate] | None = Field( + default=None, + description="(router) Branches. A default branch is auto-created.", + ) + + # -- smtp_email -- + to_emails: str | None = Field( + default=None, description=f"(smtp_email) Recipients.{SUPPORTS_FORMULA}" + ) + cc_emails: str | None = Field( + default=None, description=f"(smtp_email) CC.{SUPPORTS_FORMULA}" + ) + bcc_emails: str | None = Field( + default=None, description=f"(smtp_email) BCC.{SUPPORTS_FORMULA}" + ) + subject: str | None = Field( + default=None, description=f"(smtp_email) Subject.{SUPPORTS_FORMULA}" + ) + body: str | None = Field( + default=None, description=f"(smtp_email) Body.{SUPPORTS_FORMULA}" ) + body_type: Literal["plain", "html"] = "plain" - def to_orm_reference_node( - self, node_mapping: dict - ) -> tuple[Optional[int], Optional[str]]: - """Get the ORM node ID and output label from the previous node reference.""" + # -- slack_write_message -- + channel: str | None = None + text: str | None = Field( + default=None, description=f"(slack) Message.{SUPPORTS_FORMULA}" + ) + + # -- create_row / update_row / delete_row -- + table_id: int | None = None + row_id: str | None = Field( + default=None, description=f"(update/delete_row) Row ID.{SUPPORTS_FORMULA}" + ) + values: list[AutomationFieldValue] | None = None + + # -- ai_agent -- + output_type: Literal["text", "choice"] = Field( + default="text", + description="(ai_agent) Chain another action to use the output.", + ) + choices: list[str] | None = Field( + default=None, + description="(ai_agent) Choices if output_type='choice'.", + ) + prompt: str | None = Field( + default=None, description=f"(ai_agent) Prompt.{SUPPORTS_FORMULA}" + ) + + # Required fields per type + _REQUIRED_FIELDS: dict[str, list[tuple[str, str]]] = { + "router": [("edges", "edges")], + "smtp_email": [ + ("to_emails", "to_emails"), + ("subject", "subject"), + ("body", "body"), + ], + "slack_write_message": [("channel", "channel"), ("text", "text")], + "create_row": [("table_id", "table_id"), ("values", "values")], + "update_row": [ + ("table_id", "table_id"), + ("row_id", "row_id"), + ("values", "values"), + ], + "delete_row": [("table_id", "table_id"), ("row_id", "row_id")], + "ai_agent": [("prompt", "prompt")], + } + + @model_validator(mode="after") + def _validate_required_for_type(self): + required = self._REQUIRED_FIELDS.get(self.type) + if required: + missing = [name for attr, name in required if getattr(self, attr) is None] + if missing: + raise ValueError(f"{self.type} requires {', '.join(missing)}") + return self + + # -- ORM conversion -- + + def to_orm_service_dict(self) -> dict[str, Any]: + """Convert type-specific fields to an ORM service dict.""" + return _TO_ORM_SERVICE[self.type](self) + + def to_orm_reference_node(self, node_mapping: dict) -> tuple[Optional[int], str]: + """Resolve the previous node reference into an ORM node ID and output label.""" if self.previous_node_ref not in node_mapping: raise ValueError( @@ -169,13 +336,13 @@ def to_orm_reference_node( previous_orm_node, previous_node_create = node_mapping[self.previous_node_ref] output = "" - if self.router_edge_label and previous_node_create.type == "router": + if ( + self.router_edge_label + and getattr(previous_node_create, "type", None) == "router" + ): + edges = getattr(previous_node_create, "edges", None) or [] output = next( - ( - edge._uid - for edge in previous_node_create.edges - if edge.label == self.router_edge_label - ), + (edge._uid for edge in edges if edge.label == self.router_edge_label), None, ) if output is None: @@ -185,356 +352,541 @@ def to_orm_reference_node( return previous_orm_node.id, output + # -- Formula lifecycle -- -class RouterEdgeCreate(BaseModel): - """Router branch configuration.""" + def get_formulas_to_create(self, orm_node: AutomationNode) -> dict[str, str] | None: + """Return a ``{key: description}`` dict of formulas to generate, or None.""" - label: str = Field( - description="The label of the router branch. Order of branches matters: first matching branch is taken.", - ) - condition: str = Field( - description=( - "The condition formula to evaluate for this branch as boolean. " - "Use comparison operators and get(...) functions to build the formula with a boolean result. " - "Always mentions the field values using get(...) functions." - ), - ) + fn = _GET_FORMULAS.get(self.type) + return fn(self, orm_node) if fn else None - _uid: str = PrivateAttr(default_factory=lambda: str(uuid4())) + def apply_direct_values(self, service: Service): + """Apply literal (non-$formula) values directly to the service.""" - def to_orm_service_dict(self) -> dict[str, Any]: - return { - "uid": self._uid, - "label": self.label, - } + fn = _APPLY_DIRECT.get(self.type) + if fn is not None: + fn(self, service) + def update_service_with_formulas(self, service: Service, formulas: dict[str, str]): + """Write generated formulas back to the ORM service.""" -class RouterBranch(RouterEdgeCreate, Item): - """Existing router branch with ID.""" + fn = _UPDATE_FORMULAS.get(self.type) + if fn is not None: + fn(self, service, formulas) + else: + _default_update_formulas(service, formulas) -class RouterNodeBase(NodeBase): - """Create a router node with branches.""" +# --------------------------------------------------------------------------- +# to_orm_service dispatch: (ActionNodeCreate) -> dict +# --------------------------------------------------------------------------- - type: Literal["router"] - edges: list[RouterEdgeCreate] = Field( - ..., - description="List of branches for the router node. A default branch is created automatically.", - ) +def _router_to_orm(n: ActionNodeCreate) -> dict[str, Any]: + return {"edges": [branch.to_orm_service_dict() for branch in n.edges]} -class RouterNodeCreate(RouterNodeBase, RefCreate, EdgeCreate, HasFormulasToCreateMixin): - """Create a router node with branches and link configuration.""" - def to_orm_service_dict(self) -> dict[str, Any]: - return {"edges": [branch.to_orm_service_dict() for branch in self.edges]} +def _email_to_orm(n: ActionNodeCreate) -> dict[str, Any]: + return { + "to_email": literal_or_placeholder(n.to_emails), + "cc_email": literal_or_placeholder(n.cc_emails), + "bcc_email": literal_or_placeholder(n.bcc_emails), + "subject": literal_or_placeholder(n.subject), + "body": literal_or_placeholder(n.body), + "body_type": f"'{n.body_type}'", + } - def get_formulas_to_create(self, orm_node: AutomationNode) -> dict[str, str]: - return {edge.label: edge.condition for edge in self.edges} - def update_service_with_formulas(self, service: Service, formulas: dict[str, str]): - orm_edges = service.specific.edges.all() - formulas = {k.lower(): v for k, v in formulas.items()} - EdgeModel = service.specific.edges.model - updates = [] - for orm_edge in orm_edges: - label = orm_edge.label.lower() - if label in formulas: - orm_edge.condition["mode"] = BASEROW_FORMULA_MODE_ADVANCED - orm_edge.condition["formula"] = formulas[label] - updates.append(orm_edge) - if updates: - EdgeModel.objects.bulk_update(updates, ["condition"]) - - -class RouterNodeItem(RouterNodeBase, Item): - """Existing router node with ID.""" - - -class SendEmailActionBase(NodeBase): - """Send email action configuration.""" - - type: Literal["smtp_email"] - to_emails: str - cc_emails: Optional[str] - bcc_emails: Optional[str] - subject: str - body: str - body_type: Literal["plain", "html"] = Field(default="plain") - - -class SendEmailActionCreate( - SendEmailActionBase, RefCreate, EdgeCreate, HasFormulasToCreateMixin -): - """Create a send email action with edge configuration.""" +def _slack_to_orm(n: ActionNodeCreate) -> dict[str, Any]: + channel = (n.channel or "").lstrip("#") + return { + "channel": channel, + "text": literal_or_placeholder(n.text), + } - def to_orm_service_dict(self) -> dict[str, Any]: - return { - "to_email": f"'{self.to_emails}'", - "cc_email": f"'{self.cc_emails or ''}'", - "bcc_email": f"'{self.bcc_emails or ''}'", - "subject": f"'{self.subject}'", - "body": f"'{self.body}'", - "body_type": f"'{self.body_type}'", - } - - def get_formulas_to_create(self, orm_node: AutomationNode) -> dict[str, str]: - values = {} - to_emails_base = ( - "A comma separated list of email addresses to send the email to." - ) - if self.to_emails: - values["to_emails"] = ( - to_emails_base + f" Value to resolve: {self.to_emails}" - ) - else: - values["to_emails"] = "[optional]: " + to_emails_base - cc_emails_base = "A comma separated list of email addresses to CC the email to." - if self.cc_emails: - values["cc_emails"] = ( - cc_emails_base + f" Value to resolve: {self.cc_emails}" - ) - else: - values["cc_emails"] = "[optional]: " + cc_emails_base - - bcc_emails_base = ( - "A comma separated list of email addresses to BCC the email to." - ) - if self.bcc_emails: - values["bcc_emails"] = ( - bcc_emails_base + f" Value to resolve: {self.bcc_emails}" - ) - else: - values["bcc_emails"] = "[optional]: " + bcc_emails_base +def _row_action_to_orm(n: ActionNodeCreate) -> dict[str, Any]: + return {"table_id": n.table_id} - values["subject"] = "The subject of the email." - if self.subject: - values["subject"] += f" Value to resolve: {self.subject}" - values["body"] = f"The {self.body_type} body content of the email." - if self.body: - values["body"] += f" Value to resolve: {self.body}" - return values +def _ai_agent_to_orm(n: ActionNodeCreate) -> dict[str, Any]: + return { + "ai_choices": (n.choices or []) if n.output_type == "choice" else [], + "ai_prompt": literal_or_placeholder(n.prompt), + "ai_output_type": n.output_type, + } -class SendEmailActionItem(SendEmailActionBase, Item): - """Existing send email action with ID.""" +_TO_ORM_SERVICE: dict[str, Callable] = { + "router": _router_to_orm, + "smtp_email": _email_to_orm, + "slack_write_message": _slack_to_orm, + "create_row": _row_action_to_orm, + "update_row": _row_action_to_orm, + "delete_row": _row_action_to_orm, + "ai_agent": _ai_agent_to_orm, +} -class SlackWriteMessageActionBase(NodeBase): - """Send Slack message action configuration.""" +# --------------------------------------------------------------------------- +# get_formulas_to_create dispatch: (ActionNodeCreate, AutomationNode) -> dict | None +# --------------------------------------------------------------------------- - type: Literal["slack_write_message"] - channel: str - text: str +def _router_formulas(n: ActionNodeCreate, orm_node: AutomationNode) -> dict[str, str]: + return {edge.label: edge.condition for edge in n.edges} -class SlackWriteMessageActionCreate( - SlackWriteMessageActionBase, RefCreate, EdgeCreate, HasFormulasToCreateMixin -): - """Create a send Slack message action with edge configuration.""" - def to_orm_service_dict(self) -> dict[str, Any]: +def _email_formulas( + n: ActionNodeCreate, orm_node: AutomationNode +) -> dict[str, str] | None: + fields = { + "to_emails": ( + "A comma separated list of email addresses to send the email to.", + n.to_emails, + ), + "cc_emails": ( + "A comma separated list of email addresses to CC the email to.", + n.cc_emails, + ), + "bcc_emails": ( + "A comma separated list of email addresses to BCC the email to.", + n.bcc_emails, + ), + "subject": ("The subject of the email.", n.subject), + "body": (f"The {n.body_type} body content of the email.", n.body), + } + values = { + key: f"{base_desc} Value to resolve: {formula_desc(val)}" + for key, (base_desc, val) in fields.items() + if needs_formula(val) + } + return values or None + + +def _slack_formulas( + n: ActionNodeCreate, orm_node: AutomationNode +) -> dict[str, str] | None: + if needs_formula(n.text): return { - "channel": self.channel, - "text": f"'{self.text}'", + "text": f"The message content. Value to resolve: {formula_desc(n.text)}" } + return None - def get_formulas_to_create(self, orm_node: AutomationNode) -> dict[str, str]: - values = {} - message_base = "The message content." - if self.text: - values["text"] = message_base + f" Value to resolve: '{self.text}'" - else: - values["text"] = "[optional]: " + message_base - return values +def _row_action_formulas( + n: ActionNodeCreate, orm_node: AutomationNode +) -> dict[str, str] | None: + from baserow_enterprise.assistant.tools.shared.formula_utils import ( + minimize_json_schema, + ) + service = orm_node.service.specific + schema = service.get_type().generate_schema(service.specific) + values_by_id = {fv.field_id: fv.value for fv in (n.values or [])} + values = {} -class CreateRowActionBase(NodeBase): - """Create row action configuration.""" + if needs_formula(n.row_id): + values["row_id"] = ( + f"the row ID to update. Value to resolve: {formula_desc(n.row_id)}" + ) - type: Literal["create_row"] - table_id: int - values: dict[int, Any] = Field( - ..., description="A mapping of field IDs to values or formulas to update" - ) + for v in minimize_json_schema(schema).values(): + value = values_by_id.get(int(v["id"])) + if needs_formula(value): + desc = v["desc"] + f" Value to resolve: {formula_desc(value)}" + values[int(v["id"])] = {**v, "desc": desc} + return values or None -class RowActionService: - def to_orm_service_dict(self) -> dict[str, Any]: + +def _ai_agent_formulas( + n: ActionNodeCreate, orm_node: AutomationNode +) -> dict[str, str] | None: + if needs_formula(n.prompt): return { - "table_id": self.table_id, + "ai_prompt": f"The AI prompt. Value to resolve: {formula_desc(n.prompt)}" } + return None -class RowActionFormulaToCreate(HasFormulasToCreateMixin): - def get_formulas_to_create(self, orm_node: AutomationNode) -> dict[str, str]: - from baserow_enterprise.assistant.tools.automation.utils import ( - _minimize_json_schema, - ) +_GET_FORMULAS: dict[str, Callable] = { + "router": _router_formulas, + "smtp_email": _email_formulas, + "slack_write_message": _slack_formulas, + "create_row": _row_action_formulas, + "update_row": _row_action_formulas, + "delete_row": _row_action_formulas, + "ai_agent": _ai_agent_formulas, +} - service = orm_node.service.specific - schema = service.get_type().generate_schema(service.specific) - values = {"row_id": "the row ID to update"} - for v in _minimize_json_schema(schema).values(): - desc = v["desc"] - value = self.values.get(int(v["id"])) - if value: - desc += f" Value to resolve: {value}" - else: - desc = "[optional]: " + desc - values[int(v["id"])] = {**v, "desc": desc} - return values - def update_service_with_formulas(self, service: Service, formulas: dict[str, str]): - row_id_formula = formulas.pop("row_id", None) - - field_mappings = {m.field_id: m for m in service.field_mappings.all()} - field_mapping_to_create = [] - field_mapping_to_update = [] - FieldMapping = service.field_mappings.model - for field_id, formula in formulas.items(): - if field_id in field_mappings: - field_mappings[field_id].value = formula - field_mappings[field_id].enabled = True - field_mapping_to_update.append(field_mappings[field_id]) - else: - field_mapping_to_create.append( - FieldMapping( - field_id=field_id, - value=formula, - enabled=True, - service_id=service.id, - ) - ) - if field_mapping_to_create: - service.field_mappings.bulk_create(field_mapping_to_create) - if field_mapping_to_update: - FieldMapping.objects.bulk_update( - field_mapping_to_update, ["value", "enabled"] - ) +# --------------------------------------------------------------------------- +# update_service_with_formulas dispatch +# --------------------------------------------------------------------------- + - if row_id_formula: - service.row_id = row_id_formula - ServiceHandler().update_service(service.get_type(), service) +def _default_update_formulas(service: Service, formulas: dict[str, str]): + """Set ``BaserowFormulaObject`` on named service fields.""" + save = False + for field_name, formula in formulas.items(): + if hasattr(service, field_name): + setattr(service, field_name, BaserowFormulaObject.create(formula=formula)) + save = True + if save: + ServiceHandler().update_service(service.get_type(), service) -class CreateRowActionCreate( - RowActionService, - CreateRowActionBase, - RefCreate, - EdgeCreate, - RowActionFormulaToCreate, + +def _router_update_formulas( + n: ActionNodeCreate, service: Service, formulas: dict[str, str] +): + """Write generated condition formulas to router edges.""" + + formulas_lower = {k.lower(): v for k, v in formulas.items()} + EdgeModel = service.specific.edges.model + updates = [] + for orm_edge in service.specific.edges.all(): + label = orm_edge.label.lower() + if label in formulas_lower: + orm_edge.condition["mode"] = BASEROW_FORMULA_MODE_ADVANCED + orm_edge.condition["formula"] = formulas_lower[label] + updates.append(orm_edge) + if updates: + EdgeModel.objects.bulk_update(updates, ["condition"]) + + +def _row_action_update_formulas( + n: ActionNodeCreate, service: Service, formulas: dict[str, str] ): - """Create a create row action with edge configuration.""" + """Write generated formulas to row action field mappings and row_id.""" + row_id_formula = formulas.pop("row_id", None) -class CreateRowActionItem(CreateRowActionBase, Item): - """Existing create row action with ID.""" + _upsert_field_mappings( + service, + {field_id: (formula, True) for field_id, formula in formulas.items()}, + ) + + if row_id_formula: + service.row_id = row_id_formula + ServiceHandler().update_service(service.get_type(), service) + + +_UPDATE_FORMULAS: dict[str, Callable] = { + "router": _router_update_formulas, + "create_row": _row_action_update_formulas, + "update_row": _row_action_update_formulas, + "delete_row": _row_action_update_formulas, +} + + +# --------------------------------------------------------------------------- +# apply_direct_values dispatch +# --------------------------------------------------------------------------- -class UpdateRowActionBase(NodeBase): - """Update row action configuration.""" +def _row_action_apply_direct(n: ActionNodeCreate, service: Service): + """Write literal (non-$formula) field values as quoted formulas.""" - type: Literal["update_row"] - table_id: int - row_id: str = Field(..., description="The row ID or a formula to identify the row") - values: dict[int, Any] = Field( - ..., description="A mapping of field IDs to values or formulas to update" + _upsert_field_mappings( + service, + { + fv.field_id: (f"'{fv.value}'", True) + for fv in (n.values or []) + if not needs_formula(fv.value) + }, ) + if n.row_id and not needs_formula(n.row_id): + service.row_id = f"'{n.row_id}'" + ServiceHandler().update_service(service.get_type(), service) -class UpdateRowActionCreate( - RowActionService, - UpdateRowActionBase, - RefCreate, - EdgeCreate, - RowActionFormulaToCreate, -): - """Create an update row action with edge configuration.""" +_APPLY_DIRECT: dict[str, Callable] = { + "create_row": _row_action_apply_direct, + "update_row": _row_action_apply_direct, + "delete_row": _row_action_apply_direct, +} -class UpdateRowActionItem(UpdateRowActionBase, Item): - """Existing update row action with ID.""" +# --------------------------------------------------------------------------- +# ActionNodeItem (read-back) +# --------------------------------------------------------------------------- -class DeleteRowActionBase(NodeBase): - """Delete row action configuration.""" - type: Literal["delete_row"] - table_id: int - row_id: str = Field(..., description="The row ID or a formula to identify the row") +# --------------------------------------------------------------------------- +# NodeUpdate (for update_nodes tool) +# --------------------------------------------------------------------------- -class DeleteRowActionCreate( - RowActionService, - DeleteRowActionBase, - RefCreate, - EdgeCreate, - RowActionFormulaToCreate, -): - """Create a delete row action with edge configuration.""" +class NodeUpdate(BaseModel): + """Flat model for updating an automation node.""" + node_id: int = Field(..., description="The ID of the node to update.") + label: str | None = Field(None, description="New display name.") -class DeleteRowActionItem(DeleteRowActionBase, Item): - """Existing delete row action with ID.""" + # -- smtp_email -- + to_emails: str | None = Field( + default=None, description=f"(smtp_email) Recipients.{SUPPORTS_FORMULA}" + ) + cc_emails: str | None = Field( + default=None, description=f"(smtp_email) CC.{SUPPORTS_FORMULA}" + ) + bcc_emails: str | None = Field( + default=None, description=f"(smtp_email) BCC.{SUPPORTS_FORMULA}" + ) + subject: str | None = Field( + default=None, description=f"(smtp_email) Subject.{SUPPORTS_FORMULA}" + ) + body: str | None = Field( + default=None, description=f"(smtp_email) Body.{SUPPORTS_FORMULA}" + ) + body_type: Literal["plain", "html"] | None = None + # -- slack_write_message -- + channel: str | None = None + text: str | None = Field( + default=None, description=f"(slack) Message.{SUPPORTS_FORMULA}" + ) -class AiAgentNodeBase(NodeBase): - """AI Agent action configuration.""" + # -- create_row / update_row / delete_row -- + table_id: int | None = None + row_id: str | None = Field( + default=None, description=f"(update/delete_row) Row ID.{SUPPORTS_FORMULA}" + ) + values: list[AutomationFieldValue] | None = None - type: Literal["ai_agent"] = Field( - ..., - description="Don't stop at this node. Chain some other action to use the AI output.", + # -- ai_agent -- + output_type: Literal["text", "choice"] | None = None + choices: list[str] | None = None + prompt: str | None = Field( + default=None, description=f"(ai_agent) Prompt.{SUPPORTS_FORMULA}" ) - output_type: Literal["text", "choice"] = Field(default="text") - choices: Optional[list[str]] = Field( - default=None, - description="List of choices if output_type is 'choice'", + + def to_update_service_dict(self, current_type: str) -> dict[str, Any] | None: + """Build a service kwargs dict from non-None fields. Returns None if no service fields set.""" + builder = _TO_UPDATE_SERVICE.get(current_type) + if builder is None: + return None + result = builder(self) + return result if result else None + + def get_formulas_to_update(self, orm_node: AutomationNode) -> dict[str, str] | None: + """Return a {key: description} dict of formulas to generate, or None.""" + fn = _GET_UPDATE_FORMULAS.get( + orm_node.service.get_type().type if orm_node.service else None + ) + return fn(self, orm_node) if fn else None + + def apply_direct_values(self, service: Service): + """Apply literal (non-$formula) values directly to the service.""" + fn = _APPLY_UPDATE_DIRECT.get(service.get_type().type if service else None) + if fn is not None: + fn(self, service) + + def update_service_with_formulas(self, service: Service, formulas: dict[str, str]): + """Write generated formulas back to the ORM service.""" + stype = service.get_type().type if service else None + fn = _UPDATE_FORMULAS.get(stype) + if fn is not None: + # Reuse the existing dispatch (expects ActionNodeCreate-like but works for our purposes) + fn(self, service, formulas) + else: + _default_update_formulas(service, formulas) + + +# -- to_update_service dispatch -- + + +def _email_update_service(n: "NodeUpdate") -> dict[str, Any]: + d = {} + if n.to_emails is not None: + d["to_email"] = literal_or_placeholder(n.to_emails) + if n.cc_emails is not None: + d["cc_email"] = literal_or_placeholder(n.cc_emails) + if n.bcc_emails is not None: + d["bcc_email"] = literal_or_placeholder(n.bcc_emails) + if n.subject is not None: + d["subject"] = literal_or_placeholder(n.subject) + if n.body is not None: + d["body"] = literal_or_placeholder(n.body) + if n.body_type is not None: + d["body_type"] = f"'{n.body_type}'" + return d + + +def _slack_update_service(n: "NodeUpdate") -> dict[str, Any]: + d = {} + if n.channel is not None: + d["channel"] = n.channel.lstrip("#") + if n.text is not None: + d["text"] = literal_or_placeholder(n.text) + return d + + +def _row_action_update_service(n: "NodeUpdate") -> dict[str, Any]: + d = {} + if n.table_id is not None: + d["table_id"] = n.table_id + return d + + +def _ai_agent_update_service(n: "NodeUpdate") -> dict[str, Any]: + d = {} + if n.prompt is not None: + d["ai_prompt"] = literal_or_placeholder(n.prompt) + if n.output_type is not None: + d["ai_output_type"] = n.output_type + if n.choices is not None: + d["ai_choices"] = n.choices + return d + + +_TO_UPDATE_SERVICE: dict[str, Callable] = { + "smtp_email": _email_update_service, + "slack_write_message": _slack_update_service, + "create_row": _row_action_update_service, + "update_row": _row_action_update_service, + "delete_row": _row_action_update_service, + "ai_agent": _ai_agent_update_service, +} + + +# -- get_formulas_to_update dispatch -- + + +def _email_update_formulas( + n: "NodeUpdate", orm_node: AutomationNode +) -> dict[str, str] | None: + fields = { + "to_emails": ("Recipients.", n.to_emails), + "cc_emails": ("CC.", n.cc_emails), + "bcc_emails": ("BCC.", n.bcc_emails), + "subject": ("Subject.", n.subject), + "body": ("Body.", n.body), + } + values = { + key: f"{base_desc} Value to resolve: {formula_desc(val)}" + for key, (base_desc, val) in fields.items() + if needs_formula(val) + } + return values or None + + +def _slack_update_formulas( + n: "NodeUpdate", orm_node: AutomationNode +) -> dict[str, str] | None: + if needs_formula(n.text): + return { + "text": f"The message content. Value to resolve: {formula_desc(n.text)}" + } + return None + + +def _row_action_update_formulas( + n: "NodeUpdate", orm_node: AutomationNode +) -> dict[str, str] | None: + from baserow_enterprise.assistant.tools.shared.formula_utils import ( + minimize_json_schema, ) - prompt: str + service = orm_node.service.specific + schema = service.get_type().generate_schema(service.specific) + values_by_id = {fv.field_id: fv.value for fv in (n.values or [])} + values = {} -class AiAgentNodeCreate( - AiAgentNodeBase, RefCreate, EdgeCreate, HasFormulasToCreateMixin -): - """Create an AI Agent action with edge configuration.""" + if needs_formula(n.row_id): + values["row_id"] = f"the row ID. Value to resolve: {formula_desc(n.row_id)}" - def to_orm_service_dict(self) -> dict[str, Any]: + for v in minimize_json_schema(schema).values(): + value = values_by_id.get(int(v["id"])) + if needs_formula(value): + desc = v["desc"] + f" Value to resolve: {formula_desc(value)}" + values[int(v["id"])] = {**v, "desc": desc} + + return values or None + + +def _ai_agent_update_formulas( + n: "NodeUpdate", orm_node: AutomationNode +) -> dict[str, str] | None: + if needs_formula(n.prompt): return { - "ai_choices": (self.choices or []) if self.output_type == "choice" else [], - "ai_prompt": f"'{self.prompt}'", - "ai_output_type": self.output_type, + "ai_prompt": f"The AI prompt. Value to resolve: {formula_desc(n.prompt)}" } + return None - def get_formulas_to_create(self, orm_node: AutomationNode) -> dict[str, str]: - return {"ai_prompt": self.prompt} +_GET_UPDATE_FORMULAS: dict[str, Callable] = { + "smtp_email": _email_update_formulas, + "slack_write_message": _slack_update_formulas, + "create_row": _row_action_update_formulas, + "update_row": _row_action_update_formulas, + "delete_row": _row_action_update_formulas, + "ai_agent": _ai_agent_update_formulas, +} -class AiAgentNodeItem(AiAgentNodeBase, Item): - """Existing AI Agent action with ID.""" +# -- apply_direct_values dispatch for update -- -AnyNodeCreate = Annotated[ - RouterNodeCreate - # actions - | SendEmailActionCreate - | SlackWriteMessageActionCreate - | CreateRowActionCreate - | UpdateRowActionCreate - | DeleteRowActionCreate - | AiAgentNodeCreate, - Field(discriminator="type"), -] -AnyNodeItem = ( - RouterNodeItem - # actions - | SendEmailActionItem - | CreateRowActionItem - | UpdateRowActionItem - | DeleteRowActionItem - | AiAgentNodeItem -) +def _row_action_update_apply_direct(n: "NodeUpdate", service: Service): + """Write literal (non-$formula) field values as quoted formulas.""" + _upsert_field_mappings( + service, + { + fv.field_id: (f"'{fv.value}'", True) + for fv in (n.values or []) + if not needs_formula(fv.value) + }, + ) + if n.row_id and not needs_formula(n.row_id): + service.row_id = f"'{n.row_id}'" + ServiceHandler().update_service(service.get_type(), service) + + +_APPLY_UPDATE_DIRECT: dict[str, Callable] = { + "create_row": _row_action_update_apply_direct, + "update_row": _row_action_update_apply_direct, + "delete_row": _row_action_update_apply_direct, +} + + +class ActionNodeItem(BaseModel): + """Existing action node with ID — flat structure, excludes None values.""" + + id: str + label: str + type: str + previous_node_ref: str | None = None + router_edge_label: str | None = None + + # (router) + edges: list[RouterBranch] | None = None + + # (smtp_email) + to_emails: str | None = None + cc_emails: str | None = None + bcc_emails: str | None = None + subject: str | None = None + body: str | None = None + body_type: str | None = None + + # (slack_write_message) + channel: str | None = None + text: str | None = None + + # (create_row, update_row, delete_row) + table_id: int | None = None + row_id: str | None = None + values: list[AutomationFieldValue] | None = None + + # (ai_agent) + output_type: str | None = None + choices: list[str] | None = None + prompt: str | None = None + + @model_serializer(mode="wrap") + def _exclude_none(self, handler): + return {k: v for k, v in handler(self).items() if v is not None} diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/workflow.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/workflow.py index 5470d91648..5fe625816c 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/workflow.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/workflow.py @@ -1,66 +1,18 @@ -from typing import Annotated, Literal - from pydantic import Field from baserow_enterprise.assistant.types import BaseModel -from .node import AnyNodeCreate, TriggerNodeCreate - - -class WorkflowEdgeCreate(BaseModel): - """Workflow edge connecting two nodes.""" - - type: Literal["edge"] - from_node_label: str = Field( - ..., - description="The label of the node where the edge starts", - ) - to_node_label: str = Field( - ..., - description="The label of the node where the edge ends", - ) - - -class WorkflowRouterEdgeCreate(WorkflowEdgeCreate): - """Workflow edge connecting to a router node with a branch label.""" - - type: Literal["router_branch"] - router_branch_label: str = Field( - default="", - description="The branch label for the router node edge", - ) - - -AnyWorkflowEdgeCreate = Annotated[ - WorkflowEdgeCreate, - WorkflowRouterEdgeCreate, - Field( - discriminator="type", - default="edge", - description=( - "The type of workflow edge. Use 'edge' in normal linear (a follows b) connections. " - "Use 'router_branch' when connecting to a router node with a branch label. " - ), - ), -] +from .node import ActionNodeCreate, TriggerNodeCreate class WorkflowCreate(BaseModel): """Base workflow model.""" - name: str = Field(..., description="The name of the workflow") - trigger: TriggerNodeCreate = Field( - ..., - description="The trigger node configuration for the workflow", - ) - nodes: list[AnyNodeCreate] = Field( + name: str = Field(..., description="Workflow name.") + trigger: TriggerNodeCreate = Field(..., description="The trigger node.") + nodes: list[ActionNodeCreate] = Field( default_factory=list, - description=( - "The nodes executed or evaluated once the trigger fires. " - "Every node must have only one incoming edge. If the previous node is a router, " - "the branch label must be specified for non-default branches. " - "Only if explicitly requested, this list can be empty." - ), + description="Action nodes executed after the trigger. Each node has one previous_node_ref.", ) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/utils.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/utils.py deleted file mode 100644 index 2a2dcac63a..0000000000 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/utils.py +++ /dev/null @@ -1,390 +0,0 @@ -from datetime import date, datetime -from typing import TYPE_CHECKING, Any, Tuple - -from django.contrib.auth.models import AbstractUser -from django.db import transaction -from django.utils.translation import gettext as _ - -import udspy -from loguru import logger -from pydantic import ConfigDict - -from baserow.contrib.automation.models import Automation -from baserow.contrib.automation.nodes.models import AutomationNode -from baserow.contrib.automation.nodes.registries import automation_node_type_registry -from baserow.contrib.automation.nodes.service import AutomationNodeService -from baserow.contrib.automation.workflows.models import AutomationWorkflow -from baserow.contrib.automation.workflows.service import AutomationWorkflowService -from baserow.core.formula import resolve_formula -from baserow.core.formula.registries import formula_runtime_function_registry -from baserow.core.formula.types import ( - BASEROW_FORMULA_MODE_ADVANCED, - BaserowFormulaObject, - FormulaContext, -) -from baserow.core.models import Workspace -from baserow.core.service import CoreService -from baserow.core.utils import to_path - -from .prompts import GENERATE_FORMULA_PROMPT -from .types import HasFormulasToCreateMixin, NodeBase, WorkflowCreate - -if TYPE_CHECKING: - from baserow_enterprise.assistant.assistant import ToolHelpers - - -def _minimize_json_schema(schema) -> dict[str, dict[str, str]]: - """ - Generate a mapping between field ids and names/types from a JSON schema. - Useful when generating formulas to understand the provided context. - """ - - field_type_descriptions = { - "link_row": "the row ID as number or the primary field value as string", - "single_select": "the option ID as number or the value as string", - "multiple_select": "a comma separated list of option IDs or values as string", - "date": "a date string in ISO 8601 format", - "date_time": "a date-time string in ISO 8601 format", - "boolean": "true or false", - } - field_type_extra_info = { - "single_select": lambda meta: { - "select_options": meta.get("select_options", []) - }, - "multiple_select": lambda meta: { - "select_options": meta.get("select_options", []) - }, - "multiple_collaborators": lambda meta: { - "available_collaborators": meta.get("available_collaborators", []) - }, - } - - if schema.get("type") == "array": - return _minimize_json_schema(schema.get("items")) - elif schema.get("type") != "object": - raise ValueError("Schema must be of type object or array of objects") - - properties = schema.get("properties", {}) - mapping = {} - for key, prop in properties.items(): - metadata = prop.get("metadata") - if metadata: - field_type = metadata["type"] - mapping[key] = { - "id": metadata["id"], - "name": metadata["name"], - "type": field_type, - "desc": field_type_descriptions.get(field_type, ""), - } - if field_type in field_type_extra_info: - get_extra_info = field_type_extra_info[field_type] - mapping[key].update(get_extra_info(metadata)) - return mapping - - -def _create_example_from_json_schema(schema) -> Tuple[dict, dict]: - """ - Generate example data from a JSON schema. - Useful when generating formulas to provide example context data. - """ - - examples = { - "string": "text", - "number": 1, - "boolean": True, - "null": None, - "object": lambda prop: _create_example_from_json_schema(prop), - "array": lambda prop: [_create_example_from_json_schema(prop["items"])], - } - - if schema.get("type") == "array": - return [_create_example_from_json_schema(schema.get("items"))] - elif schema.get("type") != "object": - raise ValueError("Schema must be of type object or array of objects") - - properties = schema.get("properties", {}) - example = {} - for key, prop in properties.items(): - value = examples[prop.get("type")] - if callable(value): - example[key] = value(prop) - else: - example[key] = value - return example - - -class AssistantFormulaContext(FormulaContext): - def __init__(self): - self.context = {} - self.context_metadata = {} - super().__init__() - - def add_node_context( - self, - node_id: int | str, - node_context: dict[str, any], - context_metadata: dict[str, dict[str, str]] | None = None, - ): - """Update the formula context with new values.""" - - self.context.update({str(node_id): node_context}) - if context_metadata: - self.context_metadata.update({str(node_id): context_metadata}) - - def get_formula_context(self) -> dict[str, any]: - return {"previous_node": self.context} - - def get_context_metadata(self) -> dict[str, any]: - return self.context_metadata - - def __getitem__(self, key) -> any: - start, *key_parts = to_path(key) - if start != "previous_node": - raise KeyError( - f"Key '{key}' not found in context. Only 'previous_node' is supported at the root level." - ) - value = self.context - for kp in key_parts: - try: - value = value[int(kp) if isinstance(value, list) else kp] - except (KeyError, TypeError, ValueError): - available_keys = ( - list(value.keys()) - if isinstance(value, dict) - else ", ".join(map(str, range(len(value)))) - ) - raise KeyError( - f"Key '{kp}' of '{key}' not found in {value}, Available keys: {available_keys}" - ) - if not isinstance(value, (int, float, str, bool, date, datetime)): - raise ValueError( - f"Value for key '{key}' is not a valid type. " - f"Expected int, float, str, bool, date, or datetime. " - f"Got {type(value).__name__} instead. " - f"Make sure to only reference primitive types in the formula context." - ) - return value - - -def get_generate_formulas_tool(): - class RuntimeFormulaGenerator(udspy.Signature): - __doc__ = GENERATE_FORMULA_PROMPT - - fields_to_resolve: dict[str, dict[str, str]] = udspy.InputField( - desc=( - "The fields that need formulas to be generated. " - "If prefixed with [optional], the field is not mandatory." - ) - ) - context: dict[str, Any] = udspy.InputField( - desc="The available context to use in formula generation composed of previous nodes results." - ) - context_metadata: dict[str, Any] = udspy.InputField( - desc="Metadata about the context fields, with refs and names to assist in formula generation." - ) - feedback: str = udspy.InputField( - desc="Validation errors from previous attempt. Empty if first attempt." - ) - generated_formulas: dict[str, Any] = udspy.OutputField() - - model_config = ConfigDict(arbitrary_types_allowed=True) - - def check_formula(generated_formula: str, context: AssistantFormulaContext) -> str: - try: - resolve_formula( - BaserowFormulaObject.create( - formula=generated_formula, mode=BASEROW_FORMULA_MODE_ADVANCED - ), - formula_runtime_function_registry, - context, - ) - except Exception as exc: - raise ValueError(f"Generated formula is invalid: {str(exc)}") - return "ok, the formula is valid" - - def generate_node_formulas( - fields_to_resolve: dict, - context: AssistantFormulaContext, - max_retries: int = 3, - ) -> str: - """ - For every non-null input field in the node's schema, generate a formula - that fulfills the request, using the provided context object. - """ - - predict = udspy.Predict(RuntimeFormulaGenerator) - feedback = "" - for __ in range(max_retries): - result = predict( - fields_to_resolve=fields_to_resolve, - context=context.get_formula_context(), - context_metadata=context.get_context_metadata(), - feedback=feedback, - ) - # Ensure all the generated formulas are valid - valid_formulas = {} - generated_formulas = result.generated_formulas - for field_id, formula in generated_formulas.items(): - try: - check_formula(formula, context) - valid_formulas[field_id] = formula - except ValueError as exc: - feedback += f"Error for {field_id}, formula {formula} not valid: {str(exc)}\n" - - if len(valid_formulas) == len(generated_formulas): - return valid_formulas - - # Any valid formula is better than none - if valid_formulas: - return valid_formulas - else: - raise ValueError( - "Failed to generate any valid formulas after " - f"{max_retries} attempts. Feedback:\n{feedback}" - ) - - return generate_node_formulas - - -def get_automation( - automation_id: int, user: AbstractUser, workspace: Workspace -) -> Automation: - """Get automation with permission check.""" - - base_queryset = Automation.objects.filter(workspace=workspace) - automation = CoreService().get_application( - user, automation_id, base_queryset=base_queryset - ) - return automation - - -def get_workflow( - workflow_id: int, user: AbstractUser, workspace: Workspace -) -> AutomationWorkflow: - """Get workflow with permission check.""" - - workflow = AutomationWorkflowService().get_workflow(user, workflow_id) - if workflow.automation.workspace_id != workspace.id: - raise ValueError("Workflow not in workspace") - return workflow - - -def create_workflow( - user: AbstractUser, - automation: Automation, - workflow: "WorkflowCreate", - tool_helpers: "ToolHelpers", -) -> Tuple[AutomationWorkflow, dict[int | str, Any]]: - """ - Creates a new workflow in the given automation based on the provided definition. - """ - - tool_helpers.update_status( - _("Creating workflow '%(name)s'..." % {"name": workflow.name}) - ) - - orm_wf = AutomationWorkflowService().create_workflow( - user, automation.id, workflow.name - ) - - node_mapping = {} - - # First create the trigger node - orm_service_data = workflow.trigger.to_orm_service_dict() - node_type = automation_node_type_registry.get(workflow.trigger.type) - tool_helpers.update_status( - _("Creating trigger '%(label)s'..." % {"label": workflow.trigger.label}) - ) - orm_trigger = AutomationNodeService().create_node( - user, - node_type, - orm_wf, - label=workflow.trigger.label, - service=orm_service_data, - ) - - node_mapping[workflow.trigger.ref] = node_mapping[orm_trigger.id] = ( - orm_trigger, - workflow.trigger, - ) - - for node in workflow.nodes: - orm_service_data = node.to_orm_service_dict() - reference_node_id, output = node.to_orm_reference_node(node_mapping) - node_type = automation_node_type_registry.get(node.type) - tool_helpers.update_status( - _("Creating node '%(label)s'..." % {"label": node.label}) - ) - orm_node = AutomationNodeService().create_node( - user, - node_type, - orm_wf, - reference_node_id=reference_node_id, - output=output, - label=node.label, - service=orm_service_data, - ) - node_mapping[node.ref] = node_mapping[orm_node.id] = (orm_node, node) - - return orm_wf, node_mapping - - -def update_workflow_formulas( - workflow: "WorkflowCreate", - node_mapping: dict[int | str, Any], - tool_helpers: "ToolHelpers", -) -> None: - """ - Loop over all nodes and verify if they have formulas to update. If so, update the - formulas in the ORM node service providing the available context up to that node and - the user request for that node. - """ - - context = AssistantFormulaContext() - - def _get_service_schema(orm_node: AutomationNode): - return orm_node.service.get_type().generate_schema(orm_node.service.specific) - - def _update_context_with_node_data( - orm_node: AutomationNode, node_to_create: NodeBase - ): - schema = _get_service_schema(orm_node) - example = _create_example_from_json_schema(schema) - descr = _minimize_json_schema(schema) - descr["node_id"] = orm_node.id - descr["node_ref"] = node_to_create.ref - if getattr(node_to_create, "previous_node_ref", None): - descr["previous_node_ref"] = node_to_create.previous_node_ref - context.add_node_context(orm_node.id, example, descr) - - # Add the trigger context first - trigger_node = workflow.trigger - orm_trigger, __ = node_mapping[trigger_node.ref] - _update_context_with_node_data(orm_trigger, trigger_node) - - generate_formula_tool = get_generate_formulas_tool() - - def _generate_and_update_node_formulas( - node: HasFormulasToCreateMixin, orm_node: AutomationNode - ): - formulas_to_create = node.get_formulas_to_create(orm_node) - result = generate_formula_tool(formulas_to_create, context) - if result: - node.update_service_with_formulas(orm_node.service, result) - - # Node by node, generate formulas if needed and update the context with the node - # data, so following nodes can use it. - for node in workflow.nodes: - orm_node, __ = node_mapping[node.ref] - if isinstance(node, HasFormulasToCreateMixin): - tool_helpers.update_status( - _("Generating formulas for node '%(label)s'..." % {"label": node.label}) - ) - with transaction.atomic(): - try: - _generate_and_update_node_formulas(node, orm_node) - except Exception as exc: - logger.exception( - "Failed to generate formulas for node %s: %s", orm_node.id, exc - ) - _update_context_with_node_data(orm_node, node) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/core/tool_types.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/core/tool_types.py new file mode 100644 index 0000000000..83b1ca9391 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/core/tool_types.py @@ -0,0 +1,15 @@ +from baserow_enterprise.assistant.tools.registries import AssistantToolType + + +class CoreToolType(AssistantToolType): + type = "core" + + def get_tool_functions(self): + from .tools import TOOL_FUNCTIONS + + return TOOL_FUNCTIONS + + def get_toolset(self): + from .tools import core_toolset + + return core_toolset diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/core/tools.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/core/tools.py index 8d94e02e70..f8dc64e0e2 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/core/tools.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/core/tools.py @@ -1,130 +1,173 @@ -from typing import TYPE_CHECKING, Any, Callable, Literal +from typing import Annotated, Any, Literal -from django.contrib.auth.models import AbstractUser from django.db import transaction from django.utils.translation import gettext as _ +from pydantic import Field +from pydantic_ai import RunContext +from pydantic_ai.toolsets import FunctionToolset + from baserow.core.actions import CreateApplicationActionType -from baserow.core.models import Workspace -from baserow.core.registries import application_type_registry from baserow.core.service import CoreService -from baserow_enterprise.assistant.tools.registries import AssistantToolType - -from .types import AnyBuilderItem, BuilderItem, BuilderItemCreate - -if TYPE_CHECKING: - from baserow_enterprise.assistant.assistant import ToolHelpers - +from baserow_enterprise.assistant.deps import AgentMode, AssistantDeps + +from .types import BuilderItem, BuilderItemCreate, builder_type_registry + + +def list_builders( + ctx: RunContext[AssistantDeps], + builder_types: Annotated[ + list[Literal["database", "application", "automation", "dashboard"]] | None, + Field( + description="Filter: only return builders of these types. null to return all types." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + List databases, applications, automations, dashboards in the workspace. + + WHEN to use: You need to find databases, applications, automations, or dashboards in the workspace. Call this before creating builders to avoid duplicates. + WHAT it does: Lists all builders the user can access, optionally filtered by type. Max 20 results. + RETURNS: Dict of builders grouped by type, each with id, name, type. + DO NOT USE when: You already know the builder ID you need. + """ -def get_list_builders_tool( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[], list[AnyBuilderItem]]: + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers + + tool_helpers.update_status( + _("Listing %(builder_types)ss...") + % { + "builder_types": builder_types[0] + if builder_types and len(builder_types) == 1 + else "builder" + } + ) + + applications_qs = CoreService().list_applications_in_workspace( + user, workspace, specific=False + ) + + builders = {} + for app in applications_qs: + try: + item = builder_type_registry.from_django_orm(app) + except KeyError: + continue + if not builder_types or item.type in builder_types: + builders.setdefault(item.type, []).append(item.model_dump()) + + if not builders: + return {} + + total = sum(len(v) for v in builders.values()) + max_items = 20 + if total > max_items: + truncated = {} + remaining = max_items + for btype, items in builders.items(): + truncated[btype] = items[:remaining] + remaining -= len(truncated[btype]) + if remaining <= 0: + break + return { + **truncated, + "_info": f"Showing {max_items} of {total} builders. " + "Use builder_types to filter.", + } + + return builders + + +def create_builders( + ctx: RunContext[AssistantDeps], + builders: Annotated[ + list[BuilderItemCreate], + Field(description="List of builders to create, each with a name and type."), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Create a new database, application, or automation. + + WHEN to use: User wants a new database, application, or automation created in the workspace. + WHAT it does: Creates one or more builders with the specified names and types. + RETURNS: List of created builders with id, name, type. + DO NOT USE when: A builder with that name may already exist — check with list_builders first. + HOW: Pick a unique, descriptive name. Check existing builders with list_builders to avoid duplicates. """ - Returns a function that lists all the builders the user has access to in the - current workspace. + + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers + + created_builders = [] + with transaction.atomic(): + for builder in builders: + tool_helpers.raise_if_cancelled() + tool_helpers.update_status( + _("Creating %(builder_type)s %(builder_name)s...") + % {"builder_type": builder.type, "builder_name": builder.name} + ) + builder_orm_instance = CreateApplicationActionType.do( + user, workspace, builder.get_orm_type(), name=builder.name + ) + builder.post_creation_hook(user, builder_orm_instance) + created_builders.append( + BuilderItem( + id=builder_orm_instance.id, + name=builder_orm_instance.name, + type=builder.type, + ).model_dump() + ) + + return {"created_builders": created_builders} + + +def switch_mode( + ctx: RunContext[AssistantDeps], + mode: Annotated[ + Literal["database", "application", "automation", "explain"], + Field( + description=( + "Target mode: 'database' for table/field/view/row ops, " + "'application' for page/element/data-source ops, " + "'automation' for workflow/node ops, " + "'explain' for answering Baserow questions." + ) + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> str: + """\ + Switch between domain modes (database, application, automation, explain). + + WHEN to use: Task needs tools from a different domain, or user asks a how-to question (→ "explain"). + WHAT it does: Changes the available toolset to the target domain's tools. + RETURNS: Confirmation of mode switch. + DO NOT USE when: Already in the requested mode. """ - def list_builders( - builder_types: list[ - Literal["database", "application", "automation", "dashboard"] - ] - | None = None, - ) -> list[AnyBuilderItem] | str: - """ - Lists all the builders the user can access (databases, applications, - automations, dashboards) in the current workspace. - - If `builder_types` is provided, only builders of that type are returned, - otherwise all builders are returned (default). - """ - - nonlocal user, workspace, tool_helpers - - tool_helpers.update_status( - _("Listing %(builder_types)ss...") - % { - "builder_types": builder_types[0] - if builder_types and len(builder_types) == 1 - else "builder" - } - ) + target = AgentMode(mode) + if ctx.deps.mode == target: + return f"Already in {target.value} mode." - applications_qs = CoreService().list_applications_in_workspace( - user, workspace, specific=False + ctx.deps.mode = target + if target == AgentMode.EXPLAIN: + return ( + "Switched to explain mode. " + "Call search_user_docs now to answer the user's question from the Baserow documentation." ) + return f"Switched to {target.value} mode." - builders = {} - for builder in applications_qs: - builder_type = application_type_registry.get_by_model( - builder.specific_class - ).type - if not builder_types or builder_type in builder_types: - builders.setdefault(builder_type, []).append( - BuilderItem( - id=builder.id, name=builder.name, type=builder_type - ).model_dump() - ) - - return builders if builders else "no builders found" - - return list_builders - - -class ListBuildersToolType(AssistantToolType): - type = "list_builders" - - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - return get_list_builders_tool(user, workspace, tool_helpers) - - -def get_create_modules_tool( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[str], dict[str, Any]]: - """ - Returns a function that creates a module in the current workspace. - """ - def create_builders(builders: list[BuilderItemCreate]) -> dict[str, Any]: - """ - Create a builder in the current workspace and return its ID and name. - - - name: desired name for the builder (better if unique in the workspace) - """ - - nonlocal user, workspace, tool_helpers - - created_builders = [] - with transaction.atomic(): - for builder in builders: - tool_helpers.update_status( - _("Creating %(builder_type)s %(builder_name)s...") - % {"builder_type": builder.type, "builder_name": builder.name} - ) - builder_orm_instance = CreateApplicationActionType.do( - user, workspace, builder.get_orm_type(), name=builder.name - ) - builder.post_creation_hook(user, builder_orm_instance) - created_builders.append( - BuilderItem( - id=builder_orm_instance.id, - name=builder_orm_instance.name, - type=builder.type, - ).model_dump() - ) - - return {"created_builders": created_builders} - - return create_builders - - -class CreateBuildersToolType(AssistantToolType): - type = "create_builders" - - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - return get_create_modules_tool(user, workspace, tool_helpers) +TOOL_FUNCTIONS = [list_builders, create_builders, switch_mode] +core_toolset = FunctionToolset(TOOL_FUNCTIONS, max_retries=3) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/core/types.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/core/types.py index 87183d68dc..0612620349 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/core/types.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/core/types.py @@ -30,10 +30,14 @@ def get_orm_type(self) -> str: def from_django_orm(cls, orm_app: BaserowApplication) -> "BuilderItem": """Creates a BuilderItem instance from a Django ORM Application instance.""" + orm_type = application_type_registry.get_by_model(orm_app.specific_class).type + # The application_type_registry uses "builder" internally, but our + # Literal type expects "application". + type_mapping = {"builder": "application"} return cls( id=orm_app.id, name=orm_app.name, - type=application_type_registry.get_by_model(orm_app.specific_class).type, + type=type_mapping.get(orm_type, orm_type), ) def _post_creation_hook(self, user, builder_orm_instance): diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/agents.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/agents.py new file mode 100644 index 0000000000..a92d7c2000 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/agents.py @@ -0,0 +1,284 @@ +from typing import Any, Callable + +from django.contrib.auth.models import AbstractUser +from django.utils.translation import gettext as _ + +from pydantic import BaseModel as PydanticBaseModel +from pydantic import Field +from pydantic_ai import Agent, Tool +from pydantic_ai.toolsets import FunctionToolset +from pydantic_ai.usage import UsageLimits + +from baserow.contrib.database.api.formula.serializers import TypeFormulaResultSerializer +from baserow.contrib.database.fields.handler import FieldHandler +from baserow.contrib.database.fields.models import FormulaField +from baserow.core.models import Workspace +from baserow_premium.prompts import get_formula_docs + +from . import helpers +from .prompts import ( + FORMULA_AGENT_INSTRUCTIONS, + SAMPLE_ROW_AGENT_INSTRUCTIONS, + format_formula_fixer_prompt, + format_sample_rows_prompt, +) + +# --------------------------------------------------------------------------- +# Formula generation agent +# --------------------------------------------------------------------------- + + +class FormulaGenerationResult(PydanticBaseModel): + """Output model for the formula generation agent.""" + + table_id: int = Field( + description=( + "The ID of the table the formula is intended for. " + "Should be the same as current_table_id, unless the formula can " + "only be created in a different table." + ) + ) + field_name: str = Field( + description="The name of the formula field to be created. For a new field, it must be unique in the table." + ) + formula: str = Field( + description="The generated formula. Must be a valid Baserow formula." + ) + formula_type: str = Field( + description=( + "The type of the generated formula. Must be one of: text, long_text, " + "number, boolean, date, link_row, single_select, multiple_select, duration, array." + ) + ) + is_formula_valid: bool = Field( + description="Whether the generated formula is valid or not." + ) + error_message: str = Field( + default="", + description="If the formula is not valid, an error message explaining why.", + ) + + +formula_generation_agent: Agent[None, FormulaGenerationResult] = Agent( + output_type=FormulaGenerationResult, + instructions=FORMULA_AGENT_INSTRUCTIONS, + name="formula_generation_agent", +) + + +def get_formula_type_tool( + user: AbstractUser, workspace: Workspace +) -> Callable[[str], str]: + """ + Returns a function that validates a formula and returns its type. + """ + + def get_formula_type(table_id: int, field_name: str, formula: str) -> str: + """ + Returns the type of a formula. Raises an exception if the formula + is not valid. + **ALWAYS** call this to validate a formula is valid before returning it. + """ + + nonlocal user, workspace + + table = helpers.filter_tables(user, workspace).filter(id=table_id).first() + if not table: + raise ValueError(f"Table with ID {table_id} not found in workspace.") + + field = FormulaField(formula=formula, table=table, name=field_name, order=0) + field.recalculate_internal_fields(raise_if_invalid=True) + + result = TypeFormulaResultSerializer(field).data + if result["error"]: + field_names = list( + FieldHandler() + .get_base_fields_queryset() + .filter(table=table) + .values_list("name", flat=True) + ) + raise TypeError( + f"Invalid formula: {result['error']}. " + f"Available fields in table '{table.name}': {', '.join(field_names)}" + ) + + return result["formula_type"] + + return get_formula_type + + +def make_formula_fixer( + user: AbstractUser, workspace: Workspace, tool_helpers +) -> Callable: + """ + Returns a callback that tries to auto-generate a valid formula when the + LLM-provided one is invalid. Uses the ``formula_generation_agent``. + """ + + def fix_formula(table, field_name: str, original_formula: str) -> str | None: + database_tables = helpers.filter_tables(user, workspace).filter( + database_id=table.database_id + ) + schema = [ + t.model_dump() for t in helpers.get_tables_schema(database_tables, True) + ] + tool_helpers.update_status( + _("Fixing formula for %(name)s...") % {"name": field_name} + ) + + formula_type_tool = Tool(get_formula_type_tool(user, workspace)) + formula_toolset = FunctionToolset([formula_type_tool]) + prompt = format_formula_fixer_prompt( + field_name, original_formula, schema, get_formula_docs() + ) + from baserow_enterprise.assistant.model_profiles import ( + UTILITY, + get_model_settings, + get_model_string, + ) + + model = get_model_string() + result = formula_generation_agent.run_sync( + prompt, + model=model, + model_settings=get_model_settings(model, UTILITY), + toolsets=[formula_toolset], + usage_limits=UsageLimits(request_limit=20), + ) + if result.output.is_formula_valid: + return result.output.formula + return None + + return fix_formula + + +# --------------------------------------------------------------------------- +# Sample-row generation agent +# --------------------------------------------------------------------------- + + +def _find_reverse_link_row_fields(tables: list) -> dict[int, set[int]]: + """ + Identify auto-created reverse link_row fields across a set of tables. + + When a link_row field is created between two tables, Baserow auto-creates + a reverse field on the linked table. For sample-row generation we only + want the "owning" side (the explicitly created field) so the agent doesn't + face circular dependencies. + + For any bidirectional pair the field with the **higher** ID is the + auto-created reverse (it's created immediately after the explicit one). + + :returns: ``{table_id: {field_id, ...}}`` of reverse field IDs to exclude. + """ + + from baserow.contrib.database.fields.models import LinkRowField + + table_ids = {t.id for t in tables} + link_fields = LinkRowField.objects.filter( + table_id__in=table_ids, link_row_table_id__in=table_ids + ).select_related("link_row_related_field") + + reverse_ids: dict[int, set[int]] = {} + seen_pairs: set[tuple[int, int]] = set() + + for lf in link_fields: + related = lf.link_row_related_field + if related is None: + continue + pair = (min(lf.id, related.id), max(lf.id, related.id)) + if pair in seen_pairs: + continue + seen_pairs.add(pair) + + # The field with the higher ID is the auto-created reverse. + reverse = lf if lf.id > related.id else related + reverse_ids.setdefault(reverse.table_id, set()).add(reverse.id) + + return reverse_ids + + +def generate_sample_rows( + user: AbstractUser, + workspace: Workspace, + tool_helpers, + created_tables: list, + data_brief: str | None = None, +) -> dict[int, list[Any]]: + """ + Use an agent with ``create_rows`` tools to generate and insert + realistic sample rows for newly created tables. + + Instead of building one giant structured-output schema for all tables, + this gives the agent a ``create_rows_in_table_`` tool per table. + The agent decides the insertion order itself — it naturally creates + rows in linked-to tables first, sees the returned row IDs, and uses + them in link_row fields of dependent tables. + """ + + from baserow_enterprise.assistant.model_profiles import ( + SAMPLE, + get_model_settings, + get_model_string, + ) + + from .tools import _build_row_tools + + tool_helpers.update_status(_("Generating example rows for these new tables...")) + + # Build a create_rows tool for every table in the database (not just + # the newly created ones) so link_row fields can reference rows in + # pre-existing tables too. + database = created_tables[0].database + all_db_tables = list(database.table_set.all()) + + # Identify reverse (auto-created) link_row fields to exclude from the + # create schema. When a link_row is created between two tables in the + # same batch, Baserow auto-creates a reverse field. Including both + # sides creates a circular dependency the sample-row agent cannot + # resolve. For any bidirectional pair, the field with the higher ID + # is the auto-created reverse — we exclude it. + reverse_field_ids = _find_reverse_link_row_fields(all_db_tables) + + create_tools = [] + for table in all_db_tables: + # Exclude reverse link_row fields for this table + exclude = reverse_field_ids.get(table.id) + field_ids = None + if exclude: + all_field_ids = [ + fo["field"].id for fo in table.get_model().get_field_objects() + ] + field_ids = [fid for fid in all_field_ids if fid not in exclude] + row_tools = _build_row_tools( + user, workspace, tool_helpers, table, field_ids=field_ids + ) + create_tools.append(row_tools["create"]) + + # Build a description of each table so the agent knows the schemas. + schemas = helpers.get_tables_schema(created_tables, full_schema=True) + table_info = "\n".join(f"- {schema.model_dump()}" for schema in schemas) + + model = get_model_string() + sample_row_agent = Agent( + output_type=str, + instructions=SAMPLE_ROW_AGENT_INSTRUCTIONS, + tools=create_tools, + name="sample_row_agent", + ) + sample_row_agent.run_sync( + format_sample_rows_prompt(table_info, data_brief=data_brief), + model=model, + model_settings=get_model_settings(model, SAMPLE), + usage_limits=UsageLimits(request_limit=len(all_db_tables) * 3 + 2), + ) + + # Collect the rows that were actually inserted. + rows_created: dict[int, list] = {} + for table in created_tables: + table_model = table.get_model() + rows = list(table_model.objects.all()) + if rows: + rows_created[table.id] = rows + + return rows_created diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/helpers.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/helpers.py new file mode 100644 index 0000000000..1652f2a207 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/helpers.py @@ -0,0 +1,388 @@ +""" +Shared helpers for the database assistant tools. + +Contains query helpers, schema builders, and action orchestration used by +``tools.py`` and ``agents.py``. +""" + +from itertools import groupby +from typing import TYPE_CHECKING, Any, Callable + +from django.contrib.auth.models import AbstractUser +from django.db.models import Q, QuerySet +from django.utils.translation import gettext as _ + +from baserow.contrib.database.fields.actions import ( + CreateFieldActionType, + DeleteFieldActionType, + UpdateFieldActionType, +) +from baserow.contrib.database.fields.handler import FieldHandler +from baserow.contrib.database.fields.models import Field +from baserow.contrib.database.fields.registries import field_type_registry +from baserow.contrib.database.table.handler import TableHandler +from baserow.contrib.database.table.models import Table +from baserow.contrib.database.views.actions import CreateViewFilterActionType +from baserow.contrib.database.views.handler import ViewHandler +from baserow.contrib.database.views.models import View, ViewFilter +from baserow.core.db import specific_iterator +from baserow.core.models import Workspace +from baserow_enterprise.assistant.tools.database.types.table import TableItem + +from .types import ( + AnyViewFilterItemCreate, + FieldItem, + FieldItemCreate, + FieldItemUpdate, + InvalidFormulaFieldError, +) + +if TYPE_CHECKING: + from baserow_enterprise.assistant.deps import ToolHelpers + + +class ToolInputError(Exception): + """Raised when tool input is invalid — returned to the model as an error message.""" + + +def filter_tables(user: AbstractUser, workspace: Workspace) -> QuerySet[Table]: + """Return all tables visible to the user in the given workspace.""" + + return TableHandler().list_workspace_tables(user, workspace) + + +def get_table(user: AbstractUser, workspace: Workspace, table_id: int) -> Table: + """Get a single table by ID, raising ToolInputError if not found.""" + + try: + return filter_tables(user, workspace).get(id=table_id) + except Table.DoesNotExist: + raise ToolInputError( + f"Table with ID {table_id} not found. " + "Use get_tables_schema to find valid table IDs." + ) + + +def get_tables_schema( + tables: list[Table], + full_schema: bool = False, +) -> list[TableItem]: + """ + Build serialised schema descriptions for the given tables. + + :param tables: Tables to describe. + :param full_schema: If True include all fields, otherwise only primary + fields and relationships. + :returns: List of table descriptions, in the same order as the input tables. + """ + + q = Q(table__in=tables) + if not full_schema: + q &= Q(linkrowfield__isnull=False) | Q(primary=True) + + base_field_queryset = FieldHandler().get_base_fields_queryset() + fields = specific_iterator( + base_field_queryset.filter(q).order_by("table_id", "order"), + per_content_type_queryset_hook=( + lambda field, queryset: field_type_registry.get_by_model( + field + ).enhance_field_queryset(queryset, field) + ), + ) + + table_items: list[TableItem] = [] + tables_by_id = {table.id: table for table in tables} + for table_id, fields_in_table in groupby(fields, lambda f: f.table_id): + table_items.append(_get_table_schema(tables_by_id, table_id, fields_in_table)) + + # Preserve the input order + input_order = {t.id: i for i, t in enumerate(tables)} + table_items.sort(key=lambda t: input_order[t.id]) + return table_items + + +def _get_table_schema( + tables_by_id: dict[int, Table], table_id: int, fields_in_table: list[Field] +) -> TableItem: + """ + Build a TableItem schema description for a single table given its fields. + + :param tables_by_id: Mapping of table ID → table instance for all tables. + :param table_id: ID of the table to describe. + :param fields_in_table: Iterable of field instances belonging to the table. + :returns: TableItem describing the table and its fields. + """ + + fields_in_table = list(fields_in_table) + primary_field = next((f for f in fields_in_table if f.primary), None) + if primary_field is None: + raise ValueError(f"Table {table_id} has no primary field") + primary_field_item = FieldItem.from_django_orm(primary_field) + + table = tables_by_id[table_id] + + return TableItem( + id=table_id, + name=table.name, + primary_field=primary_field_item, + fields=[ + FieldItem.from_django_orm(f) + for f in fields_in_table + if f.id != primary_field.id + ], + ) + + +def create_fields( + user: AbstractUser, + table: Table, + field_items: list[FieldItemCreate], + tool_helpers: "ToolHelpers", + formula_fixer: Callable[[Table, str, str], str | None] | None = None, +) -> tuple[list[FieldItem], list[str], list[dict]]: + """ + Create fields in a table, handling formula errors with optional auto-fix. + + Fields are sorted so that dependencies are satisfied: regular fields first, + then link_row, lookup, and formula last. + + :param user: The acting user. + :param table: Target table. + :param field_items: Field definitions to create. + :param tool_helpers: Provides status updates and cancellation. + :param formula_fixer: Optional callback ``(table, name, formula) -> fixed`` + invoked when a formula field fails validation. + :returns: Tuple of (created fields, field error messages, formula error dicts). + """ + + from .types import InvalidFormulaFieldError + from .types.fields import FIELD_ORDER + + created_fields: list[FieldItem] = [] + formula_errors: list[dict] = [] + field_errors: list[str] = [] + + # Creation order: regular → link_row → lookup → formula. + # link_row before lookup so auto-created links exist for lookups. + # formula last so they can reference fields created earlier. + field_items = sorted(field_items, key=lambda f: FIELD_ORDER.get(f.type, 0)) + + for field_item in field_items: + tool_helpers.raise_if_cancelled() + tool_helpers.update_status( + _("Creating field %(field_name)s...") % {"field_name": field_item.name} + ) + + try: + new_field = CreateFieldActionType.do( + user, + table, + field_item.type, + **field_item.to_django_orm_kwargs(table, user=user), + ) + created_fields.append(FieldItem.from_django_orm(new_field)) + except InvalidFormulaFieldError as exc: + _fix_formula_field( + user, table, formula_fixer, created_fields, formula_errors, exc + ) + except Exception as e: + field_errors.append( + f"Error creating field {field_item.name} in table_{table.id}: {e}.\n" + f"Please retry recreating this field later, if important." + ) + return created_fields, field_errors, formula_errors + + +def _fix_formula_field( + user: AbstractUser, + table: Table, + formula_fixer: Callable[[Table, str, str], str | None] | None, + created_fields: list[FieldItem], + formula_errors: list[dict], + exc: InvalidFormulaFieldError, +): + """ + Attempt to fix an invalid formula field using the provided formula_fixer callback. + If successful, creates the field with the fixed formula. Otherwise, records the error. + + :param user: The acting user. + :param table: The table the field belongs to. + :param formula_fixer: Callback to attempt formula fixing. + :param created_fields: List to append successfully created fields to. + :param formula_errors: List to append error details to if fixing fails. + :param exc: The exception containing details about the invalid formula. + """ + + fixed = False + if formula_fixer: + try: + new_formula = formula_fixer(exc.table, exc.field_name, exc.formula) + if new_formula: + new_field = CreateFieldActionType.do( + user, + table, + "formula", + name=exc.field_name, + formula=new_formula, + ) + created_fields.append(FieldItem.from_django_orm(new_field)) + fixed = True + except Exception: + pass + if not fixed: + formula_errors.append( + { + "field_name": exc.field_name, + "formula": exc.formula, + "error": exc.error, + } + ) + + +def get_view(user: AbstractUser, workspace: Workspace, view_id: int) -> View: + """ + Fetch a view scoped to the user's workspace. + + :param user: The acting user. + :param workspace: Workspace the view must belong to. + :param view_id: ID of the view to retrieve. + """ + + return ViewHandler().get_view_as_user( + user, + view_id, + base_queryset=View.objects.filter(table__database__workspace=workspace), + ) + + +def create_view_filter( + user: AbstractUser, + orm_view: View, + table_fields: dict[int, Any], + view_filter_item: AnyViewFilterItemCreate, +) -> ViewFilter: + """ + Create a single view filter after validating the field type matches. + + :param user: The acting user. + :param orm_view: The view to add the filter to. + :param table_fields: Mapping of field ID → field instance for the table. + :param view_filter_item: The filter definition to create. + :raises ValueError: If the field is not found or its type doesn't match. + """ + + field = table_fields.get(view_filter_item.field_id) + if field is None: + raise ValueError( + f"Field {view_filter_item.field_id} not found for filter. " + f"Available field IDs: {sorted(table_fields.keys())}" + ) + field_type = field_type_registry.get_by_model(field.specific_class) + if field_type.type != view_filter_item.type: + raise ValueError( + f"Field '{field.name}' (id={field.id}) is type '{field_type.type}', " + f"but filter declared type '{view_filter_item.type}'" + ) + + filter_type = view_filter_item.get_django_orm_type(field) + filter_value = view_filter_item.get_django_orm_value( + field, timezone=user.profile.timezone + ) + + return CreateViewFilterActionType.do( + user, + orm_view, + field, + filter_type, + filter_value, + filter_group_id=None, + ) + + +def update_field( + user: AbstractUser, + workspace: Workspace, + field_update: "FieldItemUpdate", + formula_fixer: Callable[[Table, str, str], str | None] | None = None, +) -> FieldItem: + """ + Update an existing field. + + :param user: The acting user. + :param workspace: Workspace the field must belong to. + :param field_update: The update definition. + :param formula_fixer: Optional callback for fixing invalid formulas. + :returns: Updated field as FieldItem. + """ + + base_field = FieldHandler().get_field(field_update.field_id) + field = base_field.specific + + # Verify workspace access + filter_tables(user, workspace).filter(id=base_field.table_id).get() + field_type = field_type_registry.get_by_model(field).type + kwargs = field_update.to_update_kwargs(field_type) + + if not kwargs: + return FieldItem.from_django_orm(field) + + # Validate formula before updating + if "formula" in kwargs and kwargs["formula"]: + from baserow.contrib.database.fields.models import FormulaField + from baserow.core.formula.parser.exceptions import BaserowFormulaException + + try: + tmp = FormulaField( + formula=kwargs["formula"], + table=field.table, + name=kwargs.get("name", field.name), + order=0, + ) + tmp.recalculate_internal_fields(raise_if_invalid=True) + except BaserowFormulaException as e: + if formula_fixer: + fixed = formula_fixer( + field.table, + kwargs.get("name", field.name), + kwargs["formula"], + ) + if fixed: + kwargs["formula"] = fixed + else: + raise InvalidFormulaFieldError( + kwargs.get("name", field.name), + kwargs["formula"], + field.table, + str(e), + ) + else: + raise InvalidFormulaFieldError( + kwargs.get("name", field.name), + kwargs["formula"], + field.table, + str(e), + ) + + UpdateFieldActionType.do(user, field, **kwargs) + # Re-fetch the specific field to get the updated state + updated_field = FieldHandler().get_field(field_update.field_id).specific + return FieldItem.from_django_orm(updated_field) + + +def delete_field( + user: AbstractUser, + workspace: Workspace, + field_id: int, +) -> None: + """ + Delete (soft-delete / trash) a field. + + :param user: The acting user. + :param workspace: Workspace the field must belong to. + :param field_id: ID of the field to delete. + """ + + base_field = FieldHandler().get_field(field_id) + # Verify workspace access + filter_tables(user, workspace).filter(id=base_field.table_id).get() + DeleteFieldActionType.do(user, base_field.specific) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/prompts.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/prompts.py new file mode 100644 index 0000000000..cc1808ef37 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/prompts.py @@ -0,0 +1,64 @@ +""" +Prompt strings and templates for database sub-agents. +""" + +# --------------------------------------------------------------------------- +# Agent instructions +# --------------------------------------------------------------------------- + +FORMULA_AGENT_INSTRUCTIONS = ( + "Generates a Baserow formula based on the provided description and table schema. " + "Always validate the formula using the get_formula_type tool before returning it." +) + +SAMPLE_ROW_AGENT_INSTRUCTIONS = ( + "Create 5 realistic sample rows for each table using the " + "create_rows tools provided. " + "IMPORTANT: Fill EVERY field for every row. Do NOT leave any field " + "empty or null unless the data genuinely requires it. " + "Insertion order: start with tables that have NO link_row fields, " + "so you have real row IDs to reference. " + "Then create rows in dependent tables, using those IDs in link_row fields. " + "Reply with a short summary when done." +) + +# --------------------------------------------------------------------------- +# Prompt formatters +# --------------------------------------------------------------------------- + + +def format_formula_fixer_prompt( + field_name: str, + original_formula: str, + schema: list[dict], + formula_docs: str, +) -> str: + return ( + f"Fix this formula for field '{field_name}': {original_formula}\n\n" + f"Tables schema: {schema}\n\n" + f"Formula documentation: {formula_docs}" + ) + + +def format_formula_generation_prompt( + description: str, + schema: list[dict], + formula_docs: str, +) -> str: + return ( + f"Description: {description}\n\n" + f"Tables schema: {schema}\n\n" + f"Formula documentation: {formula_docs}" + ) + + +def format_sample_rows_prompt(table_info: str, data_brief: str | None = None) -> str: + prompt = ( + f"Create 5 sample rows for each of these tables:\n{table_info}" + "\n\nREMINDER: Fill ALL fields for every row — especially link_row " + "(relationship) fields. Use the row IDs returned by previous " + "create_rows calls as values for link_row fields in dependent tables." + ) + if data_brief: + prompt += f"\n\nUser instructions for the data: {data_brief}" + return prompt diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tool_types.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tool_types.py new file mode 100644 index 0000000000..2ee9cbd5cc --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tool_types.py @@ -0,0 +1,20 @@ +from baserow_enterprise.assistant.tools.registries import AssistantToolType + + +class DatabaseToolType(AssistantToolType): + type = "database" + + def get_tool_functions(self): + from .tools import TOOL_FUNCTIONS + + return TOOL_FUNCTIONS + + def get_toolset(self): + from .tools import database_toolset + + return database_toolset + + def get_routing_rules(self): + from .tools import ROUTING_RULES + + return ROUTING_RULES diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tools.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tools.py index 56d2ed7a37..08f221d1cf 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tools.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tools.py @@ -1,23 +1,29 @@ -from typing import TYPE_CHECKING, Any, Callable, Literal, Tuple +from typing import TYPE_CHECKING, Annotated, Any, Literal from django.contrib.auth.models import AbstractUser from django.db import transaction from django.utils.translation import gettext as _ -import udspy from loguru import logger -from pydantic import create_model +from pydantic import Field, create_model +from pydantic_ai import RunContext, Tool +from pydantic_ai.toolsets import FunctionToolset +from pydantic_ai.usage import UsageLimits -from baserow.contrib.database.api.formula.serializers import TypeFormulaResultSerializer from baserow.contrib.database.fields.actions import ( CreateFieldActionType, DeleteFieldActionType, UpdateFieldActionType, ) -from baserow.contrib.database.fields.models import FormulaField from baserow.contrib.database.fields.registries import field_type_registry from baserow.contrib.database.models import Database +from baserow.contrib.database.rows.actions import ( + CreateRowsActionType, + DeleteRowsActionType, + UpdateRowsActionType, +) from baserow.contrib.database.table.actions import CreateTableActionType +from baserow.contrib.database.table.models import Table from baserow.contrib.database.views.actions import ( CreateViewActionType, UpdateViewFieldOptionsActionType, @@ -25,848 +31,1200 @@ from baserow.contrib.database.views.handler import ViewHandler from baserow.core.models import Workspace from baserow.core.service import CoreService -from baserow_enterprise.assistant.tools.registries import AssistantToolType +from baserow_enterprise.assistant.deps import AssistantDeps +from baserow_enterprise.assistant.tools.toolset import inline_refs from baserow_enterprise.assistant.types import TableNavigationType, ViewNavigationType from baserow_premium.prompts import get_formula_docs -from . import utils +from . import helpers +from .agents import ( + formula_generation_agent, + generate_sample_rows, + get_formula_type_tool, + make_formula_fixer, +) +from .prompts import format_formula_generation_prompt from .types import ( - AnyFieldItem, - AnyFieldItemCreate, - AnyViewFilterItem, - AnyViewItemCreate, - BaseTableItem, + FieldItemCreate, + FieldItemUpdate, ListTablesFilterArg, TableItemCreate, ViewFiltersArgs, - view_item_registry, + ViewItem, + ViewItemCreate, + get_create_row_model, + get_link_row_hints, + get_update_row_model, ) if TYPE_CHECKING: - from baserow_enterprise.assistant.assistant import ToolHelpers + from baserow_enterprise.assistant.deps import ToolHelpers +MAX_HINT_TABLES = 10 -def get_list_tables_tool( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[int], list[str]]: - """ - Returns a function that lists all the tables in a given database the user has - access to in the current workspace. - """ - def list_tables(filters: ListTablesFilterArg) -> list[dict[str, Any]]: - """ - List tables that verifies the filters +def _no_tables_found_hint( + user: AbstractUser, workspace: Workspace, filters: "ListTablesFilterArg" +) -> str: + """Build an informative message when no tables match the filters. - - Always call this before creating new tables to avoid duplicates. - - Always call this to link existing tables when table IDs are not known. - """ + When the caller supplied a ``database_id`` that doesn't correspond to any + real database in the workspace, say so explicitly and list the first + available tables so the model can self-correct. + """ - nonlocal user, workspace, tool_helpers + parts: list[str] = [] - tables = ( - utils.filter_tables(user, workspace) - .filter(filters.to_orm_filter()) - .select_related("database") + # Check whether the requested database actually exists. + db_ref = filters.database_id_or_name + if db_ref is not None: + if isinstance(db_ref, int): + db_exists = Database.objects.filter(workspace=workspace, id=db_ref).exists() + else: + db_exists = Database.objects.filter( + workspace=workspace, name__icontains=db_ref + ).exists() + if not db_exists: + parts.append( + f"No database matching '{db_ref}' exists in this " + f"workspace. Note: workspace, application, and database IDs " + f"are different — make sure you are using a database ID." + ) + else: + parts.append( + f"Database '{db_ref}' exists but has no tables " + f"matching the provided filters." + ) + else: + parts.append("No tables found matching the provided filters.") + + # Fetch a sample of available tables across the workspace. + all_tables = ( + helpers.filter_tables(user, workspace) + .select_related("database") + .order_by("database_id", "id") + ) + total_tables = all_tables.count() + + if total_tables == 0: + parts.append("This workspace has no database tables at all.") + return " ".join(parts) + + sample = all_tables[:MAX_HINT_TABLES] + db_ids_seen: set[int] = set() + table_lines: list[str] = [] + for t in sample: + db_ids_seen.add(t.database_id) + table_lines.append( + f' - table_id={t.id}, table_name="{t.name}", ' + f'database_id={t.database_id}, database_name="{t.database.name}"' ) - databases = {} - database_names = [] - for table in tables: - if table.database_id not in databases: - databases[table.database_id] = { - "id": table.database_id, - "name": table.database.name, - "tables": [], - } - database_names.append(table.database.name) - databases[table.database_id]["tables"].append( - { - "id": table.id, - "name": table.name, - "database_id": table.database_id, - } - ) + total_dbs = Database.objects.filter(workspace=workspace).count() - tool_helpers.update_status( - _("Listing tables in %(database_names)s...") - % {"database_names": ", ".join(database_names)} + parts.append( + f"Available tables ({total_tables} table(s) across " + f"{total_dbs} database(s) in this workspace):" + ) + parts.append("\n".join(table_lines)) + + remaining_tables = total_tables - len(sample) + remaining_dbs = total_dbs - len(db_ids_seen) + if remaining_tables > 0: + parts.append( + f" ... and {remaining_tables} more table(s) in " + f"{remaining_dbs} more database(s)." ) - if len(databases) == 0: - return "No tables found" - elif len(databases) == 1: - # Return just the tables array when there's only one database - return list(databases.values())[0]["tables"] - else: - return list(databases.values()) - - return list_tables + return "\n".join(parts) -class ListTablesToolType(AssistantToolType): - type = "list_tables" - thinking_message = "Looking for tables..." +# --------------------------------------------------------------------------- +# Tool 1: list_tables +# --------------------------------------------------------------------------- - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - return get_list_tables_tool(user, workspace, tool_helpers) +def list_tables( + ctx: RunContext[AssistantDeps], + filters: Annotated[ + ListTablesFilterArg, + Field(description="Filter criteria to narrow down which tables to list."), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> list[dict[str, Any]] | dict[str, Any]: + """\ + List tables, optionally filtered by database or name. -def get_tables_schema_tool( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[int], list[str]]: - """ - Returns a function that lists all the fields in a given table the user has - access to in the current workspace. + WHEN to use: Before creating tables (to avoid duplicates), when you need table IDs, or to discover what tables exist in the workspace. + WHAT it does: Lists tables matching the filter criteria (database_id, name, starred), grouped by database. + RETURNS: Tables with id, name, database_id. Includes a hint with available tables if no match found. + DO NOT USE when: You already have the table IDs you need. """ - def get_tables_schema( - table_ids: list[int], - full_schema: bool, - ) -> list[dict[str, Any]]: - """ - Returns the schema of the specified tables, including their fields if requested. - Use `full_schema=True` to get all the fields, otherwise only the table names, - IDs, primary keys, and relationships will be included. + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - When to use: - Understanding table structure before creating/modifying fields - - Checking existing field names to avoid duplicates - Understanding table - relationships when creating link_row fields + tables = ( + helpers.filter_tables(user, workspace) + .filter(filters.to_orm_filter()) + .select_related("database") + ) - Remember: - Always call this before creating fields to avoid duplicate names - - Use get_rows_tools() for any row operations - not this one - """ + databases = {} + database_names = [] + for table in tables: + if table.database_id not in databases: + databases[table.database_id] = { + "id": table.database_id, + "name": table.database.name, + "tables": [], + } + database_names.append(table.database.name) + databases[table.database_id]["tables"].append( + { + "id": table.id, + "name": table.name, + "database_id": table.database_id, + } + ) - nonlocal user, workspace, tool_helpers + tool_helpers.update_status( + _("Listing tables in %(database_names)s...") + % {"database_names": ", ".join(database_names)} + ) - if not table_ids: - return [] + if len(databases) == 0: + return {"tables": [], "_info": _no_tables_found_hint(user, workspace, filters)} + elif len(databases) == 1: + # Return just the tables array when there's only one database + return list(databases.values())[0]["tables"] + else: + return list(databases.values()) + + +# --------------------------------------------------------------------------- +# Tool 2: get_tables_schema +# --------------------------------------------------------------------------- + + +def get_tables_schema( + ctx: RunContext[AssistantDeps], + table_ids: Annotated[ + list[int], Field(description="List of table IDs to retrieve schemas for.") + ], + full_schema: Annotated[ + bool, + Field( + description="If True, include all fields. If False, only table names, IDs, primary keys, and relationships." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Get field definitions for tables (full_schema=True for all fields). + + WHEN to use: Before creating/modifying fields to understand table structure and avoid duplicates. Also for understanding relationships when creating link_row fields. + WHAT it does: Returns the schema of specified tables. full_schema=True returns all fields with types and configs. full_schema=False returns only names, IDs, primary keys, and relationships. + RETURNS: Table schemas with field names, types, IDs, primary keys, and relationships. + DO NOT USE when: You need row data — use list_rows instead. For row operations, use load_row_tools, those tools already provide the necessary schema info in their instructions. + """ - tables = utils.filter_tables(user, workspace).filter(id__in=table_ids) + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - tool_helpers.update_status( - _("Inspecting %(table_names)s schema...") - % {"table_names": ", ".join(t.name for t in tables)} - ) + if not table_ids: + return {"tables_schema": []} - return { - "tables_schema": [ - ts.model_dump() for ts in utils.get_tables_schema(tables, full_schema) - ] - } + tables = helpers.filter_tables(user, workspace).filter(id__in=table_ids) - return get_tables_schema + tool_helpers.update_status( + _("Inspecting %(table_names)s schema...") + % {"table_names": ", ".join(t.name for t in tables)} + ) + return { + "tables_schema": [ + ts.model_dump() for ts in helpers.get_tables_schema(tables, full_schema) + ] + } + + +# --------------------------------------------------------------------------- +# Tool 3: list_rows +# --------------------------------------------------------------------------- + + +def list_rows( + ctx: RunContext[AssistantDeps], + table_id: Annotated[ + int, Field(description="The ID of the table to list rows from.") + ], + offset: Annotated[ + int, + Field( + description="Number of rows to skip for pagination. Use 0 for the first page." + ), + ], + limit: Annotated[ + int, Field(description="Maximum number of rows to return (max 20).") + ], + field_ids: Annotated[ + list[int] | None, + Field(description="List of field IDs to include, or null for all fields."), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Read rows from a table with pagination (max 20 per call). + + WHEN to use: User wants to see data in a table, or you need to check existing row values. + WHAT it does: Lists rows from a table with pagination (offset/limit) and optional field filtering. Max 20 rows per call. + RETURNS: Rows array with field values, plus total row count for pagination. + DO NOT USE when: You need to create, update, or delete rows — call load_row_tools first to get row manipulation tools. + """ -class GetTablesSchemaToolType(AssistantToolType): - type = "get_tables_schema" + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - return get_tables_schema_tool(user, workspace, tool_helpers) + table = helpers.get_table(user, workspace, table_id) + tool_helpers.update_status( + _("Listing rows in %(table_name)s ") % {"table_name": table.name} + ) -def get_table_and_fields_tools_factory( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[list[TableItemCreate]], list[dict[str, Any]]]: - def create_fields( - table_id: int, fields: list[AnyFieldItemCreate] - ) -> list[AnyFieldItem]: - """ - Creates fields in the specified table. + rows_qs = table.get_model().objects.all() + rows = rows_qs[offset : offset + limit] - - Choose the most appropriate field type for each field. - - Field names must be unique within a table: check existing names - when needed and skip duplicates. - - For link_row fields, ensure the linked table already exists in - the same database; create it first if needed. - """ + response_model = create_model( + f"ResponseTable{table.id}RowWithFieldFilter", + id=(int, ...), + __base__=get_create_row_model(table, field_ids=field_ids), + ) - nonlocal user, workspace, tool_helpers + return { + "rows": [ + response_model.from_django_orm(row, field_ids).model_dump() for row in rows + ], + "total": rows_qs.count(), + } + + +# --------------------------------------------------------------------------- +# Tool 4: list_views +# --------------------------------------------------------------------------- + + +def list_views( + ctx: RunContext[AssistantDeps], + table_id: Annotated[ + int, Field(description="The ID of the table to list views for.") + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + List views in a table. + + WHEN to use: Before creating views (to avoid duplicate names), or to find existing view IDs. + WHAT it does: Lists all views in a table with their id, name, and type. + RETURNS: Views array with id, name, type configuration. + DO NOT USE when: You already have the view IDs you need. + """ - if not fields: - return [] + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - table = utils.filter_tables(user, workspace).get(id=table_id) + table = helpers.get_table(user, workspace, table_id) - with transaction.atomic(): - created_fields = utils.create_fields(user, table, fields, tool_helpers) - return {"created_fields": [field.model_dump() for field in created_fields]} - - def create_tables( - database_id: int, tables: list[TableItemCreate], add_sample_rows: bool = True - ) -> list[dict[str, Any]]: - """ - Creates tables with fields and rows in a database. **ALWAYS** add sample rows - unless explicitly asked otherwise. - - - table names should be unique in a database - - add meaningful fields with the appropriate types and relationships to other - existing tables. The reversed link_row fields will be created automatically. - - if add_sample_rows is True (default), add some example rows to each table - """ - - nonlocal user, workspace, tool_helpers - - if not tables: - return {"created_tables": []} - - database = CoreService().get_application( - user, - database_id, - specific=False, - base_queryset=Database.objects.filter(workspace=workspace), - ) + tool_helpers.update_status( + _("Listing views in %(table_name)s...") % {"table_name": table.name} + ) - created_tables = [] - with transaction.atomic(): - for i, table in enumerate(tables): - tool_helpers.update_status( - _("Creating table %(table_name)s...") % {"table_name": table.name} - ) + views = ViewHandler().list_views( + user, + table, + filters=False, + sortings=False, + decorations=False, + group_bys=False, + limit=100, + ) - created_table, __ = CreateTableActionType.do( - user, database, table.name, fill_example=False - ) - created_tables.append(created_table) + return {"views": [ViewItem.from_django_orm(view).model_dump() for view in views]} - primary_field_item = table.primary_field - primary_field = created_table.get_primary_field().specific - new_field_type = field_type_registry.get(primary_field_item.type) - UpdateFieldActionType.do( - user, - primary_field, - new_type_name=new_field_type.type, - name=primary_field_item.name, - ) - # Now that we have all the tables created, we can create the fields - notes = [] - for table, created_table in zip(tables, created_tables): - with transaction.atomic(): - try: - utils.create_fields(user, created_table, table.fields, tool_helpers) - except Exception as e: - notes.append( - f"Error creating fields for table_{created_table.id}: {e}.\n" - f"Please retry recreating fields for table_{created_table.id} manually." - ) +# --------------------------------------------------------------------------- +# Tool 5: create_tables +# --------------------------------------------------------------------------- - tool_helpers.navigate_to( - TableNavigationType( - type="database-table", - database_id=database.id, - table_id=created_table.id, - table_name=created_table.name, - ) - ) - if add_sample_rows: - instructions = [] +def _create_empty_tables( + user: AbstractUser, + database: Database, + tables: list[TableItemCreate], + tool_helpers: "ToolHelpers", +) -> list[Table]: + """Create bare tables and rename each one's auto-created primary field.""" + created: list[Table] = [] + with transaction.atomic(): + for table in tables: + tool_helpers.raise_if_cancelled() tool_helpers.update_status( - _("Preparing example rows for these new tables...") + _("Creating table %(table_name)s...") % {"table_name": table.name} ) - tools = [] - for table, created_table in zip(tables, created_tables): - create_rows_tool = utils.get_table_rows_tools( - user, workspace, tool_helpers, created_table - )["create"] - tools.append(create_rows_tool) - instructions.append( - f"- Create 5 example rows with realistic data for {created_table.name} (Id: {created_table.id}). " - "Fill every relationship with valid data when possible." - ) - - predictor = udspy.ReAct( - "instructions -> result", tools=tools, max_iters=len(tables * 2) + created_table, __ = CreateTableActionType.do( + user, database, table.name, fill_example=False ) - result = predictor(instructions=("\n".join(instructions))) - notes.append(result) - - return { - "created_tables": [ - BaseTableItem(id=t.id, name=t.name).model_dump() for t in created_tables - ], - "notes": notes, - } - - def load_table_and_fields_tools(): - """ - TOOL LOADER: Loads table and field creation tools for a database. - - After calling this loader, you will have access to: - - create_tables: Create new tables in a database with fields and sample rows - - create_fields: Add new fields to an existing table + created.append(created_table) + primary_field = created_table.get_primary_field().specific + UpdateFieldActionType.do(user, primary_field, name=table.primary_field_name) + return created - Use this when you need to create tables or add fields but don't have the tools. - """ - @udspy.module_callback - def _load_table_and_fields_tools(context): - nonlocal user, workspace, tool_helpers - - observation = ["New tools are now available.\n"] - - create_tool = udspy.Tool(create_tables) - new_tools = [create_tool] - observation.append("- Use `create_tables` to create tables in a database.") - - create_fields_tool = udspy.Tool(create_fields) - new_tools.append(create_fields_tool) - observation.append("- Use `create_fields` to create fields in a table.") +def _create_table_fields( + user: AbstractUser, + tables: list[TableItemCreate], + created_tables: list[Table], + tool_helpers: "ToolHelpers", + formula_fixer, +) -> list[str]: + """Create non-primary fields for each table; return collected notes/errors.""" + notes: list[str] = [] + for table, created_table in zip(tables, created_tables): + tool_helpers.raise_if_cancelled() + with transaction.atomic(): + # Drop any field whose name matches the primary field name — it's + # already set via UpdateFieldActionType.do() above. Including it in + # fields too is a common model mistake that would otherwise produce + # a "field already exists" error note. + non_primary_fields = [ + f + for f in table.fields + if f.name.lower() != table.primary_field_name.lower() + ] + _created, field_errors, formula_errors = helpers.create_fields( + user, + created_table, + non_primary_fields, + tool_helpers, + formula_fixer=formula_fixer, + ) + notes.extend(field_errors) + for err in formula_errors: + notes.append( + f"Invalid formula for field '{err['field_name']}' " + f"in table_{created_table.id}: {err['error']}. " + f"Use generate_formula to fix it." + ) + return notes + + +def create_tables( + ctx: RunContext[AssistantDeps], + database_id: Annotated[ + int, + Field( + ..., + description="The ID of the database to create tables in.", + ), + ], + tables: Annotated[ + list[TableItemCreate], + Field( + ..., + description="List of tables to create, each with a name, primary field, fields and relationships.", + ), + ], + add_sample_rows: Annotated[ + bool | str, + Field( + ..., + description="Controls sample row generation. True (default): generate realistic example rows. " + "A string: a brief describing what kind of data to create (e.g. 'Italian recipes with calorie counts'). " + "False: create empty tables, only use when the user explicitly asks for no sample data.", + ), + ], + thought: Annotated[ + str, + Field( + ..., + description="Brief reasoning for calling this tool.", + ), + ], +) -> dict[str, Any]: + """\ + Create tables with fields; generates sample rows by default. + + WHEN to use: User wants new tables created in a database. Always set add_sample_rows=true (or a descriptive string) unless explicitly asked for empty tables. + WHAT it does: Creates tables with fields, generates sample rows by default. Pass add_sample_rows=false ONLY when the user explicitly asks for empty tables. + Pass a string to guide the kind of sample data generated (e.g. "Italian recipes with calorie counts"). Table names must be unique. Reversed link_row fields are auto-created. + At the end, this tool automatically navigates the user to the last created table. + RETURNS: Created table schemas with all field IDs. Notes on any errors. + DO NOT USE when: Tables already exist — check with list_tables first. + HOW: Pass ALL related tables in a single call — link_row fields can reference other tables in the same call by name (they are created internally before fields are added). Choose appropriate field types for each column. + Use single_select/multiple_select with select_options for categorical data. The primary field is always text — pick a meaningful name for it. + """ - # Re-initialize the module with the new tools for the next iteration - context.module.init_module(tools=context.module._tools + new_tools) - return "\n".join(observation) + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - return _load_table_and_fields_tools + if not tables: + return {"created_tables": []} - return load_table_and_fields_tools + database = CoreService().get_application( + user, + database_id, + specific=False, + base_queryset=Database.objects.filter(workspace=workspace), + ) + created_tables = _create_empty_tables(user, database, tables, tool_helpers) -class TableAndFieldsToolFactoryToolType(AssistantToolType): - type = "table_and_fields_tool_factory" + formula_fixer = make_formula_fixer(user, workspace, tool_helpers) + notes = _create_table_fields( + user, tables, created_tables, tool_helpers, formula_fixer + ) - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - return get_table_and_fields_tools_factory(user, workspace, tool_helpers) + last_table = created_tables[-1] + tool_helpers.navigate_to( + TableNavigationType( + type="database-table", + database_id=database.id, + table_id=last_table.id, + table_name=last_table.name, + ) + ) + created_rows = {} + if add_sample_rows: + try: + data_brief = add_sample_rows if isinstance(add_sample_rows, str) else None + created_rows = generate_sample_rows( + user, workspace, tool_helpers, created_tables, data_brief=data_brief + ) + except Exception as e: + logger.exception( + "[assistant] generate_sample_rows raised unexpectedly: {}", e + ) + notes.append(f"Error creating sample rows: {e}") + + # Return the full schema so callers don't need a separate + # get_tables_schema call to learn field IDs. + tables_schema = [ + ts.model_dump() + for ts in helpers.get_tables_schema(created_tables, full_schema=True) + ] + + response: dict[str, Any] = {"created_tables": tables_schema, "notes": notes} + if created_rows: + response["created_rows"] = { + f"Row IDs for newly created rows in table_{table_id}": [ + row.id for row in rows + ] + for table_id, rows in created_rows.items() + } -def get_list_rows_tool( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[int, int, int, list[int] | None], list[dict[str, Any]]]: + return response + + +# --------------------------------------------------------------------------- +# Tool 6: create_fields +# --------------------------------------------------------------------------- + + +def create_fields( + ctx: RunContext[AssistantDeps], + table_id: Annotated[ + int, Field(description="The ID of the table to add fields to.") + ], + fields: Annotated[ + list[FieldItemCreate], + Field( + description="List of fields to create with their types and configurations." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Add fields to an existing table. + + WHEN to use: Adding fields to an existing table, or retrying failed field creation after create_tables partial failure. + WHAT it does: Creates fields in the specified table. Field names must be unique. For link_row fields, the linked table must already exist. + RETURNS: Created fields with id, name, type. Formula errors with hints if any. + DO NOT USE when: Creating a brand new table — use create_tables instead, which handles fields as part of table creation. + HOW: Call get_tables_schema first to see existing fields and avoid duplicates. For link_row fields, ensure the target table already exists. """ - Returns a function that lists rows in a given table the user has access to in the - current workspace. - """ - - def list_rows( - table_id: int, - offset: int = 0, - limit: int = 20, - field_ids: list[int] | None = None, - ) -> list[dict[str, Any]]: - """ - Lists rows in the specified table. - - - Use offset and limit for pagination. - - Use field_ids to limit the response to specific fields. - """ - nonlocal user, workspace, tool_helpers + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - table = utils.filter_tables(user, workspace).get(id=table_id) - - tool_helpers.update_status( - _("Listing rows in %(table_name)s ") % {"table_name": table.name} - ) + if not fields: + return {"created_fields": []} - rows_qs = table.get_model().objects.all() - rows = rows_qs[offset : offset + limit] + table = helpers.get_table(user, workspace, table_id) - response_model = create_model( - f"ResponseTable{table.id}RowWithFieldFilter", - id=(int, ...), - __base__=utils.get_create_row_model(table, field_ids=field_ids), + with transaction.atomic(): + formula_fixer = make_formula_fixer(user, workspace, tool_helpers) + created_fields, field_errors, formula_errors = helpers.create_fields( + user, table, fields, tool_helpers, formula_fixer=formula_fixer ) + result = {"created_fields": [field.model_dump() for field in created_fields]} + if field_errors: + result["field_errors"] = field_errors + if formula_errors: + for err in formula_errors: + err["hint"] = ( + "Use generate_formula to create a valid formula for this field." + ) + result["formula_errors"] = formula_errors + return result + + +# --------------------------------------------------------------------------- +# Tool 7: update_fields +# --------------------------------------------------------------------------- + + +def update_fields( + ctx: RunContext[AssistantDeps], + fields: Annotated[ + list[FieldItemUpdate], + Field( + description="List of field updates, each with a field_id and the properties to change." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Update existing fields (rename, change properties). + + WHEN to use: User wants to rename a field, change decimal places, update select options, or modify other field properties. + WHAT it does: Updates field properties. Cannot change field type or link_row targets — create a new field instead. + RETURNS: Updated fields with id, name, type and current properties. + DO NOT USE when: You need to change a field's type — delete and recreate it instead. + HOW: Call get_tables_schema first to see current field IDs and types. + """ - return { - "rows": [ - response_model.from_django_orm(row, field_ids).model_dump() - for row in rows - ], - "total": rows_qs.count(), - } - - return list_rows - - -class ListRowsToolType(AssistantToolType): - type = "list_rows" + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - return get_list_rows_tool(user, workspace, tool_helpers) + if not fields: + return {"updated_fields": [], "errors": []} + updated = [] + errors = [] + formula_fixer = make_formula_fixer(user, workspace, tool_helpers) -def get_rows_tools_factory( - user: AbstractUser, - workspace: Workspace, - tool_helpers: "ToolHelpers", -) -> Callable[[int, list[dict[str, Any]]], list[Any]]: - def load_rows_tools( - table_ids: list[int], - operations: list[Literal["create", "update", "delete"]], - ) -> Tuple[str, list[Callable[[Any], Any]]]: - """ - TOOL LOADER: Loads row manipulation tools for specified tables. - Make sure to have the correct table IDs. - - After calling this loader, you will have access to table-specific tools: - - create_rows_in_table_X: Create new rows in table X - - update_rows_in_table_X: Update existing rows in table X by their IDs - - delete_rows_in_table_X: Delete rows from table X by their IDs - - Use this when you need to create, update, or delete rows but don't have - the tools. - Call with the table IDs and desired operations (create/update/delete). - """ - - @udspy.module_callback - def _load_rows_tools(context): - nonlocal user, workspace, tool_helpers - - tables = utils.filter_tables(user, workspace).filter(id__in=table_ids) - if not tables: - observation = [ - "No valid tables found for the given IDs. ", - "Make sure the table IDs are correct.", - ] - return "\n".join(observation) - - new_tools = [] - observation = ["New tools are now available.\n"] - for table in tables: - table_tools = utils.get_table_rows_tools( - user, workspace, tool_helpers, table + with transaction.atomic(): + for field_update in fields: + tool_helpers.raise_if_cancelled() + tool_helpers.update_status( + _("Updating field %(field_id)s...") + % {"field_id": field_update.field_id} + ) + try: + field_item = helpers.update_field( + user, workspace, field_update, formula_fixer=formula_fixer ) + updated.append(field_item.model_dump()) + except Exception as e: + errors.append(f"Error updating field {field_update.field_id}: {e}") + + result: dict[str, Any] = {"updated_fields": updated} + if errors: + result["errors"] = errors + return result + + +# --------------------------------------------------------------------------- +# Tool 8: delete_fields +# --------------------------------------------------------------------------- + + +def delete_fields( + ctx: RunContext[AssistantDeps], + field_ids: Annotated[ + list[int], + Field(description="List of field IDs to delete."), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Delete fields (moves them to trash). + + WHEN to use: User wants to remove fields from a table. + WHAT it does: Soft-deletes fields (moves to trash, can be restored). Primary fields cannot be deleted. + RETURNS: List of deleted field IDs. + DO NOT USE when: You want to change a field — use update_fields instead. + HOW: Call get_tables_schema first to confirm field IDs. + """ - observation.append(f"Table '{table.name}' (ID: {table.id}):") + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - if "create" in operations: - create_rows = table_tools["create"] - new_tools.append(create_rows) - observation.append(f"- Use {create_rows.name} to create new rows.") + if not field_ids: + return {"deleted_field_ids": [], "errors": []} - if "update" in operations: - update_rows = table_tools["update"] - new_tools.append(update_rows) - observation.append( - f"- Use {update_rows.name} to update existing rows by their IDs." - ) + deleted = [] + errors = [] - if "delete" in operations: - delete_rows = table_tools["delete"] - new_tools.append(delete_rows) - observation.append( - f"- Use {delete_rows.name} to delete rows by their IDs." - ) + with transaction.atomic(): + for field_id in field_ids: + tool_helpers.raise_if_cancelled() + tool_helpers.update_status( + _("Deleting field %(field_id)s...") % {"field_id": field_id} + ) + try: + helpers.delete_field(user, workspace, field_id) + deleted.append(field_id) + except Exception as e: + errors.append(f"Error deleting field {field_id}: {e}") + + result: dict[str, Any] = {"deleted_field_ids": deleted} + if errors: + result["errors"] = errors + return result + + +# --------------------------------------------------------------------------- +# Tool 9: create_views +# --------------------------------------------------------------------------- + + +def create_views( + ctx: RunContext[AssistantDeps], + table_id: Annotated[ + int, Field(description="The ID of the table to create views for.") + ], + views: Annotated[ + list[ViewItemCreate], + Field( + description="List of views to create (grid, form, gallery, kanban, calendar, timeline)." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Create views (grid, form, gallery, kanban, calendar, timeline). + + WHEN to use: User wants a new view (grid, form, gallery, kanban, calendar, timeline) on a table. + WHAT it does: Creates views in the table. View names must be unique. A default grid view is auto-created with every new table — no need to recreate it. + RETURNS: Created views with id, name, type configuration. + DO NOT USE when: The default grid view already meets the user's needs. Check existing views with list_views to avoid duplicates. + HOW: Each view type requires specific config. Form views: provide field_options listing every field to show (field_id, name, order, required). Kanban: set column_field_id to a single_select field. Calendar: set date_field_id to a date field. Timeline: set both start/end date fields. Gallery: optionally set cover_field_id to a file field. Call get_tables_schema first to get the field IDs you need. + """ - observation.append("") + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - # Re-initialize the module with the new tools for the next iteration - context.module.init_module(tools=context.module._tools + new_tools) - return "\n".join(observation) + if not views: + return {"created_views": []} - return _load_rows_tools + table = helpers.get_table(user, workspace, table_id) - return load_rows_tools + created_views = [] + with transaction.atomic(): + for view in views: + tool_helpers.raise_if_cancelled() + tool_helpers.update_status( + _("Creating %(view_type)s view %(view_name)s") + % {"view_type": view.type, "view_name": view.name} + ) + orm_view = CreateViewActionType.do( + user, + table, + view.type, + **view.to_django_orm_kwargs(table), + ) -class RowsToolFactoryToolType(AssistantToolType): - type = "rows_tool_factory" + field_options = view.field_options_to_django_orm() + if field_options: + UpdateViewFieldOptionsActionType.do(user, orm_view, field_options) - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - return get_rows_tools_factory(user, workspace, tool_helpers) + created_views.append({"id": orm_view.id, **view.model_dump()}) + tool_helpers.navigate_to( + ViewNavigationType( + type="database-view", + database_id=table.database_id, + table_id=table.id, + view_id=created_views[0]["id"], + view_name=created_views[0]["name"], + view_type=created_views[0]["type"], + ) + ) -def get_list_views_tool( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[int], list[dict[str, Any]]]: + return {"created_views": created_views} + + +# --------------------------------------------------------------------------- +# Tool 8: create_view_filters +# --------------------------------------------------------------------------- + + +def create_view_filters( + ctx: RunContext[AssistantDeps], + view_filters: Annotated[ + list[ViewFiltersArgs], + Field( + description="List of view filter configurations, each specifying a view ID and its filters." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Add filters to views. + + WHEN to use: User wants to filter a view to show only specific rows matching conditions. + WHAT it does: Creates filter conditions on one or more views. Supports multiple filters per view. + RETURNS: Created filters with id and configuration per view. + DO NOT USE when: The view doesn't exist yet — create it first with create_views. + HOW: Get the table schema first to know field IDs and types. Match filter type to field type. + + ## Value formats by type + + - text: string + - number: number + - date: ISO date string (mode=exact_date) or integer (mode=nr_days_ago etc.) or "" (mode=today etc.) + - single_select / multiple_select: list of option label strings (matched case-insensitively) + - link_row: row ID (integer) + - boolean: true / false """ - Returns a function that lists all the views in a given table the user has - access to in the current workspace. - """ - - def list_views(table_id: int) -> list[dict[str, Any]]: - """ - List views in the specified table. - - - Always call this for existing tables to avoid creating views with duplicate - names. - """ - nonlocal user, workspace, tool_helpers + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - table = utils.filter_tables(user, workspace).get(id=table_id) + if not view_filters: + return {"created_view_filters": []} + created_view_filters = [] + for vf in view_filters: + tool_helpers.raise_if_cancelled() + orm_view = helpers.get_view(user, workspace, vf.view_id) tool_helpers.update_status( - _("Listing views in %(table_name)s...") % {"table_name": table.name} + _("Creating filters in %(view_name)s...") % {"view_name": orm_view.name} ) - views = ViewHandler().list_views( - user, - table, - filters=False, - sortings=False, - decorations=False, - group_bys=False, - limit=100, - ) - - return { - "views": [ - view_item_registry.from_django_orm(view).model_dump() for view in views - ] - } - - return list_views - - -class ListViewsToolType(AssistantToolType): - type = "list_views" - - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - return get_list_views_tool(user, workspace, tool_helpers) - + fields = {f.id: f for f in orm_view.table.field_set.all()} + created_filters = [] + with transaction.atomic(): + for filter in vf.filters: + try: + orm_filter = helpers.create_view_filter( + user, orm_view, fields, filter + ) + except ValueError as e: + logger.warning(f"Skipping filter creation: {e}") + continue + + created_filters.append({"id": orm_filter.id, **filter.model_dump()}) + created_view_filters.append({"view_id": vf.view_id, "filters": created_filters}) + + return {"created_view_filters": created_view_filters} + + +# --------------------------------------------------------------------------- +# Tool 9: generate_formula +# --------------------------------------------------------------------------- + + +def generate_formula( + ctx: RunContext[AssistantDeps], + database_id: Annotated[ + int, + Field( + description="The ID of the database containing the tables for the formula." + ), + ], + description: Annotated[ + str, + Field( + description="A natural language description of what the formula should compute." + ), + ], + save_to_field: Annotated[ + bool, + Field( + description="If true, save the formula to a field. If false, only return it. Should be true unless explicitly asked otherwise." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, str]: + """\ + Generate a formula from a natural-language description and save it. + + WHEN to use: User needs a computed field (formulas, calculations, cross-table lookups). No need to inspect the schema first — this tool does it automatically. + WHAT it does: Generates a valid Baserow formula from a natural-language description. Finds the best table and fields automatically. Saves to a formula field by default (save_to_field=true). + RETURNS: Generated formula string, formula type, and field details (name, table, operation). + DO NOT USE when: The user wants a simple non-formula field — use create_fields instead. + HOW: Describe what the formula should compute in plain language. The tool auto-discovers the table schema — no need to inspect it first. + """ + from baserow_enterprise.assistant.model_profiles import ( + UTILITY, + get_model_settings, + get_model_string, + ) -def get_views_tool_factory( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[int, list[str]], list[str]]: - def create_view_filters( - view_filters: list[ViewFiltersArgs], - ) -> list[AnyViewFilterItem]: - """ - Creates filters in the specified views. - """ + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - nonlocal user, workspace, tool_helpers + database_tables = helpers.filter_tables(user, workspace).filter( + database_id=database_id + ) + database_tables_schema = [ + t.model_dump() for t in helpers.get_tables_schema(database_tables, True) + ] - if not view_filters: - return [] + tool_helpers.update_status(_("Generating formula...")) - created_view_filters = [] - for vf in view_filters: - orm_view = utils.get_view(user, vf.view_id) - tool_helpers.update_status( - _("Creating filters in %(view_name)s...") % {"view_name": orm_view.name} - ) + formula_docs = get_formula_docs() + formula_type_tool = Tool(get_formula_type_tool(user, workspace)) + formula_toolset = FunctionToolset([formula_type_tool]) - fields = {f.id: f for f in orm_view.table.field_set.all()} - created_filters = [] - with transaction.atomic(): - for filter in vf.filters: - try: - orm_filter = utils.create_view_filter( - user, orm_view, fields, filter - ) - except ValueError as e: - logger.warning(f"Skipping filter creation: {e}") - continue - - created_filters.append({"id": orm_filter.id, **filter.model_dump()}) - created_view_filters.append( - {"view_id": vf.view_id, "filters": created_filters} - ) - - return {"created_view_filters": created_view_filters} + prompt = format_formula_generation_prompt( + description, database_tables_schema, formula_docs + ) - def create_views( - table_id: int, views: list[AnyViewItemCreate] - ) -> list[dict[str, Any]]: - """ - Creates views in the specified table. A default grid view showing all the rows - is created automatically when a table is created, no need to recreate it. + model = get_model_string() + agent_result = formula_generation_agent.run_sync( + prompt, + model=model, + model_settings=get_model_settings(model, UTILITY), + toolsets=[formula_toolset], + usage_limits=UsageLimits(request_limit=20), + ) + result = agent_result.output - - Choose the most appropriate view type for each view. - - View names must be unique within a table: check existing names when needed and - avoid duplicates. - """ + if not result.is_formula_valid: + raise Exception(f"Error generating formula: {result.error_message}") - nonlocal user, workspace, tool_helpers + table = next((t for t in database_tables if t.id == result.table_id), None) + if table is None: + raise Exception( + "The generated formula is intended for a different table " + f"than the current one. Table with ID {result.table_id} not found." + ) - if not views: - return [] + data = { + "formula": result.formula, + "formula_type": result.formula_type, + } + field_name = result.field_name - table = utils.filter_tables(user, workspace).get(id=table_id) + if save_to_field: + field = table.field_set.filter(name=field_name).first() + if field: + field = field.specific - created_views = [] with transaction.atomic(): - for view in views: - tool_helpers.update_status( - _("Creating %(view_type)s view %(view_name)s") - % {"view_type": view.type, "view_name": view.name} - ) - - orm_view = CreateViewActionType.do( + # Trash any existing non-formula field so it can be replaced, allowing + # the user to easily restore the original field if needed. + if field and field_type_registry.get_by_model(field).type != "formula": + DeleteFieldActionType.do(user, field) + field = None + + if field is None: + CreateFieldActionType.do( user, table, - view.type, - **view.to_django_orm_kwargs(table), + type_name="formula", + name=field_name, + formula=result.formula, + ) + operation = "field created" + else: + # Only update the formula of an existing formula field. + UpdateFieldActionType.do( + user, + field, + formula=result.formula, + ) + operation = "field updated" + + tool_helpers.navigate_to( + TableNavigationType( + type="database-table", + database_id=table.database_id, + table_id=table.id, + table_name=table.name, ) - - field_options = view.field_options_to_django_orm() - if field_options: - UpdateViewFieldOptionsActionType.do(user, orm_view, field_options) - - created_views.append({"id": orm_view.id, **view.model_dump()}) - - tool_helpers.navigate_to( - ViewNavigationType( - type="database-view", - database_id=table.database_id, - table_id=table.id, - view_id=created_views[0]["id"], - view_name=created_views[0]["name"], - view_type=created_views[0]["type"], ) - ) - - return {"created_views": created_views} - - def load_views_tools(): - """ - TOOL LOADER: Loads tools to manage views and filters - (grid, gallery, form, kanban, calendar and timeline). - - After calling this loader, you will be able to: - - create_views: Create grid, gallery, form, kanban, calendar and timeline views - - create_view_filters: Create filters for specific views to filter rows - - Use this when you need to create views or filters but don't have the tools yet. - """ - @udspy.module_callback - def _load_views_tools(context): - nonlocal user, workspace, tool_helpers - - observation = ["New tools are now available.\n"] - - create_tool = udspy.Tool(create_views) - new_tools = [create_tool] - observation.append("- Use `create_views` to create views.") - - create_filters_tool = udspy.Tool(create_view_filters) - new_tools.append(create_filters_tool) - observation.append( - "- Use `create_view_filters` to create filters in views." + data.update( + { + "table_id": table.id, + "table_name": table.name, + "field_name": result.field_name, + "operation": operation, + } ) - # Re-initialize the module with the new tools for the next iteration - context.module.init_module(tools=context.module._tools + new_tools) - return "\n".join(observation) - - return _load_views_tools + return data - return load_views_tools +# --------------------------------------------------------------------------- +# Dynamic row tools (create / update / delete) +# --------------------------------------------------------------------------- -class ViewsToolFactoryToolType(AssistantToolType): - type = "views_tool_factory" - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - return get_views_tool_factory(user, workspace, tool_helpers) - - -def get_formula_type_tool( - user: AbstractUser, workspace: Workspace -) -> Callable[[str], str]: +def _build_row_tools( + user: AbstractUser, + workspace: Workspace, + tool_helpers: "ToolHelpers", + table: Table, + field_ids: list[int] | None = None, +) -> dict[str, Tool]: """ - Returns a function that returns the type of a formula. + Build pydantic-ai Tool objects for row CRUD on a single table. + + Returns a dict with keys ``"create"``, ``"update"``, ``"delete"``, each + containing a ready-to-use ``Tool`` whose schema is derived from the table's + fields. + + :param user: The acting user. + :param workspace: Current workspace. + :param tool_helpers: Provides status updates and cancellation. + :param table: The table to build row tools for. + :param field_ids: If given, only include these field IDs in the + create model (useful for excluding reverse link_row fields). """ - def get_formula_type(table_id: int, field_name: str, formula: str) -> str: - """ - Returns the type of a formula. Raises an exception if the formula is not valid. - **ALWAYS** call this to validate a formula is valid before returning it. - """ - - nonlocal user, workspace + row_model_for_create = get_create_row_model(table, field_ids=field_ids) + row_model_for_update = get_update_row_model(table) + link_row_hints = get_link_row_hints(row_model_for_create) - table = utils.filter_tables(user, workspace).get(id=table_id) - field = FormulaField(formula=formula, table=table, name=field_name, order=0) - field.recalculate_internal_fields(raise_if_invalid=True) + def _create_rows( + rows: list[row_model_for_create], + thought: Annotated[str, "Brief reasoning for calling this tool."], + ) -> dict[str, Any]: + """Create new rows in the specified table.""" - result = TypeFormulaResultSerializer(field).data - if result["error"]: - raise Exception(f"Invalid formula: {result['error']}") - - return result["formula_type"] - - return get_formula_type + if not rows: + return {"created_row_ids": []} + tool_helpers.update_status( + _("Creating rows in %(table_name)s ") % {"table_name": table.name} + ) -class FormulaGenerationSignature(udspy.Signature): - """ - Generates a Baserow formula based on the provided description and table schema. - """ + validated_rows = [row.to_django_orm() for row in rows] - description: str = udspy.InputField( - desc="A brief description of what the formula should do." - ) - tables_schema: dict = udspy.InputField( - desc="The schema of all the tables in the database." - ) - formula_documentation: str = udspy.InputField( - desc="Documentation about Baserow formulas and their syntax." - ) - table_id: int = udspy.OutputField( - desc=( - "The ID of the table the formula is intended for. " - "Should be the same as current_table_id, unless the formula can " - "only be created in a different table." - ) - ) - field_name: str = udspy.OutputField( - desc="The name of the formula field to be created. For a new field, it must be unique in the table." - ) - formula: str = udspy.OutputField( - desc="The generated formula. Must be a valid Baserow formula." - ) - formula_type: str = udspy.OutputField( - desc="The type of the generated formula. Must be one of: text, long_text, " - "number, boolean, date, link_row, single_select, multiple_select, duration, array." - ) - is_formula_valid: bool = udspy.OutputField( - desc="Whether the generated formula is valid or not." + with transaction.atomic(): + orm_rows = CreateRowsActionType.do(user, table, validated_rows) + + return {"created_row_ids": [r.id for r in orm_rows]} + + create_rows_tool = Tool( + _create_rows, + name=f"create_rows_in_table_{table.id}", + description=( + f"WHEN: Creating new rows in '{table.name}' (ID: {table.id}). " + f"WHAT: Inserts up to 20 rows with field values matching the table schema. " + f"RETURNS: Created row IDs. " + f"DO NOT USE: For other tables — each table has its own create tool. " + f"HOW: Fill EVERY field including ALL link_row (relationship) fields. Never skip a field unless data is genuinely unavailable." + f"{link_row_hints}" + ), + max_retries=2, ) - error_message: str = udspy.OutputField( - desc="If the formula is not valid, an error message explaining why." + create_rows_tool.function_schema.json_schema = inline_refs( + create_rows_tool.function_schema.json_schema ) + def _update_rows( + rows: list[row_model_for_update], + thought: Annotated[str, "Brief reasoning for calling this tool."], + ) -> dict[str, Any]: + """Update existing rows in the specified table.""" -def get_generate_database_formula_tool( - user: AbstractUser, - workspace: Workspace, - tool_helpers: "ToolHelpers", -) -> Callable[[str, int], dict[str, str]]: - """ - Returns a function that generates a formula for a given field in a table. - """ + if not rows: + return {"updated_row_ids": []} - def generate_database_formula( - database_id: int, - description: str, - save_to_field: bool = True, - ) -> dict[str, str]: - """ - Generate a database formula for a formula field. No need to inspect the schema - before, this tool will do it automatically and find the best table and fields to - use. - - - table_id: The database ID where the formula field is located. - - description: A brief description of what the formula should do. - - save_to_field: Whether to save the generated formula to a field with the given - name (default: True). If False, the formula will be generated but not saved - into a field. - """ - - nonlocal user, workspace, tool_helpers - - database_tables = utils.filter_tables(user, workspace).filter( - database_id=database_id + tool_helpers.update_status( + _("Updating rows in %(table_name)s ") % {"table_name": table.name} ) - database_tables_schema = [ - t.model_dump() for t in utils.get_tables_schema(database_tables, True) - ] - - tool_helpers.update_status(_("Generating formula...")) - - formula_docs = get_formula_docs() - formula_generator = udspy.ReAct( - FormulaGenerationSignature, - tools=[get_formula_type_tool(user, workspace)], - max_iters=20, - ) - result = formula_generator( - description=description, - tables_schema={"tables": database_tables_schema}, - formula_documentation=formula_docs, - ) + validated_rows = [row.to_django_orm() for row in rows] - if not result.is_formula_valid: - raise Exception(f"Error generating formula: {result.error_message}") + with transaction.atomic(): + orm_rows = UpdateRowsActionType.do(user, table, validated_rows).updated_rows + + return {"updated_row_ids": [r.id for r in orm_rows]} + + update_rows_tool = Tool( + _update_rows, + name=f"update_rows_in_table_{table.id}", + description=( + f"WHEN: Updating existing rows in '{table.name}' (ID: {table.id}) by row ID. " + f"WHAT: Updates specified fields on up to 20 rows. Only include fields you want to change — omit fields to keep them unchanged. " + f"RETURNS: Updated row IDs. " + f"DO NOT USE: For other tables — each table has its own update tool." + f"{link_row_hints}" + ), + max_retries=2, + ) + update_rows_tool.function_schema.json_schema = inline_refs( + update_rows_tool.function_schema.json_schema + ) - table = next((t for t in database_tables if t.id == result.table_id), None) - if table is None: - raise Exception( - "The generated formula is intended for a different table " - f"than the current one. Table with ID {result.table_id} not found." - ) + def _delete_rows( + row_ids: list[int], + thought: Annotated[str, "Brief reasoning for calling this tool."], + ) -> dict[str, Any]: + """Delete rows in the specified table.""" - data = { - "formula": result.formula, - "formula_type": result.formula_type, - } - field_name = result.field_name - - if save_to_field: - field = table.field_set.filter(name=field_name).first() - if field: - field = field.specific - - with transaction.atomic(): - # Trash any existing non-formula field so it can be replaced, allowing - # the user to easily restore the original field if needed. - if field and field_type_registry.get_by_model(field).type != "formula": - DeleteFieldActionType.do(user, field) - field = None - - if field is None: - CreateFieldActionType.do( - user, - table, - type_name="formula", - name=field_name, - formula=result.formula, - ) - operation = "field created" - else: - # Only update the formula of an existing formula field. - UpdateFieldActionType.do( - user, - field, - formula=result.formula, - ) - operation = "field updated" - - tool_helpers.navigate_to( - TableNavigationType( - type="database-table", - database_id=table.database_id, - table_id=table.id, - table_name=table.name, - ) - ) + if not row_ids: + return {"deleted_row_ids": []} - data.update( - { - "table_id": table.id, - "table_name": table.name, - "field_name": result.field_name, - "operation": operation, - } - ) + tool_helpers.update_status( + _("Deleting rows in %(table_name)s ") % {"table_name": table.name} + ) - return data + with transaction.atomic(): + DeleteRowsActionType.do(user, table, row_ids) + + return {"deleted_row_ids": row_ids} + + delete_rows_tool = Tool( + _delete_rows, + name=f"delete_rows_in_table_{table.id}", + description=( + f"WHEN: Deleting rows from '{table.name}' (ID: {table.id}) by row ID. " + f"WHAT: Permanently removes up to 20 specified rows. " + f"RETURNS: Deleted row IDs. " + f"DO NOT USE: For other tables — each table has its own delete tool." + ), + ) - return generate_database_formula + return { + "create": create_rows_tool, + "update": update_rows_tool, + "delete": delete_rows_tool, + } + + +# --------------------------------------------------------------------------- +# Tool 10: load_row_tools +# --------------------------------------------------------------------------- + + +def load_row_tools( + ctx: RunContext[AssistantDeps], + table_ids: Annotated[ + list[int], Field(description="List of table IDs to load row tools for.") + ], + operations: Annotated[ + list[Literal["create", "update", "delete"]], + Field( + description="Which row operations to enable: 'create', 'update', and/or 'delete'." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> str: + """\ + TOOL LOADER — unlocks create/update/delete row tools for directly manipulating DATABASE rows. No need to know the schema beforehand, the loaded tools include it. + + WHEN to use: You need to directly create, update, or delete rows in a database table. Must be called before any row manipulation. + WHAT it does: Unlocks table-specific tools and their schema: create_rows_in_table_X, update_rows_in_table_X, delete_rows_in_table_X for each table ID provided. The loaded tools include the full field schema — no need to call get_tables_schema. + RETURNS: Names of newly available tools. + DO NOT USE when: Row tools for these tables are already loaded from a previous call in this session. + DO NOT USE for builder workflow actions — if you want a button/form in an Application Builder page to create/update/delete rows, use create_actions instead. load_row_tools is for direct database manipulation, NOT for configuring app behavior. + HOW: Just call this with the table ID(s) and operations you need. The loaded row tools already contain the complete field schema in their parameters — do NOT call get_tables_schema or search_user_docs before or after this tool. + + EXAMPLES: + - "Create 5 rows" → load_row_tools([table_id], ["create"]) → create_rows_in_table_X(rows=[...]) + - "Update row 7" → load_row_tools([table_id], ["update"]) → update_rows_in_table_X(rows=[{id: 7, ...}]) + - "Delete rows 1-3" → load_row_tools([table_id], ["delete"]) → delete_rows_in_table_X(row_ids=[1,2,3]) + - To find linked row values, use list_rows with field_ids filter on the linked table. + """ + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers -class GenerateDatabaseFormulaToolType(AssistantToolType): - type = "generate_database_formula" + tables = helpers.filter_tables(user, workspace).filter(id__in=table_ids) + if not tables: + return ( + "No valid tables found for the given IDs. " + "Make sure the table IDs are correct." + ) - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - return get_generate_database_formula_tool(user, workspace, tool_helpers) + new_tools: list[Tool] = [] + for table in tables: + table_tools = _build_row_tools(user, workspace, tool_helpers, table) + + if "create" in operations: + new_tools.append(table_tools["create"]) + if "update" in operations: + new_tools.append(table_tools["update"]) + if "delete" in operations: + new_tools.append(table_tools["delete"]) + + # Store new tools in dynamic_tools for the dynamic toolset + # to pick up on the next agent step + ctx.deps.dynamic_tools.extend(new_tools) + + tool_names = [t.name for t in new_tools] + return f"Tools loaded: {', '.join(tool_names)}" + + +# --------------------------------------------------------------------------- +# Module-level toolset +# --------------------------------------------------------------------------- + + +TOOL_FUNCTIONS = [ + list_tables, + get_tables_schema, + list_rows, + list_views, + create_tables, + create_fields, + update_fields, + delete_fields, + create_views, + create_view_filters, + generate_formula, + load_row_tools, +] +database_toolset = FunctionToolset(TOOL_FUNCTIONS, max_retries=3) + +ROUTING_RULES = """\ +- Check list_* before create_* to avoid duplicates. +- switch_mode: switch domain if task needs tools not in the current mode. +- Database row CRUD → call load_row_tools first (includes schema — skip get_tables_schema). +- create_tables: include ALL related tables in one call so link_row fields connect properly. Add sample rows unless told otherwise. +- create_rows: fill EVERY field including ALL link_row fields. +- After creating tables for an app/automation task, switch_mode back to continue building.""" diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/__init__.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/__init__.py index 7406ef8204..ddc43397e1 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/__init__.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/__init__.py @@ -1,5 +1,6 @@ from .base import * # noqa: F401, F403 from .fields import * # noqa: F401, F403 +from .rows import * # noqa: F401, F403 from .table import * # noqa: F401, F403 from .view_filters import * # noqa: F401, F403 from .views import * # noqa: F401, F403 diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/base.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/base.py index 83e0219545..63d34e3656 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/base.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/base.py @@ -1,35 +1,49 @@ from datetime import date, datetime -from baserow_enterprise.assistant.types import BaseModel +from dateutil import parser as _dateutil_parser -# Somehow LLMs struggle with dates -class Date(BaseModel): - year: int - month: int - day: int +def _normalize(value: str) -> str: + """Replace common separator variants so fromisoformat can parse them.""" - def to_django_orm(self): - return date(self.year, self.month, self.day).isoformat() + return value.replace("/", "-").strip() - @classmethod - def from_django_orm(cls, orm_date: date) -> "Date": - d = orm_date - return cls(year=d.year, month=d.month, day=d.day) +def parse_date(value: str) -> date: + """ + Parse a date string into a date object. -class Datetime(Date): - hour: int - minute: int + Tries ISO 8601 first, then falls back to dateutil for fuzzy formats + like ``Jan 5, 2023`` or ``05/01/2023``. + """ - def to_django_orm(self): - return datetime( - self.year, self.month, self.day, self.hour, self.minute - ).isoformat() + try: + return date.fromisoformat(_normalize(value)) + except ValueError: + return _dateutil_parser.parse(value).date() - @classmethod - def from_django_orm(cls, orm_datetime: datetime) -> "Datetime": - dt = orm_datetime - return cls( - year=dt.year, month=dt.month, day=dt.day, hour=dt.hour, minute=dt.minute - ) + +def parse_datetime(value: str) -> datetime: + """ + Parse a datetime string into a datetime object. + + Tries ISO 8601 first, then falls back to dateutil for fuzzy formats + like ``Jan 5, 2023 10:00 AM``. + """ + + try: + return datetime.fromisoformat(_normalize(value)) + except ValueError: + return _dateutil_parser.parse(value) + + +def format_date(value: date) -> str: + """Format a date as ISO 8601 (``YYYY-MM-DD``).""" + + return value.isoformat() + + +def format_datetime(value: datetime) -> str: + """Format a datetime as ISO 8601 (``YYYY-MM-DDTHH:MM``), without seconds.""" + + return value.strftime("%Y-%m-%dT%H:%M") diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/fields.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/fields.py index 742e212c6c..8829a91639 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/fields.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/fields.py @@ -1,253 +1,18 @@ -from typing import Annotated, Literal, Type +import json +from typing import Any, Literal from django.db.models import Q -from pydantic import Field - -from baserow.contrib.database.fields.models import ( - DateField, - FormulaField, - LinkRowField, - LookupField, - MultipleSelectField, - NumberField, - RatingField, - SingleSelectField, -) +from pydantic import Field, model_serializer, model_validator + from baserow.contrib.database.fields.models import Field as BaserowField from baserow.contrib.database.fields.registries import field_type_registry from baserow_enterprise.assistant.types import BaseModel -from baserow_enterprise.data_sync.hubspot_contacts_data_sync import LongTextField from baserow_premium.permission_manager import Table - -class FieldItemCreate(BaseModel): - """Base model for creating a new field (no ID).""" - - name: str = Field(...) - type: str = Field(...) - - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - return {k: v for k, v in self.model_dump().items() if k not in {"id", "type"}} - - -class FieldItem(FieldItemCreate): - """Model for an existing field (with ID).""" - - id: int = Field(...) - - @classmethod - def from_django_orm(cls, orm_field: BaserowField) -> "FieldItem": - return cls( - id=orm_field.id, - name=orm_field.name, - type=field_type_registry.get_by_model(orm_field).type, - ) - - -# Event if type could be inferred, certain models (i.e. openai-gpt-oss-120b) requires -# all the fields to be required and can cause issues with optional fields, so we -# explicitly set them as required, even if seems unnecessary. - - -class BaseTextFieldItem(FieldItemCreate): - type: Literal["text"] = Field(..., description="Single line text field.") - - -class TextFieldItemCreate(BaseTextFieldItem): - """Model for creating a text field.""" - - -class TextFieldItem(BaseTextFieldItem, FieldItem): - """Model for an existing text field.""" - - -class BaseLongTextFieldItem(FieldItemCreate): - type: Literal["long_text"] = Field( - ..., - description="Multi-line text field. Ideal for descriptions, notes and long-form content.", - ) - rich_text: bool = Field( - default=True, - description="Whether the long text field supports rich text.", - ) - - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - return { - "name": self.name, - "long_text_enable_rich_text": self.rich_text, - } - - -class LongTextFieldItemCreate(BaseLongTextFieldItem): - """Model for creating a long text field.""" - - -class LongTextFieldItem(BaseLongTextFieldItem, FieldItem): - """Model for an existing long text field.""" - - @classmethod - def from_django_orm(cls, orm_field: LongTextField) -> "LongTextFieldItem": - field = orm_field.specific - return cls( - id=field.id, - name=field.name, - type="long_text", - rich_text=orm_field.long_text_enable_rich_text, - ) - - -class BaseNumberFieldItem(FieldItemCreate): - type: Literal["number"] = Field( - ..., description="Numeric field, with decimals and optional prefix/suffix." - ) - decimal_places: int = Field(default=2, description="The number of decimal places.") - suffix: str = Field( - default="", - description="An optional suffix to display after the number.", - ) - - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - return { - "name": self.name, - "number_decimal_places": self.decimal_places, - "number_suffix": self.suffix, - } - - -class NumberFieldItemCreate(BaseNumberFieldItem): - """Model for creating a number field.""" - - -class NumberFieldItem(BaseNumberFieldItem, FieldItem): - """Model for an existing number field.""" - - @classmethod - def from_django_orm(cls, orm_field: NumberField) -> "NumberFieldItem": - return cls( - id=orm_field.id, - name=orm_field.name, - type="number", - decimal_places=orm_field.number_decimal_places, - suffix=orm_field.number_suffix, - ) - - -class BaseRatingFieldItem(FieldItemCreate): - type: Literal["rating"] = Field( - ..., description="Rating field. Ideal for reviews or scores." - ) - max_value: int = Field( - default=5, description="The maximum value of the rating field." - ) - - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - return { - "name": self.name, - "max_value": self.max_value, - } - - -class RatingFieldItemCreate(BaseRatingFieldItem): - """Model for creating a rating field.""" - - -class RatingFieldItem(BaseRatingFieldItem, FieldItem): - """Model for an existing rating field.""" - - @classmethod - def from_django_orm(cls, orm_field: RatingField) -> "RatingFieldItem": - return cls( - id=orm_field.id, - name=orm_field.name, - type="rating", - max_value=orm_field.max_value, - ) - - -class BaseBooleanFieldItem(FieldItemCreate): - type: Literal["boolean"] = Field(..., description="Boolean field.") - - -class BooleanFieldItemCreate(BaseBooleanFieldItem): - """Model for creating a boolean field.""" - - -class BooleanFieldItem(BaseBooleanFieldItem, FieldItem): - """Model for an existing boolean field.""" - - -class BaseDateFieldItem(FieldItemCreate): - type: Literal["date"] = Field(..., description="Date or datetime field.") - include_time: bool = Field( - default=False, description="Whether the date field includes time." - ) - - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - return { - "name": self.name, - "date_include_time": self.include_time, - } - - -class DateFieldItemCreate(BaseDateFieldItem): - """Model for creating a date field.""" - - -class DateFieldItem(BaseDateFieldItem, FieldItem): - """Model for an existing date field.""" - - @classmethod - def from_django_orm(cls, orm_field: DateField) -> "DateFieldItem": - return cls( - id=orm_field.id, - name=orm_field.name, - type="date", - include_time=orm_field.date_include_time, - ) - - -class BaseLinkRowFieldItem(FieldItemCreate): - type: Literal["link_row"] = Field( - ..., description="Link row field. It creates relationships between tables." - ) - linked_table: str | int = Field( - ..., description="The ID or the name of the table this field links to." - ) - - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - if isinstance(self.linked_table, str): - q = Q(name=self.linked_table, database=table.database) - else: - q = Q(id=self.linked_table, database=table.database) - - try: - link_row_table = Table.objects.get(q) - except Table.DoesNotExist: - raise ValueError( - f"The linked_table '{self.linked_table}' does not exist in the database." - "Ensure you provide a valid table name or ID." - ) - - return {"name": self.name, "link_row_table": link_row_table} - - -class LinkRowFieldItemCreate(BaseLinkRowFieldItem): - """Model for creating a link row field.""" - - -class LinkRowFieldItem(BaseLinkRowFieldItem, FieldItem): - """Model for an existing link row field.""" - - @classmethod - def from_django_orm(cls, orm_field: LinkRowField) -> "BaseLinkRowFieldItem": - return cls( - id=orm_field.id, - name=orm_field.name, - type="link_row", - linked_table=orm_field.link_row_table_id, - ) - +# --------------------------------------------------------------------------- +# Shared types +# --------------------------------------------------------------------------- OptionColor = Literal[ "light-blue", @@ -301,8 +66,7 @@ class SelectOption(BaseModel): color: OptionColor -# Define a subset of colors to use when creating fields, so we don't confuse the model -# with too many options. +# Subset of colors for creation to avoid confusing the model OptionColorCreate = Literal[ "blue", "green", @@ -319,248 +83,591 @@ class SelectOption(BaseModel): class SelectOptionCreate(BaseModel): value: str - color: OptionColorCreate + color: OptionColorCreate | None = None + + +class InvalidFormulaFieldError(Exception): + """Raised when a formula field has an invalid formula.""" + + def __init__(self, field_name: str, formula: str, table: Table, error: str): + self.field_name = field_name + self.formula = formula + self.table = table + self.error = error + super().__init__(f"Invalid formula for field '{field_name}': {error}") + + +# --------------------------------------------------------------------------- +# Flat field types — single model, all type-specific fields optional +# --------------------------------------------------------------------------- + +FieldType = Literal[ + "text", + "long_text", + "number", + "rating", + "boolean", + "date", + "link_row", + "single_select", + "multiple_select", + "file", + "formula", + "lookup", +] +_TYPE_ALIASES: dict[str, str] = { + "string": "text", + "varchar": "text", + "rich_text": "long_text", + "richtext": "long_text", + "textarea": "long_text", + "integer": "number", + "int": "number", + "float": "number", + "decimal": "number", + "numeric": "number", + "checkbox": "boolean", + "bool": "boolean", + "datetime": "date", + "link": "link_row", + "relation": "link_row", + "relationship": "link_row", + "foreign_key": "link_row", + "fk": "link_row", + "select": "single_select", + "dropdown": "single_select", + "enum": "single_select", + "multi_select": "multiple_select", + "multiselect": "multiple_select", + "tags": "multiple_select", + "attachment": "file", + "upload": "file", + "image": "file", +} + +_SELECT_COLORS: list[str] = [ + "blue", + "green", + "cyan", + "orange", + "yellow", + "red", + "brown", + "purple", + "pink", + "gray", +] -class BaseSingleSelectFieldItem(FieldItemCreate): - type: Literal["single_select"] = Field( - ..., - description="Single select field. Allows users to choose one option from a list.", - ) +_KEY_ALIASES: dict[str, str] = { + "long_text_enable_rich_text": "rich_text", + "number_decimal_places": "decimal_places", + "number_suffix": "suffix", + "date_include_time": "include_time", + "link_row_table": "linked_table", + "link_row_table_id": "linked_table", + "through_field": "linked_table", + "through_field_id": "linked_table", + "target_field_id": "target_field", +} + +# Creation order: regular → link_row → lookup → formula +FIELD_ORDER: dict[str, int] = {"link_row": 1, "lookup": 2, "formula": 3} + +_FIELD_EXAMPLES: dict[str, dict] = { + "text": {"name": "Title", "type": "text"}, + "long_text": {"name": "Notes", "type": "long_text"}, + "number": {"name": "Price", "type": "number", "decimal_places": 2}, + "rating": {"name": "Stars", "type": "rating", "max_value": 5}, + "boolean": {"name": "Active", "type": "boolean"}, + "date": {"name": "Due Date", "type": "date"}, + "link_row": { + "name": "Project", + "type": "link_row", + "linked_table": "Projects", + }, + "single_select": { + "name": "Status", + "type": "single_select", + "options": [{"value": "Open", "color": "green"}], + }, + "multiple_select": { + "name": "Tags", + "type": "multiple_select", + "options": [{"value": "Important", "color": "red"}], + }, + "file": {"name": "Attachment", "type": "file"}, + "formula": { + "name": "Total", + "type": "formula", + "formula": "field('Price') * 2", + }, + "lookup": { + "name": "Client Name", + "type": "lookup", + "linked_table": "Clients", + "target_field": "Name", + }, +} + + +# --------------------------------------------------------------------------- +# to_django_orm builders: (FieldItemCreate, Table, user | None) -> dict +# --------------------------------------------------------------------------- + + +def _resolve_linked_table(linked_table_ref, table): + """Resolve a linked_table reference (name or ID) to a Table object.""" + + if isinstance(linked_table_ref, str): + q = Q(name=linked_table_ref, database=table.database) + else: + q = Q(id=linked_table_ref, database=table.database) + result = Table.objects.filter(q).order_by("id").first() + if not result: + raise ValueError( + f"Table '{linked_table_ref}' not found in the database. " + f"Ensure you provide a valid table name or ID." + ) + return result -class SingleSelectFieldItemCreate(BaseSingleSelectFieldItem): - options: list[SelectOptionCreate] = Field( - description="The list of options for the field. Use appropriate colors for each option.", - ) +def _simple_to_orm(f, table, user): + return {"name": f.name} - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - return { - "name": self.name, - "select_options": [ - {"id": -i, "value": option.value, "color": option.color} - for (i, option) in enumerate(self.options, start=1) - ], - } +def _long_text_to_orm(f, table, user): + return {"name": f.name, "long_text_enable_rich_text": f.rich_text} -class SingleSelectFieldItem(BaseSingleSelectFieldItem, FieldItem): - options: list[SelectOption] = Field( - description="The list of options for the field.", - ) - @classmethod - def from_django_orm( - cls, orm_field: SingleSelectField - ) -> "BaseSingleSelectFieldItem": - field = orm_field.specific - return cls( - id=field.id, - name=field.name, - type="single_select", - options=[ - SelectOption( - id=opt.id, - value=opt.value, - color=opt.color, - ) - for opt in field.select_options.all() - ], - ) +def _number_to_orm(f, table, user): + return { + "name": f.name, + "number_decimal_places": f.decimal_places, + "number_suffix": f.suffix, + "number_negative": True, + } -class BaseMultipleSelectFieldItem(FieldItemCreate): - type: Literal["multiple_select"] = Field( - ..., - description="Multiple select field. Allows users to choose multiple options from a list.", - ) +def _rating_to_orm(f, table, user): + return {"name": f.name, "max_value": f.max_value} - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - return { - "name": self.name, - "select_options": [ - {"id": -i, "value": option.value, "color": option.color} - for (i, option) in enumerate(self.options, start=1) - ], - } +def _date_to_orm(f, table, user): + return {"name": f.name, "date_include_time": f.include_time} -class MultipleSelectFieldItemCreate(BaseMultipleSelectFieldItem): - options: list[SelectOptionCreate] = Field( - description="The list of options for the field. Use appropriate colors for each option.", - ) +def _link_row_to_orm(f, table, user): + linked = _resolve_linked_table(f.linked_table, table) + return {"name": f.name, "link_row_table": linked} -class MultipleSelectFieldItem(BaseMultipleSelectFieldItem, FieldItem): - options: list[SelectOption] = Field( - description="The list of options for the field.", - ) - @classmethod - def from_django_orm( - cls, orm_field: MultipleSelectField - ) -> "BaseMultipleSelectFieldItem": - field = orm_field.specific - return cls( - id=field.id, - name=field.name, - type="multiple_select", - options=[ - SelectOption( - id=opt.id, - value=opt.value, - color=opt.color, - ) - for opt in field.select_options.all() - ], - ) +def _select_to_orm(f, table, user): + return { + "name": f.name, + "select_options": [ + { + "id": -i, + "value": opt.value, + "color": opt.color or _SELECT_COLORS[(i - 1) % len(_SELECT_COLORS)], + } + for i, opt in enumerate(f.options, start=1) + ], + } -class BaseFileFieldItem(FieldItemCreate): - type: Literal["file"] = Field(..., description="File field.") +def _formula_to_orm(f, table, user): + if f.formula: + from baserow.contrib.database.fields.models import FormulaField + from baserow.core.formula.parser.exceptions import BaserowFormulaException + try: + tmp = FormulaField(formula=f.formula, table=table, name=f.name, order=0) + tmp.recalculate_internal_fields(raise_if_invalid=True) + except BaserowFormulaException as e: + raise InvalidFormulaFieldError(f.name, f.formula, table, str(e)) -class FileFieldItemCreate(BaseFileFieldItem): - pass + return {"name": f.name, "formula": f.formula} -class FileFieldItem(BaseFileFieldItem, FieldItem): - pass +def _lookup_to_orm(f, table, user): + from baserow.contrib.database.fields.models import LinkRowField + linked = _resolve_linked_table(f.linked_table, table) -class FormulaFieldItemCreate(FieldItemCreate): - type: Literal["formula"] = Field(..., description="Formula field.") - formula: str = Field( - ..., - description="The formula to use in the field. It needs to be generated via the appropriate tool or use '' as placeholder.", + # Find existing link_row field pointing to linked table + through = ( + LinkRowField.objects.filter(table=table, link_row_table=linked) + .order_by("id") + .first() ) - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - return { - "name": self.name, - "formula": self.formula, - } + # Auto-create link_row if missing and user is available + if not through and user: + from baserow.contrib.database.fields.actions import CreateFieldActionType + through = CreateFieldActionType.do( + user, + table, + "link_row", + name=linked.name, + link_row_table=linked, + ) -class FormulaFieldItem(FormulaFieldItemCreate, FieldItem): - formula_type: str = Field(..., description="The type of the formula.") - array_formula_type: str | None = Field( - ..., - description=("If the formula type is 'array', the type of the array items."), - ) + if not through: + raise ValueError( + f"No link_row field to '{f.linked_table}' exists on this table. " + f"Create a link_row field first." + ) - @classmethod - def from_django_orm(cls, orm_field: FormulaField) -> "FormulaFieldItem": - field = orm_field.specific - return cls( - id=field.id, - name=field.name, - type="formula", - formula=field.formula, - formula_type=field.formula_type, - array_formula_type=field.array_formula_type, + data: dict[str, Any] = {"name": f.name, "through_field_id": through.id} + if isinstance(f.target_field, str): + data["target_field_name"] = f.target_field + else: + data["target_field_id"] = f.target_field + return data + + +_TO_DJANGO_ORM = { + "text": _simple_to_orm, + "boolean": _simple_to_orm, + "file": _simple_to_orm, + "long_text": _long_text_to_orm, + "number": _number_to_orm, + "rating": _rating_to_orm, + "date": _date_to_orm, + "link_row": _link_row_to_orm, + "single_select": _select_to_orm, + "multiple_select": _select_to_orm, + "formula": _formula_to_orm, + "lookup": _lookup_to_orm, +} + + +# --------------------------------------------------------------------------- +# from_django_orm builders: (orm_field) -> dict of extra kwargs +# --------------------------------------------------------------------------- + + +def _select_options_from_orm(orm_field): + from typing import get_args + + valid_colors = set(get_args(OptionColor)) + return [ + SelectOption( + id=opt.id, + value=opt.value, + color=opt.color if opt.color in valid_colors else "blue", ) + for opt in orm_field.specific.select_options.all() + ] + + +_FROM_DJANGO_ORM: dict[str, Any] = { + "long_text": lambda f: {"rich_text": f.specific.long_text_enable_rich_text}, + "number": lambda f: { + "decimal_places": f.number_decimal_places, + "suffix": f.number_suffix, + }, + "rating": lambda f: {"max_value": f.max_value}, + "date": lambda f: {"include_time": f.date_include_time}, + "link_row": lambda f: {"linked_table": f.link_row_table_id}, + "single_select": lambda f: {"options": _select_options_from_orm(f)}, + "multiple_select": lambda f: {"options": _select_options_from_orm(f)}, + "formula": lambda f: { + "formula": f.specific.formula, + "formula_type": f.specific.formula_type, + "array_formula_type": f.specific.array_formula_type, + }, + "lookup": lambda f: { + "through_field": f.specific.through_field_id, + "target_field": f.specific.target_field_id, + "through_field_name": f.specific.through_field_name, + "target_field_name": f.specific.target_field_name, + }, +} + + +# --------------------------------------------------------------------------- +# FieldItemCreate +# --------------------------------------------------------------------------- -class LookupFieldItemCreate(FieldItemCreate): - type: Literal["lookup"] = Field(..., description="Lookup field.") - through_field: int | str = Field( - ..., description="The ID of the link row field to lookup through." +class FieldItemCreate(BaseModel): + """Flat model for creating a field: name + type + type-specific options.""" + + name: str = Field(..., description="The name of the field.") + type: FieldType = Field(..., description="The field type.") + + # (long_text) + rich_text: bool = Field( + True, description="(long_text) Whether the field supports rich text." ) - target_field: int | str = Field( - ..., description="The ID of the field to lookup on the linked table." + # (number) + decimal_places: int = Field( + 0, description="(number) Decimal places (0, 1, 2, ...)." + ) + suffix: str = Field( + "", description="(number) Suffix displayed after the number, or ''." + ) + # (rating) + max_value: int = Field(5, description="(rating) Maximum rating value.") + # (date) + include_time: bool = Field( + False, description="(date) Whether the date includes time." + ) + # (link_row, lookup) + linked_table: str | int | None = Field( + None, + description="(link_row, lookup) ID or name of the linked table.", + ) + # (single_select, multiple_select) + options: list[SelectOptionCreate] | None = Field( + None, + description="(single_select, multiple_select) List of options with colors.", + ) + # (formula) + formula: str = Field( + "", description="(formula) The formula expression, or '' as placeholder." + ) + # (lookup) + target_field: int | str | None = Field( + None, description="(lookup) ID or name of the field to look up." ) - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - data = {"name": self.name} - if isinstance(self.through_field, str): - data["through_field_name"] = self.through_field - else: - data["through_field_id"] = self.through_field - - if isinstance(self.target_field, str): - data["target_field_name"] = self.target_field - else: - data["target_field_id"] = self.target_field + @model_validator(mode="before") + @classmethod + def _normalize(cls, data): + if not isinstance(data, dict): + return data + + # Normalize type aliases + raw_type = data.get("type") + if isinstance(raw_type, str): + data["type"] = _TYPE_ALIASES.get(raw_type, raw_type) + + # Normalize key aliases + for old_key, new_key in _KEY_ALIASES.items(): + if old_key in data and new_key not in data: + data[new_key] = data.pop(old_key) + + # Convert string options to SelectOptionCreate dicts + if "options" in data and isinstance(data["options"], list): + normalized = [] + for i, opt in enumerate(data["options"]): + if isinstance(opt, str): + normalized.append( + {"value": opt, "color": _SELECT_COLORS[i % len(_SELECT_COLORS)]} + ) + else: + normalized.append(opt) + data["options"] = normalized return data + # Required fields per type: {type: [(attr_name, display_name), ...]} + _REQUIRED_FIELDS: dict[str, list[tuple[str, str]]] = { + "link_row": [("linked_table", "linked_table")], + "single_select": [("options", "options")], + "multiple_select": [("options", "options")], + "lookup": [("linked_table", "linked_table"), ("target_field", "target_field")], + } -class LookupFieldItem(LookupFieldItemCreate, FieldItem): - through_field_name: str = Field( - ..., description="The name of the link row field to lookup through." - ) - target_field_name: str = Field( - ..., description="The name of the field to lookup on the linked table." - ) + @model_validator(mode="after") + def _validate_required_for_type(self): + required = self._REQUIRED_FIELDS.get(self.type) + if required: + missing = [name for attr, name in required if not getattr(self, attr)] + if missing: + raise ValueError( + f"{self.type} requires {', '.join(missing)}. " + f"Example: {json.dumps(_FIELD_EXAMPLES[self.type])}" + ) + return self - @classmethod - def from_django_orm(cls, orm_field: LookupField) -> "LookupFieldItem": - field = orm_field.specific - return cls( - id=field.id, - name=field.name, - type="lookup", - through_field=field.through_field_id, - target_field=field.target_field_id, - through_field_name=field.through_field_name, - target_field_name=field.target_field_name, - ) + def to_django_orm_kwargs(self, table: Table, user=None) -> dict[str, Any]: + builder = _TO_DJANGO_ORM.get(self.type, _simple_to_orm) + return builder(self, table, user) -AnyFieldItemCreate = Annotated[ - TextFieldItemCreate - | LongTextFieldItemCreate - | NumberFieldItemCreate - | RatingFieldItemCreate - | BooleanFieldItemCreate - | DateFieldItemCreate - | LinkRowFieldItemCreate - | SingleSelectFieldItemCreate - | MultipleSelectFieldItemCreate - | FileFieldItemCreate - | FormulaFieldItemCreate - | LookupFieldItemCreate, - Field(discriminator="type"), -] +# --------------------------------------------------------------------------- +# FieldItem (read-back) +# --------------------------------------------------------------------------- -AnyFieldItem = ( - TextFieldItem - | LongTextFieldItem - | NumberFieldItem - | RatingFieldItem - | BooleanFieldItem - | DateFieldItem - | LinkRowFieldItem - | SingleSelectFieldItem - | MultipleSelectFieldItem - | FileFieldItem - | FormulaFieldItem - | LookupFieldItem - | FieldItem -) - - -class FieldItemsRegistry: - _registry = { - "text": TextFieldItem, - "long_text": LongTextFieldItem, - "number": NumberFieldItem, - "date": DateFieldItem, - "boolean": BooleanFieldItem, - "rating": RatingFieldItem, - "link_row": LinkRowFieldItem, - "single_select": SingleSelectFieldItem, - "multiple_select": MultipleSelectFieldItem, - "file": FileFieldItem, - "formula": FormulaFieldItem, - "lookup": LookupFieldItem, - } - def from_django_orm(self, orm_field: Type[BaserowField]) -> FieldItem: +class FieldItem(BaseModel): + """Existing field with ID — flat structure matching FieldItemCreate.""" + + id: int = Field(...) + name: str = Field(..., description="The name of the field.") + type: str = Field(..., description="The field type.") + + # Type-specific (populated per type, others excluded via exclude_none) + rich_text: bool | None = None + decimal_places: int | None = None + suffix: str | None = None + max_value: int | None = None + include_time: bool | None = None + linked_table: int | None = None + options: list[SelectOption] | None = None + formula: str | None = None + formula_type: str | None = None + array_formula_type: str | None = None + through_field: int | None = None + target_field: int | None = None + through_field_name: str | None = None + target_field_name: str | None = None + + @model_serializer(mode="wrap") + def _exclude_none(self, handler): + return {k: v for k, v in handler(self).items() if v is not None} + + @classmethod + def from_django_orm(cls, orm_field: BaserowField) -> "FieldItem": field_type = field_type_registry.get_by_model(orm_field).type - field_class: FieldItem = self._registry.get(field_type, FieldItem) - return field_class.from_django_orm(orm_field) + kwargs: dict[str, Any] = { + "id": orm_field.id, + "name": orm_field.name, + "type": field_type, + } + builder = _FROM_DJANGO_ORM.get(field_type) + if builder: + kwargs.update(builder(orm_field)) + return cls(**kwargs) + + +# --------------------------------------------------------------------------- +# FieldItemUpdate +# --------------------------------------------------------------------------- + + +def _update_simple(f, field_type): + kwargs = {} + if f.name is not None: + kwargs["name"] = f.name + return kwargs + + +def _update_long_text(f, field_type): + kwargs = _update_simple(f, field_type) + if f.rich_text is not None: + kwargs["long_text_enable_rich_text"] = f.rich_text + return kwargs + + +def _update_number(f, field_type): + kwargs = _update_simple(f, field_type) + if f.decimal_places is not None: + kwargs["number_decimal_places"] = f.decimal_places + if f.suffix is not None: + kwargs["number_suffix"] = f.suffix + return kwargs + + +def _update_rating(f, field_type): + kwargs = _update_simple(f, field_type) + if f.max_value is not None: + kwargs["max_value"] = f.max_value + return kwargs + + +def _update_date(f, field_type): + kwargs = _update_simple(f, field_type) + if f.include_time is not None: + kwargs["date_include_time"] = f.include_time + return kwargs + + +def _update_select(f, field_type): + kwargs = _update_simple(f, field_type) + if f.options is not None: + kwargs["select_options"] = [ + { + "id": -i, + "value": opt.value, + "color": opt.color or _SELECT_COLORS[(i - 1) % len(_SELECT_COLORS)], + } + for i, opt in enumerate(f.options, start=1) + ] + return kwargs + + +def _update_formula(f, field_type): + kwargs = _update_simple(f, field_type) + if f.formula is not None: + kwargs["formula"] = f.formula + return kwargs + + +_TO_UPDATE_ORM = { + "text": _update_simple, + "boolean": _update_simple, + "file": _update_simple, + "long_text": _update_long_text, + "number": _update_number, + "rating": _update_rating, + "date": _update_date, + "link_row": _update_simple, + "single_select": _update_select, + "multiple_select": _update_select, + "formula": _update_formula, + "lookup": _update_simple, +} + + +class FieldItemUpdate(BaseModel): + """Flat model for updating a field: field_id + optional type-specific fields.""" + + field_id: int = Field(..., description="The ID of the field to update.") + name: str | None = Field(None, description="New name for the field.") + + # (long_text) + rich_text: bool | None = Field( + None, description="(long_text) Whether the field supports rich text." + ) + # (number) + decimal_places: int | None = Field( + None, description="(number) Decimal places (0, 1, 2, ...)." + ) + suffix: str | None = Field( + None, description="(number) Suffix displayed after the number." + ) + # (rating) + max_value: int | None = Field(None, description="(rating) Maximum rating value.") + # (date) + include_time: bool | None = Field( + None, description="(date) Whether the date includes time." + ) + # (single_select, multiple_select) + options: list[SelectOptionCreate] | None = Field( + None, + description="(single_select, multiple_select) List of options with colors.", + ) + # (formula) + formula: str | None = Field(None, description="(formula) The formula expression.") + @model_validator(mode="before") + @classmethod + def _normalize_keys(cls, data): + if not isinstance(data, dict): + return data + for old_key, new_key in _KEY_ALIASES.items(): + if old_key in data and new_key not in data: + data[new_key] = data.pop(old_key) + # Convert string options to SelectOptionCreate dicts + if "options" in data and isinstance(data["options"], list): + normalized = [] + for i, opt in enumerate(data["options"]): + if isinstance(opt, str): + normalized.append( + {"value": opt, "color": _SELECT_COLORS[i % len(_SELECT_COLORS)]} + ) + else: + normalized.append(opt) + data["options"] = normalized + return data -field_item_registry = FieldItemsRegistry() + def to_update_kwargs(self, field_type: str) -> dict[str, Any]: + """Build kwargs for UpdateFieldActionType.do() based on the field's current type.""" + builder = _TO_UPDATE_ORM.get(field_type, _update_simple) + return builder(self, field_type) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/rows.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/rows.py new file mode 100644 index 0000000000..f393275804 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/rows.py @@ -0,0 +1,383 @@ +""" +Dynamic Pydantic models for table row CRUD. + +Builds per-table create and update models whose fields mirror the table's +database columns, with converters to/from Django ORM representations. +""" + +from dataclasses import dataclass +from typing import Any, Callable, Literal, Type + +from django.core.exceptions import ValidationError +from django.db.models import Q + +from pydantic import ConfigDict, Field, create_model + +from baserow.contrib.database.fields.field_types import LinkRowFieldType +from baserow.contrib.database.fields.models import SelectOption as OrmSelectOption +from baserow.contrib.database.table.models import ( + FieldObject, + GeneratedTableModel, + Table, +) +from baserow_enterprise.assistant.types import BaseModel + +from .base import format_date, format_datetime, parse_date, parse_datetime + + +@dataclass +class FieldDefinition: + """ + Pydantic field specification for a single table column. + + When ``type`` is None the field is unsupported and will be skipped + during model construction. + """ + + type: Type | None = None + field_def: Any | None = None + to_django_orm: Callable[[Any], Any] | None = None + from_django_orm: Callable[[Any], Any] | None = None + + +# --------------------------------------------------------------------------- +# Per-type builder functions +# --------------------------------------------------------------------------- + +# Shared converters for text-like fields +_none_to_empty = lambda v: v if v is not None else "" # noqa: E731 + + +def _text_field_def(orm_field, orm_field_type): + return FieldDefinition( + str | None, + Field(..., description="Single-line text", title=orm_field.name), + _none_to_empty, + _none_to_empty, + ) + + +def _long_text_field_def(orm_field, orm_field_type): + return FieldDefinition( + str | None, + Field(..., description="Multi-line text", title=orm_field.name), + _none_to_empty, + _none_to_empty, + ) + + +def _number_field_def(orm_field, orm_field_type): + return FieldDefinition( + float | None, + Field(..., description="Number or None", title=orm_field.name), + ) + + +def _boolean_field_def(orm_field, orm_field_type): + return FieldDefinition( + bool, Field(..., description="Boolean", title=orm_field.name) + ) + + +def _date_field_def(orm_field, orm_field_type): + if orm_field.date_include_time: + return FieldDefinition( + str | None, + Field( + ..., + description="ISO datetime (YYYY-MM-DDTHH:MM) or None", + title=orm_field.name, + ), + lambda v: parse_datetime(v).isoformat() if v else None, + lambda v: format_datetime(v) if v is not None else None, + ) + return FieldDefinition( + str | None, + Field(..., description="ISO date (YYYY-MM-DD) or None", title=orm_field.name), + lambda v: parse_date(v).isoformat() if v else None, + lambda v: format_date(v) if v is not None else None, + ) + + +def _single_select_field_def(orm_field, orm_field_type): + choices = [option.value for option in orm_field.select_options.all()] + if not choices: + return FieldDefinition() # Unsupported: no options defined + + return FieldDefinition( + Literal[*choices] | None, + Field( + ..., + description=f"One of: {', '.join(choices)} or None", + title=orm_field.name, + ), + lambda v: v if v in choices else None, + lambda v: v.value if isinstance(v, OrmSelectOption) else v, + ) + + +def _multiple_select_field_def(orm_field, orm_field_type): + choices = [option.value for option in orm_field.select_options.all()] + if not choices: + return FieldDefinition() # Unsupported: no options defined + + return FieldDefinition( + list[Literal[*choices]], + Field( + ..., + description=f"List of any of: {', '.join(choices)} or empty list", + title=orm_field.name, + ), + lambda v: [opt for opt in v if opt in choices], + lambda v: [opt.value for opt in v.all()] if v is not None else [], + ) + + +def _link_row_field_def(orm_field, orm_field_type): + linked_model = orm_field.link_row_table.get_model() + linked_primary_key = linked_model.get_primary_field() + if linked_primary_key is None: + return FieldDefinition() + + linked_pk = linked_primary_key.db_column + examples = list( + linked_model.objects.exclude( + Q(**{f"{linked_pk}__isnull": True}) | Q(**{f"{linked_pk}__exact": ""}) + ).values_list("id", linked_pk)[:10] + ) + + def to_django_orm(value): + if isinstance(value, (str, int)): + value = [value] + if value is not None: + try: + return LinkRowFieldType().prepare_value_for_db(orm_field, value) + except ValidationError: + pass + return [] + + def from_django_orm(value): + values = [str(v) for v in value.all()] + if orm_field.link_row_multiple_relationships: + return values + return values[0] if values else None + + if orm_field.link_row_multiple_relationships: + desc = "List of values (as strings) or IDs (as integers) from the linked table or empty list." + field_type = list[str | int] | None + else: + desc = "Single value (as string) or ID (as integer) from the linked table." + field_type = str | int | None + if examples: + desc += ( + " Examples: " + + ", ".join(f"{{id:{v[0]}, value: `{v[1]}`}}" for v in examples) + + ", .." + ) + return FieldDefinition( + field_type, + Field(..., description=desc, title=orm_field.name), + to_django_orm, + from_django_orm, + ) + + +_FIELD_DEF_BUILDERS: dict[str, Callable] = { + "text": _text_field_def, + "long_text": _long_text_field_def, + "number": _number_field_def, + "boolean": _boolean_field_def, + "date": _date_field_def, + "single_select": _single_select_field_def, + "multiple_select": _multiple_select_field_def, + "link_row": _link_row_field_def, +} + + +def get_field_definition(field_object: FieldObject) -> FieldDefinition: + """ + Return a :class:`FieldDefinition` for a table field, or an empty + (unsupported) definition if the field type has no registered builder. + """ + + orm_field_type = field_object["type"] + builder = _FIELD_DEF_BUILDERS.get(orm_field_type.type) + if builder is None: + return FieldDefinition() + return builder(field_object["field"], orm_field_type) + + +# --------------------------------------------------------------------------- +# Helpers shared by create / update models +# --------------------------------------------------------------------------- + +# field_conversions maps field names to (db_column, to_orm, from_orm) tuples. +FieldConversions = dict[str, tuple[str, Callable | None, Callable | None]] + + +def _scan_table_fields( + table: Table, field_ids: list[int] | None = None +) -> tuple[dict[str, tuple], FieldConversions]: + """ + Scan a table's fields and return Pydantic field specs plus ORM converters. + + :param table: The table to scan. + :param field_ids: If given, only include fields with these IDs. + :returns: ``(field_definitions, field_conversions)`` dicts keyed by field name. + """ + + field_definitions: dict[str, tuple] = {} + field_conversions: FieldConversions = {} + + for field_object in table.get_model().get_field_objects(): + fd = get_field_definition(field_object) + if fd.type is None: + continue + if field_ids is not None and field_object["field"].id not in field_ids: + continue + + field = field_object["field"] + field_definitions[field.name] = (fd.type, fd.field_def) + field_conversions[field.name] = ( + field.db_column, + fd.to_django_orm, + fd.from_django_orm, + ) + + return field_definitions, field_conversions + + +def _convert_fields( + items: dict[str, Any], field_conversions: FieldConversions +) -> dict[str, Any]: + """Convert a {field_name: value} mapping to {db_column: orm_value}.""" + + orm_data: dict[str, Any] = {} + for key, value in items.items(): + if key == "id": + orm_data["id"] = value + continue + if key not in field_conversions: + continue + orm_key, converter, _ = field_conversions[key] + orm_data[orm_key] = converter(value) if converter else value + return orm_data + + +# --------------------------------------------------------------------------- +# Row models +# --------------------------------------------------------------------------- + + +def get_create_row_model( + table: Table, field_ids: list[int] | None = None +) -> type[BaseModel]: + """ + Build a Pydantic model for creating rows in the given table. + + The returned model has a field for each supported column, with + ``to_django_orm()`` and ``from_django_orm()`` for ORM conversion. + + :param table: The table whose columns define the model fields. + :param field_ids: If given, only include these field IDs. + """ + + field_definitions, field_conversions = _scan_table_fields(table, field_ids) + + class CreateRowModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + def to_django_orm(self) -> dict[str, Any]: + return _convert_fields(self.__dict__, field_conversions) + + @classmethod + def from_django_orm( + cls, orm_row: GeneratedTableModel, field_ids: list[int] | None = None + ) -> "CreateRowModel": + init_data = {} + if "id" in cls.model_fields: + init_data["id"] = orm_row.id + for field_object in orm_row.get_field_objects(): + field = field_object["field"] + if field.name not in field_conversions: + continue + if field_ids is not None and field.id not in field_ids: + continue + db_column, _, from_django_orm = field_conversions[field.name] + value = getattr(orm_row, db_column) + init_data[field.name] = ( + from_django_orm(value) if from_django_orm else value + ) + return cls(**init_data) + + return create_model( + f"Table{table.id}Row", + __module__=__name__, + __base__=CreateRowModel, + **field_definitions, + ) + + +def get_update_row_model(table: Table) -> type[BaseModel]: + """ + Build a Pydantic model for updating rows in the given table. + + All fields are optional with ``default=None``; only fields explicitly + provided during construction are included in ``to_django_orm()`` output, + so omitting a field means "don't change". + + :param table: The table whose columns define the model fields. + """ + + create_model_class = get_create_row_model(table) + _, field_conversions = _scan_table_fields(table) + + # All fields become Optional with default=None + update_fields = { + name: ( + info.annotation | None, + Field(default=None, description=info.description, title=info.title), + ) + for name, info in create_model_class.model_fields.items() + } + update_fields["id"] = (int, Field(..., description="The ID of the row to update")) + + class UpdateRowModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + def to_django_orm(self) -> dict[str, Any]: + # Only convert explicitly provided fields (pydantic tracks this) + explicitly_set = { + k: getattr(self, k) for k in self.model_fields_set if k != "id" + } + orm_data = _convert_fields(explicitly_set, field_conversions) + orm_data["id"] = self.id + return orm_data + + return create_model( + f"UpdateTable{table.id}Row", + __module__=__name__, + __base__=UpdateRowModel, + **update_fields, + ) + + +def get_link_row_hints(row_model: type[BaseModel]) -> str: + """ + Collect link_row example hints from a row model's field descriptions. + + Returns a formatted string for inclusion in tool descriptions, or an + empty string if no link_row fields with examples are found. + + :param row_model: A row model built by :func:`get_create_row_model`. + """ + + hints: list[str] = [] + for name, info in row_model.model_fields.items(): + desc = info.description or "" + if "linked table" in desc and "Examples:" in desc: + hints.append(f"{name} ({info.title}): {desc}") + + if not hints: + return "" + return " LINK_ROW fields: " + "; ".join(hints) + "." diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/table.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/table.py index 3703e877bc..8203c162dd 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/table.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/table.py @@ -1,10 +1,12 @@ +import json + from django.db.models import Q -from pydantic import Field +from pydantic import Field, ValidationError, model_validator from baserow_enterprise.assistant.types import BaseModel -from .fields import AnyFieldItem, AnyFieldItemCreate +from .fields import _FIELD_EXAMPLES, _TYPE_ALIASES, FieldItem, FieldItemCreate class BaseTableItemCreate(BaseModel): @@ -22,50 +24,91 @@ class BaseTableItem(BaseTableItemCreate): class TableItemCreate(BaseTableItemCreate): """Model for creating a table with fields.""" - primary_field: AnyFieldItemCreate = Field( + primary_field_name: str = Field( ..., - description="The primary field of the table. Preferbly a text field with a sensible name for a primary field of the table.", - ) - fields: list[AnyFieldItemCreate] = Field( - ..., description="The fields of the table." + description="The name of the primary field (text field).", ) + fields: list[FieldItemCreate] = Field(..., description="The fields of the table.") + + @model_validator(mode="wrap") + @classmethod + def _validate_with_field_examples(cls, data, handler): + try: + return handler(data) + except ValidationError as exc: + if not isinstance(data, dict): + raise + + table_name = data.get("name", "unknown") + fields_data = data.get("fields", []) + if not isinstance(fields_data, list): + raise + + # Collect field indices that have errors + error_field_indices: set[int] = set() + for error in exc.errors(): + loc = error.get("loc", ()) + if len(loc) >= 2 and loc[0] == "fields" and isinstance(loc[1], int): + error_field_indices.add(loc[1]) + + if not error_field_indices: + raise # No field-level errors, re-raise as-is + + error_fields = [] + error_types: set[str] = set() + for idx in sorted(error_field_indices): + if idx < len(fields_data) and isinstance(fields_data[idx], dict): + fd = fields_data[idx] + fname = fd.get("name", f"fields[{idx}]") + ftype = str(fd.get("type", "unknown")) + ftype = _TYPE_ALIASES.get(ftype, ftype) + error_fields.append(f"'{fname}' ({ftype})") + if ftype in _FIELD_EXAMPLES: + error_types.add(ftype) + + if not error_fields: + raise + + parts = [ + f"Table '{table_name}': invalid fields: {', '.join(error_fields)}." + ] + for ft in sorted(error_types): + parts.append(f" {ft}: {json.dumps(_FIELD_EXAMPLES[ft])}") + + raise ValueError("\n".join(parts)) from None class TableItem(BaseTableItem): """Model for an existing table with fields.""" - primary_field: AnyFieldItem = Field( - ..., description="The primary field of the table." - ) - fields: list[AnyFieldItem] = Field(..., description="The fields of the table.") + primary_field: FieldItem = Field(..., description="The primary field of the table.") + fields: list[FieldItem] = Field(..., description="The fields of the table.") class ListTablesFilterArg(BaseModel): - database_ids: list[int] | None = Field( - default=None, - description="A list of database_ids to filter. None to exclude this filter", - ) - database_names: list[str] | None = Field( - default=None, - description="A list of database_names to filter. None to exclude this filter", - ) - table_ids: list[int] | None = Field( + database_id_or_name: int | str | None = Field( default=None, - description="A list of table ids to filter. None to exclude this filter", + description="The ID or name of the database to filter. null to exclude this filter.", ) - table_names: list[str] | None = Field( + table_ids_or_names: list[int | str] | None = Field( default=None, - description="A list of table names to filter. None to exclude this filter", + description="A list of table ids or names to filter in an OR fashion. null to exclude this filter.", ) def to_orm_filter(self) -> Q: q_filter = Q() - if self.database_ids: - q_filter &= Q(database_id__in=self.database_ids) - if self.database_names: - q_filter &= Q(database__name__in=self.database_names) - if self.table_ids: - q_filter &= Q(id__in=self.table_ids) - if self.table_names: - q_filter &= Q(name__in=self.table_names) + if isinstance(self.database_id_or_name, int): + q_filter &= Q(database_id=self.database_id_or_name) + elif isinstance(self.database_id_or_name, str): + q_filter &= Q(database__name__icontains=self.database_id_or_name) + if self.table_ids_or_names: + combined = Q() + ids = [item for item in self.table_ids_or_names if isinstance(item, int)] + names = [item for item in self.table_ids_or_names if isinstance(item, str)] + if ids: + combined |= Q(id__in=ids) + if names: + for name in names: + combined |= Q(name__icontains=name) + q_filter &= combined return q_filter diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/view_filters.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/view_filters.py index 343fc14773..f80a9875c3 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/view_filters.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/view_filters.py @@ -1,558 +1,238 @@ from typing import Literal -from pydantic import Field +from pydantic import Field, model_validator -from baserow_enterprise.assistant.types import Annotated, BaseModel +from baserow_enterprise.assistant.types import BaseModel -from .base import Date +from .base import parse_date +# --------------------------------------------------------------------------- +# Flat filter model +# --------------------------------------------------------------------------- -class ViewFilterItemCreate(BaseModel): - """Model for creating a new view filter (no ID).""" - - field_id: int = Field(...) - type: str = Field(...) - operator: str = Field(...) - value: str = Field(...) - - def get_django_orm_type(self, field, **kwargs) -> str: - return self.operator - - def get_django_orm_value(self, field, **kwargs) -> str: - return self.value - - -class ViewFilterItem(ViewFilterItemCreate): - """Model for an existing view filter (with ID).""" - - id: int = Field(..., description="The unique identifier of the view filter.") - - -class TextViewFilterItemCreate(ViewFilterItemCreate): - type: Literal["text"] = Field(..., description="A text filter.") - value: str = Field(..., description="The text value to filter on.") - - -class TextEqualViewFilterItemCreate(TextViewFilterItemCreate): - operator: Literal["equal"] = Field( - ..., description="Checks if the field is equal to the value." - ) - - -class TextEqualViewFilterItem(TextEqualViewFilterItemCreate, ViewFilterItem): - pass - - -class TextNotEqualViewFilterItemCreate(TextViewFilterItemCreate): - operator: Literal["not_equal"] = Field( - ..., description="Checks if the field is not equal to the value." - ) - - -class TextNotEqualViewFilterItem(TextNotEqualViewFilterItemCreate, ViewFilterItem): - pass - - -class TextContainsViewFilterItemCreate(TextViewFilterItemCreate): - operator: Literal["contains"] = Field( - ..., description="Checks if the field contains the value." - ) - - -class TextContainsViewFilterItem(TextContainsViewFilterItemCreate, ViewFilterItem): - pass - - -class TextNotContainsViewFilterItemCreate(TextViewFilterItemCreate): - operator: Literal["contains_not"] = Field( - ..., description="Checks if the field does not contain the value." - ) - - -class TextNotContainsViewFilterItem( - TextNotContainsViewFilterItemCreate, ViewFilterItem -): - pass - - -class TextEmptyViewFilterItemCreate(TextViewFilterItemCreate): - operator: Literal["empty"] = Field(..., description="Checks if the field is empty.") - - -class TextEmptyViewFilterItem(TextEmptyViewFilterItemCreate, ViewFilterItem): - pass - - -class TextNotEmptyViewFilterItemCreate(TextViewFilterItemCreate): - operator: Literal["not_empty"] = Field( - ..., description="Checks if the field is not empty." - ) - - -class TextNotEmptyViewFilterItem(TextNotEmptyViewFilterItemCreate, ViewFilterItem): - pass - - -AnyTextViewFilterItemCreate = Annotated[ - TextEqualViewFilterItemCreate - | TextNotEqualViewFilterItemCreate - | TextContainsViewFilterItemCreate - | TextNotContainsViewFilterItemCreate - | TextEmptyViewFilterItemCreate - | TextNotEmptyViewFilterItemCreate, - Field(discriminator="operator"), +FilterType = Literal[ + "text", "number", "date", "single_select", "multiple_select", "link_row", "boolean" ] -AnyTextViewFilterItem = Annotated[ - TextEqualViewFilterItem - | TextNotEqualViewFilterItem - | TextContainsViewFilterItem - | TextNotContainsViewFilterItem - | TextEmptyViewFilterItem - | TextNotEmptyViewFilterItem, - Field(discriminator="operator"), +_OPERATORS: dict[str, tuple[str, ...]] = { + "text": ("equal", "not_equal", "contains", "contains_not", "empty", "not_empty"), + "number": ("equal", "not_equal", "higher_than", "lower_than", "empty", "not_empty"), + "date": ("equal", "not_equal", "after", "before"), + "single_select": ("is_any_of", "is_none_of"), + "multiple_select": ("is_any_of", "is_none_of"), + "link_row": ("has", "has_not"), + "boolean": ("equal",), +} + +# Operator aliases: normalize LLM-natural names to Baserow names before validation. +_OPERATOR_ALIASES: dict[str, str] = { + "equals": "equal", + "is": "equal", + "not_equals": "not_equal", + "is_not": "not_equal", + "greater_than": "higher_than", + "greater_than_or_equal": "higher_than", # or_equal flag handles the rest + "less_than": "lower_than", + "less_than_or_equal": "lower_than", # or_equal flag handles the rest + "gte": "higher_than", + "lte": "lower_than", + "gt": "higher_than", + "lt": "lower_than", + "neq": "not_equal", + "ne": "not_equal", + "eq": "equal", +} + +DateFilterMode = Literal[ + "today", + "yesterday", + "tomorrow", + "this_week", + "last_week", + "next_week", + "this_month", + "last_month", + "next_month", + "this_year", + "last_year", + "next_year", + "nr_days_ago", + "nr_days_from_now", + "nr_weeks_ago", + "nr_weeks_from_now", + "nr_months_ago", + "nr_months_from_now", + "nr_years_ago", + "nr_years_from_now", + "exact_date", ] -class NumberViewFilterItemCreate(ViewFilterItemCreate): - type: Literal["number"] = Field(..., description="A number filter.") - value: float = Field(..., description="The number value to filter on.") - - def get_django_orm_value(self, field, **kwargs) -> str: - return str(self.value) - - -class NumberViewFilterItem(NumberViewFilterItemCreate, ViewFilterItem): - pass +# --------------------------------------------------------------------------- +# ORM type dispatch: (filter, field, **kwargs) -> str +# --------------------------------------------------------------------------- + +_NUMBER_OR_EQUAL = { + "higher_than": "higher_than_or_equal", + "lower_than": "lower_than_or_equal", +} + +_DATE_ORM_TYPE = { + "equal": "date_is", + "not_equal": "date_is_not", + "after": "date_is_after", + "before": "date_is_before", +} + +_DATE_OR_EQUAL = { + "after": "date_is_on_or_after", + "before": "date_is_on_or_before", +} + +_SINGLE_SELECT_ORM_TYPE = { + "is_any_of": "single_select_is_any_of", + "is_none_of": "single_select_is_none_of", +} + +_MULTIPLE_SELECT_ORM_TYPE = { + "is_any_of": "multiple_select_has", + "is_none_of": "multiple_select_has_not", +} + +_LINK_ROW_ORM_TYPE = { + "has": "link_row_has", + "has_not": "link_row_has_not", +} + +_GET_ORM_TYPE = { + "text": lambda f, field, **kw: f.operator, + "number": lambda f, field, **kw: ( + _NUMBER_OR_EQUAL.get(f.operator, f.operator) if f.or_equal else f.operator + ), + "date": lambda f, field, **kw: ( + _DATE_OR_EQUAL[f.operator] + if f.or_equal and f.operator in _DATE_OR_EQUAL + else _DATE_ORM_TYPE[f.operator] + ), + "single_select": lambda f, field, **kw: _SINGLE_SELECT_ORM_TYPE[f.operator], + "multiple_select": lambda f, field, **kw: _MULTIPLE_SELECT_ORM_TYPE[f.operator], + "link_row": lambda f, field, **kw: _LINK_ROW_ORM_TYPE[f.operator], + "boolean": lambda f, field, **kw: "equal", +} + + +# --------------------------------------------------------------------------- +# ORM value dispatch: (filter, field, **kwargs) -> str +# --------------------------------------------------------------------------- + + +def _select_orm_value(f, field, **kwargs): + values = set(v.lower() for v in f.value) + valid_option_ids = [ + option.id + for option in field.select_options.all() + if option.value.lower() in values + ] + return ",".join(str(v) for v in valid_option_ids) + + +def _date_orm_value(f, field, **kwargs): + timezone = kwargs.get("timezone", "UTC") + if isinstance(f.value, str): + value = parse_date(f.value).isoformat() + elif isinstance(f.value, int): + value = str(f.value) + else: + value = "" + return f"{timezone}?{value}?{f.mode}" + + +_GET_ORM_VALUE = { + "text": lambda f, field, **kw: f.value + if isinstance(f.value, str) + else str(f.value or ""), + "number": lambda f, field, **kw: str(f.value), + "date": _date_orm_value, + "single_select": _select_orm_value, + "multiple_select": _select_orm_value, + "link_row": lambda f, field, **kw: str(f.value), + "boolean": lambda f, field, **kw: "1" if f.value else "0", +} + + +# --------------------------------------------------------------------------- +# ViewFilterItemCreate +# --------------------------------------------------------------------------- -class NumberEqualsViewFilterItemCreate(NumberViewFilterItemCreate): - operator: Literal["equal"] = Field( - ..., description="Checks if the field is equal to the value." - ) - - -class NumberEqualsViewFilterItem(NumberEqualsViewFilterItemCreate, ViewFilterItem): - pass - - -class NumberNotEqualsViewFilterItemCreate(NumberViewFilterItemCreate): - operator: Literal["not_equal"] = Field( - ..., description="Checks if the field is not equal to the value." - ) - - -class NumberNotEqualsViewFilterItem( - NumberNotEqualsViewFilterItemCreate, ViewFilterItem -): - pass - - -class NumberHigherThanViewFilterItemCreate(NumberViewFilterItemCreate): - operator: Literal["higher_than"] = Field( - ..., description="Checks if the field is higher than the value." - ) - or_equal: bool = Field( - False, - description="If true, checks if the field is higher than or equal to the value.", - ) - - -class NumberHigherThanViewFilterItem( - NumberHigherThanViewFilterItemCreate, ViewFilterItem -): - pass - - -class NumberLowerThanViewFilterItemCreate(NumberViewFilterItemCreate): - operator: Literal["lower_than"] = Field( - ..., description="Checks if the field is lower than the value." - ) - or_equal: bool = Field( - False, - description="If true, checks if the field is lower than or equal to the value.", - ) - - -class NumberLowerThanViewFilterItem( - NumberLowerThanViewFilterItemCreate, ViewFilterItem -): - pass - - -class NumberEmptyViewFilterItemCreate(NumberViewFilterItemCreate): - operator: Literal["empty"] = Field(..., description="Checks if the field is empty.") - - -class NumberEmptyViewFilterItem(NumberEmptyViewFilterItemCreate, ViewFilterItem): - pass - - -class NumberNotEmptyViewFilterItemCreate(NumberViewFilterItemCreate): - operator: Literal["not_empty"] = Field( - ..., description="Checks if the field is not empty." - ) - - -class NumberNotEmptyViewFilterItem(NumberNotEmptyViewFilterItemCreate, ViewFilterItem): - pass - - -AnyNumberViewFilterItemCreate = Annotated[ - NumberEqualsViewFilterItemCreate - | NumberNotEqualsViewFilterItemCreate - | NumberHigherThanViewFilterItemCreate - | NumberLowerThanViewFilterItemCreate - | NumberEmptyViewFilterItemCreate - | NumberNotEmptyViewFilterItemCreate, - Field(discriminator="operator"), -] - -AnyNumberViewFilterItem = Annotated[ - NumberEqualsViewFilterItem - | NumberNotEqualsViewFilterItem - | NumberHigherThanViewFilterItem - | NumberLowerThanViewFilterItem - | NumberEmptyViewFilterItem - | NumberNotEmptyViewFilterItem, - Field(discriminator="operator"), -] - +class ViewFilterItemCreate(BaseModel): + """Flat model for creating a view filter: field_id + type + operator + value.""" -class DateViewFilterItemCreate(ViewFilterItemCreate): - type: Literal["date"] = Field(..., description="A date filter.") - value: Date | int | None = Field( + field_id: int = Field(..., description="Field ID to filter on.") + type: FilterType = Field(..., description="Must match field type.") + operator: str = Field( ..., - description="\n".join( - [ - "The date value to filter on.", - "Use an integer for days/weeks/months/years ago/from now.", - "Use a date object for an exact date.", - "None otherwise.", - ] + description=( + "Filter operator. " + "text: equal/not_equal/contains/contains_not/empty/not_empty. " + "number: equal/not_equal/greater_than/less_than/empty/not_empty " + "(use or_equal=true for ≥/≤). " + "date: equal/not_equal/after/before (use or_equal=true for on_or_after/on_or_before). " + "single_select/multiple_select: is_any_of/is_none_of. " + "link_row: has/has_not. " + "boolean: equal." ), ) - mode: Literal[ - "today", - "yesterday", - "tomorrow", - "this_week", - "last_week", - "next_week", - "this_month", - "last_month", - "next_month", - "this_year", - "last_year", - "next_year", - "nr_days_ago", - "nr_days_from_now", - "nr_weeks_ago", - "nr_weeks_from_now", - "nr_months_ago", - "nr_months_from_now", - "nr_years_ago", - "nr_years_from_now", - "exact_date", - ] = Field( - "exact_date", - description="The mode to use for the date filter. ALWAYS use the right mode if available. Use 'exact_date' if you have an exact date.", - ) - - def get_django_orm_value(self, field, **kwargs) -> str: - timezone = kwargs.get("timezone", "UTC") - - if isinstance(self.value, Date): - value = self.value.to_django_orm() - elif isinstance(self.value, int): - value = str(self.value) - else: - value = "" - - return f"{timezone}?{value}?{self.mode}" - - -class DateEqualsViewFilterItemCreate(DateViewFilterItemCreate): - operator: Literal["equal"] = Field( - ..., description="Checks if the field is equal to the value." - ) + value: str | float | int | bool | list[str] | None = Field( + None, + description="Filter value (type-dependent).", + ) + mode: DateFilterMode | None = Field(None, description="(date) Date filter mode.") + or_equal: bool = Field(False, description="(number, date) Include equal values.") + + @model_validator(mode="before") + @classmethod + def _normalize_operator(cls, data): + if isinstance(data, dict) and "operator" in data: + op = data["operator"] + normalized = _OPERATOR_ALIASES.get(op) + if normalized: + data = dict(data) + data["operator"] = normalized + # Auto-set or_equal for _or_equal variants + if "or_equal" in op: + data.setdefault("or_equal", True) + return data + + @model_validator(mode="after") + def _validate_per_type(self): + valid = _OPERATORS.get(self.type) + if valid and self.operator not in valid: + raise ValueError( + f"Invalid operator '{self.operator}' for type '{self.type}'. " + f"Valid operators: {', '.join(valid)}" + ) + if self.type == "date" and self.mode is None: + raise ValueError("date filter requires 'mode'.") + return self def get_django_orm_type(self, field, **kwargs) -> str: - return "date_is" - - -class DateEqualsViewFilterItem(DateEqualsViewFilterItemCreate, ViewFilterItem): - pass - - -class DateNotEqualsViewFilterItemCreate(DateViewFilterItemCreate): - operator: Literal["not_equal"] = Field( - ..., description="Checks if the field is not equal to the value." - ) - - def get_django_orm_type(self, field, **kwargs) -> str: - return "date_is_not" - - -class DateNotEqualsViewFilterItem(DateNotEqualsViewFilterItemCreate, ViewFilterItem): - pass - - -class DateAfterViewFilterItemCreate(DateViewFilterItemCreate): - operator: Literal["after"] = Field( - ..., description="Checks if the field is after the value." - ) - or_equal: bool = Field( - False, - description="If true, checks if the field is after or equal to the value.", - ) - - def get_django_orm_type(self, field, **kwargs) -> str: - return "date_is_on_or_after" if self.or_equal else "date_is_after" - - -class DateAfterViewFilterItem(DateAfterViewFilterItemCreate, ViewFilterItem): - pass - - -class DateBeforeViewFilterItemCreate(DateViewFilterItemCreate): - operator: Literal["before"] = Field( - ..., description="Checks if the field is before the value." - ) - or_equal: bool = Field( - False, - description="If true, checks if the field is before or equal to the value.", - ) - - def get_django_orm_type(self, field, **kwargs) -> str: - return "date_is_on_or_before" if self.or_equal else "date_is_before" - - -class DateBeforeViewFilterItem(DateBeforeViewFilterItemCreate, ViewFilterItem): - pass - - -AnyDateViewFilterItemCreate = Annotated[ - DateEqualsViewFilterItemCreate - | DateNotEqualsViewFilterItemCreate - | DateAfterViewFilterItemCreate - | DateBeforeViewFilterItemCreate, - Field(discriminator="operator"), -] -AnyDateViewFilterItem = Annotated[ - DateEqualsViewFilterItem - | DateNotEqualsViewFilterItem - | DateAfterViewFilterItem - | DateBeforeViewFilterItem, - Field(discriminator="operator"), -] - - -class SingleSelectViewFilterItemCreate(ViewFilterItemCreate): - type: Literal["single_select"] = Field(..., description="A single select filter.") - value: list[str] = Field( - ..., description="The select option value(s) to filter on." - ) - - def get_django_orm_value(self, field, **kwargs) -> str: - values = set(v.lower() for v in self.value) - valid_option_ids = [ - option.id - for option in field.select_options.all() - if option.value.lower() in values - ] - return ",".join([str(v) for v in valid_option_ids]) - - -class SingleSelectIsAnyViewFilterItemCreate(SingleSelectViewFilterItemCreate): - operator: Literal["is_any_of"] = Field( - ..., description="Checks if the field is equal to any of the values " - ) - - def get_django_orm_type(self, field, **kwargs): - return "single_select_is_any_of" - - -class SingleSelectIsAnyViewFilterItem( - SingleSelectIsAnyViewFilterItemCreate, ViewFilterItem -): - pass - - -class SingleSelectIsNoneOfNotViewFilterItemCreate(SingleSelectViewFilterItemCreate): - operator: Literal["is_none_of"] = Field( - ..., description="Checks if the field is not equal to the value." - ) - - def get_django_orm_type(self, field, **kwargs): - return "single_select_is_none_of" - - -class SingleSelectIsNoneOfNotViewFilterItem( - SingleSelectIsNoneOfNotViewFilterItemCreate, ViewFilterItem -): - pass - - -AnySingleSelectViewFilterItemCreate = Annotated[ - SingleSelectIsAnyViewFilterItemCreate | SingleSelectIsNoneOfNotViewFilterItemCreate, - Field(discriminator="operator"), -] - -AnySingleSelectViewFilterItem = Annotated[ - SingleSelectIsAnyViewFilterItem | SingleSelectIsNoneOfNotViewFilterItem, - Field(discriminator="operator"), -] - - -class MultipleSelectViewFilterItemCreate(ViewFilterItemCreate): - type: Literal["multiple_select"] = Field( - ..., description="A multiple select filter." - ) - value: list[str] = Field( - ..., description="The select option value(s) to filter on." - ) + return _GET_ORM_TYPE[self.type](self, field, **kwargs) def get_django_orm_value(self, field, **kwargs) -> str: - values = set(v.lower() for v in self.value) - valid_option_ids = [ - option.id - for option in field.select_options.all() - if option.value.lower() in values - ] - return ",".join([str(v) for v in valid_option_ids]) - - -class MultipleSelectIsAnyViewFilterItemCreate(MultipleSelectViewFilterItemCreate): - operator: Literal["is_any_of"] = Field( - ..., description="Checks if the field is equal to any of the values " - ) - - def get_django_orm_type(self, field, **kwargs): - return "multiple_select_has" - - -class MultipleSelectIsAnyViewFilterItem( - MultipleSelectIsAnyViewFilterItemCreate, ViewFilterItem -): - pass - - -class MultipleSelectIsNoneOfNotViewFilterItemCreate(MultipleSelectViewFilterItemCreate): - operator: Literal["is_none_of"] = Field( - ..., description="Checks if the field is not equal to the value." - ) - - def get_django_orm_type(self, field, **kwargs): - return "multiple_select_has_not" - - -class MultipleSelectIsNoneOfNotViewFilterItem( - MultipleSelectIsNoneOfNotViewFilterItemCreate, ViewFilterItem -): - pass - - -AnyMultipleSelectViewFilterItemCreate = Annotated[ - MultipleSelectIsAnyViewFilterItemCreate - | MultipleSelectIsNoneOfNotViewFilterItemCreate, - Field(discriminator="operator"), -] - -AnyMultipleSelectViewFilterItem = Annotated[ - MultipleSelectIsAnyViewFilterItem | MultipleSelectIsNoneOfNotViewFilterItem, - Field(discriminator="operator"), -] - - -class LinkRowViewFilterItemCreate(ViewFilterItemCreate): - type: Literal["link_row"] = Field(..., description="A link row filter.") - value: int = Field(..., description="The linked record ID to filter on.") - - def get_django_orm_value(self, field, **kwargs) -> str: - return str(self.value) - - -class LinkRowHasViewFilterItemCreate(LinkRowViewFilterItemCreate): - operator: Literal["has"] = Field( - ..., description="Checks if the field has the linked record." - ) - - def get_django_orm_type(self, field, **kwargs): - return "link_row_has" - + return _GET_ORM_VALUE[self.type](self, field, **kwargs) -class LinkRowHasViewFilterItem(LinkRowHasViewFilterItemCreate, ViewFilterItem): - pass +class ViewFilterItem(ViewFilterItemCreate): + """Existing view filter with ID.""" -class LinkRowHasNotViewFilterItemCreate(LinkRowViewFilterItemCreate): - operator: Literal["has_not"] = Field( - ..., description="Checks if the field does not have the linked record." - ) - - def get_django_orm_type(self, field, **kwargs): - return "link_row_has_not" - - -class LinkRowHasNotViewFilterItem(LinkRowHasNotViewFilterItemCreate, ViewFilterItem): - pass - - -AnyLinkRowViewFilterItemCreate = Annotated[ - LinkRowHasViewFilterItemCreate | LinkRowHasNotViewFilterItemCreate, - Field(discriminator="operator"), -] - -AnyLinkRowViewFilterItem = Annotated[ - LinkRowHasViewFilterItem | LinkRowHasNotViewFilterItem, - Field(discriminator="operator"), -] - - -class BooleanViewFilterItemCreate(ViewFilterItemCreate): - type: Literal["boolean"] = Field(..., description="A boolean filter.") - value: bool = Field(..., description="The boolean value to filter on.") - - def get_django_orm_value(self, field, **kwargs) -> str: - return "1" if self.value else "0" - - -class BooleanIsViewFilterItemCreate(BooleanViewFilterItemCreate): - operator: Literal["is"] = Field(..., description="Checks if the field is true.") - value: bool = Field(..., description="The boolean value to filter on.") - - def get_django_orm_type(self, field, **kwargs) -> str: - return "boolean" - - -class BooleanIsTrueViewFilterItem(BooleanIsViewFilterItemCreate, ViewFilterItem): - pass - + id: int = Field(..., description="The unique identifier of the view filter.") -AnyViewFilterItemCreate = Annotated[ - AnyTextViewFilterItemCreate - | AnyNumberViewFilterItemCreate - | AnyDateViewFilterItemCreate - | AnySingleSelectViewFilterItemCreate - | AnyLinkRowViewFilterItemCreate - | BooleanViewFilterItemCreate - | MultipleSelectViewFilterItemCreate, - Field(discriminator="type"), -] -AnyViewFilterItem = Annotated[ - AnyTextViewFilterItem - | AnyNumberViewFilterItem - | AnyDateViewFilterItem - | AnySingleSelectViewFilterItem - | AnyLinkRowViewFilterItem - | BooleanIsTrueViewFilterItem - | MultipleSelectIsAnyViewFilterItem, - Field(discriminator="type"), -] +AnyViewFilterItemCreate = ViewFilterItemCreate +AnyViewFilterItem = ViewFilterItem class ViewFiltersArgs(BaseModel): view_id: int - filters: list[AnyViewFilterItemCreate] + filters: list[ViewFilterItemCreate] diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/views.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/views.py index 021e4916a1..72d9efd74b 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/views.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/views.py @@ -1,300 +1,269 @@ -from typing import Annotated, Literal, Type +import json +from typing import Any, Literal -from pydantic import Field +from pydantic import Field, model_serializer, model_validator from baserow.contrib.database.fields.models import ( DateField, FileField, SingleSelectField, ) -from baserow.contrib.database.views.models import FormView, GalleryView, GridView from baserow.contrib.database.views.models import View as BaserowView from baserow.contrib.database.views.registries import view_type_registry from baserow_enterprise.assistant.types import BaseModel from baserow_premium.permission_manager import Table -from baserow_premium.views.models import CalendarView, KanbanView, TimelineView +# --------------------------------------------------------------------------- +# Shared types +# --------------------------------------------------------------------------- -class ViewItemCreate(BaseModel): - name: str = Field( - ..., - description="A sensible name for the view (i.e. 'Pending payments', 'Completed tasks', etc.).", - ) - public: bool = Field( - default=False, - description="Whether the view is publicly accessible. False unless specified.", - ) - - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - return { - "name": self.name, - "public": self.public, - } - - def field_options_to_django_orm(self) -> dict[str, any]: - return {} - - -class ViewItem(BaseModel): - id: int = Field(...) - @classmethod - def from_django_orm(cls, orm_view: Type[BaserowView]) -> "ViewItem": - return cls( - id=orm_view.id, - name=orm_view.name, - public=orm_view.public, - ) +class FormFieldOption(BaseModel): + field_id: int = Field(..., description="Field ID.") + name: str = Field(..., description="Display name in form.") + description: str = Field(..., description="Field description, or ''.") + required: bool = Field(..., description="Required?") + order: int = Field(..., description="Sort order.") class GridFieldOption(BaseModel): field_id: int = Field(...) width: int = Field( - default=200, - description="The width of the field in the grid view. Default is 200.", + ..., + description="The width of the field in the grid view (e.g. 200).", ) hidden: bool = Field( - default=False, - description="Whether the field is hidden in the grid view. Default is False.", - ) - - -class GridViewItemCreate(ViewItemCreate): - type: Literal["grid"] = Field(..., description="A grid view.") - row_height: Literal["small", "medium", "large"] = Field( - default="small", - description=( - "The height of the rows in the view. Can be 'small', 'medium' or 'large'. Default is 'small'." - ), - ) - - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - return { - **super().to_django_orm_kwargs(table), - "row_height": self.row_height, - } - - -class GridViewItem(GridViewItemCreate, ViewItem): - @classmethod - def from_django_orm(cls, orm_view: GridView) -> "GridViewItem": - return cls( - id=orm_view.id, - name=orm_view.name, - type="grid", - row_height="small", - public=orm_view.public, - ) - - -class KanbanViewItemCreate(ViewItemCreate): - type: Literal["kanban"] = Field(..., description="A kanban view.") - column_field_id: int | None = Field( ..., - description="The ID of the field to use for the kanban columns. Must be a single select field. None if no single select field is available.", + description="Whether the field is hidden in the grid view.", ) - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - model = table.get_model() - column_field = model.get_field_object_by_id(self.column_field_id)["field"] - if not isinstance(column_field, SingleSelectField): - raise ValueError("The column_field_id must be a Single Select field.") - return { - **super().to_django_orm_kwargs(table), - "single_select_field": column_field, - } +# --------------------------------------------------------------------------- +# Flat view types +# --------------------------------------------------------------------------- + +ViewType = Literal["grid", "kanban", "calendar", "gallery", "timeline", "form"] + +_VIEW_EXAMPLES: dict[str, dict] = { + "grid": {"name": "All Items", "type": "grid", "row_height": "small"}, + "kanban": {"name": "Board", "type": "kanban", "column_field_id": 123}, + "calendar": {"name": "Schedule", "type": "calendar", "date_field_id": 456}, + "gallery": {"name": "Photos", "type": "gallery", "cover_field_id": 789}, + "timeline": { + "name": "Project Timeline", + "type": "timeline", + "start_date_field_id": 111, + "end_date_field_id": 222, + }, + "form": { + "name": "Contact Form", + "type": "form", + "title": "Contact Us", + "description": "", + "submit_button_label": "Submit", + "receive_notification_on_submit": False, + "submit_action": "MESSAGE", + "submit_action_message": "Thank you!", + "submit_action_redirect_url": "", + "field_options": [ + { + "field_id": 1, + "name": "Name", + "description": "", + "required": True, + "order": 1, + } + ], + }, +} + + +# --------------------------------------------------------------------------- +# to_django_orm builders: (ViewItemCreate, Table) -> dict +# --------------------------------------------------------------------------- + + +def _grid_to_orm(v, table): + return {"row_height": v.row_height} + + +def _kanban_to_orm(v, table): + model = table.get_model() + column_field = model.get_field_object_by_id(v.column_field_id)["field"] + if not isinstance(column_field, SingleSelectField): + raise ValueError("The column_field_id must be a Single Select field.") + return {"single_select_field": column_field} + + +def _calendar_to_orm(v, table): + model = table.get_model() + date_field = model.get_field_object_by_id(v.date_field_id)["field"] + if not isinstance(date_field, DateField): + raise ValueError("The date_field_id must be a Date field.") + return {"date_field": date_field} + + +def _gallery_to_orm(v, table): + model = table.get_model() + cover_field = model.get_field_object_by_id(v.cover_field_id)["field"] + if not isinstance(cover_field, FileField): + raise ValueError("The cover_field_id must be a File field.") + return {"card_cover_image_field_id": v.cover_field_id} + + +def _timeline_to_orm(v, table): + model = table.get_model() + start_field = model.get_field_object_by_id(v.start_date_field_id)["field"] + end_field = model.get_field_object_by_id(v.end_date_field_id)["field"] + if ( + not isinstance(start_field, DateField) + or not isinstance(end_field, DateField) + or start_field.id == end_field.id + or start_field.date_include_time != end_field.date_include_time + ): + raise ValueError( + "Invalid timeline configuration: both start and end fields must be Date fields " + "and they must have the same include_time setting (either both include time or " + "both are date-only). " + ) + return {"start_date_field": start_field, "end_date_field": end_field} -class KanbanViewItem(KanbanViewItemCreate, ViewItem): - @classmethod - def from_django_orm(cls, orm_view: KanbanView) -> "KanbanViewItem": - return cls( - id=orm_view.id, - name=orm_view.name, - type="kanban", - column_field_id=orm_view.single_select_field_id, - public=orm_view.public, - ) +def _form_to_orm(v, table): + return {"title": v.title, "description": v.description} -class CalendarViewItemCreate(ViewItemCreate): - type: Literal["calendar"] = Field(..., description="A calendar view.") - date_field_id: int | None = Field( - ..., - description="The ID of the field to use for the calendar dates. Must be a date field. None if no date field is available.", - ) +_TO_DJANGO_ORM = { + "grid": _grid_to_orm, + "kanban": _kanban_to_orm, + "calendar": _calendar_to_orm, + "gallery": _gallery_to_orm, + "timeline": _timeline_to_orm, + "form": _form_to_orm, +} - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - model = table.get_model() - date_field = model.get_field_object_by_id(self.date_field_id)["field"] - if not isinstance(date_field, DateField): - raise ValueError("The date_field_id must be a Date field.") - return { - **super().to_django_orm_kwargs(table), - "date_field": date_field, - } +# --------------------------------------------------------------------------- +# from_django_orm builders: (orm_view) -> dict of extra kwargs +# --------------------------------------------------------------------------- -class CalendarViewItem(CalendarViewItemCreate, ViewItem): - @classmethod - def from_django_orm(cls, orm_view: CalendarView) -> "CalendarViewItem": - return cls( - id=orm_view.id, - name=orm_view.name, - type="calendar", - date_field_id=orm_view.date_field_id, - public=orm_view.public, +def _form_field_options_from_orm(orm_view): + return [ + FormFieldOption( + field_id=fo.field_id, + name=fo.name, + description=fo.description, + required=fo.required, + order=fo.order, ) + for fo in orm_view.active_field_options.all() + ] -class BaseGalleryViewItem(ViewItemCreate): - type: Literal["gallery"] = Field(..., description="A gallery view.") - cover_field_id: int | None = Field( - default=None, - description=( - "The ID of the field to use for the gallery cover image. Must be a file field. None if no file field is available." - ), - ) +_FROM_DJANGO_ORM: dict[str, Any] = { + "grid": lambda v: {"row_height": "small"}, + "kanban": lambda v: {"column_field_id": v.single_select_field_id}, + "calendar": lambda v: {"date_field_id": v.date_field_id}, + "gallery": lambda v: {"cover_field_id": v.card_cover_image_field_id}, + "timeline": lambda v: { + "start_date_field_id": v.start_date_field_id, + "end_date_field_id": v.end_date_field_id, + }, + "form": lambda v: { + "title": v.title, + "description": v.description, + "field_options": _form_field_options_from_orm(v), + }, +} -class GalleryViewItemCreate(BaseGalleryViewItem): - def to_django_orm_kwargs(self, table): - model = table.get_model() - cover_field = model.get_field_object_by_id(self.cover_field_id)["field"] - if not isinstance(cover_field, FileField): - raise ValueError("The cover_field_id must be a File field.") - - return { - **super().to_django_orm_kwargs(table), - "card_cover_image_field_id": self.cover_field_id, - } +# --------------------------------------------------------------------------- +# ViewItemCreate +# --------------------------------------------------------------------------- -class GalleryViewItem(BaseGalleryViewItem, ViewItem): - @classmethod - def from_django_orm(cls, orm_view: GalleryView) -> "GalleryViewItem": - return cls( - id=orm_view.id, - name=orm_view.name, - type="gallery", - cover_field_id=orm_view.card_cover_image_field_id, - public=orm_view.public, - ) +class ViewItemCreate(BaseModel): + """Flat model for creating a view: name + type + type-specific options.""" + name: str = Field(..., description="Descriptive view name.") + public: bool = Field(..., description="Publicly accessible? Default false.") + type: ViewType = Field(..., description="View type.") -class BaseTimelineViewItem(ViewItemCreate): - type: Literal["timeline"] = Field(..., description="A timeline view.") - start_date_field_id: int | None = Field( - ..., - description="The ID of the field to use for the timeline dates. Must be a date field. None if no date field is available.", + # -- grid -- + row_height: Literal["small", "medium", "large"] = Field( + "small", description="(grid) Row height." ) - end_date_field_id: int | None = Field( - ..., - description=( - "The ID of the field to use for the timeline end dates. Must be a date field. None if no date field is available." - ), + # -- kanban -- + column_field_id: int | None = Field( + None, description="(kanban) Single-select field ID for columns." ) - - -class TimelineViewItemCreate(BaseTimelineViewItem): - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - model = table.get_model() - start_field = model.get_field_object_by_id(self.start_date_field_id)["field"] - end_field = model.get_field_object_by_id(self.end_date_field_id)["field"] - if ( - not isinstance(start_field, DateField) - or not isinstance(end_field, DateField) - or start_field.id == end_field.id - or start_field.date_include_time != end_field.date_include_time - ): - raise ValueError( - "Invalid timeline configuration: both start and end fields must be Date fields " - "and they must have the same include_time setting (either both include time or " - "both are date-only). " - ) - - return { - **super().to_django_orm_kwargs(table), - "start_date_field": start_field, - "end_date_field": end_field, - } - - -class TimelineViewItem(BaseTimelineViewItem, ViewItem): - @classmethod - def from_django_orm(cls, orm_view: TimelineView) -> "TimelineViewItem": - return cls( - id=orm_view.id, - name=orm_view.name, - type="timeline", - start_date_field_id=orm_view.start_date_field_id, - end_date_field_id=orm_view.end_date_field_id, - public=orm_view.public, - ) - - -class FormFieldOption(BaseModel): - field_id: int = Field(..., description="The ID of the field.") - name: str = Field(..., description="The name to show for the field in the form.") - description: str = Field( - default="", description="The description to show for the field in the form." + # -- calendar -- + date_field_id: int | None = Field(None, description="(calendar) Date field ID.") + # -- gallery -- + cover_field_id: int | None = Field( + None, description="(gallery) File field ID for covers." ) - required: bool = Field( - default=True, - description="Whether the field is required in the form. Default is True.", + # -- timeline -- + start_date_field_id: int | None = Field( + None, description="(timeline) Start date field ID." ) - order: int = Field(..., description="The order of the field in the form.") - - -class BaseFormViewItem(ViewItemCreate): - type: Literal["form"] = Field(..., description="A form view.") - title: str = Field(..., description="The title of the form.") - description: str = Field(..., description="The description of the form.") - submit_button_label: str = Field( - default="Submit", description="The label of the submit button." + end_date_field_id: int | None = Field( + None, description="(timeline) End date field ID." ) + # -- form -- + title: str = Field("", description="(form) Title, or ''.") + description: str = Field("", description="(form) Description, or ''.") + submit_button_label: str = Field("Submit", description="(form) Button label.") receive_notification_on_submit: bool = Field( - default=False, - description=( - "Whether to receive an email notification when the form is submitted." - ), + False, description="(form) Email on submit." ) submit_action: Literal["MESSAGE", "REDIRECT"] = Field( - default="MESSAGE", - description="The action to perform when the form is submitted.", - ) - submit_action_message: str = Field( - default="", - description=( - "The message to display when the form is submitted and the action is 'MESSAGE'." - ), + "MESSAGE", description="(form) 'MESSAGE' or 'REDIRECT'." ) + submit_action_message: str = Field("", description="(form) Message after submit.") submit_action_redirect_url: str = Field( - default="", - description=( - "The URL to redirect to when the form is submitted and the action is 'REDIRECT'." - ), + "", description="(form) Redirect URL after submit." ) - - field_options: list[FormFieldOption] = Field( - ..., - description=( - "The list of fields to show in the form, along with their options. The fields must be part of the table." - ), + field_options: list[FormFieldOption] | None = Field( + None, + description="(form) Fields to show (OPT-IN: include all you want visible).", ) + # Required fields per type: {type: [(attr_name, display_name), ...]} + _REQUIRED_FIELDS: dict[str, list[tuple[str, str]]] = { + "kanban": [("column_field_id", "column_field_id")], + "calendar": [("date_field_id", "date_field_id")], + "gallery": [("cover_field_id", "cover_field_id")], + "timeline": [ + ("start_date_field_id", "start_date_field_id"), + ("end_date_field_id", "end_date_field_id"), + ], + "form": [("field_options", "field_options")], + } -class FormViewItemCreate(BaseFormViewItem): - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - return { - **super().to_django_orm_kwargs(table), - "title": self.title, - "description": self.description, - } - - def field_options_to_django_orm(self): + @model_validator(mode="after") + def _validate_required_for_type(self): + required = self._REQUIRED_FIELDS.get(self.type) + if required: + missing = [name for attr, name in required if not getattr(self, attr)] + if missing: + raise ValueError( + f"{self.type} requires {', '.join(missing)}. " + f"Example: {json.dumps(_VIEW_EXAMPLES[self.type])}" + ) + return self + + def to_django_orm_kwargs(self, table: Table) -> dict[str, Any]: + base = {"name": self.name, "public": self.public} + builder = _TO_DJANGO_ORM.get(self.type) + if builder: + base.update(builder(self, table)) + return base + + def field_options_to_django_orm(self) -> dict[str, Any]: + if self.type != "form" or not self.field_options: + return {} return { fo.field_id: { "enabled": True, @@ -307,64 +276,44 @@ def field_options_to_django_orm(self): } -class FormViewItem(FormViewItemCreate, ViewItem): - @classmethod - def from_django_orm(cls, orm_view: FormView) -> "FormViewItem": - return cls( - id=orm_view.id, - name=orm_view.name, - type="form", - public=orm_view.public, - title=orm_view.title, - description=orm_view.description, - field_options=[ - FormFieldOption( - field_id=fo.field_id, - name=fo.name, - description=fo.description, - required=fo.required, - order=fo.order, - ) - for fo in orm_view.active_field_options.all() - ], - ) +# --------------------------------------------------------------------------- +# ViewItem (read-back) +# --------------------------------------------------------------------------- -AnyViewItemCreate = Annotated[ - GridViewItemCreate - | KanbanViewItemCreate - | CalendarViewItemCreate - | GalleryViewItemCreate - | TimelineViewItemCreate - | FormViewItemCreate, - Field(discriminator="type"), -] - -AnyViewItem = Annotated[ - GridViewItem - | KanbanViewItem - | CalendarViewItem - | GalleryViewItem - | TimelineViewItem - | FormViewItem, - Field(discriminator="type"), -] - - -class ViewItemsRegistry: - _registry = { - "grid": GridViewItem, - "kanban": KanbanViewItem, - "calendar": CalendarViewItem, - "gallery": GalleryViewItem, - "timeline": TimelineViewItem, - "form": FormViewItem, - } - - def from_django_orm(self, orm_view: Type[BaserowView]) -> ViewItem: - view_type = view_type_registry.get_by_model(orm_view).type - view_class: ViewItem = self._registry.get(view_type, ViewItem) - return view_class.from_django_orm(orm_view) +class ViewItem(BaseModel): + """Existing view with ID — flat structure matching ViewItemCreate.""" + id: int = Field(...) + name: str = Field(...) + public: bool = Field(...) + type: str = Field(...) + + # Type-specific (populated per type, others excluded via serializer) + row_height: str | None = None + column_field_id: int | None = None + date_field_id: int | None = None + cover_field_id: int | None = None + start_date_field_id: int | None = None + end_date_field_id: int | None = None + title: str | None = None + description: str | None = None + field_options: list[FormFieldOption] | None = None + + @model_serializer(mode="wrap") + def _exclude_none(self, handler): + return {k: v for k, v in handler(self).items() if v is not None} -view_item_registry = ViewItemsRegistry() + @classmethod + def from_django_orm(cls, orm_view: BaserowView) -> "ViewItem": + view_type = view_type_registry.get_by_model(orm_view).type + kwargs: dict[str, Any] = { + "id": orm_view.id, + "name": orm_view.name, + "public": orm_view.public, + "type": view_type, + } + builder = _FROM_DJANGO_ORM.get(view_type) + if builder: + kwargs.update(builder(orm_view)) + return cls(**kwargs) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/utils.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/utils.py deleted file mode 100644 index 2cc186c093..0000000000 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/utils.py +++ /dev/null @@ -1,559 +0,0 @@ -from dataclasses import dataclass -from itertools import groupby -from typing import TYPE_CHECKING, Any, Callable, Literal, Type, Union - -from django.contrib.auth.models import AbstractUser -from django.core.exceptions import ValidationError -from django.db import transaction -from django.db.models import Q, QuerySet -from django.utils.translation import gettext as _ - -import udspy -from pydantic import ConfigDict, Field, create_model -from udspy.utils import minimize_schema, resolve_json_schema_reference - -from baserow.contrib.database.fields.actions import CreateFieldActionType -from baserow.contrib.database.fields.field_types import LinkRowFieldType -from baserow.contrib.database.fields.handler import FieldHandler -from baserow.contrib.database.fields.models import SelectOption as OrmSelectOption -from baserow.contrib.database.fields.registries import field_type_registry -from baserow.contrib.database.rows.actions import ( - CreateRowsActionType, - DeleteRowsActionType, - UpdateRowsActionType, -) -from baserow.contrib.database.table.handler import TableHandler -from baserow.contrib.database.table.models import ( - FieldObject, - GeneratedTableModel, - Table, -) -from baserow.contrib.database.views.actions import CreateViewFilterActionType -from baserow.contrib.database.views.handler import ViewHandler -from baserow.contrib.database.views.models import View, ViewFilter -from baserow.core.db import specific_iterator -from baserow.core.models import Workspace -from baserow_enterprise.assistant.tools.database.types.table import ( - BaseTableItem, - TableItem, -) - -from .types import ( - AnyFieldItem, - AnyFieldItemCreate, - AnyViewFilterItemCreate, - BaseModel, - Date, - Datetime, - field_item_registry, -) - -if TYPE_CHECKING: - from baserow_enterprise.assistant.assistant import ToolHelpers - -NoChange = Literal["__NO_CHANGE__"] - - -def filter_tables(user: AbstractUser, workspace: Workspace) -> QuerySet[Table]: - return TableHandler().list_workspace_tables(user, workspace) - - -def list_tables( - user: AbstractUser, workspace: Workspace, database_id: int -) -> list[BaseTableItem]: - tables_qs = filter_tables(user, workspace).filter(database_id=database_id) - - return [BaseTableItem(id=table.id, name=table.name) for table in tables_qs] - - -def get_tables_schema( - tables: list[Table], - full_schema: bool = False, -) -> list[TableItem]: - """Returns the schema of the specified tables.""" - - q = Q(table__in=tables) - if not full_schema: # Only the primary fields and relationships - q &= Q(linkrowfield__isnull=False) | Q(primary=True) - - base_field_queryset = FieldHandler().get_base_fields_queryset() - fields = specific_iterator( - base_field_queryset.filter(q).order_by("table_id", "order"), - per_content_type_queryset_hook=( - lambda field, queryset: field_type_registry.get_by_model( - field - ).enhance_field_queryset(queryset, field) - ), - ) - - table_items = [] - for table_id, fields_in_table in groupby(fields, lambda f: f.table_id): - fields_in_table = list(fields_in_table) - table = next(t for t in tables if t.id == table_id) - primary_field = next(f for f in fields if f.primary) - primary_field_item = field_item_registry.from_django_orm(primary_field) - - table_items.append( - TableItem( - id=table.id, - name=table.name, - primary_field=primary_field_item, - fields=[ - field_item_registry.from_django_orm(f) - for f in fields_in_table - if f.id != primary_field.id - ], - ) - ) - - # Make sure the order is the same as the input - tables = list(tables) - table_items.sort( - key=lambda t: tables.index(next(tb for tb in tables if tb.id == t.id)) - ) - - return table_items - - -def create_fields( - user: AbstractUser, - table: Table, - field_items: list[AnyFieldItemCreate], - tool_helpers: "ToolHelpers", -) -> list[AnyFieldItem]: - created_fields = [] - for field_item in field_items: - tool_helpers.update_status( - _("Creating field %(field_name)s...") % {"field_name": field_item.name} - ) - - new_field = CreateFieldActionType.do( - user, - table, - field_item.type, - **field_item.to_django_orm_kwargs(table), - ) - created_fields.append(field_item_registry.from_django_orm(new_field)) - return created_fields - - -@dataclass -class FieldDefinition: - type: Type | None = None - field_def: Any | None = None - to_django_orm: Callable[[Any], Any] | None = None - from_django_orm: Callable[[Any], Any] | None = None - - -def _get_pydantic_field_definition( - field_object: FieldObject, -) -> FieldDefinition: - """ - Returns the Pydantic field type and definition for the given field object. - """ - - orm_field = field_object["field"] - orm_field_type = field_object["type"] - - match orm_field_type.type: - case "text": - return FieldDefinition( - str | None, - Field(..., description="Single-line text", title=orm_field.name), - lambda v: v if v is not None else "", - lambda v: v if v is not None else "", - ) - - case "long_text": - return FieldDefinition( - str | None, - Field(..., description="Multi-line text", title=orm_field.name), - lambda v: v if v is not None else "", - lambda v: v if v is not None else "", - ) - case "number": - return FieldDefinition( - float | None, - Field(..., description="Number or None", title=orm_field.name), - ) - case "boolean": - return FieldDefinition( - bool, Field(..., description="Boolean", title=orm_field.name) - ) - case "date": - if orm_field.date_include_time: - return FieldDefinition( - Datetime | None, - Field(..., description="Datetime or None", title=orm_field.name), - lambda v: v.to_django_orm() if v else None, - lambda v: Datetime.from_django_orm(v) if v is not None else None, - ) - else: - return FieldDefinition( - Date | None, - Field(..., description="Date or None", title=orm_field.name), - lambda v: v.to_django_orm() if v else None, - lambda v: Date.from_django_orm(v) if v is not None else None, - ) - case "single_select": - choices = [option.value for option in orm_field.select_options.all()] - - return FieldDefinition( - Literal[*choices] | None, - Field( - ..., - description=f"One of: {', '.join(choices)} or None", - title=orm_field.name, - ), - lambda v: v if v in choices else None, - lambda v: v.value if isinstance(v, OrmSelectOption) else v, - ) - case "multiple_select": - choices = [option.value for option in orm_field.select_options.all()] - - return FieldDefinition( - list[Literal[*choices]], - Field( - ..., - description=f"List of any of: {', '.join(choices)} or empty list", - title=orm_field.name, - ), - lambda v: [opt for opt in v if opt in choices], - lambda v: [opt.value for opt in v.all()] if v is not None else None, - ) - case "link_row": - linked_model = orm_field.link_row_table.get_model() - linked_primary_key = linked_model.get_primary_field() - - # If there's no primary key, we can't safely work with this field - if linked_primary_key is None: - return FieldDefinition() # Unsupported field type - - # Avoid null or empty values - linked_pk = linked_primary_key.db_column - linked_values = list( - linked_model.objects.exclude( - Q(**{f"{linked_pk}__isnull": True}) - | Q(**{f"{linked_pk}__exact": ""}) - ).values_list(linked_pk, flat=True)[:10] - ) - examples = f"Examples: {', '.join([str(v) for v in linked_values])}" - - def to_django_orm(value): - if isinstance(value, str) or isinstance(value, int): - value = [value] - if value is not None: - try: - return LinkRowFieldType().prepare_value_for_db(orm_field, value) - except ValidationError: - pass - return [] - - def from_django_orm(value): - values = [str(v) for v in value.all()] - if orm_field.link_row_multiple_relationships: - return values - else: - return values[0] if values else None - - # TODO: verify this can work with every possible primary field type - if orm_field.link_row_multiple_relationships: - desc = "List of values (as strings) or IDs (as integers) from the linked table or empty list." - field_type = list[str | int] | None - else: - desc = "Single value (as string) or ID (as integer) from the linked table or empty list." - field_type = str | int | None - if examples: - desc += " " + examples - return FieldDefinition( - field_type, - Field(None, description=desc, title=orm_field.name), - to_django_orm, - from_django_orm, - ) - - case _: - return FieldDefinition() # Unsupported field type - - -def get_create_row_model(table: Table, field_ids: list[int] | None = None) -> BaseModel: - """ - Dynamically creates a Pydantic model for the given table based on its fields, to be - used for row creation and validation. - """ - - model_name = f"Table{table.id}Row" - - field_definitions = {} - field_conversions = {} - - table_model = table.get_model() - for field_object in table_model.get_field_objects(): - field_definition = _get_pydantic_field_definition(field_object) - if field_definition.type is None: - continue # Skip unsupported field types - if field_ids is not None and field_object["field"].id not in field_ids: - continue # Skip fields not in the specified list - - field = field_object["field"] - field_definitions[field.name] = ( - field_definition.type, - field_definition.field_def, - ) - field_conversions[field.name] = ( - field.db_column, - field_definition.to_django_orm, - field_definition.from_django_orm, - ) - - class TableRowModel(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) - - def to_django_orm(self) -> dict[str, Any]: - orm_data = {} - for key, value in self.__dict__.items(): - if key == "id": - orm_data["id"] = value - continue - - if key not in field_conversions or value == "__NO_CHANGE__": - continue - - orm_key, to_django_orm, _ = field_conversions[key] - if to_django_orm: - orm_data[orm_key] = to_django_orm(value) - else: - orm_data[orm_key] = value - return orm_data - - @classmethod - def from_django_orm( - cls, orm_row: GeneratedTableModel, field_ids: list[int] | None = None - ) -> "TableRowModel": - init_data = {"id": orm_row.id} - for field_object in orm_row.get_field_objects(): - field = field_object["field"] - if field.name not in field_conversions: - continue - if field_ids is not None and field.id not in field_ids: - continue - db_column, _, from_django_orm = field_conversions[field.name] - value = getattr(orm_row, db_column) - if from_django_orm: - init_data[field.name] = from_django_orm(value) - else: - init_data[field.name] = value - return cls(**init_data) - - return create_model( - model_name, - __module__=__name__, - __base__=TableRowModel, - **field_definitions, - ) - - -def get_update_row_model(table) -> BaseModel: - """Creates an update model where all fields can be NoChange.""" - - create_model_class = get_create_row_model(table) - - # Build update fields - all fields become Union[OriginalType, NoChange] - update_fields = {} - - for field_name, field_info in create_model_class.model_fields.items(): - original_type = field_info.annotation - - update_fields[field_name] = ( - Union[NoChange, original_type], - Field( - ..., - description=f"Use '__NO_CHANGE__' to keep current value. To update, use a {field_info.description}", - ), - ) - - update_fields["id"] = (int, Field(..., description="The ID of the row to update")) - - # Create the update model - UpdateRowModel = create_model( - f"UpdateTable{table.id}Row", - __base__=create_model_class, - **update_fields, - ) - - return UpdateRowModel - - -def get_view(user, view_id: int): - return ViewHandler().get_view_as_user(user, view_id) - - -def get_table_rows_tools( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers", table: Table -): - row_model_for_create = get_create_row_model(table) - row_model_for_update = get_update_row_model(table) - row_model_for_response = create_model( - f"ResponseTable{table.id}Row", - id=(int, ...), - __base__=row_model_for_create, - ) - - def _create_rows( - rows: list[dict[str, Any]], - ) -> list[dict[str, Any]]: - """ - Create new rows in the specified table. - """ - - nonlocal \ - user, \ - workspace, \ - tool_helpers, \ - row_model_for_create, \ - row_model_for_response - - if not rows: - return [] - - tool_helpers.update_status( - _("Creating rows in %(table_name)s ") % {"table_name": table.name} - ) - - with transaction.atomic(): - orm_rows = CreateRowsActionType.do( - user, - table, - [row_model_for_create(**row).to_django_orm() for row in rows], - ) - - return {"created_row_ids": [r.id for r in orm_rows]} - - create_row_model_schema = minimize_schema( - resolve_json_schema_reference(row_model_for_create.model_json_schema()) - ) - create_rows_tool = udspy.Tool( - func=_create_rows, - name=f"create_rows_in_table_{table.id}", - description=f"Creates new rows in the table {table.name} (ID: {table.id}). Max 20 rows at a time.", - args={ - "rows": { - "items": create_row_model_schema, - "type": "array", - "maxItems": 20, - } - }, - ) - - def _update_rows( - rows: list[dict[str, Any]], - ) -> list[dict[str, Any]]: - """ - Update existing rows in the specified table. - """ - - nonlocal user, workspace, tool_helpers, row_model_for_update - - if not rows: - return [] - - tool_helpers.update_status( - _("Updating rows in %(table_name)s ") % {"table_name": table.name} - ) - - with transaction.atomic(): - orm_rows = UpdateRowsActionType.do( - user, - table, - [row_model_for_update(**row).to_django_orm() for row in rows], - ).updated_rows - - return {"updated_row_ids": [r.id for r in orm_rows]} - - update_row_model_schema = minimize_schema( - resolve_json_schema_reference(row_model_for_update.model_json_schema()) - ) - update_rows_tool = udspy.Tool( - func=_update_rows, - name=f"update_rows_in_table_{table.id}", - description=f"Updates existing rows in the table {table.name} (ID: {table.id}), identified by their row IDs. Max 20 at a time.", - args={ - "rows": { - "items": update_row_model_schema, - "type": "array", - "maxItems": 20, - } - }, - ) - - def _delete_rows(row_ids: list[int]) -> str: - """ - Delete rows in the specified table. - """ - - nonlocal user, workspace, tool_helpers - - if not row_ids: - return - - tool_helpers.update_status( - _("Deleting rows in %(table_name)s ") % {"table_name": table.name} - ) - - with transaction.atomic(): - DeleteRowsActionType.do(user, table, row_ids) - - return {"deleted_row_ids": row_ids} - - delete_rows_tool = udspy.Tool( - func=_delete_rows, - name=f"delete_rows_in_table_{table.id}", - description=f"Deletes rows in the table {table.name} (ID: {table.id}). Max 20 at a time.", - args={ - "row_ids": { - "items": {"type": "integer"}, - "type": "array", - "maxItems": 20, - } - }, - ) - - return { - "create": create_rows_tool, - "update": update_rows_tool, - "delete": delete_rows_tool, - } - - -def create_view_filter( - user: AbstractUser, - orm_view: View, - table_fields: list[Field], - view_filter_item: AnyViewFilterItemCreate, -) -> ViewFilter: - """ - Creates a view filter from the given view filter item. - """ - - field = table_fields.get(view_filter_item.field_id) - if field is None: - raise ValueError("Field not found for filter") - field_type = field_type_registry.get_by_model(field.specific_class) - if field_type.type != view_filter_item.type: - raise ValueError("Field type mismatch for filter") - - filter_type = view_filter_item.get_django_orm_type(field) - filter_value = view_filter_item.get_django_orm_value( - field, timezone=user.profile.timezone - ) - - return CreateViewFilterActionType.do( - user, - orm_view, - field, - filter_type, - filter_value, - filter_group_id=None, - ) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/tool_types.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/tool_types.py new file mode 100644 index 0000000000..23864ac773 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/tool_types.py @@ -0,0 +1,15 @@ +from baserow_enterprise.assistant.tools.registries import AssistantToolType + + +class NavigationToolType(AssistantToolType): + type = "navigation" + + def get_tool_functions(self): + from .tools import TOOL_FUNCTIONS + + return TOOL_FUNCTIONS + + def get_toolset(self): + from .tools import navigation_toolset + + return navigation_toolset diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/tools.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/tools.py index a9ad456ac7..53b867d789 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/tools.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/tools.py @@ -1,52 +1,53 @@ -from typing import TYPE_CHECKING, Callable +from typing import Annotated -from django.contrib.auth.models import AbstractUser +from django.core.exceptions import ObjectDoesNotExist from django.utils.translation import gettext as _ -from baserow.core.models import Workspace -from baserow_enterprise.assistant.tools.registries import AssistantToolType +from pydantic import Field +from pydantic_ai import RunContext +from pydantic_ai.toolsets import FunctionToolset -from .types import AnyNavigationRequestType +from baserow_enterprise.assistant.deps import AssistantDeps -if TYPE_CHECKING: - from baserow_enterprise.assistant.assistant import ToolHelpers +from .types import AnyNavigationRequestType -def get_navigation_tool( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[AnyNavigationRequestType], str]: +def navigate( + ctx: RunContext[AssistantDeps], + request: Annotated[ + AnyNavigationRequestType, + Field( + description="The navigation target: either a specific table or the workspace home." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> str: + """\ + Navigate the UI to a table, view, automation, page, or workspace home. + + WHEN to use: User asks to open, go to, or see something in the workspace. Also after creating new resources (views, fields, rows) in an existing database or table. + WHAT it does: Navigates the UI to a table, view, automation workflow, builder page, or workspace home. + RETURNS: Confirmation of navigation. + DO NOT USE when: You need data — use list/get tools instead. Navigation only changes the UI focus. """ - Returns a function that provides navigation instructions to the user based on - their current workspace context. - """ - - def navigate(request: AnyNavigationRequestType) -> str: - """ - Navigate within the workspace. - Use when: - - the user asks to open, go, to be brought to something - - the user asks to see something from their workspace - - if something new has been created in a previously existing database or table, - like a view, a field or some rows - """ - - nonlocal user, workspace + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers + try: location = request.to_location(user, workspace, request) + except ObjectDoesNotExist: + return "Error: could not navigate — the target was not found. Check that the ID is correct." - tool_helpers.update_status( - _("Navigating to %(location)s...") - % {"location": location.to_localized_string()} - ) - return tool_helpers.navigate_to(location) - - return navigate - + tool_helpers.update_status( + _("Navigating to %(location)s...") + % {"location": location.to_localized_string()} + ) + return tool_helpers.navigate_to(location) -class NavigationToolType(AssistantToolType): - type = "navigation" - @classmethod - def get_tool(cls, user, workspace, tool_helpers): - return get_navigation_tool(user, workspace, tool_helpers) +TOOL_FUNCTIONS = [navigate] +navigation_toolset = FunctionToolset(TOOL_FUNCTIONS, max_retries=3) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/types.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/types.py index f0421eb7bf..49219c4da0 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/types.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/types.py @@ -5,11 +5,10 @@ from pydantic import Field from baserow.core.models import Workspace -from baserow_enterprise.assistant.tools.database.utils import filter_tables +from baserow_enterprise.assistant.tools.database.helpers import filter_tables from baserow_enterprise.assistant.types import ( BaseModel, TableNavigationType, - WorkspaceNavigationType, ) @@ -51,22 +50,7 @@ def to_location( ) -class WorkspaceNavigationRequestType(NavigationRequestType): - type: Literal["workspace"] = Field( - ..., description="The home page of the workspace" - ) - - @classmethod - def to_location( - cls, - user: AbstractUser, - workspace: Workspace, - request: "WorkspaceNavigationRequestType", - ) -> WorkspaceNavigationType: - return WorkspaceNavigationType(type="workspace", id=workspace.id) - - AnyNavigationRequestType = Annotated[ - TableNavigationRequestType | WorkspaceNavigationRequestType, + TableNavigationRequestType, Field(discriminator="type"), ] diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/utils.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/utils.py index ee2fd6c157..d6753d751a 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/utils.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/utils.py @@ -1,16 +1,16 @@ +from baserow_enterprise.assistant.deps import EventBus from baserow_enterprise.assistant.types import AiNavigationMessage, AnyNavigationType -def unsafe_navigate_to(location: AnyNavigationType) -> str: +def unsafe_navigate_to(location: AnyNavigationType, event_bus: EventBus) -> str: """ Navigate to a specific table or view without any safety checks. Make sure all the IDs provided are valid and can be accessed by the user before calling this function. - :param navigation_type: The type of navigation to perform. + :param location: The type of navigation to perform. + :param event_bus: The event bus to emit the navigation event on. """ - from udspy.streaming import emit_event - - emit_event(AiNavigationMessage(location=location)) + event_bus.emit(AiNavigationMessage(location=location)) return "Navigated successfully." diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/registries.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/registries.py index df3a7dea16..e0d413dbed 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/registries.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/registries.py @@ -1,111 +1,190 @@ -from typing import TYPE_CHECKING, Any, Callable +""" +Baserow registry for assistant tool types. -from django.contrib.auth.models import AbstractUser +Each tool module (navigation, database, etc.) registers an +``AssistantToolType`` instance. The registry assembles the combined +toolset at runtime, filtering by ``can_use(user, workspace)`` so +individual tool groups can be gated on permissions or feature flags. +""" -from baserow.core.exceptions import ( - InstanceTypeAlreadyRegistered, - InstanceTypeDoesNotExist, +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable + +from pydantic_ai.toolsets import AbstractToolset, CombinedToolset + +from baserow.core.registry import Instance, Registry +from baserow_enterprise.assistant.deps import AgentMode + +from .toolset import ( + InlineRefsToolset, + ModeAwareToolset, + generate_tool_manifest_compact, ) -from baserow.core.models import Workspace -from baserow.core.registries import Instance, Registry if TYPE_CHECKING: - from baserow_enterprise.assistant.assistant import ToolHelpers + from django.contrib.auth.models import AbstractUser + from baserow.core.models import Workspace + from baserow_enterprise.assistant.deps import AssistantDeps -class AssistantToolType(Instance): - name: str = "" - - @classmethod - def can_use(cls, user: AbstractUser, workspace: Workspace, *args, **kwargs) -> bool: - """ - Returns whether or not the given user can use this tool in the given workspace. - :param user: The user to check if they can use this tool. - :param workspace: The workspace where to check if the tool can be used. - :return: True if the user can use this tool, False otherwise. - """ +class AssistantToolType(Instance): + """ + Base class for assistant tool groups. - return True + Each subclass represents a logical group of tools (e.g. "database", + "navigation"). Override ``can_use`` to gate availability on user + permissions or feature flags. + """ - @classmethod - def on_tool_start( - cls, - call_id: str, - instance: Any, - inputs: dict[str, Any], - ): - """ - Called when the tool is started. It can be used to stream status messages. + type: str = "" - :param call_id: The unique identifier of the tool call. - :param instance: The instance of the udspy tool being called. - :param inputs: The inputs provided to the tool. + def can_use(self, user: "AbstractUser", workspace: "Workspace") -> bool: """ + Permission gate. Override in subclasses for conditional availability. - pass - - @classmethod - def on_tool_end( - cls, - call_id: str, - instance: Any, - inputs: dict[str, Any], - outputs: dict[str, Any] | None, - exception: Exception | None = None, - ): - """ - Called when the tool has finished, either successfully or with an exception. - - :param call_id: The unique identifier of the tool call. - :param instance: The instance of the udspy tool being called. - :param inputs: The inputs provided to the tool. - :param outputs: The outputs returned by the tool, or None if there was an - exception. - :param exception: The exception raised by the tool, or None if it was - successful. + :param user: The requesting user. + :param workspace: The current workspace. + :return: ``True`` if this tool group should be included. """ - pass + return True - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - """ - Returns the actual tool function to be called to pass to the udspy react agent. + def get_tool_functions(self) -> list[Callable]: + """Return the raw tool functions for manifest generation.""" - :param user: The user that will be using the tool. - :param workspace: The workspace the user is currently in. - :param tool_helpers: A dataclass containing helper functions that can be used by - the tool function. - """ + raise NotImplementedError - raise NotImplementedError("Subclasses must implement this method.") + def get_toolset(self) -> AbstractToolset: + """Return the pydantic-ai ``FunctionToolset`` for this group.""" + raise NotImplementedError -class AssistantToolDoesNotExist(InstanceTypeDoesNotExist): - pass + def get_routing_rules(self) -> str: + """Return routing rules text for this tool group's manifest. + Override in subclasses that define mode-specific routing rules. + Returns empty string by default (no rules). + """ -class AssistantToolAlreadyRegistered(InstanceTypeAlreadyRegistered): - pass + return "" class AssistantToolRegistry(Registry[AssistantToolType]): name = "assistant_tool" - does_not_exist_exception_class = AssistantToolDoesNotExist - already_registered_exception_class = AssistantToolAlreadyRegistered + def build_toolset( + self, + user: "AbstractUser", + workspace: "Workspace", + model: str, + deps: "AssistantDeps", + ) -> tuple[AbstractToolset, str, str, str, str]: + """ + Assemble the combined assistant toolset, filtering by ``can_use()``. + + :param user: The requesting user. + :param workspace: The current workspace. + :param model: The pydantic-ai model string. + :param deps: The assistant deps (used for mode-aware filtering). + :return: ``(toolset, database_manifest, application_manifest, + automation_manifest, explain_manifest)``. + """ + + toolsets: list[AbstractToolset] = [] + module_groups: list[tuple[str, list[Callable]]] = [] + + for tool_type in self.get_all(): + if not tool_type.can_use(user, workspace): + continue + toolsets.append(tool_type.get_toolset()) + module_groups.append((tool_type.type, tool_type.get_tool_functions())) + + combined = CombinedToolset(toolsets) + mode_aware = ModeAwareToolset(combined, deps) + + from .toolset import _get_mode_tool_map + + # Build a routing-rules lookup from registered tool types so each + # module owns its own rules (no hardcoded imports here). + routing_rules_by_type: dict[str, str] = { + tt.type: tt.get_routing_rules() + for tt in self.get_all() + if tt.get_routing_rules() + } + + mode_map = _get_mode_tool_map() + shared = mode_map[AgentMode.DATABASE] & mode_map[AgentMode.APPLICATION] + + _mode_config: list[tuple[str, AgentMode, str]] = [ + ("database", AgentMode.DATABASE, routing_rules_by_type.get("database", "")), + ( + "application", + AgentMode.APPLICATION, + routing_rules_by_type.get("builder", ""), + ), + ( + "automation", + AgentMode.AUTOMATION, + routing_rules_by_type.get("automation", ""), + ), + ] - def list_all_usable_tools( - self, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> list[AssistantToolType]: - return [ - tool_type.get_tool(user, workspace, tool_helpers) - for tool_type in self.get_all() - if tool_type.can_use(user, workspace) + manifests = {} + for mode_key, mode, rules in _mode_config: + allowed = mode_map[mode] + groups = [ + (label, [f for f in funcs if f.__name__ in allowed]) + for label, funcs in module_groups + ] + manifest = generate_tool_manifest_compact(groups, routing_rules=rules) + + # Append a compact cross-mode summary so the agent knows what + # capabilities exist in other modes (and can switch_mode to use them). + other_lines = [] + for other_key, other_mode, _ in _mode_config: + if other_key == mode_key: + continue + specific = mode_map[other_mode] - shared + other_lines.append(f"- {other_key}: {', '.join(sorted(specific))}") + if other_lines: + manifest += "\n\n## Other modes (switch_mode to access)\n" + "\n".join( + other_lines + ) + + manifests[mode_key] = manifest + + explain_allowed = mode_map[AgentMode.EXPLAIN] + explain_groups = [ + (label, [f for f in funcs if f.__name__ in explain_allowed]) + for label, funcs in module_groups ] + manifests["explain"] = generate_tool_manifest_compact(explain_groups) + + return ( + InlineRefsToolset(mode_aware, model=model), + manifests["database"], + manifests["application"], + manifests["automation"], + manifests["explain"], + ) assistant_tool_registry = AssistantToolRegistry() + + +def get_shared_read_funcs() -> list[Callable]: + """ + Return read-only tool functions shared across sub-agents. + + Uses deferred imports to avoid circular dependencies. + """ + + from baserow_enterprise.assistant.tools.database.tools import ( + get_tables_schema, + list_rows, + list_tables, + ) + + return [list_tables, get_tables_schema, list_rows] diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/search_user_docs/tool_types.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/search_user_docs/tool_types.py new file mode 100644 index 0000000000..95860d2695 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/search_user_docs/tool_types.py @@ -0,0 +1,20 @@ +from baserow_enterprise.assistant.tools.registries import AssistantToolType + + +class SearchDocsToolType(AssistantToolType): + type = "search_user_docs" + + def can_use(self, user, workspace) -> bool: + from .handler import KnowledgeBaseHandler + + return KnowledgeBaseHandler().can_search() + + def get_tool_functions(self): + from .tools import TOOL_FUNCTIONS + + return TOOL_FUNCTIONS + + def get_toolset(self): + from .tools import search_docs_toolset + + return search_docs_toolset diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/search_user_docs/tools.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/search_user_docs/tools.py index 4d337e7685..bb1eba26f9 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/search_user_docs/tools.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/search_user_docs/tools.py @@ -1,64 +1,70 @@ -from typing import TYPE_CHECKING, Annotated, Any, Callable +import re +from typing import Annotated, Any -from django.contrib.auth.models import AbstractUser from django.utils.translation import gettext as _ -import udspy from asgiref.sync import sync_to_async +from loguru import logger +from pydantic import BaseModel as PydanticBaseModel +from pydantic import Field +from pydantic_ai import Agent, RunContext +from pydantic_ai.toolsets import FunctionToolset -from baserow.core.models import Workspace +from baserow_enterprise.assistant.deps import AssistantDeps from baserow_enterprise.assistant.models import KnowledgeBaseChunk -from baserow_enterprise.assistant.tools.registries import AssistantToolType from .handler import KnowledgeBaseHandler -if TYPE_CHECKING: - from baserow_enterprise.assistant.assistant import ToolHelpers - - -class SearchDocsSignature(udspy.Signature): - """ - Given a user question and documentation chunks as context, provide an accurate - and concise answer along with a reliability score. - - CRITICAL: The context may contain documents retrieved by keyword similarity that - are NOT actually relevant to the user's question. You MUST carefully evaluate - each document's ACTUAL TOPIC before using it: - - 1. First, identify the SPECIFIC FEATURE or concept the user is asking about - 2. For each document, check if it DIRECTLY explains that specific feature - 3. IGNORE documents that merely mention similar keywords but cover different topics - (e.g., if asked about "webhooks in Baserow", ignore docs about external - webhook services or third-party integrations - only use docs about - Baserow's native webhook feature) - 4. Only use documents that would genuinely help answer THIS specific question - - If no documents in the context actually address the user's question (even if - they contain similar words), respond with "Nothing found in the documentation." - - Include instructions and URLs from the documentation when relevant. - Never fabricate answers or URLs. - """ - - question: str = udspy.InputField() - context: dict[str, str] = udspy.InputField( - desc=( - "A mapping of source URLs to documents. WARNING: These documents were " - "retrieved by keyword similarity and may include irrelevant results. " - "Carefully filter to only use documents that DIRECTLY address the question." - ) - ) - - answer: str = udspy.OutputField() - sources: list[str] = udspy.OutputField( - desc=( +# Regex that matches assistant tool names in a search query. Used to +# short-circuit search_user_docs when the model is trying to look up how +# its own tools work instead of answering a user question. +_TOOL_QUERY_RE = re.compile( + r"(?:list|create|get|update|delete|generate|load|add)_" + r"(?:tables?|fields?|views?|rows?|pages?|elements?|actions?|data_sources?|" + r"theme|workflows?|view_filters?|formula|row_tools|" + r"action_field_mapping|rows_in_table)" + r"|search_user_docs" + r"|\bnavigate\s+(?:tool|function|param)", + re.IGNORECASE, +) + + +SEARCH_DOCS_INSTRUCTIONS = """\ +Given a user question and documentation chunks as context, provide an accurate +and concise answer along with a reliability score. + +CRITICAL: The context may contain documents retrieved by keyword similarity that +are NOT actually relevant to the user's question. You MUST carefully evaluate +each document's ACTUAL TOPIC before using it: + +1. First, identify the SPECIFIC FEATURE or concept the user is asking about +2. For each document, check if it DIRECTLY explains that specific feature +3. IGNORE documents that merely mention similar keywords but cover different topics + (e.g., if asked about "webhooks in Baserow", ignore docs about external + webhook services or third-party integrations - only use docs about + Baserow's native webhook feature) +4. Only use documents that would genuinely help answer THIS specific question + +If no documents in the context actually address the user's question (even if +they contain similar words), respond with "Nothing found in the documentation." + +Include instructions and URLs from the documentation when relevant. +Never fabricate answers or URLs. +""" + + +class SearchDocsResult(PydanticBaseModel): + answer: str = Field(description="The answer to the user's question.") + sources: list[str] = Field( + default_factory=list, + description=( "URLs of documents that were ACTUALLY USED to form the answer. " "Only include sources that directly addressed the question topic. " "Leave empty if no documents were relevant. Maximum 3 URLs, ordered by relevance." - ) + ), ) - reliability: float = udspy.OutputField( - desc=( + reliability: float = Field( + description=( "How well the RELEVANT documents (not all documents) support the answer. " "1.0 = found documents that directly and completely answer the question. " "0.5 = found partially relevant information. " @@ -66,155 +72,195 @@ class SearchDocsSignature(udspy.Signature): ) ) - @classmethod - def format_context(cls, chunks: list[KnowledgeBaseChunk]) -> dict[str, str]: - """ - Formats the context as a list of strings for the signature. - Each string is formatted as "Source URL: content". - :param chunks: The list of knowledge base chunks. - :return: A dictionary mapping source URLs to their combined content. - """ +search_docs_agent: Agent[None, SearchDocsResult] = Agent( + output_type=SearchDocsResult, + instructions=SEARCH_DOCS_INSTRUCTIONS, + name="search_docs_agent", +) - context = {} - for chunk in chunks: - url = chunk.source_document.source_url - content = chunk.content - if url not in context: - context[url] = content - else: - context[url] += "\n" + content - - return context +def format_context(chunks: list[KnowledgeBaseChunk]) -> dict[str, str]: + """ + Formats the context as a mapping of source URLs to their combined content. -def get_search_user_docs_tool( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[str], dict[str, Any]]: + :param chunks: The list of knowledge base chunks. + :return: A dictionary mapping source URLs to their combined content. """ - Returns a tool function that searches Baserow's knowledge base and uses an LLM - to filter and synthesize relevant documentation into a focused answer. - The search retrieves documents by keyword similarity, then the LLM evaluates - each document's actual relevance to the question before generating an answer. + context = {} + for chunk in chunks: + url = chunk.source_document.source_url + content = chunk.content + if url not in context: + context[url] = content + else: + context[url] += "\n" + content + + return context + + +async def search_user_docs( + ctx: RunContext[AssistantDeps], + question: Annotated[ + str, + ( + "A precise search query in English using Baserow terminology. " + "Focus on the SPECIFIC Baserow feature being asked about. " + "Include the feature name and action, e.g., 'How to create webhooks in Baserow' " + "or 'Baserow table linking feature'. Avoid generic terms that could match " + "unrelated documentation about third-party services or integrations." + ), + ], + thought: Annotated[str, "Brief reasoning for calling this tool."], +) -> dict[str, Any]: + """\ + Search Baserow end-user docs for feature guides. NOT for tool introspection. It doesn't provide any information about your own tools. + + WHEN to use: User explicitly asks how to do something in Baserow's UI, or wants to learn about a specific Baserow feature (e.g., linking tables, webhooks, forms). + WHAT it does: Searches official Baserow end-user documentation and returns an answer with reliability score and source URLs. + RETURNS: Answer, reliability score (0.0-1.0), reliability_note (HIGH/PARTIAL/LOW), source URLs. Always check reliability_note before using the answer. + DO NOT USE when: Looking up how YOUR OWN tools work — you already know your tools from their names, descriptions, and schemas. Also not for API/programming documentation. + + IMPORTANT: Frame the question to target Baserow's NATIVE features specifically. + For example, ask about "Baserow webhooks" not just "webhooks" to avoid getting + results about external webhook services that integrate WITH Baserow. """ - async def search_user_docs( - question: Annotated[ - str, - ( - "A precise search query in English using Baserow terminology. " - "Focus on the SPECIFIC Baserow feature being asked about. " - "Include the feature name and action, e.g., 'How to create webhooks in Baserow' " - "or 'Baserow table linking feature'. Avoid generic terms that could match " - "unrelated documentation about third-party services or integrations." + tool_helpers = ctx.deps.tool_helpers + + # Guard: reject queries about the model's own tools. + if _TOOL_QUERY_RE.search(question): + logger.info("search_user_docs: rejected tool-introspection query: {}", question) + return { + "answer": ( + "STOP. This tool searches END-USER documentation only — " + "it has no information about your tools. " + "You already know how to use your tools from their names, " + "descriptions, and parameter schemas. " + "If a tool call failed, read the error message carefully " + "and adjust the parameters." ), - ], - ) -> dict[str, Any]: - """ - Search Baserow's official documentation for user guides and feature - explanations. - - PURPOSE: Provides end-user documentation about Baserow's built-in - features and how to use them through the UI. - - USE WHEN: The user asks how to do something in Baserow, wants to learn - about a Baserow feature, or needs step-by-step instructions. - - DO NOT USE FOR: Agent tool usage, API implementation details, or - programming help. - - IMPORTANT: Frame the question to target Baserow's NATIVE features - specifically. For example, ask about "Baserow webhooks" not just - "webhooks" to avoid getting results about external webhook services that - integrate WITH Baserow. - """ - - nonlocal tool_helpers - - tool_helpers.update_status(_("Exploring the knowledge base...")) - - @sync_to_async - def _search(question: str) -> list[KnowledgeBaseChunk]: - chunks = KnowledgeBaseHandler().search(question, 15) - return list(chunks) - - searcher = udspy.ChainOfThought(SearchDocsSignature) - relevant_chunks = await _search(question) - prediction = await searcher.aexecute( - question=question, - context=SearchDocsSignature.format_context(relevant_chunks), - stream=True, - ) + "reliability": 0.0, + "reliability_note": "REJECTED: Tool-introspection query.", + "sources": [], + } - sources = [] - available_urls = {chunk.source_document.source_url for chunk in relevant_chunks} - for url in prediction.sources: - # somehow LLMs sometimes return sources as objects - if isinstance(url, dict) and "url" in url: - url = url["url"] - - if not isinstance(url, str): - continue - - if url in available_urls and url not in sources: - sources.append(url) - if len(sources) >= 3: - break - - # Only fallback to available URLs if reliability is high AND we have a - # real answer. Don't populate sources if the model indicated no relevant - # docs were found. - nothing_found = "nothing found" in prediction.answer.lower() - if not sources and prediction.reliability > 0.8 and not nothing_found: - sources = list(available_urls)[:3] - - # Override reliability to 0 if the model explicitly said nothing was - # found. The model sometimes returns high reliability for "nothing - # found" answers, which is semantically incorrect - we want reliability - # to reflect whether we actually found useful information. - reliability = 0.0 if nothing_found else prediction.reliability - - if reliability >= 0.7: - reliability_note = ( - "HIGH CONFIDENCE: Answer is well-supported by the documentation." - ) - elif reliability >= 0.4: - reliability_note = ( - "PARTIAL MATCH: Some relevant information was found, but the " - "documentation may not fully cover this topic. Supplement with " - "general knowledge but warn the user that details may be incomplete." - ) - else: - reliability_note = ( + tool_helpers.update_status(_("Exploring the knowledge base...")) + + try: + return await _search_user_docs_impl(ctx, question) + except Exception: + logger.exception("search_user_docs failed for question: {}", question) + return { + "answer": "An error occurred while searching the documentation.", + "reliability": 0.0, + "reliability_note": ( + "LOW CONFIDENCE: The documentation search encountered an error. " + "Inform the user that documentation search is temporarily " + "unavailable and suggest they check baserow.io/docs directly." + ), + "sources": [], + } + + +async def _search_user_docs_impl( + ctx: RunContext[AssistantDeps], + question: str, +) -> dict[str, Any]: + """Inner implementation of search_user_docs, separated for error handling.""" + + @sync_to_async + def _search(question: str) -> list[KnowledgeBaseChunk]: + chunks = KnowledgeBaseHandler().search(question, 15) + return list(chunks) + + relevant_chunks = await _search(question) + + if not relevant_chunks: + return { + "answer": "Nothing found in the documentation.", + "reliability": 0.0, + "reliability_note": ( "LOW CONFIDENCE: The documentation does not contain information about " "this topic. DO NOT provide an answer based on general knowledge or " "assumptions - the feature may not exist in Baserow. Tell the user: " "'I couldn't find information about this in the official Baserow " "documentation.' and suggest they check the community forum or " "contact support." - ) - - return { - "answer": prediction.answer, - "reliability": reliability, - "reliability_note": reliability_note, - "sources": sources, + ), + "sources": [], } - return search_user_docs + context = format_context(relevant_chunks) + + prompt = ( + f"Question: {question}\n\n" + f"Documentation context (source URL -> content):\n{context}" + ) + from baserow_enterprise.assistant.model_profiles import get_model_string + + agent_result = await search_docs_agent.run(prompt, model=get_model_string()) + prediction = agent_result.output + + sources = [] + available_urls = {chunk.source_document.source_url for chunk in relevant_chunks} + for url in prediction.sources: + # somehow LLMs sometimes return sources as objects + if isinstance(url, dict) and "url" in url: + url = url["url"] + + if not isinstance(url, str): + continue + + if url in available_urls and url not in sources: + sources.append(url) + if len(sources) >= 3: + break + + # Only fallback to available URLs if reliability is high AND we have a + # real answer. Don't populate sources if the model indicated no relevant + # docs were found. + nothing_found = "nothing found" in prediction.answer.lower() + if not sources and prediction.reliability > 0.8 and not nothing_found: + sources = list(available_urls)[:3] + + # Override reliability to 0 if the model explicitly said nothing was + # found. The model sometimes returns high reliability for "nothing + # found" answers, which is semantically incorrect - we want reliability + # to reflect whether we actually found useful information. + reliability = 0.0 if nothing_found else prediction.reliability + + if reliability >= 0.7: + reliability_note = ( + "HIGH CONFIDENCE: Answer is well-supported by the documentation." + ) + elif reliability >= 0.4: + reliability_note = ( + "PARTIAL MATCH: Some relevant information was found, but the " + "documentation may not fully cover this topic. Supplement with " + "general knowledge but warn the user that details may be incomplete." + ) + else: + reliability_note = ( + "LOW CONFIDENCE: The documentation does not contain information about " + "this topic. DO NOT provide an answer based on general knowledge or " + "assumptions - the feature may not exist in Baserow. Tell the user: " + "'I couldn't find information about this in the official Baserow " + "documentation.' and suggest they check the community forum or " + "contact support." + ) + if sources: + ctx.deps.extend_sources(sources) -class SearchDocsToolType(AssistantToolType): - type = "search_user_docs" + return { + "answer": prediction.answer, + "reliability": reliability, + "reliability_note": reliability_note, + "sources": sources, + } - def can_use( - self, user: AbstractUser, workspace: Workspace, *args, **kwargs - ) -> bool: - return KnowledgeBaseHandler().can_search() - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - return get_search_user_docs_tool(user, workspace, tool_helpers) +TOOL_FUNCTIONS = [search_user_docs] +search_docs_toolset = FunctionToolset(TOOL_FUNCTIONS, max_retries=3) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/__init__.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/__init__.py new file mode 100644 index 0000000000..b586400c55 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/__init__.py @@ -0,0 +1,25 @@ +from .agents import get_formula_generator +from .formula_utils import ( + FORMULA_PREFIX, + RAW_FORMULA_RE, + BaseFormulaContext, + create_example_from_json_schema, + formula_desc, + literal_or_placeholder, + minimize_json_schema, + needs_formula, + wrap_static_string, +) + +__all__ = [ + "FORMULA_PREFIX", + "RAW_FORMULA_RE", + "needs_formula", + "formula_desc", + "literal_or_placeholder", + "wrap_static_string", + "minimize_json_schema", + "create_example_from_json_schema", + "BaseFormulaContext", + "get_formula_generator", +] diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/agents.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/agents.py new file mode 100644 index 0000000000..058e5f8ecd --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/agents.py @@ -0,0 +1,141 @@ +""" +Shared formula generation agent factory. + +Contains: +- ``FormulaGeneratorOutput``: Output model for the formula generator agent. +- ``get_formula_generator()``: Factory to create a formula generator with a custom prompt. +""" + +from typing import Callable + +from pydantic import BaseModel as PydanticBaseModel +from pydantic import Field +from pydantic_ai import Agent + +from baserow.core.formula import resolve_formula +from baserow.core.formula.registries import formula_runtime_function_registry +from baserow.core.formula.types import ( + BASEROW_FORMULA_MODE_ADVANCED, + BaserowFormulaObject, +) + +from .formula_utils import BaseFormulaContext + + +class FormulaGeneratorOutput(PydanticBaseModel): + """Output model for the formula generator agent.""" + + generated_formulas: dict[str, str] = Field( + description=( + "A mapping of field identifiers to their generated formulas. " + "Each key is a field id/name from `fields_to_resolve` and the value " + "is the generated formula string." + ) + ) + + +def get_formula_generator( + prompt: str, +) -> Callable[[dict, BaseFormulaContext, int], dict[str, str]]: + """ + Factory to create a formula generator with a custom prompt. + + :param prompt: The system prompt for the LLM describing available functions. + :return: A function that generates formulas from field descriptions. + """ + + formula_agent = Agent( + output_type=FormulaGeneratorOutput, + instructions=prompt, + name="formula_agent", + ) + + def check_formula(generated_formula: str, context: BaseFormulaContext) -> str: + """Validate a generated formula against the context.""" + try: + resolve_formula( + BaserowFormulaObject.create( + formula=generated_formula, mode=BASEROW_FORMULA_MODE_ADVANCED + ), + formula_runtime_function_registry, + context, + ) + except Exception as exc: + raise ValueError(f"Generated formula is invalid: {str(exc)}") + return "ok, the formula is valid" + + def generate_formulas( + fields_to_resolve: dict, + context: BaseFormulaContext, + max_retries: int = 3, + ) -> dict[str, str]: + """ + Generate formulas for the given field descriptions. + + :param fields_to_resolve: Dict mapping field names to descriptions. + :param context: Formula context with available data. + :param max_retries: Number of retry attempts on validation failure. + :return: Dict mapping field names to generated formulas. + :raises ValueError: If no valid formulas could be generated. + """ + feedback = "" + valid_formulas = {} + remaining = dict(fields_to_resolve) + + for __ in range(max_retries): + if not remaining: + break + + user_prompt = ( + f"Fields to resolve: {remaining}\n" + f"(If prefixed with [optional], the field is not mandatory.)\n\n" + f"Context: {context.get_formula_context()}\n\n" + f"Context metadata: {context.get_context_metadata()}\n" + f"(Metadata about the context fields, with refs and names " + f"to assist in formula generation.)\n\n" + f"Feedback: {feedback or 'None (first attempt)'}" + ) + from baserow_enterprise.assistant.model_profiles import ( + UTILITY, + get_model_settings, + get_model_string, + ) + + model = get_model_string() + try: + result = formula_agent.run_sync( + user_prompt, + model=model, + model_settings=get_model_settings(model, UTILITY), + ) + except Exception as exc: + feedback += f"Formula agent error: {str(exc)}\n" + continue + + generated_formulas = result.output.generated_formulas + for field_id, formula in generated_formulas.items(): + if field_id not in remaining: + continue + try: + check_formula(formula, context) + valid_formulas[field_id] = formula + remaining.pop(field_id, None) + except ValueError as exc: + feedback += ( + f"Error for {field_id}, formula {formula} not valid: " + f"{str(exc)}\n" + ) + + if not remaining: + return valid_formulas + + # Return any valid formulas we have, or raise if none + if valid_formulas: + return valid_formulas + else: + raise ValueError( + f"Failed to generate any valid formulas after " + f"{max_retries} attempts. Feedback:\n{feedback}" + ) + + return generate_formulas diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/formula_prompt.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/formula_prompt.py new file mode 100644 index 0000000000..30d87fcfdd --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/formula_prompt.py @@ -0,0 +1,84 @@ +""" +Shared formula language reference for formula generation prompts. + +This module contains the common formula language documentation shared between +the automation and builder formula generators. Context-specific sections +(automation paths, builder data providers) are appended by each consumer. +""" + +FORMULA_LANGUAGE = """\ +You are a formula builder. Generate formulas using the Baserow formula language. + +## Value Access + +**get(path)** - Retrieves values from context using dot-separated path notation +- Objects: get('user.name') +- Arrays by index: get('items.0'), get('orders.2.total') +- Nested: get('users.0.address.city') +- Wildcard: get('users.*.email') returns a list of values from all items + +## Field Type Suffixes + +When accessing database fields via get(), certain field types require a suffix +to extract the display value. Use the correct suffix based on the field type +reported in context_metadata: + +| Field type | Suffix | Example path | +|---|---|---| +| text, number, boolean, date, url, email, phone_number, rating, long_text, uuid | *(none)* | `field_10` | +| single_select | `.value` | `field_10.value` | +| multiple_select | `.*.value` | `field_10.*.value` | +| link_row | `.*.value` | `field_10.*.value` | +| last_modified_by | `.name` | `field_10.name` | +| created_by | `.name` | `field_10.name` | +| multiple_collaborators | `.*.name` | `field_10.*.name` | +| file | `.*.url` or `.*.visible_name` | `field_10.*.url` | + +Always check the field type in context_metadata and apply the matching suffix. + +## Operators + +**Comparison** (return boolean): +- equal(a, b), not_equal(a, b) +- greater_than(a, b), less_than(a, b) +- greater_than_or_equal(a, b), less_than_or_equal(a, b) +- Infix: a==b, a!=b, ab, a>=b + +**Arithmetic:** +- add(a, b) or a+b, minus(a, b) or a-b +- multiply(a, b) or a*b, divide(a, b) or a/b + +**Logic:** +- and(a, b), or(a, b) + +## Functions + +**Core:** +- concat(...args) - Join arguments into a string: concat('Hello ', get('name'), '!') +- if(condition, true_value, false_value) - Conditional expression + +**String:** +- upper(text), lower(text), capitalize(text) +- strip(text), replace(text, old, new), length(text), contains(text, search) +- split(text, separator), join(array, separator) + +**Number:** +- round(num, decimals), is_even(num), is_odd(num) + +**Date:** +- today() - Current date +- now() - Current date and time +- day(date), month(date), year(date), hour(datetime), minute(datetime), second(datetime) +- datetime_format(datetime, format) + +**Array:** +- sum(array), avg(array), at(array, index) + +**Utility:** +- is_empty(value), get_property(object, key) + +## Constants + +- String literals in single quotes: 'hello world', '123' +- Numbers: 42, 3.14 +""" diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/formula_utils.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/formula_utils.py new file mode 100644 index 0000000000..3b92631e00 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/formula_utils.py @@ -0,0 +1,272 @@ +import re +from abc import ABC, abstractmethod +from datetime import date, datetime +from typing import Any + +from baserow.core.formula.types import FormulaContext +from baserow.core.utils import to_path + +# ============================================================================= +# Formula Detection Constants and Helpers +# ============================================================================= + +FORMULA_PREFIX = "$formula:" + +# Detects raw formula syntax the LLM might write instead of using $formula:. +# Matches: get('...'), concat(...), {{ ... }}, comparison operators, if(...), +# today(), now(). +RAW_FORMULA_RE = re.compile( + r"\bget\s*\(|\bconcat\s*\(|\{\{.*\}\}" + r"|\b(?:equal|not_equal|greater_than|less_than" + r"|greater_than_(?:or_)?equal|less_than_(?:or_)?equal)\s*\(" + r"|\bif\s*\(|\btoday\s*\(|\bnow\s*\(" +) + + +def needs_formula(value: str | None) -> bool: + """ + Check if a value requires formula processing. + + Returns True for explicit ``$formula:`` prefixed values *and* for raw + formula expressions the LLM may write inline (e.g. ``get('field')`` + or ``{{ get('field') }}``). + + :param value: The string value to check, or None. + :return: True if the value needs formula generation. + """ + + if not value: + return False + stripped = value.strip() + return stripped.lower().startswith(FORMULA_PREFIX) or bool( + RAW_FORMULA_RE.search(stripped) + ) + + +def formula_desc(value: str) -> str: + """ + Extract the formula description from a value. + + For ``$formula:`` prefixed values, strips the prefix. + For raw formula expressions, returns the value as-is so the + formula generator can convert it to a proper formula. + + :param value: A string containing a formula description or raw formula. + :return: The description text or raw formula expression. + """ + + stripped = value.strip() + if stripped.lower().startswith(FORMULA_PREFIX): + return stripped[len(FORMULA_PREFIX) :].strip() + # Raw formula expression — pass through for the generator to fix up + return stripped + + +def literal_or_placeholder(value: str | None) -> str: + """ + Return a quoted literal formula, or empty placeholder for formula values. + + Used when creating ORM objects: formula fields get a ``''`` placeholder + that will be replaced later by the formula generator, while literal + values are wrapped in single quotes. + + :param value: The string value, or None. + :return: A single-quoted formula literal or ``''`` placeholder. + """ + + if not value or needs_formula(value): + return "''" + return wrap_static_string(value) + + +def wrap_static_string(value: str) -> str: + """ + Wrap a static string as a Baserow formula literal. + + If the value is already a quoted formula literal (e.g. ``'Submit'``), + it is returned unchanged to avoid double-wrapping which would produce + escaped quotes visible in the UI (e.g. ``'\\'Submit\\''``). + + :param value: Plain text string or already-quoted formula literal. + :return: Formula-compatible string literal with proper escaping. + """ + + if len(value) >= 2 and value[0] == "'" and value[-1] == "'": + return value + escaped = value.replace("'", "\\'") + return f"'{escaped}'" + + +# ============================================================================= +# JSON Schema Utilities +# ============================================================================= + + +def minimize_json_schema(schema: dict) -> dict[str, dict[str, str]]: + """ + Generate a mapping between field ids and names/types from a JSON schema. + Useful when generating formulas to understand the provided context. + + :param schema: JSON schema dict with properties and metadata. + :return: Mapping of field_key -> {id, name, type, desc, ...}. + """ + field_type_descriptions = { + "link_row": "the row ID as number or the primary field value as string", + "single_select": "the option ID as number or the value as string", + "multiple_select": "a comma separated list of option IDs or values as string", + "date": "a date string in ISO 8601 format", + "date_time": "a date-time string in ISO 8601 format", + "boolean": "true or false", + } + field_type_extra_info = { + "single_select": lambda meta: { + "select_options": meta.get("select_options", []) + }, + "multiple_select": lambda meta: { + "select_options": meta.get("select_options", []) + }, + "multiple_collaborators": lambda meta: { + "available_collaborators": meta.get("available_collaborators", []) + }, + } + + if schema.get("type") == "array": + return minimize_json_schema(schema.get("items")) + elif schema.get("type") != "object": + raise ValueError("Schema must be of type object or array of objects") + + properties = schema.get("properties", {}) + mapping = {} + for key, prop in properties.items(): + metadata = prop.get("metadata") + if metadata: + field_type = metadata["type"] + mapping[key] = { + "id": metadata["id"], + "name": metadata["name"], + "type": field_type, + "desc": field_type_descriptions.get(field_type, ""), + } + if field_type in field_type_extra_info: + get_extra_info = field_type_extra_info[field_type] + mapping[key].update(get_extra_info(metadata)) + return mapping + + +def create_example_from_json_schema(schema: dict) -> Any: + """ + Generate example data from a JSON schema. + Useful when generating formulas to provide example context data. + + :param schema: JSON schema dict. + :return: Example data matching the schema structure. + """ + examples = { + "string": "1", + "number": 1, + "boolean": True, + "null": None, + "object": lambda prop: create_example_from_json_schema(prop), + "array": lambda prop: [create_example_from_json_schema(prop["items"])], + } + + if schema.get("type") == "array": + return [create_example_from_json_schema(schema.get("items"))] + elif schema.get("type") != "object": + raise ValueError("Schema must be of type object or array of objects") + + properties = schema.get("properties", {}) + example = {} + for key, prop in properties.items(): + value = examples[prop.get("type")] + if callable(value): + example[key] = value(prop) + else: + example[key] = value + return example + + +# ============================================================================= +# Base Formula Context +# ============================================================================= + + +class BaseFormulaContext(FormulaContext, ABC): + """ + Base context for formula generation, shared between automation and builder. + + Subclasses must implement get_formula_context() and __getitem__ for + path resolution. + """ + + def __init__(self): + self.context: dict[str, Any] = {} + self.context_metadata: dict[str, Any] = {} + super().__init__() + + def add_context( + self, + key: str, + example_data: Any, + metadata: dict[str, Any] | None = None, + ): + """ + Add data to the formula context. + + :param key: Context key (e.g., "data_source.5" or "1" for node ID). + :param example_data: Example data for this context entry. + :param metadata: Optional metadata describing the structure. + """ + self.context[key] = example_data + if metadata: + self.context_metadata[key] = metadata + + @abstractmethod + def get_formula_context(self) -> dict[str, Any]: + """Return the context dict for formula generation.""" + pass + + def get_context_metadata(self) -> dict[str, Any]: + """Return metadata about the context.""" + return self.context_metadata + + def _resolve_path(self, key: str, root_key: str) -> Any: + """ + Resolve a dotted path through the context. + + :param key: Full path like "data_source.5.field_name". + :param root_key: Expected root key to validate against. + :return: The resolved value. + :raises KeyError: If path cannot be resolved. + :raises ValueError: If resolved value is not a primitive type. + """ + start, *key_parts = to_path(key) + if start != root_key: + raise KeyError( + f"Key '{key}' not found in context. " + f"Only '{root_key}' is supported at the root level." + ) + + value = self.context + for kp in key_parts: + try: + value = value[int(kp) if isinstance(value, list) else kp] + except (KeyError, TypeError, ValueError): + available_keys = ( + list(value.keys()) + if isinstance(value, dict) + else ", ".join(map(str, range(len(value)))) + ) + raise KeyError( + f"Key '{kp}' of '{key}' not found in {value}, " + f"Available keys: {available_keys}" + ) + + if not isinstance(value, (int, float, str, bool, date, datetime)): + raise ValueError( + f"Value for key '{key}' is not a valid type. " + f"Expected int, float, str, bool, date, or datetime. " + f"Got {type(value).__name__} instead. " + f"Make sure to only reference primitive types in the formula context." + ) + return value diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/toolset.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/toolset.py new file mode 100644 index 0000000000..f4e4497a4c --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/toolset.py @@ -0,0 +1,438 @@ +""" +Pydantic-ai toolset utilities for the assistant. + +Contains schema helpers (``inline_refs``), lenient argument validation, +the ``InlineRefsToolset`` wrapper, ``ModeAwareToolset``, and the compact +tool manifest builder. These are pure toolset concerns with no dependency +on the Baserow registry system. +""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any, Callable + +from loguru import logger +from pydantic import ValidationError +from pydantic_ai import Agent +from pydantic_ai.exceptions import ModelRetry +from pydantic_ai.toolsets import AbstractToolset +from pydantic_ai.toolsets.abstract import AgentDepsT, ToolsetTool +from typing_extensions import Self + +from baserow_enterprise.assistant.deps import AgentMode + +if TYPE_CHECKING: + from baserow_enterprise.assistant.deps import AssistantDeps + +# --------------------------------------------------------------------------- +# Schema utilities +# --------------------------------------------------------------------------- + +# Keys that are JSON Schema / Pydantic metadata the LLM doesn't need. +_STRIP_KEYS = frozenset({"$defs", "discriminator", "title"}) + + +def inline_refs(schema: dict) -> dict: + """ + Recursively resolve all ``$ref`` pointers in a JSON schema, producing a + self-contained schema with no ``$defs`` section. + + Also strips ``discriminator`` and ``title`` metadata that LLMs don't need + and that can contain dangling ``$defs`` references. + + Many LLM providers (especially open-weight models behind Groq) struggle + with ``$ref`` / ``$defs`` indirection. Inlining makes the schema + directly readable by the model. + """ + + defs = schema.get("$defs", {}) + _seen: set[str] = set() # guard against circular refs + + def _resolve(node, *, _inside_properties=False): + if isinstance(node, dict): + if "$ref" in node: + ref_name = node["$ref"].rsplit("/", 1)[-1] + if ref_name in _seen: + return {"type": "object"} # break circular ref + _seen.add(ref_name) + resolved = _resolve(defs[ref_name]) if ref_name in defs else node + _seen.discard(ref_name) + return resolved + result = {} + for k, v in node.items(): + # Strip JSON Schema metadata keys, but never strip property + # names inside a "properties" dict (e.g. a field literally + # named "title" or "description"). + if k in _STRIP_KEYS and not _inside_properties: + continue + result[k] = _resolve(v, _inside_properties=(k == "properties")) + return result + if isinstance(node, list): + return [_resolve(item) for item in node] + return node + + return _resolve(schema) + + +# --------------------------------------------------------------------------- +# Lenient validator & fixer +# --------------------------------------------------------------------------- + +_FIXER_PROMPT = """\ +You are a JSON repair tool. You receive a JSON object that failed schema \ +validation, the validation errors, and the target JSON schema. Return ONLY \ +the fixed JSON object — no explanation, no markdown fences. Preserve the \ +original values as much as possible; only change what is needed to satisfy \ +the schema.""" + + +class _LenientValidator: + """ + Drop-in replacement for pydantic-core ``SchemaValidator`` that parses + JSON without enforcing the tool's parameter schema. + + Real validation happens later in ``InlineRefsToolset.call_tool()``, + where we can attempt an async structured-output fix before failing. + """ + + def validate_json(self, input, *, allow_partial="off", **kwargs): + if isinstance(input, (str, bytes, bytearray)): + return json.loads(input) if input else {} + return input + + def validate_python(self, input, *, allow_partial="off", **kwargs): + return input if input is not None else {} + + +_LENIENT_VALIDATOR = _LenientValidator() + + +# --------------------------------------------------------------------------- +# InlineRefsToolset +# --------------------------------------------------------------------------- + + +class InlineRefsToolset(AbstractToolset[AgentDepsT]): + """ + Wraps another toolset with two responsibilities: + + 1. **Inline $ref/$defs** in tool parameter schemas so open-weight models + can parse them directly. + 2. **Fix broken tool args** via a lightweight structured-output call + instead of going through the full agent retry loop (which is slow + and rarely succeeds). + """ + + def __init__(self, inner: AbstractToolset[AgentDepsT], model: str): + self._inner = inner + self._model = model + self._original_validators: dict[str, Any] = {} + self._schemas: dict[str, dict] = {} + + @property + def id(self) -> str: + return self._inner.id + + # --- Delegation methods (match WrapperToolset pattern) --- + + async def __aenter__(self) -> Self: + await self._inner.__aenter__() + return self + + async def __aexit__(self, *args: Any) -> bool | None: + return await self._inner.__aexit__(*args) + + def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None: + self._inner.apply(visitor) + + def visit_and_replace( + self, + visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]], + ) -> AbstractToolset[AgentDepsT]: + new = InlineRefsToolset( + self._inner.visit_and_replace(visitor), model=self._model + ) + return new + + # --- Tool interception --- + + async def get_tools(self, ctx) -> dict[str, ToolsetTool[AgentDepsT]]: + tools = await self._inner.get_tools(ctx) + for name, tool in tools.items(): + # Inline $ref/$defs in the JSON schema + tool.tool_def.parameters_json_schema = inline_refs( + tool.tool_def.parameters_json_schema + ) + # Save the original validator and schema once, then replace with + # lenient passthrough so validation failures reach call_tool() + # where we can attempt an async fix. Guard against multiple calls + # so we don't overwrite the real validator with _LENIENT_VALIDATOR. + if name not in self._original_validators: + self._original_validators[name] = tool.args_validator + self._schemas[name] = tool.tool_def.parameters_json_schema + tool.args_validator = _LENIENT_VALIDATOR + return tools + + async def call_tool( + self, + name: str, + tool_args: dict[str, Any], + ctx: Any, + tool: ToolsetTool[AgentDepsT], + ) -> Any: + original_validator = self._original_validators.get(name) + if original_validator: + try: + tool_args = original_validator.validate_python(tool_args) + except ValidationError as e: + tool_args = await self._fix_tool_args(name, tool_args, e) + return await self._inner.call_tool(name, tool_args, ctx, tool) + + async def _fix_tool_args( + self, + tool_name: str, + wrong_args: dict[str, Any], + error: ValidationError, + ) -> dict[str, Any]: + """ + Attempt to fix invalid tool arguments via a lightweight structured- + output call. If the fix also fails validation, raises ``ModelRetry`` + so pydantic-ai can handle it normally. + """ + + schema = self._schemas.get(tool_name, {}) + error_details = error.errors(include_url=False, include_context=False) + + logger.warning( + "[assistant] Tool '{}' args failed validation, attempting fix. Errors: {}", + tool_name, + error_details, + ) + + prompt = ( + f"Tool: {tool_name}\n\n" + f"Schema:\n{json.dumps(schema, indent=2)}\n\n" + f"Invalid input:\n{json.dumps(wrong_args, indent=2)}\n\n" + f"Validation errors:\n{json.dumps(error_details, indent=2)}" + ) + + try: + fix_agent = Agent( + output_type=str, + instructions=_FIXER_PROMPT, + name="fix_agent", + ) + from baserow_enterprise.assistant.model_profiles import ( + UTILITY, + get_model_settings, + ) + + fixer_settings = get_model_settings(self._model, UTILITY) + result = await fix_agent.run( + prompt, + model=self._model, + model_settings={ + **fixer_settings, + "response_format": {"type": "json_object"}, + }, + ) + fixed_args = json.loads(result.output) + except Exception as exc: + logger.warning( + "[assistant] Fixer call failed for tool '{}': {}", + tool_name, + exc, + ) + raise ModelRetry( + f"Tool arguments invalid and fix attempt failed: {error_details}" + ) from exc + + # Re-validate with original schema + original_validator = self._original_validators[tool_name] + try: + validated = original_validator.validate_python(fixed_args) + except ValidationError as e2: + logger.warning( + "[assistant] Fixed args for tool '{}' still invalid: {}", + tool_name, + e2.errors(include_url=False, include_context=False), + ) + raise ModelRetry( + f"Tool arguments still invalid after fix attempt: " + f"{e2.errors(include_url=False, include_context=False)}" + ) from e2 + + return validated + + +# --------------------------------------------------------------------------- +# Mode-aware toolset +# --------------------------------------------------------------------------- + + +def _build_mode_tool_map() -> dict[AgentMode, frozenset[str]]: + """Build mode → tool-names mapping from actual function references. + + Derives names via ``f.__name__`` instead of hand-maintained string + lists to eliminate typo risk. + """ + + from .automation.tools import TOOL_FUNCTIONS as AUTO_FN + from .core.tools import create_builders, list_builders, switch_mode + from .database.tools import TOOL_FUNCTIONS as DB_FN + from .navigation.tools import navigate + from .search_user_docs.tools import search_user_docs + + try: + from .builder.tools import TOOL_FUNCTIONS as BUILDER_FN + except ImportError: + BUILDER_FN = [] + + n = frozenset # alias for readability + + def names(*funcs): + return n(f.__name__ for f in funcs) + + shared = names( + navigate, + switch_mode, + list_builders, + # Read-only database tools available in every mode + *[f for f in DB_FN if f.__name__.startswith(("list_", "get_"))], + ) + + return { + AgentMode.DATABASE: shared | names(*DB_FN, create_builders), + AgentMode.APPLICATION: shared | names(*BUILDER_FN, create_builders), + AgentMode.AUTOMATION: shared | names(*AUTO_FN, create_builders), + AgentMode.EXPLAIN: shared + | names( + *[f for f in BUILDER_FN if f.__name__.startswith("list_")], + *[f for f in AUTO_FN if f.__name__.startswith("list_")], + search_user_docs, + ), + } + + +_MODE_TOOL_MAP: dict[AgentMode, frozenset[str]] | None = None + + +def _get_mode_tool_map() -> dict[AgentMode, frozenset[str]]: + global _MODE_TOOL_MAP + if _MODE_TOOL_MAP is None: + _MODE_TOOL_MAP = _build_mode_tool_map() + return _MODE_TOOL_MAP + + +class ModeAwareToolset(AbstractToolset[AgentDepsT]): + """ + Filters the inner toolset based on the current :class:`AgentMode`. + + Each domain mode (DATABASE, APPLICATION, AUTOMATION) exposes only its + relevant tools plus shared read-only tools. EXPLAIN mode exposes + read-only tools plus ``search_user_docs``. + """ + + def __init__(self, inner: AbstractToolset[AgentDepsT], deps: "AssistantDeps"): + self._inner = inner + self._deps = deps + + @property + def id(self) -> str: + return self._inner.id + + async def __aenter__(self) -> Self: + await self._inner.__aenter__() + return self + + async def __aexit__(self, *args: Any) -> bool | None: + return await self._inner.__aexit__(*args) + + def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None: + self._inner.apply(visitor) + + def visit_and_replace( + self, + visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]], + ) -> AbstractToolset[AgentDepsT]: + return ModeAwareToolset(self._inner.visit_and_replace(visitor), self._deps) + + async def get_tools(self, ctx) -> dict[str, ToolsetTool[AgentDepsT]]: + all_tools = await self._inner.get_tools(ctx) + allowed = _get_mode_tool_map()[self._deps.mode] + return {k: v for k, v in all_tools.items() if k in allowed} + + async def call_tool( + self, + name: str, + tool_args: dict[str, Any], + ctx: Any, + tool: ToolsetTool[AgentDepsT], + ) -> Any: + from baserow.core.exceptions import UserNotInWorkspace + from baserow_enterprise.assistant.tools.database.helpers import ToolInputError + + try: + return await self._inner.call_tool(name, tool_args, ctx, tool) + except ToolInputError as exc: + return {"error": str(exc)} + except UserNotInWorkspace: + return { + "error": ( + "One or more IDs reference a resource outside the current " + "workspace. Use the appropriate list_* tool to find " + "the correct IDs and retry." + ) + } + + +# --------------------------------------------------------------------------- +# Compact tool manifest +# --------------------------------------------------------------------------- + + +def tool_manifest_line_compact(name: str, description: str) -> str: + """Format a single tool entry — first line of description only.""" + + desc = description.strip() + first_line = desc.split("\n")[0].strip() if desc else name + return f"- {name}: {first_line}" + + +_MODULE_LABELS: dict[str, str] = { + "core": "Core (workspace & modules)", + "navigation": "Navigation", + "database": "Database (tables, fields, views, rows)", + "builder": "Application Builder (pages, elements, data sources, actions)", + "automation": "Automations (workflows, triggers, actions)", + "search_user_docs": "Documentation", +} + + +def generate_tool_manifest_compact( + module_groups: list[tuple[str, list[Callable]]], + routing_rules: str = "", +) -> str: + """ + Build a compact ```` manifest: routing rules + tools + grouped by module with section headers. + + :param module_groups: ``(module_type, funcs)`` pairs, one per module. + :param routing_rules: Cross-tool routing rules to prepend. + :return: A newline-separated manifest string. + """ + + lines: list[str] = [] + if routing_rules: + lines.append(routing_rules.strip()) + lines.append("") + for module_type, funcs in module_groups: + if not funcs: + continue + label = _MODULE_LABELS.get(module_type, module_type) + lines.append(f"## {label}") + for func in funcs: + lines.append(tool_manifest_line_compact(func.__name__, func.__doc__ or "")) + lines.append("") + return "\n".join(lines).rstrip() diff --git a/enterprise/backend/src/baserow_enterprise/assistant/types.py b/enterprise/backend/src/baserow_enterprise/assistant/types.py index 080dbee730..21be1de4bb 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/types.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/types.py @@ -4,7 +4,6 @@ from django.utils.translation import gettext as _ -import udspy from pydantic import BaseModel as PydanticBaseModel from pydantic import ConfigDict, Field @@ -12,6 +11,7 @@ class BaseModel(PydanticBaseModel): model_config = ConfigDict( extra="forbid", + coerce_numbers_to_str=True, ) @@ -157,7 +157,7 @@ class AiMessage(AiMessageChunk): ) -class AiThinkingMessage(BaseModel, udspy.StreamEvent): +class AiThinkingMessage(BaseModel): type: Literal["ai/thinking"] = AssistantMessageType.AI_THINKING.value content: str = Field( default="", @@ -236,16 +236,27 @@ def to_localized_string(self): return _("workflow %(workflow_name)s") % {"workflow_name": self.workflow_name} +class BuilderPageNavigationType(BaseModel): + type: Literal["builder-page"] + application_id: int + page_id: int + page_name: str + + def to_localized_string(self): + return _("page %(page_name)s") % {"page_name": self.page_name} + + AnyNavigationType = Annotated[ TableNavigationType | WorkspaceNavigationType | ViewNavigationType - | WorkflowNavigationType, + | WorkflowNavigationType + | BuilderPageNavigationType, Field(discriminator="type"), ] -class AiNavigationMessage(BaseModel, udspy.StreamEvent): +class AiNavigationMessage(BaseModel): type: Literal["ai/navigation"] = "ai/navigation" location: AnyNavigationType diff --git a/enterprise/backend/src/baserow_enterprise/config/settings/settings.py b/enterprise/backend/src/baserow_enterprise/config/settings/settings.py index d3e8852e11..595288e713 100644 --- a/enterprise/backend/src/baserow_enterprise/config/settings/settings.py +++ b/enterprise/backend/src/baserow_enterprise/config/settings/settings.py @@ -79,6 +79,35 @@ def setup(settings): settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL = os.getenv( "BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL", "" ) - settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_TEMPERATURE = float( - os.getenv("BASEROW_ENTERPRISE_ASSISTANT_LLM_TEMPERATURE", "") or 0.3 + _temp_raw = os.getenv("BASEROW_ENTERPRISE_ASSISTANT_LLM_TEMPERATURE", "") + settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_TEMPERATURE = ( + float(_temp_raw) if _temp_raw else None ) + + # Backward compatibility: bridge old UDSPY_LM_* env vars so existing + # deployments continue to work without config changes. + _udspy_model = os.getenv("UDSPY_LM_MODEL", "") + if _udspy_model and not settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL: + settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL = _udspy_model + + _udspy_api_key = os.getenv("UDSPY_LM_API_KEY", "") + if _udspy_api_key: + # pydantic-ai reads provider-specific env vars. Set them all as + # fallbacks so the old catch-all key works regardless of provider. + for _key in ( + "OPENAI_API_KEY", + "GROQ_API_KEY", + "ANTHROPIC_API_KEY", + "GEMINI_API_KEY", + ): + os.environ.setdefault(_key, _udspy_api_key) + + _udspy_base_url = os.getenv("UDSPY_LM_OPENAI_COMPATIBLE_BASE_URL", "") + if _udspy_base_url: + # pydantic-ai's OpenAI provider reads OPENAI_BASE_URL. + os.environ.setdefault("OPENAI_BASE_URL", _udspy_base_url) + + # Bridge old AWS_REGION_NAME to boto3's standard AWS_DEFAULT_REGION. + _aws_region = os.getenv("AWS_REGION_NAME", "") + if _aws_region: + os.environ.setdefault("AWS_DEFAULT_REGION", _aws_region) diff --git a/enterprise/backend/src/baserow_enterprise/migrations/0058_assistantchat_message_history.py b/enterprise/backend/src/baserow_enterprise/migrations/0058_assistantchat_message_history.py new file mode 100644 index 0000000000..1ebf92109c --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/migrations/0058_assistantchat_message_history.py @@ -0,0 +1,24 @@ +# Generated by Django 5.0.13 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("baserow_enterprise", "0057_role_hidden"), + ] + + operations = [ + migrations.AddField( + model_name="assistantchat", + name="message_history", + field=models.BinaryField( + blank=True, + help_text=( + "Serialized pydantic-ai message history (JSON bytes) for " + "multi-turn conversation context." + ), + null=True, + ), + ), + ] diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/__init__.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/__init__.py @@ -0,0 +1 @@ + diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/conftest.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/conftest.py new file mode 100644 index 0000000000..ea0299c75a --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/conftest.py @@ -0,0 +1,153 @@ +import asyncio +import logging +import os +import sys + +from django.conf import settings + +import pytest +from loguru import logger + +from baserow.config.settings.test import TEST_ENV_VARS + +# Suppress DEBUG-level loguru output during evals. Baserow's cache layer logs +# every cache hit/miss at DEBUG, which floods the output when using -s. Agent +# message history is printed via print() and is captured by pytest: it appears +# in the failure report automatically without needing -s. +logger.remove() +logger.add(sys.stderr, level="WARNING") + +# Expose API keys from TEST_ENV_FILE to os.environ so that LLM provider +# SDKs (which read os.getenv() at import/construction time) can find them. +# test.py already parses TEST_ENV_FILE via dotenv_values but deliberately +# does NOT inject non-allowlisted keys into os.environ. We bridge that +# gap here for the small set of keys the eval suite needs. +_API_KEY_NAMES = ("GROQ_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY") +for _k in _API_KEY_NAMES: + if (_v := TEST_ENV_VARS.get(_k)) and not os.environ.get(_k): + os.environ[_k] = _v + + +_EVALS_DIR = os.path.dirname(__file__) + + +def _evals_explicitly_requested(config): + """Return True when the user intentionally targeted eval tests.""" + + # ``-m eval`` on the command line + marker_expr = config.getoption("-m", default="") + if "eval" in marker_expr: + return True + + # User pointed pytest at an eval file/directory (e.g. VSCode test runner) + for arg in config.args: + if os.path.abspath(arg).startswith(_EVALS_DIR): + return True + + return False + + +def pytest_collection_modifyitems(config, items): + """Skip eval tests unless explicitly requested (``-m eval`` or by path). + + Also wires up ``EVAL_RETRIES``: when set to a positive integer, every eval + test is automatically marked with ``pytest.mark.retry(N)`` so that failing + tests are re-run up to N times. A test that passes on retry is a flake + (LLM non-determinism); one that fails all N retries is a consistent bug. + """ + + if not _evals_explicitly_requested(config): + skip_eval = pytest.mark.skip(reason="eval tests only run with -m eval") + for item in items: + if item.get_closest_marker("eval"): + item.add_marker(skip_eval) + return + + eval_retries = int(os.environ.get("EVAL_RETRIES", "0")) + if eval_retries > 0: + for item in items: + if item.get_closest_marker("eval"): + item.add_marker(pytest.mark.retry(eval_retries)) + + +def pytest_generate_tests(metafunc): + """Auto-parametrize tests that use the ``eval_model`` fixture.""" + + if "eval_model" in metafunc.fixturenames: + from .eval_utils import get_eval_model + + model_str = get_eval_model() + models = [m.strip() for m in model_str.split(",") if m.strip()] + metafunc.parametrize("eval_model", models, scope="session") + + +@pytest.fixture(scope="session") +def synced_knowledge_base(django_db_blocker): + """ + Sync the knowledge base once per pytest session if not already populated. + + With ``--reuse-db`` the DB persists across sessions, so the (slow) + embedding + sync step only runs the very first time. Subsequent + sessions detect that the KB is already populated and return immediately. + """ + + with django_db_blocker.unblock(): + if not getattr(settings, "BASEROW_EMBEDDINGS_API_URL", ""): + return # No embeddings server → nothing to sync + + from baserow_enterprise.assistant.tools.search_user_docs.handler import ( + KnowledgeBaseHandler, + ) + + handler = KnowledgeBaseHandler() + + if handler.can_search(): + return # Already populated (e.g. --reuse-db from a previous run) + + if not handler.can_have_knowledge_base(): + return # pgvector not available + + print("\n[eval] Syncing knowledge base (first run — this may take a while)...") + handler.sync_knowledge_base() + print("[eval] Knowledge base sync complete.") + + +@pytest.fixture(autouse=True) +def suppress_asyncio_stopiteration_error(): + """ + Suppress the 'StopIteration interacts badly with generators' asyncio error. + + This is a known Python issue when generators raise StopIteration in contexts + where asyncio futures are involved. The error is harmless but noisy. + """ + original_handler = None + + def custom_exception_handler(loop, context): + exception = context.get("exception") + if isinstance(exception, TypeError) and "StopIteration" in str(exception): + return # Suppress this specific error + if original_handler: + original_handler(loop, context) + else: + loop.default_exception_handler(context) + + try: + loop = asyncio.get_event_loop() + original_handler = loop.get_exception_handler() + loop.set_exception_handler(custom_exception_handler) + except RuntimeError: + pass # No event loop + + # Also suppress the log message + asyncio_logger = logging.getLogger("asyncio") + original_level = asyncio_logger.level + asyncio_logger.setLevel(logging.CRITICAL) + + yield + + asyncio_logger.setLevel(original_level) + try: + loop = asyncio.get_event_loop() + loop.set_exception_handler(original_handler) + except RuntimeError: + pass diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/eval_utils.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/eval_utils.py new file mode 100644 index 0000000000..e4d9963fca --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/eval_utils.py @@ -0,0 +1,372 @@ +""" +Shared utilities for assistant evals (single-agent architecture). + +These utilities are used by multiple eval test files and provide: +- LLM configuration +- UIContext building +- Callback tracking for assertions +- Assistant creation helpers +- Message history formatting for inspection +""" + +import json +import os + +from pydantic_ai.usage import UsageLimits + +from baserow_enterprise.assistant.agents import main_agent +from baserow_enterprise.assistant.deps import AssistantDeps, ToolHelpers +from baserow_enterprise.assistant.tools.registries import assistant_tool_registry +from baserow_enterprise.assistant.types import ( + ApplicationUIContext, + TableUIContext, + UIContext, + UserUIContext, + WorkspaceUIContext, +) + +# Default model for evals - can be overridden via EVAL_LLM_MODEL env var +DEFAULT_EVAL_MODEL = "groq:openai/gpt-oss-120b" + + +def build_database_ui_context(user, workspace, database=None, table=None) -> str: + """ + Build a UIContext for a database, formatted as JSON string. + + This tells the agent which workspace/database/table the user is viewing. + """ + ctx = UIContext( + workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), + database=ApplicationUIContext(id=str(database.id), name=database.name) + if database + else None, + table=TableUIContext(id=table.id, name=table.name) if table else None, + user=UserUIContext(id=user.id, name=user.first_name, email=user.email), + ) + return ctx.format() + + +def format_message_history(result) -> list[dict]: + """ + Format the full message history from an agent run for inspection. + + Returns a list of dicts with structured info about each message: + - role: system/user/assistant/tool + - type: the pydantic-ai message class name + - content: text content (if any) + - tool_calls: list of tool call info (if any) + - tool_name: name of tool that returned this result (for tool results) + - timestamp: message timestamp (if available) + """ + from pydantic_ai.messages import ( + ModelRequest, + ModelResponse, + ) + + messages = getattr(result, "all_messages", lambda: [])() or [] + formatted = [] + + for msg in messages: + if isinstance(msg, ModelRequest): + for part in msg.parts: + part_type = type(part).__name__ + entry = {"role": "user", "type": part_type} + + if hasattr(part, "content"): + entry["content"] = part.content + if hasattr(part, "tool_name"): + entry["tool_name"] = part.tool_name + if hasattr(part, "tool_call_id"): + entry["tool_call_id"] = part.tool_call_id + if hasattr(part, "timestamp"): + entry["timestamp"] = str(part.timestamp) + + formatted.append(entry) + + elif isinstance(msg, ModelResponse): + for part in msg.parts: + part_type = type(part).__name__ + entry = {"role": "assistant", "type": part_type} + + if hasattr(part, "content"): + entry["content"] = part.content + if hasattr(part, "tool_name"): + entry["tool_name"] = part.tool_name + if hasattr(part, "tool_call_id"): + entry["tool_call_id"] = part.tool_call_id + if hasattr(part, "args"): + # Tool call arguments + args = part.args + if isinstance(args, str): + try: + args = json.loads(args) + except (json.JSONDecodeError, TypeError): + pass + entry["args"] = args + + formatted.append(entry) + + return formatted + + +def print_message_history(result, max_content_len=1000): + """ + Print a human-readable summary of the full message history. + + Shows all LLM requests, responses, tool calls, and tool results + in chronological order. + """ + history = format_message_history(result) + + print("\n" + "=" * 80) + print("MESSAGE HISTORY") + print("=" * 80) + + for i, entry in enumerate(history): + role = entry["role"].upper() + msg_type = entry.get("type", "unknown") + print(f"\n--- [{i + 1}] {role} ({msg_type}) ---") + + if "content" in entry: + content = str(entry["content"]) + if len(content) > max_content_len: + content = content[:max_content_len] + "..." + print(f" Content: {content}") + + if "tool_name" in entry: + print(f" Tool: {entry['tool_name']}") + + if "args" in entry: + args_str = json.dumps(entry["args"], indent=2, default=str) + if len(args_str) > max_content_len: + args_str = args_str[:max_content_len] + "..." + print(f" Args: {args_str}") + + if "tool_call_id" in entry: + print(f" Call ID: {entry['tool_call_id']}") + + print("\n" + "=" * 80) + print(f"Total entries: {len(history)}") + print("=" * 80 + "\n") + + +def print_trajectory(result, max_obs_len=500): + """Debug helper to print the agent's trajectory.""" + print("\n=== TRAJECTORY ===") + # pydantic-ai stores messages differently + for i, msg in enumerate(getattr(result, "all_messages", lambda: [])() or []): + print(f"\n--- Message {i + 1} ---") + print(f" {type(msg).__name__}: {str(msg)[:max_obs_len]}") + print("\n=== END TRAJECTORY ===\n") + + +def get_eval_model() -> str: + """ + Get the model string for evals. + + Configure via EVAL_LLM_MODEL environment variable. + API keys should be set via standard env vars (OPENAI_API_KEY, GROQ_API_KEY). + """ + return os.environ.get("EVAL_LLM_MODEL", DEFAULT_EVAL_MODEL) + + +class EvalToolTracker: + """ + Placeholder for future tool-call instrumentation. + + Currently eval assertions rely on inspecting the pydantic-ai message + history (``RetryPromptPart`` entries) rather than wrapping individual + tools, so this class is intentionally minimal. + """ + + def __init__(self, verbose: bool = True): + self.verbose = verbose + + +def create_eval_assistant(user, workspace, max_iters=15, model=None): + """ + Create an assistant configured like production for evals. + + Returns (agent, deps, tracker, model, usage_limits, toolset) so tests + can run the agent. Uses the single-agent architecture with the full + monolithic toolset from build_assistant_toolset(). + + :param model: Override the LLM model string. Falls back to + ``get_eval_model()`` (i.e. the ``EVAL_LLM_MODEL`` env var). + """ + from django.conf import settings + + tool_helpers = ToolHelpers(lambda x: None, lambda x: None) + tracker = EvalToolTracker() + model = model or get_eval_model() + + # Ensure sub-agents (e.g. formula_agent) also use the eval model. + # get_model_string() does .replace("/", ":", 1) on the setting value, + # so store in "/" format (e.g. "groq/openai/gpt-oss-120b"). + settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL = model.replace(":", "/", 1) + + deps = AssistantDeps( + user=user, + workspace=workspace, + tool_helpers=tool_helpers, + ) + + # Build the single-agent toolset (navigation + core + database + automation) + toolset, db_manifest, app_manifest, auto_manifest, explain_manifest = ( + assistant_tool_registry.build_toolset(user, workspace, model, deps) + ) + deps.database_manifest = db_manifest + deps.application_manifest = app_manifest + deps.automation_manifest = auto_manifest + deps.explain_manifest = explain_manifest + usage_limits = UsageLimits(request_limit=max_iters) + + return main_agent, deps, tracker, model, usage_limits, toolset + + +def get_tool_call_sequence(result) -> list[str]: + """ + Return the ordered list of tool names called during an agent run. + + Extracts assistant-side tool call entries from the message history, + preserving chronological order. + """ + + history = format_message_history(result) + return [ + e["tool_name"] + for e in history + if e["role"] == "assistant" and "tool_name" in e and "args" in e + ] + + +def assert_tool_call_order(result, expected_order: list[str]): + """ + Assert that tools were called in the expected relative order. + + For each consecutive pair (A, B) in *expected_order*, verifies that the + **last** call to A comes before the **first** call to B. This guarantees + that all A work is fully completed before any B work begins. + + Example:: + + assert_tool_call_order(result, [ + "create_pages", + "create_layout_elements", + "create_display_elements", + ]) + """ + + sequence = get_tool_call_sequence(result) + + def _all_indices(tool_name: str) -> list[int]: + indices = [i for i, name in enumerate(sequence) if name == tool_name] + if not indices: + raise AssertionError( + f"Expected tool '{tool_name}' was never called. " + f"Actual sequence: {sequence}" + ) + return indices + + for i in range(len(expected_order) - 1): + name_a = expected_order[i] + name_b = expected_order[i + 1] + last_a = _all_indices(name_a)[-1] + first_b = _all_indices(name_b)[0] + assert last_a < first_b, ( + f"Expected all '{name_a}' calls to finish before any '{name_b}' call, " + f"but last '{name_a}' at pos {last_a} >= first '{name_b}' at pos {first_b}. " + f"Actual sequence: {sequence}" + ) + + +class EvalChecklist: + """ + Soft-assertion context manager for eval tests. + + Collects labelled checks without raising immediately. On exit it prints a + score table (visible with ``-s``) and raises a single AssertionError that + lists every failed check. This lets you see "4/6 (66%)" instead of the + binary "FAIL at first assertion" behaviour of plain ``assert``. + + Usage:: + + with EvalChecklist("creates Bookstore database") as checks: + checks.check("Books table exists", any("book" in n for n in names)) + checks.check("Authors table exists", any("author" in n for n in names), + hint=f"got: {names}") + """ + + def __init__(self, name: str): + self.name = name + self._checks: list[tuple[str, bool, str]] = [] + + def check(self, label: str, condition: bool, hint: str = "") -> bool: + """Record a soft check. Returns the condition value for further use.""" + self._checks.append((label, bool(condition), hint)) + return bool(condition) + + @property + def score(self) -> tuple[int, int]: + passed = sum(1 for _, ok, _ in self._checks if ok) + return passed, len(self._checks) + + def assert_all(self): + passed, total = self.score + pct = 100 * passed // total if total else 0 + lines = [ + f" {'✓' if ok else '✗'} {label}" + + (f" ({hint})" if not ok and hint else "") + for label, ok, hint in self._checks + ] + summary = ( + f"\nEVAL SCORE [{self.name}]: {passed}/{total} ({pct}%)\n" + + "\n".join(lines) + ) + print(summary) + failed = [label for label, ok, _ in self._checks if not ok] + assert not failed, summary + + def __enter__(self): + return self + + def __exit__(self, exc_type, *_): + if exc_type is None: + self.assert_all() + return False + + +def count_tool_errors(result) -> tuple[int, str]: + """ + Count tool validation errors in the agent result. + + Inspects the pydantic-ai message history for ``RetryPromptPart`` entries, + which indicate the LLM sent invalid arguments that failed pydantic + validation. "Unknown tool name" retries are excluded — the LLM explored a + non-existent tool and recovered on its own, which is acceptable. + + Returns ``(error_count, hint)`` suitable for use with + :meth:`EvalChecklist.check`. + """ + from pydantic_ai.messages import ModelRequest, RetryPromptPart + + if result is None: + return 0, "" + + messages = getattr(result, "all_messages", lambda: [])() or [] + retry_errors = [] + for msg in messages: + if isinstance(msg, ModelRequest): + for part in msg.parts: + if isinstance(part, RetryPromptPart): + content = str(part.content) + if "Unknown tool name" in content: + continue + retry_errors.append( + { + "tool_name": getattr(part, "tool_name", None), + "content": content, + } + ) + hint = "\n".join(f" - {e['tool_name']}: {e['content']}" for e in retry_errors) + return len(retry_errors), hint diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_automation_workflows.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_automation_workflows.py new file mode 100644 index 0000000000..6982c1a284 --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_automation_workflows.py @@ -0,0 +1,845 @@ +import pytest + +from baserow.contrib.automation.workflows.models import AutomationWorkflow + +from .eval_utils import ( + EvalChecklist, + build_database_ui_context, + count_tool_errors, + create_eval_assistant, + format_message_history, + print_message_history, +) + +# --------------------------------------------------------------------------- +# Eval prompts — one per test, easy to scan for coverage +# --------------------------------------------------------------------------- + +PROMPT_LISTS_WORKFLOWS = "List the workflows in automation ID {automation_id}" + +PROMPT_CREATES_WORKFLOW = ( + "Create a workflow in automation {automation_name} that " + "triggers when a row is created in table '{table_name}', " + "and updates the Status field to 'Processing'." +) + +PROMPT_CREATES_WEEKLY_SLACK_REMINDER = ( + "In automation '{automation_name}', create a workflow that sends a " + "Slack message to #general every Tuesday at 9am UTC asking " + "'Is there anything to demo this week?'" +) + +PROMPT_CREATES_ROUTER_WORKFLOW = ( + "In automation '{automation_name}', create a workflow that " + "triggers when a row is created in table '{table_name}'. " + "Add a router: if Priority is 'High', send a Slack message to " + "#urgent saying 'High priority ticket created'. " + "If Priority is 'Low', do nothing (just the router branch is fine)." +) + +PROMPT_CREATES_ROW_WITH_FIELD_VALUES = ( + "In automation '{automation_name}', create a workflow that " + "triggers when a row is created in '{source_table_name}'. " + "Then create a row in '{log_table_name}' with Entry set to " + "the new contact's Name and Source set to 'automation'." +) + +PROMPT_CREATES_UPDATE_ROW_WORKFLOW = ( + "In automation '{automation_name}', create a workflow that " + "triggers when a row is updated in '{table_name}'. " + "Then update the same row: set Status to 'Reviewed' and " + "Notes to 'Automatically reviewed by automation'." +) + +PROMPT_CREATES_EMAIL_NOTIFICATION_WORKFLOW = ( + "In automation '{automation_name}', create a workflow that " + "triggers when a row is created in '{table_name}'. " + "Send an email to admin@example.com with subject 'New Order' " + "and body 'A new order has been placed'." +) + + +def _run_agent( + agent, deps, tracker, model, usage_limits, toolset, question, ui_context +): + deps.tool_helpers.request_context["ui_context"] = ui_context + return agent.run_sync( + user_prompt=question, + deps=deps, + model=model, + usage_limits=usage_limits, + toolsets=[toolset], + ) + + +def _get_create_workflows_args(result) -> list[dict]: + """Return the parsed ``args`` dicts of every ``create_workflows`` tool call + the agent made (assistant-side entries have ``args``).""" + + history = format_message_history(result) + return [ + e["args"] + for e in history + if e["role"] == "assistant" + and e.get("tool_name") == "create_workflows" + and "args" in e + ] + + +def _get_workflow_nodes(automation): + """Return (workflow, trigger, action_nodes) for the first workflow.""" + + workflow = AutomationWorkflow.objects.filter(automation=automation).first() + assert workflow is not None, "No workflow was created" + trigger = workflow.get_trigger() + action_nodes = list( + workflow.automation_workflow_nodes.exclude(id=trigger.id).order_by("id") + ) + return workflow, trigger, action_nodes + + +# --------------------------------------------------------------------------- +# Existing evals +# --------------------------------------------------------------------------- + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_lists_workflows(data_fixture, eval_model): + """Agent should call list_workflows when asked about automation workflows.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + automation = data_fixture.create_automation_application( + workspace=workspace, name="My Automation" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=10, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_LISTS_WORKFLOWS.format(automation_id=automation.id), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + history = format_message_history(result) + tool_calls = [ + e + for e in history + if e.get("tool_name") == "list_workflows" and e["role"] == "user" + ] + + with EvalChecklist("lists workflows") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "called list_workflows", + len(tool_calls) >= 1, + hint=f"tools called: {[e.get('tool_name') for e in history if e.get('tool_name')]}", + ) + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_workflow(data_fixture, eval_model): + """Agent should create a workflow when asked to automate a process.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database, name="Orders") + data_fixture.create_text_field(table=table, name="Order ID", primary=True) + data_fixture.create_text_field(table=table, name="Status") + + automation = data_fixture.create_automation_application( + workspace=workspace, name="Order Processing" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=20, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database, table) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_WORKFLOW.format( + automation_name=automation.name, table_name=table.name + ), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + history = format_message_history(result) + tool_calls = [ + e + for e in history + if e.get("tool_name") == "create_workflows" and e["role"] == "user" + ] + workflows = AutomationWorkflow.objects.filter(automation=automation) + + call_args_list = _get_create_workflows_args(result) + args = call_args_list[0] if call_args_list else {} + wf_args = args.get("workflows", [{}])[0] if args.get("workflows") else {} + trigger_args = wf_args.get("trigger", {}) + nodes_args = wf_args.get("nodes", []) + trigger_table_id = trigger_args.get("rows_triggers_settings", {}).get("table_id") + update_nodes_args = [n for n in nodes_args if n.get("type") == "update_row"] + ur_values = update_nodes_args[0].get("values", []) if update_nodes_args else [] + ur_has_processing = any( + "processing" in str(v.get("value", "")).lower() for v in ur_values + ) + + db_ok = workflows.exists() + if db_ok: + workflow, trigger_node, action_nodes = _get_workflow_nodes(automation) + db_trigger_type = trigger_node.service.get_type().type + db_update_actions = [ + n + for n in action_nodes + if n.service.get_type().type == "local_baserow_upsert_row" + ] + else: + db_trigger_type = None + db_update_actions = [] + + with EvalChecklist("creates workflow") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "called create_workflows", + len(tool_calls) >= 1, + hint=f"tools called: {[e.get('tool_name') for e in history if e.get('tool_name')]}", + ) + checks.check( + "workflow created in DB", + db_ok, + ) + checks.check( + "trigger is rows_created", + trigger_args.get("type") == "rows_created", + hint=f"got {trigger_args.get('type')}", + ) + checks.check( + "trigger table is Orders", + trigger_table_id == table.id, + hint=f"got table_id={trigger_table_id}, expected={table.id}", + ) + checks.check( + "update_row node in args", + len(update_nodes_args) >= 1, + hint=f"node types: {[n.get('type') for n in nodes_args]}", + ) + checks.check( + "update_row sets field to 'Processing'", + ur_has_processing, + hint=f"values: {ur_values}", + ) + checks.check( + "DB trigger is rows_created", + db_trigger_type == "local_baserow_rows_created", + hint=f"got {db_trigger_type}", + ) + checks.check( + "update_row action in DB", + len(db_update_actions) >= 1, + ) + + +# --------------------------------------------------------------------------- +# Periodic trigger + Slack message +# --------------------------------------------------------------------------- + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_weekly_slack_reminder(data_fixture, eval_model): + """Agent should create a periodic-WEEK trigger firing on Tuesday with a + Slack message node asking about demos.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + automation = data_fixture.create_automation_application( + workspace=workspace, name="Team Reminders" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=20, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_WEEKLY_SLACK_REMINDER.format( + automation_name=automation.name + ), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + call_args_list = _get_create_workflows_args(result) + args = call_args_list[0] if call_args_list else {} + wf_args = args.get("workflows", [{}])[0] if args.get("workflows") else {} + trigger_args = wf_args.get("trigger", {}) + interval_args = trigger_args.get("periodic_interval", {}) + nodes_args = wf_args.get("nodes", []) + slack_nodes_args = [n for n in nodes_args if n.get("type") == "slack_write_message"] + + db_ok = AutomationWorkflow.objects.filter(automation=automation).exists() + if db_ok: + workflow, trigger_node, action_nodes = _get_workflow_nodes(automation) + db_trigger_type = trigger_node.get_type().type + db_slack_actions = [ + n + for n in action_nodes + if n.service.get_type().type == "slack_write_message" + ] + else: + db_trigger_type = None + db_slack_actions = [] + + slack_node = slack_nodes_args[0] if slack_nodes_args else {} + slack_channel = slack_node.get("channel", "") + slack_text = slack_node.get("text", "") + + with EvalChecklist("creates weekly Slack reminder") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "called create_workflows", + len(call_args_list) >= 1, + ) + checks.check( + "trigger type is periodic", + trigger_args.get("type") == "periodic", + hint=f"got {trigger_args.get('type')}", + ) + checks.check( + "interval is WEEK", + interval_args.get("interval") == "WEEK", + hint=f"got {interval_args.get('interval')}", + ) + checks.check( + "day_of_week is 1 (Tuesday)", + interval_args.get("day_of_week") == 1, + hint=f"got {interval_args.get('day_of_week')}", + ) + checks.check( + "slack_write_message node in args", + len(slack_nodes_args) >= 1, + hint=f"node types: {[n.get('type') for n in nodes_args]}", + ) + checks.check( + "workflow created in DB with periodic trigger", + db_trigger_type == "periodic", + hint=f"got {db_trigger_type}", + ) + checks.check( + "Slack action exists in DB", + len(db_slack_actions) >= 1, + ) + checks.check( + "Slack channel is #general", + "general" in slack_channel.lower(), + hint=f"got channel: '{slack_channel}'", + ) + checks.check( + "Slack message mentions demo", + "demo" in slack_text.lower(), + hint=f"got text: '{slack_text}'", + ) + + +# --------------------------------------------------------------------------- +# Router node +# --------------------------------------------------------------------------- + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_router_workflow(data_fixture, eval_model): + """Agent should create a workflow with a router node that branches + based on a condition.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database, name="Tickets") + data_fixture.create_text_field(table=table, name="Title", primary=True) + priority_field = data_fixture.create_single_select_field( + table=table, name="Priority" + ) + data_fixture.create_select_option(field=priority_field, value="High", order=0) + data_fixture.create_select_option(field=priority_field, value="Low", order=1) + + automation = data_fixture.create_automation_application( + workspace=workspace, name="Ticket Router" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=20, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database, table) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_ROUTER_WORKFLOW.format( + automation_name=automation.name, table_name=table.name + ), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + call_args_list = _get_create_workflows_args(result) + args = call_args_list[0] if call_args_list else {} + wf_args = args.get("workflows", [{}])[0] if args.get("workflows") else {} + nodes_args = wf_args.get("nodes", []) + router_nodes_args = [n for n in nodes_args if n.get("type") == "router"] + router_edges_args = ( + router_nodes_args[0].get("edges", []) if router_nodes_args else [] + ) + + db_ok = AutomationWorkflow.objects.filter(automation=automation).exists() + if db_ok: + workflow, trigger_node, action_nodes = _get_workflow_nodes(automation) + db_router_actions = [ + n for n in action_nodes if n.service.get_type().type == "router" + ] + db_edges_count = ( + db_router_actions[0].service.specific.edges.count() + if db_router_actions + else 0 + ) + else: + db_router_actions = [] + db_edges_count = 0 + + trigger_args = wf_args.get("trigger", {}) + trigger_table_id = trigger_args.get("rows_triggers_settings", {}).get("table_id") + slack_nodes_in_nodes = [ + n for n in nodes_args if n.get("type") == "slack_write_message" + ] + slack_channel = ( + slack_nodes_in_nodes[0].get("channel", "") if slack_nodes_in_nodes else "" + ) + + with EvalChecklist("creates router workflow") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check("called create_workflows", len(call_args_list) >= 1) + checks.check( + "trigger is rows_created", + trigger_args.get("type") == "rows_created", + hint=f"got {trigger_args.get('type')}", + ) + checks.check( + "trigger table is Tickets", + trigger_table_id == table.id, + hint=f"got table_id={trigger_table_id}, expected={table.id}", + ) + checks.check( + "router node in args", + len(router_nodes_args) >= 1, + hint=f"node types: {[n.get('type') for n in nodes_args]}", + ) + checks.check( + "router has >=2 edges in args", + len(router_edges_args) >= 2, + hint=f"got {len(router_edges_args)}", + ) + checks.check( + "router node in DB", + len(db_router_actions) >= 1, + ) + checks.check( + "router has >=2 edges in DB", + db_edges_count >= 2, + hint=f"got {db_edges_count}", + ) + checks.check( + "Slack node exists for High branch", + len(slack_nodes_in_nodes) >= 1, + hint=f"node types: {[n.get('type') for n in nodes_args]}", + ) + checks.check( + "Slack channel is #urgent", + "urgent" in slack_channel.lower(), + hint=f"got channel: '{slack_channel}'", + ) + + +# --------------------------------------------------------------------------- +# Create-row / update-row with field value formulas +# --------------------------------------------------------------------------- + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_row_with_field_values(data_fixture, eval_model): + """Agent should create a workflow with a create_row node that maps + specific field values (including formula-style references).""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + + source_table = data_fixture.create_database_table( + database=database, name="Contacts" + ) + data_fixture.create_text_field(table=source_table, name="Name", primary=True) + data_fixture.create_email_field(table=source_table, name="Email") + + log_table = data_fixture.create_database_table(database=database, name="Log") + data_fixture.create_text_field(table=log_table, name="Entry", primary=True) + data_fixture.create_text_field(table=log_table, name="Source") + + automation = data_fixture.create_automation_application( + workspace=workspace, name="Contact Logger" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=20, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database, source_table) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_ROW_WITH_FIELD_VALUES.format( + automation_name=automation.name, + source_table_name=source_table.name, + log_table_name=log_table.name, + ), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + call_args_list = _get_create_workflows_args(result) + args = call_args_list[0] if call_args_list else {} + wf_args = args.get("workflows", [{}])[0] if args.get("workflows") else {} + trigger_args = wf_args.get("trigger", {}) + nodes_args = wf_args.get("nodes", []) + create_row_nodes_args = [n for n in nodes_args if n.get("type") == "create_row"] + cr_values = ( + create_row_nodes_args[0].get("values", []) if create_row_nodes_args else [] + ) + + db_ok = AutomationWorkflow.objects.filter(automation=automation).exists() + if db_ok: + workflow, trigger_node, action_nodes = _get_workflow_nodes(automation) + db_trigger_type = trigger_node.service.get_type().type + db_create_actions = [ + n + for n in action_nodes + if n.service.get_type().type == "local_baserow_upsert_row" + ] + else: + db_trigger_type = None + db_create_actions = [] + + trigger_table_id = trigger_args.get("rows_triggers_settings", {}).get("table_id") + cr_node = create_row_nodes_args[0] if create_row_nodes_args else {} + cr_table_id = cr_node.get("table_id") + cr_has_literal_automation = any( + "automation" in str(v.get("value", "")).lower() for v in cr_values + ) + + with EvalChecklist("creates row with field values") as checks: + checks.check("<=1 tool errors", err_count <= 1, hint=err_hint) + checks.check("called create_workflows", len(call_args_list) >= 1) + checks.check( + "trigger is rows_created", + trigger_args.get("type") == "rows_created", + hint=f"got {trigger_args.get('type')}", + ) + checks.check( + "trigger table is Contacts (source_table)", + trigger_table_id == source_table.id, + hint=f"got table_id={trigger_table_id}, expected={source_table.id}", + ) + checks.check( + "create_row node in args", + len(create_row_nodes_args) >= 1, + hint=f"node types: {[n.get('type') for n in nodes_args]}", + ) + checks.check( + "create_row targets Log table", + cr_table_id == log_table.id, + hint=f"got table_id={cr_table_id}, expected={log_table.id}", + ) + checks.check( + "create_row has >=1 field value", + len(cr_values) >= 1, + hint=f"got {len(cr_values)}", + ) + checks.check( + "create_row has 'automation' literal value (Source field)", + cr_has_literal_automation, + hint=f"values: {cr_values}", + ) + checks.check( + "DB trigger is rows_created", + db_trigger_type == "local_baserow_rows_created", + hint=f"got {db_trigger_type}", + ) + checks.check( + "create_row action in DB", + len(db_create_actions) >= 1, + ) + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_update_row_workflow(data_fixture, eval_model): + """Agent should create a workflow with an update_row node that references + field values from the trigger.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + + table = data_fixture.create_database_table(database=database, name="Tasks") + data_fixture.create_text_field(table=table, name="Task", primary=True) + data_fixture.create_text_field(table=table, name="Status") + data_fixture.create_long_text_field(table=table, name="Notes") + + automation = data_fixture.create_automation_application( + workspace=workspace, name="Task Processor" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=20, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database, table) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_UPDATE_ROW_WORKFLOW.format( + automation_name=automation.name, table_name=table.name + ), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + call_args_list = _get_create_workflows_args(result) + args = call_args_list[0] if call_args_list else {} + wf_args = args.get("workflows", [{}])[0] if args.get("workflows") else {} + trigger_args = wf_args.get("trigger", {}) + nodes_args = wf_args.get("nodes", []) + update_nodes_args = [n for n in nodes_args if n.get("type") == "update_row"] + ur = update_nodes_args[0] if update_nodes_args else {} + + db_ok = AutomationWorkflow.objects.filter(automation=automation).exists() + if db_ok: + workflow, trigger_node, action_nodes = _get_workflow_nodes(automation) + db_trigger_type = trigger_node.service.get_type().type + db_update_actions = [ + n + for n in action_nodes + if n.service.get_type().type == "local_baserow_upsert_row" + ] + else: + db_trigger_type = None + db_update_actions = [] + + ur_values = ur.get("values", []) + ur_has_reviewed = any( + "reviewed" in str(v.get("value", "")).lower() for v in ur_values + ) + ur_has_notes = any( + "automation" in str(v.get("value", "")).lower() + or "review" in str(v.get("value", "")).lower() + for v in ur_values + ) + trigger_table_id = trigger_args.get("rows_triggers_settings", {}).get("table_id") + + with EvalChecklist("creates update-row workflow") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check("called create_workflows", len(call_args_list) >= 1) + checks.check( + "trigger is rows_updated", + trigger_args.get("type") == "rows_updated", + hint=f"got {trigger_args.get('type')}", + ) + checks.check( + "trigger table is Tasks", + trigger_table_id == table.id, + hint=f"got table_id={trigger_table_id}, expected={table.id}", + ) + checks.check( + "update_row node in args", + len(update_nodes_args) >= 1, + hint=f"node types: {[n.get('type') for n in nodes_args]}", + ) + checks.check( + "update_row has >=1 field value", + len(ur_values) >= 1, + ) + checks.check( + "update_row has row_id", + bool(ur.get("row_id")), + ) + checks.check( + "update_row sets Status to 'Reviewed'", + ur_has_reviewed, + hint=f"values: {ur_values}", + ) + checks.check( + "update_row sets Notes (automation/reviewed text)", + ur_has_notes, + hint=f"values: {ur_values}", + ) + checks.check( + "DB trigger is rows_updated", + db_trigger_type == "local_baserow_rows_updated", + hint=f"got {db_trigger_type}", + ) + checks.check( + "update_row action in DB", + len(db_update_actions) >= 1, + ) + + +# --------------------------------------------------------------------------- +# Send email node +# --------------------------------------------------------------------------- + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_email_notification_workflow(data_fixture, eval_model): + """Agent should create a workflow with an smtp_email node.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database, name="Orders") + data_fixture.create_text_field(table=table, name="Order ID", primary=True) + data_fixture.create_text_field(table=table, name="Customer Email") + + automation = data_fixture.create_automation_application( + workspace=workspace, name="Order Notifications" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=20, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database, table) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_EMAIL_NOTIFICATION_WORKFLOW.format( + automation_name=automation.name, table_name=table.name + ), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + call_args_list = _get_create_workflows_args(result) + args = call_args_list[0] if call_args_list else {} + wf_args = args.get("workflows", [{}])[0] if args.get("workflows") else {} + trigger_args = wf_args.get("trigger", {}) + trigger_table_id = trigger_args.get("rows_triggers_settings", {}).get("table_id") + nodes_args = wf_args.get("nodes", []) + email_nodes_args = [n for n in nodes_args if n.get("type") == "smtp_email"] + email_node = email_nodes_args[0] if email_nodes_args else {} + email_to = email_node.get("to_emails", "") + email_subject = email_node.get("subject", "") + email_body = email_node.get("body", "") + + db_ok = AutomationWorkflow.objects.filter(automation=automation).exists() + if db_ok: + workflow, trigger_node, action_nodes = _get_workflow_nodes(automation) + db_email_actions = [ + n for n in action_nodes if n.service.get_type().type == "smtp_email" + ] + else: + db_email_actions = [] + + with EvalChecklist("creates email notification workflow") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check("called create_workflows", len(call_args_list) >= 1) + checks.check( + "trigger is rows_created", + trigger_args.get("type") == "rows_created", + hint=f"got {trigger_args.get('type')}", + ) + checks.check( + "trigger table is Orders", + trigger_table_id == table.id, + hint=f"got table_id={trigger_table_id}, expected={table.id}", + ) + checks.check( + "smtp_email node in args", + len(email_nodes_args) >= 1, + hint=f"node types: {[n.get('type') for n in nodes_args]}", + ) + checks.check( + "email to admin@example.com", + "admin@example.com" in email_to, + hint=f"got to: '{email_to}'", + ) + checks.check( + "email subject mentions 'Order'", + "order" in email_subject.lower(), + hint=f"got subject: '{email_subject}'", + ) + checks.check( + "email body mentions order being placed", + "order" in email_body.lower() or "placed" in email_body.lower(), + hint=f"got body: '{email_body}'", + ) + checks.check("workflow created in DB", db_ok) + checks.check( + "smtp_email action in DB", + len(db_email_actions) >= 1, + ) diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_core_builders.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_core_builders.py new file mode 100644 index 0000000000..db0a28ccd0 --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_core_builders.py @@ -0,0 +1,201 @@ +import pytest + +from baserow.contrib.automation.models import Automation +from baserow.contrib.database.models import Database + +from .eval_utils import ( + EvalChecklist, + build_database_ui_context, + count_tool_errors, + create_eval_assistant, + format_message_history, + print_message_history, +) + +# --------------------------------------------------------------------------- +# Eval prompts — one per test, easy to scan for coverage +# --------------------------------------------------------------------------- + +PROMPT_LISTS_DATABASES = "What databases do I have in this workspace?" + +PROMPT_CREATES_DATABASE = "Create a new database called 'Customer Portal'" + +PROMPT_CREATES_AUTOMATION = "Create an empty automation called 'Overdue Task Reminder'." + + +def _run_agent( + agent, deps, tracker, model, usage_limits, toolset, question, ui_context +): + deps.tool_helpers.request_context["ui_context"] = ui_context + return agent.run_sync( + user_prompt=question, + deps=deps, + model=model, + usage_limits=usage_limits, + toolsets=[toolset], + ) + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_lists_databases(data_fixture, eval_model): + """Agent should call list_builders when asked what databases exist.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application( + workspace=workspace, name="Inventory" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=10, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_LISTS_DATABASES, + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + history = format_message_history(result) + tool_calls = [ + e + for e in history + if e.get("tool_name") == "list_builders" and e["role"] == "user" + ] + + with EvalChecklist("lists databases") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "called list_builders", + len(tool_calls) >= 1, + hint=f"tools called: {[e.get('tool_name') for e in history if e.get('tool_name')]}", + ) + checks.check( + "answer mentions 'Inventory'", + "Inventory" in result.output, + hint=f"answer: {result.output[:200]}", + ) + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_database(data_fixture, eval_model): + """Agent should create a new database when asked.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=15, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_DATABASE, + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + history = format_message_history(result) + tool_calls = [ + e + for e in history + if e.get("tool_name") == "create_builders" and e["role"] == "user" + ] + created = Database.objects.filter(workspace=workspace, name__icontains="customer") + + with EvalChecklist("creates database") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "called create_builders", + len(tool_calls) >= 1, + hint=f"tools called: {[e.get('tool_name') for e in history if e.get('tool_name')]}", + ) + checks.check( + "database 'Customer Portal' exists", + created.exists(), + hint=f"databases: {list(Database.objects.filter(workspace=workspace).values_list('name', flat=True))}", + ) + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_automation(data_fixture, eval_model): + """Agent should create a new automation when asked.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=15, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace) + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_AUTOMATION, + ui_context=ui_context, + ) + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + history = format_message_history(result) + tool_calls = [ + e + for e in history + if e.get("tool_name") == "create_builders" and e["role"] == "user" + ] + created = list(Automation.objects.all()) + automation = created[0] if created else None + + with EvalChecklist("creates automation") as checks: + checks.check("<=1 tool errors", err_count <= 1, hint=err_hint) + checks.check( + "called create_builders", + len(tool_calls) >= 1, + hint=f"tools called: {[e.get('tool_name') for e in history if e.get('tool_name')]}", + ) + checks.check( + "exactly 1 automation created", + len(created) == 1, + hint=f"found {len(created)}: {[a.name for a in created]}", + ) + checks.check( + "automation named 'Overdue Task Reminder'", + automation is not None and "overdue" in automation.name.lower(), + hint=f"got: '{automation.name if automation else None}'", + ) + checks.check( + "automation in correct workspace", + automation is not None and automation.workspace_id == workspace.id, + hint=f"workspace_id={automation.workspace_id if automation else None} vs {workspace.id}", + ) + checks.check( + "automation has no workflows", + automation is not None and automation.workflows.count() == 0, + hint=f"workflows: {list(automation.workflows.all()) if automation else []}", + ) diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_database_rows.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_database_rows.py new file mode 100644 index 0000000000..bad28ec3fa --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_database_rows.py @@ -0,0 +1,214 @@ +import pytest + +from baserow.contrib.database.rows.handler import RowHandler + +from .eval_utils import ( + EvalChecklist, + build_database_ui_context, + count_tool_errors, + create_eval_assistant, + print_message_history, +) + +# --------------------------------------------------------------------------- +# Eval prompts — one per test, easy to scan for coverage +# --------------------------------------------------------------------------- + +PROMPT_CREATES_ROWS_WITH_ALL_FIELD_TYPES = ( + "Create 5 rows with diverse sample data in table {table_name}. " + "Fill in ALL fields with realistic values." +) + + +def _create_rich_table(data_fixture): + """ + Create a table with all managed field types plus a linked table + with sample data. + """ + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + + # Linked table (target for link_row fields) + linked_table = data_fixture.create_database_table( + database=database, name="Categories" + ) + linked_primary = data_fixture.create_text_field( + table=linked_table, name="Name", primary=True + ) + + # Populate linked table with sample rows + RowHandler().force_create_rows( + user, + linked_table, + [ + {linked_primary.db_column: "Work"}, + {linked_primary.db_column: "Personal"}, + {linked_primary.db_column: "Urgent"}, + ], + ) + + # Main table with all managed field types + table = data_fixture.create_database_table(database=database, name="Tasks") + title = data_fixture.create_text_field(table=table, name="Title", primary=True) + description = data_fixture.create_long_text_field(table=table, name="Description") + estimated_hours = data_fixture.create_number_field( + table=table, name="Estimated Hours", number_decimal_places=1 + ) + completed = data_fixture.create_boolean_field(table=table, name="Completed") + due_date = data_fixture.create_date_field(table=table, name="Due Date") + created_at = data_fixture.create_date_field( + table=table, name="Created At", date_include_time=True + ) + + status_field = data_fixture.create_single_select_field(table=table, name="Status") + data_fixture.create_select_option(field=status_field, value="To Do", order=0) + data_fixture.create_select_option(field=status_field, value="In Progress", order=1) + data_fixture.create_select_option(field=status_field, value="Done", order=2) + + tags_field = data_fixture.create_multiple_select_field(table=table, name="Tags") + data_fixture.create_select_option(field=tags_field, value="Bug", order=0) + data_fixture.create_select_option(field=tags_field, value="Feature", order=1) + data_fixture.create_select_option(field=tags_field, value="Docs", order=2) + + category_field = data_fixture.create_link_row_field( + table=table, + link_row_table=linked_table, + name="Category", + link_row_multiple_relationships=False, + ) + related_categories_field = data_fixture.create_link_row_field( + table=table, + link_row_table=linked_table, + name="Related Categories", + link_row_multiple_relationships=True, + ) + + return { + "user": user, + "workspace": workspace, + "database": database, + "table": table, + "linked_table": linked_table, + "fields": { + "title": title, + "description": description, + "estimated_hours": estimated_hours, + "completed": completed, + "due_date": due_date, + "created_at": created_at, + "status": status_field, + "tags": tags_field, + "category": category_field, + "related_categories": related_categories_field, + }, + } + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_rows_with_all_field_types(data_fixture, eval_model, db): + """ + Agent should create rows with sensible data for every field type. + + This tests the full flow: + 1. Agent calls get_tables_schema to learn the table structure + 2. Agent calls load_row_tools to unlock create_rows_in_table_X + 3. Agent calls create_rows_in_table_X with all fields populated + """ + + res = _create_rich_table(data_fixture) + user = res["user"] + workspace = res["workspace"] + database = res["database"] + table = res["table"] + fields = res["fields"] + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=20, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database, table=table) + deps.tool_helpers.request_context["ui_context"] = ui_context + + result = agent.run_sync( + user_prompt=PROMPT_CREATES_ROWS_WITH_ALL_FIELD_TYPES.format( + table_name=table.name + ), + deps=deps, + model=model, + usage_limits=usage_limits, + toolsets=[toolset], + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + table_model = table.get_model() + row_count = table_model.objects.count() + sample_rows = list(table_model.objects.all()) + + def _get_field_value(row, field_name): + return getattr(row, fields[field_name].db_column, None) + + def _any_row(check_fn): + return any(check_fn(r) for r in sample_rows) + + with EvalChecklist("creates rows with all field types") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check("5 rows created", row_count == 5, hint=f"got {row_count}") + checks.check( + "title populated", + _any_row(lambda r: bool(_get_field_value(r, "title"))), + ) + checks.check( + "description populated", + _any_row(lambda r: bool(_get_field_value(r, "description"))), + ) + checks.check( + "estimated_hours populated", + _any_row(lambda r: _get_field_value(r, "estimated_hours") is not None), + ) + checks.check( + "estimated_hours > 0 in at least one row", + _any_row(lambda r: (_get_field_value(r, "estimated_hours") or 0) > 0), + ) + checks.check( + "completed has at least one True", + _any_row(lambda r: _get_field_value(r, "completed") is True), + ) + checks.check( + "due_date populated", + _any_row(lambda r: _get_field_value(r, "due_date") is not None), + ) + checks.check( + "created_at populated", + _any_row(lambda r: _get_field_value(r, "created_at") is not None), + ) + checks.check( + "status is a known option", + _any_row( + lambda r: bool(_get_field_value(r, "status")) + and _get_field_value(r, "status").value + in ["To Do", "In Progress", "Done"] + ), + ) + checks.check( + "tags has at least one known option", + _any_row( + lambda r: bool( + set(_get_field_value(r, "tags").values_list("value", flat=True)) + & {"Bug", "Feature", "Docs"} + ) + ), + ) + checks.check( + "category linked", + _any_row(lambda r: len(_get_field_value(r, "category").all()) > 0), + ) + checks.check( + "related_categories linked", + _any_row( + lambda r: len(_get_field_value(r, "related_categories").all()) > 0 + ), + ) diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_database_tables.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_database_tables.py new file mode 100644 index 0000000000..1761d79a46 --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_database_tables.py @@ -0,0 +1,1164 @@ +import pytest + +from baserow.contrib.database.fields.models import ( + BooleanField, + DateField, + LinkRowField, + LongTextField, + NumberField, + SingleSelectField, + TextField, +) +from baserow.contrib.database.models import Table +from baserow.contrib.database.views.models import View, ViewFilter +from baserow.core.db import specific_iterator + +from .eval_utils import ( + EvalChecklist, + build_database_ui_context, + count_tool_errors, + create_eval_assistant, + print_message_history, +) + +# --------------------------------------------------------------------------- +# Eval prompts — one per test, easy to scan for coverage +# --------------------------------------------------------------------------- + +PROMPT_CREATES_SIMPLE_TABLE = ( + "Create a Recipes table in database {database_name} with these fields: " + "Name, Description, Prep Time in Minutes, Servings, and Vegetarian. " + "Don't add sample rows." +) + +PROMPT_CREATES_TABLE_WITH_SELECT_FIELDS = ( + "Create a Tasks table in database {database_name} with: " + "Title, Status with options: To Do, In Progress, Done, " + "Priority with options: Low, Medium, High, " + "and Due Date. Don't add sample rows." +) + +PROMPT_CREATES_RELATED_TABLES = ( + "Create a simple project management system in database {database_name} with: " + "1. A Projects table with Name and Description. " + "2. A Tasks table with Title, Status with options: To Do, In Progress, Done, " + "and a link to the Projects table. " + "Don't add sample rows." +) + +PROMPT_CREATES_DATABASE_FROM_DESCRIPTION = ( + "Set up a Bookstore database to manage a bookstore. " + "I need tables for Books and Authors. " + "Books should have title, description, price, publication date, and a link to Authors. " + "Authors should have name and bio. " + "Don't add sample rows." +) + +PROMPT_CREATE_RELATED_TABLES_WITH_SAMPLE_ROWS = ( + "Set up the Bookstore database {database_name} with: " + "1. An Authors table with Name and Bio. " + "2. A Books table with Title, Genre " + "(single select: Fiction, Non-Fiction, Science, History), " + "Price, and a link to the Authors table." +) + +# -- View creation prompts -------------------------------------------------- + +PROMPT_CREATE_GRID_VIEW = ( + "Create a grid view called 'All Tasks' for table {table_name}." +) + +PROMPT_CREATE_KANBAN_VIEW = ( + "Create a kanban view called 'Task Board' for table {table_name}. " + "Use the Status field (id: {status_field_name}) as the column field." +) + +PROMPT_CREATE_CALENDAR_VIEW = ( + "Create a calendar view called 'Schedule' for table {table_name}. " + "Use the Due Date field (id: {date_field_name}) as the date field." +) + +PROMPT_CREATE_GALLERY_VIEW = ( + "Create a gallery view called 'Image Gallery' for table {table_name}. " + "Use the Cover Image field (id: {file_field_name}) as the cover image." +) + +PROMPT_CREATE_TIMELINE_VIEW = ( + "Create a timeline view called 'Project Timeline' for table {table_name}. " + "Use Start Date (id: {start_field_name}) and End Date (id: {end_field_name})." +) + +PROMPT_CREATE_FORM_VIEW = ( + "Create a form view called 'Submit Task' for table {table_name}. " + "Include the Name field in the form." +) + +# -- View filter prompts ---------------------------------------------------- + +PROMPT_FILTER_TEXT_CONTAINS = ( + "Create a grid view called 'Filtered' for table {table_name}, " + "then add a filter on the Description field (id: {text_field_name}) " + "to only show rows where it contains 'important'." +) + +PROMPT_FILTER_NUMBER_GREATER_THAN = ( + "Create a grid view called 'Filtered' for table {table_name}, " + "then add a filter on the Amount field (id: {number_field_name}) " + "to only show rows where it is greater than 100." +) + +PROMPT_FILTER_DATE_AFTER = ( + "Create a grid view called 'Filtered' for table {table_name}, " + "then add a filter on the Due Date field (id: {date_field_name}) " + "to only show rows where the date is after today." +) + +PROMPT_FILTER_SINGLE_SELECT_ANY_OF = ( + "Create a grid view called 'Filtered' for table {table_name}, " + "then add a filter on the Status field (id: {select_field_name}) " + "to only show rows where Status is any of 'Active' or 'Pending'." +) + +PROMPT_FILTER_MULTIPLE_SELECT_HAS = ( + "Create a grid view called 'Filtered' for table {table_name}, " + "then add a filter on the Tags field (id: {multi_field_name}) " + "to only show rows where Tags has 'Important'." +) + +PROMPT_FILTER_BOOLEAN_IS = ( + "Create a grid view called 'Filtered' for table {table_name}, " + "then add a filter on the Active field (id: {bool_field_name}) " + "to only show rows where Active is true." +) + +# -- Field update/delete prompts -------------------------------------------- + +PROMPT_UPDATE_FIELD_RENAME = ( + "Rename the Description field to Summary in the {table_name} table." +) + +PROMPT_UPDATE_FIELD_SELECT_OPTIONS = ( + "Add an 'In Progress' option to the Status field in the {table_name} table." +) + +PROMPT_DELETE_FIELD = "Delete the Notes field from the {table_name} table." + + +def _run_agent( + agent, deps, tracker, model, usage_limits, toolset, question, ui_context +): + """Helper to run the agent with standard configuration.""" + deps.tool_helpers.request_context["ui_context"] = ui_context + + result = agent.run_sync( + user_prompt=question, + deps=deps, + model=model, + usage_limits=usage_limits, + toolsets=[toolset], + ) + return result + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_simple_table(data_fixture, eval_model): + """Agent should create a table with basic field types when asked.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application( + workspace=workspace, name="Recipe Database" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=15, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_SIMPLE_TABLE.format(database_name=database.name), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + tables = Table.objects.filter(database=database) + recipe_tables = [t for t in tables if "recipe" in t.name.lower()] + table = recipe_tables[0] if recipe_tables else None + fields = list(specific_iterator(table.field_set.all())) if table else [] + field_names = {f.name.lower(): f for f in fields} + text_fields = [f for f in fields if isinstance(f, (TextField, LongTextField))] + number_fields = [f for f in fields if isinstance(f, NumberField)] + boolean_fields = [f for f in fields if isinstance(f, BooleanField)] + + prep_number = next( + ( + f + for f in number_fields + if any(kw in f.name.lower() for kw in ("prep", "time", "minute")) + ), + None, + ) + veg_bool = next((f for f in boolean_fields if "vegetarian" in f.name.lower()), None) + + with EvalChecklist("creates Recipes table") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "Recipes table created", + len(recipe_tables) == 1, + hint=f"got {len(recipe_tables)}: {[t.name for t in tables]}", + ) + checks.check( + "Name field exists", + any("name" in n for n in field_names), + hint=f"fields: {list(field_names.keys())}", + ) + checks.check( + "Description field exists", + any("description" in n for n in field_names), + hint=f"fields: {list(field_names.keys())}", + ) + checks.check( + ">=2 text/long_text fields", + len(text_fields) >= 2, + hint=f"got {len(text_fields)}", + ) + checks.check( + ">=2 number fields", + len(number_fields) >= 2, + hint=f"got {len(number_fields)}", + ) + checks.check( + ">=1 boolean field", + len(boolean_fields) >= 1, + hint=f"got {len(boolean_fields)}", + ) + checks.check( + "Prep Time/Minutes field exists (number)", + prep_number is not None, + hint=f"number fields: {[f.name for f in number_fields]}", + ) + checks.check( + "Vegetarian field exists (boolean)", + veg_bool is not None, + hint=f"boolean fields: {[f.name for f in boolean_fields]}", + ) + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_table_with_select_fields(data_fixture, eval_model): + """Agent should create a table with single select and appropriate options.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application( + workspace=workspace, name="Task Management" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=15, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_TABLE_WITH_SELECT_FIELDS.format( + database_name=database.name + ), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + tables = Table.objects.filter(database=database) + task_tables = [t for t in tables if "task" in t.name.lower()] + table = task_tables[0] if task_tables else None + fields = list(specific_iterator(table.field_set.all())) if table else [] + select_fields = [f for f in fields if isinstance(f, SingleSelectField)] + status_field = next((f for f in select_fields if "status" in f.name.lower()), None) + status_options = ( + list(status_field.select_options.values_list("value", flat=True)) + if status_field + else [] + ) + date_fields = [f for f in fields if isinstance(f, DateField)] + field_names_lower = {f.name.lower(): f for f in fields} + priority_field = next( + (f for f in select_fields if "priority" in f.name.lower()), None + ) + priority_options = ( + list(priority_field.select_options.values_list("value", flat=True)) + if priority_field + else [] + ) + status_option_values = {o.lower() for o in status_options} + priority_option_values = {o.lower() for o in priority_options} + + with EvalChecklist("creates Tasks table with selects") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "Tasks table created", + len(task_tables) == 1, + hint=f"got {len(task_tables)}: {[t.name for t in tables]}", + ) + checks.check( + ">=2 single select fields (Status, Priority)", + len(select_fields) >= 2, + hint=f"got {len(select_fields)}: {[f.name for f in select_fields]}", + ) + checks.check( + "Status field exists", + status_field is not None, + hint=f"select fields: {[f.name for f in select_fields]}", + ) + checks.check( + "Status has >=3 options", + len(status_options) >= 3, + hint=f"got: {status_options}", + ) + checks.check( + ">=1 date field", + len(date_fields) >= 1, + hint=f"got {len(date_fields)}", + ) + checks.check( + "Title text field exists", + any("title" in n for n in field_names_lower), + hint=f"fields: {list(field_names_lower.keys())}", + ) + checks.check( + "Priority field exists", + priority_field is not None, + hint=f"select fields: {[f.name for f in select_fields]}", + ) + checks.check( + "Status has To Do / In Progress / Done", + {"to do", "in progress", "done"} <= status_option_values, + hint=f"got: {status_options}", + ) + checks.check( + "Priority has Low / Medium / High", + {"low", "medium", "high"} <= priority_option_values, + hint=f"got: {priority_options}", + ) + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_related_tables(data_fixture, eval_model): + """Agent should create multiple tables with link_row relationships.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application( + workspace=workspace, name="Project Management" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=20, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_RELATED_TABLES.format(database_name=database.name), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + tables = Table.objects.filter(database=database) + table_names = {t.name.lower(): t for t in tables} + project_tables = [name for name in table_names if "project" in name] + task_tables = [name for name in table_names if "task" in name] + + task_table = table_names[task_tables[0]] if task_tables else None + task_fields = ( + list(specific_iterator(task_table.field_set.all())) if task_table else [] + ) + link_fields = [f for f in task_fields if isinstance(f, LinkRowField)] + + project_table = table_names[project_tables[0]] if project_tables else None + link_to_projects = ( + [f for f in link_fields if f.link_row_table_id == project_table.id] + if project_table + else [] + ) + project_fields = ( + list(specific_iterator(project_table.field_set.all())) if project_table else [] + ) + project_text_fields = [ + f for f in project_fields if isinstance(f, (TextField, LongTextField)) + ] + task_select_fields = [f for f in task_fields if isinstance(f, SingleSelectField)] + status_field_in_tasks = next( + (f for f in task_select_fields if "status" in f.name.lower()), None + ) + status_opts_in_tasks = ( + list(status_field_in_tasks.select_options.values_list("value", flat=True)) + if status_field_in_tasks + else [] + ) + status_opt_values = {o.lower() for o in status_opts_in_tasks} + + with EvalChecklist("creates related tables") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "Projects table exists", + len(project_tables) >= 1, + hint=f"got tables: {list(table_names.keys())}", + ) + checks.check( + "Tasks table exists", + len(task_tables) >= 1, + hint=f"got tables: {list(table_names.keys())}", + ) + checks.check( + ">=1 link_row field in Tasks", + len(link_fields) >= 1, + hint=f"fields: {[(f.name, type(f).__name__) for f in task_fields]}", + ) + checks.check( + "link_row points to Projects table", + len(link_to_projects) >= 1, + hint=f"links to: {[(f.name, f.link_row_table_id) for f in link_fields]}", + ) + checks.check( + "Projects has >=2 text fields (Name, Description)", + len(project_text_fields) >= 2, + hint=f"project text fields: {[f.name for f in project_text_fields]}", + ) + checks.check( + "Tasks has Status single_select field", + status_field_in_tasks is not None, + hint=f"task select fields: {[f.name for f in task_select_fields]}", + ) + checks.check( + "Tasks Status has To Do / In Progress / Done", + {"to do", "in progress", "done"} <= status_opt_values, + hint=f"got: {status_opts_in_tasks}", + ) + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_database_from_description(data_fixture, eval_model): + """ + Agent should create a full database structure from a high-level description. + + This tests the agent's ability to interpret a vague request and create + appropriate tables, fields, and relationships. + """ + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=25, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_DATABASE_FROM_DESCRIPTION, + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + from baserow.contrib.database.models import Database + + databases = Database.objects.filter(workspace=workspace) + tables = list(Table.objects.filter(database__in=databases)) + table_names_lower = [t.name.lower() for t in tables] + + books_table = next((t for t in tables if "book" in t.name.lower()), None) + books_fields = ( + list(specific_iterator(books_table.field_set.all())) if books_table else [] + ) + books_field_types = {type(f) for f in books_fields} + + authors_table_obj = next((t for t in tables if "author" in t.name.lower()), None) + authors_fields = ( + list(specific_iterator(authors_table_obj.field_set.all())) + if authors_table_obj + else [] + ) + authors_field_types = {type(f) for f in authors_fields} + books_link_fields = [f for f in books_fields if isinstance(f, LinkRowField)] + link_to_authors = ( + [f for f in books_link_fields if f.link_row_table_id == authors_table_obj.id] + if authors_table_obj + else [] + ) + + with EvalChecklist("creates Bookstore database") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "database created", + databases.exists(), + hint="no database found in workspace", + ) + checks.check( + "Books table exists", + any("book" in n for n in table_names_lower), + hint=f"got: {[t.name for t in tables]}", + ) + checks.check( + "Authors table exists", + any("author" in n for n in table_names_lower), + hint=f"got: {[t.name for t in tables]}", + ) + checks.check( + "Books has text/long_text field", + TextField in books_field_types or LongTextField in books_field_types, + hint=f"field types: {[t.__name__ for t in books_field_types]}", + ) + checks.check( + "Books has number field (price)", + NumberField in books_field_types, + hint=f"field types: {[t.__name__ for t in books_field_types]}", + ) + checks.check( + "Books has date field", + DateField in books_field_types, + hint=f"field types: {[t.__name__ for t in books_field_types]}", + ) + checks.check( + "Books has link_row field to Authors", + LinkRowField in books_field_types, + hint=f"field types: {[t.__name__ for t in books_field_types]}", + ) + checks.check( + "Books link_row points to Authors table", + len(link_to_authors) >= 1, + hint=f"link targets: {[f.link_row_table_id for f in books_link_fields]}", + ) + checks.check( + "Authors has text field (name/bio)", + TextField in authors_field_types or LongTextField in authors_field_types, + hint=f"authors field types: {[t.__name__ for t in authors_field_types]}", + ) + checks.check( + "Books has >=2 text/long_text fields (title + description)", + sum(1 for f in books_fields if isinstance(f, (TextField, LongTextField))) + >= 2, + hint=f"books text fields: {[f.name for f in books_fields if isinstance(f, (TextField, LongTextField))]}", + ) + + +# --------------------------------------------------------------------------- +# Parametrized view creation eval +# --------------------------------------------------------------------------- + + +def _setup_grid(data_fixture, table): + """Grid view needs no special fields.""" + return {} + + +def _setup_kanban(data_fixture, table): + """Kanban needs a single_select field.""" + field = data_fixture.create_single_select_field(table=table, name="Status") + data_fixture.create_select_option(field=field, value="To Do", order=1) + data_fixture.create_select_option(field=field, value="In Progress", order=2) + data_fixture.create_select_option(field=field, value="Done", order=3) + return {"status_field": field} + + +def _setup_calendar(data_fixture, table): + """Calendar needs a date field.""" + field = data_fixture.create_date_field(table=table, name="Due Date") + return {"date_field": field} + + +def _setup_gallery(data_fixture, table): + """Gallery needs a file field.""" + field = data_fixture.create_file_field(table=table, name="Cover Image") + return {"file_field": field} + + +def _setup_timeline(data_fixture, table): + """Timeline needs two date fields with matching include_time.""" + start = data_fixture.create_date_field( + table=table, name="Start Date", date_include_time=False + ) + end = data_fixture.create_date_field( + table=table, name="End Date", date_include_time=False + ) + return {"start_field": start, "end_field": end} + + +def _setup_form(data_fixture, table): + """Form view uses existing fields; no extra setup beyond what's already there.""" + return {} + + +_VIEW_TEST_CASES = [ + pytest.param("grid", _setup_grid, PROMPT_CREATE_GRID_VIEW, id="grid"), + pytest.param("kanban", _setup_kanban, PROMPT_CREATE_KANBAN_VIEW, id="kanban"), + pytest.param( + "calendar", _setup_calendar, PROMPT_CREATE_CALENDAR_VIEW, id="calendar" + ), + pytest.param("gallery", _setup_gallery, PROMPT_CREATE_GALLERY_VIEW, id="gallery"), + pytest.param( + "timeline", _setup_timeline, PROMPT_CREATE_TIMELINE_VIEW, id="timeline" + ), + pytest.param("form", _setup_form, PROMPT_CREATE_FORM_VIEW, id="form"), +] + + +_EXPECTED_VIEW_NAMES = { + "grid": "all tasks", + "kanban": "task board", + "calendar": "schedule", + "gallery": "image gallery", + "timeline": "project timeline", + "form": "submit task", +} + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +@pytest.mark.parametrize("view_type,setup_fn,prompt_template", _VIEW_TEST_CASES) +def test_agent_creates_view( + data_fixture, eval_model, view_type, setup_fn, prompt_template +): + """Agent should create a view of the given type without tool errors.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database, name="Tasks") + data_fixture.create_text_field(table=table, name="Name", primary=True) + + # Set up type-specific fields + extra = setup_fn(data_fixture, table) + + # Build prompt with field IDs injected + fmt_kwargs = {"table_name": table.name} + for key, field in extra.items(): + fmt_kwargs[f"{key}_name"] = field.name + prompt = prompt_template.format(**fmt_kwargs) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=15, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database, table) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=prompt, + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + views = View.objects.filter(table=table) + typed_views = [ + v for v in views if v.get_type().type == view_type and v.name != "Grid" + ] + + view_name_ok = any( + _EXPECTED_VIEW_NAMES[view_type] in v.name.lower() for v in typed_views + ) + + with EvalChecklist(f"creates {view_type} view") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + f"{view_type} view created", + len(typed_views) >= 1, + hint=f"got views: {[(v.name, v.get_type().type) for v in views]}", + ) + checks.check( + "view name matches expected", + view_name_ok, + hint=f"expected '{_EXPECTED_VIEW_NAMES[view_type]}', got: {[v.name for v in typed_views]}", + ) + + +# --------------------------------------------------------------------------- +# Parametrized view filter creation eval +# --------------------------------------------------------------------------- + + +def _setup_text_filter(data_fixture, table): + field = data_fixture.create_text_field(table=table, name="Description") + return {"text_field": field} + + +def _setup_number_filter(data_fixture, table): + field = data_fixture.create_number_field(table=table, name="Amount") + return {"number_field": field} + + +def _setup_date_filter(data_fixture, table): + field = data_fixture.create_date_field(table=table, name="Due Date") + return {"date_field": field} + + +def _setup_single_select_filter(data_fixture, table): + field = data_fixture.create_single_select_field(table=table, name="Status") + data_fixture.create_select_option(field=field, value="Active", order=1) + data_fixture.create_select_option(field=field, value="Pending", order=2) + data_fixture.create_select_option(field=field, value="Closed", order=3) + return {"select_field": field} + + +def _setup_multiple_select_filter(data_fixture, table): + field = data_fixture.create_multiple_select_field(table=table, name="Tags") + data_fixture.create_select_option(field=field, value="Important", order=1) + data_fixture.create_select_option(field=field, value="Urgent", order=2) + data_fixture.create_select_option(field=field, value="Low", order=3) + return {"multi_field": field} + + +def _setup_boolean_filter(data_fixture, table): + field = data_fixture.create_boolean_field(table=table, name="Active") + return {"bool_field": field} + + +_FILTER_TEST_CASES = [ + pytest.param( + "text", + _setup_text_filter, + PROMPT_FILTER_TEXT_CONTAINS, + "contains", + "important", + id="text_contains", + ), + pytest.param( + "number", + _setup_number_filter, + PROMPT_FILTER_NUMBER_GREATER_THAN, + "higher_than", + "100", + id="number_greater_than", + ), + pytest.param( + "date", + _setup_date_filter, + PROMPT_FILTER_DATE_AFTER, + "date_is_after", + None, # value contains UTC?date_mode format — fragile to check + id="date_after", + ), + pytest.param( + "single_select", + _setup_single_select_filter, + PROMPT_FILTER_SINGLE_SELECT_ANY_OF, + "single_select_is_any_of", + None, # value is comma-separated option IDs — fragile to check + id="single_select_is_any_of", + ), + pytest.param( + "multiple_select", + _setup_multiple_select_filter, + PROMPT_FILTER_MULTIPLE_SELECT_HAS, + "multiple_select_has", + None, # value is option ID — fragile to check + id="multiple_select_has", + ), + pytest.param( + "boolean", + _setup_boolean_filter, + PROMPT_FILTER_BOOLEAN_IS, + "equal", + "1", + id="boolean_equal", + ), +] + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +@pytest.mark.parametrize( + "filter_type,setup_fn,prompt_template,expected_orm_type,expected_value_fragment", + _FILTER_TEST_CASES, +) +def test_agent_creates_view_filter( + data_fixture, + eval_model, + filter_type, + setup_fn, + prompt_template, + expected_orm_type, + expected_value_fragment, +): + """Agent should create a view with the correct filter type without tool errors.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database, name="Tasks") + data_fixture.create_text_field(table=table, name="Name", primary=True) + + # Set up type-specific fields + extra = setup_fn(data_fixture, table) + + # Build prompt with field IDs injected + fmt_kwargs = {"table_name": table.name} + for key, field in extra.items(): + fmt_kwargs[f"{key}_name"] = field.name + prompt = prompt_template.format(**fmt_kwargs) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=15, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database, table) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=prompt, + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + filters = ViewFilter.objects.filter(view__table=table, type=expected_orm_type) + all_filter_types = list( + ViewFilter.objects.filter(view__table=table).values_list("type", flat=True) + ) + filter_obj = filters.first() + setup_field = list(extra.values())[0] if extra else None + + with EvalChecklist(f"creates {filter_type} view filter") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + f"ViewFilter type='{expected_orm_type}' exists", + filters.exists(), + hint=f"got filter types: {all_filter_types}", + ) + checks.check( + "filter is on the correct field", + filter_obj is not None + and setup_field is not None + and filter_obj.field_id == setup_field.id, + hint=f"filter field_id={filter_obj.field_id if filter_obj else None}, expected={setup_field.id if setup_field else None}", + ) + if expected_value_fragment is not None: + checks.check( + "filter value is correct", + filter_obj is not None + and expected_value_fragment in (filter_obj.value or ""), + hint=f"filter value='{filter_obj.value if filter_obj else None}', expected fragment='{expected_value_fragment}'", + ) + + +# --------------------------------------------------------------------------- +# Field update/delete evals +# --------------------------------------------------------------------------- + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_renames_field(data_fixture, eval_model): + """Agent should rename a field when asked.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database, name="Tasks") + data_fixture.create_text_field(table=table, name="Name", primary=True) + data_fixture.create_long_text_field(table=table, name="Description") + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=15, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database, table) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_UPDATE_FIELD_RENAME.format(table_name=table.name), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + field_names = list(table.field_set.all().values_list("name", flat=True)) + + with EvalChecklist("renames field") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "Summary field exists", + any("summary" in n.lower() for n in field_names), + hint=f"fields: {field_names}", + ) + checks.check( + "Description field gone", + not any(n.lower() == "description" for n in field_names), + hint=f"fields: {field_names}", + ) + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_updates_select_options(data_fixture, eval_model): + """Agent should add a new option to a single_select field.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database, name="Tasks") + data_fixture.create_text_field(table=table, name="Name", primary=True) + status_field = data_fixture.create_single_select_field(table=table, name="Status") + data_fixture.create_select_option(field=status_field, value="To Do", order=1) + data_fixture.create_select_option(field=status_field, value="Done", order=2) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=15, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database, table) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_UPDATE_FIELD_SELECT_OPTIONS.format(table_name=table.name), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + status_field.refresh_from_db() + options = list(status_field.select_options.values_list("value", flat=True)) + + with EvalChecklist("updates select options") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "In Progress option added", + any("in progress" in o.lower() for o in options), + hint=f"options: {options}", + ) + checks.check( + "existing options preserved", + {"to do", "done"} <= {o.lower() for o in options}, + hint=f"options: {options}", + ) + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_deletes_field(data_fixture, eval_model): + """Agent should delete a field when asked.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database, name="Tasks") + data_fixture.create_text_field(table=table, name="Name", primary=True) + data_fixture.create_long_text_field(table=table, name="Notes") + data_fixture.create_text_field(table=table, name="Priority") + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=15, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database, table) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_DELETE_FIELD.format(table_name=table.name), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + field_names = list(table.field_set.all().values_list("name", flat=True)) + + with EvalChecklist("deletes field") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "Notes field gone", + not any(n.lower() == "notes" for n in field_names), + hint=f"fields: {field_names}", + ) + checks.check( + "other fields preserved", + any("name" in n.lower() for n in field_names) + and any("priority" in n.lower() for n in field_names), + hint=f"fields: {field_names}", + ) + + +# --------------------------------------------------------------------------- +# Sample rows eval +# --------------------------------------------------------------------------- + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_create_related_tables_with_sample_rows(data_fixture, eval_model): + """ + Agent creates two related tables (Authors → Books) and sample rows + are generated for both, including link_row references. + """ + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application( + workspace=workspace, name="Bookstore" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=25, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATE_RELATED_TABLES_WITH_SAMPLE_ROWS.format( + database_name=database.name + ), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + tables = Table.objects.filter(database=database) + table_names = {t.name.lower(): t for t in tables} + author_tables = [name for name in table_names if "author" in name] + book_tables = [name for name in table_names if "book" in name] + + authors_count = ( + table_names[author_tables[0]].get_model().objects.count() + if author_tables + else 0 + ) + books_count = ( + table_names[book_tables[0]].get_model().objects.count() if book_tables else 0 + ) + books_table_obj = table_names[book_tables[0]] if book_tables else None + books_fields_list = ( + list(specific_iterator(books_table_obj.field_set.all())) + if books_table_obj + else [] + ) + genre_field = next( + ( + f + for f in books_fields_list + if isinstance(f, SingleSelectField) and "genre" in f.name.lower() + ), + None, + ) + genre_options = ( + list(genre_field.select_options.values_list("value", flat=True)) + if genre_field + else [] + ) + genre_option_values = {o.lower() for o in genre_options} + price_field = next( + ( + f + for f in books_fields_list + if isinstance(f, NumberField) and "price" in f.name.lower() + ), + None, + ) + books_link_fields_list = [ + f for f in books_fields_list if isinstance(f, LinkRowField) + ] + + with EvalChecklist("creates Bookstore with sample rows") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "Authors table exists", + len(author_tables) >= 1, + hint=f"got: {list(table_names.keys())}", + ) + checks.check( + "Books table exists", + len(book_tables) >= 1, + hint=f"got: {list(table_names.keys())}", + ) + checks.check( + "Authors has >=1 sample row", + authors_count >= 1, + hint=f"got {authors_count}", + ) + checks.check( + "Books has >=2 sample rows", + books_count >= 2, + hint=f"got {books_count}", + ) + checks.check( + "Books has Genre single_select field", + genre_field is not None, + hint=f"books select fields: {[f.name for f in books_fields_list if isinstance(f, SingleSelectField)]}", + ) + checks.check( + "Genre has Fiction / Non-Fiction / Science / History options", + {"fiction", "non-fiction", "science", "history"} <= genre_option_values, + hint=f"got: {genre_options}", + ) + checks.check( + "Books has Price (number) field", + price_field is not None, + hint=f"books number fields: {[f.name for f in books_fields_list if isinstance(f, NumberField)]}", + ) + checks.check( + "Books has link_row to Authors", + len(books_link_fields_list) >= 1, + hint=f"books fields: {[f.name for f in books_fields_list]}", + ) diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_search_user_docs.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_search_user_docs.py new file mode 100644 index 0000000000..2aa153e221 --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_search_user_docs.py @@ -0,0 +1,276 @@ +from django.conf import settings + +import pytest + +from .eval_utils import ( + EvalChecklist, + build_database_ui_context, + create_eval_assistant, + format_message_history, + print_message_history, +) + + +@pytest.fixture(autouse=True) +def _require_knowledge_base(synced_knowledge_base): + """Skip search docs tests when the knowledge base is not available. + + Depends on the session-scoped ``synced_knowledge_base`` fixture + (conftest.py) which syncs the KB once per session if needed. + """ + + if not getattr(settings, "BASEROW_EMBEDDINGS_API_URL", ""): + pytest.skip( + "BASEROW_EMBEDDINGS_API_URL not set. " + "See docs/testing/ai-assistant-evals.md for setup instructions." + ) + + from baserow_enterprise.assistant.tools.search_user_docs.handler import ( + KnowledgeBaseHandler, + ) + + if not KnowledgeBaseHandler().can_search(): + pytest.skip( + "Knowledge base not available. " + "Requires: pgvector extension and synced KB data. " + "See docs/testing/ai-assistant-evals.md for setup instructions." + ) + + +# --------------------------------------------------------------------------- +# Test cases: (id, question, expected_source_patterns, expected_answer_keywords) +# +# expected_source_patterns: at least ONE returned source URL must contain +# one of these substrings. +# expected_answer_keywords: the agent's final answer must contain at least +# ONE of these substrings (case-insensitive). +# --------------------------------------------------------------------------- + +SEARCH_DOCS_CASES = [ + pytest.param( + ( + "I'm trying to do a VLOOKUP to pull the 'Client Email' from my " + "'Clients' tab into my 'Projects' tab based on the client name. " + "I can't find the formula for this. Does it exist in Baserow?" + ), + ["link-to-table", "lookup-field"], + ["link row", "lookup", "link_row", "relationship"], + id="vlookup-to-link-row", + ), + pytest.param( + ( + "I need to run a raw SQL query to join three tables for a report. " + "I'm on the standard cloud hosted plan. Where do I find my database " + "host, port, and credentials to connect my BI tool?" + ), + ["technical", "set-up-baserow"], + ["api", "self-host", "rest api", "not available", "cannot"], + id="raw-sql-cloud-plan", + ), + pytest.param( + ( + "I'm trying to calculate the days between two dates. I typed " + "=DAYS(field('End'), field('Start')) like I do in Google Sheets " + "but it says 'Invalid Syntax'. What am I doing wrong?" + ), + ["formula", "understanding-formulas"], + ["date_diff", "date diff", "datediff"], + id="date-diff-formula", + ), + pytest.param( + "Where is the save button? I don't want to lose my work.", + ["baserow-basics"], + ["auto", "automatically", "saved"], + id="auto-save", + ), + pytest.param( + "How can I put a form on my website that sends data to my table?", + ["creating-forms", "guide-to-creating-forms"], + ["form", "embed", "share"], + id="form-embed", + ), + pytest.param( + "I deleted a bunch of rows by mistake. Is there a recycling bin?", + ["data-recovery", "deletion"], + ["trash", "recover", "undo", "restore"], + id="data-recovery", + ), + pytest.param( + ( + "I want to share a specific view with my client so they can see " + "the progress, but I don't want them to edit anything or see the " + "other tables. Is that possible?" + ), + ["public-sharing", "permissions"], + ["share", "public", "read-only", "read only", "view"], + id="share-view-read-only", + ), + pytest.param( + "I need to lock a column so my team can see it but not mess it up.", + ["field-level-permissions", "permissions"], + ["permission", "field", "read", "lock"], + id="field-permissions", + ), + pytest.param( + ( + "How can I create a calendar that shows my tasks, but only the ones assigned to me." + ), + ["calendar-view", "calendar", "filters"], + ["calendar", "filter", "view"], + id="calendar-with-filter", + ), + pytest.param( + ( + "I'm trying to combine the first name and last name columns " + "into one, but I want to make sure it's uppercase. Can you tell me how to " + "write that formula?" + ), + ["formula", "understanding-formulas"], + ["concat", "upper", "formula"], + id="concat-upper-formula", + ), + pytest.param( + ( + "I'm running Baserow on my own server with Docker. A new version " + "came out yesterday, how do I install it without losing my data?" + ), + ["set-up-baserow", "configuration"], + ["docker", "pull", "upgrade", "update", "volume"], + id="docker-upgrade", + ), + pytest.param( + ( + "I want to write a script so that whenever I tick a checkbox, " + "it sends an email to the client. Do I need to build a custom " + "plugin for this?" + ), + ["webhook", "workflow-automation", "automation"], + ["automation", "webhook", "trigger", "workflow"], + id="checkbox-email-automation", + ), + pytest.param( + ( + "I want to embed my inventory sheet on my website so clients " + "can search it. Do they need a Baserow account to see it? " + "How do I generate the code?" + ), + ["public-sharing"], + ["embed", "public", "share", "account"], + id="embed-public-view", + ), + pytest.param( + "Can Baserow integrate with Google AI Studio?", + ["configure-generative-ai", "database-api"], + ["ai", "generative", "integration", "api"], + id="google-ai-studio", + ), + pytest.param( + ( + "I'm trying to fetch data from my table using curl but I keep " + "getting a 401 error. I generated a token in my settings, but it " + "says I don't have permissions. Do I need to use my login email " + "and password instead?" + ), + ["rest-api", "database-api"], + ["token", "api", "permission", "authentication"], + id="api-401-error", + ), + pytest.param( + ( + "Is there a way to only get rows where the 'Status' field is " + "set to 'Done' via the API? I don't want to download the whole " + "JSON and filter it in my script." + ), + ["rest-api", "database-api"], + ["filter", "api", "parameter", "field"], + id="api-filter-rows", + ), +] + + +def _run_agent( + agent, deps, tracker, model, usage_limits, toolset, question, ui_context +): + deps.tool_helpers.request_context["ui_context"] = ui_context + return agent.run_sync( + user_prompt=question, + deps=deps, + model=model, + usage_limits=usage_limits, + toolsets=[toolset], + ) + + +@pytest.mark.eval +@pytest.mark.django_db +@pytest.mark.parametrize( + "question,expected_source_patterns,expected_keywords", SEARCH_DOCS_CASES +) +def test_search_user_docs( + data_fixture, + eval_model, + question, + expected_source_patterns, + expected_keywords, +): + """ + Agent should call search_user_docs for user-docs questions and return + an answer with relevant sources and content. + """ + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=10, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=question, + ui_context=ui_context, + ) + + print_message_history(result) + + history = format_message_history(result) + search_calls = [ + e + for e in history + if e.get("tool_name") == "search_user_docs" and e["role"] == "assistant" + ] + sources = deps.sources + answer = result.output.lower() + keyword_match = any(kw.lower() in answer for kw in expected_keywords) + + # Source URL matching is non-fatal — URLs change and the retrieval may + # return valid alternative sources. Print a warning but don't score it. + if expected_source_patterns and sources: + source_match = any( + any(pattern in url for pattern in expected_source_patterns) + for url in sources + ) + if not source_match: + print( + f"\n WARNING: No source matched {expected_source_patterns}.\n" + f" Returned sources: {sources}" + ) + + with EvalChecklist("search user docs") as checks: + checks.check( + "called search_user_docs", + len(search_calls) >= 1, + hint=f"tools called: {[e.get('tool_name') for e in history if e.get('tool_name')]}", + ) + checks.check( + f"answer mentions one of {expected_keywords}", + keyword_match, + hint=f"answer (first 300 chars): {result.output[:300]}", + ) diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py index 1d5bbec55f..62e5e77d54 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py @@ -1,15 +1,21 @@ from unittest.mock import MagicMock, patch -from django.core.cache import cache +from django.test.utils import override_settings import pytest from asgiref.sync import async_to_sync -from udspy import OutputStreamChunk, Prediction +from pydantic_ai.messages import PartStartEvent +from pydantic_ai.messages import TextPart as PaiTextPart -from baserow_enterprise.assistant.assistant import Assistant, AssistantCallbacks -from baserow_enterprise.assistant.exceptions import AssistantMessageCancelled +from baserow_enterprise.assistant.assistant import ( + Assistant, + compact_message_history, + get_model_string, +) +from baserow_enterprise.assistant.deps import AssistantDeps from baserow_enterprise.assistant.models import AssistantChat, AssistantChatMessage from baserow_enterprise.assistant.types import ( + AiMessage, AiMessageChunk, AiStartedMessage, AiThinkingMessage, @@ -23,133 +29,113 @@ WorkspaceUIContext, ) +TEST_MODEL = "groq:test-model" + @pytest.fixture(autouse=True) -def mock_posthog_openai(): - with patch("posthog.ai.openai.AsyncOpenAI") as mock: - # Configure the mock if needed +def mock_posthog(): + with patch("baserow_enterprise.assistant.telemetry.get_posthog_client") as mock: mock.return_value = MagicMock() - mock.return_value.model = "test-model" yield mock -@pytest.mark.django_db -class TestAssistantCallbacks: - """Test the AssistantCallbacks class for handling tool execution""" - - def test_extend_sources_deduplicates(self): - """Test that sources are deduplicated when extended""" - - callbacks = AssistantCallbacks() - - # Add initial sources - callbacks.extend_sources( - ["https://example.com/doc1", "https://example.com/doc2"] - ) - assert callbacks.sources == [ - "https://example.com/doc1", - "https://example.com/doc2", - ] +@pytest.fixture(autouse=True) +def _set_test_model(settings): + settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL = "groq/test-model" - # Add sources with duplicates - callbacks.extend_sources( - ["https://example.com/doc2", "https://example.com/doc3"] - ) - # Should only add the new source, not the duplicate - assert callbacks.sources == [ - "https://example.com/doc1", - "https://example.com/doc2", - "https://example.com/doc3", - ] +# --------------------------------------------------------------------------- +# Mock helpers for pydantic-ai's run_stream_events async generator +# --------------------------------------------------------------------------- - def test_extend_sources_preserves_order(self): - """Test that source order is preserved (first occurrence wins)""" - callbacks = AssistantCallbacks() +async def _mock_run_stream_events(answer: str, messages_json: bytes = b"[]"): + """ + Async generator that mimics ``main_agent.run_stream_events()`` + yielding PartStartEvent, then AgentRunResultEvent. + """ + from pydantic_ai.run import AgentRunResultEvent - callbacks.extend_sources(["https://example.com/a"]) - callbacks.extend_sources(["https://example.com/b"]) - callbacks.extend_sources(["https://example.com/a"]) # Duplicate + # Emit a text part start with the full answer + yield PartStartEvent(index=0, part=PaiTextPart(content=answer)) - # 'a' should remain first - assert callbacks.sources == ["https://example.com/a", "https://example.com/b"] + # Emit the final result event + mock_result = MagicMock() + mock_result.output = answer + mock_result.all_messages_json.return_value = messages_json + yield AgentRunResultEvent(result=mock_result) - def test_on_tool_end_extracts_sources_from_outputs(self): - """Test that sources are extracted from tool outputs""" - callbacks = AssistantCallbacks() +def make_mock_run_stream_events_side_effect(answer: str, messages_json: bytes = b"[]"): + """Return a side_effect callable that returns the mock async generator.""" - # Mock tool instance and inputs - tool_instance = MagicMock() - tool_instance.name = "search_user_docs" - inputs = {"query": "test"} + def side_effect(*args, **kwargs): + return _mock_run_stream_events(answer, messages_json) - # Register tool call - callbacks.tool_calls["call_123"] = (tool_instance, inputs) + return side_effect - # Mock registry - with patch( - "baserow_enterprise.assistant.assistant.assistant_tool_registry" - ) as mock_registry: - mock_tool = MagicMock() - mock_registry.get.return_value = mock_tool - # Tool returns outputs with sources - outputs = { - "result": "Some documentation", - "sources": ["https://baserow.io/docs/api"], - } +# --------------------------------------------------------------------------- +# Unit tests +# --------------------------------------------------------------------------- - callbacks.on_tool_end("call_123", outputs) - # Sources should be extracted - assert callbacks.sources == ["https://baserow.io/docs/api"] +@pytest.mark.django_db +class TestAssistantDeps: + """Test the AssistantDeps class for source tracking.""" - def test_on_tool_end_handles_missing_sources(self): - """Test that tool outputs without sources don't cause errors""" + def test_extend_sources_deduplicates(self): + deps = AssistantDeps( + user=MagicMock(), + workspace=MagicMock(), + tool_helpers=MagicMock(), + ) - callbacks = AssistantCallbacks() + deps.extend_sources(["https://example.com/doc1", "https://example.com/doc2"]) + assert deps.sources == [ + "https://example.com/doc1", + "https://example.com/doc2", + ] - tool_instance = MagicMock() - tool_instance.name = "some_tool" - callbacks.tool_calls["call_123"] = (tool_instance, {}) + deps.extend_sources(["https://example.com/doc2", "https://example.com/doc3"]) - with patch( - "baserow_enterprise.assistant.assistant.assistant_tool_registry" - ) as mock_registry: - mock_tool = MagicMock() - mock_registry.get.return_value = mock_tool + assert deps.sources == [ + "https://example.com/doc1", + "https://example.com/doc2", + "https://example.com/doc3", + ] - # Tool returns outputs without sources - outputs = {"result": "Some result"} + def test_extend_sources_preserves_order(self): + deps = AssistantDeps( + user=MagicMock(), + workspace=MagicMock(), + tool_helpers=MagicMock(), + ) - callbacks.on_tool_end("call_123", outputs) + deps.extend_sources(["https://example.com/a"]) + deps.extend_sources(["https://example.com/b"]) + deps.extend_sources(["https://example.com/a"]) - # Should not raise, sources should remain empty - assert callbacks.sources == [] + assert deps.sources == ["https://example.com/a", "https://example.com/b"] @pytest.mark.django_db class TestAssistantChatHistory: - """Test chat history loading and formatting""" + """Test chat history loading and formatting.""" def test_list_chat_messages_returns_in_chronological_order( self, enterprise_data_fixture ): - """Test that list_chat_messages returns messages oldest to newest""" - user = enterprise_data_fixture.create_user() workspace = enterprise_data_fixture.create_workspace(user=user) chat = AssistantChat.objects.create( user=user, workspace=workspace, title="Test Chat" ) - # Create messages in order - msg1 = AssistantChatMessage.objects.create( + AssistantChatMessage.objects.create( chat=chat, role=AssistantChatMessage.Role.HUMAN, content="First question" ) - msg2 = AssistantChatMessage.objects.create( + AssistantChatMessage.objects.create( chat=chat, role=AssistantChatMessage.Role.AI, content="First answer" ) msg3 = AssistantChatMessage.objects.create( @@ -159,77 +145,38 @@ def test_list_chat_messages_returns_in_chronological_order( assistant = Assistant(chat) messages = assistant.list_chat_messages() - # Should be in chronological order (oldest first) assert len(messages) == 3 assert messages[0].content == "First question" assert messages[1].content == "First answer" assert messages[2].content == "Second question" - # It's possible to skip messages using last_message_id messages = assistant.list_chat_messages(last_message_id=msg3.id, limit=1) assert len(messages) == 1 assert messages[0].content == "First answer" - def test_aload_chat_history_formats_as_question_answer_pairs( - self, enterprise_data_fixture - ): - """Test that chat history is loaded as user/assistant message pairs for UDSPy""" - + def test_load_message_history_returns_none_for_empty(self, enterprise_data_fixture): user = enterprise_data_fixture.create_user() workspace = enterprise_data_fixture.create_workspace(user=user) chat = AssistantChat.objects.create( user=user, workspace=workspace, title="Test Chat" ) - # Create conversation history - AssistantChatMessage.objects.create( - chat=chat, role=AssistantChatMessage.Role.HUMAN, content="What is Baserow?" - ) - AssistantChatMessage.objects.create( - chat=chat, - role=AssistantChatMessage.Role.AI, - content="Baserow is a no-code database platform.", - ) - AssistantChatMessage.objects.create( - chat=chat, - role=AssistantChatMessage.Role.HUMAN, - content="How do I create a table?", - ) - AssistantChatMessage.objects.create( - chat=chat, - role=AssistantChatMessage.Role.AI, - content="You can create a table by clicking the + button.", - ) - assistant = Assistant(chat) - assistant.history = async_to_sync(assistant.afetch_chat_history)() - - # History should contain user/assistant message pairs - assert assistant.history is not None - assert len(assistant.history.messages) == 4 - - # First pair - assert assistant.history.messages[0] == { - "role": "user", - "content": "What is Baserow?", - } - assert assistant.history.messages[1] == { - "role": "assistant", - "content": "Baserow is a no-code database platform.", - } - - # Second pair - assert assistant.history.messages[2] == { - "role": "user", - "content": "How do I create a table?", - } - assert assistant.history.messages[3] == { - "role": "assistant", - "content": "You can create a table by clicking the + button.", - } - - def test_aload_chat_history_respects_limit(self, enterprise_data_fixture): - """Test that history loading respects the limit parameter""" + history = async_to_sync(assistant._load_message_history)() + assert history is None + + def test_load_message_history_deserializes_and_compacts( + self, enterprise_data_fixture + ): + from pydantic_ai.messages import ( + ModelMessagesTypeAdapter, + ModelRequest, + ModelResponse, + TextPart, + ToolCallPart, + ToolReturnPart, + UserPromptPart, + ) user = enterprise_data_fixture.create_user() workspace = enterprise_data_fixture.create_workspace(user=user) @@ -237,178 +184,147 @@ def test_aload_chat_history_respects_limit(self, enterprise_data_fixture): user=user, workspace=workspace, title="Test Chat" ) - # Create 10 message pairs (20 messages) - for i in range(10): - AssistantChatMessage.objects.create( - chat=chat, - role=AssistantChatMessage.Role.HUMAN, - content=f"Question {i}", - ) - AssistantChatMessage.objects.create( - chat=chat, role=AssistantChatMessage.Role.AI, content=f"Answer {i}" - ) + messages = [ + ModelRequest(parts=[UserPromptPart(content="create a database")]), + ModelResponse( + parts=[ + ToolCallPart( + tool_name="create_tables", + args={"thought": "creating", "tables": ["recipes"]}, + tool_call_id="tc1", + ) + ] + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name="create_tables", + content="Created", + tool_call_id="tc1", + ) + ] + ), + ModelResponse(parts=[TextPart(content="Done!")]), + ] + chat.message_history = ModelMessagesTypeAdapter.dump_json(messages) + chat.save(update_fields=["message_history"]) assistant = Assistant(chat) - assistant.history = async_to_sync(assistant.afetch_chat_history)( - limit=6 - ) # Last 6 messages - - # Should only load the most recent 6 messages (3 pairs) - assert len(assistant.history.messages) == 6 + history = async_to_sync(assistant._load_message_history)() - def test_aload_chat_history_handles_incomplete_pairs(self, enterprise_data_fixture): - """ - Test that incomplete message pairs (e.g., orphaned human messages) are skipped - """ + assert history is not None + assert len(history) == 2 + assert isinstance(history[0], ModelRequest) + assert isinstance(history[1], ModelResponse) + def test_load_message_history_handles_corrupt_data(self, enterprise_data_fixture): user = enterprise_data_fixture.create_user() workspace = enterprise_data_fixture.create_workspace(user=user) chat = AssistantChat.objects.create( user=user, workspace=workspace, title="Test Chat" ) - # Create complete pair - AssistantChatMessage.objects.create( - chat=chat, role=AssistantChatMessage.Role.HUMAN, content="Question 1" - ) - AssistantChatMessage.objects.create( - chat=chat, role=AssistantChatMessage.Role.AI, content="Answer 1" - ) - - # Create orphaned human message (no AI response yet) - AssistantChatMessage.objects.create( - chat=chat, role=AssistantChatMessage.Role.HUMAN, content="Question 2" - ) + chat.message_history = b"not valid json" + chat.save(update_fields=["message_history"]) assistant = Assistant(chat) - assistant.history = async_to_sync(assistant.afetch_chat_history)() - - # Should only include the complete pair (2 messages: user + assistant) - assert len(assistant.history.messages) == 2 - assert assistant.history.messages[0] == { - "role": "user", - "content": "Question 1", - } - assert assistant.history.messages[1] == { - "role": "assistant", - "content": "Answer 1", - } - - @patch("udspy.ReAct.astream") - def test_history_is_passed_to_astream_as_context( - self, mock_react_astream, enterprise_data_fixture - ): - """ - Test that chat history is loaded correctly and passed to the agent as context - """ + history = async_to_sync(assistant._load_message_history)() + assert history is None + + +class TestCompactMessageHistory: + """Test the message history compaction logic.""" + + def test_compacts_tool_calls_in_older_turns(self): + from pydantic_ai.messages import ( + ModelRequest, + ModelResponse, + TextPart, + ToolCallPart, + ToolReturnPart, + UserPromptPart, + ) + + messages = [ + ModelRequest(parts=[UserPromptPart(content="create a database")]), + ModelResponse( + parts=[ + ToolCallPart( + tool_name="create_tables", + args={"thought": "creating"}, + tool_call_id="tc1", + ) + ] + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name="create_tables", + content="Created", + tool_call_id="tc1", + ) + ] + ), + ModelResponse(parts=[TextPart(content="Done!")]), + ModelRequest(parts=[UserPromptPart(content="add a field")]), + ModelResponse(parts=[TextPart(content="Added!")]), + ] - user = enterprise_data_fixture.create_user() - workspace = enterprise_data_fixture.create_workspace(user=user) - chat = AssistantChat.objects.create( - user=user, workspace=workspace, title="Test Chat" - ) + compacted = compact_message_history(messages) + assert len(compacted) == 4 - # Create conversation history (2 complete pairs) - AssistantChatMessage.objects.create( - chat=chat, role=AssistantChatMessage.Role.HUMAN, content="What is Baserow?" - ) - AssistantChatMessage.objects.create( - chat=chat, - role=AssistantChatMessage.Role.AI, - content="Baserow is a no-code database", - ) - AssistantChatMessage.objects.create( - chat=chat, - role=AssistantChatMessage.Role.HUMAN, - content="How do I create a table?", - ) - AssistantChatMessage.objects.create( - chat=chat, - role=AssistantChatMessage.Role.AI, - content="Click the Create Table button", + def test_trims_to_max_messages(self): + from pydantic_ai.messages import ( + ModelRequest, + ModelResponse, + TextPart, + UserPromptPart, ) - assistant = Assistant(chat) + messages = [] + for i in range(20): + messages.append( + ModelRequest(parts=[UserPromptPart(content=f"Question {i}")]) + ) + messages.append(ModelResponse(parts=[TextPart(content=f"Answer {i}")])) - # Mock the agent stream to verify conversation history is passed - def mock_agent_stream_factory(*args, **kwargs): - # Verify conversation history is passed to the agent - assert kwargs["conversation_history"] == [ - "[0] (user): What is Baserow?", - "[1] (assistant): Baserow is a no-code database", - "[2] (user): How do I create a table?", - "[3] (assistant): Click the Create Table button", - ] - - async def _stream(): - yield OutputStreamChunk( - module=assistant._assistant.extract_module, - field_name="answer", - delta="Answer", - content="Answer", - is_complete=False, - ) - yield Prediction( - module=assistant._assistant, - answer="Answer", - trajectory=[], - reasoning="", - ) - - return _stream() - - mock_react_astream.side_effect = mock_agent_stream_factory - - message = HumanMessage(content="How to add a view?") - - # Consume the stream to trigger assertions - async def consume_stream(): - async for _ in assistant.astream_messages(message): - pass + compacted = compact_message_history(messages, max_messages=6) + assert len(compacted) == 6 - async_to_sync(consume_stream)() + def test_preserves_simple_conversations(self): + from pydantic_ai.messages import ( + ModelRequest, + ModelResponse, + TextPart, + UserPromptPart, + ) + + messages = [ + ModelRequest(parts=[UserPromptPart(content="hello")]), + ModelResponse(parts=[TextPart(content="hi")]), + ] + + compacted = compact_message_history(messages) + assert len(compacted) == 2 @pytest.mark.django_db class TestAssistantMessagePersistence: - """Test that messages are persisted correctly during streaming""" + """Test that messages are persisted correctly during streaming.""" - @patch("udspy.ChainOfThought.astream") - @patch("udspy.ReAct.astream") + @patch("baserow_enterprise.assistant.agents.main_agent.run_stream_events") def test_astream_messages_persists_human_message( - self, mock_react_astream, mock_cot_astream, enterprise_data_fixture + self, mock_run_stream_events, enterprise_data_fixture ): - """Test that human messages are persisted to database before streaming""" - user = enterprise_data_fixture.create_user() workspace = enterprise_data_fixture.create_workspace(user=user) chat = AssistantChat.objects.create( user=user, workspace=workspace, title="Test Chat" ) - # Mock the router stream - async def mock_router_stream(*args, **kwargs): - yield Prediction( - routing_decision="delegate_to_agent", - extracted_context="", - search_query="", - ) - - mock_cot_astream.return_value = mock_router_stream() - - # Mock the agent streaming - async def mock_agent_stream(*args, **kwargs): - # Yield a simple response - yield OutputStreamChunk( - module=None, - field_name="answer", - delta="Hello", - content="Hello", - is_complete=False, - ) - yield Prediction(answer="Hello", trajectory=[], reasoning="") - - mock_react_astream.return_value = mock_agent_stream() + mock_run_stream_events.side_effect = make_mock_run_stream_events_side_effect( + "Hello" + ) assistant = Assistant(chat) ui_context = UIContext( @@ -416,7 +332,6 @@ async def mock_agent_stream(*args, **kwargs): user=UserUIContext(id=user.id, name=user.first_name, email=user.email), ) - # Consume the stream async def consume_stream(): human_message = HumanMessage(content="Test message", ui_context=ui_context) async for _ in assistant.astream_messages(human_message): @@ -424,7 +339,6 @@ async def consume_stream(): async_to_sync(consume_stream)() - # Human message should be persisted human_messages = AssistantChatMessage.objects.filter( chat=chat, role=AssistantChatMessage.Role.HUMAN ).count() @@ -435,129 +349,64 @@ async def consume_stream(): ).first() assert saved_message.content == "Test message" - @patch("udspy.ChainOfThought.astream") - @patch("udspy.ReAct.astream") - def test_astream_messages_persists_ai_message_with_sources( - self, mock_react_astream, mock_cot_astream, enterprise_data_fixture + @patch("baserow_enterprise.assistant.agents.main_agent.run_stream_events") + def test_astream_messages_persists_ai_message( + self, mock_run_stream_events, enterprise_data_fixture ): - """Test that AI messages are persisted with sources in artifacts""" - user = enterprise_data_fixture.create_user() workspace = enterprise_data_fixture.create_workspace(user=user) chat = AssistantChat.objects.create( user=user, workspace=workspace, title="Test Chat" ) - assistant = Assistant(chat) - - # Mock the router stream - async def mock_router_stream(*args, **kwargs): - yield Prediction( - routing_decision="delegate_to_agent", - extracted_context="", - search_query="", - ) - - mock_cot_astream.return_value = mock_router_stream() - - # Mock the agent streaming with a Prediction at the end - async def mock_agent_stream(*args, **kwargs): - yield OutputStreamChunk( - module=None, - field_name="answer", - delta="Based on docs", - content="Based on docs", - is_complete=False, - ) - yield Prediction( - module=assistant._assistant, - answer="Based on docs", - trajectory=[], - reasoning="", - ) + mock_run_stream_events.side_effect = make_mock_run_stream_events_side_effect( + "Based on docs" + ) - mock_react_astream.return_value = mock_agent_stream() + assistant = Assistant(chat) ui_context = UIContext( workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), user=UserUIContext(id=user.id, name=user.first_name, email=user.email), ) - # Manually add sources to callback manager (simulating tool execution) async def consume_stream(): - messages = [] human_message = HumanMessage(content="Question", ui_context=ui_context) - async for msg in assistant.astream_messages(human_message): - messages.append(msg) - return messages + async for _ in assistant.astream_messages(human_message): + pass async_to_sync(consume_stream)() - # AI message should be persisted ai_messages = AssistantChatMessage.objects.filter( chat=chat, role=AssistantChatMessage.Role.AI ).count() assert ai_messages == 1 - @patch("udspy.ChainOfThought.astream") - @patch("udspy.ReAct.astream") - @patch("udspy.Predict") + @patch("baserow_enterprise.assistant.agents.title_agent.run") + @patch("baserow_enterprise.assistant.agents.main_agent.run_stream_events") def test_astream_messages_persists_chat_title( self, - mock_predict_class, - mock_react_astream, - mock_cot_astream, + mock_run_stream_events, + mock_title_run, enterprise_data_fixture, ): - """Test that chat titles are persisted to the database""" - user = enterprise_data_fixture.create_user() workspace = enterprise_data_fixture.create_workspace(user=user) - chat = AssistantChat.objects.create( - user=user, - workspace=workspace, - title="", # New chat - ) + chat = AssistantChat.objects.create(user=user, workspace=workspace, title="") - # Mock title generator - async def mock_title_aforward(*args, **kwargs): - return Prediction(chat_title="Greeting") + mock_run_stream_events.side_effect = make_mock_run_stream_events_side_effect( + "Hello" + ) - mock_title_generator = MagicMock() - mock_title_generator.aforward = mock_title_aforward - mock_predict_class.return_value = mock_title_generator + mock_title_result = MagicMock() + mock_title_result.output = "Greeting" + mock_title_run.return_value = mock_title_result assistant = Assistant(chat) - - # Mock the router stream - async def mock_router_stream(*args, **kwargs): - yield Prediction( - routing_decision="delegate_to_agent", - extracted_context="", - search_query="", - ) - - mock_cot_astream.return_value = mock_router_stream() - - # Mock agent streaming - async def mock_agent_stream(*args, **kwargs): - yield OutputStreamChunk( - module=None, - field_name="answer", - delta="Hello", - content="Hello", - is_complete=False, - ) - yield Prediction( - module=assistant._assistant, answer="Hello", trajectory=[], reasoning="" - ) - - mock_react_astream.return_value = mock_agent_stream() ui_context = UIContext( workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), user=UserUIContext(id=user.id, name=user.first_name, email=user.email), ) - # Consume the stream async def consume_stream(): human_message = HumanMessage(content="Hello", ui_context=ui_context) async for _ in assistant.astream_messages(human_message): @@ -565,187 +414,112 @@ async def consume_stream(): async_to_sync(consume_stream)() - # Refresh from DB chat.refresh_from_db() - - # Title should be persisted assert chat.title == "Greeting" @pytest.mark.django_db class TestAssistantStreaming: - """Test streaming behavior of the Assistant""" + """Test streaming behavior of the Assistant.""" - @patch("udspy.ChainOfThought.astream") - @patch("udspy.ReAct.astream") + @patch("baserow_enterprise.assistant.agents.main_agent.run_stream_events") def test_astream_messages_yields_answer_chunks( - self, mock_react_astream, mock_cot_astream, enterprise_data_fixture + self, mock_run_stream_events, enterprise_data_fixture ): - """Test that answer chunks are yielded during streaming""" - user = enterprise_data_fixture.create_user() workspace = enterprise_data_fixture.create_workspace(user=user) chat = AssistantChat.objects.create( user=user, workspace=workspace, title="Test Chat" ) - # Mock the router stream - async def mock_router_stream(*args, **kwargs): - yield Prediction( - routing_decision="delegate_to_agent", - extracted_context="", - search_query="", - ) - - mock_cot_astream.return_value = mock_router_stream() + mock_run_stream_events.side_effect = make_mock_run_stream_events_side_effect( + "Hello world" + ) assistant = Assistant(chat) - # Mock agent streaming - async def mock_agent_stream(*args, **kwargs): - yield OutputStreamChunk( - module=assistant._assistant.extract_module, - field_name="answer", - delta="Hello", - content="Hello", - is_complete=False, - ) - yield OutputStreamChunk( - module=assistant._assistant.extract_module, - field_name="answer", - delta=" world", - content="Hello world", - is_complete=False, - ) - yield Prediction(answer="Hello world", trajectory=[], reasoning="") - - mock_react_astream.return_value = mock_agent_stream() - async def consume_stream(): - chunks = [] + messages = [] human_message = HumanMessage(content="Test") async for msg in assistant.astream_messages(human_message): - if isinstance(msg, AiMessageChunk): - chunks.append(msg) - return chunks + messages.append(msg) + return messages - chunks = async_to_sync(consume_stream)() + messages = async_to_sync(consume_stream)() - # Should receive chunks with accumulated content - assert len(chunks) == 2 - assert chunks[0].content == "Hello" - assert chunks[1].content == "Hello world" + # Filter for final AiMessage + ai_messages = [m for m in messages if isinstance(m, AiMessage)] + assert len(ai_messages) == 1 + assert ai_messages[0].content == "Hello world" + assert ai_messages[0].id is not None - @patch("udspy.ChainOfThought.astream") - @patch("udspy.ReAct.astream") - @patch("udspy.Predict") - def test_astream_messages_yields_title_chunks( - self, - mock_predict_class, - mock_react_astream, - mock_cot_astream, - enterprise_data_fixture, - ): - """Test that title chunks are yielded for new chats""" + # Should also have AiMessageChunk(s) + chunks = [ + m + for m in messages + if isinstance(m, AiMessageChunk) and not isinstance(m, AiMessage) + ] + assert len(chunks) >= 1 + @patch("baserow_enterprise.assistant.agents.title_agent.run") + @patch("baserow_enterprise.assistant.agents.main_agent.run_stream_events") + def test_astream_messages_yields_title_for_new_chat( + self, mock_run_stream_events, mock_title_run, enterprise_data_fixture + ): user = enterprise_data_fixture.create_user() workspace = enterprise_data_fixture.create_workspace(user=user) - chat = AssistantChat.objects.create( - user=user, - workspace=workspace, - title="", # New chat - ) + chat = AssistantChat.objects.create(user=user, workspace=workspace, title="") - # Mock title generator - async def mock_title_aforward(*args, **kwargs): - return Prediction(chat_title="Title") + mock_run_stream_events.side_effect = make_mock_run_stream_events_side_effect( + "Answer" + ) - mock_title_generator = MagicMock() - mock_title_generator.aforward = mock_title_aforward - mock_predict_class.return_value = mock_title_generator + mock_title_result = MagicMock() + mock_title_result.output = "Title" + mock_title_run.return_value = mock_title_result assistant = Assistant(chat) - # Mock the router stream - async def mock_router_stream(*args, **kwargs): - yield Prediction( - routing_decision="delegate_to_agent", - extracted_context="", - search_query="", - ) - - mock_cot_astream.return_value = mock_router_stream() - - # Mock agent streaming - async def mock_agent_stream(*args, **kwargs): - yield OutputStreamChunk( - module=None, - field_name="answer", - delta="Answer", - content="Answer", - is_complete=False, - ) - yield Prediction( - module=assistant._assistant, - answer="Answer", - trajectory=[], - reasoning="", - ) - - mock_react_astream.return_value = mock_agent_stream() - async def consume_stream(): - title_messages = [] + msgs = [] human_message = HumanMessage(content="Test") async for msg in assistant.astream_messages(human_message): - if isinstance(msg, ChatTitleMessage): - title_messages.append(msg) - return title_messages + msgs.append(msg) + return msgs - title_messages = async_to_sync(consume_stream)() + messages = async_to_sync(consume_stream)() - # Should receive title chunks + title_messages = [m for m in messages if isinstance(m, ChatTitleMessage)] assert len(title_messages) == 1 assert title_messages[0].content == "Title" - @patch("udspy.ChainOfThought.astream") - @patch("udspy.ReAct.astream") + @patch("baserow_enterprise.assistant.agents.main_agent.run_stream_events") def test_astream_messages_yields_thinking_messages( - self, mock_react_astream, mock_cot_astream, enterprise_data_fixture + self, mock_run_stream_events, enterprise_data_fixture ): - """Test that thinking messages from tools are yielded""" - user = enterprise_data_fixture.create_user() workspace = enterprise_data_fixture.create_workspace(user=user) chat = AssistantChat.objects.create( user=user, workspace=workspace, title="Test Chat" ) - # Mock the router stream - async def mock_router_stream(*args, **kwargs): - yield Prediction( - routing_decision="delegate_to_agent", - extracted_context="", - search_query="", - ) + assistant = Assistant(chat) - mock_cot_astream.return_value = mock_router_stream() + async def mock_stream_with_thinking(*args, **kwargs): + from pydantic_ai.run import AgentRunResultEvent - assistant = Assistant(chat) + # Emit thinking message via the event bus during streaming + assistant._event_bus.emit(AiThinkingMessage(content="still thinking...")) - # Mock the agent streaming - async def mock_agent_stream(*args, **kwargs): - yield AiThinkingMessage(content="still thinking...") - yield OutputStreamChunk( - module=assistant._assistant.extract_module, - field_name="answer", - delta="Answer", - content="Answer", - is_complete=False, - ) - yield Prediction(answer="Answer", trajectory=[], reasoning="") + # Yield text part then result + yield PartStartEvent(index=0, part=PaiTextPart(content="Answer")) + + mock_result = MagicMock() + mock_result.output = "Answer" + mock_result.all_messages_json.return_value = b"[]" + yield AgentRunResultEvent(result=mock_result) - mock_react_astream.return_value = mock_agent_stream() + mock_run_stream_events.side_effect = mock_stream_with_thinking ui_context = UIContext( workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), @@ -753,38 +527,60 @@ async def mock_agent_stream(*args, **kwargs): ) async def consume_stream(): - thinking_messages = [] + thinking = [] human_message = HumanMessage(content="Test", ui_context=ui_context) async for msg in assistant.astream_messages(human_message): if isinstance(msg, AiThinkingMessage): - thinking_messages.append(msg) - return thinking_messages + thinking.append(msg) + return thinking thinking_messages = async_to_sync(consume_stream)() - # Should receive the thinking message emitted by the agent stream assert len(thinking_messages) == 1 assert thinking_messages[0].content == "still thinking..." + @patch("baserow_enterprise.assistant.agents.main_agent.run_stream_events") + def test_astream_messages_yields_ai_started_message( + self, mock_run_stream_events, enterprise_data_fixture + ): + user = enterprise_data_fixture.create_user() + workspace = enterprise_data_fixture.create_workspace(user=user) + chat = AssistantChat.objects.create( + user=user, workspace=workspace, title="Test" + ) + + mock_run_stream_events.side_effect = make_mock_run_stream_events_side_effect( + "Hello" + ) + + assistant = Assistant(chat) + human_message = HumanMessage(content="Hello") + + async def collect_messages(): + messages = [] + async for msg in assistant.astream_messages(human_message): + messages.append(msg) + return messages + + messages = async_to_sync(collect_messages)() + + assert len(messages) > 0 + assert isinstance(messages[0], AiStartedMessage) + assert messages[0].message_id is not None + @pytest.mark.django_db class TestUIContext: - """Test UI context handling and validation""" + """Test UI context handling and validation.""" def test_ui_context_from_validate_request_adds_user_info( self, enterprise_data_fixture ): - """ - Test that UIContext.from_validate_request adds user information - from request - """ - user = enterprise_data_fixture.create_user( email="test@example.com", first_name="Test User" ) workspace = enterprise_data_fixture.create_workspace(user=user) - # Mock request object class MockRequest: pass @@ -792,7 +588,6 @@ class MockRequest: request.user = user ui_context_data = {"workspace": {"id": workspace.id, "name": workspace.name}} - ui_context = UIContext.from_validate_request(request, ui_context_data) assert ui_context.workspace.id == workspace.id @@ -802,8 +597,6 @@ class MockRequest: assert ui_context.user.name == "Test User" def test_ui_context_with_database_builder_fields(self): - """Test that UIContext correctly stores database builder fields""" - ui_context = UIContext( workspace=WorkspaceUIContext(id=1, name="Test Workspace"), database=ApplicationUIContext(id="db-123", name="My Database"), @@ -814,123 +607,23 @@ def test_ui_context_with_database_builder_fields(self): assert ui_context.workspace.id == 1 assert ui_context.database.id == "db-123" - assert ui_context.database.name == "My Database" assert ui_context.table.id == 456 - assert ui_context.table.name == "Customers" assert ui_context.view.id == 789 - assert ui_context.view.name == "All Customers" - assert ui_context.view.type == "grid" - - def test_ui_context_with_application_builder_fields(self): - """Test that UIContext correctly stores application builder fields""" - - ui_context = UIContext( - workspace=WorkspaceUIContext(id=1, name="Test Workspace"), - application=ApplicationUIContext(id="app-123", name="My App"), - user=UserUIContext(id=1, name="Test", email="test@test.com"), - ) - - assert ui_context.application.id == "app-123" - assert ui_context.application.name == "My App" - assert ui_context.database is None - assert ui_context.table is None def test_ui_context_serialization_excludes_none_values(self): - """Test that UIContext serialization excludes None values""" - ui_context = UIContext( workspace=WorkspaceUIContext(id=1, name="Test Workspace"), user=UserUIContext(id=1, name="Test", email="test@test.com"), - # All other fields are None ) - # Serialize with exclude_none=True serialized = ui_context.model_dump(exclude_none=True) - assert "workspace" in serialized assert "user" in serialized assert "database" not in serialized assert "table" not in serialized - assert "view" not in serialized - assert "application" not in serialized - - def test_ui_context_json_serialization_excludes_none(self): - """Test that UIContext JSON serialization excludes None values""" - - ui_context = UIContext( - workspace=WorkspaceUIContext(id=1, name="Test Workspace"), - table=TableUIContext(id=456, name="Customers"), - user=UserUIContext(id=1, name="Test", email="test@test.com"), - # database and view are None - ) - - # Serialize to JSON with exclude_none=True - json_str = ui_context.model_dump_json(exclude_none=True) - - # Parse back to verify - import json - - parsed = json.loads(json_str) - - assert "workspace" in parsed - assert "table" in parsed - assert "user" in parsed - assert "database" not in parsed - assert "view" not in parsed - - def test_human_message_with_ui_context(self): - """Test that HumanMessage correctly stores ui_context""" - - ui_context = UIContext( - workspace=WorkspaceUIContext(id=1, name="Test Workspace"), - database=ApplicationUIContext(id="db-123", name="My Database"), - user=UserUIContext(id=1, name="Test", email="test@test.com"), - ) - - human_message = HumanMessage( - content="How do I create a field?", ui_context=ui_context - ) - - assert human_message.content == "How do I create a field?" - assert human_message.ui_context.workspace.id == 1 - assert human_message.ui_context.database.id == "db-123" - assert human_message.ui_context.database.name == "My Database" - - def test_human_message_ui_context_json_serialization(self): - """ - Test that HumanMessage ui_context serializes to JSON with None - values excluded - """ - - ui_context = UIContext( - workspace=WorkspaceUIContext(id=1, name="Test Workspace"), - database=ApplicationUIContext(id="db-123", name="My Database"), - table=TableUIContext(id=456, name="Customers"), - user=UserUIContext(id=1, name="Test", email="test@test.com"), - # view is None - ) - - human_message = HumanMessage( - content="How do I filter this view?", ui_context=ui_context - ) - - # Serialize ui_context as it would be in the prompt - ui_context_json = human_message.ui_context.model_dump_json(exclude_none=True) - - # Parse to verify - import json - - parsed = json.loads(ui_context_json) - - # Should have database and table but not view - assert "database" in parsed - assert parsed["database"]["name"] == "My Database" - assert "table" in parsed - assert parsed["table"]["name"] == "Customers" - assert "view" not in parsed # None values excluded def test_ui_context_has_default_timestamp(self): - """Test that UIContext has a default timestamp""" + from datetime import datetime ui_context = UIContext( workspace=WorkspaceUIContext(id=1, name="Test"), @@ -938,14 +631,9 @@ def test_ui_context_has_default_timestamp(self): ) assert ui_context.timestamp is not None - # Should be a datetime object - from datetime import datetime - assert isinstance(ui_context.timestamp, datetime) def test_ui_context_has_default_timezone(self): - """Test that UIContext has a default timezone of UTC""" - ui_context = UIContext( workspace=WorkspaceUIContext(id=1, name="Test"), user=UserUIContext(id=1, name="Test", email="test@test.com"), @@ -954,8 +642,6 @@ def test_ui_context_has_default_timezone(self): assert ui_context.timezone == "UTC" def test_user_ui_context_from_user(self, enterprise_data_fixture): - """Test UserUIContext.from_user factory method""" - user = enterprise_data_fixture.create_user( email="john@example.com", first_name="John Doe" ) @@ -966,173 +652,55 @@ def test_user_ui_context_from_user(self, enterprise_data_fixture): assert user_context.name == "John Doe" assert user_context.email == "john@example.com" - -@pytest.mark.django_db -class TestAssistantCancellation: - """Test cancellation functionality in Assistant""" - - def test_get_cancellation_cache_key(self, enterprise_data_fixture): - """Test that cancellation cache key is correctly formatted""" - - user = enterprise_data_fixture.create_user() - workspace = enterprise_data_fixture.create_workspace(user=user) - chat = AssistantChat.objects.create( - user=user, workspace=workspace, title="Test" - ) - - assistant = Assistant(chat) - cache_key = assistant._get_cancellation_cache_key() - - assert cache_key == f"assistant:chat:{chat.uuid}:cancelled" - - def test_check_cancellation_raises_when_flag_set(self, enterprise_data_fixture): - """Test that check_cancellation raises exception when flag is set""" - - user = enterprise_data_fixture.create_user() - workspace = enterprise_data_fixture.create_workspace(user=user) - chat = AssistantChat.objects.create( - user=user, workspace=workspace, title="Test" + def test_human_message_with_ui_context(self): + ui_context = UIContext( + workspace=WorkspaceUIContext(id=1, name="Test Workspace"), + database=ApplicationUIContext(id="db-123", name="My Database"), + user=UserUIContext(id=1, name="Test", email="test@test.com"), ) - assistant = Assistant(chat) - cache_key = assistant._get_cancellation_cache_key() - - # Set cancellation flag - cache.set(cache_key, True) - - # Should raise exception - with pytest.raises(AssistantMessageCancelled) as exc_info: - assistant._check_cancellation(cache_key, "msg123") - - assert exc_info.value.message_id == "msg123" - - # Flag should be cleaned up - assert cache.get(cache_key) is None - - def test_check_cancellation_does_nothing_when_no_flag( - self, enterprise_data_fixture - ): - """Test that check_cancellation does nothing when flag not set""" - - user = enterprise_data_fixture.create_user() - workspace = enterprise_data_fixture.create_workspace(user=user) - chat = AssistantChat.objects.create( - user=user, workspace=workspace, title="Test" + human_message = HumanMessage( + content="How do I create a field?", ui_context=ui_context ) - assistant = Assistant(chat) - cache_key = assistant._get_cancellation_cache_key() + assert human_message.content == "How do I create a field?" + assert human_message.ui_context.workspace.id == 1 + assert human_message.ui_context.database.id == "db-123" - # Should not raise - assistant._check_cancellation(cache_key, "msg123") - @patch("udspy.ChainOfThought.astream") - @patch("udspy.ReAct.astream") - def test_astream_messages_yields_ai_started_message( - self, mock_react_astream, mock_cot_astream, enterprise_data_fixture - ): - """Test that astream_messages yields AiStartedMessage at the beginning""" +@pytest.mark.django_db +class TestAssistantCancellation: + """Test cancellation functionality in Assistant.""" + def test_get_cancellation_cache_key(self, enterprise_data_fixture): user = enterprise_data_fixture.create_user() workspace = enterprise_data_fixture.create_workspace(user=user) chat = AssistantChat.objects.create( user=user, workspace=workspace, title="Test" ) - # Mock the router stream - async def mock_router_stream(*args, **kwargs): - yield Prediction( - routing_decision="delegate_to_agent", - extracted_context="", - search_query="", - ) - - mock_cot_astream.return_value = mock_router_stream() - - # Mock the agent streaming - async def mock_agent_stream(*args, **kwargs): - yield OutputStreamChunk( - module=None, - field_name="answer", - delta="Hello", - content="Hello", - is_complete=False, - ) - yield Prediction(answer="Hello there!", trajectory=[], reasoning="") - - mock_react_astream.return_value = mock_agent_stream() - - assistant = Assistant(chat) - human_message = HumanMessage(content="Hello") - - # Collect messages - async def collect_messages(): - messages = [] - async for msg in assistant.astream_messages(human_message): - messages.append(msg) - return messages - - messages = async_to_sync(collect_messages)() - - # First message should be AiStartedMessage - assert len(messages) > 0 - assert isinstance(messages[0], AiStartedMessage) - assert messages[0].message_id is not None - - @patch("udspy.ChainOfThought.astream") - @patch("udspy.ReAct.astream") - def test_astream_messages_checks_cancellation_periodically( - self, mock_react_astream, mock_cot_astream, enterprise_data_fixture - ): - """Test that astream_messages checks for cancellation every 10 chunks""" - - user = enterprise_data_fixture.create_user() - workspace = enterprise_data_fixture.create_workspace(user=user) - chat = AssistantChat.objects.create( - user=user, workspace=workspace, title="Test" + from baserow_enterprise.assistant.assistant import ( + get_assistant_cancellation_key, ) - # Mock the router stream - async def mock_router_stream(*args, **kwargs): - yield Prediction( - routing_decision="delegate_to_agent", - extracted_context="", - search_query="", - ) - - mock_cot_astream.return_value = mock_router_stream() + cache_key = get_assistant_cancellation_key(str(chat.uuid)) + assert cache_key == f"assistant:chat:{chat.uuid}:cancelled" - # Mock the stream to return many chunks - enough to trigger check at 10 - async def mock_agent_stream(*args, **kwargs): - # Yield 15 chunks - cancellation check happens at chunk 10 - for i in range(15): - yield OutputStreamChunk( - module=None, - field_name="answer", - delta=f"word{i}", - content=f"word{i}", - is_complete=False, - ) - yield Prediction(answer="Complete response", trajectory=[], reasoning="") - mock_react_astream.return_value = mock_agent_stream() +class TestGetModelString: + """Test the model string conversion logic.""" - assistant = Assistant(chat) - cache_key = assistant._get_cancellation_cache_key() + @override_settings(BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL="groq/llama-3.3-70b") + def test_replaces_slash_with_colon(self): + assert get_model_string() == "groq:llama-3.3-70b" - # Set cancellation flag immediately - it should be detected at chunk 10 - cache.set(cache_key, True) - - ui_context = UIContext( - workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), - user=UserUIContext(id=user.id, name=user.first_name, email=user.email), - ) - human_message = HumanMessage(content="Hello", ui_context=ui_context) + @override_settings(BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL="openai/gpt-4") + def test_openai_model(self): + assert get_model_string() == "openai:gpt-4" - # Should raise AssistantMessageCancelled when check happens at chunk 10 - async def stream_messages(): - async for msg in assistant.astream_messages(human_message): - pass + @override_settings(BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL="gpt-4o") + def test_bare_model_defaults_to_openai(self): + assert get_model_string() == "openai:gpt-4o" - with pytest.raises(AssistantMessageCancelled): - async_to_sync(stream_messages)() + def test_explicit_model_overrides_setting(self): + assert get_model_string("groq/custom-model") == "groq:custom-model" diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_automation_node_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_automation_node_tools.py new file mode 100644 index 0000000000..c57f35cf1d --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_automation_node_tools.py @@ -0,0 +1,304 @@ +import pytest + +from baserow.contrib.automation.nodes.models import AutomationNode +from baserow.contrib.automation.nodes.service import AutomationNodeService +from baserow_enterprise.assistant.tools.automation.tools import ( + add_nodes, + create_workflows, + delete_nodes, + list_nodes, + update_nodes, +) +from baserow_enterprise.assistant.tools.automation.types import ( + ActionNodeCreate, + NodeUpdate, + TriggerNodeCreate, + WorkflowCreate, +) + +from .utils import make_test_ctx + + +@pytest.fixture(autouse=True) +def mock_formula_generator(monkeypatch): + """Mock update_workflow_formulas and update_single_node_formulas to avoid LM calls.""" + + monkeypatch.setattr( + "baserow_enterprise.assistant.tools.automation.agents.update_workflow_formulas", + lambda workflow, node_mapping, tool_helpers: None, + ) + monkeypatch.setattr( + "baserow_enterprise.assistant.tools.automation.agents.update_single_node_formulas", + lambda node_update, orm_node, tool_helpers: None, + ) + + +def _create_test_workflow(data_fixture, user, workspace): + """Create a workflow with a trigger and an email action node.""" + automation = data_fixture.create_automation_application( + user=user, workspace=workspace + ) + + ctx = make_test_ctx(user, workspace) + result = create_workflows( + ctx, + automation_id=automation.id, + workflows=[ + WorkflowCreate( + name="Test Workflow", + trigger=TriggerNodeCreate( + ref="trigger1", + label="Periodic Trigger", + type="periodic", + periodic_interval={"interval": "DAY"}, + ), + nodes=[ + ActionNodeCreate( + ref="email1", + label="Send Email", + previous_node_ref="trigger1", + type="smtp_email", + to_emails="test@example.com", + subject="Hello", + body="World", + ), + ], + ) + ], + thought="test", + ) + + workflow_id = result["created_workflows"][0]["id"] + return automation, workflow_id + + +@pytest.mark.django_db(transaction=True) +def test_list_nodes(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + automation, workflow_id = _create_test_workflow(data_fixture, user, workspace) + + ctx = make_test_ctx(user, workspace) + result = list_nodes(ctx, workflow_id=workflow_id, thought="inspect") + + nodes = result["nodes"] + assert len(nodes) == 2 + + # First node is the trigger + assert nodes[0]["label"] == "Periodic Trigger" + assert nodes[0]["type"] == "periodic" + + # Second node is the email action + assert nodes[1]["label"] == "Send Email" + assert nodes[1]["type"] == "smtp_email" + + # All nodes have IDs + assert all("id" in n for n in nodes) + + +@pytest.mark.django_db(transaction=True) +def test_add_node_after_existing(data_fixture): + """Add a router node between the trigger and existing email node.""" + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + automation, workflow_id = _create_test_workflow(data_fixture, user, workspace) + + # Get existing nodes + ctx = make_test_ctx(user, workspace) + existing = list_nodes(ctx, workflow_id=workflow_id, thought="check") + trigger_id = existing["nodes"][0]["id"] + email_id = existing["nodes"][1]["id"] + + # Delete the existing email node first (we'll re-add it after the router) + delete_nodes( + ctx, node_ids=[email_id], thought="remove email to re-add after router" + ) + + # Add a router after the trigger, then a new email after the router + result = add_nodes( + ctx, + workflow_id=workflow_id, + nodes=[ + ActionNodeCreate( + ref="router1", + label="My Router", + type="router", + previous_node_ref=str(trigger_id), + edges=[ + {"label": "always", "condition": "true"}, + ], + ), + ActionNodeCreate( + ref="slack1", + label="Send Slack After Router", + type="smtp_email", + previous_node_ref="router1", + router_edge_label="always", + to_emails="test@example.com", + subject="Hello", + body="Routed message", + ), + ], + thought="insert router between trigger and email", + ) + + assert len(result["created_nodes"]) == 2 + assert result["created_nodes"][0]["type"] == "router" + assert result["created_nodes"][0]["label"] == "My Router" + assert result["created_nodes"][1]["label"] == "Send Slack After Router" + + # Verify final workflow order + final = list_nodes(ctx, workflow_id=workflow_id, thought="verify") + assert len(final["nodes"]) == 3 + assert final["nodes"][0]["type"] == "periodic" + assert final["nodes"][1]["type"] == "router" + assert final["nodes"][2]["type"] == "smtp_email" + + +@pytest.mark.django_db(transaction=True) +def test_add_node_append_to_workflow(data_fixture): + """Append a new action node at the end of an existing workflow.""" + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + automation, workflow_id = _create_test_workflow(data_fixture, user, workspace) + + ctx = make_test_ctx(user, workspace) + existing = list_nodes(ctx, workflow_id=workflow_id, thought="check") + email_id = existing["nodes"][1]["id"] + + # Append a new email node after the existing email node + result = add_nodes( + ctx, + workflow_id=workflow_id, + nodes=[ + ActionNodeCreate( + ref="email1", + label="Follow-up Email", + type="smtp_email", + previous_node_ref=str(email_id), + to_emails="followup@example.com", + subject="Follow-up", + body="This is a follow-up.", + ), + ], + thought="append email after email", + ) + + assert len(result["created_nodes"]) == 1 + assert result["created_nodes"][0]["label"] == "Follow-up Email" + + # Verify workflow now has 3 nodes + final = list_nodes(ctx, workflow_id=workflow_id, thought="verify") + assert len(final["nodes"]) == 3 + assert final["nodes"][2]["type"] == "smtp_email" + assert final["nodes"][2]["label"] == "Follow-up Email" + + +@pytest.mark.django_db(transaction=True) +def test_update_node_label(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + automation, workflow_id = _create_test_workflow(data_fixture, user, workspace) + + # Get the action node + from baserow.contrib.automation.workflows.service import AutomationWorkflowService + + workflow = AutomationWorkflowService().get_workflow(user, workflow_id) + nodes = list(workflow.automation_workflow_nodes.all().order_by("id")) + action_node = nodes[-1] # The email action node + + ctx = make_test_ctx(user, workspace) + result = update_nodes( + ctx, + workflow_id=workflow_id, + nodes=[NodeUpdate(node_id=action_node.id, label="Updated Email")], + thought="rename node", + ) + + assert result["updated_nodes"][0]["label"] == "Updated Email" + + # Verify in DB + refreshed = AutomationNodeService().get_node(user, action_node.id) + assert refreshed.label == "Updated Email" + + +@pytest.mark.django_db(transaction=True) +def test_update_node_service_config(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + automation, workflow_id = _create_test_workflow(data_fixture, user, workspace) + + from baserow.contrib.automation.workflows.service import AutomationWorkflowService + + workflow = AutomationWorkflowService().get_workflow(user, workflow_id) + nodes = list(workflow.automation_workflow_nodes.all().order_by("id")) + action_node = nodes[-1] + + ctx = make_test_ctx(user, workspace) + result = update_nodes( + ctx, + workflow_id=workflow_id, + nodes=[ + NodeUpdate( + node_id=action_node.id, + subject="New Subject", + ) + ], + thought="update email subject", + ) + + assert len(result["updated_nodes"]) == 1 + assert "errors" not in result + + +@pytest.mark.django_db(transaction=True) +def test_delete_node(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + automation, workflow_id = _create_test_workflow(data_fixture, user, workspace) + + from baserow.contrib.automation.workflows.service import AutomationWorkflowService + + workflow = AutomationWorkflowService().get_workflow(user, workflow_id) + nodes = list(workflow.automation_workflow_nodes.all().order_by("id")) + action_node = nodes[-1] + + ctx = make_test_ctx(user, workspace) + result = delete_nodes( + ctx, + node_ids=[action_node.id], + thought="delete node", + ) + + assert result["deleted_node_ids"] == [action_node.id] + + # Node should be gone + assert not AutomationNode.objects.filter(id=action_node.id).exists() + + +@pytest.mark.django_db(transaction=True) +def test_delete_node_wrong_workspace(data_fixture): + user = data_fixture.create_user() + workspace1 = data_fixture.create_workspace(user=user) + workspace2 = data_fixture.create_workspace(user=user) + automation, workflow_id = _create_test_workflow(data_fixture, user, workspace1) + + from baserow.contrib.automation.workflows.service import AutomationWorkflowService + + workflow = AutomationWorkflowService().get_workflow(user, workflow_id) + nodes = list(workflow.automation_workflow_nodes.all().order_by("id")) + action_node = nodes[-1] + + # Try to delete from wrong workspace + ctx = make_test_ctx(user, workspace2) + result = delete_nodes( + ctx, + node_ids=[action_node.id], + thought="delete from wrong workspace", + ) + + assert result["deleted_node_ids"] == [] + assert len(result["errors"]) == 1 + + # Node should still exist + assert AutomationNode.objects.filter(id=action_node.id).exists() diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_automation_workflow_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_automation_workflow_tools.py index 1e8ee8f8b0..05a7f75057 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_automation_workflow_tools.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_automation_workflow_tools.py @@ -1,28 +1,25 @@ -from unittest.mock import Mock - import pytest -from udspy.module.callbacks import ModuleContext, is_module_callback from baserow.contrib.automation.workflows.handler import AutomationWorkflowHandler from baserow.core.formula import resolve_formula from baserow.core.formula.registries import formula_runtime_function_registry from baserow.core.formula.types import BASEROW_FORMULA_MODE_ADVANCED +from baserow_enterprise.assistant.tools.automation.agents import AssistantFormulaContext from baserow_enterprise.assistant.tools.automation.tools import ( - get_list_workflows_tool, - get_workflow_tool_factory, + create_workflows, + list_workflows, ) from baserow_enterprise.assistant.tools.automation.types import ( - CreateRowActionCreate, - DeleteRowActionCreate, - RouterNodeCreate, + ActionNodeCreate, TriggerNodeCreate, - UpdateRowActionCreate, WorkflowCreate, ) -from baserow_enterprise.assistant.tools.automation.types.node import RouterEdgeCreate -from baserow_enterprise.assistant.tools.automation.utils import AssistantFormulaContext +from baserow_enterprise.assistant.tools.automation.types.node import ( + AutomationFieldValue, + RouterEdgeCreate, +) -from .utils import fake_tool_helpers +from .utils import make_test_ctx @pytest.fixture(autouse=True) @@ -38,7 +35,7 @@ def mock_update_workflow_formulas(workflow, node_mapping, tool_helpers): pass monkeypatch.setattr( - "baserow_enterprise.assistant.tools.automation.utils.update_workflow_formulas", + "baserow_enterprise.assistant.tools.automation.agents.update_workflow_formulas", mock_update_workflow_formulas, ) @@ -54,8 +51,8 @@ def test_list_workflows(data_fixture): automation=automation, name="Test Workflow" ) - tool = get_list_workflows_tool(user, workspace, fake_tool_helpers) - result = tool(automation_id=automation.id) + ctx = make_test_ctx(user, workspace) + result = list_workflows(ctx, automation_id=automation.id, thought="test") assert result == { "workflows": [{"id": workflow.id, "name": "Test Workflow", "state": "draft"}] @@ -76,8 +73,8 @@ def test_list_workflows_multiple(data_fixture): automation=automation, name="Workflow 2" ) - tool = get_list_workflows_tool(user, workspace, fake_tool_helpers) - result = tool(automation_id=automation.id) + ctx = make_test_ctx(user, workspace) + result = list_workflows(ctx, automation_id=automation.id, thought="test") assert result == { "workflows": [ @@ -97,25 +94,10 @@ def test_create_workflows(data_fixture): database = data_fixture.create_database_application(user=user, workspace=workspace) table = data_fixture.create_database_table(user=user, database=database) - factory = get_workflow_tool_factory(user, workspace, fake_tool_helpers) - assert callable(factory) - - tools_upgrade = factory() - assert is_module_callback(tools_upgrade) + ctx = make_test_ctx(user, workspace) - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called - - added_tools = mock_module.init_module.call_args[1]["tools"] - create_workflows_tool = next( - (tool for tool in added_tools if tool.name == "create_workflows"), None - ) - assert create_workflows_tool is not None - - result = create_workflows_tool.func( + result = create_workflows( + ctx, automation_id=automation.id, workflows=[ WorkflowCreate( @@ -124,19 +106,21 @@ def test_create_workflows(data_fixture): ref="trigger1", label="Periodic Trigger", type="periodic", + periodic_interval={"interval": "DAY"}, ), nodes=[ - CreateRowActionCreate( + ActionNodeCreate( ref="action1", label="Create row", previous_node_ref="trigger1", type="create_row", table_id=table.id, - values={}, + values=[], ) ], ) ], + thought="test", ) assert len(result["created_workflows"]) == 1 @@ -144,8 +128,6 @@ def test_create_workflows(data_fixture): assert result["created_workflows"][0]["state"] == "draft" # Verify workflow was created with a trigger - from baserow.contrib.automation.workflows.handler import AutomationWorkflowHandler - workflow_id = result["created_workflows"][0]["id"] workflow = AutomationWorkflowHandler().get_workflow(workflow_id) trigger = workflow.get_trigger() @@ -163,25 +145,10 @@ def test_create_multiple_workflows(data_fixture): database = data_fixture.create_database_application(user=user, workspace=workspace) table = data_fixture.create_database_table(user=user, database=database) - factory = get_workflow_tool_factory(user, workspace, fake_tool_helpers) - assert callable(factory) - - tools_upgrade = factory() - assert is_module_callback(tools_upgrade) - - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called + ctx = make_test_ctx(user, workspace) - added_tools = mock_module.init_module.call_args[1]["tools"] - create_workflows_tool = next( - (tool for tool in added_tools if tool.name == "create_workflows"), None - ) - assert create_workflows_tool is not None - - result = create_workflows_tool.func( + result = create_workflows( + ctx, automation_id=automation.id, workflows=[ WorkflowCreate( @@ -190,15 +157,16 @@ def test_create_multiple_workflows(data_fixture): ref="trigger1", label="Trigger", type="periodic", + periodic_interval={"interval": "DAY"}, ), nodes=[ - CreateRowActionCreate( + ActionNodeCreate( ref="action1", label="Action", previous_node_ref="trigger1", type="create_row", table_id=table.id, - values={}, + values=[], ) ], ), @@ -208,19 +176,21 @@ def test_create_multiple_workflows(data_fixture): ref="trigger2", label="Trigger", type="periodic", + periodic_interval={"interval": "DAY"}, ), nodes=[ - CreateRowActionCreate( + ActionNodeCreate( ref="action2", label="Action", previous_node_ref="trigger2", type="create_row", table_id=table.id, - values={}, + values=[], ) ], ), ], + thought="test", ) assert len(result["created_workflows"]) == 2 @@ -234,36 +204,45 @@ def test_create_multiple_workflows(data_fixture): [ ( TriggerNodeCreate( - type="rows_created", ref="trigger", label="Rows Created Trigger" + type="rows_created", + ref="trigger", + label="Rows Created Trigger", + rows_triggers_settings={"table_id": 999}, ), - CreateRowActionCreate( + ActionNodeCreate( type="create_row", ref="action", previous_node_ref="trigger", label="Create Row Action", table_id=999, - values={}, + values=[], ), ), ( TriggerNodeCreate( - type="rows_updated", ref="trigger", label="Rows Updated Trigger" + type="rows_updated", + ref="trigger", + label="Rows Updated Trigger", + rows_triggers_settings={"table_id": 999}, ), - UpdateRowActionCreate( + ActionNodeCreate( type="update_row", ref="action", previous_node_ref="trigger", label="Update Row Action", table_id=999, row_id="1", - values={}, + values=[], ), ), ( TriggerNodeCreate( - type="rows_deleted", ref="trigger", label="Rows Deleted Trigger" + type="rows_deleted", + ref="trigger", + label="Rows Deleted Trigger", + rows_triggers_settings={"table_id": 999}, ), - DeleteRowActionCreate( + ActionNodeCreate( type="delete_row", ref="action", previous_node_ref="trigger", @@ -285,25 +264,10 @@ def test_create_workflow_with_row_triggers_and_actions(data_fixture, trigger, ac table.pk = 999 # To match the action's table_id table.save() - factory = get_workflow_tool_factory(user, workspace, fake_tool_helpers) - assert callable(factory) - - tools_upgrade = factory() - assert is_module_callback(tools_upgrade) + ctx = make_test_ctx(user, workspace) - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called - - added_tools = mock_module.init_module.call_args[1]["tools"] - create_workflows_tool = next( - (tool for tool in added_tools if tool.name == "create_workflows"), None - ) - assert create_workflows_tool is not None - - result = create_workflows_tool.func( + result = create_workflows( + ctx, automation_id=automation.id, workflows=[ WorkflowCreate( @@ -312,6 +276,7 @@ def test_create_workflow_with_row_triggers_and_actions(data_fixture, trigger, ac nodes=[action], ) ], + thought="test", ) assert len(result["created_workflows"]) == 1 @@ -328,7 +293,7 @@ def test_create_workflow_with_row_triggers_and_actions(data_fixture, trigger, ac @pytest.mark.django_db(transaction=True) def test_create_row_action_with_field_ids(data_fixture): - """Test CreateRowActionCreate uses field IDs in values dict, not field names.""" + """Test ActionNodeCreate uses field IDs in values dict, not field names.""" user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) @@ -340,25 +305,10 @@ def test_create_row_action_with_field_ids(data_fixture): text_field = data_fixture.create_text_field(table=table, name="Name") number_field = data_fixture.create_number_field(table=table, name="Age") - factory = get_workflow_tool_factory(user, workspace, fake_tool_helpers) - assert callable(factory) - - tools_upgrade = factory() - assert is_module_callback(tools_upgrade) - - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called - - added_tools = mock_module.init_module.call_args[1]["tools"] - create_workflows_tool = next( - (tool for tool in added_tools if tool.name == "create_workflows"), None - ) - assert create_workflows_tool is not None + ctx = make_test_ctx(user, workspace) - result = create_workflows_tool.func( + result = create_workflows( + ctx, automation_id=automation.id, workflows=[ WorkflowCreate( @@ -367,22 +317,26 @@ def test_create_row_action_with_field_ids(data_fixture): ref="trigger1", label="Periodic Trigger", type="periodic", + periodic_interval={"interval": "DAY"}, ), nodes=[ - CreateRowActionCreate( + ActionNodeCreate( ref="action1", label="Create row with field IDs", previous_node_ref="trigger1", type="create_row", table_id=table.id, - values={ - text_field.id: "John Doe", - number_field.id: 25, - }, + values=[ + AutomationFieldValue( + field_id=text_field.id, value="John Doe" + ), + AutomationFieldValue(field_id=number_field.id, value="25"), + ], ) ], ) ], + thought="test", ) assert len(result["created_workflows"]) == 1 @@ -400,7 +354,7 @@ def test_create_row_action_with_field_ids(data_fixture): @pytest.mark.django_db(transaction=True) def test_update_row_action_with_row_id_and_field_ids(data_fixture): - """Test UpdateRowActionCreate uses row_id parameter and field IDs in values.""" + """Test ActionNodeCreate uses row_id parameter and field IDs in values.""" user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) @@ -411,25 +365,10 @@ def test_update_row_action_with_row_id_and_field_ids(data_fixture): table = data_fixture.create_database_table(user=user, database=database) text_field = data_fixture.create_text_field(table=table, name="Status") - factory = get_workflow_tool_factory(user, workspace, fake_tool_helpers) - assert callable(factory) - - tools_upgrade = factory() - assert is_module_callback(tools_upgrade) - - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called + ctx = make_test_ctx(user, workspace) - added_tools = mock_module.init_module.call_args[1]["tools"] - create_workflows_tool = next( - (tool for tool in added_tools if tool.name == "create_workflows"), None - ) - assert create_workflows_tool is not None - - result = create_workflows_tool.func( + result = create_workflows( + ctx, automation_id=automation.id, workflows=[ WorkflowCreate( @@ -438,42 +377,44 @@ def test_update_row_action_with_row_id_and_field_ids(data_fixture): ref="trigger1", label="Periodic Trigger", type="periodic", + periodic_interval={"interval": "DAY"}, ), nodes=[ - UpdateRowActionCreate( + ActionNodeCreate( ref="action1", label="Update row", previous_node_ref="trigger1", type="update_row", table_id=table.id, row_id="123", - values={text_field.id: "completed"}, + values=[ + AutomationFieldValue( + field_id=text_field.id, value="completed" + ) + ], ) ], ) ], + thought="test", ) assert len(result["created_workflows"]) == 1 workflow_id = result["created_workflows"][0]["id"] workflow = AutomationWorkflowHandler().get_workflow(workflow_id) - # Get the action node and verify it was created with the correct table - # Note: row_id formula generation occurs in a separate transaction and may fail - # if DSPy is not configured, so we only verify basic service configuration action_nodes = workflow.automation_workflow_nodes.exclude( id=workflow.get_trigger().id ) assert action_nodes.count() == 1 action_node = action_nodes.first() assert action_node.service.specific.table_id == table.id - # Verify the service type is correct for upsert_row (update operation) assert action_node.service.get_type().type == "local_baserow_upsert_row" @pytest.mark.django_db(transaction=True) def test_delete_row_action_with_row_id(data_fixture): - """Test DeleteRowActionCreate uses row_id parameter.""" + """Test ActionNodeCreate uses row_id parameter.""" user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) @@ -483,25 +424,10 @@ def test_delete_row_action_with_row_id(data_fixture): database = data_fixture.create_database_application(user=user, workspace=workspace) table = data_fixture.create_database_table(user=user, database=database) - factory = get_workflow_tool_factory(user, workspace, fake_tool_helpers) - assert callable(factory) + ctx = make_test_ctx(user, workspace) - tools_upgrade = factory() - assert is_module_callback(tools_upgrade) - - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called - - added_tools = mock_module.init_module.call_args[1]["tools"] - create_workflows_tool = next( - (tool for tool in added_tools if tool.name == "create_workflows"), None - ) - assert create_workflows_tool is not None - - result = create_workflows_tool.func( + result = create_workflows( + ctx, automation_id=automation.id, workflows=[ WorkflowCreate( @@ -510,9 +436,10 @@ def test_delete_row_action_with_row_id(data_fixture): ref="trigger1", label="Periodic Trigger", type="periodic", + periodic_interval={"interval": "DAY"}, ), nodes=[ - DeleteRowActionCreate( + ActionNodeCreate( ref="action1", label="Delete row", previous_node_ref="trigger1", @@ -523,28 +450,25 @@ def test_delete_row_action_with_row_id(data_fixture): ], ) ], + thought="test", ) assert len(result["created_workflows"]) == 1 workflow_id = result["created_workflows"][0]["id"] workflow = AutomationWorkflowHandler().get_workflow(workflow_id) - # Get the action node and verify it was created with the correct table - # Note: row_id formula generation occurs in a separate transaction and may fail - # if DSPy is not configured, so we only verify basic service configuration action_nodes = workflow.automation_workflow_nodes.exclude( id=workflow.get_trigger().id ) assert action_nodes.count() == 1 action_node = action_nodes.first() assert action_node.service.specific.table_id == table.id - # Verify the service type is correct for delete_row assert action_node.service.get_type().type == "local_baserow_delete_row" @pytest.mark.django_db(transaction=True) def test_router_node_with_required_conditions(data_fixture): - """Test RouterNodeCreate requires condition field for each edge.""" + """Test ActionNodeCreate requires condition field for each edge.""" user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) @@ -554,25 +478,10 @@ def test_router_node_with_required_conditions(data_fixture): database = data_fixture.create_database_application(user=user, workspace=workspace) table = data_fixture.create_database_table(user=user, database=database) - factory = get_workflow_tool_factory(user, workspace, fake_tool_helpers) - assert callable(factory) - - tools_upgrade = factory() - assert is_module_callback(tools_upgrade) - - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called - - added_tools = mock_module.init_module.call_args[1]["tools"] - create_workflows_tool = next( - (tool for tool in added_tools if tool.name == "create_workflows"), None - ) - assert create_workflows_tool is not None + ctx = make_test_ctx(user, workspace) - result = create_workflows_tool.func( + result = create_workflows( + ctx, automation_id=automation.id, workflows=[ WorkflowCreate( @@ -581,9 +490,10 @@ def test_router_node_with_required_conditions(data_fixture): ref="trigger1", label="Periodic Trigger", type="periodic", + periodic_interval={"interval": "DAY"}, ), nodes=[ - RouterNodeCreate( + ActionNodeCreate( ref="router1", label="Router", previous_node_ref="trigger1", @@ -599,17 +509,18 @@ def test_router_node_with_required_conditions(data_fixture): ), ], ), - CreateRowActionCreate( + ActionNodeCreate( ref="action1", label="Create row", previous_node_ref="router1", type="create_row", table_id=table.id, - values={}, + values=[], ), ], ) ], + thought="test", ) assert len(result["created_workflows"]) == 1 diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_core_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_core_tools.py new file mode 100644 index 0000000000..3054689274 --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_core_tools.py @@ -0,0 +1,116 @@ +import pytest + +from baserow.test_utils.helpers import AnyInt +from baserow_enterprise.assistant.tools.core.tools import ( + create_builders, + list_builders, +) +from baserow_enterprise.assistant.tools.core.types import BuilderItemCreate + +from .utils import make_test_ctx + + +@pytest.mark.django_db +def test_list_builders_all(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + db = data_fixture.create_database_application(workspace=workspace, name="My DB") + automation = data_fixture.create_automation_application( + workspace=workspace, name="My Automation" + ) + + ctx = make_test_ctx(user, workspace) + result = list_builders(ctx, builder_types=None, thought="list all") + + assert "database" in result + assert any(b["name"] == "My DB" for b in result["database"]) + assert "automation" in result + assert any(b["name"] == "My Automation" for b in result["automation"]) + + +@pytest.mark.django_db +def test_list_builders_filter_by_type(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + data_fixture.create_database_application(workspace=workspace, name="DB 1") + data_fixture.create_automation_application(workspace=workspace, name="Auto 1") + + ctx = make_test_ctx(user, workspace) + result = list_builders(ctx, builder_types=["database"], thought="databases only") + + assert "database" in result + assert "automation" not in result + + +@pytest.mark.django_db +def test_list_builders_empty(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + + ctx = make_test_ctx(user, workspace) + result = list_builders(ctx, builder_types=None, thought="list all") + + assert result == {} + + +@pytest.mark.django_db +def test_list_builders_truncation(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + for i in range(25): + data_fixture.create_database_application(workspace=workspace, name=f"DB {i}") + + ctx = make_test_ctx(user, workspace) + result = list_builders(ctx, builder_types=None, thought="list all") + + assert "_info" in result + assert len(result["database"]) == 20 + + +@pytest.mark.django_db +def test_create_builders_database(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + + ctx = make_test_ctx(user, workspace) + builders = [BuilderItemCreate(name="New Database", type="database")] + result = create_builders(ctx, builders=builders, thought="create db") + + assert len(result["created_builders"]) == 1 + created = result["created_builders"][0] + assert created["name"] == "New Database" + assert created["type"] == "database" + assert created["id"] == AnyInt() + + +@pytest.mark.django_db +def test_create_builders_multiple(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + + ctx = make_test_ctx(user, workspace) + builders = [ + BuilderItemCreate(name="DB One", type="database"), + BuilderItemCreate(name="DB Two", type="database"), + ] + result = create_builders(ctx, builders=builders, thought="create two dbs") + + assert len(result["created_builders"]) == 2 + names = [b["name"] for b in result["created_builders"]] + assert "DB One" in names + assert "DB Two" in names + + +@pytest.mark.django_db +def test_create_database_ignores_theme(data_fixture): + """Creating a database should not fail even though databases have no theme.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + + ctx = make_test_ctx(user, workspace) + builders = [BuilderItemCreate(name="My DB", type="database")] + result = create_builders(ctx, builders=builders, thought="create db") + + assert len(result["created_builders"]) == 1 + assert result["created_builders"][0]["type"] == "database" diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_field_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_field_tools.py new file mode 100644 index 0000000000..7a96240ecc --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_field_tools.py @@ -0,0 +1,151 @@ +import pytest + +from baserow.contrib.database.fields.handler import FieldHandler +from baserow.contrib.database.fields.models import Field +from baserow_enterprise.assistant.tools.database.tools import ( + delete_fields, + update_fields, +) +from baserow_enterprise.assistant.tools.database.types import ( + FieldItemUpdate, + SelectOptionCreate, +) + +from .utils import make_test_ctx + + +@pytest.mark.django_db +def test_update_field_name(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_text_field(table=table, name="Old Name") + + ctx = make_test_ctx(user, workspace) + result = update_fields( + ctx, + fields=[FieldItemUpdate(field_id=field.id, name="New Name")], + thought="rename field", + ) + + assert result["updated_fields"][0]["name"] == "New Name" + assert result["updated_fields"][0]["id"] == field.id + + # Verify in DB + refreshed = FieldHandler().get_field(field.id) + assert refreshed.name == "New Name" + + +@pytest.mark.django_db +def test_update_number_field_decimal_places(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_number_field( + table=table, name="Price", number_decimal_places=0 + ) + + ctx = make_test_ctx(user, workspace) + result = update_fields( + ctx, + fields=[FieldItemUpdate(field_id=field.id, decimal_places=2)], + thought="change decimal places", + ) + + assert result["updated_fields"][0]["decimal_places"] == 2 + + +@pytest.mark.django_db +def test_update_select_field_options(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_single_select_field(table=table, name="Status") + + ctx = make_test_ctx(user, workspace) + result = update_fields( + ctx, + fields=[ + FieldItemUpdate( + field_id=field.id, + options=[ + SelectOptionCreate(value="Open", color="green"), + SelectOptionCreate(value="Closed", color="red"), + ], + ) + ], + thought="add options", + ) + + updated = result["updated_fields"][0] + assert len(updated["options"]) == 2 + option_values = {o["value"] for o in updated["options"]} + assert option_values == {"Open", "Closed"} + + +@pytest.mark.django_db +def test_update_field_no_changes(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_text_field(table=table, name="Unchanged") + + ctx = make_test_ctx(user, workspace) + result = update_fields( + ctx, + fields=[FieldItemUpdate(field_id=field.id)], + thought="no changes", + ) + + assert result["updated_fields"][0]["name"] == "Unchanged" + + +@pytest.mark.django_db +def test_delete_field(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_text_field(table=table, name="To Delete") + + ctx = make_test_ctx(user, workspace) + result = delete_fields( + ctx, + field_ids=[field.id], + thought="delete field", + ) + + assert result["deleted_field_ids"] == [field.id] + + # Field should be trashed + assert not Field.objects.filter(id=field.id).exists() + assert Field.objects_and_trash.filter(id=field.id, trashed=True).exists() + + +@pytest.mark.django_db +def test_delete_primary_field_fails(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + primary_field = data_fixture.create_text_field( + table=table, name="Primary", primary=True + ) + + ctx = make_test_ctx(user, workspace) + result = delete_fields( + ctx, + field_ids=[primary_field.id], + thought="try delete primary", + ) + + assert result["deleted_field_ids"] == [] + assert len(result["errors"]) == 1 + + # Primary field should still exist + refreshed = FieldHandler().get_field(primary_field.id) + assert refreshed.primary is True diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_rows_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_rows_tools.py index 004477d9e6..306c06289a 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_rows_tools.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_rows_tools.py @@ -1,15 +1,12 @@ -from unittest.mock import Mock - import pytest -from udspy.module.callbacks import ModuleContext, is_module_callback from baserow.contrib.database.rows.handler import RowHandler from baserow_enterprise.assistant.tools.database.tools import ( - get_list_rows_tool, - get_rows_tools_factory, + list_rows, + load_row_tools, ) -from .utils import fake_tool_helpers +from .utils import make_test_ctx def _create_simple_database_with_linked_tables_and_rows(data_fixture): @@ -134,20 +131,20 @@ def test_list_rows(data_fixture): user = res["user"] workspace = res["workspace"] table = res["table_a"] - tool_helpers = fake_tool_helpers - list_table_rows = get_list_rows_tool(user, workspace, tool_helpers) - assert callable(list_table_rows) + ctx = make_test_ctx(user, workspace) - result = list_table_rows(table_id=table.id, offset=0, limit=50) + result = list_rows( + ctx, table_id=table.id, offset=0, limit=50, field_ids=None, thought="test" + ) rows = result["rows"] assert len(rows) == 3 assert rows[0] == { "primary": "Row A1", "Long text field": "Long text A1", "Number field": 10.123, - "Date field": {"year": 2023, "month": 1, "day": 1}, - "Datetime field": {"year": 2023, "month": 1, "day": 1, "hour": 10, "minute": 0}, + "Date field": "2023-01-01", + "Datetime field": "2023-01-01T10:00", "Single link to B": "Row B1", "Multiple select": ["Option A", "Option B"], "Text field": "Text A1", @@ -159,8 +156,8 @@ def test_list_rows(data_fixture): "primary": "Row A2", "Long text field": "Long text A2", "Number field": 20.456, - "Date field": {"year": 2023, "month": 2, "day": 1}, - "Datetime field": {"year": 2023, "month": 2, "day": 1, "hour": 11, "minute": 0}, + "Date field": "2023-02-01", + "Datetime field": "2023-02-01T11:00", "Single link to B": "Row B2", "Multiple select": ["Option B", "Option C"], "Text field": "Text A2", @@ -183,8 +180,13 @@ def test_list_rows(data_fixture): } # List a single field - result = list_table_rows( - table_id=table.id, offset=0, limit=50, field_ids=[table.get_primary_field().id] + result = list_rows( + ctx, + table_id=table.id, + offset=0, + limit=50, + field_ids=[table.get_primary_field().id], + thought="test", ) rows = result["rows"] assert len(rows) == 3 @@ -209,24 +211,19 @@ def test_create_rows(data_fixture): user = res["user"] workspace = res["workspace"] table = res["table_a"] - tool_helpers = fake_tool_helpers - meta_tool = get_rows_tools_factory(user, workspace, tool_helpers) - assert callable(meta_tool) + ctx = make_test_ctx(user, workspace) - tools_upgrade = meta_tool([table.id], ["create"]) - assert is_module_callback(tools_upgrade) + observation = load_row_tools(ctx, [table.id], ["create"], thought="test") + assert isinstance(observation, str) + assert f"create_rows_in_table_{table.id}" in observation - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called + # Tools should be stored in ctx.deps.dynamic_tools + dynamic_tools = ctx.deps.dynamic_tools + assert len(dynamic_tools) == 1 - added_tools = mock_module.init_module.call_args[1]["tools"] - added_tools_names = [tool.name for tool in added_tools] - assert len(added_tools) == 1 - assert f"create_rows_in_table_{table.id}" in added_tools_names + create_tool = dynamic_tools[0] + assert create_tool.name == f"create_rows_in_table_{table.id}" table_model = table.get_model() assert table_model.objects.count() == 3 @@ -236,14 +233,8 @@ def test_create_rows(data_fixture): "Text field": "Text A3", "Long text field": "Long text A3", "Number field": 30.789, - "Date field": {"year": 2023, "month": 3, "day": 1}, - "Datetime field": { - "year": 2023, - "month": 3, - "day": 1, - "hour": 12, - "minute": 0, - }, + "Date field": "2023-03-01", + "Datetime field": "2023-03-01T12:00", "Single select": "Option 1", "Multiple select": ["Option A", "Option C"], "Single link to B": "Row B3", @@ -261,8 +252,12 @@ def test_create_rows(data_fixture): "Single link to B": None, "link": [], } - create_table_rows = added_tools[0] - result = create_table_rows(rows=[row_1, row_2]) + # Validate dicts through the tool's schema (as pydantic-ai would), + # then call the underlying function. + validated_args = create_tool.function_schema.validator.validate_python( + {"rows": [row_1, row_2], "thought": "test"} + ) + result = create_tool.function(**validated_args) created_row_ids = result["created_row_ids"] assert len(created_row_ids) == 2 assert created_row_ids == [4, 5] @@ -275,28 +270,22 @@ def test_update_rows(data_fixture): user = res["user"] workspace = res["workspace"] table = res["table_a"] - tool_helpers = fake_tool_helpers - meta_tool = get_rows_tools_factory(user, workspace, tool_helpers) - assert callable(meta_tool) - tools_upgrade = meta_tool([table.id], ["update"]) - assert is_module_callback(tools_upgrade) + ctx = make_test_ctx(user, workspace) + + observation = load_row_tools(ctx, [table.id], ["update"], thought="test") + assert isinstance(observation, str) - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called + dynamic_tools = ctx.deps.dynamic_tools + assert len(dynamic_tools) == 1 - added_tools = mock_module.init_module.call_args[1]["tools"] - added_tools_names = [tool.name for tool in added_tools] - assert len(added_tools) == 1 - assert f"update_rows_in_table_{table.id}" in added_tools_names + update_tool = dynamic_tools[0] + assert update_tool.name == f"update_rows_in_table_{table.id}" table_model = table.get_model() assert table_model.objects.count() == 3 - # Update row 1 with new values + # Update row 1 — only pass fields to change, omit the rest row_1_updates = { "id": 1, "primary": "Updated Row A1", @@ -305,41 +294,33 @@ def test_update_rows(data_fixture): "Single select": "Option 2", "link": ["Row B3"], "Single link to B": "Row B2", - "Datetime field": "__NO_CHANGE__", - "Date field": "__NO_CHANGE__", - "Multiple select": "__NO_CHANGE__", - "Long text field": "__NO_CHANGE__", } - # Update row 2 with new values + # Update row 2 — only pass fields to change row_2_updates = { "id": 2, - "Single link to B": "__NO_CHANGE__", "Long text field": "Updated Long text A2", - "Date field": {"year": 2024, "month": 12, "day": 31}, + "Date field": "2024-12-31", "Multiple select": ["Option A"], - "primary": "__NO_CHANGE__", - "Text field": "__NO_CHANGE__", - "Number field": "__NO_CHANGE__", - "Datetime field": "__NO_CHANGE__", - "Single select": "__NO_CHANGE__", - "link": "__NO_CHANGE__", } - update_table_rows = added_tools[0] - result = update_table_rows(rows=[row_1_updates, row_2_updates]) + validated_args = update_tool.function_schema.validator.validate_python( + {"rows": [row_1_updates, row_2_updates], "thought": "test"} + ) + result = update_tool.function(**validated_args) updated_row_ids = result["updated_row_ids"] assert len(updated_row_ids) == 2 assert updated_row_ids == [1, 2] # Verify the rows were updated correctly - list_table_rows = get_list_rows_tool(user, workspace, tool_helpers) - row_1, row_2 = list_table_rows(table_id=table.id, offset=0, limit=2)["rows"] + row_1, row_2 = list_rows( + ctx, table_id=table.id, offset=0, limit=2, field_ids=None, thought="test" + )["rows"] assert row_1 == { "primary": "Updated Row A1", "Long text field": "Long text A1", "Number field": 99.999, - "Date field": {"year": 2023, "month": 1, "day": 1}, - "Datetime field": {"year": 2023, "month": 1, "day": 1, "hour": 10, "minute": 0}, + "Date field": "2023-01-01", + "Datetime field": "2023-01-01T10:00", "Single link to B": "Row B2", "Multiple select": ["Option A", "Option B"], "Text field": "Updated Text A1", @@ -351,8 +332,8 @@ def test_update_rows(data_fixture): "primary": "Row A2", "Long text field": "Updated Long text A2", "Number field": 20.456, - "Date field": {"year": 2024, "month": 12, "day": 31}, - "Datetime field": {"year": 2023, "month": 2, "day": 1, "hour": 11, "minute": 0}, + "Date field": "2024-12-31", + "Datetime field": "2023-02-01T11:00", "Single link to B": "Row B2", "Multiple select": ["Option A"], "Text field": "Text A2", @@ -369,29 +350,23 @@ def test_delete_rows(data_fixture): user = res["user"] workspace = res["workspace"] table = res["table_a"] - tool_helpers = fake_tool_helpers - - meta_tool = get_rows_tools_factory(user, workspace, tool_helpers) - assert callable(meta_tool) - - tools_upgrade = meta_tool([table.id], ["delete"]) - assert is_module_callback(tools_upgrade) - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called - added_tools = mock_module.init_module.call_args[1]["tools"] - added_tools_names = [tool.name for tool in added_tools] - assert len(added_tools) == 1 - assert f"delete_rows_in_table_{table.id}" in added_tools_names - delete_table_rows = added_tools[0] + + ctx = make_test_ctx(user, workspace) + + observation = load_row_tools(ctx, [table.id], ["delete"], thought="test") + assert isinstance(observation, str) + + dynamic_tools = ctx.deps.dynamic_tools + assert len(dynamic_tools) == 1 + + delete_tool = dynamic_tools[0] + assert delete_tool.name == f"delete_rows_in_table_{table.id}" table_model = table.get_model() assert table_model.objects.count() == 3 # Delete rows with ids 1 and 3 - result = delete_table_rows(row_ids=[1, 3]) + result = delete_tool.function(row_ids=[1, 3], thought="test") assert result["deleted_row_ids"] == [1, 3] # Verify rows were deleted diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_table_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_table_tools.py index 0d1b052dd0..97c9560f35 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_table_tools.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_table_tools.py @@ -1,35 +1,45 @@ -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, patch import pytest -from udspy.module.callbacks import ModuleContext, is_module_callback from baserow.contrib.database.fields.models import FormulaField from baserow.contrib.database.formula.registries import formula_function_registry from baserow.contrib.database.table.models import Table from baserow.test_utils.helpers import AnyInt +from baserow_enterprise.assistant.tools.database.agents import FormulaGenerationResult from baserow_enterprise.assistant.tools.database.tools import ( - get_generate_database_formula_tool, - get_list_tables_tool, - get_table_and_fields_tools_factory, + create_fields, + create_tables, + generate_formula, + list_tables, ) from baserow_enterprise.assistant.tools.database.types import ( - BooleanFieldItemCreate, - DateFieldItemCreate, - FileFieldItemCreate, - LinkRowFieldItemCreate, + FieldItem, + FieldItemCreate, + InvalidFormulaFieldError, ListTablesFilterArg, - LongTextFieldItemCreate, - MultipleSelectFieldItemCreate, - NumberFieldItemCreate, - RatingFieldItemCreate, SelectOptionCreate, - SingleSelectFieldItemCreate, TableItemCreate, - TextFieldItemCreate, - field_item_registry, ) -from .utils import fake_tool_helpers +from .utils import make_test_ctx + + +def _make_mock_formula_result(**kwargs): + """Create a mock agent result with a FormulaGenerationResult output.""" + defaults = { + "table_id": 1, + "field_name": "test_formula", + "formula": "'ok'", + "formula_type": "text", + "is_formula_valid": True, + "error_message": "", + } + defaults.update(kwargs) + result = FormulaGenerationResult(**defaults) + mock_agent_result = MagicMock() + mock_agent_result.output = result + return mock_agent_result @pytest.mark.django_db @@ -46,138 +56,86 @@ def test_list_tables_tool(data_fixture): table_2 = data_fixture.create_database_table(database=database_1, name="Table 2") table_3 = data_fixture.create_database_table(database=database_2, name="Table 3") - tool = get_list_tables_tool(user, workspace, fake_tool_helpers) + ctx = make_test_ctx(user, workspace) - # Test 1: Filter by database_ids (single database) - returns flat list - response = tool( + # Test 1: Filter by database_id (single database) - returns flat list + response = list_tables( + ctx, + thought="test", filters=ListTablesFilterArg( - database_ids=[database_1.id], - database_names=None, - table_ids=None, - table_names=None, - ) + database_id_or_name=database_1.id, + table_ids_or_names=None, + ), ) assert response == [ {"id": table_1.id, "name": "Table 1", "database_id": database_1.id}, {"id": table_2.id, "name": "Table 2", "database_id": database_1.id}, ] - # Test 2: Filter by database_names (single database) - returns flat list - response = tool( + # Test 2: Filter by database_name (single database) - returns flat list + response = list_tables( + ctx, + thought="test", filters=ListTablesFilterArg( - database_ids=None, - database_names=["Database 2"], - table_ids=None, - table_names=None, - ) + database_id_or_name="Database 2", + table_ids_or_names=None, + ), ) assert response == [ {"id": table_3.id, "name": "Table 3", "database_id": database_2.id}, ] - # Test 3: Filter by multiple database_ids - returns database wrapper structure - response = tool( + # Test 4: Filter by database + table_ids - returns flat list + response = list_tables( + ctx, + thought="test", filters=ListTablesFilterArg( - database_ids=[database_1.id, database_2.id], - database_names=None, - table_ids=None, - table_names=None, - ) - ) - assert response == [ - { - "id": database_1.id, - "name": "Database 1", - "tables": [ - {"id": table_1.id, "name": "Table 1", "database_id": database_1.id}, - {"id": table_2.id, "name": "Table 2", "database_id": database_1.id}, - ], - }, - { - "id": database_2.id, - "name": "Database 2", - "tables": [ - {"id": table_3.id, "name": "Table 3", "database_id": database_2.id}, - ], - }, - ] - - # Test 4: Filter by table_ids (single database) - returns flat list - response = tool( - filters=ListTablesFilterArg( - database_ids=None, - database_names=None, - table_ids=[table_1.id, table_2.id], - table_names=None, - ) + database_id_or_name=database_1.id, + table_ids_or_names=[table_1.id, table_2.id], + ), ) assert response == [ {"id": table_1.id, "name": "Table 1", "database_id": database_1.id}, {"id": table_2.id, "name": "Table 2", "database_id": database_1.id}, ] - # Test 5: Filter by table_names (single database) - returns flat list - response = tool( + # Test 5: Filter by database + table_names - returns flat list + response = list_tables( + ctx, + thought="test", filters=ListTablesFilterArg( - database_ids=None, - database_names=None, - table_ids=None, - table_names=["Table 1"], - ) + database_id_or_name=database_1.id, + table_ids_or_names=["Table 1"], + ), ) assert response == [ {"id": table_1.id, "name": "Table 1", "database_id": database_1.id}, ] - # Test 6: Filter by table_ids across multiple databases - returns database wrapper - response = tool( - filters=ListTablesFilterArg( - database_ids=None, - database_names=None, - table_ids=[table_1.id, table_3.id], - table_names=None, - ) - ) - assert response == [ - { - "id": database_1.id, - "name": "Database 1", - "tables": [ - {"id": table_1.id, "name": "Table 1", "database_id": database_1.id}, - ], - }, - { - "id": database_2.id, - "name": "Database 2", - "tables": [ - {"id": table_3.id, "name": "Table 3", "database_id": database_2.id}, - ], - }, - ] - - # Test 7: Combined filters (database_ids + table_names) - returns flat list - response = tool( + # Test 6: Combined filters (database_id + table_names) - returns flat list + response = list_tables( + ctx, + thought="test", filters=ListTablesFilterArg( - database_ids=[database_1.id], - database_names=None, - table_ids=None, - table_names=["Table 2"], - ) + database_id_or_name=database_1.id, + table_ids_or_names=["Table 2"], + ), ) assert response == [ {"id": table_2.id, "name": "Table 2", "database_id": database_1.id}, ] - # Test 8: No matching tables - returns "No tables found" - response = tool( + # Test 7: No matching tables - returns hint with available tables + response = list_tables( + ctx, + thought="test", filters=ListTablesFilterArg( - database_ids=None, - database_names=None, - table_ids=None, - table_names=["Nonexistent Table"], - ) + database_id_or_name=database_1.id, + table_ids_or_names=["Nonexistent Table"], + ), ) - assert response == "No tables found" + info = response["_info"] + assert "no tables matching" in info or "No tables found" in info @pytest.mark.django_db @@ -188,44 +146,29 @@ def test_create_simple_table_tool(data_fixture): workspace=workspace, name="Database 1" ) - factory = get_table_and_fields_tools_factory(user, workspace, fake_tool_helpers) - assert callable(factory) - - tools_upgrade = factory() - assert is_module_callback(tools_upgrade) - - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called + ctx = make_test_ctx(user, workspace) - added_tools = mock_module.init_module.call_args[1]["tools"] - assert len(added_tools) == 2 # create_tables and create_fields - - # Find the create_tables tool - create_tables_tool = next( - (tool for tool in added_tools if tool.name == "create_tables"), None - ) - assert create_tables_tool is not None - - # Call the underlying function directly (not through udspy.Tool wrapper) - response = create_tables_tool.func( + # Call the tool function directly + response = create_tables( + ctx, + thought="test", database_id=database.id, tables=[ TableItemCreate( name="New Table", - primary_field=TextFieldItemCreate(type="text", name="Name"), + primary_field_name="Name", fields=[], ) ], add_sample_rows=False, ) - assert response == { - "created_tables": [{"id": AnyInt(), "name": "New Table"}], - "notes": [], - } + assert len(response["created_tables"]) == 1 + assert response["created_tables"][0]["name"] == "New Table" + assert response["created_tables"][0]["id"] == AnyInt() + assert response["notes"] == [] + # Full schema is included so callers have field IDs + assert "primary_field" in response["created_tables"][0] # Ensure the table was actually created assert Table.objects.filter( @@ -242,66 +185,27 @@ def test_create_complex_table_tool(data_fixture): ) table = data_fixture.create_database_table(database=database, name="Table 1") - factory = get_table_and_fields_tools_factory(user, workspace, fake_tool_helpers) - assert callable(factory) - - tools_upgrade = factory() - assert is_module_callback(tools_upgrade) - - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called - - added_tools = mock_module.init_module.call_args[1]["tools"] - assert len(added_tools) == 2 # create_tables and create_fields - - # Find the create_tables tool - create_tables_tool = next( - (tool for tool in added_tools if tool.name == "create_tables"), None - ) - assert create_tables_tool is not None + ctx = make_test_ctx(user, workspace) - primary_field = TextFieldItemCreate(type="text", name="Name") + primary_field_name = "Name" fields = [ - LongTextFieldItemCreate( - type="long_text", - name="Description", - rich_text=True, - ), - NumberFieldItemCreate( - type="number", - name="Amount", - decimal_places=2, - suffix="$", - ), - DateFieldItemCreate( - type="date", - name="Due Date", - include_time=False, - ), - DateFieldItemCreate( - type="date", - name="Event Time", - include_time=True, - ), - BooleanFieldItemCreate( - type="boolean", - name="Done?", - ), - SingleSelectFieldItemCreate( - type="single_select", + FieldItemCreate(name="Description", type="long_text", rich_text=True), + FieldItemCreate(name="Amount", type="number", decimal_places=2, suffix="$"), + FieldItemCreate(name="Due Date", type="date", include_time=False), + FieldItemCreate(name="Event Time", type="date", include_time=True), + FieldItemCreate(name="Done?", type="boolean"), + FieldItemCreate( name="Status", + type="single_select", options=[ SelectOptionCreate(value="New", color="blue"), SelectOptionCreate(value="In Progress", color="yellow"), SelectOptionCreate(value="Done", color="green"), ], ), - MultipleSelectFieldItemCreate( - type="multiple_select", + FieldItemCreate( name="Tags", + type="multiple_select", options=[ SelectOptionCreate(value="Red", color="red"), SelectOptionCreate(value="Yellow", color="yellow"), @@ -309,38 +213,36 @@ def test_create_complex_table_tool(data_fixture): SelectOptionCreate(value="Blue", color="blue"), ], ), - LinkRowFieldItemCreate( - type="link_row", + FieldItemCreate( name="Related Items", + type="link_row", linked_table=table.id, ), - RatingFieldItemCreate( - type="rating", - name="Rating", - max_value=5, - ), - FileFieldItemCreate( - type="file", - name="Attachments", - ), + FieldItemCreate(name="Rating", type="rating", max_value=5), + FieldItemCreate(name="Attachments", type="file"), ] - # Call the underlying function directly (not through udspy.Tool wrapper) - response = create_tables_tool.func( + # Call the tool function directly + response = create_tables( + ctx, + thought="test", database_id=database.id, tables=[ TableItemCreate( name="New Table", - primary_field=primary_field, + primary_field_name=primary_field_name, fields=fields, ) ], add_sample_rows=False, ) - assert response == { - "created_tables": [{"id": AnyInt(), "name": "New Table"}], - "notes": [], - } + assert len(response["created_tables"]) == 1 + assert response["created_tables"][0]["name"] == "New Table" + assert response["created_tables"][0]["id"] == AnyInt() + assert response["notes"] == [] + # Full schema is included with all field details + assert "primary_field" in response["created_tables"][0] + assert "fields" in response["created_tables"][0] # Ensure the table was actually created with all fields created_table = Table.objects.filter( @@ -351,28 +253,41 @@ def test_create_complex_table_tool(data_fixture): table_model = created_table.get_model() fields_map = {field.name: field for field in fields} - fields_map[primary_field.name] = primary_field for field_object in table_model.get_field_objects(): orm_field = field_object["field"] - assert orm_field.name in fields_map - field_item = fields_map.pop(orm_field.name).model_dump() - orm_field_to_item = field_item_registry.from_django_orm(orm_field).model_dump() + read_item = FieldItem.from_django_orm(orm_field).model_dump() + if orm_field.primary: - assert field_item["name"] == primary_field.name + assert orm_field.name == primary_field_name + continue + + assert orm_field.name in fields_map + create_item = fields_map.pop(orm_field.name) + create_dump = create_item.model_dump() + + # Both create and read are flat: type is top-level + assert create_dump["type"] == read_item["type"] - for key, value in orm_field_to_item.items(): - if key == "id": + # Compare type-specific fields present in both + skip_keys = {"name", "type"} + for key, value in create_dump.items(): + if key in skip_keys: continue + read_value = read_item.get(key) + if read_value is None: + continue # read model excludes None; defaults aren't relevant if key == "options": - # Saved options have an ID, so we need to remove them before comparison - for option in value: + # Saved options have an ID, so remove them before comparison + for option in read_value: option.pop("id") - - assert field_item[key] == value + assert read_value == value, ( + f"Field '{orm_field.name}' key '{key}': " + f"expected {value}, got {read_value}" + ) @pytest.mark.django_db -def test_generate_database_formula_no_save(data_fixture): +def test_generate_formula_no_save(data_fixture): """Test formula generation without saving to a field.""" user = data_fixture.create_user() @@ -381,20 +296,22 @@ def test_generate_database_formula_no_save(data_fixture): table = data_fixture.create_database_table(database=database, name="Test Table") data_fixture.create_text_field(table=table, name="text_field", primary=True) - # Mock the udspy.ReAct to return a valid formula - mock_prediction = MagicMock() - mock_prediction.is_formula_valid = True - mock_prediction.formula = "'ok'" - mock_prediction.formula_type = "text" - mock_prediction.field_name = "test_formula" - mock_prediction.table_id = table.id - mock_prediction.error_message = "" + mock_result = _make_mock_formula_result( + table_id=table.id, + field_name="test_formula", + formula="'ok'", + formula_type="text", + ) - with patch("udspy.ReAct") as mock_react: - mock_react.return_value.return_value = mock_prediction + with patch( + "baserow_enterprise.assistant.tools.database.tools.formula_generation_agent.run_sync" + ) as mock_agent: + mock_agent.return_value = mock_result - tool = get_generate_database_formula_tool(user, workspace, fake_tool_helpers) - result = tool( + ctx = make_test_ctx(user, workspace) + result = generate_formula( + ctx, + thought="test", database_id=database.id, description="Return a simple text", save_to_field=False, @@ -409,7 +326,7 @@ def test_generate_database_formula_no_save(data_fixture): @pytest.mark.django_db -def test_generate_database_formula_create_new_field(data_fixture): +def test_generate_formula_create_new_field(data_fixture): """Test formula generation creates a new field when none exists.""" user = data_fixture.create_user() @@ -418,20 +335,22 @@ def test_generate_database_formula_create_new_field(data_fixture): table = data_fixture.create_database_table(database=database, name="Test Table") data_fixture.create_text_field(table=table, name="text_field", primary=True) - # Mock the udspy.ReAct to return a valid formula - mock_prediction = MagicMock() - mock_prediction.is_formula_valid = True - mock_prediction.formula = "'ok'" - mock_prediction.formula_type = "text" - mock_prediction.field_name = "test_formula" - mock_prediction.table_id = table.id - mock_prediction.error_message = "" + mock_result = _make_mock_formula_result( + table_id=table.id, + field_name="test_formula", + formula="'ok'", + formula_type="text", + ) - with patch("udspy.ReAct") as mock_react: - mock_react.return_value.return_value = mock_prediction + with patch( + "baserow_enterprise.assistant.tools.database.tools.formula_generation_agent.run_sync" + ) as mock_agent: + mock_agent.return_value = mock_result - tool = get_generate_database_formula_tool(user, workspace, fake_tool_helpers) - result = tool( + ctx = make_test_ctx(user, workspace) + result = generate_formula( + ctx, + thought="test", database_id=database.id, description="Return a simple text", save_to_field=True, @@ -453,7 +372,7 @@ def test_generate_database_formula_create_new_field(data_fixture): @pytest.mark.django_db -def test_generate_database_formula_update_existing_formula_field(data_fixture): +def test_generate_formula_update_existing_formula_field(data_fixture): """Test formula generation updates an existing formula field.""" user = data_fixture.create_user() @@ -468,20 +387,22 @@ def test_generate_database_formula_update_existing_formula_field(data_fixture): ) existing_field_id = existing_field.id - # Mock the udspy.ReAct to return a new formula - mock_prediction = MagicMock() - mock_prediction.is_formula_valid = True - mock_prediction.formula = "'new'" - mock_prediction.formula_type = "text" - mock_prediction.field_name = "test_formula" - mock_prediction.table_id = table.id - mock_prediction.error_message = "" + mock_result = _make_mock_formula_result( + table_id=table.id, + field_name="test_formula", + formula="'new'", + formula_type="text", + ) - with patch("udspy.ReAct") as mock_react: - mock_react.return_value.return_value = mock_prediction + with patch( + "baserow_enterprise.assistant.tools.database.tools.formula_generation_agent.run_sync" + ) as mock_agent: + mock_agent.return_value = mock_result - tool = get_generate_database_formula_tool(user, workspace, fake_tool_helpers) - result = tool( + ctx = make_test_ctx(user, workspace) + result = generate_formula( + ctx, + thought="test", database_id=database.id, description="Return updated text", save_to_field=True, @@ -503,7 +424,7 @@ def test_generate_database_formula_update_existing_formula_field(data_fixture): @pytest.mark.django_db -def test_generate_database_formula_replace_non_formula_field(data_fixture): +def test_generate_formula_replace_non_formula_field(data_fixture): """Test formula generation replaces a non-formula field.""" user = data_fixture.create_user() @@ -518,20 +439,22 @@ def test_generate_database_formula_replace_non_formula_field(data_fixture): ) existing_field_id = existing_text_field.id - # Mock the udspy.ReAct to return a valid formula - mock_prediction = MagicMock() - mock_prediction.is_formula_valid = True - mock_prediction.formula = "'ok'" - mock_prediction.formula_type = "text" - mock_prediction.field_name = "test_formula" - mock_prediction.table_id = table.id - mock_prediction.error_message = "" + mock_result = _make_mock_formula_result( + table_id=table.id, + field_name="test_formula", + formula="'ok'", + formula_type="text", + ) - with patch("udspy.ReAct") as mock_react: - mock_react.return_value.return_value = mock_prediction + with patch( + "baserow_enterprise.assistant.tools.database.tools.formula_generation_agent.run_sync" + ) as mock_agent: + mock_agent.return_value = mock_result - tool = get_generate_database_formula_tool(user, workspace, fake_tool_helpers) - result = tool( + ctx = make_test_ctx(user, workspace) + result = generate_formula( + ctx, + thought="test", database_id=database.id, description="Return a simple text", save_to_field=True, @@ -559,7 +482,7 @@ def test_generate_database_formula_replace_non_formula_field(data_fixture): @pytest.mark.django_db -def test_generate_database_formula_invalid_formula(data_fixture): +def test_generate_formula_invalid_formula(data_fixture): """Test error handling when formula generation fails.""" user = data_fixture.create_user() @@ -568,23 +491,27 @@ def test_generate_database_formula_invalid_formula(data_fixture): table = data_fixture.create_database_table(database=database, name="Test Table") data_fixture.create_text_field(table=table, name="text_field", primary=True) - # Mock the udspy.ReAct to return an invalid formula - mock_prediction = MagicMock() - mock_prediction.is_formula_valid = False - mock_prediction.formula = "" - mock_prediction.formula_type = "" - mock_prediction.field_name = "test_formula" - mock_prediction.table_id = table.id - mock_prediction.error_message = "Formula syntax error: invalid expression" + mock_result = _make_mock_formula_result( + table_id=table.id, + field_name="test_formula", + formula="", + formula_type="", + is_formula_valid=False, + error_message="Formula syntax error: invalid expression", + ) - with patch("udspy.ReAct") as mock_react: - mock_react.return_value.return_value = mock_prediction + with patch( + "baserow_enterprise.assistant.tools.database.tools.formula_generation_agent.run_sync" + ) as mock_agent: + mock_agent.return_value = mock_result - tool = get_generate_database_formula_tool(user, workspace, fake_tool_helpers) + ctx = make_test_ctx(user, workspace) # Verify exception is raised with pytest.raises(Exception) as exc_info: - tool( + generate_formula( + ctx, + thought="test", database_id=database.id, description="Invalid formula test", save_to_field=True, @@ -598,7 +525,7 @@ def test_generate_database_formula_invalid_formula(data_fixture): @pytest.mark.django_db -def test_generate_database_formula_documentation_completeness(data_fixture): +def test_generate_formula_documentation_completeness(data_fixture): """Test that formula documentation contains all required functions.""" user = data_fixture.create_user() @@ -607,39 +534,39 @@ def test_generate_database_formula_documentation_completeness(data_fixture): table = data_fixture.create_database_table(database=database, name="Test Table") data_fixture.create_text_field(table=table, name="text_field", primary=True) - # Mock the udspy.ReAct to capture the formula_documentation argument - mock_prediction = MagicMock() - mock_prediction.is_formula_valid = True - mock_prediction.formula = "'ok'" - mock_prediction.formula_type = "text" - mock_prediction.field_name = "test_formula" - mock_prediction.table_id = table.id - mock_prediction.error_message = "" - - captured_formula_docs = None - - class MockReAct: - def __init__(self, signature, tools=None, max_iters=10): - nonlocal captured_formula_docs - # Don't capture anything here - wait for the call - self.mock_instance = MagicMock(return_value=mock_prediction) - - def __call__(self, **kwargs): - nonlocal captured_formula_docs - captured_formula_docs = kwargs.get("formula_documentation") - return mock_prediction - - with patch("udspy.ReAct", MockReAct): - tool = get_generate_database_formula_tool(user, workspace, fake_tool_helpers) - tool( + mock_result = _make_mock_formula_result( + table_id=table.id, + field_name="test_formula", + formula="'ok'", + formula_type="text", + ) + + captured_prompt = None + + def mock_run_sync(prompt, **kwargs): + nonlocal captured_prompt + captured_prompt = prompt + return mock_result + + with patch( + "baserow_enterprise.assistant.tools.database.tools.formula_generation_agent.run_sync", + side_effect=mock_run_sync, + ): + ctx = make_test_ctx(user, workspace) + generate_formula( + ctx, + thought="test", database_id=database.id, description="Test documentation", save_to_field=False, ) - # Verify formula_documentation was provided - assert captured_formula_docs is not None - assert len(captured_formula_docs) > 0 + # Verify formula documentation was included in the prompt + assert captured_prompt is not None + assert len(captured_prompt) > 0 + + # The formula_documentation is now embedded in the prompt string + captured_formula_docs = captured_prompt # Known exceptions (internal functions not documented) formula_exceptions = [ @@ -689,3 +616,204 @@ def __call__(self, **kwargs): assert func in captured_formula_docs, ( f"Expected function '{func}' not found in documentation" ) + + +@pytest.mark.django_db +def test_formula_field_validation_raises_on_invalid_formula(data_fixture): + """Invalid formula in to_django_orm_kwargs raises InvalidFormulaFieldError.""" + + table = data_fixture.create_database_table(name="Test") + data_fixture.create_text_field(table=table, name="Name", primary=True) + + item = FieldItemCreate( + name="Bad Formula", + type="formula", + formula="this is not a valid formula!!!", + ) + with pytest.raises(InvalidFormulaFieldError) as exc_info: + item.to_django_orm_kwargs(table) + + assert exc_info.value.field_name == "Bad Formula" + assert exc_info.value.formula == "this is not a valid formula!!!" + assert exc_info.value.table == table + + +@pytest.mark.django_db +def test_formula_field_validation_passes_for_valid_formula(data_fixture): + """Valid formula in to_django_orm_kwargs returns kwargs without error.""" + + table = data_fixture.create_database_table(name="Test") + data_fixture.create_text_field(table=table, name="Name", primary=True) + + item = FieldItemCreate( + name="Good Formula", + type="formula", + formula="field('Name')", + ) + result = item.to_django_orm_kwargs(table) + assert result == {"name": "Good Formula", "formula": "field('Name')"} + + +@pytest.mark.django_db +def test_formula_field_validation_passes_for_empty_formula(data_fixture): + """Empty formula string skips validation.""" + + table = data_fixture.create_database_table(name="Test") + + item = FieldItemCreate( + name="Empty Formula", + type="formula", + formula="", + ) + result = item.to_django_orm_kwargs(table) + assert result == {"name": "Empty Formula", "formula": ""} + + +@pytest.mark.django_db +def test_create_fields_tool_with_invalid_formula_auto_fixes(data_fixture): + """ + When a formula field has an invalid formula, create_fields + auto-fixes it via the formula generation pipeline. + """ + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database, name="Test") + data_fixture.create_text_field(table=table, name="Name", primary=True) + + mock_result = _make_mock_formula_result( + table_id=table.id, + field_name="Fixed Formula", + formula="field('Name')", + formula_type="text", + ) + + with patch( + "baserow_enterprise.assistant.tools.database.tools.formula_generation_agent.run_sync" + ) as mock_agent: + mock_agent.return_value = mock_result + + ctx = make_test_ctx(user, workspace) + result = create_fields( + ctx, + thought="test", + table_id=table.id, + fields=[ + FieldItemCreate(name="Description", type="text"), + FieldItemCreate( + name="Bad Formula", + type="formula", + formula="invalid_stuff!!!", + ), + ], + ) + + # The text field should be created successfully + assert len(result["created_fields"]) == 2 + # No formula errors since auto-fix succeeded + assert "formula_errors" not in result + + # Verify the formula field was created with the original name and fixed formula + formula_field = table.field_set.filter(name="Bad Formula").first() + assert formula_field is not None + + +@pytest.mark.django_db +def test_create_fields_tool_reports_error_when_auto_fix_fails(data_fixture): + """ + When auto-fix also fails, create_fields reports the error + without failing the entire batch. + """ + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database, name="Test") + data_fixture.create_text_field(table=table, name="Name", primary=True) + + mock_result = _make_mock_formula_result( + is_formula_valid=False, + error_message="Could not fix formula", + ) + + with patch( + "baserow_enterprise.assistant.tools.database.tools.formula_generation_agent.run_sync" + ) as mock_agent: + mock_agent.return_value = mock_result + + ctx = make_test_ctx(user, workspace) + result = create_fields( + ctx, + thought="test", + table_id=table.id, + fields=[ + FieldItemCreate(name="Description", type="text"), + FieldItemCreate( + name="Bad Formula", + type="formula", + formula="invalid_stuff!!!", + ), + ], + ) + + # The text field should still be created successfully + assert len(result["created_fields"]) == 1 + assert result["created_fields"][0]["name"] == "Description" + + # Formula errors should be reported + assert len(result["formula_errors"]) == 1 + assert result["formula_errors"][0]["field_name"] == "Bad Formula" + assert "hint" in result["formula_errors"][0] + + +@pytest.mark.django_db +def test_create_tables_with_invalid_formula_auto_fixes(data_fixture): + """ + When create_tables encounters an invalid formula, it auto-fixes + via the formula generation pipeline. + """ + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + + def mock_run_sync(prompt, **kwargs): + # The table doesn't exist yet when the mock is created, so we + # dynamically set table_id on call. + tables = Table.objects.filter(database=database).order_by("-id") + return _make_mock_formula_result( + table_id=tables.first().id, + field_name="My Formula", + formula="'fixed'", + formula_type="text", + ) + + with patch( + "baserow_enterprise.assistant.tools.database.tools.formula_generation_agent.run_sync", + side_effect=mock_run_sync, + ): + ctx = make_test_ctx(user, workspace) + result = create_tables( + ctx, + thought="test", + database_id=database.id, + tables=[ + TableItemCreate( + name="Test Table", + primary_field_name="Name", + fields=[ + FieldItemCreate( + name="My Formula", + type="formula", + formula="bad formula!!!", + ), + ], + ) + ], + add_sample_rows=False, + ) + + assert len(result["created_tables"]) == 1 + # No formula error notes since auto-fix succeeded + assert result["notes"] == [] diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_tools.py.skip b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_tools.py.skip deleted file mode 100644 index 2c74da289e..0000000000 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_tools.py.skip +++ /dev/null @@ -1,52 +0,0 @@ -import pytest - -from baserow.contrib.database.models import Database -from baserow.test_utils.helpers import AnyInt -from baserow_enterprise.assistant.tools.database.tools import ( - get_create_database_tool, - get_list_databases_tool, -) - -from .utils import fake_tool_helpers - - -@pytest.mark.django_db -def test_list_databases_tool(data_fixture): - user = data_fixture.create_user() - workspace = data_fixture.create_workspace(user=user) - database = data_fixture.create_database_application( - workspace=workspace, name="Database 1" - ) - - tool = get_list_databases_tool(user, workspace, fake_tool_helpers) - response = tool() - - assert response == {"databases": [{"id": database.id, "name": "Database 1"}]} - - database_2 = data_fixture.create_database_application( - workspace=workspace, name="Database 2" - ) - response = tool() - assert response == { - "databases": [ - {"id": database.id, "name": "Database 1"}, - {"id": database_2.id, "name": "Database 2"}, - ] - } - - -@pytest.mark.django_db -def test_create_database_tool(data_fixture): - user = data_fixture.create_user() - workspace = data_fixture.create_workspace(user=user) - - tool = get_create_database_tool(user, workspace, fake_tool_helpers) - response = tool(name="New Database") - - assert response == {"created_database": {"id": AnyInt(), "name": "New Database"}} - - # Ensure the database was actually created - - assert Database.objects.filter( - id=response["created_database"]["id"], name="New Database" - ).exists() diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_view_filters_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_view_filters_tools.py index 88726637f9..273df74d86 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_view_filters_tools.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_view_filters_tools.py @@ -1,41 +1,20 @@ import pytest from baserow.contrib.database.views.models import ViewFilter -from baserow_enterprise.assistant.tools.database.types import ( - BooleanIsViewFilterItemCreate, - DateAfterViewFilterItemCreate, - DateBeforeViewFilterItemCreate, - DateEqualsViewFilterItemCreate, - DateNotEqualsViewFilterItemCreate, - LinkRowHasNotViewFilterItemCreate, - LinkRowHasViewFilterItemCreate, - MultipleSelectIsAnyViewFilterItemCreate, - MultipleSelectIsNoneOfNotViewFilterItemCreate, - NumberEmptyViewFilterItemCreate, - NumberEqualsViewFilterItemCreate, - NumberHigherThanViewFilterItemCreate, - NumberLowerThanViewFilterItemCreate, - NumberNotEmptyViewFilterItemCreate, - NumberNotEqualsViewFilterItemCreate, - SingleSelectIsAnyViewFilterItemCreate, - SingleSelectIsNoneOfNotViewFilterItemCreate, - TextContainsViewFilterItemCreate, - TextEmptyViewFilterItemCreate, - TextEqualViewFilterItemCreate, - TextNotContainsViewFilterItemCreate, - TextNotEmptyViewFilterItemCreate, - TextNotEqualViewFilterItemCreate, -) -from baserow_enterprise.assistant.tools.database.types.base import Date +from baserow_enterprise.assistant.tools.database.helpers import create_view_filter from baserow_enterprise.assistant.tools.database.types.view_filters import ( ViewFilterItemCreate, ) -from baserow_enterprise.assistant.tools.database.utils import create_view_filter + + +def _make_filter(field_id, **kwargs): + """Shortcut to build a ViewFilterItemCreate.""" + return ViewFilterItemCreate(field_id=field_id, **kwargs) @pytest.mark.django_db def test_all_text_filters_conversion(data_fixture): - """Test all text filter types can be converted to Baserow filters.""" + """Test all text filter operators can be converted to Baserow filters.""" user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) @@ -46,52 +25,29 @@ def test_all_text_filters_conversion(data_fixture): table_fields = {field.id: field} text_filters = [ + ({"type": "text", "operator": "equal", "value": "test"}, "equal", "test"), ( - TextEqualViewFilterItemCreate( - field_id=field.id, type="text", operator="equal", value="test" - ), - "equal", - "test", - ), - ( - TextNotEqualViewFilterItemCreate( - field_id=field.id, type="text", operator="not_equal", value="test" - ), + {"type": "text", "operator": "not_equal", "value": "test"}, "not_equal", "test", ), ( - TextContainsViewFilterItemCreate( - field_id=field.id, type="text", operator="contains", value="keyword" - ), + {"type": "text", "operator": "contains", "value": "keyword"}, "contains", "keyword", ), ( - TextNotContainsViewFilterItemCreate( - field_id=field.id, type="text", operator="contains_not", value="spam" - ), + {"type": "text", "operator": "contains_not", "value": "spam"}, "contains_not", "spam", ), - ( - TextEmptyViewFilterItemCreate( - field_id=field.id, type="text", operator="empty", value="" - ), - "empty", - "", - ), - ( - TextNotEmptyViewFilterItemCreate( - field_id=field.id, type="text", operator="not_empty", value="" - ), - "not_empty", - "", - ), + ({"type": "text", "operator": "empty", "value": ""}, "empty", ""), + ({"type": "text", "operator": "not_empty", "value": ""}, "not_empty", ""), ] - for filter_create, expected_type, expected_value in text_filters: - created_filter = create_view_filter(user, view, table_fields, filter_create) + for kwargs, expected_type, expected_value in text_filters: + filter_item = _make_filter(field.id, **kwargs) + created_filter = create_view_filter(user, view, table_fields, filter_item) assert created_filter is not None assert created_filter.view.id == view.id @@ -99,7 +55,6 @@ def test_all_text_filters_conversion(data_fixture): assert created_filter.type == expected_type assert created_filter.value == expected_value - # Verify in database assert ViewFilter.objects.filter( view=view, field=field, type=expected_type ).exists() @@ -107,7 +62,7 @@ def test_all_text_filters_conversion(data_fixture): @pytest.mark.django_db def test_all_number_filters_conversion(data_fixture): - """Test all number filter types can be converted to Baserow filters.""" + """Test all number filter operators can be converted to Baserow filters.""" user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) @@ -119,59 +74,80 @@ def test_all_number_filters_conversion(data_fixture): number_filters = [ ( - NumberEqualsViewFilterItemCreate( - field_id=field.id, type="number", operator="equal", value=42.0 - ), + {"type": "number", "operator": "equal", "value": 42.0, "or_equal": False}, "equal", "42.0", ), ( - NumberNotEqualsViewFilterItemCreate( - field_id=field.id, type="number", operator="not_equal", value=0.0 - ), + { + "type": "number", + "operator": "not_equal", + "value": 0.0, + "or_equal": False, + }, "not_equal", "0.0", ), ( - NumberHigherThanViewFilterItemCreate( - field_id=field.id, - type="number", - operator="higher_than", - value=100.0, - or_equal=False, - ), + { + "type": "number", + "operator": "higher_than", + "value": 100.0, + "or_equal": False, + }, "higher_than", "100.0", ), ( - NumberLowerThanViewFilterItemCreate( - field_id=field.id, - type="number", - operator="lower_than", - value=50.0, - or_equal=False, - ), + { + "type": "number", + "operator": "higher_than", + "value": 100.0, + "or_equal": True, + }, + "higher_than_or_equal", + "100.0", + ), + ( + { + "type": "number", + "operator": "lower_than", + "value": 50.0, + "or_equal": False, + }, "lower_than", "50.0", ), ( - NumberEmptyViewFilterItemCreate( - field_id=field.id, type="number", operator="empty", value=0.0 - ), + { + "type": "number", + "operator": "lower_than", + "value": 50.0, + "or_equal": True, + }, + "lower_than_or_equal", + "50.0", + ), + ( + {"type": "number", "operator": "empty", "value": 0.0, "or_equal": False}, "empty", "0.0", ), ( - NumberNotEmptyViewFilterItemCreate( - field_id=field.id, type="number", operator="not_empty", value=0.0 - ), + { + "type": "number", + "operator": "not_empty", + "value": 0.0, + "or_equal": False, + }, "not_empty", "0.0", ), ] - for filter_create, expected_type, expected_value in number_filters: - created_filter = create_view_filter(user, view, table_fields, filter_create) + for kwargs, expected_type, expected_value in number_filters: + filter_item = _make_filter(field.id, **kwargs) + created_filter = create_view_filter(user, view, table_fields, filter_item) assert created_filter is not None assert created_filter.type == expected_type @@ -183,7 +159,7 @@ def test_all_number_filters_conversion(data_fixture): @pytest.mark.django_db def test_all_date_filters_conversion(data_fixture): - """Test all date filter types can be converted to Baserow filters.""" + """Test all date filter operators can be converted to Baserow filters.""" user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) @@ -194,80 +170,86 @@ def test_all_date_filters_conversion(data_fixture): table_fields = {field.id: field} # Test with exact date - date_filter = DateEqualsViewFilterItemCreate( - field_id=field.id, + filter_item = _make_filter( + field.id, type="date", operator="equal", - value=Date(year=2024, month=1, day=15), + value="2024-01-15", mode="exact_date", + or_equal=False, ) - created_filter = create_view_filter(user, view, table_fields, date_filter) - assert created_filter.type == "date_is" - assert "2024-01-15" in created_filter.value - assert created_filter.value.endswith("?exact_date") + created = create_view_filter(user, view, table_fields, filter_item) + assert created.type == "date_is" + assert "2024-01-15" in created.value + assert created.value.endswith("?exact_date") # Test with relative date (today) - date_filter2 = DateNotEqualsViewFilterItemCreate( - field_id=field.id, type="date", operator="not_equal", value=None, mode="today" + filter_item2 = _make_filter( + field.id, + type="date", + operator="not_equal", + value=None, + mode="today", + or_equal=False, ) - created_filter2 = create_view_filter(user, view, table_fields, date_filter2) - assert created_filter2.type == "date_is_not" - assert created_filter2.value.endswith("??today") + created2 = create_view_filter(user, view, table_fields, filter_item2) + assert created2.type == "date_is_not" + assert created2.value.endswith("??today") # Test date_is_after - date_filter3 = DateAfterViewFilterItemCreate( - field_id=field.id, + filter_item3 = _make_filter( + field.id, type="date", operator="after", value=7, mode="nr_days_ago", or_equal=False, ) - created_filter3 = create_view_filter(user, view, table_fields, date_filter3) - assert created_filter3.type == "date_is_after" - assert "?7?" in created_filter3.value - assert created_filter3.value.endswith("nr_days_ago") + created3 = create_view_filter(user, view, table_fields, filter_item3) + assert created3.type == "date_is_after" + assert "?7?" in created3.value + assert created3.value.endswith("nr_days_ago") # Test date_is_on_or_after - date_filter4 = DateAfterViewFilterItemCreate( - field_id=field.id, + filter_item4 = _make_filter( + field.id, type="date", operator="after", value=30, mode="nr_days_from_now", or_equal=True, ) - created_filter4 = create_view_filter(user, view, table_fields, date_filter4) - assert created_filter4.type == "date_is_on_or_after" + created4 = create_view_filter(user, view, table_fields, filter_item4) + assert created4.type == "date_is_on_or_after" # Test date_is_before - date_filter5 = DateBeforeViewFilterItemCreate( - field_id=field.id, + filter_item5 = _make_filter( + field.id, type="date", operator="before", value=None, mode="tomorrow", or_equal=False, ) - created_filter5 = create_view_filter(user, view, table_fields, date_filter5) - assert created_filter5.type == "date_is_before" + created5 = create_view_filter(user, view, table_fields, filter_item5) + assert created5.type == "date_is_before" # Test date_is_on_or_before - date_filter6 = DateBeforeViewFilterItemCreate( - field_id=field.id, + filter_item6 = _make_filter( + field.id, type="date", operator="before", value=14, mode="nr_weeks_from_now", or_equal=True, ) - created_filter6 = create_view_filter(user, view, table_fields, date_filter6) - assert created_filter6.type == "date_is_on_or_before" + created6 = create_view_filter(user, view, table_fields, filter_item6) + assert created6.type == "date_is_on_or_before" @pytest.mark.django_db def test_all_single_select_filters_conversion(data_fixture): - """Test all single select filter types can be converted to Baserow filters.""" + """Test all single select filter operators can be converted to Baserow filters.""" user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) @@ -281,45 +263,38 @@ def test_all_single_select_filters_conversion(data_fixture): table_fields = {field.id: field} # Test is_any_of - filter_create = SingleSelectIsAnyViewFilterItemCreate( - field_id=field.id, + filter_item = _make_filter( + field.id, type="single_select", operator="is_any_of", value=["Active", "Pending"], ) - created_filter = create_view_filter(user, view, table_fields, filter_create) - assert created_filter.type == "single_select_is_any_of" - # Value should contain option IDs - option_ids = created_filter.value.split(",") + created = create_view_filter(user, view, table_fields, filter_item) + assert created.type == "single_select_is_any_of" + option_ids = created.value.split(",") assert str(option1.id) in option_ids assert str(option2.id) in option_ids assert len(option_ids) == 2 # Test case insensitive matching - filter_create2 = SingleSelectIsAnyViewFilterItemCreate( - field_id=field.id, - type="single_select", - operator="is_any_of", - value=["active"], # lowercase + filter_item2 = _make_filter( + field.id, type="single_select", operator="is_any_of", value=["active"] ) - created_filter2 = create_view_filter(user, view, table_fields, filter_create2) - assert str(option1.id) in created_filter2.value + created2 = create_view_filter(user, view, table_fields, filter_item2) + assert str(option1.id) in created2.value # Test is_none_of - filter_create3 = SingleSelectIsNoneOfNotViewFilterItemCreate( - field_id=field.id, - type="single_select", - operator="is_none_of", - value=["Inactive"], + filter_item3 = _make_filter( + field.id, type="single_select", operator="is_none_of", value=["Inactive"] ) - created_filter3 = create_view_filter(user, view, table_fields, filter_create3) - assert created_filter3.type == "single_select_is_none_of" - assert str(option3.id) in created_filter3.value + created3 = create_view_filter(user, view, table_fields, filter_item3) + assert created3.type == "single_select_is_none_of" + assert str(option3.id) in created3.value @pytest.mark.django_db def test_all_multiple_select_filters_conversion(data_fixture): - """Test all multiple select filter types can be converted to Baserow filters.""" + """Test all multiple select filter operators can be converted to Baserow filters.""" user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) @@ -333,28 +308,25 @@ def test_all_multiple_select_filters_conversion(data_fixture): table_fields = {field.id: field} # Test is_any_of (has) - filter_create = MultipleSelectIsAnyViewFilterItemCreate( - field_id=field.id, + filter_item = _make_filter( + field.id, type="multiple_select", operator="is_any_of", value=["Important", "Urgent"], ) - created_filter = create_view_filter(user, view, table_fields, filter_create) - assert created_filter.type == "multiple_select_has" - option_ids = created_filter.value.split(",") + created = create_view_filter(user, view, table_fields, filter_item) + assert created.type == "multiple_select_has" + option_ids = created.value.split(",") assert str(option1.id) in option_ids assert str(option2.id) in option_ids # Test is_none_of (has_not) - filter_create2 = MultipleSelectIsNoneOfNotViewFilterItemCreate( - field_id=field.id, - type="multiple_select", - operator="is_none_of", - value=["Archived"], + filter_item2 = _make_filter( + field.id, type="multiple_select", operator="is_none_of", value=["Archived"] ) - created_filter2 = create_view_filter(user, view, table_fields, filter_create2) - assert created_filter2.type == "multiple_select_has_not" - assert str(option3.id) in created_filter2.value + created2 = create_view_filter(user, view, table_fields, filter_item2) + assert created2.type == "multiple_select_has_not" + assert str(option3.id) in created2.value @pytest.mark.django_db @@ -362,7 +334,7 @@ def test_all_multiple_select_filters_conversion(data_fixture): reason="Link row filters have a bug in Baserow (UnboundLocalError in view_filters.py:1301)" ) def test_all_link_row_filters_conversion(data_fixture): - """Test all link row filter types can be converted to Baserow filters.""" + """Test all link row filter operators can be converted to Baserow filters.""" user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) @@ -374,25 +346,23 @@ def test_all_link_row_filters_conversion(data_fixture): table_fields = {field.id: field} # Test link_row_has - filter_create = LinkRowHasViewFilterItemCreate( - field_id=field.id, type="link_row", operator="has", value=123 - ) - created_filter = create_view_filter(user, view, table_fields, filter_create) - assert created_filter.type == "link_row_has" - assert created_filter.value == "123" + filter_item = _make_filter(field.id, type="link_row", operator="has", value=123) + created = create_view_filter(user, view, table_fields, filter_item) + assert created.type == "link_row_has" + assert created.value == "123" # Test link_row_has_not - filter_create2 = LinkRowHasNotViewFilterItemCreate( - field_id=field.id, type="link_row", operator="has_not", value=456 + filter_item2 = _make_filter( + field.id, type="link_row", operator="has_not", value=456 ) - created_filter2 = create_view_filter(user, view, table_fields, filter_create2) - assert created_filter2.type == "link_row_has_not" - assert created_filter2.value == "456" + created2 = create_view_filter(user, view, table_fields, filter_item2) + assert created2.type == "link_row_has_not" + assert created2.value == "456" @pytest.mark.django_db def test_all_boolean_filters_conversion(data_fixture): - """Test all boolean filter types can be converted to Baserow filters.""" + """Test all boolean filter operators can be converted to Baserow filters.""" user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) @@ -403,109 +373,30 @@ def test_all_boolean_filters_conversion(data_fixture): table_fields = {field.id: field} # Test is true - filter_create = BooleanIsViewFilterItemCreate( - field_id=field.id, type="boolean", operator="is", value=True - ) - created_filter = create_view_filter(user, view, table_fields, filter_create) - assert created_filter.type == "boolean" - assert created_filter.value == "1" + filter_item = _make_filter(field.id, type="boolean", operator="equal", value=True) + created = create_view_filter(user, view, table_fields, filter_item) + assert created.type == "equal" + assert created.value == "1" # Test is false - filter_create2 = BooleanIsViewFilterItemCreate( - field_id=field.id, type="boolean", operator="is", value=False - ) - created_filter2 = create_view_filter(user, view, table_fields, filter_create2) - assert created_filter2.type == "boolean" - assert created_filter2.value == "0" - - -def get_all_concrete_filter_classes(): - """ - Recursively find all concrete ViewFilterItemCreate subclasses. Concrete classes are - those that have specific operators and are meant to be instantiated. - """ - - def get_all_subclasses(cls): - all_subclasses = [] - for subclass in cls.__subclasses__(): - all_subclasses.append(subclass) - all_subclasses.extend(get_all_subclasses(subclass)) - return all_subclasses - - all_subclasses = get_all_subclasses(ViewFilterItemCreate) - - # Filter to only concrete classes (those with specific operators defined as Literal) - # These are the classes that end with "Create" and have a specific operator - concrete_classes = [] - for cls in all_subclasses: - # Check if this class defines a specific operator (has Literal type annotation) - if hasattr(cls, "__annotations__") and "operator" in cls.__annotations__: - annotation = cls.__annotations__["operator"] - # Check if it's a Literal type (concrete operator) - if hasattr(annotation, "__origin__") or "Literal" in str(annotation): - concrete_classes.append(cls) - - return concrete_classes - - -def test_filter_class_discovery(): - """ - Test that the filter class discovery mechanism works correctly. This ensures our - introspection logic properly identifies concrete filter classes. - """ - - all_concrete_classes = get_all_concrete_filter_classes() - - # Verify we found a reasonable number of filter classes - # As of now, there should be at least 20+ concrete filter classes - assert len(all_concrete_classes) >= 20, ( - f"Expected at least 20 concrete filter classes, found {len(all_concrete_classes)}. " - f"Classes found: {[cls.__name__ for cls in all_concrete_classes]}" - ) - - # Verify that known concrete classes are discovered - class_names = {cls.__name__ for cls in all_concrete_classes} - expected_classes = { - "TextEqualViewFilterItemCreate", - "NumberEqualsViewFilterItemCreate", - "DateEqualsViewFilterItemCreate", - "BooleanIsViewFilterItemCreate", - "LinkRowHasViewFilterItemCreate", - "SingleSelectIsAnyViewFilterItemCreate", - "MultipleSelectIsAnyViewFilterItemCreate", - } - - missing = expected_classes - class_names - assert not missing, f"Expected classes not found: {missing}" - - # Verify that base/intermediate classes are NOT included - excluded_classes = { - "ViewFilterItemCreate", - "TextViewFilterItemCreate", - "NumberViewFilterItemCreate", - "DateViewFilterItemCreate", - } - - found_excluded = excluded_classes & class_names - assert not found_excluded, ( - f"Base/intermediate classes should not be included: {found_excluded}" - ) + filter_item2 = _make_filter(field.id, type="boolean", operator="equal", value=False) + created2 = create_view_filter(user, view, table_fields, filter_item2) + assert created2.type == "equal" + assert created2.value == "0" @pytest.mark.django_db def test_comprehensive_all_filter_types_conversion(data_fixture): """ - Comprehensive test ensuring ALL filter types can be successfully converted to - Baserow filters with a table containing all supported field types. + Comprehensive test ensuring all filter config types can be successfully + converted to Baserow filters with a table containing all supported field types. """ - # Setup user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) database = data_fixture.create_database_application(workspace=workspace) table = data_fixture.create_database_table(database=database, name="All Fields") - # Create all field types text_field = data_fixture.create_text_field(table=table, name="Text", primary=True) number_field = data_fixture.create_number_field(table=table, name="Number") date_field = data_fixture.create_date_field(table=table, name="Date") @@ -513,16 +404,9 @@ def test_comprehensive_all_filter_types_conversion(data_fixture): single_select = data_fixture.create_single_select_field(table=table, name="Status") multi_select = data_fixture.create_multiple_select_field(table=table, name="Tags") - linked_table = data_fixture.create_database_table(database=database, name="Linked") - data_fixture.create_text_field(table=linked_table, name="Linked Text", primary=True) - link_field = data_fixture.create_link_row_field( - table=table, link_row_table=linked_table - ) - data_fixture.create_select_option(field=single_select, value="Active", order=1) data_fixture.create_select_option(field=multi_select, value="Important", order=1) - # Create view and table_fields dict view = data_fixture.create_grid_view(table=table) table_fields = { text_field.id: text_field, @@ -531,82 +415,78 @@ def test_comprehensive_all_filter_types_conversion(data_fixture): boolean_field.id: boolean_field, single_select.id: single_select, multi_select.id: multi_select, - link_field.id: link_field, } - # List of all filter types to test all_filters = [ # Text filters - TextEqualViewFilterItemCreate( - field_id=text_field.id, type="text", operator="equal", value="test" - ), - TextNotEqualViewFilterItemCreate( - field_id=text_field.id, type="text", operator="not_equal", value="test" - ), - TextContainsViewFilterItemCreate( - field_id=text_field.id, type="text", operator="contains", value="test" - ), - TextNotContainsViewFilterItemCreate( - field_id=text_field.id, type="text", operator="contains_not", value="test" - ), - TextEmptyViewFilterItemCreate( - field_id=text_field.id, type="text", operator="empty", value="" - ), - TextNotEmptyViewFilterItemCreate( - field_id=text_field.id, type="text", operator="not_empty", value="" - ), + _make_filter(text_field.id, type="text", operator="equal", value="test"), + _make_filter(text_field.id, type="text", operator="not_equal", value="test"), + _make_filter(text_field.id, type="text", operator="contains", value="test"), + _make_filter(text_field.id, type="text", operator="contains_not", value="test"), + _make_filter(text_field.id, type="text", operator="empty", value=""), + _make_filter(text_field.id, type="text", operator="not_empty", value=""), # Number filters - NumberEqualsViewFilterItemCreate( - field_id=number_field.id, type="number", operator="equal", value=42.0 + _make_filter( + number_field.id, type="number", operator="equal", value=42.0, or_equal=False ), - NumberNotEqualsViewFilterItemCreate( - field_id=number_field.id, type="number", operator="not_equal", value=0.0 + _make_filter( + number_field.id, + type="number", + operator="not_equal", + value=0.0, + or_equal=False, ), - NumberHigherThanViewFilterItemCreate( - field_id=number_field.id, + _make_filter( + number_field.id, type="number", operator="higher_than", value=10.0, or_equal=False, ), - NumberLowerThanViewFilterItemCreate( - field_id=number_field.id, + _make_filter( + number_field.id, type="number", operator="lower_than", value=100.0, or_equal=True, ), - NumberEmptyViewFilterItemCreate( - field_id=number_field.id, type="number", operator="empty", value=0.0 + _make_filter( + number_field.id, type="number", operator="empty", value=0.0, or_equal=False ), - NumberNotEmptyViewFilterItemCreate( - field_id=number_field.id, type="number", operator="not_empty", value=0.0 + _make_filter( + number_field.id, + type="number", + operator="not_empty", + value=0.0, + or_equal=False, ), # Date filters - DateEqualsViewFilterItemCreate( - field_id=date_field.id, + _make_filter( + date_field.id, type="date", operator="equal", - value=Date(year=2024, month=1, day=1), + value="2024-01-01", mode="exact_date", + or_equal=False, ), - DateNotEqualsViewFilterItemCreate( - field_id=date_field.id, + _make_filter( + date_field.id, type="date", operator="not_equal", value=None, mode="today", + or_equal=False, ), - DateAfterViewFilterItemCreate( - field_id=date_field.id, + _make_filter( + date_field.id, type="date", operator="after", value=7, mode="nr_days_ago", or_equal=False, ), - DateBeforeViewFilterItemCreate( - field_id=date_field.id, + _make_filter( + date_field.id, type="date", operator="before", value=None, @@ -614,44 +494,34 @@ def test_comprehensive_all_filter_types_conversion(data_fixture): or_equal=True, ), # Select filters - SingleSelectIsAnyViewFilterItemCreate( - field_id=single_select.id, + _make_filter( + single_select.id, type="single_select", operator="is_any_of", value=["Active"], ), - SingleSelectIsNoneOfNotViewFilterItemCreate( - field_id=single_select.id, + _make_filter( + single_select.id, type="single_select", operator="is_none_of", value=["Active"], ), - MultipleSelectIsAnyViewFilterItemCreate( - field_id=multi_select.id, + _make_filter( + multi_select.id, type="multiple_select", operator="is_any_of", value=["Important"], ), - MultipleSelectIsNoneOfNotViewFilterItemCreate( - field_id=multi_select.id, + _make_filter( + multi_select.id, type="multiple_select", operator="is_none_of", value=["Important"], ), - # Link row filters - LinkRowHasViewFilterItemCreate( - field_id=link_field.id, type="link_row", operator="has", value=1 - ), - LinkRowHasNotViewFilterItemCreate( - field_id=link_field.id, type="link_row", operator="has_not", value=2 - ), # Boolean filter - BooleanIsViewFilterItemCreate( - field_id=boolean_field.id, type="boolean", operator="is", value=True - ), + _make_filter(boolean_field.id, type="boolean", operator="equal", value=True), ] - # Test that all filters can be created successfully created_filters = [] for filter_item in all_filters: try: @@ -662,11 +532,9 @@ def test_comprehensive_all_filter_types_conversion(data_fixture): except Exception as e: pytest.fail(f"Failed to create filter {filter_item}: {e}") - # Verify all filters were created in the database assert len(created_filters) == len(all_filters) assert ViewFilter.objects.filter(view=view).count() == len(all_filters) - # Verify each filter type is represented filter_types = set(f.type for f in created_filters) expected_types = { "equal", @@ -676,7 +544,7 @@ def test_comprehensive_all_filter_types_conversion(data_fixture): "empty", "not_empty", "higher_than", - "lower_than", + "lower_than_or_equal", "date_is", "date_is_not", "date_is_after", @@ -685,29 +553,6 @@ def test_comprehensive_all_filter_types_conversion(data_fixture): "single_select_is_none_of", "multiple_select_has", "multiple_select_has_not", - "link_row_has", - "link_row_has_not", - "boolean", + "equal", # for boolean field } assert filter_types == expected_types - - # CRITICAL CHECK: Ensure all concrete filter classes are tested - all_concrete_classes = get_all_concrete_filter_classes() - tested_classes = {type(filter_item) for filter_item in all_filters} - - missing_classes = set(all_concrete_classes) - tested_classes - if missing_classes: - missing_names = [cls.__name__ for cls in missing_classes] - pytest.fail( - f"The following filter classes are not tested: {', '.join(missing_names)}. " - f"Please add test instances for these classes to the all_filters list." - ) - - # Ensure we're not testing non-existent classes - extra_classes = tested_classes - set(all_concrete_classes) - if extra_classes: - extra_names = [cls.__name__ for cls in extra_classes] - pytest.fail( - f"The following classes in the test don't exist as concrete filter classes: " - f"{', '.join(extra_names)}. Please remove them from the test." - ) diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_views_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_views_tools.py index bf3e940d08..4ac15bd2e7 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_views_tools.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_views_tools.py @@ -1,91 +1,21 @@ -from unittest.mock import Mock - import pytest -from udspy.module.callbacks import ModuleContext, is_module_callback from baserow.contrib.database.views.models import View, ViewFilter from baserow_enterprise.assistant.tools.database.tools import ( - get_list_views_tool, - get_views_tool_factory, + create_view_filters, + create_views, + list_views, ) from baserow_enterprise.assistant.tools.database.types import ( - BooleanIsViewFilterItemCreate, - CalendarViewItemCreate, - DateAfterViewFilterItemCreate, - DateBeforeViewFilterItemCreate, - DateEqualsViewFilterItemCreate, - DateNotEqualsViewFilterItemCreate, FormFieldOption, - FormViewItemCreate, - GalleryViewItemCreate, - GridViewItemCreate, - KanbanViewItemCreate, - MultipleSelectIsAnyViewFilterItemCreate, - MultipleSelectIsNoneOfNotViewFilterItemCreate, - NumberEqualsViewFilterItemCreate, - NumberHigherThanViewFilterItemCreate, - NumberLowerThanViewFilterItemCreate, - NumberNotEqualsViewFilterItemCreate, - SingleSelectIsAnyViewFilterItemCreate, - SingleSelectIsNoneOfNotViewFilterItemCreate, - TextContainsViewFilterItemCreate, - TextEqualViewFilterItemCreate, - TextNotContainsViewFilterItemCreate, - TextNotEqualViewFilterItemCreate, - TimelineViewItemCreate, + ViewItemCreate, ) -from baserow_enterprise.assistant.tools.database.types.base import Date from baserow_enterprise.assistant.tools.database.types.view_filters import ( + ViewFilterItemCreate, ViewFiltersArgs, ) -from .utils import fake_tool_helpers - - -def get_create_views_tool(user, workspace): - """Helper to get the create_views tool from the factory""" - - factory = get_views_tool_factory(user, workspace, fake_tool_helpers) - assert callable(factory) - - tools_upgrade = factory() - assert is_module_callback(tools_upgrade) - - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called - - added_tools = mock_module.init_module.call_args[1]["tools"] - create_views_tool = next( - (tool for tool in added_tools if tool.name == "create_views"), None - ) - assert create_views_tool is not None - return create_views_tool - - -def get_create_view_filters_tool(user, workspace): - """Helper to get the create_view_filters tool from the factory""" - - factory = get_views_tool_factory(user, workspace, fake_tool_helpers) - assert callable(factory) - - tools_upgrade = factory() - assert is_module_callback(tools_upgrade) - - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called - - added_tools = mock_module.init_module.call_args[1]["tools"] - create_filters_tool = next( - (tool for tool in added_tools if tool.name == "create_view_filters"), None - ) - assert create_filters_tool is not None - return create_filters_tool +from .utils import make_test_ctx @pytest.mark.django_db @@ -96,23 +26,23 @@ def test_list_views_tool(data_fixture): table = data_fixture.create_database_table(database=database) view = data_fixture.create_grid_view(table=table, name="View 1", order=1) - tool = get_list_views_tool(user, workspace, fake_tool_helpers) - response = tool(table_id=table.id) + ctx = make_test_ctx(user, workspace) + response = list_views(ctx, thought="test", table_id=table.id) assert response == { "views": [ { "id": view.id, "name": "View 1", + "public": False, "type": "grid", "row_height": "small", - "public": False, } ] } view_2 = data_fixture.create_grid_view(table=table, name="View 2", order=2) - response = tool(table_id=table.id) + response = list_views(ctx, thought="test", table_id=table.id) assert len(response["views"]) == 2 assert response["views"][0]["name"] == "View 1" assert response["views"][1]["name"] == "View 2" @@ -125,12 +55,17 @@ def test_create_grid_view(data_fixture): database = data_fixture.create_database_application(workspace=workspace) table = data_fixture.create_database_table(database=database) - tool = get_create_views_tool(user, workspace) - response = tool.func( + ctx = make_test_ctx(user, workspace) + response = create_views( + ctx, + thought="test", table_id=table.id, views=[ - GridViewItemCreate( - type="grid", name="Grid View", public=False, row_height="medium" + ViewItemCreate( + name="Grid View", + public=False, + type="grid", + row_height="medium", ) ], ) @@ -148,14 +83,16 @@ def test_create_kanban_view(data_fixture): table = data_fixture.create_database_table(database=database) single_select = data_fixture.create_single_select_field(table=table, name="Status") - tool = get_create_views_tool(user, workspace) - response = tool.func( + ctx = make_test_ctx(user, workspace) + response = create_views( + ctx, + thought="test", table_id=table.id, views=[ - KanbanViewItemCreate( - type="kanban", + ViewItemCreate( name="Kanban View", public=False, + type="kanban", column_field_id=single_select.id, ) ], @@ -174,14 +111,16 @@ def test_create_calendar_view(data_fixture): table = data_fixture.create_database_table(database=database) date_field = data_fixture.create_date_field(table=table, name="Date") - tool = get_create_views_tool(user, workspace) - response = tool.func( + ctx = make_test_ctx(user, workspace) + response = create_views( + ctx, + thought="test", table_id=table.id, views=[ - CalendarViewItemCreate( - type="calendar", + ViewItemCreate( name="Calendar View", public=False, + type="calendar", date_field_id=date_field.id, ) ], @@ -200,14 +139,16 @@ def test_create_gallery_view(data_fixture): table = data_fixture.create_database_table(database=database) file_field = data_fixture.create_file_field(table=table, name="Files") - tool = get_create_views_tool(user, workspace) - response = tool.func( + ctx = make_test_ctx(user, workspace) + response = create_views( + ctx, + thought="test", table_id=table.id, views=[ - GalleryViewItemCreate( - type="gallery", + ViewItemCreate( name="Gallery View", public=False, + type="gallery", cover_field_id=file_field.id, ) ], @@ -227,14 +168,16 @@ def test_create_timeline_view(data_fixture): start_date = data_fixture.create_date_field(table=table, name="Start Date") end_date = data_fixture.create_date_field(table=table, name="End Date") - tool = get_create_views_tool(user, workspace) - response = tool.func( + ctx = make_test_ctx(user, workspace) + response = create_views( + ctx, + thought="test", table_id=table.id, views=[ - TimelineViewItemCreate( - type="timeline", + ViewItemCreate( name="Timeline View", public=False, + type="timeline", start_date_field_id=start_date.id, end_date_field_id=end_date.id, ) @@ -254,14 +197,16 @@ def test_create_form_view(data_fixture): table = data_fixture.create_database_table(database=database) field = data_fixture.create_text_field(table=table, name="Name", primary=True) - tool = get_create_views_tool(user, workspace) - response = tool.func( + ctx = make_test_ctx(user, workspace) + response = create_views( + ctx, + thought="test", table_id=table.id, views=[ - FormViewItemCreate( - type="form", + ViewItemCreate( name="Form View", public=True, + type="form", title="Contact Form", description="Fill out this form", submit_button_label="Submit", @@ -297,18 +242,23 @@ def test_create_text_equal_filter(data_fixture): field = data_fixture.create_text_field(table=table, name="Name") view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - TextEqualViewFilterItemCreate( - field_id=field.id, type="text", operator="equal", value="test" + ViewFilterItemCreate( + field_id=field.id, + type="text", + operator="equal", + value="test", ) ], ) - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -326,13 +276,15 @@ def test_create_text_not_equal_filter(data_fixture): field = data_fixture.create_text_field(table=table) view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - TextNotEqualViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="text", operator="not_equal", @@ -340,7 +292,7 @@ def test_create_text_not_equal_filter(data_fixture): ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -356,13 +308,15 @@ def test_create_text_contains_filter(data_fixture): field = data_fixture.create_text_field(table=table) view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - TextContainsViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="text", operator="contains", @@ -370,7 +324,7 @@ def test_create_text_contains_filter(data_fixture): ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -386,13 +340,15 @@ def test_create_text_not_contains_filter(data_fixture): field = data_fixture.create_text_field(table=table) view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - TextNotContainsViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="text", operator="contains_not", @@ -400,7 +356,7 @@ def test_create_text_not_contains_filter(data_fixture): ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -419,18 +375,24 @@ def test_create_number_equal_filter(data_fixture): field = data_fixture.create_number_field(table=table) view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - NumberEqualsViewFilterItemCreate( - field_id=field.id, type="number", operator="equal", value=42.0 + ViewFilterItemCreate( + field_id=field.id, + type="number", + operator="equal", + value=42.0, + or_equal=False, ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -446,21 +408,24 @@ def test_create_number_not_equal_filter(data_fixture): field = data_fixture.create_number_field(table=table) view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - NumberNotEqualsViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="number", operator="not_equal", value=42.0, + or_equal=False, ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -476,13 +441,15 @@ def test_create_number_higher_than_filter(data_fixture): field = data_fixture.create_number_field(table=table) view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - NumberHigherThanViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="number", operator="higher_than", @@ -509,13 +476,15 @@ def test_create_number_lower_than_filter(data_fixture): field = data_fixture.create_number_field(table=table) view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - NumberLowerThanViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="number", operator="lower_than", @@ -524,7 +493,7 @@ def test_create_number_lower_than_filter(data_fixture): ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -541,22 +510,25 @@ def test_create_date_equal_filter(data_fixture): field = data_fixture.create_date_field(table=table) view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - DateEqualsViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="date", operator="equal", - value=Date(year=2024, month=1, day=15), + value="2024-01-15", mode="exact_date", + or_equal=False, ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -572,22 +544,25 @@ def test_create_date_not_equal_filter(data_fixture): field = data_fixture.create_date_field(table=table) view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - DateNotEqualsViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="date", operator="not_equal", value=None, mode="today", + or_equal=False, ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -605,13 +580,15 @@ def test_create_date_after_filter(data_fixture): field = data_fixture.create_date_field(table=table) view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - DateAfterViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="date", operator="after", @@ -621,7 +598,7 @@ def test_create_date_after_filter(data_fixture): ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -639,13 +616,15 @@ def test_create_date_before_filter(data_fixture): field = data_fixture.create_date_field(table=table) view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - DateBeforeViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="date", operator="before", @@ -655,7 +634,7 @@ def test_create_date_before_filter(data_fixture): ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -676,13 +655,15 @@ def test_create_single_select_is_any_of_filter(data_fixture): data_fixture.create_select_option(field=field, value="Option 2") view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - SingleSelectIsAnyViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="single_select", operator="is_any_of", @@ -690,7 +671,7 @@ def test_create_single_select_is_any_of_filter(data_fixture): ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -709,13 +690,15 @@ def test_create_single_select_is_none_of_filter(data_fixture): data_fixture.create_select_option(field=field, value="Bad Option") view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - SingleSelectIsNoneOfNotViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="single_select", operator="is_none_of", @@ -723,7 +706,7 @@ def test_create_single_select_is_none_of_filter(data_fixture): ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -742,22 +725,27 @@ def test_create_boolean_is_true_filter(data_fixture): field = data_fixture.create_boolean_field(table=table, name="Active") view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - BooleanIsViewFilterItemCreate( - field_id=field.id, type="boolean", operator="is", value=True + ViewFilterItemCreate( + field_id=field.id, + type="boolean", + operator="equal", + value=True, ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 - assert ViewFilter.objects.filter(view=view, field=field, type="boolean").exists() + assert ViewFilter.objects.filter(view=view, field=field, type="equal").exists() @pytest.mark.django_db @@ -769,22 +757,27 @@ def test_create_boolean_is_false_filter(data_fixture): field = data_fixture.create_boolean_field(table=table, name="Active") view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - BooleanIsViewFilterItemCreate( - field_id=field.id, type="boolean", operator="is", value=False + ViewFilterItemCreate( + field_id=field.id, + type="boolean", + operator="equal", + value=False, ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 - assert ViewFilter.objects.filter(view=view, field=field, type="boolean").exists() + assert ViewFilter.objects.filter(view=view, field=field, type="equal").exists() # Multiple select filter tests @@ -799,13 +792,15 @@ def test_create_multiple_select_is_any_of_filter(data_fixture): data_fixture.create_select_option(field=field, value="Tag 2") view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - MultipleSelectIsAnyViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="multiple_select", operator="is_any_of", @@ -813,7 +808,7 @@ def test_create_multiple_select_is_any_of_filter(data_fixture): ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -832,13 +827,15 @@ def test_create_multiple_select_is_none_of_filter(data_fixture): data_fixture.create_select_option(field=field, value="Bad Tag") view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - MultipleSelectIsNoneOfNotViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="multiple_select", operator="is_none_of", @@ -846,7 +843,7 @@ def test_create_multiple_select_is_none_of_filter(data_fixture): ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 assert ViewFilter.objects.filter( diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_navigation_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_navigation_tools.py new file mode 100644 index 0000000000..a1f757bd3c --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_navigation_tools.py @@ -0,0 +1,47 @@ +from unittest.mock import MagicMock + +import pytest + +from baserow_enterprise.assistant.tools.navigation.tools import navigate +from baserow_enterprise.assistant.tools.navigation.types import ( + TableNavigationRequestType, +) + +from .utils import make_test_ctx + + +@pytest.mark.django_db +def test_navigate_to_table(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database, name="Tasks") + + navigate_mock = MagicMock(return_value="Navigated successfully.") + ctx = make_test_ctx(user, workspace) + ctx.deps.tool_helpers.navigate_to = navigate_mock + + request = TableNavigationRequestType(type="database-table", table_id=table.id) + result = navigate(ctx, request, thought="go to tasks table") + + assert result == "Navigated successfully." + navigate_mock.assert_called_once() + location = navigate_mock.call_args[0][0] + assert location.type == "database-table" + assert location.table_id == table.id + assert location.database_id == database.id + assert location.table_name == "Tasks" + + +@pytest.mark.django_db +def test_navigate_to_nonexistent_table(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + + ctx = make_test_ctx(user, workspace) + + request = TableNavigationRequestType(type="database-table", table_id=999999) + result = navigate(ctx, request, thought="go to missing table") + + assert "Error" in result + assert "not found" in result diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_search_docs_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_search_docs_tools.py new file mode 100644 index 0000000000..dcc77492f9 --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_search_docs_tools.py @@ -0,0 +1,104 @@ +import os +from unittest.mock import patch + +import pytest + +from baserow_enterprise.assistant.tools.search_user_docs.tools import ( + _TOOL_QUERY_RE, + search_user_docs, +) + +from .utils import make_test_ctx + +# search_user_docs is async, so we need this to allow sync ORM calls from +# data_fixture inside async tests. +os.environ.setdefault("DJANGO_ALLOW_ASYNC_UNSAFE", "true") + + +class TestToolQueryGuard: + """Tests for the tool-introspection regex guard.""" + + @pytest.mark.parametrize( + "query", + [ + "list_tables", + "create_fields", + "get_tables_schema", + "update_rows", + "delete_rows", + "generate_formula", + "create_view_filters", + "search_user_docs", + "navigate tool parameters", + ], + ) + def test_rejects_tool_introspection_queries(self, query): + assert _TOOL_QUERY_RE.search(query) is not None + + @pytest.mark.parametrize( + "query", + [ + "How to create a webhook in Baserow", + "How to link tables in Baserow", + "Baserow form view", + "How do I import data into Baserow", + ], + ) + def test_allows_legitimate_queries(self, query): + assert _TOOL_QUERY_RE.search(query) is None + + +@pytest.mark.django_db +@pytest.mark.asyncio +async def test_search_user_docs_rejects_tool_introspection(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + ctx = make_test_ctx(user, workspace) + + result = await search_user_docs( + ctx, question="list_tables", thought="looking up tool" + ) + + assert result["reliability"] == 0.0 + assert "REJECTED" in result["reliability_note"] + assert result["sources"] == [] + + +@pytest.mark.django_db +@pytest.mark.asyncio +async def test_search_user_docs_handles_empty_results(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + ctx = make_test_ctx(user, workspace) + + with patch( + "baserow_enterprise.assistant.tools.search_user_docs.tools.KnowledgeBaseHandler" + ) as mock_handler_cls: + mock_handler_cls.return_value.search.return_value = [] + + result = await search_user_docs( + ctx, question="How to use webhooks in Baserow", thought="user asks" + ) + + assert result["reliability"] == 0.0 + assert "Nothing found" in result["answer"] + + +@pytest.mark.django_db +@pytest.mark.asyncio +async def test_search_user_docs_handles_error(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + ctx = make_test_ctx(user, workspace) + + with patch( + "baserow_enterprise.assistant.tools.search_user_docs.tools.KnowledgeBaseHandler" + ) as mock_handler_cls: + mock_handler_cls.return_value.search.side_effect = RuntimeError("db error") + + result = await search_user_docs( + ctx, question="How to use webhooks", thought="user asks" + ) + + assert result["reliability"] == 0.0 + assert "error" in result["answer"].lower() diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_retrying_model.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_retrying_model.py new file mode 100644 index 0000000000..c83caf73ef --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_retrying_model.py @@ -0,0 +1,472 @@ +"""Unit tests for RetryingModel.""" + +import pytest + +from baserow_enterprise.assistant.retrying_model import ( + RetryingModel, + _is_transient_provider_error, +) + + +class TestIsTransientProviderError: + def test_groq_parse_error(self): + exc = Exception("Failed to parse tool call arguments as JSON") + assert _is_transient_provider_error(exc) is True + + def test_tool_validation_failed(self): + exc = Exception("Tool call validation failed: something") + assert _is_transient_provider_error(exc) is True + + def test_auth_error_not_retryable(self): + exc = Exception("Invalid API key") + assert _is_transient_provider_error(exc) is False + + def test_generic_error_not_retryable(self): + exc = ValueError("something went wrong") + assert _is_transient_provider_error(exc) is False + + +def _make_retrying(inner_mock, **kwargs): + """Create a RetryingModel with a pre-resolved mock as the wrapped model.""" + + model = RetryingModel.__new__(RetryingModel) + model._wrapped_or_name = inner_mock + model._resolved = inner_mock + model.max_attempts = kwargs.get("max_attempts", 3) + model.base_delay = kwargs.get("base_delay", 0.01) + model.max_delay = kwargs.get("max_delay", 10.0) + return model + + +@pytest.mark.asyncio +async def test_request_retries_on_transient_error(): + """RetryingModel.request should retry transient errors.""" + + from unittest.mock import AsyncMock, MagicMock + + from pydantic_ai.messages import ModelResponse, TextPart + from pydantic_ai.models import ModelRequestParameters + + inner = MagicMock() + response = ModelResponse(parts=[TextPart(content="hello")]) + inner.request = AsyncMock( + side_effect=[ + Exception("Failed to parse tool call arguments as JSON"), + response, + ] + ) + + model = _make_retrying(inner) + result = await model.request( + [], None, ModelRequestParameters(function_tools=[], output_tools=[]) + ) + + assert result == response + assert inner.request.call_count == 2 + + +@pytest.mark.asyncio +async def test_request_raises_non_transient_error(): + """RetryingModel.request should not retry non-transient errors.""" + + from unittest.mock import AsyncMock, MagicMock + + from pydantic_ai.models import ModelRequestParameters + + inner = MagicMock() + inner.request = AsyncMock(side_effect=ValueError("bad input")) + + model = _make_retrying(inner) + with pytest.raises(ValueError, match="bad input"): + await model.request( + [], None, ModelRequestParameters(function_tools=[], output_tools=[]) + ) + + assert inner.request.call_count == 1 + + +@pytest.mark.asyncio +async def test_request_exhausts_retries(): + """RetryingModel should raise after exhausting max_attempts.""" + + from unittest.mock import AsyncMock, MagicMock + + from pydantic_ai.models import ModelRequestParameters + + inner = MagicMock() + inner.request = AsyncMock( + side_effect=Exception("Failed to parse tool call arguments as JSON") + ) + + model = _make_retrying(inner, max_attempts=2) + with pytest.raises(Exception, match="Failed to parse"): + await model.request( + [], None, ModelRequestParameters(function_tools=[], output_tools=[]) + ) + + assert inner.request.call_count == 2 + + +def test_deferred_model_resolution(): + """RetryingModel should defer infer_model until first access.""" + + model = RetryingModel("groq:some-model") + # Should not raise at construction time + assert model._resolved is None + + +# --------------------------------------------------------------------------- +# tool_use_failed recovery +# --------------------------------------------------------------------------- + + +def _make_tool_use_failed_error( + failed_generation: str, + model_name: str = "test-model", +): + from pydantic_ai.exceptions import ModelHTTPError + + return ModelHTTPError( + status_code=400, + model_name=model_name, + body={ + "error": { + "message": "Failed to parse tool call arguments as JSON", + "type": "invalid_request_error", + "code": "tool_use_failed", + "failed_generation": failed_generation, + } + }, + ) + + +class TestTryRecoverToolUseFailed: + def test_recovers_valid_tool_call(self): + from pydantic_ai.messages import ToolCallPart + + from baserow_enterprise.assistant.retrying_model import ( + _try_recover_tool_use_failed, + ) + + exc = _make_tool_use_failed_error( + '{"name": "list_tables", "arguments": {"thought": "test"}}' + ) + result = _try_recover_tool_use_failed(exc) + + assert result is not None + assert len(result.parts) == 1 + part = result.parts[0] + assert isinstance(part, ToolCallPart) + assert part.tool_name == "list_tables" + assert "thought" in part.args + + def test_recovers_malformed_json_as_tool_call(self): + from baserow_enterprise.assistant.retrying_model import ( + _try_recover_tool_use_failed, + ) + + exc = _make_tool_use_failed_error("{not valid json") + result = _try_recover_tool_use_failed(exc) + + assert result is not None + assert len(result.parts) == 1 + from pydantic_ai.messages import ToolCallPart + + assert isinstance(result.parts[0], ToolCallPart) + assert result.parts[0].tool_name == "unknown" + assert result.parts[0].args == "{}" + + def test_recovers_malformed_json_extracts_tool_name(self): + from baserow_enterprise.assistant.retrying_model import ( + _try_recover_tool_use_failed, + ) + + exc = _make_tool_use_failed_error( + '{"name": "create_elements", "arguments": {"page_id": 1, "elements": [truncated' + ) + result = _try_recover_tool_use_failed(exc) + + assert result is not None + from pydantic_ai.messages import ToolCallPart + + assert isinstance(result.parts[0], ToolCallPart) + assert result.parts[0].tool_name == "create_elements" + assert result.parts[0].args == "{}" + + def test_returns_none_for_non_tool_use_failed(self): + from pydantic_ai.exceptions import ModelHTTPError + + from baserow_enterprise.assistant.retrying_model import ( + _try_recover_tool_use_failed, + ) + + exc = ModelHTTPError( + status_code=400, + model_name="test", + body={"error": {"message": "other error", "code": "other"}}, + ) + assert _try_recover_tool_use_failed(exc) is None + + def test_returns_none_for_non_model_http_error(self): + from baserow_enterprise.assistant.retrying_model import ( + _try_recover_tool_use_failed, + ) + + assert _try_recover_tool_use_failed(ValueError("nope")) is None + + def test_recovers_raw_api_error_with_body(self): + """Handles raw provider APIError (e.g. groq.APIError) with body attr.""" + + from pydantic_ai.messages import ToolCallPart + + from baserow_enterprise.assistant.retrying_model import ( + _try_recover_tool_use_failed, + ) + + class FakeAPIError(Exception): + def __init__(self, message, body=None): + super().__init__(message) + self.body = body + + exc = FakeAPIError( + "Failed to parse tool call arguments as JSON", + body={ + "message": "Failed to parse tool call arguments as JSON", + "type": "invalid_request_error", + "code": "tool_use_failed", + "failed_generation": '{"name": "create_rows", "arguments": {"table_id": 1}}', + }, + ) + result = _try_recover_tool_use_failed(exc) + assert result is not None + assert isinstance(result.parts[0], ToolCallPart) + assert result.parts[0].tool_name == "create_rows" + + +@pytest.mark.asyncio +async def test_request_recovers_tool_use_failed(): + """request() should recover tool_use_failed into a ModelResponse.""" + + from unittest.mock import AsyncMock, MagicMock + + from pydantic_ai.models import ModelRequestParameters + + inner = MagicMock() + inner.request = AsyncMock( + side_effect=_make_tool_use_failed_error( + '{"name": "create_rows", "arguments": {"thought": "hi"}}' + ) + ) + + model = _make_retrying(inner) + result = await model.request( + [], None, ModelRequestParameters(function_tools=[], output_tools=[]) + ) + + from pydantic_ai.messages import ToolCallPart + + # Should return a recovered response, not raise + assert len(result.parts) == 1 + assert isinstance(result.parts[0], ToolCallPart) + assert result.parts[0].tool_name == "create_rows" + # Should NOT have retried — recovery is immediate + assert inner.request.call_count == 1 + + +@pytest.mark.asyncio +async def test_request_stream_recovers_tool_use_failed(): + """request_stream() should recover tool_use_failed into a PreFetchedResponse.""" + + from unittest.mock import MagicMock + + from pydantic_ai.models import ModelRequestParameters + + inner = MagicMock() + + async def _failing_stream(*args, **kwargs): + raise _make_tool_use_failed_error( + '{"name": "list_rows", "arguments": {"table_id": 1}}' + ) + + # Make request_stream an async context manager that raises + from contextlib import asynccontextmanager + + @asynccontextmanager + async def failing_cm(*args, **kwargs): + raise _make_tool_use_failed_error( + '{"name": "list_rows", "arguments": {"table_id": 1}}' + ) + yield # pragma: no cover + + inner.request_stream = failing_cm + + model = _make_retrying(inner) + async with model.request_stream( + [], None, ModelRequestParameters(function_tools=[], output_tools=[]) + ) as stream: + # Collect events from the pre-fetched response + events = [e async for e in stream] + + from pydantic_ai.models import PartStartEvent + + start_events = [e for e in events if isinstance(e, PartStartEvent)] + assert len(start_events) == 1 + assert start_events[0].part.tool_name == "list_rows" + + +@pytest.mark.asyncio +async def test_request_stream_recovers_mid_stream_api_error(): + """_ErrorRecoveringStream catches APIError during chunk iteration + and emits recovery events instead of crashing.""" + + from contextlib import asynccontextmanager + from unittest.mock import MagicMock + + from pydantic_ai._parts_manager import ModelResponsePartsManager + from pydantic_ai.models import ModelRequestParameters, PartStartEvent + + # Simulate a real StreamedResponse whose _get_event_iterator raises APIError + class FakeAPIError(Exception): + """Simulates groq.APIError with a body attribute.""" + + def __init__(self, message, body=None): + super().__init__(message) + self.body = body + + class FakeStreamedResponse: + """Minimal fake that raises during iteration.""" + + model_name = "test-model" + provider_name = "test" + provider_url = "http://test" + timestamp = None + _parts_manager = ModelResponsePartsManager() + model_request_parameters = ModelRequestParameters( + function_tools=[], output_tools=[] + ) + final_result_event = None + provider_response_id = None + provider_details = None + finish_reason = None + + async def _get_event_iterator(self): + raise FakeAPIError( + "Failed to parse tool call arguments as JSON", + body={ + "message": "Failed to parse tool call arguments as JSON", + "type": "invalid_request_error", + "code": "tool_use_failed", + "failed_generation": '{"name": "create_elements", "arguments": {"bad": true}}', + }, + ) + yield # pragma: no cover — make it a generator + + inner = MagicMock() + + @asynccontextmanager + async def fake_request_stream(*args, **kwargs): + yield FakeStreamedResponse() + + inner.request_stream = fake_request_stream + + model = _make_retrying(inner) + async with model.request_stream( + [], None, ModelRequestParameters(function_tools=[], output_tools=[]) + ) as stream: + events = [e async for e in stream] + + start_events = [e for e in events if isinstance(e, PartStartEvent)] + assert len(start_events) == 1 + assert start_events[0].part.tool_name == "create_elements" + + +@pytest.mark.asyncio +async def test_request_stream_recovers_mid_stream_malformed_json(): + """_ErrorRecoveringStream recovers even when failed_generation JSON is + unparseable — returns a ToolCallPart with empty args so pydantic-ai's + validation loop can retry.""" + + from contextlib import asynccontextmanager + from unittest.mock import MagicMock + + from pydantic_ai._parts_manager import ModelResponsePartsManager + from pydantic_ai.messages import ToolCallPart + from pydantic_ai.models import ModelRequestParameters, PartStartEvent + + class FakeAPIError(Exception): + def __init__(self, message, body=None): + super().__init__(message) + self.body = body + + class FakeStreamedResponse: + model_name = "test-model" + provider_name = "test" + provider_url = "http://test" + timestamp = None + _parts_manager = ModelResponsePartsManager() + model_request_parameters = ModelRequestParameters( + function_tools=[], output_tools=[] + ) + final_result_event = None + provider_response_id = None + provider_details = None + finish_reason = None + + async def _get_event_iterator(self): + raise FakeAPIError( + "Failed to parse tool call arguments as JSON", + body={ + "message": "Failed to parse tool call arguments as JSON", + "type": "invalid_request_error", + "code": "tool_use_failed", + "failed_generation": '{"name": "create_elements", "arguments": {truncated', + }, + ) + yield # pragma: no cover + + inner = MagicMock() + + @asynccontextmanager + async def fake_request_stream(*args, **kwargs): + yield FakeStreamedResponse() + + inner.request_stream = fake_request_stream + + model = _make_retrying(inner) + async with model.request_stream( + [], None, ModelRequestParameters(function_tools=[], output_tools=[]) + ) as stream: + events = [e async for e in stream] + + start_events = [e for e in events if isinstance(e, PartStartEvent)] + assert len(start_events) == 1 + assert isinstance(start_events[0].part, ToolCallPart) + assert start_events[0].part.tool_name == "create_elements" + assert start_events[0].part.args == "{}" + + +@pytest.mark.asyncio +async def test_request_stream_reraises_after_yield(): + """Errors during stream __aexit__ (non-recoverable) must re-raise.""" + + from contextlib import asynccontextmanager + from unittest.mock import MagicMock + + from pydantic_ai.models import ModelRequestParameters + + inner = MagicMock() + + @asynccontextmanager + async def stream_that_fails_during_consumption(*args, **kwargs): + # Yield a mock stream, then raise on __aexit__ + yield MagicMock() + raise Exception("some unrelated error") + + inner.request_stream = stream_that_fails_during_consumption + + model = _make_retrying(inner) + with pytest.raises(Exception, match="some unrelated error"): + async with model.request_stream( + [], None, ModelRequestParameters(function_tools=[], output_tools=[]) + ) as stream: + pass # stream consumed, then __aexit__ raises diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_telemetry.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_telemetry.py index 4e0b8b095a..d1a0835c91 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_telemetry.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_telemetry.py @@ -1,10 +1,17 @@ +import json from unittest.mock import MagicMock, patch import pytest -import udspy from baserow_enterprise.assistant.models import AssistantChat -from baserow_enterprise.assistant.telemetry import PosthogTracingCallback +from baserow_enterprise.assistant.telemetry import ( + PosthogSpanProcessor, + PosthogTracingCallback, + _pydantic_messages_to_posthog, + _tool_calls, + _trace_ctx, + _TraceContext, +) @pytest.fixture @@ -16,18 +23,6 @@ def assistant_chat_fixture(enterprise_data_fixture): ) -@pytest.fixture(autouse=True) -def mock_posthog_openai(): - with ( - udspy.settings.context(lm=udspy.LM(model="fake-model")), - patch("posthog.ai.openai.AsyncOpenAI") as mock, - ): - # Configure the mock if needed - mock.return_value = MagicMock() - mock.return_value.model = "test-model" - yield mock - - @pytest.mark.django_db class TestPosthogTracingCallback: @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") @@ -89,209 +84,563 @@ def test_trace_context_manager_exception( assert call_args.kwargs["properties"]["$ai_is_error"] is True @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") - def test_on_module_start_end(self, mock_get_client, assistant_chat_fixture): - """Test module execution tracing.""" + def test_trace_with_output(self, mock_get_client, assistant_chat_fixture): + """Test that trace output is captured when set.""" mock_posthog = MagicMock() mock_get_client.return_value = mock_posthog callback = PosthogTracingCallback() - # Initialize context manually - callback.chat = assistant_chat_fixture - callback.user_id = str(assistant_chat_fixture.user_id) - callback.workspace_id = str(assistant_chat_fixture.workspace_id) - callback.chat_uuid = str(assistant_chat_fixture.uuid) - callback.trace_id = "trace-123" - callback.span_ids = ["root-span"] - callback.spans = {} - callback.enabled = True - - # Mock a CoT module - mock_module = MagicMock(spec=udspy.ChainOfThought) - mock_module.__class__ = udspy.ChainOfThought - mock_signature = MagicMock() - mock_signature.get_input_fields.return_value = {"q": 1} - mock_signature.get_output_fields.return_value = { - "a": 1 - } # Should be dict, not list - mock_signature.get_instructions.return_value = "Test instructions" - mock_module.original_signature = mock_signature - - # Start module - callback.on_module_start( - call_id="call-1", instance=mock_module, inputs={"kwargs": {"q": "test"}} - ) - assert len(callback.span_ids) == 2 - assert len(callback.spans) == 1 + with callback.trace(assistant_chat_fixture, "Hello"): + callback.set_trace_output("The answer is 42") + + call_args = mock_posthog.capture.call_args + props = call_args.kwargs["properties"] + assert props["$ai_output_state"] == {"answer": "The answer is 42"} + + @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") + def test_trace_sets_and_clears_context_var( + self, mock_get_client, assistant_chat_fixture + ): + """Test that the ContextVar is set inside the trace and cleared after.""" + + mock_get_client.return_value = MagicMock() + + callback = PosthogTracingCallback() + + # Before trace, context should be None + assert _trace_ctx.get() is None + + with callback.trace(assistant_chat_fixture, "Hello"): + ctx = _trace_ctx.get() + assert ctx is not None + assert ctx.trace_id == callback.trace_id + assert ctx.user_id == str(assistant_chat_fixture.user_id) + assert ctx.workspace_id == str(assistant_chat_fixture.workspace_id) + assert ctx.chat_uuid == str(assistant_chat_fixture.uuid) + + # After trace, context should be cleared + assert _trace_ctx.get() is None + + +class TestPydanticMessagesToPosthog: + """Test the message format conversion utility.""" + + def test_convert_text_message(self): + """Test converting a simple text message.""" + + messages = [{"role": "user", "parts": [{"type": "text", "content": "Hello"}]}] + result = _pydantic_messages_to_posthog(messages) + + assert len(result) == 1 + assert result[0]["role"] == "user" + assert result[0]["content"] == [{"type": "text", "text": "Hello"}] + + def test_convert_tool_call(self): + """Test converting a tool call message.""" + + messages = [ + { + "role": "assistant", + "parts": [ + { + "type": "tool_call", + "id": "call_123", + "name": "list_tables", + "arguments": {"database_id": 1}, + } + ], + } + ] + result = _pydantic_messages_to_posthog(messages) + + assert result[0]["role"] == "assistant" + tc = result[0]["content"][0] + assert tc["type"] == "tool_call" + assert tc["tool_call_id"] == "call_123" + assert tc["name"] == "list_tables" + assert tc["arguments"] == {"database_id": 1} + + def test_convert_tool_return(self): + """Test converting a tool return message.""" + + messages = [ + { + "role": "tool", + "parts": [ + { + "type": "tool_return", + "tool_call_id": "call_123", + "content": "Tables: Users, Orders", + } + ], + } + ] + result = _pydantic_messages_to_posthog(messages) + + assert result[0]["content"][0]["type"] == "tool_result" + assert result[0]["content"][0]["tool_call_id"] == "call_123" + + +class TestPosthogSpanProcessor: + """Test the OpenTelemetry span processor for PostHog.""" + + def _make_mock_span( + self, + name, + kind, + attrs=None, + start_time=None, + end_time=None, + parent_span_id=None, + span_id=0x1234, + ): + """Create a mock ReadableSpan.""" + + span = MagicMock() + span.name = name + span.kind = kind + span.attributes = attrs or {} + span.start_time = start_time or 1000000000 # 1s in ns + span.end_time = end_time or 2000000000 # 2s in ns + span.events = [] + + # Context + span.context = MagicMock() + span.context.span_id = span_id + + # Parent + if parent_span_id is not None: + span.parent = MagicMock() + span.parent.span_id = parent_span_id + else: + span.parent = None + + return span - # End module - callback.on_module_end( - call_id="call-1", outputs={"a": "result"}, exception=None + @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") + def test_generation_span(self, mock_get_client): + """Test that a 'chat' span is mapped to $ai_generation.""" + + from opentelemetry.trace import SpanKind + + mock_posthog = MagicMock() + mock_get_client.return_value = mock_posthog + + processor = PosthogSpanProcessor() + + span = self._make_mock_span( + name="chat groq:llama-3.3-70b", + kind=SpanKind.CLIENT, + attrs={ + "gen_ai.request.model": "llama-3.3-70b", + "gen_ai.response.model": "llama-3.3-70b", + "gen_ai.provider.name": "groq", + "gen_ai.usage.input_tokens": 100, + "gen_ai.usage.output_tokens": 50, + "gen_ai.input.messages": json.dumps( + [{"role": "user", "parts": [{"type": "text", "content": "Hi"}]}] + ), + "gen_ai.output.messages": json.dumps( + [ + { + "role": "assistant", + "parts": [{"type": "text", "content": "Hello!"}], + } + ] + ), + }, + parent_span_id=0xABCD, ) - assert len(callback.span_ids) == 1 - assert len(callback.spans) == 0 + ctx = _TraceContext( + trace_id="trace-123", + user_id="user-456", + workspace_id="ws-789", + chat_uuid="chat-abc", + ) + token = _trace_ctx.set(ctx) + try: + processor.on_end(span) + finally: + _trace_ctx.reset(token) - # Verify span event was called mock_posthog.capture.assert_called_once() - call_args = mock_posthog.capture.call_args - - # Check the event structure - assert call_args.kwargs["distinct_id"] == str(assistant_chat_fixture.user_id) - assert call_args.kwargs["event"] == "$ai_span" - assert "timestamp" in call_args.kwargs + call = mock_posthog.capture.call_args + assert call.kwargs["distinct_id"] == "user-456" + assert call.kwargs["event"] == "$ai_generation" - # Check properties - props = call_args.kwargs["properties"] + props = call.kwargs["properties"] assert props["$ai_trace_id"] == "trace-123" - assert props["$ai_session_id"] == str(assistant_chat_fixture.uuid) - assert props["workspace_id"] == str(assistant_chat_fixture.workspace_id) - assert props["$ai_span_name"] == "ChainOfThought" - assert props["$ai_span_id"] == "call-1" - assert props["$ai_parent_span_id"] == "root-span" - assert "$ai_input_state" in props - assert props["$ai_output_state"] == {"a": "result"} - assert props["$ai_latency"] >= 0 - assert props["$ai_is_error"] is False + assert props["$ai_session_id"] == "chat-abc" + assert props["workspace_id"] == "ws-789" + assert props["$ai_model"] == "llama-3.3-70b" + assert props["$ai_provider"] == "groq" + assert props["$ai_input_tokens"] == 100 + assert props["$ai_output_tokens"] == 50 + assert props["$ai_latency"] == pytest.approx(1.0, abs=0.01) + assert props["$ai_parent_id"] == f"{0xABCD:016x}" + + # Check message format conversion + assert len(props["$ai_input"]) == 1 + assert props["$ai_input"][0]["role"] == "user" + assert props["$ai_input"][0]["content"][0]["text"] == "Hi" - def test_on_lm_start(self, assistant_chat_fixture): - """Test LM start tracing.""" + @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") + def test_tool_span(self, mock_get_client): + """Test that a 'running tool' span is mapped to $ai_span.""" - callback = PosthogTracingCallback() - callback.chat = assistant_chat_fixture - callback.user_id = "user-1" - callback.workspace_id = "ws-1" - callback.chat_uuid = "chat-1" - callback.trace_id = "trace-1" - callback.span_ids = ["root"] + from opentelemetry.trace import SpanKind - mock_lm = MagicMock() - mock_lm.provider = "openai" + mock_posthog = MagicMock() + mock_get_client.return_value = mock_posthog - inputs = {"kwargs": {}} - callback.on_lm_start("call-1", mock_lm, inputs) + processor = PosthogSpanProcessor() + + span = self._make_mock_span( + name="running tool", + kind=SpanKind.INTERNAL, + attrs={ + "gen_ai.tool.name": "list_tables", + "tool_arguments": '{"database_id": 1}', + "tool_response": "Found 3 tables: Users, Orders, Products", + }, + parent_span_id=0x5678, + ) + + ctx = _TraceContext( + trace_id="trace-123", + user_id="user-456", + workspace_id="ws-789", + chat_uuid="chat-abc", + ) + token = _trace_ctx.set(ctx) + try: + processor.on_end(span) + finally: + _trace_ctx.reset(token) - assert len(callback.span_ids) == 2 - assert inputs["kwargs"]["posthog_distinct_id"] == "user-1" - assert inputs["kwargs"]["posthog_trace_id"] == "trace-1" - assert inputs["kwargs"]["posthog_properties"]["$ai_provider"] == "openai" + mock_posthog.capture.assert_called_once() + call = mock_posthog.capture.call_args + assert call.kwargs["event"] == "$ai_span" + + props = call.kwargs["properties"] + assert props["$ai_span_name"] == "Tool: list_tables" + assert props["$ai_input_state"] == {"database_id": 1} + assert "Found 3 tables" in props["$ai_output_state"] + assert props["$ai_latency"] == pytest.approx(1.0, abs=0.01) @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") - def test_on_tool_start_end(self, mock_get_client, assistant_chat_fixture): - """Test tool execution tracing.""" + def test_agent_run_span_is_exported(self, mock_get_client): + """Test that 'agent run' spans are exported as $ai_span with agent name, + system prompt, user input, and final output.""" + + from opentelemetry.trace import SpanKind mock_posthog = MagicMock() mock_get_client.return_value = mock_posthog - callback = PosthogTracingCallback() - callback.chat = assistant_chat_fixture - callback.user_id = str(assistant_chat_fixture.user_id) - callback.workspace_id = str(assistant_chat_fixture.workspace_id) - callback.chat_uuid = str(assistant_chat_fixture.uuid) - callback.trace_id = "trace-1" - callback.span_ids = ["root"] - callback.spans = {} - callback.enabled = True - - mock_tool = MagicMock() - mock_tool.name = "test_tool" - - # Start tool - callback.on_tool_start( - call_id="call-1", instance=mock_tool, inputs={"arg": "val"} + processor = PosthogSpanProcessor() + + system_instructions = json.dumps( + [{"type": "text", "content": "You are a helpful assistant."}] + ) + all_messages = json.dumps( + [ + { + "role": "user", + "parts": [{"type": "text", "content": "Create a table"}], + }, + { + "role": "model-response", + "parts": [{"type": "text", "content": "Done!"}], + }, + ] + ) + + span = self._make_mock_span( + name="agent run", + kind=SpanKind.INTERNAL, + attrs={ + "agent_name": "main_agent", + "gen_ai.system_instructions": system_instructions, + "pydantic_ai.all_messages": all_messages, + "final_result": '{"table_id": 1}', + }, + parent_span_id=0x9999, + ) + + ctx = _TraceContext( + trace_id="trace-123", + user_id="user-456", + workspace_id="ws-789", + chat_uuid="chat-abc", + ) + token = _trace_ctx.set(ctx) + try: + processor.on_end(span) + finally: + _trace_ctx.reset(token) + + mock_posthog.capture.assert_called_once() + call = mock_posthog.capture.call_args + assert call.kwargs["event"] == "$ai_span" + + props = call.kwargs["properties"] + assert props["$ai_span_name"] == "Agent: main_agent" + assert props["$ai_trace_id"] == "trace-123" + assert props["$ai_latency"] == pytest.approx(1.0, abs=0.01) + assert props["$ai_parent_id"] == f"{0x9999:016x}" + assert ( + props["$ai_input_state"]["system_prompt"] == "You are a helpful assistant." ) + assert props["$ai_input_state"]["user_prompt"] == "Create a table" + assert props["$ai_output_state"] == {"table_id": 1} - assert len(callback.spans) == 1 + @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") + def test_agent_run_span_subagent_label(self, mock_get_client): + """Test that sub-agent spans get their own distinct label and handle + string final_result.""" + + from opentelemetry.trace import SpanKind - # End tool - callback.on_tool_end(call_id="call-1", outputs="result", exception=None) + mock_posthog = MagicMock() + mock_get_client.return_value = mock_posthog + + processor = PosthogSpanProcessor() + + span = self._make_mock_span( + name="agent run", + kind=SpanKind.INTERNAL, + attrs={ + "agent_name": "sample_row_agent", + "final_result": "Rows created successfully", + }, + ) + + ctx = _TraceContext( + trace_id="trace-123", + user_id="user-456", + workspace_id="ws-789", + chat_uuid="chat-abc", + ) + token = _trace_ctx.set(ctx) + try: + processor.on_end(span) + finally: + _trace_ctx.reset(token) - # Verify event - mock_posthog.capture.assert_called() props = mock_posthog.capture.call_args.kwargs["properties"] - assert props["$ai_span_name"] == "Tool: test_tool" - assert props["$ai_input_state"] == {"arg": "val"} - assert props["$ai_output_state"] == "result" + assert props["$ai_span_name"] == "Agent: sample_row_agent" + assert props["$ai_output_state"] == "Rows created successfully" @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") - def test_on_module_end_with_exception( - self, mock_get_client, assistant_chat_fixture - ): - """Test that exception string is captured in $ai_output_state.""" + def test_running_tools_skipped_and_parent_remapped(self, mock_get_client): + """Test that 'running tools' is not emitted and child tool spans + have their parent remapped to the grandparent (agent span).""" + + from opentelemetry.trace import SpanKind mock_posthog = MagicMock() mock_get_client.return_value = mock_posthog - callback = PosthogTracingCallback() - callback.chat = assistant_chat_fixture - callback.user_id = str(assistant_chat_fixture.user_id) - callback.workspace_id = str(assistant_chat_fixture.workspace_id) - callback.chat_uuid = str(assistant_chat_fixture.uuid) - callback.trace_id = "trace-123" - callback.span_ids = ["root-span"] - callback.spans = {} - callback.enabled = True - - # Mock a module - mock_module = MagicMock(spec=udspy.ChainOfThought) - mock_module.__class__ = udspy.ChainOfThought - mock_signature = MagicMock() - mock_signature.get_input_fields.return_value = {"q": 1} - mock_signature.get_output_fields.return_value = {"a": 1} - mock_signature.get_instructions.return_value = "Test instructions" - mock_module.original_signature = mock_signature - - # Start module - callback.on_module_start( - call_id="call-1", instance=mock_module, inputs={"kwargs": {"q": "test"}} + processor = PosthogSpanProcessor() + + agent_span_id = 0xAAAA + tools_group_span_id = 0xBBBB + tool_span_id = 0xCCCC + + # 1) "running tools" starts — processor records the parent mapping. + tools_group_span = self._make_mock_span( + name="running tools", + kind=SpanKind.INTERNAL, + span_id=tools_group_span_id, + parent_span_id=agent_span_id, + ) + processor.on_start(tools_group_span) + + # 2) "running tool" ends — its direct parent is the tools group, + # but the processor should remap to the agent span. + tool_span = self._make_mock_span( + name="running tool", + kind=SpanKind.INTERNAL, + attrs={ + "gen_ai.tool.name": "create_tables", + "tool_arguments": "{}", + "tool_response": "ok", + }, + span_id=tool_span_id, + parent_span_id=tools_group_span_id, ) - # End module with exception - test_exception = ValueError("Test error message") - callback.on_module_end(call_id="call-1", outputs=None, exception=test_exception) + ctx = _TraceContext(trace_id="t", user_id="u", workspace_id="w", chat_uuid="c") + token = _trace_ctx.set(ctx) + try: + processor.on_end(tool_span) - # Verify exception string is captured - mock_posthog.capture.assert_called_once() - call_args = mock_posthog.capture.call_args - props = call_args.kwargs["properties"] + # Tool span's parent should be the agent, not the tools group. + props = mock_posthog.capture.call_args.kwargs["properties"] + assert props["$ai_parent_id"] == f"{agent_span_id:016x}" - assert props["$ai_is_error"] is True - assert props["$ai_output_state"] == "Test error message" + mock_posthog.capture.reset_mock() + + # 3) "running tools" ends — should NOT emit anything. + processor.on_end(tools_group_span) + mock_posthog.capture.assert_not_called() + finally: + _trace_ctx.reset(token) @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") - def test_on_tool_end_with_exception(self, mock_get_client, assistant_chat_fixture): - """Test that exception string is captured in $ai_output_state for tools.""" + def test_without_trace_context_is_noop(self, mock_get_client): + """Test that spans without a trace context are silently ignored.""" + + from opentelemetry.trace import SpanKind mock_posthog = MagicMock() mock_get_client.return_value = mock_posthog - callback = PosthogTracingCallback() - callback.chat = assistant_chat_fixture - callback.user_id = str(assistant_chat_fixture.user_id) - callback.workspace_id = str(assistant_chat_fixture.workspace_id) - callback.chat_uuid = str(assistant_chat_fixture.uuid) - callback.trace_id = "trace-1" - callback.span_ids = ["root"] - callback.spans = {} - callback.enabled = True - - mock_tool = MagicMock() - mock_tool.name = "test_tool" - - # Start tool - callback.on_tool_start( - call_id="call-1", instance=mock_tool, inputs={"arg": "val"} + processor = PosthogSpanProcessor() + + span = self._make_mock_span( + name="chat groq:llama-3.3-70b", + kind=SpanKind.CLIENT, ) - # End tool with exception - test_exception = RuntimeError("Tool execution failed") - callback.on_tool_end(call_id="call-1", outputs=None, exception=test_exception) + # No trace context set + processor.on_end(span) - # Verify exception string is captured - mock_posthog.capture.assert_called_once() - call_args = mock_posthog.capture.call_args - props = call_args.kwargs["properties"] + mock_posthog.capture.assert_not_called() + + @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") + def test_multiple_spans(self, mock_get_client): + """Test that multiple spans are all processed.""" + + from opentelemetry.trace import SpanKind + + mock_posthog = MagicMock() + mock_get_client.return_value = mock_posthog + + processor = PosthogSpanProcessor() + + generation_span = self._make_mock_span( + name="chat openai:gpt-4o", + kind=SpanKind.CLIENT, + attrs={ + "gen_ai.request.model": "gpt-4o", + "gen_ai.provider.name": "openai", + "gen_ai.usage.input_tokens": 200, + "gen_ai.usage.output_tokens": 80, + }, + span_id=0x1111, + ) + tool_span = self._make_mock_span( + name="running tool", + kind=SpanKind.INTERNAL, + attrs={ + "gen_ai.tool.name": "create_table", + "tool_arguments": "{}", + "tool_response": "Created table", + }, + span_id=0x2222, + ) + + ctx = _TraceContext( + trace_id="trace-multi", + user_id="user-1", + workspace_id="ws-1", + chat_uuid="chat-1", + ) + token = _trace_ctx.set(ctx) + try: + processor.on_end(generation_span) + processor.on_end(tool_span) + finally: + _trace_ctx.reset(token) + + assert mock_posthog.capture.call_count == 2 + events = [c.kwargs["event"] for c in mock_posthog.capture.call_args_list] + assert "$ai_generation" in events + assert "$ai_span" in events + + +class TestSetupInstrumentation: + """Test the one-time instrumentation setup.""" + + @patch("baserow_enterprise.assistant.telemetry._instrumentation_ready", False) + @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") + def test_setup_skipped_when_posthog_disabled(self, mock_get_client): + """Test that setup is skipped when POSTHOG_ENABLED is False.""" + + from baserow_enterprise.assistant.telemetry import setup_instrumentation + + # POSTHOG_ENABLED is False in test settings + setup_instrumentation() + + # Should not have called get_posthog_client (nothing was set up) + mock_get_client.assert_not_called() + + +class TestEndToEndOtelPipeline: + """Integration: verify that a real pydantic-ai Agent run produces + PostHog events via the OTel span exporter.""" + + @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") + def test_agent_run_produces_posthog_events(self, mock_get_client): + """A real Agent.run_sync() inside a trace() should emit both + $ai_trace and $ai_generation events via PostHog.""" + + from opentelemetry.sdk.trace import TracerProvider as _TP + from pydantic_ai import Agent, InstrumentationSettings + + mock_posthog = MagicMock() + mock_get_client.return_value = mock_posthog + + # Wire up the same pipeline that setup_instrumentation() creates. + tp = _TP() + tp.add_span_processor(PosthogSpanProcessor()) + Agent.instrument_all( + InstrumentationSettings(tracer_provider=tp, include_content=True) + ) - assert props["$ai_is_error"] is True - assert props["$ai_output_state"] == "Tool execution failed" + try: + # Set trace context (simulates PosthogTracingCallback.trace()). + ctx = _TraceContext( + trace_id="e2e-trace", + user_id="e2e-user", + workspace_id="e2e-ws", + chat_uuid="e2e-chat", + ) + tok = _trace_ctx.set(ctx) + tools_tok = _tool_calls.set([]) + + try: + agent = Agent( + output_type=str, + instructions="Reply with 'pong'.", + name="e2e_test_agent", + ) + agent.run_sync("ping", model="test") + finally: + _trace_ctx.reset(tok) + _tool_calls.reset(tools_tok) + + # Verify PostHog received at least one $ai_generation event. + events = [c.kwargs["event"] for c in mock_posthog.capture.call_args_list] + assert "$ai_generation" in events, ( + f"Expected $ai_generation in captured events, got: {events}" + ) + + # Verify the trace metadata was attached. + gen_call = next( + c + for c in mock_posthog.capture.call_args_list + if c.kwargs["event"] == "$ai_generation" + ) + props = gen_call.kwargs["properties"] + assert props["$ai_trace_id"] == "e2e-trace" + assert props["$ai_session_id"] == "e2e-chat" + assert props["workspace_id"] == "e2e-ws" + finally: + # Clean up global instrumentation so other tests aren't affected. + Agent.instrument_all(None) diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/utils.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/utils.py index eebb6175ab..e4cc97f51d 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/utils.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/utils.py @@ -1,3 +1,27 @@ -from baserow_enterprise.assistant.assistant import ToolHelpers +from unittest.mock import MagicMock -fake_tool_helpers = ToolHelpers(lambda x: None, lambda x: None) +from baserow_enterprise.assistant.deps import AssistantDeps, ToolHelpers + + +def create_fake_tool_helpers() -> ToolHelpers: + """Create a fresh ToolHelpers instance for testing.""" + return ToolHelpers(lambda x: None, lambda x: None) + + +def make_test_ctx(user, workspace, tool_helpers=None): + """ + Build a mock ``RunContext[AssistantDeps]`` for unit-testing tool functions. + + Returns a ``MagicMock`` whose ``.deps`` attribute is a real + ``AssistantDeps`` instance. + """ + + if tool_helpers is None: + tool_helpers = create_fake_tool_helpers() + ctx = MagicMock() + ctx.deps = AssistantDeps( + user=user, + workspace=workspace, + tool_helpers=tool_helpers, + ) + return ctx diff --git a/enterprise/backend/tests/baserow_enterprise_tests/conftest.py b/enterprise/backend/tests/baserow_enterprise_tests/conftest.py index d309625c87..6596794695 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/conftest.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/conftest.py @@ -24,9 +24,9 @@ def set_openai_api_key_env_var(): """ Set a dummy OpenAI API key for tests to prevent client instantiation errors. - udspy.LM() creates an OpenAI client that raises an error if OPENAI_API_KEY is not - set during client instantiation. This fixture ensures tests don't fail due to - missing API key configuration, which is not needed anyway. + Some pydantic-ai model backends create an OpenAI client that raises an error + if OPENAI_API_KEY is not set during client instantiation. This fixture ensures + tests don't fail due to missing API key configuration. """ if not os.getenv("OPENAI_API_KEY"): diff --git a/enterprise/backend/tests/baserow_enterprise_tests/enterprise/test_enterprise_license.py b/enterprise/backend/tests/baserow_enterprise_tests/enterprise/test_enterprise_license.py index 418f84e2b4..97d51450da 100755 --- a/enterprise/backend/tests/baserow_enterprise_tests/enterprise/test_enterprise_license.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/enterprise/test_enterprise_license.py @@ -258,7 +258,7 @@ def test_user_data_no_enterprise_features_instance_wide_not_active( } -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) @responses.activate def test_check_licenses_with_enterprise_license_sends_usage_data( @@ -303,7 +303,7 @@ def test_check_licenses_with_enterprise_license_sends_usage_data( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_enterprise_license_counts_viewers_as_free( enterprise_data_fixture, data_fixture @@ -350,7 +350,7 @@ def test_enterprise_license_counts_viewers_as_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_who_is_editor_in_one_workspace_and_viewer_in_another_is_not_free( enterprise_data_fixture, data_fixture @@ -393,7 +393,7 @@ def test_user_who_is_editor_in_one_workspace_and_viewer_in_another_is_not_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_marked_for_deletion_is_not_counted_as_a_paid_user( enterprise_data_fixture, data_fixture @@ -439,7 +439,7 @@ def test_user_marked_for_deletion_is_not_counted_as_a_paid_user( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_deactivated_user_is_not_counted_as_a_paid_user( enterprise_data_fixture, data_fixture @@ -584,7 +584,7 @@ def test_enterprise_license_being_unregistered_sends_signal_to_all( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_paid_table_role_is_not_free( enterprise_data_fixture, data_fixture, synced_roles @@ -622,7 +622,7 @@ def test_user_with_paid_table_role_is_not_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_free_table_role_is_free( enterprise_data_fixture, data_fixture, synced_roles @@ -660,7 +660,7 @@ def test_user_with_free_table_role_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_paid_database_role_is_not_free( enterprise_data_fixture, data_fixture, synced_roles @@ -698,7 +698,7 @@ def test_user_with_paid_database_role_is_not_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_free_database_role_is_free( enterprise_data_fixture, data_fixture, synced_roles @@ -736,7 +736,7 @@ def test_user_with_free_database_role_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_paid_table_role_is_not_free_from_team( enterprise_data_fixture, data_fixture, synced_roles @@ -777,7 +777,7 @@ def test_user_with_paid_table_role_is_not_free_from_team( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_free_table_role_is_free_from_team( enterprise_data_fixture, data_fixture, synced_roles @@ -818,7 +818,7 @@ def test_user_with_free_table_role_is_free_from_team( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_paid_database_role_is_not_free_from_team( enterprise_data_fixture, data_fixture, synced_roles @@ -859,7 +859,7 @@ def test_user_with_paid_database_role_is_not_free_from_team( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_free_database_role_is_free_from_team( enterprise_data_fixture, data_fixture, synced_roles @@ -900,7 +900,7 @@ def test_user_with_free_database_role_is_free_from_team( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_in_deleted_team_with_paid_role_is_free( enterprise_data_fixture, data_fixture, synced_roles @@ -943,7 +943,7 @@ def test_user_in_deleted_team_with_paid_role_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_inactive_user_with_paid_role_is_free( enterprise_data_fixture, data_fixture, synced_roles @@ -1002,7 +1002,7 @@ def test_inactive_user_with_paid_role_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_inactive_user_in_team_with_paid_role_is_free( enterprise_data_fixture, data_fixture, synced_roles @@ -1064,7 +1064,7 @@ def test_inactive_user_in_team_with_paid_role_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_to_be_deleted_with_paid_role_is_free( enterprise_data_fixture, data_fixture, synced_roles @@ -1123,7 +1123,7 @@ def test_user_to_be_deleted_with_paid_role_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_to_be_deleted_in_team_with_paid_role_is_free( enterprise_data_fixture, data_fixture, synced_roles @@ -1185,7 +1185,7 @@ def test_user_to_be_deleted_in_team_with_paid_role_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_complex_free_vs_paid_scenario( enterprise_data_fixture, data_fixture, synced_roles @@ -1281,7 +1281,7 @@ def test_complex_free_vs_paid_scenario( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_role_paid_on_trashed_database_is_free( enterprise_data_fixture, data_fixture, synced_roles @@ -1341,7 +1341,7 @@ def test_user_with_role_paid_on_trashed_database_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_role_paid_on_database_in_trashed_workspace_is_free( enterprise_data_fixture, data_fixture, synced_roles, django_assert_num_queries @@ -1401,7 +1401,7 @@ def test_user_with_role_paid_on_database_in_trashed_workspace_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_role_paid_on_trashed_table_is_free( enterprise_data_fixture, data_fixture, synced_roles @@ -1461,7 +1461,7 @@ def test_user_with_role_paid_on_trashed_table_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_in_team_with_role_paid_on_trashed_database_is_free( enterprise_data_fixture, data_fixture, synced_roles @@ -1524,7 +1524,7 @@ def test_user_in_team_with_role_paid_on_trashed_database_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_in_team_with_role_paid_on_trashed_table_is_free( enterprise_data_fixture, data_fixture, synced_roles @@ -1587,7 +1587,7 @@ def test_user_in_team_with_role_paid_on_trashed_table_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_summary_calculation_for_enterprise_doesnt_do_n_plus_one_queries( enterprise_data_fixture, data_fixture, synced_roles, django_assert_num_queries @@ -1678,7 +1678,7 @@ def test_user_summary_calculation_for_enterprise_doesnt_do_n_plus_one_queries( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_can_query_for_summary_per_workspace( enterprise_data_fixture, data_fixture, synced_roles @@ -1814,7 +1814,7 @@ def test_can_query_for_summary_per_workspace( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_team_and_user_role_picks_highest_of_either( enterprise_data_fixture, data_fixture, synced_roles @@ -1860,7 +1860,7 @@ def test_user_with_team_and_user_role_picks_highest_of_either( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_order_of_roles_is_as_expected( enterprise_data_fixture, data_fixture, synced_roles @@ -1904,7 +1904,7 @@ def test_order_of_roles_is_as_expected( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_weird_workspace_user_permission_doesnt_break_usage_check( enterprise_data_fixture, data_fixture, synced_roles @@ -1936,7 +1936,7 @@ def test_weird_workspace_user_permission_doesnt_break_usage_check( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_weird_ras_for_wrong_workspace_not_counted_when_querying_for_single_workspace_usage( enterprise_data_fixture, data_fixture, synced_roles @@ -1996,7 +1996,7 @@ def test_weird_ras_for_wrong_workspace_not_counted_when_querying_for_single_work ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_missing_roles_doesnt_cause_crash_and_members_admins_are_treated_as_non_free( enterprise_data_fixture, data_fixture, synced_roles @@ -2032,7 +2032,7 @@ def test_missing_roles_doesnt_cause_crash_and_members_admins_are_treated_as_non_ ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_orphaned_paid_role_assignments_dont_get_counted( enterprise_data_fixture, data_fixture, synced_roles diff --git a/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/assistant.scss b/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/assistant.scss index 5a73880ef1..712d4780cd 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/assistant.scss +++ b/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/assistant.scss @@ -463,6 +463,12 @@ border: 0; outline: 0; + &--collapsed { + max-height: 250px; + overflow: hidden; + mask-image: linear-gradient(to bottom, black 200px, transparent); + } + // Markdown styles p { margin: 0 0 8px; @@ -629,6 +635,32 @@ vertical-align: middle; } +.assistant__reasoning-toggle { + display: flex; + justify-content: center; + width: 100%; + padding: 2px 0 0; + margin: 0; + border: 0; + background: none; + cursor: pointer; + color: #16829c; + opacity: 0.6; + + &:hover { + opacity: 1; + } +} + +.assistant__reasoning-chevron { + font-size: 12px; + transition: transform 0.2s ease; + + &--expanded { + transform: rotate(180deg); + } +} + .assistant__chat-history-spacer { width: 300px; } diff --git a/enterprise/web-frontend/modules/baserow_enterprise/components/assistant/AssistantMessageList.vue b/enterprise/web-frontend/modules/baserow_enterprise/components/assistant/AssistantMessageList.vue index 2734c3fc45..8bf587603b 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/components/assistant/AssistantMessageList.vue +++ b/enterprise/web-frontend/modules/baserow_enterprise/components/assistant/AssistantMessageList.vue @@ -28,10 +28,27 @@
+