diff --git a/tests/test_tasks/test_clustering_task.py b/tests/test_tasks/test_clustering_task.py index 29f5663c4..b36dcd2f6 100644 --- a/tests/test_tasks/test_clustering_task.py +++ b/tests/test_tasks/test_clustering_task.py @@ -1,6 +1,8 @@ # License: BSD 3-Clause from __future__ import annotations +from unittest import mock + import pytest import openml @@ -20,51 +22,52 @@ def setUp(self, n_levels: int = 1): self.task_type = TaskType.CLUSTERING self.estimation_procedure = 17 - @pytest.mark.production_server() - def test_get_dataset(self): - # no clustering tasks on test server - self.use_production_server() + @mock.patch("openml.tasks.get_task") + @mock.patch("openml.datasets.get_dataset") + def test_get_dataset(self, mock_get_dataset, mock_get_task): + mock_task = mock.Mock() + mock_task.dataset_id = 36 + mock_task.get_dataset = lambda: openml.datasets.get_dataset(mock_task.dataset_id) + mock_get_task.return_value = mock_task + task = openml.tasks.get_task(self.task_id) task.get_dataset() - @pytest.mark.production_server() - @pytest.mark.test_server() - def test_download_task(self): - # no clustering tasks on test server - self.use_production_server() + mock_get_task.assert_called_once_with(self.task_id) + mock_get_dataset.assert_called_once_with(mock_task.dataset_id) + @mock.patch("tests.test_tasks.test_task.get_task") + def test_download_task(self, mock_get_task): + mock_task = mock.Mock() + mock_task.task_id = self.task_id + mock_task.task_type_id = TaskType.CLUSTERING + mock_task.dataset_id = 36 + mock_get_task.return_value = mock_task + task = super().test_download_task() assert task.task_id == self.task_id assert task.task_type_id == TaskType.CLUSTERING assert task.dataset_id == 36 + mock_get_task.assert_called_with(self.task_id) + + @mock.patch("openml.tasks.create_task") + def test_upload_task(self, mock_create_task): + mock_task = mock.Mock() + mock_task.publish.return_value = mock_task + mock_task.id = 1 + mock_create_task.return_value = mock_task + + # Triggering the test + task = openml.tasks.create_task( + task_type=self.task_type, + dataset_id=36, + estimation_procedure_id=self.estimation_procedure, + ) + task = task.publish() - @pytest.mark.test_server() - def test_upload_task(self): - compatible_datasets = self._get_compatible_rand_dataset() - for i in range(100): - try: - dataset_id = compatible_datasets[i % len(compatible_datasets)] - # Upload a clustering task without a ground truth. - task = openml.tasks.create_task( - task_type=self.task_type, - dataset_id=dataset_id, - estimation_procedure_id=self.estimation_procedure, - ) - task = task.publish() - TestBase._mark_entity_for_removal("task", task.id) - TestBase.logger.info( - f"collected from {__file__.split('/')[-1]}: {task.id}", - ) - # success - break - except OpenMLServerException as e: - # Error code for 'task already exists' - # Should be 533 according to the docs - # (# https://www.openml.org/api_docs#!/task/post_task) - if e.code == 614: - continue - else: - raise e - else: - raise ValueError( - f"Could not create a valid task for task type ID {self.task_type}", - ) + assert task.id == 1 + mock_create_task.assert_called_once_with( + task_type=self.task_type, + dataset_id=36, + estimation_procedure_id=self.estimation_procedure, + ) + mock_task.publish.assert_called_once()