diff --git a/pyoverkiz/action_queue.py b/pyoverkiz/action_queue.py index 57ae5c49..54caf2cd 100644 --- a/pyoverkiz/action_queue.py +++ b/pyoverkiz/action_queue.py @@ -120,10 +120,15 @@ async def add( into a single action to respect the gateway limitation of one action per device in each action group. - :param actions: Actions to queue - :param mode: Command mode (will flush if different from pending mode) - :param label: Label for the action group - :return: QueuedExecution that resolves to exec_id when batch executes + Args: + actions: Actions to queue. + mode: Command mode, which triggers a flush if it differs from the + pending mode. + label: Label for the action group. + + Returns: + A `QueuedExecution` that resolves to the `exec_id` when the batch + executes. """ batches_to_execute: list[ tuple[list[Action], CommandMode | None, str | None, list[QueuedExecution]] diff --git a/pyoverkiz/auth/strategies.py b/pyoverkiz/auth/strategies.py index 761ba49c..6be1c68e 100644 --- a/pyoverkiz/auth/strategies.py +++ b/pyoverkiz/auth/strategies.py @@ -15,6 +15,7 @@ from aiohttp import ClientSession, FormData from botocore.client import BaseClient from botocore.config import Config +from botocore.exceptions import ClientError from warrant_lite import WarrantLite from pyoverkiz.auth.base import AuthContext, AuthStrategy @@ -212,17 +213,6 @@ async def _request_access_token( class CozytouchAuthStrategy(SessionLoginStrategy): """Authentication strategy using Cozytouch session-based login.""" - def __init__( - self, - credentials: UsernamePasswordCredentials, - session: ClientSession, - server: ServerConfig, - ssl_context: ssl.SSLContext | bool, - api_type: APIType, - ) -> None: - """Initialize CozytouchAuthStrategy with given parameters.""" - super().__init__(credentials, session, server, ssl_context, api_type) - async def login(self) -> None: """Perform login using Cozytouch username and password.""" form = FormData( @@ -265,20 +255,9 @@ async def login(self) -> None: class NexityAuthStrategy(SessionLoginStrategy): """Authentication strategy using Nexity session-based login.""" - def __init__( - self, - credentials: UsernamePasswordCredentials, - session: ClientSession, - server: ServerConfig, - ssl_context: ssl.SSLContext | bool, - api_type: APIType, - ) -> None: - """Initialize NexityAuthStrategy with given parameters.""" - super().__init__(credentials, session, server, ssl_context, api_type) - async def login(self) -> None: """Perform login using Nexity username and password.""" - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() def _client() -> BaseClient: return boto3.client( @@ -296,8 +275,11 @@ def _client() -> BaseClient: try: tokens = await loop.run_in_executor(None, aws.authenticate_user) - except Exception as error: - raise NexityBadCredentialsException() from error + except ClientError as error: + code = error.response.get("Error", {}).get("Code") + if code in {"NotAuthorizedException", "UserNotFoundException"}: + raise NexityBadCredentialsException() from error + raise id_token = tokens["AuthenticationResult"]["IdToken"] diff --git a/pyoverkiz/client.py b/pyoverkiz/client.py index fcbabc15..e40704be 100644 --- a/pyoverkiz/client.py +++ b/pyoverkiz/client.py @@ -3,10 +3,10 @@ from __future__ import annotations import logging -import os import ssl import urllib.parse from json import JSONDecodeError +from pathlib import Path from types import TracebackType from typing import Any, cast @@ -145,7 +145,7 @@ def _create_local_ssl_context() -> ssl.SSLContext: because it will load certificates from disk and do other blocking I/O. """ context = ssl.create_default_context( - cafile=os.path.dirname(os.path.realpath(__file__)) + "/overkiz-root-ca-2048.crt" + cafile=str(Path(__file__).resolve().parent / "overkiz-root-ca-2048.crt") ) # Disable strict validation introduced in Python 3.13, which doesn't work with @@ -278,7 +278,7 @@ async def close(self) -> None: async def login( self, - register_event_listener: bool | None = True, + register_event_listener: bool = True, ) -> bool: """Authenticate and create an API session allowing access to the other operations. @@ -321,7 +321,7 @@ async def get_setup(self, refresh: bool = False) -> Setup: if self.setup and not refresh: return self.setup - response = await self.__get("setup") + response = await self._get("setup") setup = Setup(**humps.decamelize(response)) @@ -343,7 +343,7 @@ async def get_diagnostic_data(self) -> JSON: This data will be masked to not return any confidential or PII data. """ - response = await self.__get("setup") + response = await self._get("setup") return obfuscate_sensitive_data(response) @@ -356,7 +356,7 @@ async def get_devices(self, refresh: bool = False) -> list[Device]: if self.devices and not refresh: return self.devices - response = await self.__get("setup/devices") + response = await self._get("setup/devices") devices = [Device(**d) for d in humps.decamelize(response)] # Cache response @@ -375,7 +375,7 @@ async def get_gateways(self, refresh: bool = False) -> list[Gateway]: if self.gateways and not refresh: return self.gateways - response = await self.__get("setup/gateways") + response = await self._get("setup/gateways") gateways = [Gateway(**g) for g in humps.decamelize(response)] # Cache response @@ -388,7 +388,7 @@ async def get_gateways(self, refresh: bool = False) -> list[Gateway]: @retry_on_auth_error async def get_execution_history(self) -> list[HistoryExecution]: """List execution history.""" - response = await self.__get("history/executions") + response = await self._get("history/executions") execution_history = [HistoryExecution(**h) for h in humps.decamelize(response)] return execution_history @@ -396,7 +396,7 @@ async def get_execution_history(self) -> list[HistoryExecution]: @retry_on_auth_error async def get_device_definition(self, deviceurl: str) -> JSON | None: """Retrieve a particular setup device definition.""" - response: dict = await self.__get( + response: dict = await self._get( f"setup/devices/{urllib.parse.quote_plus(deviceurl)}" ) @@ -405,7 +405,7 @@ async def get_device_definition(self, deviceurl: str) -> JSON | None: @retry_on_auth_error async def get_state(self, deviceurl: str) -> list[State]: """Retrieve states of requested device.""" - response = await self.__get( + response = await self._get( f"setup/devices/{urllib.parse.quote_plus(deviceurl)}/states" ) state = [State(**s) for s in humps.decamelize(response)] @@ -415,12 +415,12 @@ async def get_state(self, deviceurl: str) -> list[State]: @retry_on_auth_error async def refresh_states(self) -> None: """Ask the box to refresh all devices states for protocols supporting that operation.""" - await self.__post("setup/devices/states/refresh") + await self._post("setup/devices/states/refresh") @retry_on_auth_error async def refresh_device_states(self, deviceurl: str) -> None: """Ask the box to refresh all states of the given device for protocols supporting that operation.""" - await self.__post( + await self._post( f"setup/devices/{urllib.parse.quote_plus(deviceurl)}/states/refresh" ) @@ -435,7 +435,7 @@ async def register_event_listener(self) -> str: timeout : listening sessions are expected to call the /events/{listenerId}/fetch API on a regular basis. """ - response = await self.__post("events/register") + response = await self._post("events/register") listener_id = cast(str, response.get("id")) self.event_listener_id = listener_id @@ -453,7 +453,7 @@ async def fetch_events(self) -> list[Event]: operation (polling). """ await self._refresh_token_if_expired() - response = await self.__post(f"events/{self.event_listener_id}/fetch") + response = await self._post(f"events/{self.event_listener_id}/fetch") events = [Event(**e) for e in humps.decamelize(response)] return events @@ -464,13 +464,13 @@ async def unregister_event_listener(self) -> None: API response status is always 200, even on unknown listener ids. """ await self._refresh_token_if_expired() - await self.__post(f"events/{self.event_listener_id}/unregister") + await self._post(f"events/{self.event_listener_id}/unregister") self.event_listener_id = None @retry_on_auth_error async def get_current_execution(self, exec_id: str) -> Execution: """Get an action group execution currently running.""" - response = await self.__get(f"exec/current/{exec_id}") + response = await self._get(f"exec/current/{exec_id}") execution = Execution(**humps.decamelize(response)) return execution @@ -478,7 +478,7 @@ async def get_current_execution(self, exec_id: str) -> Execution: @retry_on_auth_error async def get_current_executions(self) -> list[Execution]: """Get all action groups executions currently running.""" - response = await self.__get("exec/current") + response = await self._get("exec/current") executions = [Execution(**e) for e in humps.decamelize(response)] return executions @@ -486,7 +486,7 @@ async def get_current_executions(self) -> list[Execution]: @retry_on_auth_error async def get_api_version(self) -> str: """Get the API version (local only).""" - response = await self.__get("apiVersion") + response = await self._get("apiVersion") return cast(str, response["protocolVersion"]) @@ -519,7 +519,7 @@ async def _execute_action_group_direct( else: url = "exec/apply" - response: dict = await self.__post(url, final_payload) + response: dict = await self._post(url, final_payload) return cast(str, response["execId"]) @@ -544,15 +544,19 @@ async def execute_action_group( The API is consistent regardless of queue configuration - always returns exec_id string directly. - :param actions: List of actions to execute - :param mode: Command mode (GEOLOCATED, INTERNAL, HIGH_PRIORITY, or None) - :param label: Label for the action group - :return: exec_id string from the executed action group + Args: + actions: List of actions to execute. + mode: Command mode (`GEOLOCATED`, `INTERNAL`, `HIGH_PRIORITY`, + or `None`). + label: Label for the action group. - Example usage:: + Returns: + The `exec_id` string from the executed action group. - # Works the same with or without queue + Example: + ```python exec_id = await client.execute_action_group([action]) + ``` """ if self._action_queue: queued = await self._action_queue.add(actions, mode, label) @@ -582,12 +586,12 @@ def get_pending_actions_count(self) -> int: @retry_on_auth_error async def cancel_command(self, exec_id: str) -> None: """Cancel a running setup-level execution.""" - await self.__delete(f"/exec/current/setup/{exec_id}") + await self._delete(f"exec/current/setup/{exec_id}") @retry_on_auth_error async def get_action_groups(self) -> list[ActionGroup]: """List the action groups (scenarios).""" - response = await self.__get("actionGroups") + response = await self._get("actionGroups") return [ ActionGroup(**action_group) for action_group in humps.decamelize(response) ] @@ -595,20 +599,20 @@ async def get_action_groups(self) -> list[ActionGroup]: @retry_on_auth_error async def get_places(self) -> Place: """List the places.""" - response = await self.__get("setup/places") + response = await self._get("setup/places") places = Place(**humps.decamelize(response)) return places @retry_on_auth_error async def execute_scenario(self, oid: str) -> str: """Execute a scenario.""" - response = await self.__post(f"exec/{oid}") + response = await self._post(f"exec/{oid}") return cast(str, response["execId"]) @retry_on_auth_error async def execute_scheduled_scenario(self, oid: str, timestamp: int) -> str: """Execute a scheduled scenario.""" - response = await self.__post(f"exec/schedule/{oid}/{timestamp}") + response = await self._post(f"exec/schedule/{oid}/{timestamp}") return cast(str, response["triggerId"]) @retry_on_auth_error @@ -618,7 +622,7 @@ async def get_setup_options(self) -> list[Option]: Per-session rate-limit : 1 calls per 1d period for this particular operation (bulk-load) Access scope : Full enduser API access (enduser/*). """ - response = await self.__get("setup/options") + response = await self._get("setup/options") options = [Option(**o) for o in humps.decamelize(response)] return options @@ -629,7 +633,7 @@ async def get_setup_option(self, option: str) -> Option | None: For example `developerMode-{gateway_id}` to understand if developer mode is on. """ - response = await self.__get(f"setup/options/{option}") + response = await self._get(f"setup/options/{option}") if response: return Option(**humps.decamelize(response)) @@ -647,7 +651,7 @@ async def get_setup_option_parameter( If the option is not available, an OverkizException will be thrown. If the parameter is not available you will receive None. """ - response = await self.__get(f"setup/options/{option}/{parameter}") + response = await self._get(f"setup/options/{option}/{parameter}") if response: return OptionParameter(**humps.decamelize(response)) @@ -657,19 +661,19 @@ async def get_setup_option_parameter( @retry_on_auth_error async def get_reference_controllable(self, controllable_name: str) -> JSON: """Get a controllable definition.""" - return await self.__get( + return await self._get( f"reference/controllable/{urllib.parse.quote_plus(controllable_name)}" ) @retry_on_auth_error async def get_reference_controllable_types(self) -> JSON: """Get details about all supported controllable types.""" - return await self.__get("reference/controllableTypes") + return await self._get("reference/controllableTypes") @retry_on_auth_error async def search_reference_devices_model(self, payload: JSON) -> JSON: """Search reference device models using a POST payload.""" - return await self.__post("reference/devices/search", payload) + return await self._post("reference/devices/search", payload) @retry_on_auth_error async def get_reference_protocol_types(self) -> list[ProtocolType]: @@ -681,23 +685,23 @@ async def get_reference_protocol_types(self) -> list[ProtocolType]: - name: Internal protocol name - label: Human-readable protocol label """ - response = await self.__get("reference/protocolTypes") + response = await self._get("reference/protocolTypes") return [ProtocolType(**protocol) for protocol in response] @retry_on_auth_error async def get_reference_timezones(self) -> JSON: """Get timezones list.""" - return await self.__get("reference/timezones") + return await self._get("reference/timezones") @retry_on_auth_error async def get_reference_ui_classes(self) -> list[str]: """Get a list of all defined UI classes.""" - return await self.__get("reference/ui/classes") + return await self._get("reference/ui/classes") @retry_on_auth_error async def get_reference_ui_classifiers(self) -> list[str]: """Get a list of all defined UI classifiers.""" - return await self.__get("reference/ui/classifiers") + return await self._get("reference/ui/classifiers") @retry_on_auth_error async def get_reference_ui_profile(self, profile_name: str) -> UIProfileDefinition: @@ -709,7 +713,7 @@ async def get_reference_ui_profile(self, profile_name: str) -> UIProfileDefiniti - states: Available states with value types and descriptions - form_factor: Whether profile is tied to a specific physical device type """ - response = await self.__get( + response = await self._get( f"reference/ui/profile/{urllib.parse.quote_plus(profile_name)}" ) return UIProfileDefinition(**humps.decamelize(response)) @@ -717,14 +721,14 @@ async def get_reference_ui_profile(self, profile_name: str) -> UIProfileDefiniti @retry_on_auth_error async def get_reference_ui_profile_names(self) -> list[str]: """Get a list of all defined UI profiles (and form-factor variants).""" - return await self.__get("reference/ui/profileNames") + return await self._get("reference/ui/profileNames") @retry_on_auth_error async def get_reference_ui_widgets(self) -> list[str]: """Get a list of all defined UI widgets.""" - return await self.__get("reference/ui/widgets") + return await self._get("reference/ui/widgets") - async def __get(self, path: str) -> Any: + async def _get(self, path: str) -> Any: """Make a GET request to the OverKiz API.""" await self._refresh_token_if_expired() headers = dict(self._auth.auth_headers(path)) @@ -737,7 +741,7 @@ async def __get(self, path: str) -> Any: await self.check_response(response) return await response.json() - async def __post( + async def _post( self, path: str, payload: JSON | None = None, data: JSON | None = None ) -> Any: """Make a POST request to the OverKiz API.""" @@ -754,7 +758,7 @@ async def __post( await self.check_response(response) return await response.json() - async def __delete(self, path: str) -> None: + async def _delete(self, path: str) -> None: """Make a DELETE request to the OverKiz API.""" await self._refresh_token_if_expired() headers = dict(self._auth.auth_headers(path)) diff --git a/pyoverkiz/enums/base.py b/pyoverkiz/enums/base.py index 9204194b..83c075bc 100644 --- a/pyoverkiz/enums/base.py +++ b/pyoverkiz/enums/base.py @@ -15,6 +15,18 @@ class UnknownEnumMixin: __missing_message__ = "Unsupported value %s has been returned for %s" + def __init_subclass__(cls, **kwargs: object) -> None: + """Validate that concrete enum subclasses define an `UNKNOWN` member.""" + super().__init_subclass__(**kwargs) + + # _member_map_ is only present on concrete Enum subclasses. + member_map: dict[str, object] | None = getattr(cls, "_member_map_", None) + if member_map is not None and "UNKNOWN" not in member_map: + raise TypeError( + f"{cls.__name__} uses UnknownEnumMixin but does not define " + f"an UNKNOWN member" + ) + @classmethod def _missing_(cls, value: object) -> Self: # type: ignore[override] """Return `UNKNOWN` and log unrecognized values. diff --git a/pyoverkiz/models.py b/pyoverkiz/models.py index dbe791aa..b6220b2e 100644 --- a/pyoverkiz/models.py +++ b/pyoverkiz/models.py @@ -380,6 +380,10 @@ def __init__( self.qualified_name = qualified_name elif name: self.qualified_name = name + else: + raise ValueError( + "StateDefinition requires either `name` or `qualified_name`." + ) @define(init=False, kw_only=True) diff --git a/tests/test_auth.py b/tests/test_auth.py index f7b274bb..e21ec30f 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -8,10 +8,11 @@ import base64 import datetime import json -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest from aiohttp import ClientSession +from botocore.exceptions import ClientError from pyoverkiz.auth.base import AuthContext from pyoverkiz.auth.credentials import ( @@ -37,7 +38,7 @@ _decode_jwt_payload, ) from pyoverkiz.enums import APIType, Server -from pyoverkiz.exceptions import InvalidTokenException +from pyoverkiz.exceptions import InvalidTokenException, NexityBadCredentialsException from pyoverkiz.models import ServerConfig @@ -487,6 +488,74 @@ def test_auth_headers_with_token(self): assert headers == {"Authorization": "Bearer my_bearer_token"} +class TestNexityAuthStrategy: + """Tests for Nexity auth error mapping behavior.""" + + @pytest.mark.asyncio + async def test_login_maps_invalid_credentials_client_error(self): + """Map Cognito bad-credential errors to NexityBadCredentialsException.""" + server_config = ServerConfig( + server=Server.NEXITY, + name="Nexity", + endpoint="https://api.nexity.com", + manufacturer="Nexity", + type=APIType.CLOUD, + ) + credentials = UsernamePasswordCredentials("user", "pass") + session = AsyncMock(spec=ClientSession) + + bad_credentials_error = ClientError( + error_response={"Error": {"Code": "NotAuthorizedException"}}, + operation_name="InitiateAuth", + ) + warrant_instance = MagicMock() + warrant_instance.authenticate_user.side_effect = bad_credentials_error + + with ( + patch("pyoverkiz.auth.strategies.boto3.client", return_value=MagicMock()), + patch( + "pyoverkiz.auth.strategies.WarrantLite", return_value=warrant_instance + ), + pytest.raises(NexityBadCredentialsException), + ): + strategy = NexityAuthStrategy( + credentials, session, server_config, True, APIType.CLOUD + ) + await strategy.login() + + @pytest.mark.asyncio + async def test_login_propagates_non_auth_client_error(self): + """Propagate non-auth Cognito errors to preserve failure context.""" + server_config = ServerConfig( + server=Server.NEXITY, + name="Nexity", + endpoint="https://api.nexity.com", + manufacturer="Nexity", + type=APIType.CLOUD, + ) + credentials = UsernamePasswordCredentials("user", "pass") + session = AsyncMock(spec=ClientSession) + + service_error = ClientError( + error_response={"Error": {"Code": "InternalErrorException"}}, + operation_name="InitiateAuth", + ) + warrant_instance = MagicMock() + warrant_instance.authenticate_user.side_effect = service_error + + with ( + patch("pyoverkiz.auth.strategies.boto3.client", return_value=MagicMock()), + patch( + "pyoverkiz.auth.strategies.WarrantLite", return_value=warrant_instance + ), + pytest.raises(ClientError, match="InternalErrorException"), + ): + strategy = NexityAuthStrategy( + credentials, session, server_config, True, APIType.CLOUD + ) + await strategy.login() + + class TestRexelAuthStrategy: """Tests for Rexel auth specifics.""" diff --git a/tests/test_client.py b/tests/test_client.py index 8697cf5a..6c68f6a5 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -116,7 +116,7 @@ async def test_backoff_relogin_on_auth_error(self, client: OverkizClient): patch("backoff._async.asyncio.sleep", new=AsyncMock()) as sleep_mock, patch.object( OverkizClient, - "_OverkizClient__get", + "_get", new=AsyncMock( side_effect=[ exceptions.NotAuthenticatedException("expired"), @@ -144,7 +144,7 @@ async def test_backoff_refresh_listener_on_listener_error( patch("backoff._async.asyncio.sleep", new=AsyncMock()) as sleep_mock, patch.object( OverkizClient, - "_OverkizClient__post", + "_post", new=AsyncMock( side_effect=[ exceptions.InvalidEventListenerIdException("bad listener"), @@ -169,7 +169,7 @@ async def test_backoff_retries_on_concurrent_requests( patch("backoff._async.asyncio.sleep", new=AsyncMock()) as sleep_mock, patch.object( OverkizClient, - "_OverkizClient__post", + "_post", new=AsyncMock( side_effect=[ exceptions.TooManyConcurrentRequestsException("busy"), diff --git a/tests/test_client_queue_integration.py b/tests/test_client_queue_integration.py index 01df9d12..667a8858 100644 --- a/tests/test_client_queue_integration.py +++ b/tests/test_client_queue_integration.py @@ -27,9 +27,7 @@ async def test_client_without_queue_executes_immediately(): ) # Mock the internal execution - with patch.object( - client, "_OverkizClient__post", new_callable=AsyncMock - ) as mock_post: + with patch.object(client, "_post", new_callable=AsyncMock) as mock_post: mock_post.return_value = {"execId": "exec-123"} result = await client.execute_action_group([action]) @@ -61,9 +59,7 @@ async def test_client_with_queue_batches_actions(): for i in range(3) ] - with patch.object( - client, "_OverkizClient__post", new_callable=AsyncMock - ) as mock_post: + with patch.object(client, "_post", new_callable=AsyncMock) as mock_post: mock_post.return_value = {"execId": "exec-batched"} # Queue multiple actions quickly - start them as tasks to allow batching @@ -110,9 +106,7 @@ async def test_client_manual_flush(): commands=[Command(name=OverkizCommand.CLOSE)], ) - with patch.object( - client, "_OverkizClient__post", new_callable=AsyncMock - ) as mock_post: + with patch.object(client, "_post", new_callable=AsyncMock) as mock_post: mock_post.return_value = {"execId": "exec-flushed"} # Start execution as a task to allow checking pending count @@ -152,9 +146,7 @@ async def test_client_close_flushes_queue(): commands=[Command(name=OverkizCommand.CLOSE)], ) - with patch.object( - client, "_OverkizClient__post", new_callable=AsyncMock - ) as mock_post: + with patch.object(client, "_post", new_callable=AsyncMock) as mock_post: mock_post.return_value = {"execId": "exec-closed"} # Start execution as a task @@ -193,9 +185,7 @@ async def test_client_queue_respects_max_actions(): for i in range(3) ] - with patch.object( - client, "_OverkizClient__post", new_callable=AsyncMock - ) as mock_post: + with patch.object(client, "_post", new_callable=AsyncMock) as mock_post: mock_post.return_value = {"execId": "exec-123"} # Add 2 actions as tasks to trigger flush diff --git a/tests/test_models.py b/tests/test_models.py index 811bed2d..8fc78a99 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -11,6 +11,7 @@ Definition, Device, State, + StateDefinition, States, ) @@ -495,6 +496,18 @@ def test_has_state_definition_empty_states(self): assert not definition.has_state_definition(["core:ClosureState"]) +class TestStateDefinition: + """Tests for StateDefinition initialization behavior.""" + + def test_requires_name_or_qualified_name(self): + """StateDefinition should reject payloads with neither identifier field.""" + with pytest.raises( + ValueError, + match=r"StateDefinition requires either `name` or `qualified_name`\.", + ): + StateDefinition() + + class TestState: """Unit tests for State value accessors and type validation."""