diff --git a/mindinsight/wizard/conf/constants.py b/mindinsight/wizard/conf/constants.py index 460d616ebbdda382dd263d0c1a797e7b83fc8cf5..64fa7f32788f54add59a8cda4ce5bb3169436d40 100644 --- a/mindinsight/wizard/conf/constants.py +++ b/mindinsight/wizard/conf/constants.py @@ -18,7 +18,5 @@ import os TEMPLATES_BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'templates') SUPPORT_MINDSPORE_VERSION = '0.7.0' -SUPPORT_RUN_DRIVER_VERSION = 'C75' -SUPPORT_CUDA_VERSION = '10.1' QUESTION_START = '>>> ' diff --git a/mindinsight/wizard/network/generic_network.py b/mindinsight/wizard/network/generic_network.py index d1af378d72429c4b105ae3c68e19b0c615fc1294..05f41849f3d979342224d664d134663ae69dbabc 100644 --- a/mindinsight/wizard/network/generic_network.py +++ b/mindinsight/wizard/network/generic_network.py @@ -56,22 +56,21 @@ class GenericNetwork(BaseNetwork): dict, configuration value to network. """ if settings: - config = {'loss': settings['loss'], - 'optimizer': settings['optimizer'], - 'dataset': settings['dataset']} - self.settings.update(config) - return config - loss = self.ask_loss_function() - optimizer = self.ask_optimizer() - dataset = self.ask_dataset() - self._dataset_maker = load_dataset_maker(dataset) + config = dict(settings) + dataset_name = settings['dataset'] + self._dataset_maker = load_dataset_maker(dataset_name) + else: + loss = self.ask_loss_function() + optimizer = self.ask_optimizer() + dataset_name = self.ask_dataset() + self._dataset_maker = load_dataset_maker(dataset_name) + dataset_config = self._dataset_maker.configure() + + config = {'loss': loss, + 'optimizer': optimizer, + 'dataset': dataset_name} + config.update(dataset_config) self._dataset_maker.set_network(self) - dataset_config = self._dataset_maker.configure() - - config = {'loss': loss, - 'optimizer': optimizer, - 'dataset': dataset} - config.update(dataset_config) self.settings.update(config) return config diff --git a/tests/ut/wizard/__init__.py b/tests/ut/wizard/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0803281393c406dbe96fd6a96818072d7c6d1622 --- /dev/null +++ b/tests/ut/wizard/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test the wizard module.""" diff --git a/tests/ut/wizard/test_create_project.py b/tests/ut/wizard/test_create_project.py new file mode 100644 index 0000000000000000000000000000000000000000..b727817e1d566c64bb0cf63b27f71e5adeb1dbd5 --- /dev/null +++ b/tests/ut/wizard/test_create_project.py @@ -0,0 +1,74 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test CreateProject class.""" +import os +import shutil +import tempfile +from unittest.mock import patch + +from mindinsight.wizard.base.source_file import SourceFile +from mindinsight.wizard.create_project import CreateProject +from mindinsight.wizard.network.generic_network import GenericNetwork +from tests.ut.wizard.utils import generate_file + + +class TestCreateProject: + """Test SourceFile""" + workspace_dir = None + + def setup_method(self): + """Setup before call test method.""" + self.workspace_dir = tempfile.mkdtemp() + + def teardown_method(self): + """Tear down after call test method.""" + self._remove_dirs() + self.workspace_dir = None + + def _remove_dirs(self): + """Recursively delete a directory tree.""" + if self.workspace_dir and os.path.exists(self.workspace_dir): + shutil.rmtree(self.workspace_dir) + + @staticmethod + def _generate_file(file): + """Create a file and write content.""" + generate_file(file, "template file.") + + @patch.object(GenericNetwork, 'generate') + @patch.object(GenericNetwork, 'configure') + @patch.object(CreateProject, 'ask_network') + @patch.object(CreateProject, 'echo_notice') + @patch('os.getcwd') + def test_run(self, mock_getcwd, mock_echo_notice, mock_ask_network, mock_config, mock_generate): + """Test run method of CreateProject.""" + source_file = SourceFile() + source_file.template_file_path = os.path.join(self.workspace_dir, 'templates', 'train.py-tpl') + source_file.file_relative_path = 'train.py' + self._generate_file(source_file.template_file_path) + + # mock os.getcwd method + mock_getcwd.return_value = self.workspace_dir + mock_echo_notice.return_value = None + mock_ask_network.return_value = 'lenet' + mock_config.return_value = None + mock_generate.return_value = [source_file] + + project_name = 'test' + new_project = CreateProject() + new_project.run({'name': project_name}) + + assert os.path.exists(os.path.join(self.workspace_dir, project_name)) + assert os.access(os.path.join(self.workspace_dir, project_name, 'train.py'), mode=os.F_OK | os.R_OK | os.W_OK) diff --git a/tests/ut/wizard/test_generic_network.py b/tests/ut/wizard/test_generic_network.py new file mode 100644 index 0000000000000000000000000000000000000000..9e91f64bfa40ea0dc5beb86d79e57266c7035265 --- /dev/null +++ b/tests/ut/wizard/test_generic_network.py @@ -0,0 +1,71 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test GenericNetwork class.""" +import os + +import pytest + +from mindinsight.wizard.network import lenet + + +class TestGenericNetwork: + """Test SourceFile""" + + def test_generate_scripts(self): + """Test network object to generate network scripts""" + network_inst = lenet.Network() + network_inst.configure({ + "loss": "SoftmaxCrossEntropyWithLogits", + "optimizer": "Momentum", + "dataset": "mnist"}) + sources_files = network_inst.generate() + dataset_source_file = None + config_source_file = None + shell_script_dir_files = [] + out_files = [] + for sources_file in sources_files: + if sources_file.file_relative_path == 'src/dataset.py': + dataset_source_file = sources_file + elif sources_file.file_relative_path == 'src/config.py': + config_source_file = sources_file + elif sources_file.file_relative_path.startswith('scripts'): + shell_script_dir_files.append(sources_file) + elif not os.path.dirname(sources_file.file_relative_path): + out_files.append(sources_file) + else: + pass + + assert sources_files + assert dataset_source_file is not None + assert config_source_file is not None + assert shell_script_dir_files + assert out_files + + def test_config(self): + """Test network object to config.""" + network_inst = lenet.Network() + settings = { + "loss": "SoftmaxCrossEntropyWithLogits", + "optimizer": "Momentum", + "dataset": "mnist"} + configurations = network_inst.configure(settings) + assert configurations["dataset"] == settings["dataset"] + assert configurations["loss"] == settings["loss"] + assert configurations["optimizer"] == settings["optimizer"] + + settings["dataset"] = "mnist_another" + with pytest.raises(ModuleNotFoundError) as exec_info: + network_inst.configure(settings) + assert exec_info.value.name == f'mindinsight.wizard.dataset.{settings["dataset"]}' diff --git a/tests/ut/wizard/test_source_file.py b/tests/ut/wizard/test_source_file.py new file mode 100644 index 0000000000000000000000000000000000000000..9899ceab29c4231ff67f95240cfd82690e1ef1c1 --- /dev/null +++ b/tests/ut/wizard/test_source_file.py @@ -0,0 +1,98 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test SourceFile class.""" +import os +import shutil +import stat +import tempfile + +import pytest + +from mindinsight.wizard.base.source_file import SourceFile +from tests.ut.wizard.utils import generate_file + + +class TestSourceFile: + """Test SourceFile""" + + def setup_method(self): + """Setup before call test method.""" + self._input_dir = tempfile.mkdtemp() + self._output_dir = tempfile.mkdtemp() + + def teardown_method(self): + """Tear down after call test method.""" + self._remove_dirs() + self._input_dir = None + self._output_dir = None + + def _remove_dirs(self): + """Recursively delete a directory tree.""" + for temp_dir in [self._input_dir, self._output_dir]: + if temp_dir and os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + + @staticmethod + def _generate_file(file, stat_mode): + """Create a file and write content.""" + generate_file(file, "template file.", stat_mode) + + @pytest.mark.parametrize('params', [{ + 'file_relative_path': 'src/config.py', + 'template_file_path': 'src/config.py-tpl' + }, { + 'file_relative_path': 'src/lenet.py', + 'template_file_path': 'src/lenet.py-tpl' + }, { + 'file_relative_path': 'README.md', + 'template_file_path': 'README.md-tpl' + }, { + 'file_relative_path': 'train.py', + 'template_file_path': 'train.py-tpl' + }]) + def test_write_py(self, params): + """Test write python script file""" + source_file = SourceFile() + source_file.file_relative_path = params['file_relative_path'] + source_file.template_file_path = os.path.join(self._input_dir, params['template_file_path']) + self._generate_file(source_file.template_file_path, stat.S_IRUSR) + + # start write + source_file.write(self._output_dir) + + output_file_path = os.path.join(self._output_dir, source_file.file_relative_path) + assert os.access(output_file_path, os.F_OK | os.R_OK | os.W_OK) + assert stat.filemode(os.stat(output_file_path).st_mode) == '-rw-------' + + @pytest.mark.parametrize('params', [{ + 'file_relative_path': 'scripts/run_eval.sh', + 'template_file_path': 'scripts/run_eval.sh-tpl' + }, { + 'file_relative_path': 'run_distribute_train.sh', + 'template_file_path': 'run_distribute_train.sh-tpl' + }]) + def test_write_sh(self, params): + """Test write shell script file""" + source_file = SourceFile() + source_file.file_relative_path = params['file_relative_path'] + source_file.template_file_path = os.path.join(self._input_dir, params['template_file_path']) + self._generate_file(source_file.template_file_path, stat.S_IRUSR) + + # start write + source_file.write(self._output_dir) + + output_file_path = os.path.join(self._output_dir, source_file.file_relative_path) + assert os.access(output_file_path, os.F_OK | os.R_OK | os.W_OK | os.X_OK) + assert stat.filemode(os.stat(output_file_path).st_mode) == '-rwx------' diff --git a/tests/ut/wizard/test_templates.py b/tests/ut/wizard/test_templates.py new file mode 100644 index 0000000000000000000000000000000000000000..3d6143fb8a6d74934334f3281f9375f7891a5af0 --- /dev/null +++ b/tests/ut/wizard/test_templates.py @@ -0,0 +1,188 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test TemplateManager class.""" +import os +import shutil +import tempfile +import textwrap + +from mindinsight.wizard.base.templates import TemplateManager +from tests.ut.wizard.utils import generate_file + + +def create_template_files(template_dir): + """Create network template files.""" + all_template_files = [] + train_file = os.path.join(template_dir, 'train.py-tpl') + generate_file(train_file, + textwrap.dedent("""\ + {% if loss=='SoftmaxCrossEntropyWithLogits' %} + net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") + {% elif loss=='SoftmaxCrossEntropyExpand' %} + net_loss = nn.SoftmaxCrossEntropyExpand(sparse=True) + {% endif %} + """)) + all_template_files.append(train_file) + + os.mkdir(os.path.join(template_dir, 'src')) + config_file = os.path.join(template_dir, 'src', 'config.py-tpl') + generate_file(config_file, + textwrap.dedent("""\ + { + 'num_classes': 10, + {% if optimizer=='Momentum' %} + 'lr': 0.01, + "momentum": 0.9, + {% elif optimizer=='SGD' %} + 'lr': 0.1, + {% else %} + 'lr': 0.001, + {% endif %} + 'epoch_size': 1 + } + """)) + all_template_files.append(config_file) + + os.mkdir(os.path.join(template_dir, 'scripts')) + run_standalone_train_file = os.path.join(template_dir, 'scripts', 'run_standalone_train.sh-tpl') + generate_file(run_standalone_train_file, + textwrap.dedent("""\ + python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 &> log & + """)) + all_template_files.append(run_standalone_train_file) + + os.mkdir(os.path.join(template_dir, 'dataset')) + os.mkdir(os.path.join(template_dir, 'dataset', 'mnist')) + dataset_file = os.path.join(template_dir, 'dataset', 'mnist', 'dataset.py-tpl') + generate_file(dataset_file, + textwrap.dedent("""\ + import mindspore.dataset as ds + import mindspore.dataset.transforms.vision.c_transforms as CV + """)) + all_template_files.append(dataset_file) + return all_template_files + + +class TestTemplateManager: + """Test TemplateManager""" + template_dir = None + all_template_files = [] + + def setup_method(self): + """Setup before call test method.""" + self.template_dir = tempfile.mkdtemp() + self.all_template_files = create_template_files(self.template_dir) + + def teardown_method(self): + """Tear down after call test method.""" + self._remove_dirs() + self.template_dir = None + + def _remove_dirs(self): + """Recursively delete a directory tree.""" + if self.template_dir and os.path.exists(self.template_dir): + shutil.rmtree(self.template_dir) + + def test_template_files(self): + """Test get_template_files method.""" + src_file_num = 1 + dataset_file_num = 1 + template_mgr = TemplateManager(self.template_dir) + all_files = template_mgr.get_template_files() + assert set(all_files) == set(self.all_template_files) + + template_mgr = TemplateManager(os.path.join(self.template_dir, 'src')) + all_files = template_mgr.get_template_files() + assert len(all_files) == src_file_num + + template_mgr = TemplateManager(os.path.join(self.template_dir, 'dataset')) + all_files = template_mgr.get_template_files() + assert len(all_files) == dataset_file_num + + template_mgr = TemplateManager(self.template_dir, exclude_dirs=['src']) + all_files = template_mgr.get_template_files() + assert len(all_files) == len(self.all_template_files) - src_file_num + + template_mgr = TemplateManager(self.template_dir, exclude_dirs=['src', 'dataset']) + all_files = template_mgr.get_template_files() + assert len(all_files) == len(self.all_template_files) - src_file_num - dataset_file_num + + template_mgr = TemplateManager(self.template_dir, + exclude_dirs=['src', 'dataset'], + exclude_files=['train.py-tpl']) + all_files = template_mgr.get_template_files() + assert len(all_files) == len(self.all_template_files) - src_file_num - dataset_file_num - 1 + + def test_src_render(self): + """Test render file in src directory.""" + template_mgr = TemplateManager(os.path.join(self.template_dir, 'src')) + source_files = template_mgr.render(optimizer='Momentum') + assert source_files[0].content == textwrap.dedent("""\ + { + 'num_classes': 10, + 'lr': 0.01, + "momentum": 0.9, + 'epoch_size': 1 + } + """) + + source_files = template_mgr.render(optimizer='SGD') + assert source_files[0].content == textwrap.dedent("""\ + { + 'num_classes': 10, + 'lr': 0.1, + 'epoch_size': 1 + } + """) + source_files = template_mgr.render() + assert source_files[0].content == textwrap.dedent("""\ + { + 'num_classes': 10, + 'lr': 0.001, + 'epoch_size': 1 + } + """) + + def test_dataset_render(self): + """Test render file in dataset directory.""" + template_mgr = TemplateManager(os.path.join(self.template_dir, 'dataset')) + source_files = template_mgr.render() + assert source_files[0].content == textwrap.dedent("""\ + import mindspore.dataset as ds + import mindspore.dataset.transforms.vision.c_transforms as CV + """) + assert source_files[0].file_relative_path == 'mnist/dataset.py' + assert source_files[0].template_file_path == os.path.join(self.template_dir, 'dataset', 'mnist/dataset.py-tpl') + + def test_assemble_render(self): + """Test render assemble files in template directory.""" + template_mgr = TemplateManager(self.template_dir, exclude_dirs=['src', 'dataset']) + source_files = template_mgr.render(loss='SoftmaxCrossEntropyWithLogits') + unmatched_files = [] + for source_file in source_files: + if source_file.template_file_path == os.path.join(self.template_dir, 'scripts/run_standalone_train.sh-tpl'): + assert source_file.content == textwrap.dedent("""\ + python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 &> log & + """) + assert source_file.file_relative_path == 'scripts/run_standalone_train.sh' + elif source_file.template_file_path == os.path.join(self.template_dir, 'train.py-tpl'): + assert source_file.content == textwrap.dedent("""\ + net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") + """) + assert source_file.file_relative_path == 'train.py' + else: + unmatched_files.append(source_file) + + assert not unmatched_files diff --git a/tests/ut/wizard/utils.py b/tests/ut/wizard/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3f56c0e236838682287e8cc26dd9e6befa5fe754 --- /dev/null +++ b/tests/ut/wizard/utils.py @@ -0,0 +1,28 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utils method.""" +import os +import stat + + +def generate_file(file, template_content, mode=None): + """Create a file and write content.""" + os.makedirs(os.path.dirname(file), mode=stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR, exist_ok=True) + with open(file, 'w') as fp: + fp.write(template_content) + if mode: + os.chmod(file, mode) + else: + os.chmod(file, stat.S_IRUSR)