提交 f2586441 编写于 作者: P ph

Merge branch 'master' of gitee.com:mindspore/mindinsight into phdev

<!-- Thanks for sending a pull request! Here are some tips for you:
If this is your first time, please read our contributor guidelines: https://gitee.com/mindspore/mindspore/blob/master/CONTRIBUTING.md
-->
**What type of PR is this?**
> Uncomment only one ` /kind <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line:
>
> /kind bug
> /kind task
> /kind feature
**What does this PR do / why do we need it**:
**Which issue(s) this PR fixes**:
<!--
*Automatically closes linked issue when PR is merged.
Usage: `Fixes #<issue number>`, or `Fixes (paste link of issue)`.
-->
Fixes #
**Special notes for your reviewers**:
---
name: RFC
about: Use this template for the new feature or enhancement
labels: kind/feature or kind/enhancement
---
## Background
- Describe the status of the problem you wish to solve
- Attach the relevant issue if have
## Introduction
- Describe the general solution, design and/or pseudo-code
## Trail
| No. | Task Description | Related Issue(URL) |
| --- | ---------------- | ------------------ |
| 1 | | |
| 2 | | |
---
name: Bug Report
about: Use this template for reporting a bug
labels: kind/bug
---
<!-- Thanks for sending an issue! Here are some tips for you:
If this is your first time, please read our contributor guidelines: https://github.com/mindspore-ai/mindspore/blob/master/CONTRIBUTING.md
-->
## Environment
### Hardware Environment(`Ascend`/`GPU`/`CPU`):
> Uncomment only one ` /device <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line:
>
> `/device ascend`</br>
> `/device gpu`</br>
> `/device cpu`</br>
### Software Environment:
- **MindSpore version (source or binary)**:
- **Python version (e.g., Python 3.7.5)**:
- **OS platform and distribution (e.g., Linux Ubuntu 16.04)**:
- **GCC/Compiler version (if compiled from source)**:
## Describe the current behavior
## Describe the expected behavior
## Steps to reproduce the issue
1.
2.
3.
## Related log / screenshot
## Special notes for this issue
---
name: Task
about: Use this template for task tracking
labels: kind/task
---
## Task Description
## Task Goal
## Sub Task
| No. | Task Description | Issue ID |
| --- | ---------------- | -------- |
| 1 | | |
| 2 | | |
<!-- Thanks for sending a pull request! Here are some tips for you:
If this is your first time, please read our contributor guidelines: https://github.com/mindspore-ai/mindspore/blob/master/CONTRIBUTING.md
-->
**What type of PR is this?**
> Uncomment only one ` /kind <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line:
>
> `/kind bug`</br>
> `/kind task`</br>
> `/kind feature`</br>
**What does this PR do / why do we need it**:
**Which issue(s) this PR fixes**:
<!--
*Automatically closes linked issue when PR is merged.
Usage: `Fixes #<issue number>`, or `Fixes (paste link of issue)`.
-->
Fixes #
**Special notes for your reviewers**:
......@@ -18,19 +18,21 @@ This file is used to define the basic graph.
import copy
import time
from enum import Enum
from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.common import exceptions
from .node import NodeTypeEnum
from .node import Node
class EdgeTypeEnum:
class EdgeTypeEnum(Enum):
"""Node edge type enum."""
control = 'control'
data = 'data'
CONTROL = 'control'
DATA = 'data'
class DataTypeEnum:
class DataTypeEnum(Enum):
"""Data type enum."""
DT_TENSOR = 13
......@@ -292,69 +294,64 @@ class Graph:
output_attr['scope'] = NodeTypeEnum.POLYMERIC_SCOPE.value
node.update_output({dst_name: output_attr})
def _calc_polymeric_input_output(self):
def _update_polymeric_input_output(self):
"""Calc polymeric input and output after build polymeric node."""
for name, node in self._normal_nodes.items():
for node in self._normal_nodes.values():
polymeric_input = self._calc_polymeric_attr(node, 'input')
node.update_polymeric_input(polymeric_input)
polymeric_output = self._calc_polymeric_attr(node, 'output')
node.update_polymeric_output(polymeric_output)
for name, node in self._polymeric_nodes.items():
polymeric_input = {}
for src_name in node.input:
src_node = self._polymeric_nodes.get(src_name)
if node.node_type == NodeTypeEnum.POLYMERIC_SCOPE.value:
src_name = src_name if not src_node else src_node.polymeric_scope_name
output_name = self._calc_dummy_node_name(name, src_name)
polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.data}})
continue
if not src_node:
continue
if not node.name_scope and src_node.name_scope:
# if current node is in first layer, and the src node is not in
# the first layer, the src node will not be the polymeric input of current node.
continue
if node.name_scope == src_node.name_scope \
or node.name_scope.startswith(src_node.name_scope):
polymeric_input.update(
{src_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.data}})
polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.DATA.value}})
node.update_polymeric_input(polymeric_input)
polymeric_output = {}
for dst_name in node.output:
dst_node = self._polymeric_nodes.get(dst_name)
polymeric_output = {}
output_name = self._calc_dummy_node_name(name, dst_name)
polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.DATA.value}})
node.update_polymeric_output(polymeric_output)
def _calc_polymeric_attr(self, node, attr):
"""
Calc polymeric input or polymeric output after build polymeric node.
Args:
node (Node): Computes the polymeric input for a given node.
attr (str): The polymeric attr, optional value is `input` or `output`.
Returns:
dict, return polymeric input or polymeric output of the given node.
"""
polymeric_attr = {}
for node_name in getattr(node, attr):
polymeric_node = self._polymeric_nodes.get(node_name)
if node.node_type == NodeTypeEnum.POLYMERIC_SCOPE.value:
dst_name = dst_name if not dst_node else dst_node.polymeric_scope_name
output_name = self._calc_dummy_node_name(name, dst_name)
polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.data}})
node_name = node_name if not polymeric_node else polymeric_node.polymeric_scope_name
dummy_node_name = self._calc_dummy_node_name(node.name, node_name)
polymeric_attr.update({dummy_node_name: {'edge_type': EdgeTypeEnum.DATA.value}})
continue
if not dst_node:
if not polymeric_node:
continue
if not node.name_scope and dst_node.name_scope:
if not node.name_scope and polymeric_node.name_scope:
# If current node is in top-level layer, and the polymeric_node node is not in
# the top-level layer, the polymeric node will not be the polymeric input
# or polymeric output of current node.
continue
if node.name_scope == dst_node.name_scope \
or node.name_scope.startswith(dst_node.name_scope):
polymeric_output.update(
{dst_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.data}})
node.update_polymeric_output(polymeric_output)
for name, node in self._polymeric_nodes.items():
polymeric_input = {}
for src_name in node.input:
output_name = self._calc_dummy_node_name(name, src_name)
polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.data}})
node.update_polymeric_input(polymeric_input)
if node.name_scope == polymeric_node.name_scope \
or node.name_scope.startswith(polymeric_node.name_scope + '/'):
polymeric_attr.update(
{polymeric_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.DATA.value}})
polymeric_output = {}
for dst_name in node.output:
polymeric_output = {}
output_name = self._calc_dummy_node_name(name, dst_name)
polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.data}})
node.update_polymeric_output(polymeric_output)
return polymeric_attr
def _calc_dummy_node_name(self, current_node_name, other_node_name):
"""
......
......@@ -39,7 +39,7 @@ class MSGraph(Graph):
self._build_leaf_nodes(graph_proto)
self._build_polymeric_nodes()
self._build_name_scope_nodes()
self._calc_polymeric_input_output()
self._update_polymeric_input_output()
logger.info("Build graph end, normal node count: %s, polymeric node "
"count: %s.", len(self._normal_nodes), len(self._polymeric_nodes))
......@@ -90,9 +90,9 @@ class MSGraph(Graph):
node_name = leaf_node_id_map_name[node_def.name]
node = self._leaf_nodes[node_name]
for input_def in node_def.input:
edge_type = EdgeTypeEnum.data
edge_type = EdgeTypeEnum.DATA.value
if input_def.type == "CONTROL_EDGE":
edge_type = EdgeTypeEnum.control
edge_type = EdgeTypeEnum.CONTROL.value
if const_nodes_map.get(input_def.name):
const_node = copy.deepcopy(const_nodes_map[input_def.name])
......@@ -218,7 +218,7 @@ class MSGraph(Graph):
node = Node(name=const.key, node_id=const_node_id)
node.node_type = NodeTypeEnum.CONST.value
node.update_attr({const.key: str(const.value)})
if const.value.dtype == DataTypeEnum.DT_TENSOR:
if const.value.dtype == DataTypeEnum.DT_TENSOR.value:
shape = []
for dim in const.value.tensor_val.dims:
shape.append(dim)
......
......@@ -172,7 +172,7 @@ class Node:
Args:
polymeric_output (dict[str, dict): Format is {dst_node.polymeric_scope_name:
{'edge_type': EdgeTypeEnum.data}}).
{'edge_type': EdgeTypeEnum.DATA.value}}).
"""
self._polymeric_output.update(polymeric_output)
......
......@@ -18,13 +18,10 @@ Description: This file is used for some common util.
import os
import shutil
from unittest.mock import Mock
import pytest
from flask import Response
from tests.st.func.datavisual import constants
from tests.st.func.datavisual.utils.log_operations import LogOperations
from tests.st.func.datavisual.utils.utils import check_loading_done
from tests.st.func.datavisual.utils import globals as gbl
from mindinsight.conf import settings
from mindinsight.datavisual.data_transform import data_manager
from mindinsight.datavisual.data_transform.data_manager import DataManager
......@@ -32,6 +29,11 @@ from mindinsight.datavisual.data_transform.loader_generators.data_loader_generat
from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE
from mindinsight.datavisual.utils import tools
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done
from . import constants
from . import globals as gbl
summaries_metadata = None
mock_data_manager = None
summary_base_dir = constants.SUMMARY_BASE_DIR
......@@ -55,17 +57,21 @@ def init_summary_logs():
os.mkdir(summary_base_dir, mode=mode)
global summaries_metadata, mock_data_manager
log_operations = LogOperations()
summaries_metadata = log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_FIRST)
summaries_metadata = log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_FIRST,
constants.SUMMARY_DIR_PREFIX)
mock_data_manager = DataManager([DataLoaderGenerator(summary_base_dir)])
mock_data_manager.start_load_data(reload_interval=0)
check_loading_done(mock_data_manager)
summaries_metadata.update(log_operations.create_summary_logs(
summary_base_dir, constants.SUMMARY_DIR_NUM_SECOND, constants.SUMMARY_DIR_NUM_FIRST))
summaries_metadata.update(log_operations.create_multiple_logs(
summary_base_dir, constants.MULTIPLE_DIR_NAME, constants.MULTIPLE_LOG_NUM))
summaries_metadata.update(log_operations.create_reservoir_log(
summary_base_dir, constants.RESERVOIR_DIR_NAME, constants.RESERVOIR_STEP_NUM))
summaries_metadata.update(
log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_SECOND,
constants.SUMMARY_DIR_NUM_FIRST))
summaries_metadata.update(
log_operations.create_multiple_logs(summary_base_dir, constants.MULTIPLE_DIR_NAME,
constants.MULTIPLE_LOG_NUM))
summaries_metadata.update(
log_operations.create_reservoir_log(summary_base_dir, constants.RESERVOIR_DIR_NAME,
constants.RESERVOIR_STEP_NUM))
mock_data_manager.start_load_data(reload_interval=0)
# Sleep 1 sec to make sure the status of mock_data_manager changed to LOADING.
......@@ -73,7 +79,7 @@ def init_summary_logs():
# Maximum number of loads is `MAX_DATA_LOADER_SIZE`.
for i in range(len(summaries_metadata) - MAX_DATA_LOADER_SIZE):
summaries_metadata.pop("./%s%d" % (constants.SUMMARY_PREFIX, i))
summaries_metadata.pop("./%s%d" % (constants.SUMMARY_DIR_PREFIX, i))
yield
finally:
......
......@@ -16,7 +16,7 @@
import tempfile
SUMMARY_BASE_DIR = tempfile.NamedTemporaryFile().name
SUMMARY_PREFIX = "summary"
SUMMARY_DIR_PREFIX = "summary"
SUMMARY_DIR_NUM_FIRST = 5
SUMMARY_DIR_NUM_SECOND = 11
......
......@@ -19,11 +19,11 @@ Usage:
pytest tests/st/func/datavisual
"""
import os
import json
import pytest
from tests.st.func.datavisual.utils import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url
from .. import globals as gbl
from .....utils.tools import get_url, compare_result_with_file
BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes'
......@@ -33,12 +33,6 @@ class TestQueryNodes:
graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results')
def compare_result_with_file(self, result, filename):
"""Compare result with file which contain the expected results."""
with open(os.path.join(self.graph_results_dir, filename), 'r') as fp:
expected_results = json.load(fp)
assert result == expected_results
@pytest.mark.level0
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
......@@ -65,4 +59,5 @@ class TestQueryNodes:
url = get_url(BASE_URL, params)
response = client.get(url)
assert response.status_code == 200
self.compare_result_with_file(response.get_json(), result_file)
file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(response.get_json(), file_path)
......@@ -19,12 +19,11 @@ Usage:
pytest tests/st/func/datavisual
"""
import os
import json
import pytest
from tests.st.func.datavisual.utils import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url
from .. import globals as gbl
from .....utils.tools import get_url, compare_result_with_file
BASE_URL = '/v1/mindinsight/datavisual/graphs/single-node'
......@@ -34,12 +33,6 @@ class TestQuerySingleNode:
graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results')
def compare_result_with_file(self, result, filename):
"""Compare result with file which contain the expected results."""
with open(os.path.join(self.graph_results_dir, filename), 'r') as fp:
expected_results = json.load(fp)
assert result == expected_results
@pytest.mark.level0
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
......@@ -59,4 +52,5 @@ class TestQuerySingleNode:
url = get_url(BASE_URL, params)
response = client.get(url)
assert response.status_code == 200
self.compare_result_with_file(response.get_json(), result_file)
file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(response.get_json(), file_path)
......@@ -19,25 +19,20 @@ Usage:
pytest tests/st/func/datavisual
"""
import os
import json
import pytest
from tests.st.func.datavisual.utils import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url
from .. import globals as gbl
from .....utils.tools import get_url, compare_result_with_file
BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes/names'
class TestSearchNodes:
"""Test search nodes restful APIs."""
"""Test searching nodes restful APIs."""
graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results')
def compare_result_with_file(self, result, filename):
"""Compare result with file which contain the expected results."""
with open(os.path.join(self.graph_results_dir, filename), 'r') as fp:
expected_results = json.load(fp)
assert result == expected_results
@pytest.mark.level0
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
......@@ -58,4 +53,5 @@ class TestSearchNodes:
url = get_url(BASE_URL, params)
response = client.get(url)
assert response.status_code == 200
self.compare_result_with_file(response.get_json(), result_file)
file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(response.get_json(), file_path)
......@@ -20,13 +20,13 @@ Usage:
"""
import pytest
from tests.st.func.datavisual.constants import MULTIPLE_TRAIN_ID, RESERVOIR_TRAIN_ID
from tests.st.func.datavisual.utils import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url
from mindinsight.conf import settings
from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_url
from .. import globals as gbl
from ..constants import MULTIPLE_TRAIN_ID, RESERVOIR_TRAIN_ID
BASE_URL = '/v1/mindinsight/datavisual/image/metadata'
......
......@@ -20,11 +20,11 @@ Usage:
"""
import pytest
from tests.st.func.datavisual.utils import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url, get_image_tensor_from_bytes
from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_image_tensor_from_bytes, get_url
from .. import globals as gbl
BASE_URL = '/v1/mindinsight/datavisual/image/single-image'
......
......@@ -19,11 +19,12 @@ Usage:
pytest tests/st/func/datavisual
"""
import pytest
from tests.st.func.datavisual.utils import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_url
from .. import globals as gbl
BASE_URL = '/v1/mindinsight/datavisual/scalar/metadata'
......
......@@ -20,11 +20,11 @@ Usage:
"""
import pytest
from tests.st.func.datavisual.utils import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_url
from .. import globals as gbl
BASE_URL = '/v1/mindinsight/datavisual/plugins'
......
......@@ -19,11 +19,12 @@ Usage:
pytest tests/st/func/datavisual
"""
import pytest
from tests.st.func.datavisual.utils import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_url
from .. import globals as gbl
BASE_URL = '/v1/mindinsight/datavisual/single-job'
......
......@@ -20,8 +20,8 @@ Usage:
"""
import pytest
from tests.st.func.datavisual.constants import SUMMARY_DIR_NUM
from tests.st.func.datavisual.utils.utils import get_url
from ..constants import SUMMARY_DIR_NUM
from .....utils.tools import get_url
BASE_URL = '/v1/mindinsight/datavisual/train-jobs'
......
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Log generator for graph."""
import json
import os
import time
from google.protobuf import json_format
from tests.st.func.datavisual.utils.log_generators.log_generator import LogGenerator
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
class GraphLogGenerator(LogGenerator):
"""
Log generator for graph.
This is a log generator writing graph. User can use it to generate fake
summary logs about graph.
"""
def generate_log(self, file_path, graph_dict):
"""
Generate log for external calls.
Args:
file_path (str): Path to write logs.
graph_dict (dict): A dict consists of graph node information.
Returns:
dict, generated scalar metadata.
"""
graph_event = self.generate_event(dict(graph=graph_dict))
self._write_log_from_event(file_path, graph_event)
return graph_dict
def generate_event(self, values):
"""
Method for generating graph event.
Args:
values (dict): Graph values. e.g. {'graph': graph_dict}.
Returns:
summary_pb2.Event.
"""
graph_json = {
'wall_time': time.time(),
'graph_def': values.get('graph'),
}
graph_event = json_format.Parse(json.dumps(graph_json), summary_pb2.Event())
return graph_event
if __name__ == "__main__":
graph_log_generator = GraphLogGenerator()
test_file_name = '%s.%s.%s' % ('graph', 'summary', str(time.time()))
graph_base_path = os.path.join(os.path.dirname(__file__), os.pardir, "log_generators", "graph_base.json")
with open(graph_base_path, 'r') as load_f:
graph = json.load(load_f)
graph_log_generator.generate_log(test_file_name, graph)
......@@ -20,11 +20,11 @@ Usage:
"""
import pytest
from tests.st.func.datavisual.utils import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_url
from .. import globals as gbl
TRAIN_JOB_URL = '/v1/mindinsight/datavisual/train-jobs'
PLUGIN_URL = '/v1/mindinsight/datavisual/plugins'
METADATA_URL = '/v1/mindinsight/datavisual/image/metadata'
......
......@@ -20,11 +20,11 @@ Usage:
"""
import pytest
from tests.st.func.datavisual.utils import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url, get_image_tensor_from_bytes
from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_image_tensor_from_bytes, get_url
from .. import globals as gbl
TRAIN_JOB_URL = '/v1/mindinsight/datavisual/train-jobs'
PLUGIN_URL = '/v1/mindinsight/datavisual/plugins'
METADATA_URL = '/v1/mindinsight/datavisual/image/metadata'
......
......@@ -26,11 +26,101 @@ from unittest import TestCase
import pytest
from mindinsight.lineagemgr import get_summary_lineage, filter_summary_lineage
from mindinsight.lineagemgr.common.exceptions.exceptions import \
LineageParamSummaryPathError, LineageParamValueError, LineageParamTypeError, \
LineageSearchConditionParamError, LineageFileNotFoundError
from ..conftest import BASE_SUMMARY_DIR, SUMMARY_DIR, SUMMARY_DIR_2, DATASET_GRAPH
from mindinsight.lineagemgr import filter_summary_lineage, get_summary_lineage
from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageFileNotFoundError, LineageParamSummaryPathError,
LineageParamTypeError, LineageParamValueError,
LineageSearchConditionParamError)
from ..conftest import BASE_SUMMARY_DIR, DATASET_GRAPH, SUMMARY_DIR, SUMMARY_DIR_2
LINEAGE_INFO_RUN1 = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'metric': {
'accuracy': 0.78
},
'hyper_parameters': {
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'epoch': 14,
'parallel_mode': 'stand_alone',
'device_num': 2,
'batch_size': 32
},
'algorithm': {
'network': 'ResNet'
},
'train_dataset': {
'train_dataset_size': 731
},
'valid_dataset': {
'valid_dataset_size': 10240
},
'model': {
'path': '{"ckpt": "'
+ BASE_SUMMARY_DIR + '/run1/CKPtest_model.ckpt"}',
'size': 64
},
'dataset_graph': DATASET_GRAPH
}
LINEAGE_FILTRATION_EXCEPT_RUN = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'except_run'),
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 1024,
'test_dataset_path': None,
'test_dataset_count': None,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 10,
'batch_size': 32,
'loss': 0.029999999329447746,
'model_size': 64,
'metric': {},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
}
LINEAGE_FILTRATION_RUN1 = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 731,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 14,
'batch_size': 32,
'loss': None,
'model_size': 64,
'metric': {
'accuracy': 0.78
},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
}
LINEAGE_FILTRATION_RUN2 = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'),
'loss_function': None,
'train_dataset_path': None,
'train_dataset_count': None,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': None,
'optimizer': None,
'learning_rate': None,
'epoch': None,
'batch_size': None,
'loss': None,
'model_size': None,
'metric': {
'accuracy': 2.7800000000000002
},
'dataset_graph': {},
'dataset_mark': 3
}
@pytest.mark.usefixtures("create_summary_dir")
......@@ -67,36 +157,7 @@ class TestModelApi(TestCase):
total_res = get_summary_lineage(SUMMARY_DIR)
partial_res1 = get_summary_lineage(SUMMARY_DIR, ['hyper_parameters'])
partial_res2 = get_summary_lineage(SUMMARY_DIR, ['metric', 'algorithm'])
expect_total_res = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'metric': {
'accuracy': 0.78
},
'hyper_parameters': {
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'epoch': 14,
'parallel_mode': 'stand_alone',
'device_num': 2,
'batch_size': 32
},
'algorithm': {
'network': 'ResNet'
},
'train_dataset': {
'train_dataset_size': 731
},
'valid_dataset': {
'valid_dataset_size': 10240
},
'model': {
'path': '{"ckpt": "'
+ BASE_SUMMARY_DIR + '/run1/CKPtest_model.ckpt"}',
'size': 64
},
'dataset_graph': DATASET_GRAPH
}
expect_total_res = LINEAGE_INFO_RUN1
expect_partial_res1 = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'hyper_parameters': {
......@@ -139,7 +200,7 @@ class TestModelApi(TestCase):
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_get_summary_lineage_exception(self):
def test_get_summary_lineage_exception_1(self):
"""Test the interface of get_summary_lineage with exception."""
# summary path does not exist
self.assertRaisesRegex(
......@@ -183,6 +244,14 @@ class TestModelApi(TestCase):
keys=None
)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_get_summary_lineage_exception_2(self):
"""Test the interface of get_summary_lineage with exception."""
# keys is invalid
self.assertRaisesRegex(
LineageParamValueError,
......@@ -250,64 +319,9 @@ class TestModelApi(TestCase):
"""Test the interface of filter_summary_lineage."""
expect_result = {
'object': [
{
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'except_run'),
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 1024,
'test_dataset_path': None,
'test_dataset_count': None,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 10,
'batch_size': 32,
'loss': 0.029999999329447746,
'model_size': 64,
'metric': {},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
},
{
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 731,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 14,
'batch_size': 32,
'loss': None,
'model_size': 64,
'metric': {
'accuracy': 0.78
},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
},
{
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'),
'loss_function': None,
'train_dataset_path': None,
'train_dataset_count': None,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': None,
'optimizer': None,
'learning_rate': None,
'epoch': None,
'batch_size': None,
'loss': None,
'model_size': None,
'metric': {
'accuracy': 2.7800000000000002
},
'dataset_graph': {},
'dataset_mark': 3
}
LINEAGE_FILTRATION_EXCEPT_RUN,
LINEAGE_FILTRATION_RUN1,
LINEAGE_FILTRATION_RUN2
],
'count': 3
}
......@@ -357,46 +371,8 @@ class TestModelApi(TestCase):
}
expect_result = {
'object': [
{
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'),
'loss_function': None,
'train_dataset_path': None,
'train_dataset_count': None,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': None,
'optimizer': None,
'learning_rate': None,
'epoch': None,
'batch_size': None,
'loss': None,
'model_size': None,
'metric': {
'accuracy': 2.7800000000000002
},
'dataset_graph': {},
'dataset_mark': 3
},
{
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 731,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 14,
'batch_size': 32,
'loss': None,
'model_size': 64,
'metric': {
'accuracy': 0.78
},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
}
LINEAGE_FILTRATION_RUN2,
LINEAGE_FILTRATION_RUN1
],
'count': 2
}
......@@ -432,46 +408,8 @@ class TestModelApi(TestCase):
}
expect_result = {
'object': [
{
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'),
'loss_function': None,
'train_dataset_path': None,
'train_dataset_count': None,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': None,
'optimizer': None,
'learning_rate': None,
'epoch': None,
'batch_size': None,
'loss': None,
'model_size': None,
'metric': {
'accuracy': 2.7800000000000002
},
'dataset_graph': {},
'dataset_mark': 3
},
{
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 731,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 14,
'batch_size': 32,
'loss': None,
'model_size': 64,
'metric': {
'accuracy': 0.78
},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
}
LINEAGE_FILTRATION_RUN2,
LINEAGE_FILTRATION_RUN1
],
'count': 2
}
......@@ -498,44 +436,8 @@ class TestModelApi(TestCase):
}
expect_result = {
'object': [
{
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'except_run'),
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 1024,
'test_dataset_path': None,
'test_dataset_count': None,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 10,
'batch_size': 32,
'loss': 0.029999999329447746,
'model_size': 64,
'metric': {},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
},
{
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 731,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 14,
'batch_size': 32,
'loss': None,
'model_size': 64,
'metric': {
'accuracy': 0.78
},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
}
LINEAGE_FILTRATION_EXCEPT_RUN,
LINEAGE_FILTRATION_RUN1
],
'count': 2
}
......@@ -674,6 +576,14 @@ class TestModelApi(TestCase):
search_condition
)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_filter_summary_lineage_exception_3(self):
"""Test the abnormal execution of the filter_summary_lineage interface."""
# the condition of offset is invalid
search_condition = {
'offset': 1.0
......@@ -712,6 +622,14 @@ class TestModelApi(TestCase):
search_condition
)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_filter_summary_lineage_exception_4(self):
"""Test the abnormal execution of the filter_summary_lineage interface."""
# the sorted_type not supported
search_condition = {
'sorted_name': 'summary_dir',
......@@ -753,6 +671,14 @@ class TestModelApi(TestCase):
search_condition
)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_filter_summary_lineage_exception_5(self):
"""Test the abnormal execution of the filter_summary_lineage interface."""
# the summary dir is invalid in search condition
search_condition = {
'summary_dir': {
......@@ -811,7 +737,7 @@ class TestModelApi(TestCase):
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_filter_summary_lineage_exception_3(self):
def test_filter_summary_lineage_exception_6(self):
"""Test the abnormal execution of the filter_summary_lineage interface."""
# gt > lt
search_condition1 = {
......
......@@ -21,7 +21,8 @@ import tempfile
import pytest
from .collection.model import mindspore
from ....utils import mindspore
from ....utils.mindspore.dataset.engine.serializer_deserializer import SERIALIZED_PIPELINE
sys.modules['mindspore'] = mindspore
......@@ -32,52 +33,7 @@ SUMMARY_DIR_3 = os.path.join(BASE_SUMMARY_DIR, 'except_run')
COLLECTION_MODULE = 'TestModelLineage'
API_MODULE = 'TestModelApi'
DATASET_GRAPH = {
'op_type': 'BatchDataset',
'op_module': 'minddata.dataengine.datasets',
'num_parallel_workers': None,
'drop_remainder': True,
'batch_size': 10,
'children': [
{
'op_type': 'MapDataset',
'op_module': 'minddata.dataengine.datasets',
'num_parallel_workers': None,
'input_columns': [
'label'
],
'output_columns': [
None
],
'operations': [
{
'tensor_op_module': 'minddata.transforms.c_transforms',
'tensor_op_name': 'OneHot',
'num_classes': 10
}
],
'children': [
{
'op_type': 'MnistDataset',
'shard_id': None,
'num_shards': None,
'op_module': 'minddata.dataengine.datasets',
'dataset_dir': '/home/anthony/MindData/tests/dataset/data/testMnistData',
'num_parallel_workers': None,
'shuffle': None,
'num_samples': 100,
'sampler': {
'sampler_module': 'minddata.dataengine.samplers',
'sampler_name': 'RandomSampler',
'replacement': True,
'num_samples': 100
},
'children': []
}
]
}
]
}
DATASET_GRAPH = SERIALIZED_PIPELINE
def get_module_name(nodeid):
"""Get the module name from nodeid."""
......
......@@ -14,6 +14,6 @@
# ============================================================================
"""Import the mocked mindspore."""
import sys
from .lineagemgr.collection.model import mindspore
from ..utils import mindspore
sys.modules['mindspore'] = mindspore
......@@ -21,14 +21,15 @@ Usage:
from unittest.mock import patch
import pytest
from tests.ut.backend.datavisual.conftest import TRAIN_ROUTES
from tests.ut.datavisual.utils.log_generators.images_log_generator import ImagesLogGenerator
from tests.ut.datavisual.utils.log_generators.scalars_log_generator import ScalarsLogGenerator
from tests.ut.datavisual.utils.utils import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.processors.train_task_manager import TrainTaskManager
from ....utils.log_generators.images_log_generator import ImagesLogGenerator
from ....utils.log_generators.scalars_log_generator import ScalarsLogGenerator
from ....utils.tools import get_url
from .conftest import TRAIN_ROUTES
class TestTrainTask:
"""Test train task api."""
......@@ -36,9 +37,7 @@ class TestTrainTask:
_scalar_log_generator = ScalarsLogGenerator()
_image_log_generator = ImagesLogGenerator()
@pytest.mark.parametrize(
"plugin_name",
['no_plugin_name', 'not_exist_plugin_name'])
@pytest.mark.parametrize("plugin_name", ['no_plugin_name', 'not_exist_plugin_name'])
def test_query_single_train_task_with_plugin_name_not_exist(self, client, plugin_name):
"""
Parsing unavailable plugin name to single train task.
......
......@@ -21,14 +21,15 @@ Usage:
from unittest.mock import Mock, patch
import pytest
from tests.ut.backend.datavisual.conftest import TRAIN_ROUTES
from tests.ut.datavisual.utils.utils import get_url
from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
from mindinsight.datavisual.processors.graph_processor import GraphProcessor
from mindinsight.datavisual.processors.images_processor import ImageProcessor
from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor
from ....utils.tools import get_url
from .conftest import TRAIN_ROUTES
class TestTrainVisual:
"""Test Train Visual APIs."""
......@@ -95,14 +96,7 @@ class TestTrainVisual:
assert response.status_code == 200
response = response.get_json()
expected_response = {
"metadatas": [{
"height": 224,
"step": 1,
"wall_time": 1572058058.1175,
"width": 448
}]
}
expected_response = {"metadatas": [{"height": 224, "step": 1, "wall_time": 1572058058.1175, "width": 448}]}
assert expected_response == response
def test_single_image_with_params_miss(self, client):
......@@ -254,8 +248,10 @@ class TestTrainVisual:
@patch.object(GraphProcessor, 'get_nodes')
def test_graph_nodes_success(self, mock_graph_processor, mock_graph_processor_1, client):
"""Test getting graph nodes successfully."""
def mock_get_nodes(name, node_type):
return dict(name=name, node_type=node_type)
mock_graph_processor.side_effect = mock_get_nodes
mock_init = Mock(return_value=None)
......@@ -327,10 +323,7 @@ class TestTrainVisual:
assert results['error_msg'] == "Invalid parameter value. 'offset' should " \
"be greater than or equal to 0."
@pytest.mark.parametrize(
"limit",
[-1, 0, 1001]
)
@pytest.mark.parametrize("limit", [-1, 0, 1001])
@patch.object(GraphProcessor, '__init__')
def test_graph_node_names_with_invalid_limit(self, mock_graph_processor, client, limit):
"""Test getting graph node names with invalid limit."""
......@@ -348,14 +341,10 @@ class TestTrainVisual:
assert results['error_msg'] == "Invalid parameter value. " \
"'limit' should in [1, 1000]."
@pytest.mark.parametrize(
" offset, limit",
[(0, 100), (1, 1), (0, 1000)]
)
@pytest.mark.parametrize(" offset, limit", [(0, 100), (1, 1), (0, 1000)])
@patch.object(GraphProcessor, '__init__')
@patch.object(GraphProcessor, 'search_node_names')
def test_graph_node_names_success(self, mock_graph_processor, mock_graph_processor_1, client,
offset, limit):
def test_graph_node_names_success(self, mock_graph_processor, mock_graph_processor_1, client, offset, limit):
"""
Parsing unavailable params to get image metadata.
......@@ -367,8 +356,10 @@ class TestTrainVisual:
response status code: 200.
response json: dict, contains search_content, offset, and limit.
"""
def mock_search_node_names(search_content, offset, limit):
return dict(search_content=search_content, offset=int(offset), limit=int(limit))
mock_graph_processor.side_effect = mock_search_node_names
mock_init = Mock(return_value=None)
......@@ -376,15 +367,12 @@ class TestTrainVisual:
test_train_id = "aaa"
test_search_content = "bbb"
params = dict(train_id=test_train_id, search=test_search_content,
offset=offset, limit=limit)
params = dict(train_id=test_train_id, search=test_search_content, offset=offset, limit=limit)
url = get_url(TRAIN_ROUTES['graph_nodes_names'], params)
response = client.get(url)
assert response.status_code == 200
results = response.get_json()
assert results == dict(search_content=test_search_content,
offset=int(offset),
limit=int(limit))
assert results == dict(search_content=test_search_content, offset=int(offset), limit=int(limit))
def test_graph_search_single_node_with_params_is_wrong(self, client):
"""Test searching graph single node with params is wrong."""
......@@ -427,8 +415,10 @@ class TestTrainVisual:
response status code: 200.
response json: name.
"""
def mock_search_single_node(name):
return name
mock_graph_processor.side_effect = mock_search_single_node
mock_init = Mock(return_value=None)
......
......@@ -20,28 +20,9 @@ from unittest import TestCase, mock
from flask import Response
from mindinsight.backend.application import APP
from mindinsight.lineagemgr.common.exceptions.exceptions import \
LineageQuerySummaryDataError
from mindinsight.lineagemgr.common.exceptions.exceptions import LineageQuerySummaryDataError
class TestSearchModel(TestCase):
"""Test the restful api of search_model."""
def setUp(self):
"""Test init."""
APP.response_class = Response
self.app_client = APP.test_client()
self.url = '/v1/mindinsight/models/model_lineage'
@mock.patch('mindinsight.backend.lineagemgr.lineage_api.settings')
@mock.patch('mindinsight.backend.lineagemgr.lineage_api.filter_summary_lineage')
def test_search_model_success(self, *args):
"""Test the success of model_success."""
base_dir = '/path/to/test_lineage_summary_dir_base'
args[0].return_value = {
'object': [
{
'summary_dir': base_dir,
LINEAGE_FILTRATION_BASE = {
'accuracy': None,
'mae': None,
'mse': None,
......@@ -57,9 +38,8 @@ class TestSearchModel(TestCase):
'batch_size': 32,
'loss': 0.029999999329447746,
'model_size': 128
},
{
'summary_dir': os.path.join(base_dir, 'run1'),
}
LINEAGE_FILTRATION_RUN1 = {
'accuracy': 0.78,
'mae': None,
'mse': None,
......@@ -75,6 +55,32 @@ class TestSearchModel(TestCase):
'batch_size': 32,
'loss': 0.029999999329447746,
'model_size': 128
}
class TestSearchModel(TestCase):
"""Test the restful api of search_model."""
def setUp(self):
"""Test init."""
APP.response_class = Response
self.app_client = APP.test_client()
self.url = '/v1/mindinsight/models/model_lineage'
@mock.patch('mindinsight.backend.lineagemgr.lineage_api.settings')
@mock.patch('mindinsight.backend.lineagemgr.lineage_api.filter_summary_lineage')
def test_search_model_success(self, *args):
"""Test the success of model_success."""
base_dir = '/path/to/test_lineage_summary_dir_base'
args[0].return_value = {
'object': [
{
'summary_dir': base_dir,
**LINEAGE_FILTRATION_BASE
},
{
'summary_dir': os.path.join(base_dir, 'run1'),
**LINEAGE_FILTRATION_RUN1
}
],
'count': 2
......@@ -93,39 +99,11 @@ class TestSearchModel(TestCase):
'object': [
{
'summary_dir': './',
'accuracy': None,
'mae': None,
'mse': None,
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 64,
'test_dataset_path': None,
'test_dataset_count': None,
'network': 'str',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 12,
'batch_size': 32,
'loss': 0.029999999329447746,
'model_size': 128
**LINEAGE_FILTRATION_BASE
},
{
'summary_dir': './run1',
'accuracy': 0.78,
'mae': None,
'mse': None,
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 64,
'test_dataset_path': None,
'test_dataset_count': 64,
'network': 'str',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 14,
'batch_size': 32,
'loss': 0.029999999329447746,
'model_size': 128
**LINEAGE_FILTRATION_RUN1
}
],
'count': 2
......
......@@ -19,15 +19,16 @@ Usage:
pytest tests/ut/datavisual
"""
from unittest.mock import patch
from werkzeug.exceptions import MethodNotAllowed, NotFound
from tests.ut.backend.datavisual.conftest import TRAIN_ROUTES
from tests.ut.datavisual.mock import MockLogger
from tests.ut.datavisual.utils.utils import get_url
from werkzeug.exceptions import MethodNotAllowed, NotFound
from mindinsight.datavisual.processors import scalars_processor
from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor
from ....utils.tools import get_url
from ...backend.datavisual.conftest import TRAIN_ROUTES
from ..mock import MockLogger
class TestErrorHandler:
"""Test train visual api."""
......
......@@ -14,7 +14,7 @@
# ============================================================================
"""
Function:
Test mindinsight.datavisual.data_transform.log_generators.data_loader_generator
Test mindinsight.datavisual.data_transform.loader_generators.data_loader_generator
Usage:
pytest tests/ut/datavisual
"""
......@@ -22,18 +22,19 @@ import datetime
import os
import shutil
import tempfile
from unittest.mock import patch
import pytest
from tests.ut.datavisual.mock import MockLogger
import pytest
from mindinsight.datavisual.data_transform.loader_generators import data_loader_generator
from mindinsight.utils.exceptions import ParamValueError
from ...mock import MockLogger
class TestDataLoaderGenerator:
"""Test data_loader_generator."""
@classmethod
def setup_class(cls):
data_loader_generator.logger = MockLogger
......@@ -88,8 +89,9 @@ class TestDataLoaderGenerator:
mock_data_loader.return_value = True
loader_dict = generator.generate_loaders(loader_pool=dict())
expected_ids = [summary.get('relative_path')
for summary in summaries[-data_loader_generator.MAX_DATA_LOADER_SIZE:]]
expected_ids = [
summary.get('relative_path') for summary in summaries[-data_loader_generator.MAX_DATA_LOADER_SIZE:]
]
assert sorted(loader_dict.keys()) == sorted(expected_ids)
shutil.rmtree(summary_base_dir)
......
......@@ -23,12 +23,13 @@ import shutil
import tempfile
import pytest
from tests.ut.datavisual.mock import MockLogger
from mindinsight.datavisual.common.exceptions import SummaryLogPathInvalid
from mindinsight.datavisual.data_transform import data_loader
from mindinsight.datavisual.data_transform.data_loader import DataLoader
from ..mock import MockLogger
class TestDataLoader:
"""Test data_loader."""
......@@ -37,13 +38,13 @@ class TestDataLoader:
def setup_class(cls):
data_loader.logger = MockLogger
def setup_method(self, method):
def setup_method(self):
self._summary_dir = tempfile.mkdtemp()
if os.path.exists(self._summary_dir):
shutil.rmtree(self._summary_dir)
os.mkdir(self._summary_dir)
def teardown_method(self, method):
def teardown_method(self):
if os.path.exists(self._summary_dir):
shutil.rmtree(self._summary_dir)
......
......@@ -18,32 +18,29 @@ Function:
Usage:
pytest tests/ut/datavisual
"""
import time
import os
import shutil
import tempfile
import time
from unittest import mock
from unittest.mock import Mock
from unittest.mock import patch
from unittest.mock import Mock, patch
import pytest
from tests.ut.datavisual.mock import MockLogger
from tests.ut.datavisual.utils.utils import check_loading_done
from mindinsight.datavisual.common.enums import DataManagerStatus, PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager, ms_data_loader
from mindinsight.datavisual.data_transform.data_loader import DataLoader
from mindinsight.datavisual.data_transform.data_manager import DataManager
from mindinsight.datavisual.data_transform.events_data import EventsData
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import \
DataLoaderGenerator
from mindinsight.datavisual.data_transform.loader_generators.loader_generator import \
MAX_DATA_LOADER_SIZE
from mindinsight.datavisual.data_transform.loader_generators.loader_struct import \
LoaderStruct
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE
from mindinsight.datavisual.data_transform.loader_generators.loader_struct import LoaderStruct
from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader
from mindinsight.utils.exceptions import ParamValueError
from ....utils.tools import check_loading_done
from ..mock import MockLogger
class TestDataManager:
"""Test data_manager."""
......@@ -101,11 +98,17 @@ class TestDataManager:
"and loader pool size is '3'."
shutil.rmtree(summary_base_dir)
@pytest.mark.parametrize('params',
[{'reload_interval': '30'},
{'reload_interval': -1},
{'reload_interval': 30, 'max_threads_count': '20'},
{'reload_interval': 30, 'max_threads_count': 0}])
@pytest.mark.parametrize('params', [{
'reload_interval': '30'
}, {
'reload_interval': -1
}, {
'reload_interval': 30,
'max_threads_count': '20'
}, {
'reload_interval': 30,
'max_threads_count': 0
}])
def test_start_load_data_with_invalid_params(self, params):
"""Test start_load_data with invalid reload_interval or invalid max_threads_count."""
summary_base_dir = tempfile.mkdtemp()
......
......@@ -22,20 +22,24 @@ import threading
from collections import namedtuple
import pytest
from tests.ut.datavisual.mock import MockLogger
from mindinsight.conf import settings
from mindinsight.datavisual.data_transform import events_data
from mindinsight.datavisual.data_transform.events_data import EventsData, TensorEvent, _Tensor
from ..mock import MockLogger
class MockReservoir:
"""Use this class to replace reservoir.Reservoir in test."""
def __init__(self, size):
self.size = size
self._samples = [_Tensor('wall_time1', 1, 'value1'), _Tensor('wall_time2', 2, 'value2'),
_Tensor('wall_time3', 3, 'value3')]
self._samples = [
_Tensor('wall_time1', 1, 'value1'),
_Tensor('wall_time2', 2, 'value2'),
_Tensor('wall_time3', 3, 'value3')
]
def samples(self):
"""Replace the samples function."""
......@@ -63,11 +67,12 @@ class TestEventsData:
def setup_method(self):
"""Mock original logger, init a EventsData object for use."""
self._ev_data = EventsData()
self._ev_data._tags_by_plugin = {'plugin_name1': [f'tag{i}' for i in range(10)],
'plugin_name2': [f'tag{i}' for i in range(20, 30)]}
self._ev_data._tags_by_plugin = {
'plugin_name1': [f'tag{i}' for i in range(10)],
'plugin_name2': [f'tag{i}' for i in range(20, 30)]
}
self._ev_data._tags_by_plugin_mutex_lock.update({'plugin_name1': threading.Lock()})
self._ev_data._reservoir_by_tag = {'tag0': MockReservoir(500),
'new_tag': MockReservoir(500)}
self._ev_data._reservoir_by_tag = {'tag0': MockReservoir(500), 'new_tag': MockReservoir(500)}
self._ev_data._tags = [f'tag{i}' for i in range(settings.MAX_TAG_SIZE_PER_EVENTS_DATA)]
def get_ev_data(self):
......@@ -102,8 +107,7 @@ class TestEventsData:
"""Test add_tensor_event success."""
ev_data = self.get_ev_data()
t_event = TensorEvent(wall_time=1, step=4, tag='new_tag', plugin_name='plugin_name1',
value='value1')
t_event = TensorEvent(wall_time=1, step=4, tag='new_tag', plugin_name='plugin_name1', value='value1')
ev_data.add_tensor_event(t_event)
assert 'tag0' not in ev_data._tags
......@@ -111,6 +115,5 @@ class TestEventsData:
assert 'tag0' not in ev_data._tags_by_plugin['plugin_name1']
assert 'tag0' not in ev_data._reservoir_by_tag
assert 'new_tag' in ev_data._tags_by_plugin['plugin_name1']
assert ev_data._reservoir_by_tag['new_tag'].samples()[-1] == _Tensor(t_event.wall_time,
t_event.step,
assert ev_data._reservoir_by_tag['new_tag'].samples()[-1] == _Tensor(t_event.wall_time, t_event.step,
t_event.value)
......@@ -19,16 +19,17 @@ Usage:
pytest tests/ut/datavisual
"""
import os
import tempfile
import shutil
import tempfile
from unittest.mock import Mock
import pytest
from tests.ut.datavisual.mock import MockLogger
from mindinsight.datavisual.data_transform import ms_data_loader
from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader
from ..mock import MockLogger
# bytes of 3 scalar events
SCALAR_RECORD = (b'\x1e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x96\xe1\xeb)>}\xd7A\x10\x01*'
b'\x11\n\x0f\n\x08tag_name\x1d\r\x06V>\x00\x00\x00\x00\x1e\x00\x00\x00\x00\x00\x00'
......@@ -74,7 +75,8 @@ class TestMsDataLoader:
"we will reload all files in path {}.".format(summary_dir)
shutil.rmtree(summary_dir)
def test_load_success_with_crc_pass(self, crc_pass):
@pytest.mark.usefixtures('crc_pass')
def test_load_success_with_crc_pass(self):
"""Test load success."""
summary_dir = tempfile.mkdtemp()
file1 = os.path.join(summary_dir, 'summary.01')
......@@ -88,7 +90,8 @@ class TestMsDataLoader:
tensors = ms_loader.get_events_data().tensors(tag[0])
assert len(tensors) == 3
def test_load_with_crc_fail(self, crc_fail):
@pytest.mark.usefixtures('crc_fail')
def test_load_with_crc_fail(self):
"""Test when crc_fail and will not go to func _event_parse."""
summary_dir = tempfile.mkdtemp()
file2 = os.path.join(summary_dir, 'summary.02')
......@@ -100,8 +103,10 @@ class TestMsDataLoader:
def test_filter_event_files(self):
"""Test filter_event_files function ok."""
file_list = ['abc.summary', '123sumary0009abc', 'summary1234', 'aaasummary.5678',
'summary.0012', 'hellosummary.98786', 'mysummary.123abce', 'summay.4567']
file_list = [
'abc.summary', '123sumary0009abc', 'summary1234', 'aaasummary.5678', 'summary.0012', 'hellosummary.98786',
'mysummary.123abce', 'summay.4567'
]
summary_dir = tempfile.mkdtemp()
for file in file_list:
with open(os.path.join(summary_dir, file), 'w'):
......@@ -113,6 +118,7 @@ class TestMsDataLoader:
shutil.rmtree(summary_dir)
def write_file(filename, record):
"""Write bytes strings to file."""
with open(filename, 'wb') as file:
......
......@@ -19,18 +19,11 @@ Usage:
pytest tests/ut/datavisual
"""
import os
import json
import tempfile
from unittest.mock import Mock
from unittest.mock import patch
from unittest.mock import Mock, patch
import pytest
from tests.ut.datavisual.mock import MockLogger
from tests.ut.datavisual.utils.log_operations import LogOperations
from tests.ut.datavisual.utils.utils import check_loading_done, delete_files_or_dirs
from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager
......@@ -40,6 +33,10 @@ from mindinsight.datavisual.processors.graph_processor import GraphProcessor
from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import ParamValueError
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, compare_result_with_file, delete_files_or_dirs
from ..mock import MockLogger
class TestGraphProcessor:
"""Test Graph Processor api."""
......@@ -70,18 +67,13 @@ class TestGraphProcessor:
"""Load graph record."""
summary_base_dir = tempfile.mkdtemp()
log_dir = tempfile.mkdtemp(dir=summary_base_dir)
self._train_id = log_dir.replace(summary_base_dir, ".")
graph_base_path = os.path.join(os.path.dirname(__file__),
os.pardir, "utils", "log_generators", "graph_base.json")
self._temp_path, self._graph_dict = LogOperations.generate_log(
PluginNameEnum.GRAPH.value, log_dir, dict(graph_base_path=graph_base_path))
log_operation = LogOperations()
self._temp_path, self._graph_dict = log_operation.generate_log(PluginNameEnum.GRAPH.value, log_dir)
self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager(
[DataLoaderGenerator(summary_base_dir)])
self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)])
self._mock_data_manager.start_load_data(reload_interval=0)
# wait for loading done
......@@ -94,33 +86,29 @@ class TestGraphProcessor:
log_dir = tempfile.mkdtemp(dir=summary_base_dir)
self._train_id = log_dir.replace(summary_base_dir, ".")
self._temp_path, _, _ = LogOperations.generate_log(
PluginNameEnum.IMAGE.value, log_dir, dict(steps=self._steps_list, tag="image"))
log_operation = LogOperations()
self._temp_path, _, _ = log_operation.generate_log(PluginNameEnum.IMAGE.value, log_dir,
dict(steps=self._steps_list, tag="image"))
self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager(
[DataLoaderGenerator(summary_base_dir)])
self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)])
self._mock_data_manager.start_load_data(reload_interval=0)
# wait for loading done
check_loading_done(self._mock_data_manager, time_limit=5)
def compare_result_with_file(self, result, filename):
"""Compare result with file which contain the expected results."""
with open(os.path.join(self.graph_results_dir, filename), 'r') as fp:
expected_results = json.load(fp)
assert result == expected_results
def test_get_nodes_with_not_exist_train_id(self, load_graph_record):
@pytest.mark.usefixtures('load_graph_record')
def test_get_nodes_with_not_exist_train_id(self):
"""Test getting nodes with not exist train id."""
test_train_id = "not_exist_train_id"
with pytest.raises(ParamValueError) as exc_info:
GraphProcessor(test_train_id, self._mock_data_manager)
assert "Can not find the train job in data manager." in exc_info.value.message
@pytest.mark.usefixtures('load_graph_record')
@patch.object(DataManager, 'get_train_job_by_plugin')
def test_get_nodes_with_loader_is_none(self, mock_get_train_job_by_plugin, load_graph_record):
def test_get_nodes_with_loader_is_none(self, mock_get_train_job_by_plugin):
"""Test get nodes with loader is None."""
mock_get_train_job_by_plugin.return_value = None
with pytest.raises(exceptions.SummaryLogPathInvalid):
......@@ -128,15 +116,12 @@ class TestGraphProcessor:
assert mock_get_train_job_by_plugin.called
@pytest.mark.parametrize("name, node_type", [
("not_exist_name", "name_scope"),
("", "polymeric_scope")
])
def test_get_nodes_with_not_exist_name(self, load_graph_record, name, node_type):
@pytest.mark.usefixtures('load_graph_record')
@pytest.mark.parametrize("name, node_type", [("not_exist_name", "name_scope"), ("", "polymeric_scope")])
def test_get_nodes_with_not_exist_name(self, name, node_type):
"""Test getting nodes with not exist name."""
with pytest.raises(ParamValueError) as exc_info:
graph_processor = GraphProcessor(self._train_id,
self._mock_data_manager)
graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
graph_processor.get_nodes(name, node_type)
if name:
......@@ -144,105 +129,99 @@ class TestGraphProcessor:
else:
assert f'The node name "{name}" not in graph, node type is {node_type}.' in exc_info.value.message
@pytest.mark.parametrize("name, node_type, result_file", [
(None, 'name_scope', 'test_get_nodes_success_expected_results1.json'),
@pytest.mark.usefixtures('load_graph_record')
@pytest.mark.parametrize(
"name, node_type, result_file",
[(None, 'name_scope', 'test_get_nodes_success_expected_results1.json'),
('Default/conv1-Conv2d', 'name_scope', 'test_get_nodes_success_expected_results2.json'),
('Default/bn1/Reshape_1_[12]', 'polymeric_scope', 'test_get_nodes_success_expected_results3.json')
])
def test_get_nodes_success(self, load_graph_record, name, node_type, result_file):
('Default/bn1/Reshape_1_[12]', 'polymeric_scope', 'test_get_nodes_success_expected_results3.json')])
def test_get_nodes_success(self, name, node_type, result_file):
"""Test getting nodes successfully."""
graph_processor = GraphProcessor(self._train_id,
self._mock_data_manager)
graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
results = graph_processor.get_nodes(name, node_type)
self.compare_result_with_file(results, result_file)
@pytest.mark.parametrize("search_content, result_file", [
(None, 'test_search_node_names_with_search_content_expected_results1.json'),
expected_file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(results, expected_file_path)
@pytest.mark.usefixtures('load_graph_record')
@pytest.mark.parametrize("search_content, result_file",
[(None, 'test_search_node_names_with_search_content_expected_results1.json'),
('Default/bn1', 'test_search_node_names_with_search_content_expected_results2.json'),
('not_exist_search_content', None)
])
def test_search_node_names_with_search_content(self, load_graph_record,
search_content,
result_file):
('not_exist_search_content', None)])
def test_search_node_names_with_search_content(self, search_content, result_file):
"""Test search node names with search content."""
test_offset = 0
test_limit = 1000
graph_processor = GraphProcessor(self._train_id,
self._mock_data_manager)
results = graph_processor.search_node_names(search_content,
test_offset,
test_limit)
graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
results = graph_processor.search_node_names(search_content, test_offset, test_limit)
if search_content == 'not_exist_search_content':
expected_results = {'names': []}
assert results == expected_results
else:
self.compare_result_with_file(results, result_file)
expected_file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(results, expected_file_path)
@pytest.mark.usefixtures('load_graph_record')
@pytest.mark.parametrize("offset", [-100, -1])
def test_search_node_names_with_negative_offset(self, load_graph_record, offset):
def test_search_node_names_with_negative_offset(self, offset):
"""Test search node names with negative offset."""
test_search_content = ""
test_limit = 3
graph_processor = GraphProcessor(self._train_id,
self._mock_data_manager)
graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
with pytest.raises(ParamValueError) as exc_info:
graph_processor.search_node_names(test_search_content, offset, test_limit)
assert "'offset' should be greater than or equal to 0." in exc_info.value.message
@pytest.mark.parametrize("offset, result_file", [
(1, 'test_search_node_names_with_offset_expected_results1.json')
])
def test_search_node_names_with_offset(self, load_graph_record, offset, result_file):
@pytest.mark.usefixtures('load_graph_record')
@pytest.mark.parametrize("offset, result_file", [(1, 'test_search_node_names_with_offset_expected_results1.json')])
def test_search_node_names_with_offset(self, offset, result_file):
"""Test search node names with offset."""
test_search_content = "Default/bn1"
test_offset = offset
test_limit = 3
graph_processor = GraphProcessor(self._train_id,
self._mock_data_manager)
results = graph_processor.search_node_names(test_search_content,
test_offset,
test_limit)
self.compare_result_with_file(results, result_file)
graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
results = graph_processor.search_node_names(test_search_content, test_offset, test_limit)
expected_file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(results, expected_file_path)
def test_search_node_names_with_wrong_limit(self, load_graph_record):
@pytest.mark.usefixtures('load_graph_record')
def test_search_node_names_with_wrong_limit(self):
"""Test search node names with wrong limit."""
test_search_content = ""
test_offset = 0
test_limit = 0
graph_processor = GraphProcessor(self._train_id,
self._mock_data_manager)
graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
with pytest.raises(ParamValueError) as exc_info:
graph_processor.search_node_names(test_search_content, test_offset,
test_limit)
graph_processor.search_node_names(test_search_content, test_offset, test_limit)
assert "'limit' should in [1, 1000]." in exc_info.value.message
@pytest.mark.parametrize("name, result_file", [
('Default/bn1', 'test_search_single_node_success_expected_results1.json')
])
def test_search_single_node_success(self, load_graph_record, name, result_file):
@pytest.mark.usefixtures('load_graph_record')
@pytest.mark.parametrize("name, result_file",
[('Default/bn1', 'test_search_single_node_success_expected_results1.json')])
def test_search_single_node_success(self, name, result_file):
"""Test searching single node successfully."""
graph_processor = GraphProcessor(self._train_id,
self._mock_data_manager)
graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
results = graph_processor.search_single_node(name)
self.compare_result_with_file(results, result_file)
expected_file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(results, expected_file_path)
def test_search_single_node_with_not_exist_name(self, load_graph_record):
@pytest.mark.usefixtures('load_graph_record')
def test_search_single_node_with_not_exist_name(self):
"""Test searching single node with not exist name."""
test_name = "not_exist_name"
with pytest.raises(exceptions.NodeNotInGraphError):
graph_processor = GraphProcessor(self._train_id,
self._mock_data_manager)
graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
graph_processor.search_single_node(test_name)
def test_check_graph_status_no_graph(self, load_no_graph_record):
@pytest.mark.usefixtures('load_no_graph_record')
def test_check_graph_status_no_graph(self):
"""Test checking graph status no graph."""
with pytest.raises(ParamValueError) as exc_info:
GraphProcessor(self._train_id, self._mock_data_manager)
......
......@@ -22,9 +22,6 @@ import tempfile
from unittest.mock import Mock
import pytest
from tests.ut.datavisual.mock import MockLogger
from tests.ut.datavisual.utils.log_operations import LogOperations
from tests.ut.datavisual.utils.utils import check_loading_done, delete_files_or_dirs
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager
......@@ -33,6 +30,10 @@ from mindinsight.datavisual.processors.images_processor import ImageProcessor
from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import ParamValueError
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, delete_files_or_dirs, get_image_tensor_from_bytes
from ..mock import MockLogger
class TestImagesProcessor:
"""Test images processor api."""
......@@ -73,12 +74,11 @@ class TestImagesProcessor:
"""
summary_base_dir = tempfile.mkdtemp()
log_dir = tempfile.mkdtemp(dir=summary_base_dir)
self._train_id = log_dir.replace(summary_base_dir, ".")
self._temp_path, self._images_metadata, self._images_values = LogOperations.generate_log(
log_operation = LogOperations()
self._temp_path, self._images_metadata, self._images_values = log_operation.generate_log(
PluginNameEnum.IMAGE.value, log_dir, dict(steps=steps_list, tag=self._tag_name))
self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)])
......@@ -102,7 +102,8 @@ class TestImagesProcessor:
"""Load image record."""
self._init_data_manager(self._cross_steps_list)
def test_get_metadata_list_with_not_exist_id(self, load_image_record):
@pytest.mark.usefixtures('load_image_record')
def test_get_metadata_list_with_not_exist_id(self):
"""Test getting metadata list with not exist id."""
test_train_id = 'not_exist_id'
image_processor = ImageProcessor(self._mock_data_manager)
......@@ -112,7 +113,8 @@ class TestImagesProcessor:
assert exc_info.value.error_code == '50540002'
assert "Can not find any data in loader pool about the train job." in exc_info.value.message
def test_get_metadata_list_with_not_exist_tag(self, load_image_record):
@pytest.mark.usefixtures('load_image_record')
def test_get_metadata_list_with_not_exist_tag(self):
"""Test get metadata list with not exist tag."""
test_tag_name = 'not_exist_tag_name'
......@@ -124,7 +126,8 @@ class TestImagesProcessor:
assert exc_info.value.error_code == '50540002'
assert "Can not find any data in this train job by given tag." in exc_info.value.message
def test_get_metadata_list_success(self, load_image_record):
@pytest.mark.usefixtures('load_image_record')
def test_get_metadata_list_success(self):
"""Test getting metadata list success."""
test_tag_name = self._complete_tag_name
......@@ -133,7 +136,8 @@ class TestImagesProcessor:
assert results == self._images_metadata
def test_get_single_image_with_not_exist_id(self, load_image_record):
@pytest.mark.usefixtures('load_image_record')
def test_get_single_image_with_not_exist_id(self):
"""Test getting single image with not exist id."""
test_train_id = 'not_exist_id'
test_tag_name = self._complete_tag_name
......@@ -146,7 +150,8 @@ class TestImagesProcessor:
assert exc_info.value.error_code == '50540002'
assert "Can not find any data in loader pool about the train job." in exc_info.value.message
def test_get_single_image_with_not_exist_tag(self, load_image_record):
@pytest.mark.usefixtures('load_image_record')
def test_get_single_image_with_not_exist_tag(self):
"""Test getting single image with not exist tag."""
test_tag_name = 'not_exist_tag_name'
test_step = self._steps_list[0]
......@@ -159,7 +164,8 @@ class TestImagesProcessor:
assert exc_info.value.error_code == '50540002'
assert "Can not find any data in this train job by given tag." in exc_info.value.message
def test_get_single_image_with_not_exist_step(self, load_image_record):
@pytest.mark.usefixtures('load_image_record')
def test_get_single_image_with_not_exist_step(self):
"""Test getting single image with not exist step."""
test_tag_name = self._complete_tag_name
test_step = 10000
......@@ -172,24 +178,22 @@ class TestImagesProcessor:
assert exc_info.value.error_code == '50540002'
assert "Can not find the step with given train job id and tag." in exc_info.value.message
def test_get_single_image_success(self, load_image_record):
@pytest.mark.usefixtures('load_image_record')
def test_get_single_image_success(self):
"""Test getting single image successfully."""
test_tag_name = self._complete_tag_name
test_step_index = 0
test_step = self._steps_list[test_step_index]
expected_image_tensor = self._images_values.get(test_step)
image_processor = ImageProcessor(self._mock_data_manager)
results = image_processor.get_single_image(self._train_id, test_tag_name, test_step)
expected_image_tensor = self._images_values.get(test_step)
image_generator = LogOperations.get_log_generator(PluginNameEnum.IMAGE.value)
recv_image_tensor = image_generator.get_image_tensor_from_bytes(results)
recv_image_tensor = get_image_tensor_from_bytes(results)
assert recv_image_tensor.any() == expected_image_tensor.any()
def test_reservoir_add_sample(self, load_more_than_limit_image_record):
@pytest.mark.usefixtures('load_more_than_limit_image_record')
def test_reservoir_add_sample(self):
"""Test adding sample in reservoir."""
test_tag_name = self._complete_tag_name
......@@ -206,7 +210,8 @@ class TestImagesProcessor:
cnt += 1
assert len(self._more_steps_list) - cnt == 10
def test_reservoir_remove_sample(self, load_reservoir_remove_sample_image_record):
@pytest.mark.usefixtures('load_reservoir_remove_sample_image_record')
def test_reservoir_remove_sample(self):
"""
Test removing sample in reservoir.
......
......@@ -22,9 +22,6 @@ import tempfile
from unittest.mock import Mock
import pytest
from tests.ut.datavisual.mock import MockLogger
from tests.ut.datavisual.utils.log_operations import LogOperations
from tests.ut.datavisual.utils.utils import check_loading_done, delete_files_or_dirs
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager
......@@ -33,6 +30,10 @@ from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor
from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import ParamValueError
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, delete_files_or_dirs
from ..mock import MockLogger
class TestScalarsProcessor:
"""Test scalar processor api."""
......@@ -65,12 +66,11 @@ class TestScalarsProcessor:
"""Load scalar record."""
summary_base_dir = tempfile.mkdtemp()
log_dir = tempfile.mkdtemp(dir=summary_base_dir)
self._train_id = log_dir.replace(summary_base_dir, ".")
self._temp_path, self._scalars_metadata, self._scalars_values = LogOperations.generate_log(
log_operation = LogOperations()
self._temp_path, self._scalars_metadata, self._scalars_values = log_operation.generate_log(
PluginNameEnum.SCALAR.value, log_dir, dict(step=self._steps_list, tag=self._tag_name))
self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)])
......@@ -79,7 +79,8 @@ class TestScalarsProcessor:
# wait for loading done
check_loading_done(self._mock_data_manager, time_limit=5)
def test_get_metadata_list_with_not_exist_id(self, load_scalar_record):
@pytest.mark.usefixtures('load_scalar_record')
def test_get_metadata_list_with_not_exist_id(self):
"""Get metadata list with not exist id."""
test_train_id = 'not_exist_id'
scalar_processor = ScalarsProcessor(self._mock_data_manager)
......@@ -89,7 +90,8 @@ class TestScalarsProcessor:
assert exc_info.value.error_code == '50540002'
assert "Can not find any data in loader pool about the train job." in exc_info.value.message
def test_get_metadata_list_with_not_exist_tag(self, load_scalar_record):
@pytest.mark.usefixtures('load_scalar_record')
def test_get_metadata_list_with_not_exist_tag(self):
"""Get metadata list with not exist tag."""
test_tag_name = 'not_exist_tag_name'
......@@ -101,7 +103,8 @@ class TestScalarsProcessor:
assert exc_info.value.error_code == '50540002'
assert "Can not find any data in this train job by given tag." in exc_info.value.message
def test_get_metadata_list_success(self, load_scalar_record):
@pytest.mark.usefixtures('load_scalar_record')
def test_get_metadata_list_success(self):
"""Get metadata list success."""
test_tag_name = self._complete_tag_name
......
......@@ -18,15 +18,11 @@ Function:
Usage:
pytest tests/ut/datavisual
"""
import os
import tempfile
import time
from unittest.mock import Mock
import pytest
from tests.ut.datavisual.mock import MockLogger
from tests.ut.datavisual.utils.log_operations import LogOperations
from tests.ut.datavisual.utils.utils import check_loading_done, delete_files_or_dirs
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager
......@@ -35,6 +31,10 @@ from mindinsight.datavisual.processors.train_task_manager import TrainTaskManage
from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import ParamValueError
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, delete_files_or_dirs
from ..mock import MockLogger
class TestTrainTaskManager:
"""Test train task manager."""
......@@ -70,39 +70,30 @@ class TestTrainTaskManager:
@pytest.fixture(scope='function')
def load_data(self):
"""Load data."""
log_operation = LogOperations()
self._plugins_id_map = {'image': [], 'scalar': [], 'graph': []}
self._events_names = []
self._train_id_list = []
graph_base_path = os.path.join(os.path.dirname(__file__),
os.pardir, "utils", "log_generators", "graph_base.json")
self._root_dir = tempfile.mkdtemp()
for i in range(self._dir_num):
dir_path = tempfile.mkdtemp(dir=self._root_dir)
tmp_tag_name = self._tag_name + '_' + str(i)
event_name = str(i) + "_name"
train_id = dir_path.replace(self._root_dir, ".")
# Pass timestamp to write to the same file.
log_settings = dict(
steps=self._steps_list,
tag=tmp_tag_name,
graph_base_path=graph_base_path,
time=time.time())
log_settings = dict(steps=self._steps_list, tag=tmp_tag_name, time=time.time())
if i % 3 != 0:
LogOperations.generate_log(PluginNameEnum.IMAGE.value, dir_path, log_settings)
log_operation.generate_log(PluginNameEnum.IMAGE.value, dir_path, log_settings)
self._plugins_id_map['image'].append(train_id)
if i % 3 != 1:
LogOperations.generate_log(PluginNameEnum.SCALAR.value, dir_path, log_settings)
log_operation.generate_log(PluginNameEnum.SCALAR.value, dir_path, log_settings)
self._plugins_id_map['scalar'].append(train_id)
if i % 3 != 2:
LogOperations.generate_log(PluginNameEnum.GRAPH.value, dir_path, log_settings)
log_operation.generate_log(PluginNameEnum.GRAPH.value, dir_path, log_settings)
self._plugins_id_map['graph'].append(train_id)
self._events_names.append(event_name)
self._train_id_list.append(train_id)
self._generated_path.append(self._root_dir)
......@@ -112,7 +103,8 @@ class TestTrainTaskManager:
check_loading_done(self._mock_data_manager, time_limit=30)
def test_get_single_train_task_with_not_exists_train_id(self, load_data):
@pytest.mark.usefixtures('load_data')
def test_get_single_train_task_with_not_exists_train_id(self):
"""Test getting single train task with not exists train_id."""
train_task_manager = TrainTaskManager(self._mock_data_manager)
for plugin_name in PluginNameEnum.list_members():
......@@ -124,7 +116,8 @@ class TestTrainTaskManager:
"the train job in data manager."
assert exc_info.value.error_code == '50540002'
def test_get_single_train_task_with_params(self, load_data):
@pytest.mark.usefixtures('load_data')
def test_get_single_train_task_with_params(self):
"""Test getting single train task with params."""
train_task_manager = TrainTaskManager(self._mock_data_manager)
for plugin_name in PluginNameEnum.list_members():
......@@ -138,7 +131,8 @@ class TestTrainTaskManager:
else:
assert test_train_id not in self._plugins_id_map.get(plugin_name)
def test_get_plugins_with_train_id(self, load_data):
@pytest.mark.usefixtures('load_data')
def test_get_plugins_with_train_id(self):
"""Test getting plugins with train id."""
train_task_manager = TrainTaskManager(self._mock_data_manager)
......
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mask string to crc32."""
CRC_TABLE_32 = (
0x00000000, 0xF26B8303, 0xE13B70F7, 0x1350F3F4, 0xC79A971F, 0x35F1141C, 0x26A1E7E8, 0xD4CA64EB, 0x8AD958CF,
0x78B2DBCC, 0x6BE22838, 0x9989AB3B, 0x4D43CFD0, 0xBF284CD3, 0xAC78BF27, 0x5E133C24, 0x105EC76F, 0xE235446C,
0xF165B798, 0x030E349B, 0xD7C45070, 0x25AFD373, 0x36FF2087, 0xC494A384, 0x9A879FA0, 0x68EC1CA3, 0x7BBCEF57,
0x89D76C54, 0x5D1D08BF, 0xAF768BBC, 0xBC267848, 0x4E4DFB4B, 0x20BD8EDE, 0xD2D60DDD, 0xC186FE29, 0x33ED7D2A,
0xE72719C1, 0x154C9AC2, 0x061C6936, 0xF477EA35, 0xAA64D611, 0x580F5512, 0x4B5FA6E6, 0xB93425E5, 0x6DFE410E,
0x9F95C20D, 0x8CC531F9, 0x7EAEB2FA, 0x30E349B1, 0xC288CAB2, 0xD1D83946, 0x23B3BA45, 0xF779DEAE, 0x05125DAD,
0x1642AE59, 0xE4292D5A, 0xBA3A117E, 0x4851927D, 0x5B016189, 0xA96AE28A, 0x7DA08661, 0x8FCB0562, 0x9C9BF696,
0x6EF07595, 0x417B1DBC, 0xB3109EBF, 0xA0406D4B, 0x522BEE48, 0x86E18AA3, 0x748A09A0, 0x67DAFA54, 0x95B17957,
0xCBA24573, 0x39C9C670, 0x2A993584, 0xD8F2B687, 0x0C38D26C, 0xFE53516F, 0xED03A29B, 0x1F682198, 0x5125DAD3,
0xA34E59D0, 0xB01EAA24, 0x42752927, 0x96BF4DCC, 0x64D4CECF, 0x77843D3B, 0x85EFBE38, 0xDBFC821C, 0x2997011F,
0x3AC7F2EB, 0xC8AC71E8, 0x1C661503, 0xEE0D9600, 0xFD5D65F4, 0x0F36E6F7, 0x61C69362, 0x93AD1061, 0x80FDE395,
0x72966096, 0xA65C047D, 0x5437877E, 0x4767748A, 0xB50CF789, 0xEB1FCBAD, 0x197448AE, 0x0A24BB5A, 0xF84F3859,
0x2C855CB2, 0xDEEEDFB1, 0xCDBE2C45, 0x3FD5AF46, 0x7198540D, 0x83F3D70E, 0x90A324FA, 0x62C8A7F9, 0xB602C312,
0x44694011, 0x5739B3E5, 0xA55230E6, 0xFB410CC2, 0x092A8FC1, 0x1A7A7C35, 0xE811FF36, 0x3CDB9BDD, 0xCEB018DE,
0xDDE0EB2A, 0x2F8B6829, 0x82F63B78, 0x709DB87B, 0x63CD4B8F, 0x91A6C88C, 0x456CAC67, 0xB7072F64, 0xA457DC90,
0x563C5F93, 0x082F63B7, 0xFA44E0B4, 0xE9141340, 0x1B7F9043, 0xCFB5F4A8, 0x3DDE77AB, 0x2E8E845F, 0xDCE5075C,
0x92A8FC17, 0x60C37F14, 0x73938CE0, 0x81F80FE3, 0x55326B08, 0xA759E80B, 0xB4091BFF, 0x466298FC, 0x1871A4D8,
0xEA1A27DB, 0xF94AD42F, 0x0B21572C, 0xDFEB33C7, 0x2D80B0C4, 0x3ED04330, 0xCCBBC033, 0xA24BB5A6, 0x502036A5,
0x4370C551, 0xB11B4652, 0x65D122B9, 0x97BAA1BA, 0x84EA524E, 0x7681D14D, 0x2892ED69, 0xDAF96E6A, 0xC9A99D9E,
0x3BC21E9D, 0xEF087A76, 0x1D63F975, 0x0E330A81, 0xFC588982, 0xB21572C9, 0x407EF1CA, 0x532E023E, 0xA145813D,
0x758FE5D6, 0x87E466D5, 0x94B49521, 0x66DF1622, 0x38CC2A06, 0xCAA7A905, 0xD9F75AF1, 0x2B9CD9F2, 0xFF56BD19,
0x0D3D3E1A, 0x1E6DCDEE, 0xEC064EED, 0xC38D26C4, 0x31E6A5C7, 0x22B65633, 0xD0DDD530, 0x0417B1DB, 0xF67C32D8,
0xE52CC12C, 0x1747422F, 0x49547E0B, 0xBB3FFD08, 0xA86F0EFC, 0x5A048DFF, 0x8ECEE914, 0x7CA56A17, 0x6FF599E3,
0x9D9E1AE0, 0xD3D3E1AB, 0x21B862A8, 0x32E8915C, 0xC083125F, 0x144976B4, 0xE622F5B7, 0xF5720643, 0x07198540,
0x590AB964, 0xAB613A67, 0xB831C993, 0x4A5A4A90, 0x9E902E7B, 0x6CFBAD78, 0x7FAB5E8C, 0x8DC0DD8F, 0xE330A81A,
0x115B2B19, 0x020BD8ED, 0xF0605BEE, 0x24AA3F05, 0xD6C1BC06, 0xC5914FF2, 0x37FACCF1, 0x69E9F0D5, 0x9B8273D6,
0x88D28022, 0x7AB90321, 0xAE7367CA, 0x5C18E4C9, 0x4F48173D, 0xBD23943E, 0xF36E6F75, 0x0105EC76, 0x12551F82,
0xE03E9C81, 0x34F4F86A, 0xC69F7B69, 0xD5CF889D, 0x27A40B9E, 0x79B737BA, 0x8BDCB4B9, 0x988C474D, 0x6AE7C44E,
0xBE2DA0A5, 0x4C4623A6, 0x5F16D052, 0xAD7D5351
)
_CRC = 0
_MASK = 0xFFFFFFFF
def _uint32(x):
"""Transform x's type to uint32."""
return x & 0xFFFFFFFF
def _get_crc_checksum(crc, data):
"""Get crc checksum."""
crc ^= _MASK
for d in data:
crc_table_index = (crc ^ d) & 0xFF
crc = (CRC_TABLE_32[crc_table_index] ^ (crc >> 8)) & _MASK
crc ^= _MASK
return crc
def get_mask_from_string(data):
"""
Get masked crc from data.
Args:
data (byte): Byte string of data.
Returns:
uint32, masked crc.
"""
crc = _get_crc_checksum(_CRC, data)
crc = _uint32(crc & _MASK)
crc = _uint32(((crc >> 15) | _uint32(crc << 17)) + 0xA282EAD8)
return crc
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Log generator for images."""
import io
import time
import numpy as np
from PIL import Image
from tests.ut.datavisual.utils.log_generators.log_generator import LogGenerator
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
class ImagesLogGenerator(LogGenerator):
"""
Log generator for images.
This is a log generator writing images. User can use it to generate fake
summary logs about images.
"""
def generate_event(self, values):
"""
Method for generating image event.
Args:
values (dict): A dict contains:
{
wall_time (float): Timestamp.
step (int): Train step.
image (np.array): Pixels tensor.
tag (str): Tag name.
}
Returns:
summary_pb2.Event.
"""
image_event = summary_pb2.Event()
image_event.wall_time = values.get('wall_time')
image_event.step = values.get('step')
height, width, channel, image_string = self._get_image_string(values.get('image'))
value = image_event.summary.value.add()
value.tag = values.get('tag')
value.image.height = height
value.image.width = width
value.image.colorspace = channel
value.image.encoded_image = image_string
return image_event
def _get_image_string(self, image_tensor):
"""
Generate image string from tensor.
Args:
image_tensor (np.array): Pixels tensor.
Returns:
int, height.
int, width.
int, channel.
bytes, image_string.
"""
height, width, channel = image_tensor.shape
scaled_height = int(height)
scaled_width = int(width)
image = Image.fromarray(image_tensor)
image = image.resize((scaled_width, scaled_height), Image.ANTIALIAS)
output = io.BytesIO()
image.save(output, format='PNG')
image_string = output.getvalue()
output.close()
return height, width, channel, image_string
def _make_image_tensor(self, shape):
"""
Make image tensor according to shape.
Args:
shape (list): Shape of image, consists of height, width, channel.
Returns:
np.array, image tensor.
"""
image = np.prod(shape)
image_tensor = (np.arange(image, dtype=float)).reshape(shape)
image_tensor = image_tensor / np.max(image_tensor) * 255
image_tensor = image_tensor.astype(np.uint8)
return image_tensor
def generate_log(self, file_path, steps_list, tag_name):
"""
Generate log for external calls.
Args:
file_path (str): Path to write logs.
steps_list (list): A list consists of step.
tag_name (str): Tag name.
Returns:
list[dict], generated image metadata.
dict, generated image tensors.
"""
images_values = dict()
images_metadata = []
for step in steps_list:
wall_time = time.time()
# height, width, channel
image_tensor = self._make_image_tensor([5, 5, 3])
image_metadata = dict()
image_metadata.update({'wall_time': wall_time})
image_metadata.update({'step': step})
image_metadata.update({'height': image_tensor.shape[0]})
image_metadata.update({'width': image_tensor.shape[1]})
images_metadata.append(image_metadata)
images_values.update({step: image_tensor})
values = dict(
wall_time=wall_time,
step=step,
image=image_tensor,
tag=tag_name
)
self._write_log_one_step(file_path, values)
return images_metadata, images_values
def get_image_tensor_from_bytes(self, image_string):
"""Get image tensor from bytes."""
img = Image.open(io.BytesIO(image_string))
image_tensor = np.array(img)
return image_tensor
if __name__ == "__main__":
images_log_generator = ImagesLogGenerator()
test_file_name = '%s.%s.%s' % ('image', 'summary', str(time.time()))
test_steps = [1, 3, 5]
test_tags = "test_image_tag_name"
images_log_generator.generate_log(test_file_name, test_steps, test_tags)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Base log Generator."""
import struct
from abc import abstractmethod
from tests.ut.datavisual.utils import crc32
class LogGenerator:
"""
Base log generator.
This is a base class for log generators. User can use it to generate fake
summary logs.
"""
@abstractmethod
def generate_event(self, values):
"""
Abstract method for generating event.
Args:
values (dict): Values.
Returns:
summary_pb2.Event.
"""
def _write_log_one_step(self, file_path, values):
"""
Write log one step.
Args:
file_path (str): File path to write.
values (dict): Values.
"""
event = self.generate_event(values)
self._write_log_from_event(file_path, event)
@staticmethod
def _write_log_from_event(file_path, event):
"""
Write log by event.
Args:
file_path (str): File path to write.
event (summary_pb2.Event): Event object in proto.
"""
send_msg = event.SerializeToString()
header = struct.pack('<Q', len(send_msg))
header_crc = struct.pack('<I', crc32.get_mask_from_string(header))
footer_crc = struct.pack('<I', crc32.get_mask_from_string(send_msg))
write_event = header + header_crc + send_msg + footer_crc
with open(file_path, "ab") as f:
f.write(write_event)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Log generator for scalars."""
import time
import numpy as np
from tests.ut.datavisual.utils.log_generators.log_generator import LogGenerator
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
class ScalarsLogGenerator(LogGenerator):
"""
Log generator for scalars.
This is a log generator writing scalars. User can use it to generate fake
summary logs about scalar.
"""
def generate_event(self, values):
"""
Method for generating scalar event.
Args:
values (dict): A dict contains:
{
wall_time (float): Timestamp.
step (int): Train step.
value (float): Scalar value.
tag (str): Tag name.
}
Returns:
summary_pb2.Event.
"""
scalar_event = summary_pb2.Event()
scalar_event.wall_time = values.get('wall_time')
scalar_event.step = values.get('step')
value = scalar_event.summary.value.add()
value.tag = values.get('tag')
value.scalar_value = values.get('value')
return scalar_event
def generate_log(self, file_path, steps_list, tag_name):
"""
Generate log for external calls.
Args:
file_path (str): Path to write logs.
steps_list (list): A list consists of step.
tag_name (str): Tag name.
Returns:
list[dict], generated scalar metadata.
None, to be consistent with return value of ImageGenerator.
"""
scalars_metadata = []
for step in steps_list:
scalar_metadata = dict()
wall_time = time.time()
value = np.random.rand()
scalar_metadata.update({'wall_time': wall_time})
scalar_metadata.update({'step': step})
scalar_metadata.update({'value': value})
scalars_metadata.append(scalar_metadata)
scalar_metadata.update({"tag": tag_name})
self._write_log_one_step(file_path, scalar_metadata)
return scalars_metadata, None
if __name__ == "__main__":
scalars_log_generator = ScalarsLogGenerator()
test_file_name = '%s.%s.%s' % ('scalar', 'summary', str(time.time()))
test_steps = [1, 3, 5]
test_tag = "test_scalar_tag_name"
scalars_log_generator.generate_log(test_file_name, test_steps, test_tag)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""log operations module."""
import json
import os
import time
from tests.ut.datavisual.utils.log_generators.graph_log_generator import GraphLogGenerator
from tests.ut.datavisual.utils.log_generators.images_log_generator import ImagesLogGenerator
from tests.ut.datavisual.utils.log_generators.scalars_log_generator import ScalarsLogGenerator
from mindinsight.datavisual.common.enums import PluginNameEnum
log_generators = {
PluginNameEnum.GRAPH.value: GraphLogGenerator(),
PluginNameEnum.IMAGE.value: ImagesLogGenerator(),
PluginNameEnum.SCALAR.value: ScalarsLogGenerator()
}
class LogOperations:
"""Log Operations class."""
@staticmethod
def generate_log(plugin_name, log_dir, log_settings, valid=True):
"""
Generate log.
Args:
plugin_name (str): Plugin name, contains 'graph', 'image', and 'scalar'.
log_dir (str): Log path to write log.
log_settings (dict): Info about the log, e.g.:
{
current_time (int): Timestamp in summary file name, not necessary.
graph_base_path (str): Path of graph_bas.json, necessary for `graph`.
steps (list[int]): Steps for `image` and `scalar`, default is [1].
tag (str): Tag name, default is 'default_tag'.
}
valid (bool): If true, summary name will be valid.
Returns:
str, Summary log path.
"""
current_time = log_settings.get('time', int(time.time()))
current_time = int(current_time)
log_generator = log_generators.get(plugin_name)
if valid:
temp_path = os.path.join(log_dir, '%s.%s' % ('test.summary', str(current_time)))
else:
temp_path = os.path.join(log_dir, '%s.%s' % ('test.invalid', str(current_time)))
if plugin_name == PluginNameEnum.GRAPH.value:
graph_base_path = log_settings.get('graph_base_path')
with open(graph_base_path, 'r') as load_f:
graph_dict = json.load(load_f)
graph_dict = log_generator.generate_log(temp_path, graph_dict)
return temp_path, graph_dict
steps_list = log_settings.get('steps', [1])
tag_name = log_settings.get('tag', 'default_tag')
metadata, values = log_generator.generate_log(temp_path, steps_list, tag_name)
return temp_path, metadata, values
@staticmethod
def get_log_generator(plugin_name):
"""Get log generator."""
return log_generators.get(plugin_name)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Description: This file is used for some common util.
"""
import os
import shutil
import time
from urllib.parse import urlencode
from mindinsight.datavisual.common.enums import DataManagerStatus
def get_url(url, params):
"""
Concatenate the URL and params.
Args:
url (str): A link requested. For example, http://example.com.
params (dict): A dict consists of params. For example, {'offset': 1, 'limit':'100}.
Returns:
str, like http://example.com?offset=1&limit=100
"""
return url + '?' + urlencode(params)
def delete_files_or_dirs(path_list):
"""Delete files or dirs in path_list."""
for path in path_list:
if os.path.isdir(path):
shutil.rmtree(path)
else:
os.remove(path)
def check_loading_done(data_manager, time_limit=15):
"""If loading data for more than `time_limit` seconds, exit."""
start_time = time.time()
while data_manager.status != DataManagerStatus.DONE.value:
time_used = time.time() - start_time
if time_used > time_limit:
break
time.sleep(0.1)
continue
......@@ -14,6 +14,6 @@
# ============================================================================
"""Import the mocked mindspore."""
import sys
from .collection.model import mindspore
from ...utils import mindspore
sys.modules['mindspore'] = mindspore
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock MindSpore Interface."""
from .application.model_zoo.resnet import ResNet
from .common.tensor import Tensor
from .dataset import MindDataset
from .nn import *
from .train.callback import _ListCallback, Callback, RunContext, ModelCheckpoint, SummaryStep
from .train.summary import SummaryRecord
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the MindSpore ResNet class."""
from ...nn.cell import Cell
class ResNet(Cell):
"""Mocked ResNet."""
def __init__(self):
super(ResNet, self).__init__()
self._cells = {}
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the MindSpore mindspore/common/tensor.py."""
import numpy as np
class Tensor:
"""Mock the MindSpore Tensor class."""
def __init__(self, value=0):
self._value = value
def asnumpy(self):
"""Get value in numpy format."""
return np.array(self._value)
def __repr__(self):
return str(self.asnumpy())
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MindSpore Mock Interface"""
def get_context(key):
"""Get key in context."""
context = {"device_id": 1}
return context.get(key)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock mindspore.dataset."""
from .engine import MindDataset
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock mindspore.dataset.engine."""
from .datasets import MindDataset, Dataset
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the MindSpore mindspore/dataset/engine/datasets.py."""
class Dataset:
"""Mock the MindSpore Dataset class."""
def __init__(self, dataset_size=None, dataset_path=None):
self.dataset_size = dataset_size
self.dataset_path = dataset_path
self.input = []
def get_dataset_size(self):
"""Mocked get_dataset_size."""
return self.dataset_size
class MindDataset(Dataset):
"""Mock the MindSpore MindDataset class."""
def __init__(self, dataset_size=None, dataset_file=None):
super(MindDataset, self).__init__(dataset_size)
self.dataset_file = dataset_file
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the mindspore.nn package."""
from .optim import Optimizer, Momentum
from .loss.loss import SoftmaxCrossEntropyWithLogits, _Loss
from .cell import Cell, WithLossCell, TrainOneStepWithLossScaleCell
__all__ = ['Optimizer', 'Momentum', 'SoftmaxCrossEntropyWithLogits',
'_Loss', 'Cell', 'WithLossCell',
'TrainOneStepWithLossScaleCell']
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the MindSpore mindspore/train/callback.py."""
class Cell:
"""Mock the Cell class."""
def __init__(self, auto_prefix=True, pips=None):
if pips is None:
pips = dict()
self._auto_prefix = auto_prefix
self._pips = pips
@property
def auto_prefix(self):
"""The property of auto_prefix."""
return self._auto_prefix
@property
def pips(self):
"""The property of pips."""
return self._pips
class WithLossCell(Cell):
"""Mocked WithLossCell class."""
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False, pips=backbone.pips)
self._backbone = backbone
self._loss_fn = loss_fn
class TrainOneStepWithLossScaleCell(Cell):
"""Mocked TrainOneStepWithLossScaleCell."""
def __init__(self):
super(TrainOneStepWithLossScaleCell, self).__init__()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the MindSpore SoftmaxCrossEntropyWithLogits class."""
from ..cell import Cell
class _Loss(Cell):
"""Mocked _Loss."""
def __init__(self, reduction='mean'):
super(_Loss, self).__init__()
self.reduction = reduction
def construct(self, base, target):
"""Mocked construct function."""
raise NotImplementedError
class SoftmaxCrossEntropyWithLogits(_Loss):
"""Mocked SoftmaxCrossEntropyWithLogits."""
def __init__(self, weight=None):
super(SoftmaxCrossEntropyWithLogits, self).__init__(weight)
def construct(self, base, target):
"""Mocked construct."""
return 1
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the MindSpore mindspore/nn/optim.py."""
from .cell import Cell
class Parameter:
"""Mock the MindSpore Parameter class."""
def __init__(self, learning_rate):
self._name = "Parameter"
self.default_input = learning_rate
@property
def name(self):
"""The property of name."""
return self._name
def __repr__(self):
format_str = 'Parameter (name={name})'
return format_str.format(name=self._name)
class Optimizer(Cell):
"""Mock the MindSpore Optimizer class."""
def __init__(self, learning_rate):
super(Optimizer, self).__init__()
self.learning_rate = Parameter(learning_rate)
class Momentum(Optimizer):
"""Mock the MindSpore Momentum class."""
def __init__(self, learning_rate):
super(Momentum, self).__init__(learning_rate)
self.dynamic_lr = False
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock MindSpore wrap package."""
from .loss_scale import TrainOneStepWithLossScaleCell
from .cell_wrapper import WithLossCell
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock MindSpore cell_wrapper.py."""
from ..cell import Cell
class WithLossCell(Cell):
"""Mock the WithLossCell class."""
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__()
self._backbone = backbone
self._loss_fn = loss_fn
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock MindSpore loss_scale.py."""
from ..cell import Cell
class TrainOneStepWithLossScaleCell(Cell):
"""Mock the TrainOneStepWithLossScaleCell class."""
def __init__(self, network, optimizer):
super(TrainOneStepWithLossScaleCell, self).__init__()
self.network = network
self.optimizer = optimizer
def construct(self, data, label):
"""Mock the construct method."""
raise NotImplementedError
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the MindSpore mindspore/train/callback.py."""
import os
class RunContext:
"""Mock the RunContext class."""
def __init__(self, original_args=None):
self._original_args = original_args
self._stop_requested = False
def original_args(self):
"""Mock original_args."""
return self._original_args
def stop_requested(self):
"""Mock stop_requested method."""
return self._stop_requested
class Callback:
"""Mock the Callback class."""
def __init__(self):
pass
def begin(self, run_context):
"""Called once before network training."""
def epoch_begin(self, run_context):
"""Called before each epoch begin."""
class _ListCallback(Callback):
"""Mock the _ListCallabck class."""
def __init__(self, callbacks):
super(_ListCallback, self).__init__()
self._callbacks = callbacks
class ModelCheckpoint(Callback):
"""Mock the ModelCheckpoint class."""
def __init__(self, prefix='CKP', directory=None, config=None):
super(ModelCheckpoint, self).__init__()
self._prefix = prefix
self._directory = directory
self._config = config
self._latest_ckpt_file_name = os.path.join(directory, prefix + 'test_model.ckpt')
@property
def model_file_name(self):
"""Get the file name of model."""
return self._model_file_name
@property
def latest_ckpt_file_name(self):
"""Get the latest file name fo checkpoint."""
return self._latest_ckpt_file_name
class SummaryStep(Callback):
"""Mock the SummaryStep class."""
def __init__(self, summary, flush_step=10):
super(SummaryStep, self).__init__()
self._sumamry = summary
self._flush_step = flush_step
self.summary_file_name = summary.full_file_name
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MindSpore Mock Interface"""
from .summary_record import SummaryRecord
__all__ = ["SummaryRecord"]
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MindSpore Mock Interface"""
import os
import time
class SummaryRecord:
"""Mock the MindSpore SummaryRecord class."""
def __init__(self,
log_dir: str,
file_prefix: str = "events.",
file_suffix: str = ".MS",
create_time=int(time.time())):
self.log_dir = log_dir
self.prefix = file_prefix
self.suffix = file_suffix
file_name = file_prefix + 'summary.' + str(create_time) + file_suffix
self.full_file_name = os.path.join(log_dir, file_name)
def flush(self):
"""Mock flush method."""
def close(self):
"""Mock close method."""
......@@ -20,13 +20,7 @@ from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.lineagemgr.common.path_parser import SummaryPathParser
class TestSummaryPathParser(TestCase):
"""Test the class of SummaryPathParser."""
@mock.patch.object(SummaryWatcher, 'list_summary_directories')
def test_get_summary_dirs(self, *args):
"""Test the function of get_summary_dirs."""
args[0].return_value = [
MOCK_SUMMARY_DIRS = [
{
'relative_path': './relative_path0'
},
......@@ -36,25 +30,8 @@ class TestSummaryPathParser(TestCase):
{
'relative_path': './relative_path1'
}
]
expected_result = [
'/path/to/base/relative_path0',
'/path/to/base',
'/path/to/base/relative_path1'
]
base_dir = '/path/to/base'
result = SummaryPathParser.get_summary_dirs(base_dir)
self.assertListEqual(expected_result, result)
args[0].return_value = []
result = SummaryPathParser.get_summary_dirs(base_dir)
self.assertListEqual([], result)
@mock.patch.object(SummaryWatcher, 'list_summaries')
def test_get_latest_lineage_summary(self, *args):
"""Test the function of get_latest_lineage_summary."""
args[0].return_value = [
]
MOCK_SUMMARIES = [
{
'file_name': 'file0',
'create_time': datetime.fromtimestamp(1582031970)
......@@ -71,7 +48,34 @@ class TestSummaryPathParser(TestCase):
'file_name': 'file1_lineage',
'create_time': datetime.fromtimestamp(1582031971)
}
]
class TestSummaryPathParser(TestCase):
"""Test the class of SummaryPathParser."""
@mock.patch.object(SummaryWatcher, 'list_summary_directories')
def test_get_summary_dirs(self, *args):
"""Test the function of get_summary_dirs."""
args[0].return_value = MOCK_SUMMARY_DIRS
expected_result = [
'/path/to/base/relative_path0',
'/path/to/base',
'/path/to/base/relative_path1'
]
base_dir = '/path/to/base'
result = SummaryPathParser.get_summary_dirs(base_dir)
self.assertListEqual(expected_result, result)
args[0].return_value = []
result = SummaryPathParser.get_summary_dirs(base_dir)
self.assertListEqual([], result)
@mock.patch.object(SummaryWatcher, 'list_summaries')
def test_get_latest_lineage_summary(self, *args):
"""Test the function of get_latest_lineage_summary."""
args[0].return_value = MOCK_SUMMARIES
summary_dir = '/path/to/summary_dir'
result = SummaryPathParser.get_latest_lineage_summary(summary_dir)
self.assertEqual('/path/to/summary_dir/file1_lineage', result)
......@@ -119,35 +123,8 @@ class TestSummaryPathParser(TestCase):
@mock.patch.object(SummaryWatcher, 'list_summary_directories')
def test_get_latest_lineage_summaries(self, *args):
"""Test the function of get_latest_lineage_summaries."""
args[0].return_value = [
{
'relative_path': './relative_path0'
},
{
'relative_path': './'
},
{
'relative_path': './relative_path1'
}
]
args[1].return_value = [
{
'file_name': 'file0',
'create_time': datetime.fromtimestamp(1582031970)
},
{
'file_name': 'file0_lineage',
'create_time': datetime.fromtimestamp(1582031970)
},
{
'file_name': 'file1',
'create_time': datetime.fromtimestamp(1582031971)
},
{
'file_name': 'file1_lineage',
'create_time': datetime.fromtimestamp(1582031971)
}
]
args[0].return_value = MOCK_SUMMARY_DIRS
args[1].return_value = MOCK_SUMMARIES
expected_result = [
'/path/to/base/relative_path0/file1_lineage',
......
......@@ -15,11 +15,12 @@
"""Test the query_model module."""
from unittest import TestCase
from mindinsight.lineagemgr.common.exceptions.exceptions import \
LineageEventNotExistException, LineageEventFieldNotExistException
from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageEventFieldNotExistException,
LineageEventNotExistException)
from mindinsight.lineagemgr.querier.query_model import LineageObj
from . import event_data
from .test_querier import create_lineage_info, create_filtration_result
from .test_querier import create_filtration_result, create_lineage_info
class TestLineageObj(TestCase):
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
......@@ -19,10 +19,10 @@ import time
from google.protobuf import json_format
from tests.ut.datavisual.utils.log_generators.log_generator import LogGenerator
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
from .log_generator import LogGenerator
class GraphLogGenerator(LogGenerator):
"""
......@@ -74,7 +74,7 @@ class GraphLogGenerator(LogGenerator):
if __name__ == "__main__":
graph_log_generator = GraphLogGenerator()
test_file_name = '%s.%s.%s' % ('graph', 'summary', str(time.time()))
graph_base_path = os.path.join(os.path.dirname(__file__), os.pardir, "log_generators", "graph_base.json")
graph_base_path = os.path.join(os.path.dirname(__file__), os.pardir, "log_generators--", "graph_base.json")
with open(graph_base_path, 'r') as load_f:
graph = json.load(load_f)
graph_log_generator.generate_log(test_file_name, graph)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册