diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 5e7a68bc1c..81572e8893 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -4092,6 +4092,10 @@ def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wai if self._is_shutdown: return + # Track whether a specific connection was requested (e.g. after DDL, + # we must query the coordinator node). If not, we use (and can refresh) + # the control connection. + explicit_connection = bool(connection) if not connection: connection = self._connection @@ -4130,6 +4134,20 @@ def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wai if self._is_shutdown: log.debug("[control connection] Aborting wait for schema match due to shutdown") return None + elif not explicit_connection: + # The control connection was closed (e.g. the node was stopped). + # Wait briefly and try again with whatever connection the control + # connection has by now (it may have reconnected to another node). + log.debug("[control connection] Connection closed during schema agreement " + "check; will retry with a new control connection") + self._time.sleep(0.2) + elapsed = self._time.time() - start + new_conn = self._connection + if new_conn is not None and new_conn is not connection: + connection = new_conn + select_peers_query = self._get_peers_query( + self.PeersQueryType.PEERS_SCHEMA, connection) + continue else: raise diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 037d4a8888..dadb349dfa 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -18,6 +18,7 @@ from unittest.mock import Mock, ANY, call from cassandra import OperationTimedOut, SchemaTargetType, SchemaChangeType +from cassandra.connection import ConnectionShutdown from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS from cassandra.cluster import ControlConnection, _Scheduler, ProfileManager, EXEC_PROFILE_DEFAULT, ExecutionProfile from cassandra.pool import Host @@ -623,6 +624,48 @@ def test_refresh_nodes_and_tokens_add_host_detects_invalid_port(self): assert self.cluster.added_hosts[0].datacenter == "dc1" assert self.cluster.added_hosts[0].rack == "rack1" + def test_wait_for_schema_agreement_recovers_from_closed_control_connection(self): + """ + Reproduces https://github.com/scylladb/python-driver/issues/604 + + When wait_for_schema_agreement() is called without an explicit connection + and the control connection is closed mid-wait (e.g. the node was stopped), + the function should transparently pick up the new control connection and + succeed rather than propagating ConnectionShutdown to the caller. + """ + # Build a second (replacement) connection whose schema queries succeed. + new_connection = MockConnection() + new_connection.endpoint = DefaultEndPoint("192.168.1.1") + + # The initial connection raises ConnectionShutdown on the first query, + # simulating the node being stopped while we wait. + def first_call_raises(*args, **kwargs): + # Swap the control connection to the replacement before raising, + # just as the driver's reconnect logic would do. + self.control_connection._connection = new_connection + raise ConnectionShutdown("Connection is already closed") + + self.connection.wait_for_responses = Mock(side_effect=first_call_raises) + + # Before the fix this raised ConnectionShutdown; after the fix it should + # recover and return True. + result = self.control_connection.wait_for_schema_agreement() + assert result is True + + # The replacement connection must have been queried exactly once. + assert new_connection.wait_for_responses.call_count == 1 + + def test_wait_for_schema_agreement_explicit_connection_raises_on_shutdown(self): + """ + When an explicit connection is passed (DDL path, coordinator must be queried), + ConnectionShutdown should still propagate — there is no safe fallback. + """ + self.connection.wait_for_responses = Mock( + side_effect=ConnectionShutdown("Connection is already closed")) + + with self.assertRaises(ConnectionShutdown): + self.control_connection.wait_for_schema_agreement(connection=self.connection) + class EventTimingTest(unittest.TestCase): """