Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 58 additions & 22 deletions tests/test_openml/test_api_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,29 @@
import pytest

import openml
from openml.config import ConfigurationForExamples
import openml.testing
from openml._api_calls import _download_minio_bucket, API_TOKEN_HELP_LINK


class TestConfig(openml.testing.TestBase):
@pytest.mark.test_server()
def test_too_long_uri(self):

@unittest.mock.patch("requests.Session")
def test_too_long_uri(self, Session_class_mock):
response_mock = unittest.mock.Mock()
response_mock.status_code = 414
response_mock.text = "URI too long!"
response_mock.headers = {}

session_instance = Session_class_mock.return_value.__enter__.return_value
session_instance.get.return_value = response_mock

with pytest.raises(openml.exceptions.OpenMLServerError, match="URI too long!"):
openml.datasets.list_datasets(data_id=list(range(10000)))

session_instance.get.assert_called_once()

@unittest.mock.patch("time.sleep")
@unittest.mock.patch("requests.Session")
@pytest.mark.test_server()
def test_retry_on_database_error(self, Session_class_mock, _):
response_mock = unittest.mock.Mock()
response_mock.text = (
Expand All @@ -33,17 +42,19 @@ def test_retry_on_database_error(self, Session_class_mock, _):
"Please wait for N seconds and try again.</oml:message>\n"
"</oml:error>"
)
Session_class_mock.return_value.__enter__.return_value.get.return_value = response_mock

session_instance = Session_class_mock.return_value.__enter__.return_value
session_instance.get.return_value = response_mock

with pytest.raises(openml.exceptions.OpenMLServerException, match="/abc returned code 107"):
openml._api_calls._send_request("get", "/abc", {})

assert Session_class_mock.return_value.__enter__.return_value.get.call_count == 20
assert session_instance.get.call_count == 20


class FakeObject(NamedTuple):
object_name: str
etag: str
"""We use the etag of a Minio object as the name of a marker if we already downloaded it."""


class FakeMinio:
Expand All @@ -65,26 +76,30 @@ def test_download_all_files_observes_cache(mock_minio, tmp_path: Path) -> None:
some_prefix, some_filename = "some/prefix", "dataset.arff"
some_object_path = f"{some_prefix}/{some_filename}"
some_url = f"https://not.real.com/bucket/{some_object_path}"

mock_minio.return_value = FakeMinio(
objects=[
FakeObject(object_name=some_object_path, etag=str(hash(some_object_path))),
],
)

file_path = tmp_path / some_filename

_download_minio_bucket(source=some_url, destination=tmp_path)
time_created = (tmp_path / "dataset.arff").stat().st_ctime
mtime_first = file_path.stat().st_mtime

_download_minio_bucket(source=some_url, destination=tmp_path)
time_modified = (tmp_path / some_filename).stat().st_mtime
mtime_second = file_path.stat().st_mtime

assert time_created == time_modified
assert mtime_first == mtime_second


@mock.patch.object(minio, "Minio")
def test_download_minio_failure(mock_minio, tmp_path: Path) -> None:
some_prefix, some_filename = "some/prefix", "dataset.arff"
some_object_path = f"{some_prefix}/{some_filename}"
some_url = f"https://not.real.com/bucket/{some_object_path}"

mock_minio.return_value = FakeMinio(
objects=[
FakeObject(object_name=None, etag="tmp"),
Expand All @@ -104,25 +119,46 @@ def test_download_minio_failure(mock_minio, tmp_path: Path) -> None:
_download_minio_bucket(source=some_url, destination=tmp_path)


@unittest.mock.patch("requests.Session")
@pytest.mark.parametrize(
"endpoint, method",
[
# https://github.com/openml/OpenML/blob/develop/openml_OS/views/pages/api_new/v1/xml/pre.php
("flow/exists", "post"), # 102
("dataset", "post"), # 137
("dataset/42", "delete"), # 350
# ("flow/owned", "post"), # 310 - Couldn't find what would trigger this
("flow/42", "delete"), # 320
("run/42", "delete"), # 400
("task/42", "delete"), # 460
("flow/exists", "post"),
("dataset", "post"),
("dataset/42", "delete"),
("flow/42", "delete"),
("run/42", "delete"),
("task/42", "delete"),
],
)
@pytest.mark.test_server()
def test_authentication_endpoints_requiring_api_key_show_relevant_help_link(
Session_class_mock,
endpoint: str,
method: str,
) -> None:
# We need to temporarily disable the API key to test the error message
response_mock = unittest.mock.Mock()
response_mock.status_code = 200
response_mock.headers = {}
response_mock.text = (
"<oml:error>"
"<oml:code>401</oml:code>"
"<oml:message>Authentication required</oml:message>"
"</oml:error>"
)

session_instance = Session_class_mock.return_value.__enter__.return_value
session_instance.request.return_value = response_mock

with openml.config.overwrite_config_context({"apikey": None}):
with pytest.raises(openml.exceptions.OpenMLAuthenticationError, match=API_TOKEN_HELP_LINK):
openml._api_calls._perform_api_call(call=endpoint, request_method=method, data=None)
with pytest.raises(
openml.exceptions.OpenMLNotAuthorizedError,
match=API_TOKEN_HELP_LINK,
):
openml._api_calls._perform_api_call(
call=endpoint,
request_method=method,
data=None,
)

session_instance.request.assert_called()