Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
Upcoming (TBD)
==============

Features
---------
* Append dot to database name when completing on a table name.


1.71.0 (2026/05/01)
==============

Expand Down
2 changes: 1 addition & 1 deletion mycli/packages/completion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def _emit_relation_name(ctx: SuggestContext) -> list[Suggestion]:
schema = _parent_name(ctx)
if schema:
return [{'type': rel_type, 'schema': schema}]
return [{'type': 'schema'}, {'type': rel_type, 'schema': []}]
return [{'type': 'database'}, {'type': rel_type, 'schema': []}]


def _emit_on(ctx: SuggestContext) -> list[Suggestion]:
Expand Down
23 changes: 20 additions & 3 deletions mycli/sqlcompleter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from prompt_toolkit.completion.base import Document
from pygments.lexers._mysql_builtins import MYSQL_DATATYPES, MYSQL_FUNCTIONS, MYSQL_KEYWORDS
import rapidfuzz
import sqlparse

from mycli.packages.completion_engine import is_inside_quotes, suggest_type
from mycli.packages.filepaths import complete_path, parse_path, suggest_path
Expand Down Expand Up @@ -1419,6 +1420,18 @@ def get_completions(
last_for_len = last_word(word_before_cursor, include="most_punctuations")
text_for_len = last_for_len.lower()
last_for_len_paths = last_word(word_before_cursor, include='alphanum_underscore')
if statements := sqlparse.split(document.text):
total_len = 0
relevant_statement = ''
for statement in statements:
total_len = total_len + len(statement)
if document.cursor_position <= total_len:
relevant_statement = statement
break
if not relevant_statement:
relevant_statement = statements[-1]
else:
relevant_statement = document.text

if smart_completion is None:
smart_completion = self.smart_completion
Expand All @@ -1436,7 +1449,9 @@ def get_completions(
return (Completion(x[0], -len(text_for_len)) for x in matches)

completions: list[tuple[str, int, int]] = []
suggestions = suggest_type(document.text, document.text_before_cursor)
suggestions = suggest_type(relevant_statement, document.text_before_cursor)
database_is_qualifier = any(suggestion['type'] in ('table', 'view') for suggestion in suggestions)
database_is_qualifier = database_is_qualifier and not relevant_statement.lstrip().lower().startswith('show ')
rigid_sort = False
length_based_on_path = False

Expand Down Expand Up @@ -1530,15 +1545,15 @@ def get_completions(
# then only return tables that have one or more of the given columns.
# If no columns are given (or able to be parsed), return all tables
# as usual.
columns = extract_columns_from_select(document.text)
columns = extract_columns_from_select(relevant_statement)
if columns:
tables = self.populate_schema_objects(suggestion["schema"], "tables", columns)
else:
tables = self.populate_schema_objects(suggestion["schema"], "tables")

if suggestion.get("join"):
# For JOINs, suggest FK-related tables first (lower rank = higher priority)
current_tables = extract_tables(document.text)
current_tables = extract_tables(relevant_statement)
fk_map = self.dbmetadata["foreign_keys"].get(self.dbname, {}).get("tables", {})
fk_related: set[str] = set()
for tbl_schema, tbl, _alias in current_tables:
Expand Down Expand Up @@ -1602,6 +1617,8 @@ def get_completions(
self.databases,
text_before_cursor=document.text_before_cursor,
)
if database_is_qualifier:
dbs_m = ((f'{db}.', fuzziness) for db, fuzziness in dbs_m)
completions.extend([(*x, rank) for x in dbs_m])

elif suggestion["type"] == "keyword":
Expand Down
8 changes: 4 additions & 4 deletions test/pytests/test_completion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def test_emit_relation_name_with_schema_parent():

def test_emit_relation_name_without_schema_parent():
context = _build_suggest_context('view', '', None, '', empty_identifier())
assert _emit_relation_name(context) == [{'type': 'schema'}, {'type': 'view', 'schema': []}]
assert _emit_relation_name(context) == [{'type': 'database'}, {'type': 'view', 'schema': []}]


@pytest.mark.xfail
Expand Down Expand Up @@ -925,9 +925,9 @@ def test_suggest_based_on_last_token_lparen_in_function_call_suggests_columns():
('database', 'drop database ', 'drop database ', [{'type': 'database'}]),
('template', 'create database foo with template ', 'create database foo with template ', [{'type': 'database'}]),
('collate', 'collate ', 'collate ', [{'type': 'collation'}]),
('table', 'drop table ', 'drop table ', [{'type': 'schema'}, {'type': 'table', 'schema': []}]),
('view', 'drop view ', 'drop view ', [{'type': 'schema'}, {'type': 'view', 'schema': []}]),
('function', 'drop function ', 'drop function ', [{'type': 'schema'}, {'type': 'function', 'schema': []}]),
('table', 'drop table ', 'drop table ', [{'type': 'database'}, {'type': 'table', 'schema': []}]),
('view', 'drop view ', 'drop view ', [{'type': 'database'}, {'type': 'view', 'schema': []}]),
('function', 'drop function ', 'drop function ', [{'type': 'database'}, {'type': 'function', 'schema': []}]),
],
)
def test_suggest_based_on_last_token_direct_keyword_branches(token, text_before_cursor, full_text, expected):
Expand Down
69 changes: 57 additions & 12 deletions test/pytests/test_smart_completion_public_schema_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ def test_table_completion(completer, complete_event):
Completion(text="time_zone_name", start_position=0),
Completion(text="time_zone_transition", start_position=0),
Completion(text="time_zone_transition_type", start_position=0),
Completion(text="test", start_position=0),
Completion(text="`test 2`", start_position=0),
Completion(text="test.", start_position=0),
Completion(text="`test 2`.", start_position=0),
]


Expand All @@ -238,8 +238,8 @@ def test_select_filtered_table_completion(completer, complete_event):
Completion(text="time_zone_name", start_position=0),
Completion(text="time_zone_transition", start_position=0),
Completion(text="time_zone_transition_type", start_position=0),
Completion(text="test", start_position=0),
Completion(text="`test 2`", start_position=0),
Completion(text="test.", start_position=0),
Completion(text="`test 2`.", start_position=0),
]


Expand All @@ -257,8 +257,8 @@ def test_sub_select_filtered_table_completion(completer, complete_event):
Completion(text="time_zone_name", start_position=0),
Completion(text="time_zone_transition", start_position=0),
Completion(text="time_zone_transition_type", start_position=0),
Completion(text="test", start_position=0),
Completion(text="`test 2`", start_position=0),
Completion(text="test.", start_position=0),
Completion(text="`test 2`.", start_position=0),
]


Expand Down Expand Up @@ -512,8 +512,8 @@ def test_table_names_after_from(completer, complete_event):
Completion(text="time_zone_name", start_position=0),
Completion(text="time_zone_transition", start_position=0),
Completion(text="time_zone_transition_type", start_position=0),
Completion(text="test", start_position=0),
Completion(text="`test 2`", start_position=0),
Completion(text="test.", start_position=0),
Completion(text="`test 2`.", start_position=0),
]


Expand All @@ -530,6 +530,51 @@ def test_table_names_leading_partial(completer, complete_event):
]


def test_database_completion_for_table_appends_dot(completer, complete_event):
text = 'SELECT * FROM te'
position = len(text)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert Completion(text='test.', start_position=-2) in result


@pytest.mark.parametrize(
'text',
[
'SHOW TABLES FROM te',
'SHOW FULL TABLES FROM te',
'SHOW COLUMNS FROM users FROM te',
],
)
def test_show_database_completion_does_not_append_dot(completer, complete_event, text):
position = len(text)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert Completion(text='test', start_position=-2) in result
assert Completion(text='test.', start_position=-2) not in result


@pytest.mark.parametrize(
'text',
[
'SHOW TABLES FROM test; SELECT * FROM te',
],
)
def test_database_completion_for_table_ignores_previous_statement(completer, complete_event, text):
position = len(text)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert Completion(text='test.', start_position=-2) in result
assert Completion(text='test', start_position=-2) not in result


# todo: USE works interactively but not in the test suite
# this is also covered by test_show_database_completion_does_not_append_dot()
@pytest.mark.xfail
def test_database_completion_after_use_does_not_append_dot(completer, complete_event):
text = 'USE tes'
position = len(text)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert Completion(text='test', start_position=-2) in result


def test_table_names_inter_partial(completer, complete_event):
text = "SELECT * FROM time_leap"
position = len("SELECT * FROM time_leap")
Expand Down Expand Up @@ -586,8 +631,8 @@ def test_grant_on_suggets_tables_and_schemata(completer, complete_event):
position = len(text)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == [
Completion(text="test", start_position=0),
Completion(text="`test 2`", start_position=0),
Completion(text="test.", start_position=0),
Completion(text="`test 2`.", start_position=0),
Completion(text='users', start_position=0),
Completion(text='orders', start_position=0),
Completion(text='`select`', start_position=0),
Expand Down Expand Up @@ -947,8 +992,8 @@ def test_backticked_table_completion_not_required(completer, complete_event):
position = len(text)
result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == [
Completion(text='`test`', start_position=-2),
Completion(text='`test 2`', start_position=-2),
Completion(text='`test`.', start_position=-2),
Completion(text='`test 2`.', start_position=-2),
Completion(text='`time_zone`', start_position=-2),
Completion(text='`time_zone_name`', start_position=-2),
Completion(text='`time_zone_transition`', start_position=-2),
Expand Down