diff --git a/src/databricks/sql/backend/kernel/client.py b/src/databricks/sql/backend/kernel/client.py index fba814fc3..b9ff6a738 100644 --- a/src/databricks/sql/backend/kernel/client.py +++ b/src/databricks/sql/backend/kernel/client.py @@ -14,10 +14,6 @@ Phase 1 gaps documented in the integration design: -- Parameter binding (``parameters=[TSparkParameter, ...]``) is not - yet supported — the PyO3 ``Statement`` doesn't expose - ``bind_param``. ``execute_command(parameters=[...])`` raises - ``NotSupportedError``. - ``query_tags`` on execute is not supported (kernel exposes ``statement_conf`` but PyO3 doesn't surface it). - ``get_tables`` with a non-empty ``table_types`` filter applies @@ -46,6 +42,7 @@ ) from databricks.sql.backend.kernel.auth_bridge import kernel_auth_kwargs from databricks.sql.backend.kernel.result_set import KernelResultSet +from databricks.sql.backend.kernel.type_mapping import bind_tspark_params from databricks.sql.backend.types import ( BackendType, CommandId, @@ -231,11 +228,6 @@ def execute_command( ) -> Union["ResultSet", None]: if self._kernel_session is None: raise InterfaceError("Cannot execute_command without an open session.") - if parameters: - raise NotSupportedError( - "Parameter binding is not yet supported on the kernel backend " - "(PyO3 Statement.bind_param lands in a follow-up PR)." - ) if query_tags: raise NotSupportedError( "Statement-level query_tags are not yet supported on the kernel backend." @@ -248,6 +240,8 @@ def execute_command( try: try: stmt.set_sql(operation) + if parameters: + bind_tspark_params(stmt, parameters) if async_op: async_exec = stmt.submit() command_id = CommandId.from_sea_statement_id( diff --git a/src/databricks/sql/backend/kernel/type_mapping.py b/src/databricks/sql/backend/kernel/type_mapping.py index bedcdcebd..4bd38e621 100644 --- a/src/databricks/sql/backend/kernel/type_mapping.py +++ b/src/databricks/sql/backend/kernel/type_mapping.py @@ -13,19 +13,28 @@ the kernel receives Arrow schemas directly), so the mapping function stays local but the names are shared. -Parameter binding (``TSparkParameter`` → kernel ``TypedValue``) is -not yet implemented — the PyO3 ``Statement`` doesn't expose a -``bind_param`` method on this branch. It'll land in a follow-up -once that PyO3 surface ships. +Parameter binding (``TSparkParameter`` → kernel +``Statement.bind_param``) is handled by ``bind_tspark_params`` — +forwards the connector's already-string-encoded form to the kernel +binding without an intermediate Python-typed round-trip. """ from __future__ import annotations -from typing import List, Tuple +from typing import Any, List, Tuple import pyarrow from databricks.sql.backend.sea.utils.conversion import SqlType +from databricks.sql.exc import NotSupportedError +from databricks.sql.thrift_api.TCLIService import ttypes + +# Type names that the connector emits as compound TSparkParameter +# shapes (payload on ``arguments``, not ``value``). The kernel's +# parameter parser doesn't accept these yet, and our binding path +# only forwards ``value`` — so we reject them at the connector +# layer to avoid silently binding a typed NULL. +_COMPOUND_PARAM_TYPES = frozenset({"ARRAY", "MAP", "STRUCT"}) def _arrow_type_to_dbapi_string(arrow_type: pyarrow.DataType) -> str: @@ -92,3 +101,68 @@ def description_from_arrow_schema(schema: pyarrow.Schema) -> List[Tuple]: ) for field in schema ] + + +def _tspark_param_value_str(param: ttypes.TSparkParameter) -> Any: + """Extract the string-encoded value from a ``TSparkParameter``, + or ``None`` for SQL NULL. + + Native parameters (``IntegerParameter`` etc.) wrap their value + in ``TSparkParameterValue(stringValue=str(self.value))``. + ``VoidParameter._tspark_param_value()`` returns Python ``None``, + so on the wire ``param.value`` is ``None`` and we surface that + as ``None`` here. + """ + if param.value is None: + return None + return param.value.stringValue + + +def bind_tspark_params(kernel_stmt, parameters: List[ttypes.TSparkParameter]) -> None: + """Bind a list of ``TSparkParameter`` onto a kernel ``Statement``. + + The kernel expects positional bindings only (SEA v0 doesn't + accept named bindings on the wire). The connector's + ``TSparkParameter`` has an ``ordinal: bool`` flag; ``True`` means + "treat as positional in source-list order". Named-binding + parameters surface as ``NotSupportedError`` so the user gets a + clear message instead of a server-side rejection. + + Compound types (``ARRAY`` / ``MAP`` / ``STRUCT``) build a + ``TSparkParameter`` with the payload on ``arguments`` and + ``value=None`` — forwarding that would silently bind a typed + NULL. Reject up front with ``NotSupportedError`` so callers get + a clear message instead of silent data loss. + """ + for i, param in enumerate(parameters, start=1): + # ``ordinal`` on connector-native params is a bool (True for + # positional, False for named). Thrift defaults to ``None``; + # treat any non-True value with a name as a named binding so + # a future caller that forgets to set ordinal=True still gets + # rejected instead of silently dropping the name. + name = getattr(param, "name", None) + if name and getattr(param, "ordinal", None) is not True: + raise NotSupportedError( + f"Named parameter binding (got name={name!r}) is not yet " + "supported on the kernel backend; pass parameters positionally." + ) + + sql_type = param.type or "STRING" + # Compound types put their payload on ``arguments``, not + # ``value``. The kernel parser doesn't accept them yet, and + # the binding path below only forwards ``value``. Detect + # both the SQL-type name (handles ``"ARRAY"``, ``"MAP(...)"``, + # ``"STRUCT<...>"``) and the presence of ``arguments`` so a + # hand-rolled compound TSparkParameter is also caught. + base_type = sql_type.split("(", 1)[0].split("<", 1)[0].upper() + if base_type in _COMPOUND_PARAM_TYPES or getattr(param, "arguments", None): + raise NotSupportedError( + f"Compound parameter types (got {sql_type!r}) are not yet " + "supported on the kernel backend." + ) + + value_str = _tspark_param_value_str(param) + # The kernel takes 1-based ordinals; `i` is already that. + # Errors from the kernel side (bad literal, unsupported type, + # etc.) come up as KernelError and bubble through normally. + kernel_stmt.bind_param(i, value_str, sql_type) diff --git a/tests/e2e/test_kernel_backend.py b/tests/e2e/test_kernel_backend.py index 67f6e858d..d2c0c9b9c 100644 --- a/tests/e2e/test_kernel_backend.py +++ b/tests/e2e/test_kernel_backend.py @@ -199,3 +199,69 @@ def test_bad_sql_surfaces_as_databaseerror(conn): # Structured fields copied off the kernel exception: assert getattr(err, "code", None) == "SqlError" assert getattr(err, "sql_state", None) == "42P01" + + +# ── Parameter binding ───────────────────────────────────────────── + + +def test_parameterized_query_round_trips(conn): + """Positional parameter binding via the kernel backend. The + connector's native parameter classes (IntegerParameter etc.) + serialize to TSparkParameter under the hood; the kernel + backend's mapper forwards them positionally to the kernel. + """ + from databricks.sql.parameters.native import ( + IntegerParameter, + StringParameter, + BooleanParameter, + ) + + with conn.cursor() as cur: + cur.execute( + "SELECT ? AS i, ? AS s, ? AS b", + [ + IntegerParameter(42), + StringParameter("alice"), + BooleanParameter(True), + ], + ) + rows = cur.fetchall() + assert len(rows) == 1 + assert rows[0][0] == 42 + assert rows[0][1] == "alice" + assert rows[0][2] is True + + +def test_parameterized_query_with_null(conn): + """`None` in the parameter list flows through as VoidParameter + → kernel TypedValue::Null.""" + with conn.cursor() as cur: + cur.execute("SELECT ? IS NULL AS is_null", [None]) + rows = cur.fetchall() + assert rows[0][0] is True + + +def test_parameterized_query_decimal(conn): + """DECIMAL parameters carry precision/scale in the SQL type + string ('DECIMAL(p,s)') — the kernel parser extracts them so + fractional digits survive the wire. + + Uses the connector's auto-inference path + (`calculate_decimal_cast_string`) to derive precision/scale + from the value; the explicit-arg path + (`DecimalParameter(v, scale=, precision=)`) has a pre-existing + bug in this branch where the format-args are passed + `(scale, precision)` instead of `(precision, scale)` — out of + scope for this PR. + """ + import decimal + from databricks.sql.parameters.native import DecimalParameter + + with conn.cursor() as cur: + cur.execute( + "SELECT ? AS d", + [DecimalParameter(decimal.Decimal("-123.45"))], + ) + rows = cur.fetchall() + # Server echoes back as decimal.Decimal. + assert str(rows[0][0]) == "-123.45" diff --git a/tests/unit/test_kernel_client.py b/tests/unit/test_kernel_client.py index f43d8c7c7..a9e9c9090 100644 --- a/tests/unit/test_kernel_client.py +++ b/tests/unit/test_kernel_client.py @@ -234,25 +234,57 @@ def test_open_session_rejects_double_open(monkeypatch): c.open_session(session_configuration=None, catalog=None, schema=None) -def test_execute_command_rejects_parameters(): +def test_execute_command_forwards_parameters_to_bind_param(): + """``execute_command(parameters=[...])`` routes each parameter + through ``bind_tspark_params`` onto the kernel statement before + ``execute()`` is called. Replaces the prior ``NotSupportedError`` + rejection now that the kernel-side ``Statement.bind_param`` is + live (kernel PR #18).""" + from databricks.sql.thrift_api.TCLIService import ttypes + c = _make_client() c._kernel_session = MagicMock() cursor = MagicMock() cursor.arraysize = 100 cursor.buffer_size_bytes = 1024 - with pytest.raises(NotSupportedError, match="Parameter binding"): - c.execute_command( - operation="SELECT ?", - session_id=MagicMock(), - max_rows=1, - max_bytes=1, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[object()], # any non-empty list - async_op=False, - enforce_embedded_schema_correctness=False, - ) + + # Stub the statement chain so we can observe bind_param calls + # without exercising the full ExecutedStatement → arrow_schema() + # path (that's covered elsewhere). + stmt = MagicMock() + stmt.bind_param = MagicMock() + stmt.execute.return_value = MagicMock( + statement_id="stmt-id", + arrow_schema=MagicMock(return_value=pa.schema([("x", pa.int64())])), + ) + c._kernel_session.statement.return_value = stmt + + p1 = ttypes.TSparkParameter(ordinal=True, name=None, type="INT") + p1.value = ttypes.TSparkParameterValue(stringValue="42") + p2 = ttypes.TSparkParameter(ordinal=True, name=None, type="STRING") + p2.value = ttypes.TSparkParameterValue(stringValue="hello") + + c.execute_command( + operation="SELECT ?, ?", + session_id=MagicMock(), + max_rows=1, + max_bytes=1, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[p1, p2], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + # bind_param was called once per TSparkParameter, in order, with + # 1-based ordinals. + assert stmt.bind_param.call_args_list == [ + ((1, "42", "INT"), {}), + ((2, "hello", "STRING"), {}), + ] + # …and execute fired after binding. + assert stmt.execute.called def test_execute_command_rejects_query_tags(): diff --git a/tests/unit/test_kernel_type_mapping.py b/tests/unit/test_kernel_type_mapping.py index 82f62559a..bee1ee7be 100644 --- a/tests/unit/test_kernel_type_mapping.py +++ b/tests/unit/test_kernel_type_mapping.py @@ -84,3 +84,155 @@ def test_description_from_schema_reports_non_nullable_fields(): desc = description_from_arrow_schema(schema) assert desc[0][6] is False assert desc[1][6] is True + + +# ─── bind_tspark_params ────────────────────────────────────────────────── + + +def _mk_param(*, type, value, ordinal=True, name=None): + """Build a minimal TSparkParameter for tests.""" + from databricks.sql.thrift_api.TCLIService import ttypes + + p = ttypes.TSparkParameter(ordinal=ordinal, name=name, type=type) + p.value = ttypes.TSparkParameterValue(stringValue=value) if value is not None else None + return p + + +class _RecordingStmt: + """Stand-in for the kernel `Statement` pyclass — records every + `bind_param` call so tests can assert the (ordinal, value, type) + triples the mapper forwarded.""" + + def __init__(self): + self.calls = [] + + def bind_param(self, ordinal, value_str, sql_type): + self.calls.append((ordinal, value_str, sql_type)) + + +def test_bind_tspark_params_forwards_each_param_positionally(): + from databricks.sql.backend.kernel.type_mapping import bind_tspark_params + + params = [ + _mk_param(type="INT", value="42"), + _mk_param(type="STRING", value="alice"), + _mk_param(type="DATE", value="2026-05-15"), + ] + stmt = _RecordingStmt() + bind_tspark_params(stmt, params) + assert stmt.calls == [ + (1, "42", "INT"), + (2, "alice", "STRING"), + (3, "2026-05-15", "DATE"), + ] + + +def test_bind_tspark_params_null_value(): + """TSparkParameter with value=None → kernel sees value_str=None, + interpreted as SQL NULL regardless of the SQL type.""" + from databricks.sql.backend.kernel.type_mapping import bind_tspark_params + + p = _mk_param(type="STRING", value=None) + stmt = _RecordingStmt() + bind_tspark_params(stmt, [p]) + assert stmt.calls == [(1, None, "STRING")] + + +def test_bind_tspark_params_void_passes_through(): + """VoidParameter._tspark_param_value() returns Python None, so + on the wire ``param.value`` is None — the mapper forwards + value_str=None with type='VOID' and the kernel parser ignores + the value.""" + from databricks.sql.backend.kernel.type_mapping import bind_tspark_params + + p = _mk_param(type="VOID", value=None) + stmt = _RecordingStmt() + bind_tspark_params(stmt, [p]) + assert stmt.calls == [(1, None, "VOID")] + + +def test_bind_tspark_params_named_param_rejected(): + """The kernel doesn't accept named bindings on the SEA wire; + surface that at the connector layer so the user gets a pointed + error instead of a server-side rejection.""" + from databricks.sql.backend.kernel.type_mapping import bind_tspark_params + from databricks.sql.exc import NotSupportedError + + p = _mk_param(type="INT", value="42", ordinal=False, name="my_param") + stmt = _RecordingStmt() + with pytest.raises(NotSupportedError, match="(?i)named"): + bind_tspark_params(stmt, [p]) + # Nothing should have been forwarded before the rejection. + assert stmt.calls == [] + + +def test_bind_tspark_params_missing_type_defaults_to_string(): + """Defensive: a TSparkParameter with no `type` shouldn't crash + the mapper — fall back to STRING and let the kernel parse.""" + from databricks.sql.backend.kernel.type_mapping import bind_tspark_params + from databricks.sql.thrift_api.TCLIService import ttypes + + p = ttypes.TSparkParameter(ordinal=True, name=None, type=None) + p.value = ttypes.TSparkParameterValue(stringValue="hello") + stmt = _RecordingStmt() + bind_tspark_params(stmt, [p]) + assert stmt.calls == [(1, "hello", "STRING")] + + +def test_bind_tspark_params_empty_list_is_noop(): + from databricks.sql.backend.kernel.type_mapping import bind_tspark_params + + stmt = _RecordingStmt() + bind_tspark_params(stmt, []) + assert stmt.calls == [] + + +@pytest.mark.parametrize( + "sql_type", + ["ARRAY", "MAP", "STRUCT", "array", "Map(string,int)", "STRUCT"], +) +def test_bind_tspark_params_compound_types_rejected(sql_type): + """ArrayParameter / MapParameter / StructParameter build a + TSparkParameter with value=None and the payload on + ``arguments`` — forwarding that would silently bind a typed + NULL, so reject up front.""" + from databricks.sql.backend.kernel.type_mapping import bind_tspark_params + from databricks.sql.exc import NotSupportedError + + p = _mk_param(type=sql_type, value=None) + stmt = _RecordingStmt() + with pytest.raises(NotSupportedError, match="(?i)compound"): + bind_tspark_params(stmt, [p]) + assert stmt.calls == [] + + +def test_bind_tspark_params_arguments_field_rejected(): + """A TSparkParameter with ``arguments`` set is the compound + shape regardless of how the type string looks — also reject.""" + from databricks.sql.backend.kernel.type_mapping import bind_tspark_params + from databricks.sql.exc import NotSupportedError + from databricks.sql.thrift_api.TCLIService import ttypes + + p = ttypes.TSparkParameter(ordinal=True, name=None, type="ARRAY") + p.value = None + p.arguments = [ttypes.TSparkParameterValueArg(type="INT")] + stmt = _RecordingStmt() + with pytest.raises(NotSupportedError, match="(?i)compound"): + bind_tspark_params(stmt, [p]) + assert stmt.calls == [] + + +def test_bind_tspark_params_named_with_ordinal_none_rejected(): + """Defensive: a TSparkParameter with a name and ordinal=None + (Thrift default) should also be rejected as a named binding — + not silently routed positionally with the name dropped.""" + from databricks.sql.backend.kernel.type_mapping import bind_tspark_params + from databricks.sql.exc import NotSupportedError + from databricks.sql.thrift_api.TCLIService import ttypes + + p = ttypes.TSparkParameter(ordinal=None, name="my_param", type="INT") + p.value = ttypes.TSparkParameterValue(stringValue="42") + stmt = _RecordingStmt() + with pytest.raises(NotSupportedError, match="(?i)named"): + bind_tspark_params(stmt, [p]) + assert stmt.calls == []