提交 71a7240b 编写于 作者: Z zhangyunshu

add ut/st for histogram api

上级 b476d7d4
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Function:
Test histograms restful api.
Usage:
pytest tests/st/func/datavisual
"""
import pytest
from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_url
from .. import globals as gbl
BASE_URL = '/v1/mindinsight/datavisual/histograms'
class TestHistograms:
"""Test Histograms."""
@pytest.mark.level0
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.usefixtures("init_summary_logs")
def test_histograms(self, client):
"""Test getting histogram data."""
plugin_name = PluginNameEnum.HISTOGRAM.value
train_id = gbl.get_train_ids()[0]
tag_name = gbl.get_tags(train_id, plugin_name)[0]
expected_histograms = gbl.get_metadata(train_id, tag_name)
params = dict(train_id=train_id, tag=tag_name)
url = get_url(BASE_URL, params)
response = client.get(url)
histograms = response.get_json().get("histograms")
for histograms, expected_histograms in zip(histograms, expected_histograms):
assert histograms.get("wall_time") == expected_histograms.get("wall_time")
assert histograms.get("step") == expected_histograms.get("step")
......@@ -54,5 +54,6 @@ TRAIN_ROUTES = dict(
graph_single_node='/v1/mindinsight/datavisual/graphs/single-node',
image_metadata='/v1/mindinsight/datavisual/image/metadata',
image_single_image='/v1/mindinsight/datavisual/image/single-image',
scalar_metadata='/v1/mindinsight/datavisual/scalar/metadata'
scalar_metadata='/v1/mindinsight/datavisual/scalar/metadata',
histograms='/v1/mindinsight/datavisual/histograms'
)
......@@ -26,6 +26,7 @@ from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
from mindinsight.datavisual.processors.graph_processor import GraphProcessor
from mindinsight.datavisual.processors.images_processor import ImageProcessor
from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor
from mindinsight.datavisual.processors.histogram_processor import HistogramProcessor
from ....utils.tools import get_url
from .conftest import TRAIN_ROUTES
......@@ -432,3 +433,42 @@ class TestTrainVisual:
assert response.status_code == 200
results = response.get_json()
assert results == test_name
def test_histograms_with_params_miss(self, client):
"""Parsing missing params to get histogram data."""
params = dict()
url = get_url(TRAIN_ROUTES['histograms'], params)
response = client.get(url)
results = response.get_json()
assert response.status_code == 400
assert results['error_code'] == '50540003'
assert results['error_msg'] == "Param missing. 'train_id' is required."
train_id = "aa"
params = dict(train_id=train_id)
url = get_url(TRAIN_ROUTES['histograms'], params)
response = client.get(url)
results = response.get_json()
assert response.status_code == 400
assert results['error_code'] == '50540003'
assert results['error_msg'] == "Param missing. 'tag' is required."
@patch.object(HistogramProcessor, 'get_histograms')
def test_histograms_success(self, mock_histogram_processor, client):
"""Parsing available params to get histogram data."""
test_train_id = "aa"
test_tag = "bb"
expect_resp = {
'histograms': [{'buckets': [[1, 2, 3]]}],
'train_id': test_train_id,
'tag': test_tag
}
get_histograms = Mock(return_value=expect_resp)
mock_histogram_processor.side_effect = get_histograms
params = dict(train_id=test_train_id, tag=test_tag)
url = get_url(TRAIN_ROUTES['histograms'], params)
response = client.get(url)
assert response.status_code == 200
results = response.get_json()
assert results == expect_resp
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Function:
Test histogram processor.
Usage:
pytest tests/ut/datavisual
"""
import tempfile
from unittest.mock import Mock
import pytest
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.datavisual.common.exceptions import HistogramNotExistError
from mindinsight.datavisual.data_transform import data_manager
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
from mindinsight.datavisual.processors.histogram_processor import HistogramProcessor
from mindinsight.datavisual.utils import crc32
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, delete_files_or_dirs
from ..mock import MockLogger
class TestHistogramProcessor:
"""Test histogram processor api."""
_steps_list = [1, 3, 5]
_tag_name = 'tag_name'
_plugin_name = 'histogram'
_complete_tag_name = f'{_tag_name}/{_plugin_name}'
_temp_path = None
_histograms = None
_mock_data_manager = None
_train_id = None
_generated_path = []
@classmethod
def setup_class(cls):
"""Mock common environment for histograms unittest."""
crc32.CheckValueAgainstData = Mock(return_value=True)
data_manager.logger = MockLogger
def teardown_class(self):
"""Delete temp files."""
delete_files_or_dirs(self._generated_path)
@pytest.fixture(scope='function')
def load_histogram_record(self):
"""Load histogram record."""
summary_base_dir = tempfile.mkdtemp()
log_dir = tempfile.mkdtemp(dir=summary_base_dir)
self._train_id = log_dir.replace(summary_base_dir, ".")
log_operation = LogOperations()
self._temp_path, self._histograms, _ = log_operation.generate_log(
PluginNameEnum.HISTOGRAM.value, log_dir, dict(step=self._steps_list, tag=self._tag_name))
self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)])
self._mock_data_manager.start_load_data(reload_interval=0)
# wait for loading done
check_loading_done(self._mock_data_manager, time_limit=5)
@pytest.mark.usefixtures('load_histogram_record')
def test_get_histograms_with_not_exist_id(self):
"""Get histogram data with not exist id."""
test_train_id = 'not_exist_id'
processor = HistogramProcessor(self._mock_data_manager)
with pytest.raises(TrainJobNotExistError) as exc_info:
processor.get_histograms(test_train_id, self._tag_name)
assert exc_info.value.error_code == '50545005'
assert exc_info.value.message == "Train job is not exist. Detail: Can not find the given train job in cache."
@pytest.mark.usefixtures('load_histogram_record')
def test_get_histograms_with_not_exist_tag(self):
"""Get histogram data with not exist tag."""
test_tag_name = 'not_exist_tag_name'
processor = HistogramProcessor(self._mock_data_manager)
with pytest.raises(HistogramNotExistError) as exc_info:
processor.get_histograms(self._train_id, test_tag_name)
assert exc_info.value.error_code == '5054500F'
assert "Can not find any data in this train job by given tag." in exc_info.value.message
@pytest.mark.usefixtures('load_histogram_record')
def test_get_histograms_success(self):
"""Get histogram data success."""
test_tag_name = self._complete_tag_name
processor = HistogramProcessor(self._mock_data_manager)
results = processor.get_histograms(self._train_id, test_tag_name)
recv_metadata = results.get('histograms')
for recv_values, expected_values in zip(recv_metadata, self._histograms):
assert recv_values.get('wall_time') == expected_values.get('wall_time')
assert recv_values.get('step') == expected_values.get('step')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册