提交 ead3598a 编写于 作者: W wuzhuanke

add wizard ut

上级 3f69eec0
......@@ -18,7 +18,5 @@ import os
TEMPLATES_BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'templates')
SUPPORT_MINDSPORE_VERSION = '0.7.0'
SUPPORT_RUN_DRIVER_VERSION = 'C75'
SUPPORT_CUDA_VERSION = '10.1'
QUESTION_START = '>>> '
......@@ -56,22 +56,21 @@ class GenericNetwork(BaseNetwork):
dict, configuration value to network.
"""
if settings:
config = {'loss': settings['loss'],
'optimizer': settings['optimizer'],
'dataset': settings['dataset']}
self.settings.update(config)
return config
loss = self.ask_loss_function()
optimizer = self.ask_optimizer()
dataset = self.ask_dataset()
self._dataset_maker = load_dataset_maker(dataset)
config = dict(settings)
dataset_name = settings['dataset']
self._dataset_maker = load_dataset_maker(dataset_name)
else:
loss = self.ask_loss_function()
optimizer = self.ask_optimizer()
dataset_name = self.ask_dataset()
self._dataset_maker = load_dataset_maker(dataset_name)
dataset_config = self._dataset_maker.configure()
config = {'loss': loss,
'optimizer': optimizer,
'dataset': dataset_name}
config.update(dataset_config)
self._dataset_maker.set_network(self)
dataset_config = self._dataset_maker.configure()
config = {'loss': loss,
'optimizer': optimizer,
'dataset': dataset}
config.update(dataset_config)
self.settings.update(config)
return config
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test the wizard module."""
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test CreateProject class."""
import os
import shutil
import tempfile
from unittest.mock import patch
from mindinsight.wizard.base.source_file import SourceFile
from mindinsight.wizard.create_project import CreateProject
from mindinsight.wizard.network.generic_network import GenericNetwork
from tests.ut.wizard.utils import generate_file
class TestCreateProject:
"""Test SourceFile"""
workspace_dir = None
def setup_method(self):
"""Setup before call test method."""
self.workspace_dir = tempfile.mkdtemp()
def teardown_method(self):
"""Tear down after call test method."""
self._remove_dirs()
self.workspace_dir = None
def _remove_dirs(self):
"""Recursively delete a directory tree."""
if self.workspace_dir and os.path.exists(self.workspace_dir):
shutil.rmtree(self.workspace_dir)
@staticmethod
def _generate_file(file):
"""Create a file and write content."""
generate_file(file, "template file.")
@patch.object(GenericNetwork, 'generate')
@patch.object(GenericNetwork, 'configure')
@patch.object(CreateProject, 'ask_network')
@patch.object(CreateProject, 'echo_notice')
@patch('os.getcwd')
def test_run(self, mock_getcwd, mock_echo_notice, mock_ask_network, mock_config, mock_generate):
"""Test run method of CreateProject."""
source_file = SourceFile()
source_file.template_file_path = os.path.join(self.workspace_dir, 'templates', 'train.py-tpl')
source_file.file_relative_path = 'train.py'
self._generate_file(source_file.template_file_path)
# mock os.getcwd method
mock_getcwd.return_value = self.workspace_dir
mock_echo_notice.return_value = None
mock_ask_network.return_value = 'lenet'
mock_config.return_value = None
mock_generate.return_value = [source_file]
project_name = 'test'
new_project = CreateProject()
new_project.run({'name': project_name})
assert os.path.exists(os.path.join(self.workspace_dir, project_name))
assert os.access(os.path.join(self.workspace_dir, project_name, 'train.py'), mode=os.F_OK | os.R_OK | os.W_OK)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test GenericNetwork class."""
import os
import pytest
from mindinsight.wizard.network import lenet
class TestGenericNetwork:
"""Test SourceFile"""
def test_generate_scripts(self):
"""Test network object to generate network scripts"""
network_inst = lenet.Network()
network_inst.configure({
"loss": "SoftmaxCrossEntropyWithLogits",
"optimizer": "Momentum",
"dataset": "mnist"})
sources_files = network_inst.generate()
dataset_source_file = None
config_source_file = None
shell_script_dir_files = []
out_files = []
for sources_file in sources_files:
if sources_file.file_relative_path == 'src/dataset.py':
dataset_source_file = sources_file
elif sources_file.file_relative_path == 'src/config.py':
config_source_file = sources_file
elif sources_file.file_relative_path.startswith('scripts'):
shell_script_dir_files.append(sources_file)
elif not os.path.dirname(sources_file.file_relative_path):
out_files.append(sources_file)
else:
pass
assert sources_files
assert dataset_source_file is not None
assert config_source_file is not None
assert shell_script_dir_files
assert out_files
def test_config(self):
"""Test network object to config."""
network_inst = lenet.Network()
settings = {
"loss": "SoftmaxCrossEntropyWithLogits",
"optimizer": "Momentum",
"dataset": "mnist"}
configurations = network_inst.configure(settings)
assert configurations["dataset"] == settings["dataset"]
assert configurations["loss"] == settings["loss"]
assert configurations["optimizer"] == settings["optimizer"]
settings["dataset"] = "mnist_another"
with pytest.raises(ModuleNotFoundError) as exec_info:
network_inst.configure(settings)
assert exec_info.value.name == f'mindinsight.wizard.dataset.{settings["dataset"]}'
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test SourceFile class."""
import os
import shutil
import stat
import tempfile
import pytest
from mindinsight.wizard.base.source_file import SourceFile
from tests.ut.wizard.utils import generate_file
class TestSourceFile:
"""Test SourceFile"""
def setup_method(self):
"""Setup before call test method."""
self._input_dir = tempfile.mkdtemp()
self._output_dir = tempfile.mkdtemp()
def teardown_method(self):
"""Tear down after call test method."""
self._remove_dirs()
self._input_dir = None
self._output_dir = None
def _remove_dirs(self):
"""Recursively delete a directory tree."""
for temp_dir in [self._input_dir, self._output_dir]:
if temp_dir and os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
@staticmethod
def _generate_file(file, stat_mode):
"""Create a file and write content."""
generate_file(file, "template file.", stat_mode)
@pytest.mark.parametrize('params', [{
'file_relative_path': 'src/config.py',
'template_file_path': 'src/config.py-tpl'
}, {
'file_relative_path': 'src/lenet.py',
'template_file_path': 'src/lenet.py-tpl'
}, {
'file_relative_path': 'README.md',
'template_file_path': 'README.md-tpl'
}, {
'file_relative_path': 'train.py',
'template_file_path': 'train.py-tpl'
}])
def test_write_py(self, params):
"""Test write python script file"""
source_file = SourceFile()
source_file.file_relative_path = params['file_relative_path']
source_file.template_file_path = os.path.join(self._input_dir, params['template_file_path'])
self._generate_file(source_file.template_file_path, stat.S_IRUSR)
# start write
source_file.write(self._output_dir)
output_file_path = os.path.join(self._output_dir, source_file.file_relative_path)
assert os.access(output_file_path, os.F_OK | os.R_OK | os.W_OK)
assert stat.filemode(os.stat(output_file_path).st_mode) == '-rw-------'
@pytest.mark.parametrize('params', [{
'file_relative_path': 'scripts/run_eval.sh',
'template_file_path': 'scripts/run_eval.sh-tpl'
}, {
'file_relative_path': 'run_distribute_train.sh',
'template_file_path': 'run_distribute_train.sh-tpl'
}])
def test_write_sh(self, params):
"""Test write shell script file"""
source_file = SourceFile()
source_file.file_relative_path = params['file_relative_path']
source_file.template_file_path = os.path.join(self._input_dir, params['template_file_path'])
self._generate_file(source_file.template_file_path, stat.S_IRUSR)
# start write
source_file.write(self._output_dir)
output_file_path = os.path.join(self._output_dir, source_file.file_relative_path)
assert os.access(output_file_path, os.F_OK | os.R_OK | os.W_OK | os.X_OK)
assert stat.filemode(os.stat(output_file_path).st_mode) == '-rwx------'
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test TemplateManager class."""
import os
import shutil
import tempfile
import textwrap
from mindinsight.wizard.base.templates import TemplateManager
from tests.ut.wizard.utils import generate_file
def create_template_files(template_dir):
"""Create network template files."""
all_template_files = []
train_file = os.path.join(template_dir, 'train.py-tpl')
generate_file(train_file,
textwrap.dedent("""\
{% if loss=='SoftmaxCrossEntropyWithLogits' %}
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
{% elif loss=='SoftmaxCrossEntropyExpand' %}
net_loss = nn.SoftmaxCrossEntropyExpand(sparse=True)
{% endif %}
"""))
all_template_files.append(train_file)
os.mkdir(os.path.join(template_dir, 'src'))
config_file = os.path.join(template_dir, 'src', 'config.py-tpl')
generate_file(config_file,
textwrap.dedent("""\
{
'num_classes': 10,
{% if optimizer=='Momentum' %}
'lr': 0.01,
"momentum": 0.9,
{% elif optimizer=='SGD' %}
'lr': 0.1,
{% else %}
'lr': 0.001,
{% endif %}
'epoch_size': 1
}
"""))
all_template_files.append(config_file)
os.mkdir(os.path.join(template_dir, 'scripts'))
run_standalone_train_file = os.path.join(template_dir, 'scripts', 'run_standalone_train.sh-tpl')
generate_file(run_standalone_train_file,
textwrap.dedent("""\
python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 &> log &
"""))
all_template_files.append(run_standalone_train_file)
os.mkdir(os.path.join(template_dir, 'dataset'))
os.mkdir(os.path.join(template_dir, 'dataset', 'mnist'))
dataset_file = os.path.join(template_dir, 'dataset', 'mnist', 'dataset.py-tpl')
generate_file(dataset_file,
textwrap.dedent("""\
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as CV
"""))
all_template_files.append(dataset_file)
return all_template_files
class TestTemplateManager:
"""Test TemplateManager"""
template_dir = None
all_template_files = []
def setup_method(self):
"""Setup before call test method."""
self.template_dir = tempfile.mkdtemp()
self.all_template_files = create_template_files(self.template_dir)
def teardown_method(self):
"""Tear down after call test method."""
self._remove_dirs()
self.template_dir = None
def _remove_dirs(self):
"""Recursively delete a directory tree."""
if self.template_dir and os.path.exists(self.template_dir):
shutil.rmtree(self.template_dir)
def test_template_files(self):
"""Test get_template_files method."""
src_file_num = 1
dataset_file_num = 1
template_mgr = TemplateManager(self.template_dir)
all_files = template_mgr.get_template_files()
assert set(all_files) == set(self.all_template_files)
template_mgr = TemplateManager(os.path.join(self.template_dir, 'src'))
all_files = template_mgr.get_template_files()
assert len(all_files) == src_file_num
template_mgr = TemplateManager(os.path.join(self.template_dir, 'dataset'))
all_files = template_mgr.get_template_files()
assert len(all_files) == dataset_file_num
template_mgr = TemplateManager(self.template_dir, exclude_dirs=['src'])
all_files = template_mgr.get_template_files()
assert len(all_files) == len(self.all_template_files) - src_file_num
template_mgr = TemplateManager(self.template_dir, exclude_dirs=['src', 'dataset'])
all_files = template_mgr.get_template_files()
assert len(all_files) == len(self.all_template_files) - src_file_num - dataset_file_num
template_mgr = TemplateManager(self.template_dir,
exclude_dirs=['src', 'dataset'],
exclude_files=['train.py-tpl'])
all_files = template_mgr.get_template_files()
assert len(all_files) == len(self.all_template_files) - src_file_num - dataset_file_num - 1
def test_src_render(self):
"""Test render file in src directory."""
template_mgr = TemplateManager(os.path.join(self.template_dir, 'src'))
source_files = template_mgr.render(optimizer='Momentum')
assert source_files[0].content == textwrap.dedent("""\
{
'num_classes': 10,
'lr': 0.01,
"momentum": 0.9,
'epoch_size': 1
}
""")
source_files = template_mgr.render(optimizer='SGD')
assert source_files[0].content == textwrap.dedent("""\
{
'num_classes': 10,
'lr': 0.1,
'epoch_size': 1
}
""")
source_files = template_mgr.render()
assert source_files[0].content == textwrap.dedent("""\
{
'num_classes': 10,
'lr': 0.001,
'epoch_size': 1
}
""")
def test_dataset_render(self):
"""Test render file in dataset directory."""
template_mgr = TemplateManager(os.path.join(self.template_dir, 'dataset'))
source_files = template_mgr.render()
assert source_files[0].content == textwrap.dedent("""\
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as CV
""")
assert source_files[0].file_relative_path == 'mnist/dataset.py'
assert source_files[0].template_file_path == os.path.join(self.template_dir, 'dataset', 'mnist/dataset.py-tpl')
def test_assemble_render(self):
"""Test render assemble files in template directory."""
template_mgr = TemplateManager(self.template_dir, exclude_dirs=['src', 'dataset'])
source_files = template_mgr.render(loss='SoftmaxCrossEntropyWithLogits')
unmatched_files = []
for source_file in source_files:
if source_file.template_file_path == os.path.join(self.template_dir, 'scripts/run_standalone_train.sh-tpl'):
assert source_file.content == textwrap.dedent("""\
python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 &> log &
""")
assert source_file.file_relative_path == 'scripts/run_standalone_train.sh'
elif source_file.template_file_path == os.path.join(self.template_dir, 'train.py-tpl'):
assert source_file.content == textwrap.dedent("""\
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
""")
assert source_file.file_relative_path == 'train.py'
else:
unmatched_files.append(source_file)
assert not unmatched_files
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils method."""
import os
import stat
def generate_file(file, template_content, mode=None):
"""Create a file and write content."""
os.makedirs(os.path.dirname(file), mode=stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR, exist_ok=True)
with open(file, 'w') as fp:
fp.write(template_content)
if mode:
os.chmod(file, mode)
else:
os.chmod(file, stat.S_IRUSR)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册