Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/zoo/src/pyearthtools/zoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class MODEL():

Models = utils.AvailableModels()

from pyearthtools.zoo.commands import commands # pylint: disable=C0413 # noqa: E402
from pyearthtools.zoo import commands # pylint: disable=C0413 # noqa: E402
from pyearthtools.zoo.predict import ( # pylint: disable=C0413 # noqa: E402
data,
interactive,
Expand Down
6 changes: 3 additions & 3 deletions packages/zoo/src/pyearthtools/zoo/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
`pyearthtools.zoo` commands.
"""

from pyearthtools.zoo.commands import utils, commands
from pyearthtools.zoo.commands.commands import entry_point
# CLI entry point for 'pet' as defined in pyproject.toml for zoo
from ._commands import entry_point

__all__ = ["utils", "commands", "entry_point"]
__all__ = ["entry_point"]
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
import click

import pyearthtools.zoo
from pyearthtools.zoo.commands import utils as command_utils

from . import _cmd_utils as command_utils
from pyearthtools.zoo.utils import AvailableModels

available_models = AvailableModels()
Expand Down Expand Up @@ -62,6 +63,10 @@ def entry_point(debug, info):

@entry_point.command("models", help="Print available models.")
def models():
list_models()


def list_models():
from pyearthtools.zoo.register import dynamic_import

dynamic_import()
Expand All @@ -73,6 +78,7 @@ def models():

print(_models)
print("(Specify category with a '/' seperation.)")

sys.exit(0)


Expand Down Expand Up @@ -116,7 +122,11 @@ def models():
default=None,
help="Override for config path",
)
def run_predict(
def run_predict():
cmd_run_predict()


def cmd_run_predict(
ctx,
model: str,
time: str,
Expand All @@ -127,9 +137,6 @@ def run_predict(
):
ctx_kwargs = command_utils.get_keyword_from_ctx(ctx)

if "data" in ctx_kwargs:
raise ValueError("data has been deprecated as an argument for `predict`, use `data`.")

predictions = pyearthtools.zoo.predict(
model,
time,
Expand Down Expand Up @@ -200,7 +207,11 @@ def run_predict(
default=None,
help="Override for config path",
)
def interactive(
def interactive():
cmd_interactive()


def cmd_interactive(
ctx,
model: str,
time: str,
Expand Down Expand Up @@ -277,6 +288,10 @@ def interactive(
help="Override for config path",
)
def data(ctx, model, time, pipeline, data_cache, config_path: str | Path):
cmd_data(ctx, model, time, pipeline, data_cache, config_path)


def cmd_data(ctx, model, time, pipeline, data_cache, config_path: str | Path):
ctx_kwargs = command_utils.get_keyword_from_ctx(ctx)
_ = pyearthtools.zoo.data(model, time, pipeline, data_cache, config_path=config_path, **ctx_kwargs)

Expand Down
2 changes: 1 addition & 1 deletion packages/zoo/src/pyearthtools/zoo/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import warnings

import pyearthtools.zoo
from pyearthtools.zoo.commands import utils as command_utils
from pyearthtools.zoo.commands import _cmd_utils as command_utils
from pyearthtools.zoo import utils, exceptions
from pyearthtools.zoo.register import dynamic_import

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import pytest

from pyearthtools.zoo import commands

tests = [
(("--test", "value"), {"test": "value"}), # Single argument
(
Expand All @@ -37,28 +39,24 @@

@pytest.mark.parametrize("args, result", tests)
def test_parse_args_to_dict(args, result):
from pyearthtools.zoo.commands import utils

assert utils.parse_args_to_dict(*args) == result
assert commands._cmd_utils.parse_args_to_dict(*args) == result


@pytest.mark.parametrize("args, result", tests)
def test_get_keyword_from_ctx(args, result):
from pyearthtools.zoo.commands import utils
import click

class FakeContext(click.Context):
def __init__(self, args):
self.args = args

assert utils.get_keyword_from_ctx(FakeContext(args)) == result
assert commands._cmd_utils.get_keyword_from_ctx(FakeContext(args)) == result


@pytest.mark.parametrize("args, result", tests)
def test_parse_str_to_dict(args, result):
from pyearthtools.zoo.commands import utils

assert utils.parse_str_to_dict(" ".join(args)) == result
assert commands._cmd_utils.parse_str_to_dict(" ".join(args)) == result


@pytest.mark.parametrize(
Expand All @@ -72,7 +70,6 @@ def test_parse_str_to_dict(args, result):
],
)
def test_parse_args_to_dict_fail(args):
from pyearthtools.zoo.commands import utils

with pytest.raises(KeyError):
utils.parse_args_to_dict(*args)
commands._cmd_utils.parse_args_to_dict(*args)
60 changes: 60 additions & 0 deletions packages/zoo/tests/commands/test_commands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from collections import namedtuple
import pytest
from pyearthtools.zoo.commands import _commands as cmd
from unittest.mock import patch

from pyearthtools.zoo.exceptions import ModelException

# def test_entry_point():

# with pytest.raises(SystemExit):
# ep = cmd.entry_point(None, None)


def test_models():

with pytest.raises(SystemExit):
_m = cmd.list_models()


def test_run_predict():

ctx = namedtuple("Any", ["args", "kwargs"])([], [])
model = "nonexistent"
time = "2020T010100"
output = "tbc"
pipeline_name = "tbc"
data_cache = "tbc"
config_path = "tbc"

with pytest.raises(ModelException):
cmd.cmd_run_predict(ctx, model, time, output, pipeline_name, data_cache, config_path)


def test_interactive():

ctx = namedtuple("Any", ["args", "kwargs"])([], [])
model = "nonexistent"
time = "2020T010100"
output = "tbc"
pipeline_name = "tbc"
data_cache = "tbc"
config_path = "tbc"

with pytest.raises(AttributeError):
with patch("pyearthtools.zoo.available_models", return_value="fake_model"):
cmd.cmd_interactive(ctx, model, time, pipeline_name, output, data_cache, config_path)


def test_data():

ctx = namedtuple("Any", ["args", "kwargs"])([], [])
model = "nonexistent"
time = "2020T010100"
_output = "tbc"
pipeline_name = "tbc"
data_cache = "tbc"
config_path = "tbc"

with pytest.raises(ModelException):
cmd.cmd_data(ctx, model, time, pipeline_name, data_cache, config_path)
5 changes: 5 additions & 0 deletions packages/zoo/tests/test_zoo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from pyearthtools import zoo


def test_available_models():
_models = zoo.available_models() # smoke test for now