Skip to content

Commit ca4d7bc

Browse files
authored
Add statement-level query_tags support for SEA backend (#754)
* Add statement-level query_tags support for SEA backend Signed-off-by: Sreekanth Vadigi <sreekanth.vadigi@databricks.com> * Simplify None handling in query_tags serialization Signed-off-by: Sreekanth Vadigi <sreekanth.vadigi@databricks.com> --------- Signed-off-by: Sreekanth Vadigi <sreekanth.vadigi@databricks.com>
1 parent 36fb376 commit ca4d7bc

File tree

3 files changed

+118
-5
lines changed

3 files changed

+118
-5
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -463,9 +463,7 @@ def execute_command(
463463
async_op: bool,
464464
enforce_embedded_schema_correctness: bool,
465465
row_limit: Optional[int] = None,
466-
query_tags: Optional[
467-
Dict[str, Optional[str]]
468-
] = None, # TODO: implement query_tags for SEA backend
466+
query_tags: Optional[Dict[str, Optional[str]]] = None,
469467
) -> Union[SeaResultSet, None]:
470468
"""
471469
Execute a SQL command using the SEA backend.
@@ -532,6 +530,7 @@ def execute_command(
532530
row_limit=row_limit,
533531
parameters=sea_parameters if sea_parameters else None,
534532
result_compression=result_compression,
533+
query_tags=query_tags,
535534
)
536535

537536
response_data = self._http_client._make_request(

src/databricks/sql/backend/sea/models/requests.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class ExecuteStatementRequest:
3131
wait_timeout: str = "10s"
3232
on_wait_timeout: str = "CONTINUE"
3333
row_limit: Optional[int] = None
34+
query_tags: Optional[Dict[str, Optional[str]]] = None
3435

3536
def to_dict(self) -> Dict[str, Any]:
3637
"""Convert the request to a dictionary for JSON serialization."""
@@ -60,6 +61,13 @@ def to_dict(self) -> Dict[str, Any]:
6061
for param in self.parameters
6162
]
6263

64+
# SEA API expects query_tags as an array of {key, value} objects.
65+
# None/empty values are left to the server to handle as key-only tags.
66+
if self.query_tags:
67+
result["query_tags"] = [
68+
{"key": k, "value": v} for k, v in self.query_tags.items()
69+
]
70+
6371
return result
6472

6573

tests/unit/test_sea_backend.py

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i
185185
session_config = {
186186
"ANSI_MODE": "FALSE", # Supported parameter
187187
"STATEMENT_TIMEOUT": "3600", # Supported parameter
188-
"QUERY_TAGS": "team:marketing,dashboard:abc123", # Supported parameter
188+
"QUERY_TAGS": "team:marketing,dashboard:abc123", # Supported parameter
189189
"unsupported_param": "value", # Unsupported parameter
190190
}
191191
catalog = "test_catalog"
@@ -197,7 +197,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i
197197
"session_confs": {
198198
"ansi_mode": "FALSE",
199199
"statement_timeout": "3600",
200-
"query_tags": "team:marketing,dashboard:abc123",
200+
"query_tags": "team:marketing,dashboard:abc123",
201201
},
202202
"catalog": catalog,
203203
"schema": schema,
@@ -416,6 +416,112 @@ def test_command_execution_advanced(
416416
)
417417
assert "Command failed" in str(excinfo.value)
418418

419+
def _execute_response(self):
420+
return {
421+
"statement_id": "test-statement-123",
422+
"status": {"state": "SUCCEEDED"},
423+
"manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0},
424+
"result": {"data": []},
425+
}
426+
427+
def _run_execute_command(self, sea_client, sea_session_id, mock_cursor, **kwargs):
428+
"""Helper to invoke execute_command with default args."""
429+
return sea_client.execute_command(
430+
operation="SELECT 1",
431+
session_id=sea_session_id,
432+
max_rows=100,
433+
max_bytes=1000,
434+
lz4_compression=False,
435+
cursor=mock_cursor,
436+
use_cloud_fetch=False,
437+
parameters=[],
438+
async_op=False,
439+
enforce_embedded_schema_correctness=False,
440+
**kwargs,
441+
)
442+
443+
def test_execute_command_query_tags_string_values(
444+
self, sea_client, mock_http_client, mock_cursor, sea_session_id
445+
):
446+
"""query_tags with string values are included in the request payload."""
447+
mock_http_client._make_request.return_value = self._execute_response()
448+
with patch.object(sea_client, "_response_to_result_set"):
449+
self._run_execute_command(
450+
sea_client,
451+
sea_session_id,
452+
mock_cursor,
453+
query_tags={"env": "prod", "team": "data"},
454+
)
455+
_, kwargs = mock_http_client._make_request.call_args
456+
assert kwargs["data"]["query_tags"] == [
457+
{"key": "env", "value": "prod"},
458+
{"key": "team", "value": "data"},
459+
]
460+
461+
def test_execute_command_query_tags_none_value(
462+
self, sea_client, mock_http_client, mock_cursor, sea_session_id
463+
):
464+
"""query_tags with a None value omit the value field (key-only tag)."""
465+
mock_http_client._make_request.return_value = self._execute_response()
466+
with patch.object(sea_client, "_response_to_result_set"):
467+
self._run_execute_command(
468+
sea_client,
469+
sea_session_id,
470+
mock_cursor,
471+
query_tags={"env": "prod", "team": None},
472+
)
473+
_, kwargs = mock_http_client._make_request.call_args
474+
assert kwargs["data"]["query_tags"] == [
475+
{"key": "env", "value": "prod"},
476+
{"key": "team", "value": None},
477+
]
478+
479+
def test_execute_command_no_query_tags_omitted(
480+
self, sea_client, mock_http_client, mock_cursor, sea_session_id
481+
):
482+
"""query_tags field is absent from the request when not provided."""
483+
mock_http_client._make_request.return_value = self._execute_response()
484+
with patch.object(sea_client, "_response_to_result_set"):
485+
self._run_execute_command(sea_client, sea_session_id, mock_cursor)
486+
_, kwargs = mock_http_client._make_request.call_args
487+
assert "query_tags" not in kwargs["data"]
488+
489+
def test_execute_command_empty_query_tags_omitted(
490+
self, sea_client, mock_http_client, mock_cursor, sea_session_id
491+
):
492+
"""Empty query_tags dict is treated as absent — field omitted from request."""
493+
mock_http_client._make_request.return_value = self._execute_response()
494+
with patch.object(sea_client, "_response_to_result_set"):
495+
self._run_execute_command(
496+
sea_client, sea_session_id, mock_cursor, query_tags={}
497+
)
498+
_, kwargs = mock_http_client._make_request.call_args
499+
assert "query_tags" not in kwargs["data"]
500+
501+
def test_execute_command_async_query_tags(
502+
self, sea_client, mock_http_client, mock_cursor, sea_session_id
503+
):
504+
"""query_tags are included in async execute requests (execute_async path)."""
505+
mock_http_client._make_request.return_value = {
506+
"statement_id": "test-statement-async",
507+
"status": {"state": "PENDING"},
508+
}
509+
sea_client.execute_command(
510+
operation="SELECT 1",
511+
session_id=sea_session_id,
512+
max_rows=100,
513+
max_bytes=1000,
514+
lz4_compression=False,
515+
cursor=mock_cursor,
516+
use_cloud_fetch=False,
517+
parameters=[],
518+
async_op=True,
519+
enforce_embedded_schema_correctness=False,
520+
query_tags={"job": "nightly-etl"},
521+
)
522+
_, kwargs = mock_http_client._make_request.call_args
523+
assert kwargs["data"]["query_tags"] == [{"key": "job", "value": "nightly-etl"}]
524+
419525
def test_command_management(
420526
self,
421527
sea_client,

0 commit comments

Comments
 (0)