提交 f2586441 编写于 作者: P ph

Merge branch 'master' of gitee.com:mindspore/mindinsight into phdev

<!-- Thanks for sending a pull request! Here are some tips for you:
If this is your first time, please read our contributor guidelines: https://gitee.com/mindspore/mindspore/blob/master/CONTRIBUTING.md
-->
**What type of PR is this?**
> Uncomment only one ` /kind <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line:
>
> /kind bug
> /kind task
> /kind feature
**What does this PR do / why do we need it**:
**Which issue(s) this PR fixes**:
<!--
*Automatically closes linked issue when PR is merged.
Usage: `Fixes #<issue number>`, or `Fixes (paste link of issue)`.
-->
Fixes #
**Special notes for your reviewers**:
---
name: RFC
about: Use this template for the new feature or enhancement
labels: kind/feature or kind/enhancement
---
## Background
- Describe the status of the problem you wish to solve
- Attach the relevant issue if have
## Introduction
- Describe the general solution, design and/or pseudo-code
## Trail
| No. | Task Description | Related Issue(URL) |
| --- | ---------------- | ------------------ |
| 1 | | |
| 2 | | |
---
name: Bug Report
about: Use this template for reporting a bug
labels: kind/bug
---
<!-- Thanks for sending an issue! Here are some tips for you:
If this is your first time, please read our contributor guidelines: https://github.com/mindspore-ai/mindspore/blob/master/CONTRIBUTING.md
-->
## Environment
### Hardware Environment(`Ascend`/`GPU`/`CPU`):
> Uncomment only one ` /device <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line:
>
> `/device ascend`</br>
> `/device gpu`</br>
> `/device cpu`</br>
### Software Environment:
- **MindSpore version (source or binary)**:
- **Python version (e.g., Python 3.7.5)**:
- **OS platform and distribution (e.g., Linux Ubuntu 16.04)**:
- **GCC/Compiler version (if compiled from source)**:
## Describe the current behavior
## Describe the expected behavior
## Steps to reproduce the issue
1.
2.
3.
## Related log / screenshot
## Special notes for this issue
---
name: Task
about: Use this template for task tracking
labels: kind/task
---
## Task Description
## Task Goal
## Sub Task
| No. | Task Description | Issue ID |
| --- | ---------------- | -------- |
| 1 | | |
| 2 | | |
<!-- Thanks for sending a pull request! Here are some tips for you:
If this is your first time, please read our contributor guidelines: https://github.com/mindspore-ai/mindspore/blob/master/CONTRIBUTING.md
-->
**What type of PR is this?**
> Uncomment only one ` /kind <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line:
>
> `/kind bug`</br>
> `/kind task`</br>
> `/kind feature`</br>
**What does this PR do / why do we need it**:
**Which issue(s) this PR fixes**:
<!--
*Automatically closes linked issue when PR is merged.
Usage: `Fixes #<issue number>`, or `Fixes (paste link of issue)`.
-->
Fixes #
**Special notes for your reviewers**:
...@@ -18,19 +18,21 @@ This file is used to define the basic graph. ...@@ -18,19 +18,21 @@ This file is used to define the basic graph.
import copy import copy
import time import time
from enum import Enum
from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common import exceptions
from .node import NodeTypeEnum from .node import NodeTypeEnum
from .node import Node from .node import Node
class EdgeTypeEnum: class EdgeTypeEnum(Enum):
"""Node edge type enum.""" """Node edge type enum."""
control = 'control' CONTROL = 'control'
data = 'data' DATA = 'data'
class DataTypeEnum: class DataTypeEnum(Enum):
"""Data type enum.""" """Data type enum."""
DT_TENSOR = 13 DT_TENSOR = 13
...@@ -292,70 +294,65 @@ class Graph: ...@@ -292,70 +294,65 @@ class Graph:
output_attr['scope'] = NodeTypeEnum.POLYMERIC_SCOPE.value output_attr['scope'] = NodeTypeEnum.POLYMERIC_SCOPE.value
node.update_output({dst_name: output_attr}) node.update_output({dst_name: output_attr})
def _calc_polymeric_input_output(self): def _update_polymeric_input_output(self):
"""Calc polymeric input and output after build polymeric node.""" """Calc polymeric input and output after build polymeric node."""
for name, node in self._normal_nodes.items(): for node in self._normal_nodes.values():
polymeric_input = {} polymeric_input = self._calc_polymeric_attr(node, 'input')
for src_name in node.input:
src_node = self._polymeric_nodes.get(src_name)
if node.node_type == NodeTypeEnum.POLYMERIC_SCOPE.value:
src_name = src_name if not src_node else src_node.polymeric_scope_name
output_name = self._calc_dummy_node_name(name, src_name)
polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.data}})
continue
if not src_node:
continue
if not node.name_scope and src_node.name_scope:
# if current node is in first layer, and the src node is not in
# the first layer, the src node will not be the polymeric input of current node.
continue
if node.name_scope == src_node.name_scope \
or node.name_scope.startswith(src_node.name_scope):
polymeric_input.update(
{src_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.data}})
node.update_polymeric_input(polymeric_input) node.update_polymeric_input(polymeric_input)
polymeric_output = {} polymeric_output = self._calc_polymeric_attr(node, 'output')
for dst_name in node.output:
dst_node = self._polymeric_nodes.get(dst_name)
if node.node_type == NodeTypeEnum.POLYMERIC_SCOPE.value:
dst_name = dst_name if not dst_node else dst_node.polymeric_scope_name
output_name = self._calc_dummy_node_name(name, dst_name)
polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.data}})
continue
if not dst_node:
continue
if not node.name_scope and dst_node.name_scope:
continue
if node.name_scope == dst_node.name_scope \
or node.name_scope.startswith(dst_node.name_scope):
polymeric_output.update(
{dst_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.data}})
node.update_polymeric_output(polymeric_output) node.update_polymeric_output(polymeric_output)
for name, node in self._polymeric_nodes.items(): for name, node in self._polymeric_nodes.items():
polymeric_input = {} polymeric_input = {}
for src_name in node.input: for src_name in node.input:
output_name = self._calc_dummy_node_name(name, src_name) output_name = self._calc_dummy_node_name(name, src_name)
polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.data}}) polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.DATA.value}})
node.update_polymeric_input(polymeric_input) node.update_polymeric_input(polymeric_input)
polymeric_output = {} polymeric_output = {}
for dst_name in node.output: for dst_name in node.output:
polymeric_output = {} polymeric_output = {}
output_name = self._calc_dummy_node_name(name, dst_name) output_name = self._calc_dummy_node_name(name, dst_name)
polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.data}}) polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.DATA.value}})
node.update_polymeric_output(polymeric_output) node.update_polymeric_output(polymeric_output)
def _calc_polymeric_attr(self, node, attr):
"""
Calc polymeric input or polymeric output after build polymeric node.
Args:
node (Node): Computes the polymeric input for a given node.
attr (str): The polymeric attr, optional value is `input` or `output`.
Returns:
dict, return polymeric input or polymeric output of the given node.
"""
polymeric_attr = {}
for node_name in getattr(node, attr):
polymeric_node = self._polymeric_nodes.get(node_name)
if node.node_type == NodeTypeEnum.POLYMERIC_SCOPE.value:
node_name = node_name if not polymeric_node else polymeric_node.polymeric_scope_name
dummy_node_name = self._calc_dummy_node_name(node.name, node_name)
polymeric_attr.update({dummy_node_name: {'edge_type': EdgeTypeEnum.DATA.value}})
continue
if not polymeric_node:
continue
if not node.name_scope and polymeric_node.name_scope:
# If current node is in top-level layer, and the polymeric_node node is not in
# the top-level layer, the polymeric node will not be the polymeric input
# or polymeric output of current node.
continue
if node.name_scope == polymeric_node.name_scope \
or node.name_scope.startswith(polymeric_node.name_scope + '/'):
polymeric_attr.update(
{polymeric_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.DATA.value}})
return polymeric_attr
def _calc_dummy_node_name(self, current_node_name, other_node_name): def _calc_dummy_node_name(self, current_node_name, other_node_name):
""" """
Calc dummy node name. Calc dummy node name.
......
...@@ -39,7 +39,7 @@ class MSGraph(Graph): ...@@ -39,7 +39,7 @@ class MSGraph(Graph):
self._build_leaf_nodes(graph_proto) self._build_leaf_nodes(graph_proto)
self._build_polymeric_nodes() self._build_polymeric_nodes()
self._build_name_scope_nodes() self._build_name_scope_nodes()
self._calc_polymeric_input_output() self._update_polymeric_input_output()
logger.info("Build graph end, normal node count: %s, polymeric node " logger.info("Build graph end, normal node count: %s, polymeric node "
"count: %s.", len(self._normal_nodes), len(self._polymeric_nodes)) "count: %s.", len(self._normal_nodes), len(self._polymeric_nodes))
...@@ -90,9 +90,9 @@ class MSGraph(Graph): ...@@ -90,9 +90,9 @@ class MSGraph(Graph):
node_name = leaf_node_id_map_name[node_def.name] node_name = leaf_node_id_map_name[node_def.name]
node = self._leaf_nodes[node_name] node = self._leaf_nodes[node_name]
for input_def in node_def.input: for input_def in node_def.input:
edge_type = EdgeTypeEnum.data edge_type = EdgeTypeEnum.DATA.value
if input_def.type == "CONTROL_EDGE": if input_def.type == "CONTROL_EDGE":
edge_type = EdgeTypeEnum.control edge_type = EdgeTypeEnum.CONTROL.value
if const_nodes_map.get(input_def.name): if const_nodes_map.get(input_def.name):
const_node = copy.deepcopy(const_nodes_map[input_def.name]) const_node = copy.deepcopy(const_nodes_map[input_def.name])
...@@ -218,7 +218,7 @@ class MSGraph(Graph): ...@@ -218,7 +218,7 @@ class MSGraph(Graph):
node = Node(name=const.key, node_id=const_node_id) node = Node(name=const.key, node_id=const_node_id)
node.node_type = NodeTypeEnum.CONST.value node.node_type = NodeTypeEnum.CONST.value
node.update_attr({const.key: str(const.value)}) node.update_attr({const.key: str(const.value)})
if const.value.dtype == DataTypeEnum.DT_TENSOR: if const.value.dtype == DataTypeEnum.DT_TENSOR.value:
shape = [] shape = []
for dim in const.value.tensor_val.dims: for dim in const.value.tensor_val.dims:
shape.append(dim) shape.append(dim)
......
...@@ -172,7 +172,7 @@ class Node: ...@@ -172,7 +172,7 @@ class Node:
Args: Args:
polymeric_output (dict[str, dict): Format is {dst_node.polymeric_scope_name: polymeric_output (dict[str, dict): Format is {dst_node.polymeric_scope_name:
{'edge_type': EdgeTypeEnum.data}}). {'edge_type': EdgeTypeEnum.DATA.value}}).
""" """
self._polymeric_output.update(polymeric_output) self._polymeric_output.update(polymeric_output)
......
...@@ -168,7 +168,7 @@ class TrainLineage(Callback): ...@@ -168,7 +168,7 @@ class TrainLineage(Callback):
train_lineage = AnalyzeObject.get_network_args( train_lineage = AnalyzeObject.get_network_args(
run_context_args, train_lineage run_context_args, train_lineage
) )
train_dataset = run_context_args.get('train_dataset') train_dataset = run_context_args.get('train_dataset')
callbacks = run_context_args.get('list_callback') callbacks = run_context_args.get('list_callback')
list_callback = getattr(callbacks, '_callbacks', []) list_callback = getattr(callbacks, '_callbacks', [])
...@@ -601,7 +601,7 @@ class AnalyzeObject: ...@@ -601,7 +601,7 @@ class AnalyzeObject:
loss = None loss = None
else: else:
loss = run_context_args.get('net_outputs') loss = run_context_args.get('net_outputs')
if loss: if loss:
log.info('Calculating loss...') log.info('Calculating loss...')
loss_numpy = loss.asnumpy() loss_numpy = loss.asnumpy()
...@@ -610,7 +610,7 @@ class AnalyzeObject: ...@@ -610,7 +610,7 @@ class AnalyzeObject:
train_lineage[Metadata.loss] = loss train_lineage[Metadata.loss] = loss
else: else:
train_lineage[Metadata.loss] = None train_lineage[Metadata.loss] = None
# Analyze classname of optimizer, loss function and training network. # Analyze classname of optimizer, loss function and training network.
train_lineage[Metadata.optimizer] = type(optimizer).__name__ \ train_lineage[Metadata.optimizer] = type(optimizer).__name__ \
if optimizer else None if optimizer else None
......
...@@ -18,13 +18,10 @@ Description: This file is used for some common util. ...@@ -18,13 +18,10 @@ Description: This file is used for some common util.
import os import os
import shutil import shutil
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from flask import Response from flask import Response
from tests.st.func.datavisual import constants
from tests.st.func.datavisual.utils.log_operations import LogOperations
from tests.st.func.datavisual.utils.utils import check_loading_done
from tests.st.func.datavisual.utils import globals as gbl
from mindinsight.conf import settings from mindinsight.conf import settings
from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform import data_manager
from mindinsight.datavisual.data_transform.data_manager import DataManager from mindinsight.datavisual.data_transform.data_manager import DataManager
...@@ -32,6 +29,11 @@ from mindinsight.datavisual.data_transform.loader_generators.data_loader_generat ...@@ -32,6 +29,11 @@ from mindinsight.datavisual.data_transform.loader_generators.data_loader_generat
from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE
from mindinsight.datavisual.utils import tools from mindinsight.datavisual.utils import tools
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done
from . import constants
from . import globals as gbl
summaries_metadata = None summaries_metadata = None
mock_data_manager = None mock_data_manager = None
summary_base_dir = constants.SUMMARY_BASE_DIR summary_base_dir = constants.SUMMARY_BASE_DIR
...@@ -55,17 +57,21 @@ def init_summary_logs(): ...@@ -55,17 +57,21 @@ def init_summary_logs():
os.mkdir(summary_base_dir, mode=mode) os.mkdir(summary_base_dir, mode=mode)
global summaries_metadata, mock_data_manager global summaries_metadata, mock_data_manager
log_operations = LogOperations() log_operations = LogOperations()
summaries_metadata = log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_FIRST) summaries_metadata = log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_FIRST,
constants.SUMMARY_DIR_PREFIX)
mock_data_manager = DataManager([DataLoaderGenerator(summary_base_dir)]) mock_data_manager = DataManager([DataLoaderGenerator(summary_base_dir)])
mock_data_manager.start_load_data(reload_interval=0) mock_data_manager.start_load_data(reload_interval=0)
check_loading_done(mock_data_manager) check_loading_done(mock_data_manager)
summaries_metadata.update(log_operations.create_summary_logs( summaries_metadata.update(
summary_base_dir, constants.SUMMARY_DIR_NUM_SECOND, constants.SUMMARY_DIR_NUM_FIRST)) log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_SECOND,
summaries_metadata.update(log_operations.create_multiple_logs( constants.SUMMARY_DIR_NUM_FIRST))
summary_base_dir, constants.MULTIPLE_DIR_NAME, constants.MULTIPLE_LOG_NUM)) summaries_metadata.update(
summaries_metadata.update(log_operations.create_reservoir_log( log_operations.create_multiple_logs(summary_base_dir, constants.MULTIPLE_DIR_NAME,
summary_base_dir, constants.RESERVOIR_DIR_NAME, constants.RESERVOIR_STEP_NUM)) constants.MULTIPLE_LOG_NUM))
summaries_metadata.update(
log_operations.create_reservoir_log(summary_base_dir, constants.RESERVOIR_DIR_NAME,
constants.RESERVOIR_STEP_NUM))
mock_data_manager.start_load_data(reload_interval=0) mock_data_manager.start_load_data(reload_interval=0)
# Sleep 1 sec to make sure the status of mock_data_manager changed to LOADING. # Sleep 1 sec to make sure the status of mock_data_manager changed to LOADING.
...@@ -73,7 +79,7 @@ def init_summary_logs(): ...@@ -73,7 +79,7 @@ def init_summary_logs():
# Maximum number of loads is `MAX_DATA_LOADER_SIZE`. # Maximum number of loads is `MAX_DATA_LOADER_SIZE`.
for i in range(len(summaries_metadata) - MAX_DATA_LOADER_SIZE): for i in range(len(summaries_metadata) - MAX_DATA_LOADER_SIZE):
summaries_metadata.pop("./%s%d" % (constants.SUMMARY_PREFIX, i)) summaries_metadata.pop("./%s%d" % (constants.SUMMARY_DIR_PREFIX, i))
yield yield
finally: finally:
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import tempfile import tempfile
SUMMARY_BASE_DIR = tempfile.NamedTemporaryFile().name SUMMARY_BASE_DIR = tempfile.NamedTemporaryFile().name
SUMMARY_PREFIX = "summary" SUMMARY_DIR_PREFIX = "summary"
SUMMARY_DIR_NUM_FIRST = 5 SUMMARY_DIR_NUM_FIRST = 5
SUMMARY_DIR_NUM_SECOND = 11 SUMMARY_DIR_NUM_SECOND = 11
......
...@@ -19,11 +19,11 @@ Usage: ...@@ -19,11 +19,11 @@ Usage:
pytest tests/st/func/datavisual pytest tests/st/func/datavisual
""" """
import os import os
import json
import pytest import pytest
from tests.st.func.datavisual.utils import globals as gbl from .. import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url from .....utils.tools import get_url, compare_result_with_file
BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes' BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes'
...@@ -33,12 +33,6 @@ class TestQueryNodes: ...@@ -33,12 +33,6 @@ class TestQueryNodes:
graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results') graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results')
def compare_result_with_file(self, result, filename):
"""Compare result with file which contain the expected results."""
with open(os.path.join(self.graph_results_dir, filename), 'r') as fp:
expected_results = json.load(fp)
assert result == expected_results
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.env_single @pytest.mark.env_single
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
...@@ -65,4 +59,5 @@ class TestQueryNodes: ...@@ -65,4 +59,5 @@ class TestQueryNodes:
url = get_url(BASE_URL, params) url = get_url(BASE_URL, params)
response = client.get(url) response = client.get(url)
assert response.status_code == 200 assert response.status_code == 200
self.compare_result_with_file(response.get_json(), result_file) file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(response.get_json(), file_path)
...@@ -19,12 +19,11 @@ Usage: ...@@ -19,12 +19,11 @@ Usage:
pytest tests/st/func/datavisual pytest tests/st/func/datavisual
""" """
import os import os
import json
import pytest import pytest
from tests.st.func.datavisual.utils import globals as gbl from .. import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url from .....utils.tools import get_url, compare_result_with_file
BASE_URL = '/v1/mindinsight/datavisual/graphs/single-node' BASE_URL = '/v1/mindinsight/datavisual/graphs/single-node'
...@@ -34,12 +33,6 @@ class TestQuerySingleNode: ...@@ -34,12 +33,6 @@ class TestQuerySingleNode:
graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results') graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results')
def compare_result_with_file(self, result, filename):
"""Compare result with file which contain the expected results."""
with open(os.path.join(self.graph_results_dir, filename), 'r') as fp:
expected_results = json.load(fp)
assert result == expected_results
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.env_single @pytest.mark.env_single
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
...@@ -59,4 +52,5 @@ class TestQuerySingleNode: ...@@ -59,4 +52,5 @@ class TestQuerySingleNode:
url = get_url(BASE_URL, params) url = get_url(BASE_URL, params)
response = client.get(url) response = client.get(url)
assert response.status_code == 200 assert response.status_code == 200
self.compare_result_with_file(response.get_json(), result_file) file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(response.get_json(), file_path)
...@@ -19,25 +19,20 @@ Usage: ...@@ -19,25 +19,20 @@ Usage:
pytest tests/st/func/datavisual pytest tests/st/func/datavisual
""" """
import os import os
import json
import pytest import pytest
from tests.st.func.datavisual.utils import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url from .. import globals as gbl
from .....utils.tools import get_url, compare_result_with_file
BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes/names' BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes/names'
class TestSearchNodes: class TestSearchNodes:
"""Test search nodes restful APIs.""" """Test searching nodes restful APIs."""
graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results') graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results')
def compare_result_with_file(self, result, filename):
"""Compare result with file which contain the expected results."""
with open(os.path.join(self.graph_results_dir, filename), 'r') as fp:
expected_results = json.load(fp)
assert result == expected_results
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.env_single @pytest.mark.env_single
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
...@@ -58,4 +53,5 @@ class TestSearchNodes: ...@@ -58,4 +53,5 @@ class TestSearchNodes:
url = get_url(BASE_URL, params) url = get_url(BASE_URL, params)
response = client.get(url) response = client.get(url)
assert response.status_code == 200 assert response.status_code == 200
self.compare_result_with_file(response.get_json(), result_file) file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(response.get_json(), file_path)
...@@ -20,13 +20,13 @@ Usage: ...@@ -20,13 +20,13 @@ Usage:
""" """
import pytest import pytest
from tests.st.func.datavisual.constants import MULTIPLE_TRAIN_ID, RESERVOIR_TRAIN_ID
from tests.st.func.datavisual.utils import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url
from mindinsight.conf import settings from mindinsight.conf import settings
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_url
from .. import globals as gbl
from ..constants import MULTIPLE_TRAIN_ID, RESERVOIR_TRAIN_ID
BASE_URL = '/v1/mindinsight/datavisual/image/metadata' BASE_URL = '/v1/mindinsight/datavisual/image/metadata'
......
...@@ -20,11 +20,11 @@ Usage: ...@@ -20,11 +20,11 @@ Usage:
""" """
import pytest import pytest
from tests.st.func.datavisual.utils import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url, get_image_tensor_from_bytes
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_image_tensor_from_bytes, get_url
from .. import globals as gbl
BASE_URL = '/v1/mindinsight/datavisual/image/single-image' BASE_URL = '/v1/mindinsight/datavisual/image/single-image'
......
...@@ -19,11 +19,12 @@ Usage: ...@@ -19,11 +19,12 @@ Usage:
pytest tests/st/func/datavisual pytest tests/st/func/datavisual
""" """
import pytest import pytest
from tests.st.func.datavisual.utils import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_url
from .. import globals as gbl
BASE_URL = '/v1/mindinsight/datavisual/scalar/metadata' BASE_URL = '/v1/mindinsight/datavisual/scalar/metadata'
......
...@@ -20,11 +20,11 @@ Usage: ...@@ -20,11 +20,11 @@ Usage:
""" """
import pytest import pytest
from tests.st.func.datavisual.utils import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_url
from .. import globals as gbl
BASE_URL = '/v1/mindinsight/datavisual/plugins' BASE_URL = '/v1/mindinsight/datavisual/plugins'
......
...@@ -19,11 +19,12 @@ Usage: ...@@ -19,11 +19,12 @@ Usage:
pytest tests/st/func/datavisual pytest tests/st/func/datavisual
""" """
import pytest import pytest
from tests.st.func.datavisual.utils import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_url
from .. import globals as gbl
BASE_URL = '/v1/mindinsight/datavisual/single-job' BASE_URL = '/v1/mindinsight/datavisual/single-job'
......
...@@ -20,8 +20,8 @@ Usage: ...@@ -20,8 +20,8 @@ Usage:
""" """
import pytest import pytest
from tests.st.func.datavisual.constants import SUMMARY_DIR_NUM from ..constants import SUMMARY_DIR_NUM
from tests.st.func.datavisual.utils.utils import get_url from .....utils.tools import get_url
BASE_URL = '/v1/mindinsight/datavisual/train-jobs' BASE_URL = '/v1/mindinsight/datavisual/train-jobs'
......
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Log generator for graph."""
import json
import os
import time
from google.protobuf import json_format
from tests.st.func.datavisual.utils.log_generators.log_generator import LogGenerator
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
class GraphLogGenerator(LogGenerator):
"""
Log generator for graph.
This is a log generator writing graph. User can use it to generate fake
summary logs about graph.
"""
def generate_log(self, file_path, graph_dict):
"""
Generate log for external calls.
Args:
file_path (str): Path to write logs.
graph_dict (dict): A dict consists of graph node information.
Returns:
dict, generated scalar metadata.
"""
graph_event = self.generate_event(dict(graph=graph_dict))
self._write_log_from_event(file_path, graph_event)
return graph_dict
def generate_event(self, values):
"""
Method for generating graph event.
Args:
values (dict): Graph values. e.g. {'graph': graph_dict}.
Returns:
summary_pb2.Event.
"""
graph_json = {
'wall_time': time.time(),
'graph_def': values.get('graph'),
}
graph_event = json_format.Parse(json.dumps(graph_json), summary_pb2.Event())
return graph_event
if __name__ == "__main__":
graph_log_generator = GraphLogGenerator()
test_file_name = '%s.%s.%s' % ('graph', 'summary', str(time.time()))
graph_base_path = os.path.join(os.path.dirname(__file__), os.pardir, "log_generators", "graph_base.json")
with open(graph_base_path, 'r') as load_f:
graph = json.load(load_f)
graph_log_generator.generate_log(test_file_name, graph)
...@@ -20,11 +20,11 @@ Usage: ...@@ -20,11 +20,11 @@ Usage:
""" """
import pytest import pytest
from tests.st.func.datavisual.utils import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_url
from .. import globals as gbl
TRAIN_JOB_URL = '/v1/mindinsight/datavisual/train-jobs' TRAIN_JOB_URL = '/v1/mindinsight/datavisual/train-jobs'
PLUGIN_URL = '/v1/mindinsight/datavisual/plugins' PLUGIN_URL = '/v1/mindinsight/datavisual/plugins'
METADATA_URL = '/v1/mindinsight/datavisual/image/metadata' METADATA_URL = '/v1/mindinsight/datavisual/image/metadata'
......
...@@ -20,11 +20,11 @@ Usage: ...@@ -20,11 +20,11 @@ Usage:
""" """
import pytest import pytest
from tests.st.func.datavisual.utils import globals as gbl
from tests.st.func.datavisual.utils.utils import get_url, get_image_tensor_from_bytes
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from .....utils.tools import get_image_tensor_from_bytes, get_url
from .. import globals as gbl
TRAIN_JOB_URL = '/v1/mindinsight/datavisual/train-jobs' TRAIN_JOB_URL = '/v1/mindinsight/datavisual/train-jobs'
PLUGIN_URL = '/v1/mindinsight/datavisual/plugins' PLUGIN_URL = '/v1/mindinsight/datavisual/plugins'
METADATA_URL = '/v1/mindinsight/datavisual/image/metadata' METADATA_URL = '/v1/mindinsight/datavisual/image/metadata'
......
...@@ -26,11 +26,101 @@ from unittest import TestCase ...@@ -26,11 +26,101 @@ from unittest import TestCase
import pytest import pytest
from mindinsight.lineagemgr import get_summary_lineage, filter_summary_lineage from mindinsight.lineagemgr import filter_summary_lineage, get_summary_lineage
from mindinsight.lineagemgr.common.exceptions.exceptions import \ from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageFileNotFoundError, LineageParamSummaryPathError,
LineageParamSummaryPathError, LineageParamValueError, LineageParamTypeError, \ LineageParamTypeError, LineageParamValueError,
LineageSearchConditionParamError, LineageFileNotFoundError LineageSearchConditionParamError)
from ..conftest import BASE_SUMMARY_DIR, SUMMARY_DIR, SUMMARY_DIR_2, DATASET_GRAPH
from ..conftest import BASE_SUMMARY_DIR, DATASET_GRAPH, SUMMARY_DIR, SUMMARY_DIR_2
LINEAGE_INFO_RUN1 = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'metric': {
'accuracy': 0.78
},
'hyper_parameters': {
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'epoch': 14,
'parallel_mode': 'stand_alone',
'device_num': 2,
'batch_size': 32
},
'algorithm': {
'network': 'ResNet'
},
'train_dataset': {
'train_dataset_size': 731
},
'valid_dataset': {
'valid_dataset_size': 10240
},
'model': {
'path': '{"ckpt": "'
+ BASE_SUMMARY_DIR + '/run1/CKPtest_model.ckpt"}',
'size': 64
},
'dataset_graph': DATASET_GRAPH
}
LINEAGE_FILTRATION_EXCEPT_RUN = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'except_run'),
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 1024,
'test_dataset_path': None,
'test_dataset_count': None,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 10,
'batch_size': 32,
'loss': 0.029999999329447746,
'model_size': 64,
'metric': {},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
}
LINEAGE_FILTRATION_RUN1 = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 731,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 14,
'batch_size': 32,
'loss': None,
'model_size': 64,
'metric': {
'accuracy': 0.78
},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
}
LINEAGE_FILTRATION_RUN2 = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'),
'loss_function': None,
'train_dataset_path': None,
'train_dataset_count': None,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': None,
'optimizer': None,
'learning_rate': None,
'epoch': None,
'batch_size': None,
'loss': None,
'model_size': None,
'metric': {
'accuracy': 2.7800000000000002
},
'dataset_graph': {},
'dataset_mark': 3
}
@pytest.mark.usefixtures("create_summary_dir") @pytest.mark.usefixtures("create_summary_dir")
...@@ -67,36 +157,7 @@ class TestModelApi(TestCase): ...@@ -67,36 +157,7 @@ class TestModelApi(TestCase):
total_res = get_summary_lineage(SUMMARY_DIR) total_res = get_summary_lineage(SUMMARY_DIR)
partial_res1 = get_summary_lineage(SUMMARY_DIR, ['hyper_parameters']) partial_res1 = get_summary_lineage(SUMMARY_DIR, ['hyper_parameters'])
partial_res2 = get_summary_lineage(SUMMARY_DIR, ['metric', 'algorithm']) partial_res2 = get_summary_lineage(SUMMARY_DIR, ['metric', 'algorithm'])
expect_total_res = { expect_total_res = LINEAGE_INFO_RUN1
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'metric': {
'accuracy': 0.78
},
'hyper_parameters': {
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'epoch': 14,
'parallel_mode': 'stand_alone',
'device_num': 2,
'batch_size': 32
},
'algorithm': {
'network': 'ResNet'
},
'train_dataset': {
'train_dataset_size': 731
},
'valid_dataset': {
'valid_dataset_size': 10240
},
'model': {
'path': '{"ckpt": "'
+ BASE_SUMMARY_DIR + '/run1/CKPtest_model.ckpt"}',
'size': 64
},
'dataset_graph': DATASET_GRAPH
}
expect_partial_res1 = { expect_partial_res1 = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'hyper_parameters': { 'hyper_parameters': {
...@@ -139,7 +200,7 @@ class TestModelApi(TestCase): ...@@ -139,7 +200,7 @@ class TestModelApi(TestCase):
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.env_single @pytest.mark.env_single
def test_get_summary_lineage_exception(self): def test_get_summary_lineage_exception_1(self):
"""Test the interface of get_summary_lineage with exception.""" """Test the interface of get_summary_lineage with exception."""
# summary path does not exist # summary path does not exist
self.assertRaisesRegex( self.assertRaisesRegex(
...@@ -183,6 +244,14 @@ class TestModelApi(TestCase): ...@@ -183,6 +244,14 @@ class TestModelApi(TestCase):
keys=None keys=None
) )
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_get_summary_lineage_exception_2(self):
"""Test the interface of get_summary_lineage with exception."""
# keys is invalid # keys is invalid
self.assertRaisesRegex( self.assertRaisesRegex(
LineageParamValueError, LineageParamValueError,
...@@ -250,64 +319,9 @@ class TestModelApi(TestCase): ...@@ -250,64 +319,9 @@ class TestModelApi(TestCase):
"""Test the interface of filter_summary_lineage.""" """Test the interface of filter_summary_lineage."""
expect_result = { expect_result = {
'object': [ 'object': [
{ LINEAGE_FILTRATION_EXCEPT_RUN,
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'except_run'), LINEAGE_FILTRATION_RUN1,
'loss_function': 'SoftmaxCrossEntropyWithLogits', LINEAGE_FILTRATION_RUN2
'train_dataset_path': None,
'train_dataset_count': 1024,
'test_dataset_path': None,
'test_dataset_count': None,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 10,
'batch_size': 32,
'loss': 0.029999999329447746,
'model_size': 64,
'metric': {},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
},
{
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 731,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 14,
'batch_size': 32,
'loss': None,
'model_size': 64,
'metric': {
'accuracy': 0.78
},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
},
{
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'),
'loss_function': None,
'train_dataset_path': None,
'train_dataset_count': None,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': None,
'optimizer': None,
'learning_rate': None,
'epoch': None,
'batch_size': None,
'loss': None,
'model_size': None,
'metric': {
'accuracy': 2.7800000000000002
},
'dataset_graph': {},
'dataset_mark': 3
}
], ],
'count': 3 'count': 3
} }
...@@ -357,46 +371,8 @@ class TestModelApi(TestCase): ...@@ -357,46 +371,8 @@ class TestModelApi(TestCase):
} }
expect_result = { expect_result = {
'object': [ 'object': [
{ LINEAGE_FILTRATION_RUN2,
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'), LINEAGE_FILTRATION_RUN1
'loss_function': None,
'train_dataset_path': None,
'train_dataset_count': None,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': None,
'optimizer': None,
'learning_rate': None,
'epoch': None,
'batch_size': None,
'loss': None,
'model_size': None,
'metric': {
'accuracy': 2.7800000000000002
},
'dataset_graph': {},
'dataset_mark': 3
},
{
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 731,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 14,
'batch_size': 32,
'loss': None,
'model_size': 64,
'metric': {
'accuracy': 0.78
},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
}
], ],
'count': 2 'count': 2
} }
...@@ -432,46 +408,8 @@ class TestModelApi(TestCase): ...@@ -432,46 +408,8 @@ class TestModelApi(TestCase):
} }
expect_result = { expect_result = {
'object': [ 'object': [
{ LINEAGE_FILTRATION_RUN2,
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'), LINEAGE_FILTRATION_RUN1
'loss_function': None,
'train_dataset_path': None,
'train_dataset_count': None,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': None,
'optimizer': None,
'learning_rate': None,
'epoch': None,
'batch_size': None,
'loss': None,
'model_size': None,
'metric': {
'accuracy': 2.7800000000000002
},
'dataset_graph': {},
'dataset_mark': 3
},
{
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 731,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 14,
'batch_size': 32,
'loss': None,
'model_size': 64,
'metric': {
'accuracy': 0.78
},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
}
], ],
'count': 2 'count': 2
} }
...@@ -498,44 +436,8 @@ class TestModelApi(TestCase): ...@@ -498,44 +436,8 @@ class TestModelApi(TestCase):
} }
expect_result = { expect_result = {
'object': [ 'object': [
{ LINEAGE_FILTRATION_EXCEPT_RUN,
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'except_run'), LINEAGE_FILTRATION_RUN1
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 1024,
'test_dataset_path': None,
'test_dataset_count': None,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 10,
'batch_size': 32,
'loss': 0.029999999329447746,
'model_size': 64,
'metric': {},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
},
{
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 731,
'test_dataset_path': None,
'test_dataset_count': 10240,
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 14,
'batch_size': 32,
'loss': None,
'model_size': 64,
'metric': {
'accuracy': 0.78
},
'dataset_graph': DATASET_GRAPH,
'dataset_mark': 2
}
], ],
'count': 2 'count': 2
} }
...@@ -674,6 +576,14 @@ class TestModelApi(TestCase): ...@@ -674,6 +576,14 @@ class TestModelApi(TestCase):
search_condition search_condition
) )
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_filter_summary_lineage_exception_3(self):
"""Test the abnormal execution of the filter_summary_lineage interface."""
# the condition of offset is invalid # the condition of offset is invalid
search_condition = { search_condition = {
'offset': 1.0 'offset': 1.0
...@@ -712,6 +622,14 @@ class TestModelApi(TestCase): ...@@ -712,6 +622,14 @@ class TestModelApi(TestCase):
search_condition search_condition
) )
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_filter_summary_lineage_exception_4(self):
"""Test the abnormal execution of the filter_summary_lineage interface."""
# the sorted_type not supported # the sorted_type not supported
search_condition = { search_condition = {
'sorted_name': 'summary_dir', 'sorted_name': 'summary_dir',
...@@ -753,6 +671,14 @@ class TestModelApi(TestCase): ...@@ -753,6 +671,14 @@ class TestModelApi(TestCase):
search_condition search_condition
) )
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_filter_summary_lineage_exception_5(self):
"""Test the abnormal execution of the filter_summary_lineage interface."""
# the summary dir is invalid in search condition # the summary dir is invalid in search condition
search_condition = { search_condition = {
'summary_dir': { 'summary_dir': {
...@@ -811,7 +737,7 @@ class TestModelApi(TestCase): ...@@ -811,7 +737,7 @@ class TestModelApi(TestCase):
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
@pytest.mark.env_single @pytest.mark.env_single
def test_filter_summary_lineage_exception_3(self): def test_filter_summary_lineage_exception_6(self):
"""Test the abnormal execution of the filter_summary_lineage interface.""" """Test the abnormal execution of the filter_summary_lineage interface."""
# gt > lt # gt > lt
search_condition1 = { search_condition1 = {
......
...@@ -21,7 +21,8 @@ import tempfile ...@@ -21,7 +21,8 @@ import tempfile
import pytest import pytest
from .collection.model import mindspore from ....utils import mindspore
from ....utils.mindspore.dataset.engine.serializer_deserializer import SERIALIZED_PIPELINE
sys.modules['mindspore'] = mindspore sys.modules['mindspore'] = mindspore
...@@ -32,52 +33,7 @@ SUMMARY_DIR_3 = os.path.join(BASE_SUMMARY_DIR, 'except_run') ...@@ -32,52 +33,7 @@ SUMMARY_DIR_3 = os.path.join(BASE_SUMMARY_DIR, 'except_run')
COLLECTION_MODULE = 'TestModelLineage' COLLECTION_MODULE = 'TestModelLineage'
API_MODULE = 'TestModelApi' API_MODULE = 'TestModelApi'
DATASET_GRAPH = { DATASET_GRAPH = SERIALIZED_PIPELINE
'op_type': 'BatchDataset',
'op_module': 'minddata.dataengine.datasets',
'num_parallel_workers': None,
'drop_remainder': True,
'batch_size': 10,
'children': [
{
'op_type': 'MapDataset',
'op_module': 'minddata.dataengine.datasets',
'num_parallel_workers': None,
'input_columns': [
'label'
],
'output_columns': [
None
],
'operations': [
{
'tensor_op_module': 'minddata.transforms.c_transforms',
'tensor_op_name': 'OneHot',
'num_classes': 10
}
],
'children': [
{
'op_type': 'MnistDataset',
'shard_id': None,
'num_shards': None,
'op_module': 'minddata.dataengine.datasets',
'dataset_dir': '/home/anthony/MindData/tests/dataset/data/testMnistData',
'num_parallel_workers': None,
'shuffle': None,
'num_samples': 100,
'sampler': {
'sampler_module': 'minddata.dataengine.samplers',
'sampler_name': 'RandomSampler',
'replacement': True,
'num_samples': 100
},
'children': []
}
]
}
]
}
def get_module_name(nodeid): def get_module_name(nodeid):
"""Get the module name from nodeid.""" """Get the module name from nodeid."""
......
...@@ -14,6 +14,6 @@ ...@@ -14,6 +14,6 @@
# ============================================================================ # ============================================================================
"""Import the mocked mindspore.""" """Import the mocked mindspore."""
import sys import sys
from .lineagemgr.collection.model import mindspore from ..utils import mindspore
sys.modules['mindspore'] = mindspore sys.modules['mindspore'] = mindspore
...@@ -21,14 +21,15 @@ Usage: ...@@ -21,14 +21,15 @@ Usage:
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from tests.ut.backend.datavisual.conftest import TRAIN_ROUTES
from tests.ut.datavisual.utils.log_generators.images_log_generator import ImagesLogGenerator
from tests.ut.datavisual.utils.log_generators.scalars_log_generator import ScalarsLogGenerator
from tests.ut.datavisual.utils.utils import get_url
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.processors.train_task_manager import TrainTaskManager from mindinsight.datavisual.processors.train_task_manager import TrainTaskManager
from ....utils.log_generators.images_log_generator import ImagesLogGenerator
from ....utils.log_generators.scalars_log_generator import ScalarsLogGenerator
from ....utils.tools import get_url
from .conftest import TRAIN_ROUTES
class TestTrainTask: class TestTrainTask:
"""Test train task api.""" """Test train task api."""
...@@ -36,9 +37,7 @@ class TestTrainTask: ...@@ -36,9 +37,7 @@ class TestTrainTask:
_scalar_log_generator = ScalarsLogGenerator() _scalar_log_generator = ScalarsLogGenerator()
_image_log_generator = ImagesLogGenerator() _image_log_generator = ImagesLogGenerator()
@pytest.mark.parametrize( @pytest.mark.parametrize("plugin_name", ['no_plugin_name', 'not_exist_plugin_name'])
"plugin_name",
['no_plugin_name', 'not_exist_plugin_name'])
def test_query_single_train_task_with_plugin_name_not_exist(self, client, plugin_name): def test_query_single_train_task_with_plugin_name_not_exist(self, client, plugin_name):
""" """
Parsing unavailable plugin name to single train task. Parsing unavailable plugin name to single train task.
......
...@@ -21,14 +21,15 @@ Usage: ...@@ -21,14 +21,15 @@ Usage:
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
from tests.ut.backend.datavisual.conftest import TRAIN_ROUTES
from tests.ut.datavisual.utils.utils import get_url
from mindinsight.datavisual.data_transform.graph import NodeTypeEnum from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
from mindinsight.datavisual.processors.graph_processor import GraphProcessor from mindinsight.datavisual.processors.graph_processor import GraphProcessor
from mindinsight.datavisual.processors.images_processor import ImageProcessor from mindinsight.datavisual.processors.images_processor import ImageProcessor
from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor
from ....utils.tools import get_url
from .conftest import TRAIN_ROUTES
class TestTrainVisual: class TestTrainVisual:
"""Test Train Visual APIs.""" """Test Train Visual APIs."""
...@@ -95,14 +96,7 @@ class TestTrainVisual: ...@@ -95,14 +96,7 @@ class TestTrainVisual:
assert response.status_code == 200 assert response.status_code == 200
response = response.get_json() response = response.get_json()
expected_response = { expected_response = {"metadatas": [{"height": 224, "step": 1, "wall_time": 1572058058.1175, "width": 448}]}
"metadatas": [{
"height": 224,
"step": 1,
"wall_time": 1572058058.1175,
"width": 448
}]
}
assert expected_response == response assert expected_response == response
def test_single_image_with_params_miss(self, client): def test_single_image_with_params_miss(self, client):
...@@ -254,8 +248,10 @@ class TestTrainVisual: ...@@ -254,8 +248,10 @@ class TestTrainVisual:
@patch.object(GraphProcessor, 'get_nodes') @patch.object(GraphProcessor, 'get_nodes')
def test_graph_nodes_success(self, mock_graph_processor, mock_graph_processor_1, client): def test_graph_nodes_success(self, mock_graph_processor, mock_graph_processor_1, client):
"""Test getting graph nodes successfully.""" """Test getting graph nodes successfully."""
def mock_get_nodes(name, node_type): def mock_get_nodes(name, node_type):
return dict(name=name, node_type=node_type) return dict(name=name, node_type=node_type)
mock_graph_processor.side_effect = mock_get_nodes mock_graph_processor.side_effect = mock_get_nodes
mock_init = Mock(return_value=None) mock_init = Mock(return_value=None)
...@@ -327,10 +323,7 @@ class TestTrainVisual: ...@@ -327,10 +323,7 @@ class TestTrainVisual:
assert results['error_msg'] == "Invalid parameter value. 'offset' should " \ assert results['error_msg'] == "Invalid parameter value. 'offset' should " \
"be greater than or equal to 0." "be greater than or equal to 0."
@pytest.mark.parametrize( @pytest.mark.parametrize("limit", [-1, 0, 1001])
"limit",
[-1, 0, 1001]
)
@patch.object(GraphProcessor, '__init__') @patch.object(GraphProcessor, '__init__')
def test_graph_node_names_with_invalid_limit(self, mock_graph_processor, client, limit): def test_graph_node_names_with_invalid_limit(self, mock_graph_processor, client, limit):
"""Test getting graph node names with invalid limit.""" """Test getting graph node names with invalid limit."""
...@@ -348,14 +341,10 @@ class TestTrainVisual: ...@@ -348,14 +341,10 @@ class TestTrainVisual:
assert results['error_msg'] == "Invalid parameter value. " \ assert results['error_msg'] == "Invalid parameter value. " \
"'limit' should in [1, 1000]." "'limit' should in [1, 1000]."
@pytest.mark.parametrize( @pytest.mark.parametrize(" offset, limit", [(0, 100), (1, 1), (0, 1000)])
" offset, limit",
[(0, 100), (1, 1), (0, 1000)]
)
@patch.object(GraphProcessor, '__init__') @patch.object(GraphProcessor, '__init__')
@patch.object(GraphProcessor, 'search_node_names') @patch.object(GraphProcessor, 'search_node_names')
def test_graph_node_names_success(self, mock_graph_processor, mock_graph_processor_1, client, def test_graph_node_names_success(self, mock_graph_processor, mock_graph_processor_1, client, offset, limit):
offset, limit):
""" """
Parsing unavailable params to get image metadata. Parsing unavailable params to get image metadata.
...@@ -367,8 +356,10 @@ class TestTrainVisual: ...@@ -367,8 +356,10 @@ class TestTrainVisual:
response status code: 200. response status code: 200.
response json: dict, contains search_content, offset, and limit. response json: dict, contains search_content, offset, and limit.
""" """
def mock_search_node_names(search_content, offset, limit): def mock_search_node_names(search_content, offset, limit):
return dict(search_content=search_content, offset=int(offset), limit=int(limit)) return dict(search_content=search_content, offset=int(offset), limit=int(limit))
mock_graph_processor.side_effect = mock_search_node_names mock_graph_processor.side_effect = mock_search_node_names
mock_init = Mock(return_value=None) mock_init = Mock(return_value=None)
...@@ -376,15 +367,12 @@ class TestTrainVisual: ...@@ -376,15 +367,12 @@ class TestTrainVisual:
test_train_id = "aaa" test_train_id = "aaa"
test_search_content = "bbb" test_search_content = "bbb"
params = dict(train_id=test_train_id, search=test_search_content, params = dict(train_id=test_train_id, search=test_search_content, offset=offset, limit=limit)
offset=offset, limit=limit)
url = get_url(TRAIN_ROUTES['graph_nodes_names'], params) url = get_url(TRAIN_ROUTES['graph_nodes_names'], params)
response = client.get(url) response = client.get(url)
assert response.status_code == 200 assert response.status_code == 200
results = response.get_json() results = response.get_json()
assert results == dict(search_content=test_search_content, assert results == dict(search_content=test_search_content, offset=int(offset), limit=int(limit))
offset=int(offset),
limit=int(limit))
def test_graph_search_single_node_with_params_is_wrong(self, client): def test_graph_search_single_node_with_params_is_wrong(self, client):
"""Test searching graph single node with params is wrong.""" """Test searching graph single node with params is wrong."""
...@@ -427,8 +415,10 @@ class TestTrainVisual: ...@@ -427,8 +415,10 @@ class TestTrainVisual:
response status code: 200. response status code: 200.
response json: name. response json: name.
""" """
def mock_search_single_node(name): def mock_search_single_node(name):
return name return name
mock_graph_processor.side_effect = mock_search_single_node mock_graph_processor.side_effect = mock_search_single_node
mock_init = Mock(return_value=None) mock_init = Mock(return_value=None)
......
...@@ -20,8 +20,42 @@ from unittest import TestCase, mock ...@@ -20,8 +20,42 @@ from unittest import TestCase, mock
from flask import Response from flask import Response
from mindinsight.backend.application import APP from mindinsight.backend.application import APP
from mindinsight.lineagemgr.common.exceptions.exceptions import \ from mindinsight.lineagemgr.common.exceptions.exceptions import LineageQuerySummaryDataError
LineageQuerySummaryDataError
LINEAGE_FILTRATION_BASE = {
'accuracy': None,
'mae': None,
'mse': None,
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 64,
'test_dataset_path': None,
'test_dataset_count': None,
'network': 'str',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 12,
'batch_size': 32,
'loss': 0.029999999329447746,
'model_size': 128
}
LINEAGE_FILTRATION_RUN1 = {
'accuracy': 0.78,
'mae': None,
'mse': None,
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 64,
'test_dataset_path': None,
'test_dataset_count': 64,
'network': 'str',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 14,
'batch_size': 32,
'loss': 0.029999999329447746,
'model_size': 128
}
class TestSearchModel(TestCase): class TestSearchModel(TestCase):
...@@ -42,39 +76,11 @@ class TestSearchModel(TestCase): ...@@ -42,39 +76,11 @@ class TestSearchModel(TestCase):
'object': [ 'object': [
{ {
'summary_dir': base_dir, 'summary_dir': base_dir,
'accuracy': None, **LINEAGE_FILTRATION_BASE
'mae': None,
'mse': None,
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 64,
'test_dataset_path': None,
'test_dataset_count': None,
'network': 'str',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 12,
'batch_size': 32,
'loss': 0.029999999329447746,
'model_size': 128
}, },
{ {
'summary_dir': os.path.join(base_dir, 'run1'), 'summary_dir': os.path.join(base_dir, 'run1'),
'accuracy': 0.78, **LINEAGE_FILTRATION_RUN1
'mae': None,
'mse': None,
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 64,
'test_dataset_path': None,
'test_dataset_count': 64,
'network': 'str',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 14,
'batch_size': 32,
'loss': 0.029999999329447746,
'model_size': 128
} }
], ],
'count': 2 'count': 2
...@@ -93,39 +99,11 @@ class TestSearchModel(TestCase): ...@@ -93,39 +99,11 @@ class TestSearchModel(TestCase):
'object': [ 'object': [
{ {
'summary_dir': './', 'summary_dir': './',
'accuracy': None, **LINEAGE_FILTRATION_BASE
'mae': None,
'mse': None,
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 64,
'test_dataset_path': None,
'test_dataset_count': None,
'network': 'str',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 12,
'batch_size': 32,
'loss': 0.029999999329447746,
'model_size': 128
}, },
{ {
'summary_dir': './run1', 'summary_dir': './run1',
'accuracy': 0.78, **LINEAGE_FILTRATION_RUN1
'mae': None,
'mse': None,
'loss_function': 'SoftmaxCrossEntropyWithLogits',
'train_dataset_path': None,
'train_dataset_count': 64,
'test_dataset_path': None,
'test_dataset_count': 64,
'network': 'str',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'epoch': 14,
'batch_size': 32,
'loss': 0.029999999329447746,
'model_size': 128
} }
], ],
'count': 2 'count': 2
......
...@@ -19,15 +19,16 @@ Usage: ...@@ -19,15 +19,16 @@ Usage:
pytest tests/ut/datavisual pytest tests/ut/datavisual
""" """
from unittest.mock import patch from unittest.mock import patch
from werkzeug.exceptions import MethodNotAllowed, NotFound
from tests.ut.backend.datavisual.conftest import TRAIN_ROUTES from werkzeug.exceptions import MethodNotAllowed, NotFound
from tests.ut.datavisual.mock import MockLogger
from tests.ut.datavisual.utils.utils import get_url
from mindinsight.datavisual.processors import scalars_processor from mindinsight.datavisual.processors import scalars_processor
from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor
from ....utils.tools import get_url
from ...backend.datavisual.conftest import TRAIN_ROUTES
from ..mock import MockLogger
class TestErrorHandler: class TestErrorHandler:
"""Test train visual api.""" """Test train visual api."""
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# ============================================================================ # ============================================================================
""" """
Function: Function:
Test mindinsight.datavisual.data_transform.log_generators.data_loader_generator Test mindinsight.datavisual.data_transform.loader_generators.data_loader_generator
Usage: Usage:
pytest tests/ut/datavisual pytest tests/ut/datavisual
""" """
...@@ -22,18 +22,19 @@ import datetime ...@@ -22,18 +22,19 @@ import datetime
import os import os
import shutil import shutil
import tempfile import tempfile
from unittest.mock import patch from unittest.mock import patch
import pytest
from tests.ut.datavisual.mock import MockLogger import pytest
from mindinsight.datavisual.data_transform.loader_generators import data_loader_generator from mindinsight.datavisual.data_transform.loader_generators import data_loader_generator
from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamValueError
from ...mock import MockLogger
class TestDataLoaderGenerator: class TestDataLoaderGenerator:
"""Test data_loader_generator.""" """Test data_loader_generator."""
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
data_loader_generator.logger = MockLogger data_loader_generator.logger = MockLogger
...@@ -88,8 +89,9 @@ class TestDataLoaderGenerator: ...@@ -88,8 +89,9 @@ class TestDataLoaderGenerator:
mock_data_loader.return_value = True mock_data_loader.return_value = True
loader_dict = generator.generate_loaders(loader_pool=dict()) loader_dict = generator.generate_loaders(loader_pool=dict())
expected_ids = [summary.get('relative_path') expected_ids = [
for summary in summaries[-data_loader_generator.MAX_DATA_LOADER_SIZE:]] summary.get('relative_path') for summary in summaries[-data_loader_generator.MAX_DATA_LOADER_SIZE:]
]
assert sorted(loader_dict.keys()) == sorted(expected_ids) assert sorted(loader_dict.keys()) == sorted(expected_ids)
shutil.rmtree(summary_base_dir) shutil.rmtree(summary_base_dir)
......
...@@ -23,12 +23,13 @@ import shutil ...@@ -23,12 +23,13 @@ import shutil
import tempfile import tempfile
import pytest import pytest
from tests.ut.datavisual.mock import MockLogger
from mindinsight.datavisual.common.exceptions import SummaryLogPathInvalid from mindinsight.datavisual.common.exceptions import SummaryLogPathInvalid
from mindinsight.datavisual.data_transform import data_loader from mindinsight.datavisual.data_transform import data_loader
from mindinsight.datavisual.data_transform.data_loader import DataLoader from mindinsight.datavisual.data_transform.data_loader import DataLoader
from ..mock import MockLogger
class TestDataLoader: class TestDataLoader:
"""Test data_loader.""" """Test data_loader."""
...@@ -37,13 +38,13 @@ class TestDataLoader: ...@@ -37,13 +38,13 @@ class TestDataLoader:
def setup_class(cls): def setup_class(cls):
data_loader.logger = MockLogger data_loader.logger = MockLogger
def setup_method(self, method): def setup_method(self):
self._summary_dir = tempfile.mkdtemp() self._summary_dir = tempfile.mkdtemp()
if os.path.exists(self._summary_dir): if os.path.exists(self._summary_dir):
shutil.rmtree(self._summary_dir) shutil.rmtree(self._summary_dir)
os.mkdir(self._summary_dir) os.mkdir(self._summary_dir)
def teardown_method(self, method): def teardown_method(self):
if os.path.exists(self._summary_dir): if os.path.exists(self._summary_dir):
shutil.rmtree(self._summary_dir) shutil.rmtree(self._summary_dir)
......
...@@ -18,32 +18,29 @@ Function: ...@@ -18,32 +18,29 @@ Function:
Usage: Usage:
pytest tests/ut/datavisual pytest tests/ut/datavisual
""" """
import time
import os import os
import shutil import shutil
import tempfile import tempfile
import time
from unittest import mock from unittest import mock
from unittest.mock import Mock from unittest.mock import Mock, patch
from unittest.mock import patch
import pytest import pytest
from tests.ut.datavisual.mock import MockLogger
from tests.ut.datavisual.utils.utils import check_loading_done
from mindinsight.datavisual.common.enums import DataManagerStatus, PluginNameEnum from mindinsight.datavisual.common.enums import DataManagerStatus, PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager, ms_data_loader from mindinsight.datavisual.data_transform import data_manager, ms_data_loader
from mindinsight.datavisual.data_transform.data_loader import DataLoader from mindinsight.datavisual.data_transform.data_loader import DataLoader
from mindinsight.datavisual.data_transform.data_manager import DataManager from mindinsight.datavisual.data_transform.data_manager import DataManager
from mindinsight.datavisual.data_transform.events_data import EventsData from mindinsight.datavisual.data_transform.events_data import EventsData
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import \ from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
DataLoaderGenerator from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE
from mindinsight.datavisual.data_transform.loader_generators.loader_generator import \ from mindinsight.datavisual.data_transform.loader_generators.loader_struct import LoaderStruct
MAX_DATA_LOADER_SIZE
from mindinsight.datavisual.data_transform.loader_generators.loader_struct import \
LoaderStruct
from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader
from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamValueError
from ....utils.tools import check_loading_done
from ..mock import MockLogger
class TestDataManager: class TestDataManager:
"""Test data_manager.""" """Test data_manager."""
...@@ -101,11 +98,17 @@ class TestDataManager: ...@@ -101,11 +98,17 @@ class TestDataManager:
"and loader pool size is '3'." "and loader pool size is '3'."
shutil.rmtree(summary_base_dir) shutil.rmtree(summary_base_dir)
@pytest.mark.parametrize('params', @pytest.mark.parametrize('params', [{
[{'reload_interval': '30'}, 'reload_interval': '30'
{'reload_interval': -1}, }, {
{'reload_interval': 30, 'max_threads_count': '20'}, 'reload_interval': -1
{'reload_interval': 30, 'max_threads_count': 0}]) }, {
'reload_interval': 30,
'max_threads_count': '20'
}, {
'reload_interval': 30,
'max_threads_count': 0
}])
def test_start_load_data_with_invalid_params(self, params): def test_start_load_data_with_invalid_params(self, params):
"""Test start_load_data with invalid reload_interval or invalid max_threads_count.""" """Test start_load_data with invalid reload_interval or invalid max_threads_count."""
summary_base_dir = tempfile.mkdtemp() summary_base_dir = tempfile.mkdtemp()
......
...@@ -22,20 +22,24 @@ import threading ...@@ -22,20 +22,24 @@ import threading
from collections import namedtuple from collections import namedtuple
import pytest import pytest
from tests.ut.datavisual.mock import MockLogger
from mindinsight.conf import settings from mindinsight.conf import settings
from mindinsight.datavisual.data_transform import events_data from mindinsight.datavisual.data_transform import events_data
from mindinsight.datavisual.data_transform.events_data import EventsData, TensorEvent, _Tensor from mindinsight.datavisual.data_transform.events_data import EventsData, TensorEvent, _Tensor
from ..mock import MockLogger
class MockReservoir: class MockReservoir:
"""Use this class to replace reservoir.Reservoir in test.""" """Use this class to replace reservoir.Reservoir in test."""
def __init__(self, size): def __init__(self, size):
self.size = size self.size = size
self._samples = [_Tensor('wall_time1', 1, 'value1'), _Tensor('wall_time2', 2, 'value2'), self._samples = [
_Tensor('wall_time3', 3, 'value3')] _Tensor('wall_time1', 1, 'value1'),
_Tensor('wall_time2', 2, 'value2'),
_Tensor('wall_time3', 3, 'value3')
]
def samples(self): def samples(self):
"""Replace the samples function.""" """Replace the samples function."""
...@@ -63,11 +67,12 @@ class TestEventsData: ...@@ -63,11 +67,12 @@ class TestEventsData:
def setup_method(self): def setup_method(self):
"""Mock original logger, init a EventsData object for use.""" """Mock original logger, init a EventsData object for use."""
self._ev_data = EventsData() self._ev_data = EventsData()
self._ev_data._tags_by_plugin = {'plugin_name1': [f'tag{i}' for i in range(10)], self._ev_data._tags_by_plugin = {
'plugin_name2': [f'tag{i}' for i in range(20, 30)]} 'plugin_name1': [f'tag{i}' for i in range(10)],
'plugin_name2': [f'tag{i}' for i in range(20, 30)]
}
self._ev_data._tags_by_plugin_mutex_lock.update({'plugin_name1': threading.Lock()}) self._ev_data._tags_by_plugin_mutex_lock.update({'plugin_name1': threading.Lock()})
self._ev_data._reservoir_by_tag = {'tag0': MockReservoir(500), self._ev_data._reservoir_by_tag = {'tag0': MockReservoir(500), 'new_tag': MockReservoir(500)}
'new_tag': MockReservoir(500)}
self._ev_data._tags = [f'tag{i}' for i in range(settings.MAX_TAG_SIZE_PER_EVENTS_DATA)] self._ev_data._tags = [f'tag{i}' for i in range(settings.MAX_TAG_SIZE_PER_EVENTS_DATA)]
def get_ev_data(self): def get_ev_data(self):
...@@ -102,8 +107,7 @@ class TestEventsData: ...@@ -102,8 +107,7 @@ class TestEventsData:
"""Test add_tensor_event success.""" """Test add_tensor_event success."""
ev_data = self.get_ev_data() ev_data = self.get_ev_data()
t_event = TensorEvent(wall_time=1, step=4, tag='new_tag', plugin_name='plugin_name1', t_event = TensorEvent(wall_time=1, step=4, tag='new_tag', plugin_name='plugin_name1', value='value1')
value='value1')
ev_data.add_tensor_event(t_event) ev_data.add_tensor_event(t_event)
assert 'tag0' not in ev_data._tags assert 'tag0' not in ev_data._tags
...@@ -111,6 +115,5 @@ class TestEventsData: ...@@ -111,6 +115,5 @@ class TestEventsData:
assert 'tag0' not in ev_data._tags_by_plugin['plugin_name1'] assert 'tag0' not in ev_data._tags_by_plugin['plugin_name1']
assert 'tag0' not in ev_data._reservoir_by_tag assert 'tag0' not in ev_data._reservoir_by_tag
assert 'new_tag' in ev_data._tags_by_plugin['plugin_name1'] assert 'new_tag' in ev_data._tags_by_plugin['plugin_name1']
assert ev_data._reservoir_by_tag['new_tag'].samples()[-1] == _Tensor(t_event.wall_time, assert ev_data._reservoir_by_tag['new_tag'].samples()[-1] == _Tensor(t_event.wall_time, t_event.step,
t_event.step,
t_event.value) t_event.value)
...@@ -19,16 +19,17 @@ Usage: ...@@ -19,16 +19,17 @@ Usage:
pytest tests/ut/datavisual pytest tests/ut/datavisual
""" """
import os import os
import tempfile
import shutil import shutil
import tempfile
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from tests.ut.datavisual.mock import MockLogger
from mindinsight.datavisual.data_transform import ms_data_loader from mindinsight.datavisual.data_transform import ms_data_loader
from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader
from ..mock import MockLogger
# bytes of 3 scalar events # bytes of 3 scalar events
SCALAR_RECORD = (b'\x1e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x96\xe1\xeb)>}\xd7A\x10\x01*' SCALAR_RECORD = (b'\x1e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x96\xe1\xeb)>}\xd7A\x10\x01*'
b'\x11\n\x0f\n\x08tag_name\x1d\r\x06V>\x00\x00\x00\x00\x1e\x00\x00\x00\x00\x00\x00' b'\x11\n\x0f\n\x08tag_name\x1d\r\x06V>\x00\x00\x00\x00\x1e\x00\x00\x00\x00\x00\x00'
...@@ -74,7 +75,8 @@ class TestMsDataLoader: ...@@ -74,7 +75,8 @@ class TestMsDataLoader:
"we will reload all files in path {}.".format(summary_dir) "we will reload all files in path {}.".format(summary_dir)
shutil.rmtree(summary_dir) shutil.rmtree(summary_dir)
def test_load_success_with_crc_pass(self, crc_pass): @pytest.mark.usefixtures('crc_pass')
def test_load_success_with_crc_pass(self):
"""Test load success.""" """Test load success."""
summary_dir = tempfile.mkdtemp() summary_dir = tempfile.mkdtemp()
file1 = os.path.join(summary_dir, 'summary.01') file1 = os.path.join(summary_dir, 'summary.01')
...@@ -88,7 +90,8 @@ class TestMsDataLoader: ...@@ -88,7 +90,8 @@ class TestMsDataLoader:
tensors = ms_loader.get_events_data().tensors(tag[0]) tensors = ms_loader.get_events_data().tensors(tag[0])
assert len(tensors) == 3 assert len(tensors) == 3
def test_load_with_crc_fail(self, crc_fail): @pytest.mark.usefixtures('crc_fail')
def test_load_with_crc_fail(self):
"""Test when crc_fail and will not go to func _event_parse.""" """Test when crc_fail and will not go to func _event_parse."""
summary_dir = tempfile.mkdtemp() summary_dir = tempfile.mkdtemp()
file2 = os.path.join(summary_dir, 'summary.02') file2 = os.path.join(summary_dir, 'summary.02')
...@@ -100,8 +103,10 @@ class TestMsDataLoader: ...@@ -100,8 +103,10 @@ class TestMsDataLoader:
def test_filter_event_files(self): def test_filter_event_files(self):
"""Test filter_event_files function ok.""" """Test filter_event_files function ok."""
file_list = ['abc.summary', '123sumary0009abc', 'summary1234', 'aaasummary.5678', file_list = [
'summary.0012', 'hellosummary.98786', 'mysummary.123abce', 'summay.4567'] 'abc.summary', '123sumary0009abc', 'summary1234', 'aaasummary.5678', 'summary.0012', 'hellosummary.98786',
'mysummary.123abce', 'summay.4567'
]
summary_dir = tempfile.mkdtemp() summary_dir = tempfile.mkdtemp()
for file in file_list: for file in file_list:
with open(os.path.join(summary_dir, file), 'w'): with open(os.path.join(summary_dir, file), 'w'):
...@@ -113,6 +118,7 @@ class TestMsDataLoader: ...@@ -113,6 +118,7 @@ class TestMsDataLoader:
shutil.rmtree(summary_dir) shutil.rmtree(summary_dir)
def write_file(filename, record): def write_file(filename, record):
"""Write bytes strings to file.""" """Write bytes strings to file."""
with open(filename, 'wb') as file: with open(filename, 'wb') as file:
......
...@@ -19,18 +19,11 @@ Usage: ...@@ -19,18 +19,11 @@ Usage:
pytest tests/ut/datavisual pytest tests/ut/datavisual
""" """
import os import os
import json
import tempfile import tempfile
from unittest.mock import Mock, patch
from unittest.mock import Mock
from unittest.mock import patch
import pytest import pytest
from tests.ut.datavisual.mock import MockLogger
from tests.ut.datavisual.utils.log_operations import LogOperations
from tests.ut.datavisual.utils.utils import check_loading_done, delete_files_or_dirs
from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform import data_manager
...@@ -40,6 +33,10 @@ from mindinsight.datavisual.processors.graph_processor import GraphProcessor ...@@ -40,6 +33,10 @@ from mindinsight.datavisual.processors.graph_processor import GraphProcessor
from mindinsight.datavisual.utils import crc32 from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamValueError
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, compare_result_with_file, delete_files_or_dirs
from ..mock import MockLogger
class TestGraphProcessor: class TestGraphProcessor:
"""Test Graph Processor api.""" """Test Graph Processor api."""
...@@ -70,18 +67,13 @@ class TestGraphProcessor: ...@@ -70,18 +67,13 @@ class TestGraphProcessor:
"""Load graph record.""" """Load graph record."""
summary_base_dir = tempfile.mkdtemp() summary_base_dir = tempfile.mkdtemp()
log_dir = tempfile.mkdtemp(dir=summary_base_dir) log_dir = tempfile.mkdtemp(dir=summary_base_dir)
self._train_id = log_dir.replace(summary_base_dir, ".") self._train_id = log_dir.replace(summary_base_dir, ".")
graph_base_path = os.path.join(os.path.dirname(__file__), log_operation = LogOperations()
os.pardir, "utils", "log_generators", "graph_base.json") self._temp_path, self._graph_dict = log_operation.generate_log(PluginNameEnum.GRAPH.value, log_dir)
self._temp_path, self._graph_dict = LogOperations.generate_log(
PluginNameEnum.GRAPH.value, log_dir, dict(graph_base_path=graph_base_path))
self._generated_path.append(summary_base_dir) self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager( self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)])
[DataLoaderGenerator(summary_base_dir)])
self._mock_data_manager.start_load_data(reload_interval=0) self._mock_data_manager.start_load_data(reload_interval=0)
# wait for loading done # wait for loading done
...@@ -94,33 +86,29 @@ class TestGraphProcessor: ...@@ -94,33 +86,29 @@ class TestGraphProcessor:
log_dir = tempfile.mkdtemp(dir=summary_base_dir) log_dir = tempfile.mkdtemp(dir=summary_base_dir)
self._train_id = log_dir.replace(summary_base_dir, ".") self._train_id = log_dir.replace(summary_base_dir, ".")
self._temp_path, _, _ = LogOperations.generate_log( log_operation = LogOperations()
PluginNameEnum.IMAGE.value, log_dir, dict(steps=self._steps_list, tag="image")) self._temp_path, _, _ = log_operation.generate_log(PluginNameEnum.IMAGE.value, log_dir,
dict(steps=self._steps_list, tag="image"))
self._generated_path.append(summary_base_dir) self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager( self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)])
[DataLoaderGenerator(summary_base_dir)])
self._mock_data_manager.start_load_data(reload_interval=0) self._mock_data_manager.start_load_data(reload_interval=0)
# wait for loading done # wait for loading done
check_loading_done(self._mock_data_manager, time_limit=5) check_loading_done(self._mock_data_manager, time_limit=5)
def compare_result_with_file(self, result, filename): @pytest.mark.usefixtures('load_graph_record')
"""Compare result with file which contain the expected results.""" def test_get_nodes_with_not_exist_train_id(self):
with open(os.path.join(self.graph_results_dir, filename), 'r') as fp:
expected_results = json.load(fp)
assert result == expected_results
def test_get_nodes_with_not_exist_train_id(self, load_graph_record):
"""Test getting nodes with not exist train id.""" """Test getting nodes with not exist train id."""
test_train_id = "not_exist_train_id" test_train_id = "not_exist_train_id"
with pytest.raises(ParamValueError) as exc_info: with pytest.raises(ParamValueError) as exc_info:
GraphProcessor(test_train_id, self._mock_data_manager) GraphProcessor(test_train_id, self._mock_data_manager)
assert "Can not find the train job in data manager." in exc_info.value.message assert "Can not find the train job in data manager." in exc_info.value.message
@pytest.mark.usefixtures('load_graph_record')
@patch.object(DataManager, 'get_train_job_by_plugin') @patch.object(DataManager, 'get_train_job_by_plugin')
def test_get_nodes_with_loader_is_none(self, mock_get_train_job_by_plugin, load_graph_record): def test_get_nodes_with_loader_is_none(self, mock_get_train_job_by_plugin):
"""Test get nodes with loader is None.""" """Test get nodes with loader is None."""
mock_get_train_job_by_plugin.return_value = None mock_get_train_job_by_plugin.return_value = None
with pytest.raises(exceptions.SummaryLogPathInvalid): with pytest.raises(exceptions.SummaryLogPathInvalid):
...@@ -128,15 +116,12 @@ class TestGraphProcessor: ...@@ -128,15 +116,12 @@ class TestGraphProcessor:
assert mock_get_train_job_by_plugin.called assert mock_get_train_job_by_plugin.called
@pytest.mark.parametrize("name, node_type", [ @pytest.mark.usefixtures('load_graph_record')
("not_exist_name", "name_scope"), @pytest.mark.parametrize("name, node_type", [("not_exist_name", "name_scope"), ("", "polymeric_scope")])
("", "polymeric_scope") def test_get_nodes_with_not_exist_name(self, name, node_type):
])
def test_get_nodes_with_not_exist_name(self, load_graph_record, name, node_type):
"""Test getting nodes with not exist name.""" """Test getting nodes with not exist name."""
with pytest.raises(ParamValueError) as exc_info: with pytest.raises(ParamValueError) as exc_info:
graph_processor = GraphProcessor(self._train_id, graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
self._mock_data_manager)
graph_processor.get_nodes(name, node_type) graph_processor.get_nodes(name, node_type)
if name: if name:
...@@ -144,105 +129,99 @@ class TestGraphProcessor: ...@@ -144,105 +129,99 @@ class TestGraphProcessor:
else: else:
assert f'The node name "{name}" not in graph, node type is {node_type}.' in exc_info.value.message assert f'The node name "{name}" not in graph, node type is {node_type}.' in exc_info.value.message
@pytest.mark.parametrize("name, node_type, result_file", [ @pytest.mark.usefixtures('load_graph_record')
(None, 'name_scope', 'test_get_nodes_success_expected_results1.json'), @pytest.mark.parametrize(
('Default/conv1-Conv2d', 'name_scope', 'test_get_nodes_success_expected_results2.json'), "name, node_type, result_file",
('Default/bn1/Reshape_1_[12]', 'polymeric_scope', 'test_get_nodes_success_expected_results3.json') [(None, 'name_scope', 'test_get_nodes_success_expected_results1.json'),
]) ('Default/conv1-Conv2d', 'name_scope', 'test_get_nodes_success_expected_results2.json'),
def test_get_nodes_success(self, load_graph_record, name, node_type, result_file): ('Default/bn1/Reshape_1_[12]', 'polymeric_scope', 'test_get_nodes_success_expected_results3.json')])
def test_get_nodes_success(self, name, node_type, result_file):
"""Test getting nodes successfully.""" """Test getting nodes successfully."""
graph_processor = GraphProcessor(self._train_id, graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
self._mock_data_manager)
results = graph_processor.get_nodes(name, node_type) results = graph_processor.get_nodes(name, node_type)
self.compare_result_with_file(results, result_file)
expected_file_path = os.path.join(self.graph_results_dir, result_file)
@pytest.mark.parametrize("search_content, result_file", [ compare_result_with_file(results, expected_file_path)
(None, 'test_search_node_names_with_search_content_expected_results1.json'),
('Default/bn1', 'test_search_node_names_with_search_content_expected_results2.json'), @pytest.mark.usefixtures('load_graph_record')
('not_exist_search_content', None) @pytest.mark.parametrize("search_content, result_file",
]) [(None, 'test_search_node_names_with_search_content_expected_results1.json'),
def test_search_node_names_with_search_content(self, load_graph_record, ('Default/bn1', 'test_search_node_names_with_search_content_expected_results2.json'),
search_content, ('not_exist_search_content', None)])
result_file): def test_search_node_names_with_search_content(self, search_content, result_file):
"""Test search node names with search content.""" """Test search node names with search content."""
test_offset = 0 test_offset = 0
test_limit = 1000 test_limit = 1000
graph_processor = GraphProcessor(self._train_id, graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
self._mock_data_manager) results = graph_processor.search_node_names(search_content, test_offset, test_limit)
results = graph_processor.search_node_names(search_content,
test_offset,
test_limit)
if search_content == 'not_exist_search_content': if search_content == 'not_exist_search_content':
expected_results = {'names': []} expected_results = {'names': []}
assert results == expected_results assert results == expected_results
else: else:
self.compare_result_with_file(results, result_file) expected_file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(results, expected_file_path)
@pytest.mark.usefixtures('load_graph_record')
@pytest.mark.parametrize("offset", [-100, -1]) @pytest.mark.parametrize("offset", [-100, -1])
def test_search_node_names_with_negative_offset(self, load_graph_record, offset): def test_search_node_names_with_negative_offset(self, offset):
"""Test search node names with negative offset.""" """Test search node names with negative offset."""
test_search_content = "" test_search_content = ""
test_limit = 3 test_limit = 3
graph_processor = GraphProcessor(self._train_id, graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
self._mock_data_manager)
with pytest.raises(ParamValueError) as exc_info: with pytest.raises(ParamValueError) as exc_info:
graph_processor.search_node_names(test_search_content, offset, test_limit) graph_processor.search_node_names(test_search_content, offset, test_limit)
assert "'offset' should be greater than or equal to 0." in exc_info.value.message assert "'offset' should be greater than or equal to 0." in exc_info.value.message
@pytest.mark.parametrize("offset, result_file", [ @pytest.mark.usefixtures('load_graph_record')
(1, 'test_search_node_names_with_offset_expected_results1.json') @pytest.mark.parametrize("offset, result_file", [(1, 'test_search_node_names_with_offset_expected_results1.json')])
]) def test_search_node_names_with_offset(self, offset, result_file):
def test_search_node_names_with_offset(self, load_graph_record, offset, result_file):
"""Test search node names with offset.""" """Test search node names with offset."""
test_search_content = "Default/bn1" test_search_content = "Default/bn1"
test_offset = offset test_offset = offset
test_limit = 3 test_limit = 3
graph_processor = GraphProcessor(self._train_id, graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
self._mock_data_manager) results = graph_processor.search_node_names(test_search_content, test_offset, test_limit)
results = graph_processor.search_node_names(test_search_content, expected_file_path = os.path.join(self.graph_results_dir, result_file)
test_offset, compare_result_with_file(results, expected_file_path)
test_limit)
self.compare_result_with_file(results, result_file)
def test_search_node_names_with_wrong_limit(self, load_graph_record): @pytest.mark.usefixtures('load_graph_record')
def test_search_node_names_with_wrong_limit(self):
"""Test search node names with wrong limit.""" """Test search node names with wrong limit."""
test_search_content = "" test_search_content = ""
test_offset = 0 test_offset = 0
test_limit = 0 test_limit = 0
graph_processor = GraphProcessor(self._train_id, graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
self._mock_data_manager)
with pytest.raises(ParamValueError) as exc_info: with pytest.raises(ParamValueError) as exc_info:
graph_processor.search_node_names(test_search_content, test_offset, graph_processor.search_node_names(test_search_content, test_offset, test_limit)
test_limit)
assert "'limit' should in [1, 1000]." in exc_info.value.message assert "'limit' should in [1, 1000]." in exc_info.value.message
@pytest.mark.parametrize("name, result_file", [ @pytest.mark.usefixtures('load_graph_record')
('Default/bn1', 'test_search_single_node_success_expected_results1.json') @pytest.mark.parametrize("name, result_file",
]) [('Default/bn1', 'test_search_single_node_success_expected_results1.json')])
def test_search_single_node_success(self, load_graph_record, name, result_file): def test_search_single_node_success(self, name, result_file):
"""Test searching single node successfully.""" """Test searching single node successfully."""
graph_processor = GraphProcessor(self._train_id, graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
self._mock_data_manager)
results = graph_processor.search_single_node(name) results = graph_processor.search_single_node(name)
self.compare_result_with_file(results, result_file) expected_file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(results, expected_file_path)
def test_search_single_node_with_not_exist_name(self, load_graph_record): @pytest.mark.usefixtures('load_graph_record')
def test_search_single_node_with_not_exist_name(self):
"""Test searching single node with not exist name.""" """Test searching single node with not exist name."""
test_name = "not_exist_name" test_name = "not_exist_name"
with pytest.raises(exceptions.NodeNotInGraphError): with pytest.raises(exceptions.NodeNotInGraphError):
graph_processor = GraphProcessor(self._train_id, graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
self._mock_data_manager)
graph_processor.search_single_node(test_name) graph_processor.search_single_node(test_name)
def test_check_graph_status_no_graph(self, load_no_graph_record): @pytest.mark.usefixtures('load_no_graph_record')
def test_check_graph_status_no_graph(self):
"""Test checking graph status no graph.""" """Test checking graph status no graph."""
with pytest.raises(ParamValueError) as exc_info: with pytest.raises(ParamValueError) as exc_info:
GraphProcessor(self._train_id, self._mock_data_manager) GraphProcessor(self._train_id, self._mock_data_manager)
......
...@@ -22,9 +22,6 @@ import tempfile ...@@ -22,9 +22,6 @@ import tempfile
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from tests.ut.datavisual.mock import MockLogger
from tests.ut.datavisual.utils.log_operations import LogOperations
from tests.ut.datavisual.utils.utils import check_loading_done, delete_files_or_dirs
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform import data_manager
...@@ -33,6 +30,10 @@ from mindinsight.datavisual.processors.images_processor import ImageProcessor ...@@ -33,6 +30,10 @@ from mindinsight.datavisual.processors.images_processor import ImageProcessor
from mindinsight.datavisual.utils import crc32 from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamValueError
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, delete_files_or_dirs, get_image_tensor_from_bytes
from ..mock import MockLogger
class TestImagesProcessor: class TestImagesProcessor:
"""Test images processor api.""" """Test images processor api."""
...@@ -73,12 +74,11 @@ class TestImagesProcessor: ...@@ -73,12 +74,11 @@ class TestImagesProcessor:
""" """
summary_base_dir = tempfile.mkdtemp() summary_base_dir = tempfile.mkdtemp()
log_dir = tempfile.mkdtemp(dir=summary_base_dir) log_dir = tempfile.mkdtemp(dir=summary_base_dir)
self._train_id = log_dir.replace(summary_base_dir, ".") self._train_id = log_dir.replace(summary_base_dir, ".")
self._temp_path, self._images_metadata, self._images_values = LogOperations.generate_log( log_operation = LogOperations()
self._temp_path, self._images_metadata, self._images_values = log_operation.generate_log(
PluginNameEnum.IMAGE.value, log_dir, dict(steps=steps_list, tag=self._tag_name)) PluginNameEnum.IMAGE.value, log_dir, dict(steps=steps_list, tag=self._tag_name))
self._generated_path.append(summary_base_dir) self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)])
...@@ -102,7 +102,8 @@ class TestImagesProcessor: ...@@ -102,7 +102,8 @@ class TestImagesProcessor:
"""Load image record.""" """Load image record."""
self._init_data_manager(self._cross_steps_list) self._init_data_manager(self._cross_steps_list)
def test_get_metadata_list_with_not_exist_id(self, load_image_record): @pytest.mark.usefixtures('load_image_record')
def test_get_metadata_list_with_not_exist_id(self):
"""Test getting metadata list with not exist id.""" """Test getting metadata list with not exist id."""
test_train_id = 'not_exist_id' test_train_id = 'not_exist_id'
image_processor = ImageProcessor(self._mock_data_manager) image_processor = ImageProcessor(self._mock_data_manager)
...@@ -112,7 +113,8 @@ class TestImagesProcessor: ...@@ -112,7 +113,8 @@ class TestImagesProcessor:
assert exc_info.value.error_code == '50540002' assert exc_info.value.error_code == '50540002'
assert "Can not find any data in loader pool about the train job." in exc_info.value.message assert "Can not find any data in loader pool about the train job." in exc_info.value.message
def test_get_metadata_list_with_not_exist_tag(self, load_image_record): @pytest.mark.usefixtures('load_image_record')
def test_get_metadata_list_with_not_exist_tag(self):
"""Test get metadata list with not exist tag.""" """Test get metadata list with not exist tag."""
test_tag_name = 'not_exist_tag_name' test_tag_name = 'not_exist_tag_name'
...@@ -124,7 +126,8 @@ class TestImagesProcessor: ...@@ -124,7 +126,8 @@ class TestImagesProcessor:
assert exc_info.value.error_code == '50540002' assert exc_info.value.error_code == '50540002'
assert "Can not find any data in this train job by given tag." in exc_info.value.message assert "Can not find any data in this train job by given tag." in exc_info.value.message
def test_get_metadata_list_success(self, load_image_record): @pytest.mark.usefixtures('load_image_record')
def test_get_metadata_list_success(self):
"""Test getting metadata list success.""" """Test getting metadata list success."""
test_tag_name = self._complete_tag_name test_tag_name = self._complete_tag_name
...@@ -133,7 +136,8 @@ class TestImagesProcessor: ...@@ -133,7 +136,8 @@ class TestImagesProcessor:
assert results == self._images_metadata assert results == self._images_metadata
def test_get_single_image_with_not_exist_id(self, load_image_record): @pytest.mark.usefixtures('load_image_record')
def test_get_single_image_with_not_exist_id(self):
"""Test getting single image with not exist id.""" """Test getting single image with not exist id."""
test_train_id = 'not_exist_id' test_train_id = 'not_exist_id'
test_tag_name = self._complete_tag_name test_tag_name = self._complete_tag_name
...@@ -146,7 +150,8 @@ class TestImagesProcessor: ...@@ -146,7 +150,8 @@ class TestImagesProcessor:
assert exc_info.value.error_code == '50540002' assert exc_info.value.error_code == '50540002'
assert "Can not find any data in loader pool about the train job." in exc_info.value.message assert "Can not find any data in loader pool about the train job." in exc_info.value.message
def test_get_single_image_with_not_exist_tag(self, load_image_record): @pytest.mark.usefixtures('load_image_record')
def test_get_single_image_with_not_exist_tag(self):
"""Test getting single image with not exist tag.""" """Test getting single image with not exist tag."""
test_tag_name = 'not_exist_tag_name' test_tag_name = 'not_exist_tag_name'
test_step = self._steps_list[0] test_step = self._steps_list[0]
...@@ -159,7 +164,8 @@ class TestImagesProcessor: ...@@ -159,7 +164,8 @@ class TestImagesProcessor:
assert exc_info.value.error_code == '50540002' assert exc_info.value.error_code == '50540002'
assert "Can not find any data in this train job by given tag." in exc_info.value.message assert "Can not find any data in this train job by given tag." in exc_info.value.message
def test_get_single_image_with_not_exist_step(self, load_image_record): @pytest.mark.usefixtures('load_image_record')
def test_get_single_image_with_not_exist_step(self):
"""Test getting single image with not exist step.""" """Test getting single image with not exist step."""
test_tag_name = self._complete_tag_name test_tag_name = self._complete_tag_name
test_step = 10000 test_step = 10000
...@@ -172,24 +178,22 @@ class TestImagesProcessor: ...@@ -172,24 +178,22 @@ class TestImagesProcessor:
assert exc_info.value.error_code == '50540002' assert exc_info.value.error_code == '50540002'
assert "Can not find the step with given train job id and tag." in exc_info.value.message assert "Can not find the step with given train job id and tag." in exc_info.value.message
def test_get_single_image_success(self, load_image_record): @pytest.mark.usefixtures('load_image_record')
def test_get_single_image_success(self):
"""Test getting single image successfully.""" """Test getting single image successfully."""
test_tag_name = self._complete_tag_name test_tag_name = self._complete_tag_name
test_step_index = 0 test_step_index = 0
test_step = self._steps_list[test_step_index] test_step = self._steps_list[test_step_index]
expected_image_tensor = self._images_values.get(test_step)
image_processor = ImageProcessor(self._mock_data_manager) image_processor = ImageProcessor(self._mock_data_manager)
results = image_processor.get_single_image(self._train_id, test_tag_name, test_step) results = image_processor.get_single_image(self._train_id, test_tag_name, test_step)
recv_image_tensor = get_image_tensor_from_bytes(results)
expected_image_tensor = self._images_values.get(test_step)
image_generator = LogOperations.get_log_generator(PluginNameEnum.IMAGE.value)
recv_image_tensor = image_generator.get_image_tensor_from_bytes(results)
assert recv_image_tensor.any() == expected_image_tensor.any() assert recv_image_tensor.any() == expected_image_tensor.any()
def test_reservoir_add_sample(self, load_more_than_limit_image_record): @pytest.mark.usefixtures('load_more_than_limit_image_record')
def test_reservoir_add_sample(self):
"""Test adding sample in reservoir.""" """Test adding sample in reservoir."""
test_tag_name = self._complete_tag_name test_tag_name = self._complete_tag_name
...@@ -206,7 +210,8 @@ class TestImagesProcessor: ...@@ -206,7 +210,8 @@ class TestImagesProcessor:
cnt += 1 cnt += 1
assert len(self._more_steps_list) - cnt == 10 assert len(self._more_steps_list) - cnt == 10
def test_reservoir_remove_sample(self, load_reservoir_remove_sample_image_record): @pytest.mark.usefixtures('load_reservoir_remove_sample_image_record')
def test_reservoir_remove_sample(self):
""" """
Test removing sample in reservoir. Test removing sample in reservoir.
......
...@@ -22,9 +22,6 @@ import tempfile ...@@ -22,9 +22,6 @@ import tempfile
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from tests.ut.datavisual.mock import MockLogger
from tests.ut.datavisual.utils.log_operations import LogOperations
from tests.ut.datavisual.utils.utils import check_loading_done, delete_files_or_dirs
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform import data_manager
...@@ -33,6 +30,10 @@ from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor ...@@ -33,6 +30,10 @@ from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor
from mindinsight.datavisual.utils import crc32 from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamValueError
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, delete_files_or_dirs
from ..mock import MockLogger
class TestScalarsProcessor: class TestScalarsProcessor:
"""Test scalar processor api.""" """Test scalar processor api."""
...@@ -65,12 +66,11 @@ class TestScalarsProcessor: ...@@ -65,12 +66,11 @@ class TestScalarsProcessor:
"""Load scalar record.""" """Load scalar record."""
summary_base_dir = tempfile.mkdtemp() summary_base_dir = tempfile.mkdtemp()
log_dir = tempfile.mkdtemp(dir=summary_base_dir) log_dir = tempfile.mkdtemp(dir=summary_base_dir)
self._train_id = log_dir.replace(summary_base_dir, ".") self._train_id = log_dir.replace(summary_base_dir, ".")
self._temp_path, self._scalars_metadata, self._scalars_values = LogOperations.generate_log( log_operation = LogOperations()
self._temp_path, self._scalars_metadata, self._scalars_values = log_operation.generate_log(
PluginNameEnum.SCALAR.value, log_dir, dict(step=self._steps_list, tag=self._tag_name)) PluginNameEnum.SCALAR.value, log_dir, dict(step=self._steps_list, tag=self._tag_name))
self._generated_path.append(summary_base_dir) self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)])
...@@ -79,7 +79,8 @@ class TestScalarsProcessor: ...@@ -79,7 +79,8 @@ class TestScalarsProcessor:
# wait for loading done # wait for loading done
check_loading_done(self._mock_data_manager, time_limit=5) check_loading_done(self._mock_data_manager, time_limit=5)
def test_get_metadata_list_with_not_exist_id(self, load_scalar_record): @pytest.mark.usefixtures('load_scalar_record')
def test_get_metadata_list_with_not_exist_id(self):
"""Get metadata list with not exist id.""" """Get metadata list with not exist id."""
test_train_id = 'not_exist_id' test_train_id = 'not_exist_id'
scalar_processor = ScalarsProcessor(self._mock_data_manager) scalar_processor = ScalarsProcessor(self._mock_data_manager)
...@@ -89,7 +90,8 @@ class TestScalarsProcessor: ...@@ -89,7 +90,8 @@ class TestScalarsProcessor:
assert exc_info.value.error_code == '50540002' assert exc_info.value.error_code == '50540002'
assert "Can not find any data in loader pool about the train job." in exc_info.value.message assert "Can not find any data in loader pool about the train job." in exc_info.value.message
def test_get_metadata_list_with_not_exist_tag(self, load_scalar_record): @pytest.mark.usefixtures('load_scalar_record')
def test_get_metadata_list_with_not_exist_tag(self):
"""Get metadata list with not exist tag.""" """Get metadata list with not exist tag."""
test_tag_name = 'not_exist_tag_name' test_tag_name = 'not_exist_tag_name'
...@@ -101,7 +103,8 @@ class TestScalarsProcessor: ...@@ -101,7 +103,8 @@ class TestScalarsProcessor:
assert exc_info.value.error_code == '50540002' assert exc_info.value.error_code == '50540002'
assert "Can not find any data in this train job by given tag." in exc_info.value.message assert "Can not find any data in this train job by given tag." in exc_info.value.message
def test_get_metadata_list_success(self, load_scalar_record): @pytest.mark.usefixtures('load_scalar_record')
def test_get_metadata_list_success(self):
"""Get metadata list success.""" """Get metadata list success."""
test_tag_name = self._complete_tag_name test_tag_name = self._complete_tag_name
......
...@@ -18,15 +18,11 @@ Function: ...@@ -18,15 +18,11 @@ Function:
Usage: Usage:
pytest tests/ut/datavisual pytest tests/ut/datavisual
""" """
import os
import tempfile import tempfile
import time import time
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from tests.ut.datavisual.mock import MockLogger
from tests.ut.datavisual.utils.log_operations import LogOperations
from tests.ut.datavisual.utils.utils import check_loading_done, delete_files_or_dirs
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform import data_manager
...@@ -35,6 +31,10 @@ from mindinsight.datavisual.processors.train_task_manager import TrainTaskManage ...@@ -35,6 +31,10 @@ from mindinsight.datavisual.processors.train_task_manager import TrainTaskManage
from mindinsight.datavisual.utils import crc32 from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamValueError
from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, delete_files_or_dirs
from ..mock import MockLogger
class TestTrainTaskManager: class TestTrainTaskManager:
"""Test train task manager.""" """Test train task manager."""
...@@ -70,39 +70,30 @@ class TestTrainTaskManager: ...@@ -70,39 +70,30 @@ class TestTrainTaskManager:
@pytest.fixture(scope='function') @pytest.fixture(scope='function')
def load_data(self): def load_data(self):
"""Load data.""" """Load data."""
log_operation = LogOperations()
self._plugins_id_map = {'image': [], 'scalar': [], 'graph': []} self._plugins_id_map = {'image': [], 'scalar': [], 'graph': []}
self._events_names = [] self._events_names = []
self._train_id_list = [] self._train_id_list = []
graph_base_path = os.path.join(os.path.dirname(__file__),
os.pardir, "utils", "log_generators", "graph_base.json")
self._root_dir = tempfile.mkdtemp() self._root_dir = tempfile.mkdtemp()
for i in range(self._dir_num): for i in range(self._dir_num):
dir_path = tempfile.mkdtemp(dir=self._root_dir) dir_path = tempfile.mkdtemp(dir=self._root_dir)
tmp_tag_name = self._tag_name + '_' + str(i) tmp_tag_name = self._tag_name + '_' + str(i)
event_name = str(i) + "_name" event_name = str(i) + "_name"
train_id = dir_path.replace(self._root_dir, ".") train_id = dir_path.replace(self._root_dir, ".")
# Pass timestamp to write to the same file. # Pass timestamp to write to the same file.
log_settings = dict( log_settings = dict(steps=self._steps_list, tag=tmp_tag_name, time=time.time())
steps=self._steps_list,
tag=tmp_tag_name,
graph_base_path=graph_base_path,
time=time.time())
if i % 3 != 0: if i % 3 != 0:
LogOperations.generate_log(PluginNameEnum.IMAGE.value, dir_path, log_settings) log_operation.generate_log(PluginNameEnum.IMAGE.value, dir_path, log_settings)
self._plugins_id_map['image'].append(train_id) self._plugins_id_map['image'].append(train_id)
if i % 3 != 1: if i % 3 != 1:
LogOperations.generate_log(PluginNameEnum.SCALAR.value, dir_path, log_settings) log_operation.generate_log(PluginNameEnum.SCALAR.value, dir_path, log_settings)
self._plugins_id_map['scalar'].append(train_id) self._plugins_id_map['scalar'].append(train_id)
if i % 3 != 2: if i % 3 != 2:
LogOperations.generate_log(PluginNameEnum.GRAPH.value, dir_path, log_settings) log_operation.generate_log(PluginNameEnum.GRAPH.value, dir_path, log_settings)
self._plugins_id_map['graph'].append(train_id) self._plugins_id_map['graph'].append(train_id)
self._events_names.append(event_name) self._events_names.append(event_name)
self._train_id_list.append(train_id) self._train_id_list.append(train_id)
self._generated_path.append(self._root_dir) self._generated_path.append(self._root_dir)
...@@ -112,7 +103,8 @@ class TestTrainTaskManager: ...@@ -112,7 +103,8 @@ class TestTrainTaskManager:
check_loading_done(self._mock_data_manager, time_limit=30) check_loading_done(self._mock_data_manager, time_limit=30)
def test_get_single_train_task_with_not_exists_train_id(self, load_data): @pytest.mark.usefixtures('load_data')
def test_get_single_train_task_with_not_exists_train_id(self):
"""Test getting single train task with not exists train_id.""" """Test getting single train task with not exists train_id."""
train_task_manager = TrainTaskManager(self._mock_data_manager) train_task_manager = TrainTaskManager(self._mock_data_manager)
for plugin_name in PluginNameEnum.list_members(): for plugin_name in PluginNameEnum.list_members():
...@@ -124,7 +116,8 @@ class TestTrainTaskManager: ...@@ -124,7 +116,8 @@ class TestTrainTaskManager:
"the train job in data manager." "the train job in data manager."
assert exc_info.value.error_code == '50540002' assert exc_info.value.error_code == '50540002'
def test_get_single_train_task_with_params(self, load_data): @pytest.mark.usefixtures('load_data')
def test_get_single_train_task_with_params(self):
"""Test getting single train task with params.""" """Test getting single train task with params."""
train_task_manager = TrainTaskManager(self._mock_data_manager) train_task_manager = TrainTaskManager(self._mock_data_manager)
for plugin_name in PluginNameEnum.list_members(): for plugin_name in PluginNameEnum.list_members():
...@@ -138,7 +131,8 @@ class TestTrainTaskManager: ...@@ -138,7 +131,8 @@ class TestTrainTaskManager:
else: else:
assert test_train_id not in self._plugins_id_map.get(plugin_name) assert test_train_id not in self._plugins_id_map.get(plugin_name)
def test_get_plugins_with_train_id(self, load_data): @pytest.mark.usefixtures('load_data')
def test_get_plugins_with_train_id(self):
"""Test getting plugins with train id.""" """Test getting plugins with train id."""
train_task_manager = TrainTaskManager(self._mock_data_manager) train_task_manager = TrainTaskManager(self._mock_data_manager)
......
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mask string to crc32."""
CRC_TABLE_32 = (
0x00000000, 0xF26B8303, 0xE13B70F7, 0x1350F3F4, 0xC79A971F, 0x35F1141C, 0x26A1E7E8, 0xD4CA64EB, 0x8AD958CF,
0x78B2DBCC, 0x6BE22838, 0x9989AB3B, 0x4D43CFD0, 0xBF284CD3, 0xAC78BF27, 0x5E133C24, 0x105EC76F, 0xE235446C,
0xF165B798, 0x030E349B, 0xD7C45070, 0x25AFD373, 0x36FF2087, 0xC494A384, 0x9A879FA0, 0x68EC1CA3, 0x7BBCEF57,
0x89D76C54, 0x5D1D08BF, 0xAF768BBC, 0xBC267848, 0x4E4DFB4B, 0x20BD8EDE, 0xD2D60DDD, 0xC186FE29, 0x33ED7D2A,
0xE72719C1, 0x154C9AC2, 0x061C6936, 0xF477EA35, 0xAA64D611, 0x580F5512, 0x4B5FA6E6, 0xB93425E5, 0x6DFE410E,
0x9F95C20D, 0x8CC531F9, 0x7EAEB2FA, 0x30E349B1, 0xC288CAB2, 0xD1D83946, 0x23B3BA45, 0xF779DEAE, 0x05125DAD,
0x1642AE59, 0xE4292D5A, 0xBA3A117E, 0x4851927D, 0x5B016189, 0xA96AE28A, 0x7DA08661, 0x8FCB0562, 0x9C9BF696,
0x6EF07595, 0x417B1DBC, 0xB3109EBF, 0xA0406D4B, 0x522BEE48, 0x86E18AA3, 0x748A09A0, 0x67DAFA54, 0x95B17957,
0xCBA24573, 0x39C9C670, 0x2A993584, 0xD8F2B687, 0x0C38D26C, 0xFE53516F, 0xED03A29B, 0x1F682198, 0x5125DAD3,
0xA34E59D0, 0xB01EAA24, 0x42752927, 0x96BF4DCC, 0x64D4CECF, 0x77843D3B, 0x85EFBE38, 0xDBFC821C, 0x2997011F,
0x3AC7F2EB, 0xC8AC71E8, 0x1C661503, 0xEE0D9600, 0xFD5D65F4, 0x0F36E6F7, 0x61C69362, 0x93AD1061, 0x80FDE395,
0x72966096, 0xA65C047D, 0x5437877E, 0x4767748A, 0xB50CF789, 0xEB1FCBAD, 0x197448AE, 0x0A24BB5A, 0xF84F3859,
0x2C855CB2, 0xDEEEDFB1, 0xCDBE2C45, 0x3FD5AF46, 0x7198540D, 0x83F3D70E, 0x90A324FA, 0x62C8A7F9, 0xB602C312,
0x44694011, 0x5739B3E5, 0xA55230E6, 0xFB410CC2, 0x092A8FC1, 0x1A7A7C35, 0xE811FF36, 0x3CDB9BDD, 0xCEB018DE,
0xDDE0EB2A, 0x2F8B6829, 0x82F63B78, 0x709DB87B, 0x63CD4B8F, 0x91A6C88C, 0x456CAC67, 0xB7072F64, 0xA457DC90,
0x563C5F93, 0x082F63B7, 0xFA44E0B4, 0xE9141340, 0x1B7F9043, 0xCFB5F4A8, 0x3DDE77AB, 0x2E8E845F, 0xDCE5075C,
0x92A8FC17, 0x60C37F14, 0x73938CE0, 0x81F80FE3, 0x55326B08, 0xA759E80B, 0xB4091BFF, 0x466298FC, 0x1871A4D8,
0xEA1A27DB, 0xF94AD42F, 0x0B21572C, 0xDFEB33C7, 0x2D80B0C4, 0x3ED04330, 0xCCBBC033, 0xA24BB5A6, 0x502036A5,
0x4370C551, 0xB11B4652, 0x65D122B9, 0x97BAA1BA, 0x84EA524E, 0x7681D14D, 0x2892ED69, 0xDAF96E6A, 0xC9A99D9E,
0x3BC21E9D, 0xEF087A76, 0x1D63F975, 0x0E330A81, 0xFC588982, 0xB21572C9, 0x407EF1CA, 0x532E023E, 0xA145813D,
0x758FE5D6, 0x87E466D5, 0x94B49521, 0x66DF1622, 0x38CC2A06, 0xCAA7A905, 0xD9F75AF1, 0x2B9CD9F2, 0xFF56BD19,
0x0D3D3E1A, 0x1E6DCDEE, 0xEC064EED, 0xC38D26C4, 0x31E6A5C7, 0x22B65633, 0xD0DDD530, 0x0417B1DB, 0xF67C32D8,
0xE52CC12C, 0x1747422F, 0x49547E0B, 0xBB3FFD08, 0xA86F0EFC, 0x5A048DFF, 0x8ECEE914, 0x7CA56A17, 0x6FF599E3,
0x9D9E1AE0, 0xD3D3E1AB, 0x21B862A8, 0x32E8915C, 0xC083125F, 0x144976B4, 0xE622F5B7, 0xF5720643, 0x07198540,
0x590AB964, 0xAB613A67, 0xB831C993, 0x4A5A4A90, 0x9E902E7B, 0x6CFBAD78, 0x7FAB5E8C, 0x8DC0DD8F, 0xE330A81A,
0x115B2B19, 0x020BD8ED, 0xF0605BEE, 0x24AA3F05, 0xD6C1BC06, 0xC5914FF2, 0x37FACCF1, 0x69E9F0D5, 0x9B8273D6,
0x88D28022, 0x7AB90321, 0xAE7367CA, 0x5C18E4C9, 0x4F48173D, 0xBD23943E, 0xF36E6F75, 0x0105EC76, 0x12551F82,
0xE03E9C81, 0x34F4F86A, 0xC69F7B69, 0xD5CF889D, 0x27A40B9E, 0x79B737BA, 0x8BDCB4B9, 0x988C474D, 0x6AE7C44E,
0xBE2DA0A5, 0x4C4623A6, 0x5F16D052, 0xAD7D5351
)
_CRC = 0
_MASK = 0xFFFFFFFF
def _uint32(x):
"""Transform x's type to uint32."""
return x & 0xFFFFFFFF
def _get_crc_checksum(crc, data):
"""Get crc checksum."""
crc ^= _MASK
for d in data:
crc_table_index = (crc ^ d) & 0xFF
crc = (CRC_TABLE_32[crc_table_index] ^ (crc >> 8)) & _MASK
crc ^= _MASK
return crc
def get_mask_from_string(data):
"""
Get masked crc from data.
Args:
data (byte): Byte string of data.
Returns:
uint32, masked crc.
"""
crc = _get_crc_checksum(_CRC, data)
crc = _uint32(crc & _MASK)
crc = _uint32(((crc >> 15) | _uint32(crc << 17)) + 0xA282EAD8)
return crc
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
{
"node": [
{
"input": [
{
"name": "x",
"type": "DATA_EDGE"
},
{
"name": "x1",
"type": "DATA_EDGE"
},
{
"name": "x2",
"type": "DATA_EDGE"
},
{
"name": "x3",
"type": "DATA_EDGE"
},
{
"name": "x4",
"type": "DATA_EDGE"
},
{
"name": "x5",
"type": "DATA_EDGE"
},
{
"name": "x6",
"type": "DATA_EDGE"
},
{
"name": "x7",
"type": "DATA_EDGE"
},
{
"name": "x8",
"type": "DATA_EDGE"
},
{
"name": "x9",
"type": "DATA_EDGE"
},
{
"name": "x10",
"type": "DATA_EDGE"
},
{
"name": "conv1.weight",
"type": "DATA_EDGE"
}
],
"name": "1",
"opType": "Conv2D",
"scope": "Default/conv1-Conv2d",
"attribute": [
{
"name": "output_names",
"value": {
"dtype": "DT_GRAPHS",
"values": [
{
"dtype": "DT_FLOAT64",
"strVal": "output"
}
]
}
},
{
"name": "pad_mode",
"value": {
"dtype": "DT_FLOAT64",
"strVal": "same"
}
}
],
"outputType": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "64"
},
{
"size": "112"
},
{
"size": "112"
}
]
}
}
}
},
{
"input": [
{
"name": "x",
"type": "DATA_EDGE"
},
{
"name": "x1",
"type": "DATA_EDGE"
},
{
"name": "x2",
"type": "DATA_EDGE"
},
{
"name": "x3",
"type": "DATA_EDGE"
},
{
"name": "x4",
"type": "DATA_EDGE"
},
{
"name": "x5",
"type": "DATA_EDGE"
},
{
"name": "x6",
"type": "DATA_EDGE"
},
{
"name": "x7",
"type": "DATA_EDGE"
},
{
"name": "x8",
"type": "DATA_EDGE"
},
{
"name": "x9",
"type": "DATA_EDGE"
},
{
"name": "x10",
"type": "DATA_EDGE"
},
{
"name": "cst13",
"type": "DATA_EDGE"
}
],
"name": "53",
"opType": "tuple_getitem",
"scope": "Default/bn1-BatchNorm2d",
"outputType": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "128"
},
{
"size": "28"
},
{
"size": "28"
}
]
}
}
}
},
{
"input": [
{
"name": "x11",
"type": "DATA_EDGE"
},
{
"name": "x12",
"type": "DATA_EDGE"
},
{
"name": "x13",
"type": "DATA_EDGE"
},
{
"name": "x14",
"type": "DATA_EDGE"
},
{
"name": "x15",
"type": "DATA_EDGE"
},
{
"name": "x16",
"type": "DATA_EDGE"
},
{
"name": "x17",
"type": "DATA_EDGE"
},
{
"name": "x18",
"type": "DATA_EDGE"
},
{
"name": "x19",
"type": "DATA_EDGE"
},
{
"name": "x20",
"type": "DATA_EDGE"
},
{
"name": "conv1.weight",
"type": "DATA_EDGE"
},
{
"name": "cst25",
"type": "DATA_EDGE"
}
],
"name": "105",
"opType": "tuple_getitem",
"scope": "Default/bn1-BatchNorm2d",
"outputType": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "1024"
},
{
"size": "14"
},
{
"size": "14"
}
]
}
}
}
},
{
"input": [
{
"name": "x11",
"type": "DATA_EDGE"
}
],
"name": "50",
"opType": "Add",
"scope": "Default/bn1",
"outputType": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "1024"
},
{
"size": "14"
},
{
"size": "14"
}
]
}
}
}
},
{
"input": [
{
"name": "x11",
"type": "DATA_EDGE"
}
],
"name": "51",
"opType": "Add",
"scope": "Default/bn1",
"outputType": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "1024"
},
{
"size": "14"
},
{
"size": "14"
}
]
}
}
}
},
{
"input": [
{
"name": "x11",
"type": "DATA_EDGE"
}
],
"name": "52",
"opType": "Add",
"scope": "Default/bn1",
"outputType": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "1024"
},
{
"size": "14"
},
{
"size": "14"
}
]
}
}
}
},
{
"input": [
{
"name": "x11",
"type": "DATA_EDGE"
}
],
"name": "53",
"opType": "Add",
"scope": "Default/bn1",
"outputType": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "1024"
},
{
"size": "14"
},
{
"size": "14"
}
]
}
}
}
},
{
"input": [
{
"name": "x11",
"type": "DATA_EDGE"
}
],
"name": "54",
"opType": "Add",
"scope": "Default/bn1",
"outputType": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "1024"
},
{
"size": "14"
},
{
"size": "14"
}
]
}
}
}
},
{
"input": [
{
"name": "50",
"type": "DATA_EDGE"
}
],
"name": "1",
"opType": "Reshape",
"scope": "Default/bn1",
"outputType": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "128"
},
{
"size": "28"
},
{
"size": "28"
}
]
}
}
}
},
{
"input": [
{
"name": "51",
"type": "DATA_EDGE"
}
],
"name": "2",
"opType": "Reshape",
"scope": "Default/bn1",
"outputType": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "128"
},
{
"size": "28"
},
{
"size": "28"
}
]
}
}
}
},
{
"input": [
{
"name": "52",
"type": "DATA_EDGE"
}
],
"name": "3",
"opType": "Reshape",
"scope": "Default/bn1",
"outputType": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "128"
},
{
"size": "28"
},
{
"size": "28"
}
]
}
}
}
},
{
"input": [
{
"name": "53",
"type": "DATA_EDGE"
}
],
"name": "4",
"opType": "Reshape",
"scope": "Default/bn1",
"outputType": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "128"
},
{
"size": "28"
},
{
"size": "28"
}
]
}
}
}
},
{
"input": [
{
"name": "54",
"type": "DATA_EDGE"
}
],
"name": "5",
"opType": "Reshape",
"scope": "Default/bn1",
"outputType": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "128"
},
{
"size": "28"
},
{
"size": "28"
}
]
}
}
}
},
{
"input": [
{
"name": "x",
"type": "DATA_EDGE"
}
],
"name": "6",
"opType": "Reshape",
"scope": "Default/bn1",
"outputType": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "128"
},
{
"size": "28"
},
{
"size": "28"
}
]
}
}
}
},
{
"input": [
{
"name": "x",
"type": "DATA_EDGE"
}
],
"name": "7",
"opType": "Reshape",
"scope": "Default/bn1",
"outputType": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "128"
},
{
"size": "28"
},
{
"size": "28"
}
]
}
}
}
},
{
"input": [
{
"name": "x",
"type": "DATA_EDGE"
}
],
"name": "8",
"opType": "Reshape",
"scope": "Default/bn1",
"outputType": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "128"
},
{
"size": "28"
},
{
"size": "28"
}
]
}
}
}
},
{
"input": [
{
"name": "x",
"type": "DATA_EDGE"
}
],
"name": "9",
"opType": "Reshape",
"scope": "Default/bn1",
"outputType": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "128"
},
{
"size": "28"
},
{
"size": "28"
}
]
}
}
}
},
{
"input": [
{
"name": "x",
"type": "DATA_EDGE"
}
],
"name": "10",
"opType": "Reshape",
"scope": "Default/bn1",
"outputType": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "128"
},
{
"size": "28"
},
{
"size": "28"
}
]
}
}
}
},
{
"input": [
{
"name": "x",
"type": "DATA_EDGE"
}
],
"name": "11",
"opType": "Reshape",
"scope": "Default/bn1",
"outputType": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "128"
},
{
"size": "28"
},
{
"size": "28"
}
]
}
}
}
},
{
"input": [
{
"name": "x",
"type": "DATA_EDGE"
}
],
"name": "12",
"opType": "Reshape",
"scope": "Default/bn1",
"outputType": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "128"
},
{
"size": "28"
},
{
"size": "28"
}
]
}
}
}
}
],
"name": "849_848_847_424_1_construct",
"parameters": [
{
"name": "x",
"type": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "3"
},
{
"size": "224"
},
{
"size": "224"
}
]
}
}
}
},
{
"name": "conv1.weight",
"type": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "64"
},
{
"size": "3"
},
{
"size": "7"
},
{
"size": "7"
}
]
}
}
}
},
{
"name": "x1",
"type": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "3"
},
{
"size": "224"
},
{
"size": "224"
}
]
}
}
}
},
{
"name": "x2",
"type": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "3"
},
{
"size": "224"
},
{
"size": "224"
}
]
}
}
}
},
{
"name": "x3",
"type": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "3"
},
{
"size": "224"
},
{
"size": "224"
}
]
}
}
}
},
{
"name": "x4",
"type": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "3"
},
{
"size": "224"
},
{
"size": "224"
}
]
}
}
}
},
{
"name": "x5",
"type": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "3"
},
{
"size": "224"
},
{
"size": "224"
}
]
}
}
}
},
{
"name": "x6",
"type": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "3"
},
{
"size": "224"
},
{
"size": "224"
}
]
}
}
}
},
{
"name": "x7",
"type": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "3"
},
{
"size": "224"
},
{
"size": "224"
}
]
}
}
}
},
{
"name": "x8",
"type": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "3"
},
{
"size": "224"
},
{
"size": "224"
}
]
}
}
}
},
{
"name": "x9",
"type": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "3"
},
{
"size": "224"
},
{
"size": "224"
}
]
}
}
}
},
{
"name": "x10",
"type": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "3"
},
{
"size": "224"
},
{
"size": "224"
}
]
}
}
}
},
{
"name": "x11",
"type": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "3"
},
{
"size": "224"
},
{
"size": "224"
}
]
}
}
}
},
{
"name": "x12",
"type": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "3"
},
{
"size": "224"
},
{
"size": "224"
}
]
}
}
}
},
{
"name": "x13",
"type": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "3"
},
{
"size": "224"
},
{
"size": "224"
}
]
}
}
}
},
{
"name": "x14",
"type": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "3"
},
{
"size": "224"
},
{
"size": "224"
}
]
}
}
}
},
{
"name": "x15",
"type": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "3"
},
{
"size": "224"
},
{
"size": "224"
}
]
}
}
}
},
{
"name": "x16",
"type": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "3"
},
{
"size": "224"
},
{
"size": "224"
}
]
}
}
}
},
{
"name": "x17",
"type": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "3"
},
{
"size": "224"
},
{
"size": "224"
}
]
}
}
}
},
{
"name": "x18",
"type": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "3"
},
{
"size": "224"
},
{
"size": "224"
}
]
}
}
}
},
{
"name": "x19",
"type": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "3"
},
{
"size": "224"
},
{
"size": "224"
}
]
}
}
}
},
{
"name": "x20",
"type": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "3"
},
{
"size": "224"
},
{
"size": "224"
}
]
}
}
}
}
],
"outputs": [
{
"name": "228",
"type": {
"dataType": "DT_STRING",
"tensorType": {
"elemType": "DT_FLOAT16",
"shape": {
"dim": [
{
"size": "1"
},
{
"size": "10"
}
]
}
}
}
}
],
"constVals": [
{
"key": "cst12",
"value": {
"dtype": "DT_INT32",
"intVal": "0"
}
},
{
"key": "cst13",
"value": {
"dtype": "DT_INT32",
"intVal": "0"
}
},
{
"key": "cst25",
"value": {
"dtype": "DT_INT32",
"intVal": "0"
}
}
]
}
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Log generator for images."""
import io
import time
import numpy as np
from PIL import Image
from tests.ut.datavisual.utils.log_generators.log_generator import LogGenerator
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
class ImagesLogGenerator(LogGenerator):
"""
Log generator for images.
This is a log generator writing images. User can use it to generate fake
summary logs about images.
"""
def generate_event(self, values):
"""
Method for generating image event.
Args:
values (dict): A dict contains:
{
wall_time (float): Timestamp.
step (int): Train step.
image (np.array): Pixels tensor.
tag (str): Tag name.
}
Returns:
summary_pb2.Event.
"""
image_event = summary_pb2.Event()
image_event.wall_time = values.get('wall_time')
image_event.step = values.get('step')
height, width, channel, image_string = self._get_image_string(values.get('image'))
value = image_event.summary.value.add()
value.tag = values.get('tag')
value.image.height = height
value.image.width = width
value.image.colorspace = channel
value.image.encoded_image = image_string
return image_event
def _get_image_string(self, image_tensor):
"""
Generate image string from tensor.
Args:
image_tensor (np.array): Pixels tensor.
Returns:
int, height.
int, width.
int, channel.
bytes, image_string.
"""
height, width, channel = image_tensor.shape
scaled_height = int(height)
scaled_width = int(width)
image = Image.fromarray(image_tensor)
image = image.resize((scaled_width, scaled_height), Image.ANTIALIAS)
output = io.BytesIO()
image.save(output, format='PNG')
image_string = output.getvalue()
output.close()
return height, width, channel, image_string
def _make_image_tensor(self, shape):
"""
Make image tensor according to shape.
Args:
shape (list): Shape of image, consists of height, width, channel.
Returns:
np.array, image tensor.
"""
image = np.prod(shape)
image_tensor = (np.arange(image, dtype=float)).reshape(shape)
image_tensor = image_tensor / np.max(image_tensor) * 255
image_tensor = image_tensor.astype(np.uint8)
return image_tensor
def generate_log(self, file_path, steps_list, tag_name):
"""
Generate log for external calls.
Args:
file_path (str): Path to write logs.
steps_list (list): A list consists of step.
tag_name (str): Tag name.
Returns:
list[dict], generated image metadata.
dict, generated image tensors.
"""
images_values = dict()
images_metadata = []
for step in steps_list:
wall_time = time.time()
# height, width, channel
image_tensor = self._make_image_tensor([5, 5, 3])
image_metadata = dict()
image_metadata.update({'wall_time': wall_time})
image_metadata.update({'step': step})
image_metadata.update({'height': image_tensor.shape[0]})
image_metadata.update({'width': image_tensor.shape[1]})
images_metadata.append(image_metadata)
images_values.update({step: image_tensor})
values = dict(
wall_time=wall_time,
step=step,
image=image_tensor,
tag=tag_name
)
self._write_log_one_step(file_path, values)
return images_metadata, images_values
def get_image_tensor_from_bytes(self, image_string):
"""Get image tensor from bytes."""
img = Image.open(io.BytesIO(image_string))
image_tensor = np.array(img)
return image_tensor
if __name__ == "__main__":
images_log_generator = ImagesLogGenerator()
test_file_name = '%s.%s.%s' % ('image', 'summary', str(time.time()))
test_steps = [1, 3, 5]
test_tags = "test_image_tag_name"
images_log_generator.generate_log(test_file_name, test_steps, test_tags)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Base log Generator."""
import struct
from abc import abstractmethod
from tests.ut.datavisual.utils import crc32
class LogGenerator:
"""
Base log generator.
This is a base class for log generators. User can use it to generate fake
summary logs.
"""
@abstractmethod
def generate_event(self, values):
"""
Abstract method for generating event.
Args:
values (dict): Values.
Returns:
summary_pb2.Event.
"""
def _write_log_one_step(self, file_path, values):
"""
Write log one step.
Args:
file_path (str): File path to write.
values (dict): Values.
"""
event = self.generate_event(values)
self._write_log_from_event(file_path, event)
@staticmethod
def _write_log_from_event(file_path, event):
"""
Write log by event.
Args:
file_path (str): File path to write.
event (summary_pb2.Event): Event object in proto.
"""
send_msg = event.SerializeToString()
header = struct.pack('<Q', len(send_msg))
header_crc = struct.pack('<I', crc32.get_mask_from_string(header))
footer_crc = struct.pack('<I', crc32.get_mask_from_string(send_msg))
write_event = header + header_crc + send_msg + footer_crc
with open(file_path, "ab") as f:
f.write(write_event)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Log generator for scalars."""
import time
import numpy as np
from tests.ut.datavisual.utils.log_generators.log_generator import LogGenerator
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
class ScalarsLogGenerator(LogGenerator):
"""
Log generator for scalars.
This is a log generator writing scalars. User can use it to generate fake
summary logs about scalar.
"""
def generate_event(self, values):
"""
Method for generating scalar event.
Args:
values (dict): A dict contains:
{
wall_time (float): Timestamp.
step (int): Train step.
value (float): Scalar value.
tag (str): Tag name.
}
Returns:
summary_pb2.Event.
"""
scalar_event = summary_pb2.Event()
scalar_event.wall_time = values.get('wall_time')
scalar_event.step = values.get('step')
value = scalar_event.summary.value.add()
value.tag = values.get('tag')
value.scalar_value = values.get('value')
return scalar_event
def generate_log(self, file_path, steps_list, tag_name):
"""
Generate log for external calls.
Args:
file_path (str): Path to write logs.
steps_list (list): A list consists of step.
tag_name (str): Tag name.
Returns:
list[dict], generated scalar metadata.
None, to be consistent with return value of ImageGenerator.
"""
scalars_metadata = []
for step in steps_list:
scalar_metadata = dict()
wall_time = time.time()
value = np.random.rand()
scalar_metadata.update({'wall_time': wall_time})
scalar_metadata.update({'step': step})
scalar_metadata.update({'value': value})
scalars_metadata.append(scalar_metadata)
scalar_metadata.update({"tag": tag_name})
self._write_log_one_step(file_path, scalar_metadata)
return scalars_metadata, None
if __name__ == "__main__":
scalars_log_generator = ScalarsLogGenerator()
test_file_name = '%s.%s.%s' % ('scalar', 'summary', str(time.time()))
test_steps = [1, 3, 5]
test_tag = "test_scalar_tag_name"
scalars_log_generator.generate_log(test_file_name, test_steps, test_tag)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""log operations module."""
import json
import os
import time
from tests.ut.datavisual.utils.log_generators.graph_log_generator import GraphLogGenerator
from tests.ut.datavisual.utils.log_generators.images_log_generator import ImagesLogGenerator
from tests.ut.datavisual.utils.log_generators.scalars_log_generator import ScalarsLogGenerator
from mindinsight.datavisual.common.enums import PluginNameEnum
log_generators = {
PluginNameEnum.GRAPH.value: GraphLogGenerator(),
PluginNameEnum.IMAGE.value: ImagesLogGenerator(),
PluginNameEnum.SCALAR.value: ScalarsLogGenerator()
}
class LogOperations:
"""Log Operations class."""
@staticmethod
def generate_log(plugin_name, log_dir, log_settings, valid=True):
"""
Generate log.
Args:
plugin_name (str): Plugin name, contains 'graph', 'image', and 'scalar'.
log_dir (str): Log path to write log.
log_settings (dict): Info about the log, e.g.:
{
current_time (int): Timestamp in summary file name, not necessary.
graph_base_path (str): Path of graph_bas.json, necessary for `graph`.
steps (list[int]): Steps for `image` and `scalar`, default is [1].
tag (str): Tag name, default is 'default_tag'.
}
valid (bool): If true, summary name will be valid.
Returns:
str, Summary log path.
"""
current_time = log_settings.get('time', int(time.time()))
current_time = int(current_time)
log_generator = log_generators.get(plugin_name)
if valid:
temp_path = os.path.join(log_dir, '%s.%s' % ('test.summary', str(current_time)))
else:
temp_path = os.path.join(log_dir, '%s.%s' % ('test.invalid', str(current_time)))
if plugin_name == PluginNameEnum.GRAPH.value:
graph_base_path = log_settings.get('graph_base_path')
with open(graph_base_path, 'r') as load_f:
graph_dict = json.load(load_f)
graph_dict = log_generator.generate_log(temp_path, graph_dict)
return temp_path, graph_dict
steps_list = log_settings.get('steps', [1])
tag_name = log_settings.get('tag', 'default_tag')
metadata, values = log_generator.generate_log(temp_path, steps_list, tag_name)
return temp_path, metadata, values
@staticmethod
def get_log_generator(plugin_name):
"""Get log generator."""
return log_generators.get(plugin_name)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Description: This file is used for some common util.
"""
import os
import shutil
import time
from urllib.parse import urlencode
from mindinsight.datavisual.common.enums import DataManagerStatus
def get_url(url, params):
"""
Concatenate the URL and params.
Args:
url (str): A link requested. For example, http://example.com.
params (dict): A dict consists of params. For example, {'offset': 1, 'limit':'100}.
Returns:
str, like http://example.com?offset=1&limit=100
"""
return url + '?' + urlencode(params)
def delete_files_or_dirs(path_list):
"""Delete files or dirs in path_list."""
for path in path_list:
if os.path.isdir(path):
shutil.rmtree(path)
else:
os.remove(path)
def check_loading_done(data_manager, time_limit=15):
"""If loading data for more than `time_limit` seconds, exit."""
start_time = time.time()
while data_manager.status != DataManagerStatus.DONE.value:
time_used = time.time() - start_time
if time_used > time_limit:
break
time.sleep(0.1)
continue
...@@ -14,6 +14,6 @@ ...@@ -14,6 +14,6 @@
# ============================================================================ # ============================================================================
"""Import the mocked mindspore.""" """Import the mocked mindspore."""
import sys import sys
from .collection.model import mindspore from ...utils import mindspore
sys.modules['mindspore'] = mindspore sys.modules['mindspore'] = mindspore
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock MindSpore Interface."""
from .application.model_zoo.resnet import ResNet
from .common.tensor import Tensor
from .dataset import MindDataset
from .nn import *
from .train.callback import _ListCallback, Callback, RunContext, ModelCheckpoint, SummaryStep
from .train.summary import SummaryRecord
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the MindSpore ResNet class."""
from ...nn.cell import Cell
class ResNet(Cell):
"""Mocked ResNet."""
def __init__(self):
super(ResNet, self).__init__()
self._cells = {}
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the MindSpore mindspore/common/tensor.py."""
import numpy as np
class Tensor:
"""Mock the MindSpore Tensor class."""
def __init__(self, value=0):
self._value = value
def asnumpy(self):
"""Get value in numpy format."""
return np.array(self._value)
def __repr__(self):
return str(self.asnumpy())
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MindSpore Mock Interface"""
def get_context(key):
"""Get key in context."""
context = {"device_id": 1}
return context.get(key)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock mindspore.dataset."""
from .engine import MindDataset
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock mindspore.dataset.engine."""
from .datasets import MindDataset, Dataset
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the MindSpore mindspore/dataset/engine/datasets.py."""
class Dataset:
"""Mock the MindSpore Dataset class."""
def __init__(self, dataset_size=None, dataset_path=None):
self.dataset_size = dataset_size
self.dataset_path = dataset_path
self.input = []
def get_dataset_size(self):
"""Mocked get_dataset_size."""
return self.dataset_size
class MindDataset(Dataset):
"""Mock the MindSpore MindDataset class."""
def __init__(self, dataset_size=None, dataset_file=None):
super(MindDataset, self).__init__(dataset_size)
self.dataset_file = dataset_file
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the mindspore.nn package."""
from .optim import Optimizer, Momentum
from .loss.loss import SoftmaxCrossEntropyWithLogits, _Loss
from .cell import Cell, WithLossCell, TrainOneStepWithLossScaleCell
__all__ = ['Optimizer', 'Momentum', 'SoftmaxCrossEntropyWithLogits',
'_Loss', 'Cell', 'WithLossCell',
'TrainOneStepWithLossScaleCell']
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the MindSpore mindspore/train/callback.py."""
class Cell:
"""Mock the Cell class."""
def __init__(self, auto_prefix=True, pips=None):
if pips is None:
pips = dict()
self._auto_prefix = auto_prefix
self._pips = pips
@property
def auto_prefix(self):
"""The property of auto_prefix."""
return self._auto_prefix
@property
def pips(self):
"""The property of pips."""
return self._pips
class WithLossCell(Cell):
"""Mocked WithLossCell class."""
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False, pips=backbone.pips)
self._backbone = backbone
self._loss_fn = loss_fn
class TrainOneStepWithLossScaleCell(Cell):
"""Mocked TrainOneStepWithLossScaleCell."""
def __init__(self):
super(TrainOneStepWithLossScaleCell, self).__init__()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the MindSpore SoftmaxCrossEntropyWithLogits class."""
from ..cell import Cell
class _Loss(Cell):
"""Mocked _Loss."""
def __init__(self, reduction='mean'):
super(_Loss, self).__init__()
self.reduction = reduction
def construct(self, base, target):
"""Mocked construct function."""
raise NotImplementedError
class SoftmaxCrossEntropyWithLogits(_Loss):
"""Mocked SoftmaxCrossEntropyWithLogits."""
def __init__(self, weight=None):
super(SoftmaxCrossEntropyWithLogits, self).__init__(weight)
def construct(self, base, target):
"""Mocked construct."""
return 1
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the MindSpore mindspore/nn/optim.py."""
from .cell import Cell
class Parameter:
"""Mock the MindSpore Parameter class."""
def __init__(self, learning_rate):
self._name = "Parameter"
self.default_input = learning_rate
@property
def name(self):
"""The property of name."""
return self._name
def __repr__(self):
format_str = 'Parameter (name={name})'
return format_str.format(name=self._name)
class Optimizer(Cell):
"""Mock the MindSpore Optimizer class."""
def __init__(self, learning_rate):
super(Optimizer, self).__init__()
self.learning_rate = Parameter(learning_rate)
class Momentum(Optimizer):
"""Mock the MindSpore Momentum class."""
def __init__(self, learning_rate):
super(Momentum, self).__init__(learning_rate)
self.dynamic_lr = False
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock MindSpore wrap package."""
from .loss_scale import TrainOneStepWithLossScaleCell
from .cell_wrapper import WithLossCell
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock MindSpore cell_wrapper.py."""
from ..cell import Cell
class WithLossCell(Cell):
"""Mock the WithLossCell class."""
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__()
self._backbone = backbone
self._loss_fn = loss_fn
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock MindSpore loss_scale.py."""
from ..cell import Cell
class TrainOneStepWithLossScaleCell(Cell):
"""Mock the TrainOneStepWithLossScaleCell class."""
def __init__(self, network, optimizer):
super(TrainOneStepWithLossScaleCell, self).__init__()
self.network = network
self.optimizer = optimizer
def construct(self, data, label):
"""Mock the construct method."""
raise NotImplementedError
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the MindSpore mindspore/train/callback.py."""
import os
class RunContext:
"""Mock the RunContext class."""
def __init__(self, original_args=None):
self._original_args = original_args
self._stop_requested = False
def original_args(self):
"""Mock original_args."""
return self._original_args
def stop_requested(self):
"""Mock stop_requested method."""
return self._stop_requested
class Callback:
"""Mock the Callback class."""
def __init__(self):
pass
def begin(self, run_context):
"""Called once before network training."""
def epoch_begin(self, run_context):
"""Called before each epoch begin."""
class _ListCallback(Callback):
"""Mock the _ListCallabck class."""
def __init__(self, callbacks):
super(_ListCallback, self).__init__()
self._callbacks = callbacks
class ModelCheckpoint(Callback):
"""Mock the ModelCheckpoint class."""
def __init__(self, prefix='CKP', directory=None, config=None):
super(ModelCheckpoint, self).__init__()
self._prefix = prefix
self._directory = directory
self._config = config
self._latest_ckpt_file_name = os.path.join(directory, prefix + 'test_model.ckpt')
@property
def model_file_name(self):
"""Get the file name of model."""
return self._model_file_name
@property
def latest_ckpt_file_name(self):
"""Get the latest file name fo checkpoint."""
return self._latest_ckpt_file_name
class SummaryStep(Callback):
"""Mock the SummaryStep class."""
def __init__(self, summary, flush_step=10):
super(SummaryStep, self).__init__()
self._sumamry = summary
self._flush_step = flush_step
self.summary_file_name = summary.full_file_name
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MindSpore Mock Interface"""
from .summary_record import SummaryRecord
__all__ = ["SummaryRecord"]
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MindSpore Mock Interface"""
import os
import time
class SummaryRecord:
"""Mock the MindSpore SummaryRecord class."""
def __init__(self,
log_dir: str,
file_prefix: str = "events.",
file_suffix: str = ".MS",
create_time=int(time.time())):
self.log_dir = log_dir
self.prefix = file_prefix
self.suffix = file_suffix
file_name = file_prefix + 'summary.' + str(create_time) + file_suffix
self.full_file_name = os.path.join(log_dir, file_name)
def flush(self):
"""Mock flush method."""
def close(self):
"""Mock close method."""
...@@ -16,21 +16,21 @@ ...@@ -16,21 +16,21 @@
import os import os
import shutil import shutil
import unittest import unittest
from unittest import mock, TestCase from unittest import TestCase, mock
from unittest.mock import MagicMock from unittest.mock import MagicMock
from mindinsight.lineagemgr.collection.model.model_lineage import TrainLineage, EvalLineage, \ from mindinsight.lineagemgr.collection.model.model_lineage import AnalyzeObject, EvalLineage, TrainLineage
AnalyzeObject from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageGetModelFileError, LineageLogError,
from mindinsight.lineagemgr.common.exceptions.exceptions import \ MindInsightException)
LineageLogError, LineageGetModelFileError, MindInsightException
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.dataset.engine import MindDataset, Dataset from mindspore.dataset.engine import Dataset, MindDataset
from mindspore.nn import Optimizer, WithLossCell, TrainOneStepWithLossScaleCell, \ from mindspore.nn import Optimizer, SoftmaxCrossEntropyWithLogits, TrainOneStepWithLossScaleCell, WithLossCell
SoftmaxCrossEntropyWithLogits from mindspore.train.callback import ModelCheckpoint, RunContext, SummaryStep
from mindspore.train.callback import RunContext, ModelCheckpoint, SummaryStep
from mindspore.train.summary import SummaryRecord from mindspore.train.summary import SummaryRecord
@mock.patch('builtins.open')
@mock.patch('os.makedirs')
class TestModelLineage(TestCase): class TestModelLineage(TestCase):
"""Test TrainLineage and EvalLineage class in model_lineage.py.""" """Test TrainLineage and EvalLineage class in model_lineage.py."""
...@@ -51,23 +51,19 @@ class TestModelLineage(TestCase): ...@@ -51,23 +51,19 @@ class TestModelLineage(TestCase):
cls.summary_log_path = '/path/to/summary_log' cls.summary_log_path = '/path/to/summary_log'
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
def test_summary_record_exception(self, mock_validate_summary): def test_summary_record_exception(self, *args):
"""Test SummaryRecord with exception.""" """Test SummaryRecord with exception."""
mock_validate_summary.return_value = None args[0].return_value = None
summary_record = self.my_summary_record(self.summary_log_path) summary_record = self.my_summary_record(self.summary_log_path)
with self.assertRaises(MindInsightException) as context: with self.assertRaises(MindInsightException) as context:
self.my_train_module(summary_record=summary_record, raise_exception=1) self.my_train_module(summary_record=summary_record, raise_exception=1)
self.assertTrue(f'Invalid value for raise_exception.' in str(context.exception)) self.assertTrue(f'Invalid value for raise_exception.' in str(context.exception))
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.ds') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.ds')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_dataset_graph')
'LineageSummary.record_dataset_graph') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_optimizer_by_network')
'validate_summary_record') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.'
'AnalyzeObject.get_optimizer_by_network')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.'
'AnalyzeObject.analyze_optimizer')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network')
def test_begin(self, *args): def test_begin(self, *args):
"""Test TrainLineage.begin method.""" """Test TrainLineage.begin method."""
...@@ -82,14 +78,10 @@ class TestModelLineage(TestCase): ...@@ -82,14 +78,10 @@ class TestModelLineage(TestCase):
args[4].assert_called() args[4].assert_called()
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.ds') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.ds')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_dataset_graph')
'LineageSummary.record_dataset_graph') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.' @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_optimizer_by_network')
'validate_summary_record') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.'
'AnalyzeObject.get_optimizer_by_network')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.'
'AnalyzeObject.analyze_optimizer')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network')
def test_begin_error(self, *args): def test_begin_error(self, *args):
"""Test TrainLineage.begin method.""" """Test TrainLineage.begin method."""
...@@ -122,15 +114,11 @@ class TestModelLineage(TestCase): ...@@ -122,15 +114,11 @@ class TestModelLineage(TestCase):
train_lineage.begin(self.my_run_context(run_context)) train_lineage.begin(self.my_run_context(run_context))
self.assertTrue('The parameter optimizer is invalid.' in str(context.exception)) self.assertTrue('The parameter optimizer is invalid.' in str(context.exception))
@mock.patch( @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size')
'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path')
@mock.patch( @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage')
'mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset')
@mock.patch( @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer')
'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset')
@mock.patch(
'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context')
@mock.patch('builtins.float') @mock.patch('builtins.float')
...@@ -150,23 +138,19 @@ class TestModelLineage(TestCase): ...@@ -150,23 +138,19 @@ class TestModelLineage(TestCase):
args[6].assert_called() args[6].assert_called()
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
def test_train_end_exception(self, mock_validate_summary): def test_train_end_exception(self, *args):
"""Test TrainLineage.end method when exception.""" """Test TrainLineage.end method when exception."""
mock_validate_summary.return_value = True args[0].return_value = True
train_lineage = self.my_train_module(self.my_summary_record(self.summary_log_path), True) train_lineage = self.my_train_module(self.my_summary_record(self.summary_log_path), True)
with self.assertRaises(Exception) as context: with self.assertRaises(Exception) as context:
train_lineage.end(self.run_context) train_lineage.end(self.run_context)
self.assertTrue('Invalid TrainLineage run_context.' in str(context.exception)) self.assertTrue('Invalid TrainLineage run_context.' in str(context.exception))
@mock.patch( @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size')
'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path')
@mock.patch( @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage')
'mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset')
@mock.patch( @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer')
'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset')
@mock.patch(
'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context')
@mock.patch('builtins.float') @mock.patch('builtins.float')
...@@ -186,15 +170,11 @@ class TestModelLineage(TestCase): ...@@ -186,15 +170,11 @@ class TestModelLineage(TestCase):
train_lineage.end(self.my_run_context(self.run_context)) train_lineage.end(self.my_run_context(self.run_context))
self.assertTrue('End error in TrainLineage:' in str(context.exception)) self.assertTrue('End error in TrainLineage:' in str(context.exception))
@mock.patch( @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size')
'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_model_size')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_path')
@mock.patch( @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage')
'mindinsight.lineagemgr.collection.model.model_lineage.LineageSummary.record_train_lineage') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset')
@mock.patch( @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer')
'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_dataset')
@mock.patch(
'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.analyze_optimizer')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_network')
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_train_run_context')
@mock.patch('builtins.float') @mock.patch('builtins.float')
...@@ -218,9 +198,9 @@ class TestModelLineage(TestCase): ...@@ -218,9 +198,9 @@ class TestModelLineage(TestCase):
self.assertTrue('End error in TrainLineage:' in str(context.exception)) self.assertTrue('End error in TrainLineage:' in str(context.exception))
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
def test_eval_exception_train_id_none(self, mock_validate_summary): def test_eval_exception_train_id_none(self, *args):
"""Test EvalLineage.end method with initialization error.""" """Test EvalLineage.end method with initialization error."""
mock_validate_summary.return_value = True args[0].return_value = True
with self.assertRaises(MindInsightException) as context: with self.assertRaises(MindInsightException) as context:
self.my_eval_module(self.my_summary_record(self.summary_log_path), raise_exception=2) self.my_eval_module(self.my_summary_record(self.summary_log_path), raise_exception=2)
self.assertTrue('Invalid value for raise_exception.' in str(context.exception)) self.assertTrue('Invalid value for raise_exception.' in str(context.exception))
...@@ -242,9 +222,9 @@ class TestModelLineage(TestCase): ...@@ -242,9 +222,9 @@ class TestModelLineage(TestCase):
args[0].assert_called() args[0].assert_called()
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record') @mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
def test_eval_end_except_run_context(self, mock_validate_summary): def test_eval_end_except_run_context(self, *args):
"""Test EvalLineage.end method when run_context is invalid..""" """Test EvalLineage.end method when run_context is invalid.."""
mock_validate_summary.return_value = True args[0].return_value = True
eval_lineage = self.my_eval_module(self.my_summary_record(self.summary_log_path), True) eval_lineage = self.my_eval_module(self.my_summary_record(self.summary_log_path), True)
with self.assertRaises(Exception) as context: with self.assertRaises(Exception) as context:
eval_lineage.end(self.run_context) eval_lineage.end(self.run_context)
...@@ -284,8 +264,9 @@ class TestModelLineage(TestCase): ...@@ -284,8 +264,9 @@ class TestModelLineage(TestCase):
eval_lineage.end(self.my_run_context(self.run_context)) eval_lineage.end(self.my_run_context(self.run_context))
self.assertTrue('End error in EvalLineage' in str(context.exception)) self.assertTrue('End error in EvalLineage' in str(context.exception))
def test_epoch_is_zero(self): def test_epoch_is_zero(self, *args):
"""Test TrainLineage.end method.""" """Test TrainLineage.end method."""
args[0].return_value = None
run_context = self.run_context run_context = self.run_context
run_context['epoch_num'] = 0 run_context['epoch_num'] = 0
with self.assertRaises(MindInsightException): with self.assertRaises(MindInsightException):
...@@ -345,7 +326,7 @@ class TestAnalyzer(TestCase): ...@@ -345,7 +326,7 @@ class TestAnalyzer(TestCase):
) )
res1 = self.analyzer.analyze_dataset(dataset, {'step_num': 10, 'epoch': 2}, 'train') res1 = self.analyzer.analyze_dataset(dataset, {'step_num': 10, 'epoch': 2}, 'train')
res2 = self.analyzer.analyze_dataset(dataset, {'step_num': 5}, 'valid') res2 = self.analyzer.analyze_dataset(dataset, {'step_num': 5}, 'valid')
assert res1 == {'step_num': 10, assert res1 == {'step_num': 10,
'train_dataset_path': '/path/to', 'train_dataset_path': '/path/to',
'train_dataset_size': 50, 'train_dataset_size': 50,
'epoch': 2} 'epoch': 2}
......
...@@ -20,23 +20,44 @@ from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher ...@@ -20,23 +20,44 @@ from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.lineagemgr.common.path_parser import SummaryPathParser from mindinsight.lineagemgr.common.path_parser import SummaryPathParser
MOCK_SUMMARY_DIRS = [
{
'relative_path': './relative_path0'
},
{
'relative_path': './'
},
{
'relative_path': './relative_path1'
}
]
MOCK_SUMMARIES = [
{
'file_name': 'file0',
'create_time': datetime.fromtimestamp(1582031970)
},
{
'file_name': 'file0_lineage',
'create_time': datetime.fromtimestamp(1582031970)
},
{
'file_name': 'file1',
'create_time': datetime.fromtimestamp(1582031971)
},
{
'file_name': 'file1_lineage',
'create_time': datetime.fromtimestamp(1582031971)
}
]
class TestSummaryPathParser(TestCase): class TestSummaryPathParser(TestCase):
"""Test the class of SummaryPathParser.""" """Test the class of SummaryPathParser."""
@mock.patch.object(SummaryWatcher, 'list_summary_directories') @mock.patch.object(SummaryWatcher, 'list_summary_directories')
def test_get_summary_dirs(self, *args): def test_get_summary_dirs(self, *args):
"""Test the function of get_summary_dirs.""" """Test the function of get_summary_dirs."""
args[0].return_value = [ args[0].return_value = MOCK_SUMMARY_DIRS
{
'relative_path': './relative_path0'
},
{
'relative_path': './'
},
{
'relative_path': './relative_path1'
}
]
expected_result = [ expected_result = [
'/path/to/base/relative_path0', '/path/to/base/relative_path0',
...@@ -54,24 +75,7 @@ class TestSummaryPathParser(TestCase): ...@@ -54,24 +75,7 @@ class TestSummaryPathParser(TestCase):
@mock.patch.object(SummaryWatcher, 'list_summaries') @mock.patch.object(SummaryWatcher, 'list_summaries')
def test_get_latest_lineage_summary(self, *args): def test_get_latest_lineage_summary(self, *args):
"""Test the function of get_latest_lineage_summary.""" """Test the function of get_latest_lineage_summary."""
args[0].return_value = [ args[0].return_value = MOCK_SUMMARIES
{
'file_name': 'file0',
'create_time': datetime.fromtimestamp(1582031970)
},
{
'file_name': 'file0_lineage',
'create_time': datetime.fromtimestamp(1582031970)
},
{
'file_name': 'file1',
'create_time': datetime.fromtimestamp(1582031971)
},
{
'file_name': 'file1_lineage',
'create_time': datetime.fromtimestamp(1582031971)
}
]
summary_dir = '/path/to/summary_dir' summary_dir = '/path/to/summary_dir'
result = SummaryPathParser.get_latest_lineage_summary(summary_dir) result = SummaryPathParser.get_latest_lineage_summary(summary_dir)
self.assertEqual('/path/to/summary_dir/file1_lineage', result) self.assertEqual('/path/to/summary_dir/file1_lineage', result)
...@@ -119,35 +123,8 @@ class TestSummaryPathParser(TestCase): ...@@ -119,35 +123,8 @@ class TestSummaryPathParser(TestCase):
@mock.patch.object(SummaryWatcher, 'list_summary_directories') @mock.patch.object(SummaryWatcher, 'list_summary_directories')
def test_get_latest_lineage_summaries(self, *args): def test_get_latest_lineage_summaries(self, *args):
"""Test the function of get_latest_lineage_summaries.""" """Test the function of get_latest_lineage_summaries."""
args[0].return_value = [ args[0].return_value = MOCK_SUMMARY_DIRS
{ args[1].return_value = MOCK_SUMMARIES
'relative_path': './relative_path0'
},
{
'relative_path': './'
},
{
'relative_path': './relative_path1'
}
]
args[1].return_value = [
{
'file_name': 'file0',
'create_time': datetime.fromtimestamp(1582031970)
},
{
'file_name': 'file0_lineage',
'create_time': datetime.fromtimestamp(1582031970)
},
{
'file_name': 'file1',
'create_time': datetime.fromtimestamp(1582031971)
},
{
'file_name': 'file1_lineage',
'create_time': datetime.fromtimestamp(1582031971)
}
]
expected_result = [ expected_result = [
'/path/to/base/relative_path0/file1_lineage', '/path/to/base/relative_path0/file1_lineage',
......
...@@ -15,38 +15,31 @@ ...@@ -15,38 +15,31 @@
"""Test the validate module.""" """Test the validate module."""
from unittest import TestCase from unittest import TestCase
from mindinsight.lineagemgr.common.exceptions.exceptions import \ from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamTypeError, LineageParamValueError
LineageParamValueError, LineageParamTypeError from mindinsight.lineagemgr.common.validator.model_parameter import SearchModelConditionParameter
from mindinsight.lineagemgr.common.validator.model_parameter import \ from mindinsight.lineagemgr.common.validator.validate import validate_search_model_condition
SearchModelConditionParameter
from mindinsight.lineagemgr.common.validator.validate import \
validate_search_model_condition
from mindinsight.utils.exceptions import MindInsightException from mindinsight.utils.exceptions import MindInsightException
class TestValidateSearchModelCondition(TestCase): class TestValidateSearchModelCondition(TestCase):
"""Test the mothod of validate_search_model_condition.""" """Test the mothod of validate_search_model_condition."""
def test_validate_search_model_condition(self): def test_validate_search_model_condition_param_type_error(self):
"""Test the mothod of validate_search_model_condition.""" """Test the mothod of validate_search_model_condition with LineageParamTypeError."""
condition = { condition = {
'summary_dir': 'xxx' 'summary_dir': 'xxx'
} }
self.assertRaisesRegex( self._assert_raise_of_lineage_param_type_error(
LineageParamTypeError,
'The search_condition element summary_dir should be dict.', 'The search_condition element summary_dir should be dict.',
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
def test_validate_search_model_condition_param_value_error(self):
"""Test the mothod of validate_search_model_condition with LineageParamValueError."""
condition = { condition = {
'xxx': 'xxx' 'xxx': 'xxx'
} }
self.assertRaisesRegex( self._assert_raise_of_lineage_param_value_error(
LineageParamValueError,
'The search attribute not supported.', 'The search attribute not supported.',
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
...@@ -55,22 +48,38 @@ class TestValidateSearchModelCondition(TestCase): ...@@ -55,22 +48,38 @@ class TestValidateSearchModelCondition(TestCase):
'xxx': 'xxx' 'xxx': 'xxx'
} }
} }
self.assertRaisesRegex( self._assert_raise_of_lineage_param_value_error(
LineageParamValueError,
"The compare condition should be in", "The compare condition should be in",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
condition = {
1: {
"ge": 8.0
}
}
self._assert_raise_of_lineage_param_value_error(
"The search attribute not supported.",
condition
)
condition = {
'metric_': {
"ge": 8.0
}
}
self._assert_raise_of_lineage_param_value_error(
"The search attribute not supported.",
condition
)
def test_validate_search_model_condition_mindinsight_exception_1(self):
"""Test the mothod of validate_search_model_condition with MindinsightException."""
condition = { condition = {
"offset": 100001 "offset": 100001
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException,
"Invalid input offset. 0 <= offset <= 100000", "Invalid input offset. 0 <= offset <= 100000",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
...@@ -80,11 +89,9 @@ class TestValidateSearchModelCondition(TestCase): ...@@ -80,11 +89,9 @@ class TestValidateSearchModelCondition(TestCase):
}, },
'limit': 10 'limit': 10
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException, "The parameter summary_dir is invalid. It should be a dict and "
"The parameter summary_dir is invalid. It should be a dict and the value should be a string", "the value should be a string",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
...@@ -93,11 +100,9 @@ class TestValidateSearchModelCondition(TestCase): ...@@ -93,11 +100,9 @@ class TestValidateSearchModelCondition(TestCase):
'in': 1.0 'in': 1.0
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException, "The parameter learning_rate is invalid. It should be a dict and "
"The parameter learning_rate is invalid. It should be a dict and the value should be a float or a integer", "the value should be a float or a integer",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
...@@ -106,24 +111,22 @@ class TestValidateSearchModelCondition(TestCase): ...@@ -106,24 +111,22 @@ class TestValidateSearchModelCondition(TestCase):
'lt': True 'lt': True
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException, "The parameter learning_rate is invalid. It should be a dict and "
"The parameter learning_rate is invalid. It should be a dict and the value should be a float or a integer", "the value should be a float or a integer",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
def test_validate_search_model_condition_mindinsight_exception_2(self):
"""Test the mothod of validate_search_model_condition with MindinsightException."""
condition = { condition = {
'learning_rate': { 'learning_rate': {
'gt': [1.0] 'gt': [1.0]
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException, "The parameter learning_rate is invalid. It should be a dict and "
"The parameter learning_rate is invalid. It should be a dict and the value should be a float or a integer", "the value should be a float or a integer",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
...@@ -132,11 +135,9 @@ class TestValidateSearchModelCondition(TestCase): ...@@ -132,11 +135,9 @@ class TestValidateSearchModelCondition(TestCase):
'ge': 1 'ge': 1
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException, "The parameter loss_function is invalid. It should be a dict and "
"The parameter loss_function is invalid. It should be a dict and the value should be a string", "the value should be a string",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
...@@ -145,12 +146,9 @@ class TestValidateSearchModelCondition(TestCase): ...@@ -145,12 +146,9 @@ class TestValidateSearchModelCondition(TestCase):
'in': 2 'in': 2
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException,
"The parameter train_dataset_count is invalid. It should be a dict " "The parameter train_dataset_count is invalid. It should be a dict "
"and the value should be a integer between 0", "and the value should be a integer between 0",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
...@@ -162,14 +160,14 @@ class TestValidateSearchModelCondition(TestCase): ...@@ -162,14 +160,14 @@ class TestValidateSearchModelCondition(TestCase):
'eq': 'xxx' 'eq': 'xxx'
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException, "The parameter network is invalid. It should be a dict and "
"The parameter network is invalid. It should be a dict and the value should be a string", "the value should be a string",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
def test_validate_search_model_condition_mindinsight_exception_3(self):
"""Test the mothod of validate_search_model_condition with MindinsightException."""
condition = { condition = {
'batch_size': { 'batch_size': {
'lt': 2, 'lt': 2,
...@@ -179,11 +177,8 @@ class TestValidateSearchModelCondition(TestCase): ...@@ -179,11 +177,8 @@ class TestValidateSearchModelCondition(TestCase):
'eq': 222 'eq': 222
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException,
"The parameter batch_size is invalid. It should be a non-negative integer.", "The parameter batch_size is invalid. It should be a non-negative integer.",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
...@@ -192,12 +187,9 @@ class TestValidateSearchModelCondition(TestCase): ...@@ -192,12 +187,9 @@ class TestValidateSearchModelCondition(TestCase):
'lt': -2 'lt': -2
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException,
"The parameter test_dataset_count is invalid. It should be a dict " "The parameter test_dataset_count is invalid. It should be a dict "
"and the value should be a integer between 0", "and the value should be a integer between 0",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
...@@ -206,11 +198,8 @@ class TestValidateSearchModelCondition(TestCase): ...@@ -206,11 +198,8 @@ class TestValidateSearchModelCondition(TestCase):
'lt': False 'lt': False
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException,
"The parameter epoch is invalid. It should be a positive integer.", "The parameter epoch is invalid. It should be a positive integer.",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
...@@ -219,65 +208,79 @@ class TestValidateSearchModelCondition(TestCase): ...@@ -219,65 +208,79 @@ class TestValidateSearchModelCondition(TestCase):
"ge": "" "ge": ""
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException, "The parameter learning_rate is invalid. It should be a dict and "
"The parameter learning_rate is invalid. It should be a dict and the value should be a float or a integer", "the value should be a float or a integer",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
def test_validate_search_model_condition_mindinsight_exception_4(self):
"""Test the mothod of validate_search_model_condition with MindinsightException."""
condition = { condition = {
"train_dataset_count": { "train_dataset_count": {
"ge": 8.0 "ge": 8.0
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
MindInsightException,
"The parameter train_dataset_count is invalid. It should be a dict " "The parameter train_dataset_count is invalid. It should be a dict "
"and the value should be a integer between 0", "and the value should be a integer between 0",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
condition = { condition = {
1: { 'metric_attribute': {
"ge": 8.0 'ge': 'xxx'
} }
} }
self.assertRaisesRegex( self._assert_raise_of_mindinsight_exception(
LineageParamValueError, "The parameter metric_attribute is invalid. "
"The search attribute not supported.", "It should be a dict and the value should be a float or a integer",
validate_search_model_condition,
SearchModelConditionParameter,
condition condition
) )
condition = { def _assert_raise(self, exception, msg, condition):
'metric_': { """
"ge": 8.0 Assert raise by unittest.
}
}
LineageParamValueError('The search attribute not supported.')
self.assertRaisesRegex(
LineageParamValueError,
"The search attribute not supported.",
validate_search_model_condition,
SearchModelConditionParameter,
condition
)
condition = { Args:
'metric_attribute': { exception (Type): Exception class expected to be raised.
'ge': 'xxx' msg (msg): Expected error message.
} condition (dict): The parameter of search condition.
} """
self.assertRaisesRegex( self.assertRaisesRegex(
MindInsightException, exception,
"The parameter metric_attribute is invalid. " msg,
"It should be a dict and the value should be a float or a integer",
validate_search_model_condition, validate_search_model_condition,
SearchModelConditionParameter, SearchModelConditionParameter,
condition condition
) )
def _assert_raise_of_mindinsight_exception(self, msg, condition):
"""
Assert raise of MindinsightException by unittest.
Args:
msg (msg): Expected error message.
condition (dict): The parameter of search condition.
"""
self._assert_raise(MindInsightException, msg, condition)
def _assert_raise_of_lineage_param_value_error(self, msg, condition):
"""
Assert raise of LineageParamValueError by unittest.
Args:
msg (msg): Expected error message.
condition (dict): The parameter of search condition.
"""
self._assert_raise(LineageParamValueError, msg, condition)
def _assert_raise_of_lineage_param_type_error(self, msg, condition):
"""
Assert raise of LineageParamTypeError by unittest.
Args:
msg (msg): Expected error message.
condition (dict): The parameter of search condition.
"""
self._assert_raise(LineageParamTypeError, msg, condition)
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
"""The event data in querier test.""" """The event data in querier test."""
import json import json
from ....utils.mindspore.dataset.engine.serializer_deserializer import SERIALIZED_PIPELINE
EVENT_TRAIN_DICT_0 = { EVENT_TRAIN_DICT_0 = {
'wall_time': 1581499557.7017336, 'wall_time': 1581499557.7017336,
'train_lineage': { 'train_lineage': {
...@@ -373,49 +375,4 @@ EVENT_DATASET_DICT_0 = { ...@@ -373,49 +375,4 @@ EVENT_DATASET_DICT_0 = {
} }
} }
DATASET_DICT_0 = { DATASET_DICT_0 = SERIALIZED_PIPELINE
'op_type': 'BatchDataset',
'op_module': 'minddata.dataengine.datasets',
'num_parallel_workers': None,
'drop_remainder': True,
'batch_size': 10,
'children': [
{
'op_type': 'MapDataset',
'op_module': 'minddata.dataengine.datasets',
'num_parallel_workers': None,
'input_columns': [
'label'
],
'output_columns': [
None
],
'operations': [
{
'tensor_op_module': 'minddata.transforms.c_transforms',
'tensor_op_name': 'OneHot',
'num_classes': 10
}
],
'children': [
{
'op_type': 'MnistDataset',
'shard_id': None,
'num_shards': None,
'op_module': 'minddata.dataengine.datasets',
'dataset_dir': '/home/anthony/MindData/tests/dataset/data/testMnistData',
'num_parallel_workers': None,
'shuffle': None,
'num_samples': 100,
'sampler': {
'sampler_module': 'minddata.dataengine.samplers',
'sampler_name': 'RandomSampler',
'replacement': True,
'num_samples': 100
},
'children': []
}
]
}
]
}
...@@ -18,12 +18,12 @@ from unittest import TestCase, mock ...@@ -18,12 +18,12 @@ from unittest import TestCase, mock
from google.protobuf.json_format import ParseDict from google.protobuf.json_format import ParseDict
import mindinsight.datavisual.proto_files.mindinsight_summary_pb2 as summary_pb2 import mindinsight.datavisual.proto_files.mindinsight_summary_pb2 as summary_pb2
from mindinsight.lineagemgr.common.exceptions.exceptions import \ from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageParamTypeError, LineageQuerierParamException,
LineageQuerierParamException, LineageParamTypeError, \ LineageSummaryAnalyzeException,
LineageSummaryAnalyzeException, LineageSummaryParseException LineageSummaryParseException)
from mindinsight.lineagemgr.querier.querier import Querier from mindinsight.lineagemgr.querier.querier import Querier
from mindinsight.lineagemgr.summary.lineage_summary_analyzer import \ from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageInfo
LineageInfo
from . import event_data from . import event_data
...@@ -140,6 +140,98 @@ def get_lineage_infos(): ...@@ -140,6 +140,98 @@ def get_lineage_infos():
return lineage_infos return lineage_infos
LINEAGE_INFO_0 = {
'summary_dir': '/path/to/summary0',
**event_data.EVENT_TRAIN_DICT_0['train_lineage'],
'metric': event_data.METRIC_0,
'valid_dataset': event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'],
'dataset_graph': event_data.DATASET_DICT_0
}
LINEAGE_INFO_1 = {
'summary_dir': '/path/to/summary1',
**event_data.EVENT_TRAIN_DICT_1['train_lineage'],
'metric': event_data.METRIC_1,
'valid_dataset': event_data.EVENT_EVAL_DICT_1['evaluation_lineage']['valid_dataset'],
'dataset_graph': event_data.DATASET_DICT_0
}
LINEAGE_FILTRATION_0 = create_filtration_result(
'/path/to/summary0',
event_data.EVENT_TRAIN_DICT_0,
event_data.EVENT_EVAL_DICT_0,
event_data.METRIC_0,
event_data.DATASET_DICT_0
)
LINEAGE_FILTRATION_1 = create_filtration_result(
'/path/to/summary1',
event_data.EVENT_TRAIN_DICT_1,
event_data.EVENT_EVAL_DICT_1,
event_data.METRIC_1,
event_data.DATASET_DICT_0
)
LINEAGE_FILTRATION_2 = create_filtration_result(
'/path/to/summary2',
event_data.EVENT_TRAIN_DICT_2,
event_data.EVENT_EVAL_DICT_2,
event_data.METRIC_2,
event_data.DATASET_DICT_0
)
LINEAGE_FILTRATION_3 = create_filtration_result(
'/path/to/summary3',
event_data.EVENT_TRAIN_DICT_3,
event_data.EVENT_EVAL_DICT_3,
event_data.METRIC_3,
event_data.DATASET_DICT_0
)
LINEAGE_FILTRATION_4 = create_filtration_result(
'/path/to/summary4',
event_data.EVENT_TRAIN_DICT_4,
event_data.EVENT_EVAL_DICT_4,
event_data.METRIC_4,
event_data.DATASET_DICT_0
)
LINEAGE_FILTRATION_5 = {
"summary_dir": '/path/to/summary5',
"loss_function":
event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['loss_function'],
"train_dataset_path": None,
"train_dataset_count":
event_data.EVENT_TRAIN_DICT_5['train_lineage']['train_dataset']['train_dataset_size'],
"test_dataset_path": None,
"test_dataset_count": None,
"network": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['network'],
"optimizer": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['optimizer'],
"learning_rate":
event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['learning_rate'],
"epoch": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['epoch'],
"batch_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['batch_size'],
"loss": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['loss'],
"model_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['model']['size'],
"metric": {},
"dataset_graph": event_data.DATASET_DICT_0,
"dataset_mark": '2'
}
LINEAGE_FILTRATION_6 = {
"summary_dir": '/path/to/summary6',
"loss_function": None,
"train_dataset_path": None,
"train_dataset_count": None,
"test_dataset_path":
event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_path'],
"test_dataset_count":
event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_size'],
"network": None,
"optimizer": None,
"learning_rate": None,
"epoch": None,
"batch_size": None,
"loss": None,
"model_size": None,
"metric": event_data.METRIC_5,
"dataset_graph": event_data.DATASET_DICT_0,
"dataset_mark": '2'
}
class TestQuerier(TestCase): class TestQuerier(TestCase):
"""Test the class of `Querier`.""" """Test the class of `Querier`."""
@mock.patch('mindinsight.lineagemgr.querier.querier.LineageSummaryAnalyzer.get_summary_infos') @mock.patch('mindinsight.lineagemgr.querier.querier.LineageSummaryAnalyzer.get_summary_infos')
...@@ -169,31 +261,13 @@ class TestQuerier(TestCase): ...@@ -169,31 +261,13 @@ class TestQuerier(TestCase):
def test_get_summary_lineage_success_1(self): def test_get_summary_lineage_success_1(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
expected_result = [ expected_result = [LINEAGE_INFO_0]
{
'summary_dir': '/path/to/summary0',
**event_data.EVENT_TRAIN_DICT_0['train_lineage'],
'metric': event_data.METRIC_0,
'valid_dataset': event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'],
'dataset_graph': event_data.DATASET_DICT_0
}
]
result = self.single_querier.get_summary_lineage() result = self.single_querier.get_summary_lineage()
self.assertListEqual(expected_result, result) self.assertListEqual(expected_result, result)
def test_get_summary_lineage_success_2(self): def test_get_summary_lineage_success_2(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
expected_result = [ expected_result = [LINEAGE_INFO_0]
{
'summary_dir': '/path/to/summary0',
**event_data.EVENT_TRAIN_DICT_0['train_lineage'],
'metric': event_data.METRIC_0,
'valid_dataset':
event_data.EVENT_EVAL_DICT_0['evaluation_lineage'][
'valid_dataset'],
'dataset_graph': event_data.DATASET_DICT_0
}
]
result = self.single_querier.get_summary_lineage( result = self.single_querier.get_summary_lineage(
summary_dir='/path/to/summary0' summary_dir='/path/to/summary0'
) )
...@@ -216,20 +290,8 @@ class TestQuerier(TestCase): ...@@ -216,20 +290,8 @@ class TestQuerier(TestCase):
def test_get_summary_lineage_success_4(self): def test_get_summary_lineage_success_4(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
expected_result = [ expected_result = [
{ LINEAGE_INFO_0,
'summary_dir': '/path/to/summary0', LINEAGE_INFO_1,
**event_data.EVENT_TRAIN_DICT_0['train_lineage'],
'metric': event_data.METRIC_0,
'valid_dataset': event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'],
'dataset_graph': event_data.DATASET_DICT_0
},
{
'summary_dir': '/path/to/summary1',
**event_data.EVENT_TRAIN_DICT_1['train_lineage'],
'metric': event_data.METRIC_1,
'valid_dataset': event_data.EVENT_EVAL_DICT_1['evaluation_lineage']['valid_dataset'],
'dataset_graph': event_data.DATASET_DICT_0
},
{ {
'summary_dir': '/path/to/summary2', 'summary_dir': '/path/to/summary2',
**event_data.EVENT_TRAIN_DICT_2['train_lineage'], **event_data.EVENT_TRAIN_DICT_2['train_lineage'],
...@@ -274,15 +336,7 @@ class TestQuerier(TestCase): ...@@ -274,15 +336,7 @@ class TestQuerier(TestCase):
def test_get_summary_lineage_success_5(self): def test_get_summary_lineage_success_5(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
expected_result = [ expected_result = [LINEAGE_INFO_1]
{
'summary_dir': '/path/to/summary1',
**event_data.EVENT_TRAIN_DICT_1['train_lineage'],
'metric': event_data.METRIC_1,
'valid_dataset': event_data.EVENT_EVAL_DICT_1['evaluation_lineage']['valid_dataset'],
'dataset_graph': event_data.DATASET_DICT_0
}
]
result = self.multi_querier.get_summary_lineage( result = self.multi_querier.get_summary_lineage(
summary_dir='/path/to/summary1' summary_dir='/path/to/summary1'
) )
...@@ -341,20 +395,8 @@ class TestQuerier(TestCase): ...@@ -341,20 +395,8 @@ class TestQuerier(TestCase):
} }
expected_result = { expected_result = {
'object': [ 'object': [
create_filtration_result( LINEAGE_FILTRATION_1,
'/path/to/summary1', LINEAGE_FILTRATION_2
event_data.EVENT_TRAIN_DICT_1,
event_data.EVENT_EVAL_DICT_1,
event_data.METRIC_1,
event_data.DATASET_DICT_0,
),
create_filtration_result(
'/path/to/summary2',
event_data.EVENT_TRAIN_DICT_2,
event_data.EVENT_EVAL_DICT_2,
event_data.METRIC_2,
event_data.DATASET_DICT_0
)
], ],
'count': 2, 'count': 2,
} }
...@@ -377,20 +419,8 @@ class TestQuerier(TestCase): ...@@ -377,20 +419,8 @@ class TestQuerier(TestCase):
} }
expected_result = { expected_result = {
'object': [ 'object': [
create_filtration_result( LINEAGE_FILTRATION_2,
'/path/to/summary2', LINEAGE_FILTRATION_3
event_data.EVENT_TRAIN_DICT_2,
event_data.EVENT_EVAL_DICT_2,
event_data.METRIC_2,
event_data.DATASET_DICT_0
),
create_filtration_result(
'/path/to/summary3',
event_data.EVENT_TRAIN_DICT_3,
event_data.EVENT_EVAL_DICT_3,
event_data.METRIC_3,
event_data.DATASET_DICT_0
)
], ],
'count': 2, 'count': 2,
} }
...@@ -405,20 +435,8 @@ class TestQuerier(TestCase): ...@@ -405,20 +435,8 @@ class TestQuerier(TestCase):
} }
expected_result = { expected_result = {
'object': [ 'object': [
create_filtration_result( LINEAGE_FILTRATION_2,
'/path/to/summary2', LINEAGE_FILTRATION_3
event_data.EVENT_TRAIN_DICT_2,
event_data.EVENT_EVAL_DICT_2,
event_data.METRIC_2,
event_data.DATASET_DICT_0
),
create_filtration_result(
'/path/to/summary3',
event_data.EVENT_TRAIN_DICT_3,
event_data.EVENT_EVAL_DICT_3,
event_data.METRIC_3,
event_data.DATASET_DICT_0
)
], ],
'count': 7, 'count': 7,
} }
...@@ -429,82 +447,13 @@ class TestQuerier(TestCase): ...@@ -429,82 +447,13 @@ class TestQuerier(TestCase):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
expected_result = { expected_result = {
'object': [ 'object': [
create_filtration_result( LINEAGE_FILTRATION_0,
'/path/to/summary0', LINEAGE_FILTRATION_1,
event_data.EVENT_TRAIN_DICT_0, LINEAGE_FILTRATION_2,
event_data.EVENT_EVAL_DICT_0, LINEAGE_FILTRATION_3,
event_data.METRIC_0, LINEAGE_FILTRATION_4,
event_data.DATASET_DICT_0 LINEAGE_FILTRATION_5,
), LINEAGE_FILTRATION_6
create_filtration_result(
'/path/to/summary1',
event_data.EVENT_TRAIN_DICT_1,
event_data.EVENT_EVAL_DICT_1,
event_data.METRIC_1,
event_data.DATASET_DICT_0
),
create_filtration_result(
'/path/to/summary2',
event_data.EVENT_TRAIN_DICT_2,
event_data.EVENT_EVAL_DICT_2,
event_data.METRIC_2,
event_data.DATASET_DICT_0
),
create_filtration_result(
'/path/to/summary3',
event_data.EVENT_TRAIN_DICT_3,
event_data.EVENT_EVAL_DICT_3,
event_data.METRIC_3,
event_data.DATASET_DICT_0
),
create_filtration_result(
'/path/to/summary4',
event_data.EVENT_TRAIN_DICT_4,
event_data.EVENT_EVAL_DICT_4,
event_data.METRIC_4,
event_data.DATASET_DICT_0
),
{
"summary_dir": '/path/to/summary5',
"loss_function":
event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['loss_function'],
"train_dataset_path": None,
"train_dataset_count":
event_data.EVENT_TRAIN_DICT_5['train_lineage']['train_dataset']['train_dataset_size'],
"test_dataset_path": None,
"test_dataset_count": None,
"network": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['network'],
"optimizer": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['optimizer'],
"learning_rate":
event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['learning_rate'],
"epoch": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['epoch'],
"batch_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['batch_size'],
"loss": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['loss'],
"model_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['model']['size'],
"metric": {},
"dataset_graph": event_data.DATASET_DICT_0,
"dataset_mark": '2'
},
{
"summary_dir": '/path/to/summary6',
"loss_function": None,
"train_dataset_path": None,
"train_dataset_count": None,
"test_dataset_path":
event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_path'],
"test_dataset_count":
event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_size'],
"network": None,
"optimizer": None,
"learning_rate": None,
"epoch": None,
"batch_size": None,
"loss": None,
"model_size": None,
"metric": event_data.METRIC_5,
"dataset_graph": event_data.DATASET_DICT_0,
"dataset_mark": '2'
}
], ],
'count': 7, 'count': 7,
} }
...@@ -519,15 +468,7 @@ class TestQuerier(TestCase): ...@@ -519,15 +468,7 @@ class TestQuerier(TestCase):
} }
} }
expected_result = { expected_result = {
'object': [ 'object': [LINEAGE_FILTRATION_4],
create_filtration_result(
'/path/to/summary4',
event_data.EVENT_TRAIN_DICT_4,
event_data.EVENT_EVAL_DICT_4,
event_data.METRIC_4,
event_data.DATASET_DICT_0
),
],
'count': 1, 'count': 1,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
...@@ -541,82 +482,13 @@ class TestQuerier(TestCase): ...@@ -541,82 +482,13 @@ class TestQuerier(TestCase):
} }
expected_result = { expected_result = {
'object': [ 'object': [
create_filtration_result( LINEAGE_FILTRATION_0,
'/path/to/summary0', LINEAGE_FILTRATION_5,
event_data.EVENT_TRAIN_DICT_0, LINEAGE_FILTRATION_1,
event_data.EVENT_EVAL_DICT_0, LINEAGE_FILTRATION_2,
event_data.METRIC_0, LINEAGE_FILTRATION_3,
event_data.DATASET_DICT_0 LINEAGE_FILTRATION_4,
), LINEAGE_FILTRATION_6
{
"summary_dir": '/path/to/summary5',
"loss_function":
event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['loss_function'],
"train_dataset_path": None,
"train_dataset_count":
event_data.EVENT_TRAIN_DICT_5['train_lineage']['train_dataset']['train_dataset_size'],
"test_dataset_path": None,
"test_dataset_count": None,
"network": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['network'],
"optimizer": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['optimizer'],
"learning_rate":
event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['learning_rate'],
"epoch": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['epoch'],
"batch_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['batch_size'],
"loss": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['loss'],
"model_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['model']['size'],
"metric": {},
"dataset_graph": event_data.DATASET_DICT_0,
"dataset_mark": '2'
},
create_filtration_result(
'/path/to/summary1',
event_data.EVENT_TRAIN_DICT_1,
event_data.EVENT_EVAL_DICT_1,
event_data.METRIC_1,
event_data.DATASET_DICT_0
),
create_filtration_result(
'/path/to/summary2',
event_data.EVENT_TRAIN_DICT_2,
event_data.EVENT_EVAL_DICT_2,
event_data.METRIC_2,
event_data.DATASET_DICT_0
),
create_filtration_result(
'/path/to/summary3',
event_data.EVENT_TRAIN_DICT_3,
event_data.EVENT_EVAL_DICT_3,
event_data.METRIC_3,
event_data.DATASET_DICT_0
),
create_filtration_result(
'/path/to/summary4',
event_data.EVENT_TRAIN_DICT_4,
event_data.EVENT_EVAL_DICT_4,
event_data.METRIC_4,
event_data.DATASET_DICT_0
),
{
"summary_dir": '/path/to/summary6',
"loss_function": None,
"train_dataset_path": None,
"train_dataset_count": None,
"test_dataset_path":
event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_path'],
"test_dataset_count":
event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_size'],
"network": None,
"optimizer": None,
"learning_rate": None,
"epoch": None,
"batch_size": None,
"loss": None,
"model_size": None,
"metric": event_data.METRIC_5,
"dataset_graph": event_data.DATASET_DICT_0,
"dataset_mark": '2'
}
], ],
'count': 7, 'count': 7,
} }
...@@ -631,82 +503,13 @@ class TestQuerier(TestCase): ...@@ -631,82 +503,13 @@ class TestQuerier(TestCase):
} }
expected_result = { expected_result = {
'object': [ 'object': [
{ LINEAGE_FILTRATION_6,
"summary_dir": '/path/to/summary6', LINEAGE_FILTRATION_4,
"loss_function": None, LINEAGE_FILTRATION_3,
"train_dataset_path": None, LINEAGE_FILTRATION_2,
"train_dataset_count": None, LINEAGE_FILTRATION_1,
"test_dataset_path": LINEAGE_FILTRATION_0,
event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_path'], LINEAGE_FILTRATION_5
"test_dataset_count":
event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_size'],
"network": None,
"optimizer": None,
"learning_rate": None,
"epoch": None,
"batch_size": None,
"loss": None,
"model_size": None,
"metric": event_data.METRIC_5,
"dataset_graph": event_data.DATASET_DICT_0,
"dataset_mark": '2'
},
create_filtration_result(
'/path/to/summary4',
event_data.EVENT_TRAIN_DICT_4,
event_data.EVENT_EVAL_DICT_4,
event_data.METRIC_4,
event_data.DATASET_DICT_0
),
create_filtration_result(
'/path/to/summary3',
event_data.EVENT_TRAIN_DICT_3,
event_data.EVENT_EVAL_DICT_3,
event_data.METRIC_3,
event_data.DATASET_DICT_0
),
create_filtration_result(
'/path/to/summary2',
event_data.EVENT_TRAIN_DICT_2,
event_data.EVENT_EVAL_DICT_2,
event_data.METRIC_2,
event_data.DATASET_DICT_0
),
create_filtration_result(
'/path/to/summary1',
event_data.EVENT_TRAIN_DICT_1,
event_data.EVENT_EVAL_DICT_1,
event_data.METRIC_1,
event_data.DATASET_DICT_0
),
create_filtration_result(
'/path/to/summary0',
event_data.EVENT_TRAIN_DICT_0,
event_data.EVENT_EVAL_DICT_0,
event_data.METRIC_0,
event_data.DATASET_DICT_0
),
{
"summary_dir": '/path/to/summary5',
"loss_function":
event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['loss_function'],
"train_dataset_path": None,
"train_dataset_count":
event_data.EVENT_TRAIN_DICT_5['train_lineage']['train_dataset']['train_dataset_size'],
"test_dataset_path": None,
"test_dataset_count": None,
"network": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['network'],
"optimizer": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['optimizer'],
"learning_rate":
event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['learning_rate'],
"epoch": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['epoch'],
"batch_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['batch_size'],
"loss": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['loss'],
"model_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['model']['size'],
"metric": {},
"dataset_graph": event_data.DATASET_DICT_0,
"dataset_mark": '2'
}
], ],
'count': 7, 'count': 7,
} }
...@@ -722,15 +525,7 @@ class TestQuerier(TestCase): ...@@ -722,15 +525,7 @@ class TestQuerier(TestCase):
} }
} }
expected_result = { expected_result = {
'object': [ 'object': [LINEAGE_FILTRATION_4],
create_filtration_result(
'/path/to/summary4',
event_data.EVENT_TRAIN_DICT_4,
event_data.EVENT_EVAL_DICT_4,
event_data.METRIC_4,
event_data.DATASET_DICT_0
),
],
'count': 1, 'count': 1,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
...@@ -809,20 +604,8 @@ class TestQuerier(TestCase): ...@@ -809,20 +604,8 @@ class TestQuerier(TestCase):
querier = Querier(summary_path) querier = Querier(summary_path)
querier._parse_failed_paths.append('/path/to/summary1/log1') querier._parse_failed_paths.append('/path/to/summary1/log1')
expected_result = [ expected_result = [
{ LINEAGE_INFO_0,
'summary_dir': '/path/to/summary0', LINEAGE_INFO_1
**event_data.EVENT_TRAIN_DICT_0['train_lineage'],
'metric': event_data.METRIC_0,
'valid_dataset': event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'],
'dataset_graph': event_data.DATASET_DICT_0
},
{
'summary_dir': '/path/to/summary1',
**event_data.EVENT_TRAIN_DICT_1['train_lineage'],
'metric': event_data.METRIC_1,
'valid_dataset': event_data.EVENT_EVAL_DICT_1['evaluation_lineage']['valid_dataset'],
'dataset_graph': event_data.DATASET_DICT_0
}
] ]
result = querier.get_summary_lineage() result = querier.get_summary_lineage()
self.assertListEqual(expected_result, result) self.assertListEqual(expected_result, result)
...@@ -842,17 +625,7 @@ class TestQuerier(TestCase): ...@@ -842,17 +625,7 @@ class TestQuerier(TestCase):
querier._parse_failed_paths.append('/path/to/summary1/log1') querier._parse_failed_paths.append('/path/to/summary1/log1')
args[0].return_value = create_lineage_info(None, None, None) args[0].return_value = create_lineage_info(None, None, None)
expected_result = [ expected_result = [LINEAGE_INFO_0]
{
'summary_dir': '/path/to/summary0',
**event_data.EVENT_TRAIN_DICT_0['train_lineage'],
'metric': event_data.METRIC_0,
'valid_dataset':
event_data.EVENT_EVAL_DICT_0['evaluation_lineage'][
'valid_dataset'],
'dataset_graph': event_data.DATASET_DICT_0
}
]
result = querier.get_summary_lineage() result = querier.get_summary_lineage()
self.assertListEqual(expected_result, result) self.assertListEqual(expected_result, result)
self.assertListEqual( self.assertListEqual(
......
...@@ -15,11 +15,12 @@ ...@@ -15,11 +15,12 @@
"""Test the query_model module.""" """Test the query_model module."""
from unittest import TestCase from unittest import TestCase
from mindinsight.lineagemgr.common.exceptions.exceptions import \ from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageEventFieldNotExistException,
LineageEventNotExistException, LineageEventFieldNotExistException LineageEventNotExistException)
from mindinsight.lineagemgr.querier.query_model import LineageObj from mindinsight.lineagemgr.querier.query_model import LineageObj
from . import event_data from . import event_data
from .test_querier import create_lineage_info, create_filtration_result from .test_querier import create_filtration_result, create_lineage_info
class TestLineageObj(TestCase): class TestLineageObj(TestCase):
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
...@@ -19,10 +19,10 @@ import time ...@@ -19,10 +19,10 @@ import time
from google.protobuf import json_format from google.protobuf import json_format
from tests.ut.datavisual.utils.log_generators.log_generator import LogGenerator
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
from .log_generator import LogGenerator
class GraphLogGenerator(LogGenerator): class GraphLogGenerator(LogGenerator):
""" """
...@@ -74,7 +74,7 @@ class GraphLogGenerator(LogGenerator): ...@@ -74,7 +74,7 @@ class GraphLogGenerator(LogGenerator):
if __name__ == "__main__": if __name__ == "__main__":
graph_log_generator = GraphLogGenerator() graph_log_generator = GraphLogGenerator()
test_file_name = '%s.%s.%s' % ('graph', 'summary', str(time.time())) test_file_name = '%s.%s.%s' % ('graph', 'summary', str(time.time()))
graph_base_path = os.path.join(os.path.dirname(__file__), os.pardir, "log_generators", "graph_base.json") graph_base_path = os.path.join(os.path.dirname(__file__), os.pardir, "log_generators--", "graph_base.json")
with open(graph_base_path, 'r') as load_f: with open(graph_base_path, 'r') as load_f:
graph = json.load(load_f) graph = json.load(load_f)
graph_log_generator.generate_log(test_file_name, graph) graph_log_generator.generate_log(test_file_name, graph)
...@@ -18,10 +18,11 @@ import time ...@@ -18,10 +18,11 @@ import time
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from tests.st.func.datavisual.utils.log_generators.log_generator import LogGenerator
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
from .log_generator import LogGenerator
class ImagesLogGenerator(LogGenerator): class ImagesLogGenerator(LogGenerator):
""" """
...@@ -138,12 +139,7 @@ class ImagesLogGenerator(LogGenerator): ...@@ -138,12 +139,7 @@ class ImagesLogGenerator(LogGenerator):
images_metadata.append(image_metadata) images_metadata.append(image_metadata)
images_values.update({step: image_tensor}) images_values.update({step: image_tensor})
values = dict( values = dict(wall_time=wall_time, step=step, image=image_tensor, tag=tag_name)
wall_time=wall_time,
step=step,
image=image_tensor,
tag=tag_name
)
self._write_log_one_step(file_path, values) self._write_log_one_step(file_path, values)
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import struct import struct
from abc import abstractmethod from abc import abstractmethod
from tests.st.func.datavisual.utils import crc32 from ...utils import crc32
class LogGenerator: class LogGenerator:
......
...@@ -16,10 +16,11 @@ ...@@ -16,10 +16,11 @@
import time import time
import numpy as np import numpy as np
from tests.st.func.datavisual.utils.log_generators.log_generator import LogGenerator
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
from .log_generator import LogGenerator
class ScalarsLogGenerator(LogGenerator): class ScalarsLogGenerator(LogGenerator):
""" """
......
...@@ -19,13 +19,12 @@ import json ...@@ -19,13 +19,12 @@ import json
import os import os
import time import time
from tests.st.func.datavisual.constants import SUMMARY_PREFIX
from tests.st.func.datavisual.utils.log_generators.graph_log_generator import GraphLogGenerator
from tests.st.func.datavisual.utils.log_generators.images_log_generator import ImagesLogGenerator
from tests.st.func.datavisual.utils.log_generators.scalars_log_generator import ScalarsLogGenerator
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from .log_generators.graph_log_generator import GraphLogGenerator
from .log_generators.images_log_generator import ImagesLogGenerator
from .log_generators.scalars_log_generator import ScalarsLogGenerator
log_generators = { log_generators = {
PluginNameEnum.GRAPH.value: GraphLogGenerator(), PluginNameEnum.GRAPH.value: GraphLogGenerator(),
PluginNameEnum.IMAGE.value: ImagesLogGenerator(), PluginNameEnum.IMAGE.value: ImagesLogGenerator(),
...@@ -35,10 +34,12 @@ log_generators = { ...@@ -35,10 +34,12 @@ log_generators = {
class LogOperations: class LogOperations:
"""Log Operations.""" """Log Operations."""
def __init__(self): def __init__(self):
self._step_num = 3 self._step_num = 3
self._tag_num = 2 self._tag_num = 2
self._time_count = 0 self._time_count = 0
self._graph_base_path = os.path.join(os.path.dirname(__file__), "log_generators", "graph_base.json")
def _get_steps(self): def _get_steps(self):
"""Get steps.""" """Get steps."""
...@@ -61,9 +62,7 @@ class LogOperations: ...@@ -61,9 +62,7 @@ class LogOperations:
metadata_dict["plugins"].update({plugin_name: list()}) metadata_dict["plugins"].update({plugin_name: list()})
log_generator = log_generators.get(plugin_name) log_generator = log_generators.get(plugin_name)
if plugin_name == PluginNameEnum.GRAPH.value: if plugin_name == PluginNameEnum.GRAPH.value:
graph_base_path = os.path.join(os.path.dirname(__file__), with open(self._graph_base_path, 'r') as load_f:
os.pardir, "utils", "log_generators", "graph_base.json")
with open(graph_base_path, 'r') as load_f:
graph_dict = json.load(load_f) graph_dict = json.load(load_f)
values = log_generator.generate_log(file_path, graph_dict) values = log_generator.generate_log(file_path, graph_dict)
metadata_dict["actual_values"].update({plugin_name: values}) metadata_dict["actual_values"].update({plugin_name: values})
...@@ -82,13 +81,13 @@ class LogOperations: ...@@ -82,13 +81,13 @@ class LogOperations:
self._time_count += 1 self._time_count += 1
return metadata_dict return metadata_dict
def create_summary_logs(self, summary_base_dir, summary_dir_num, start_index=0): def create_summary_logs(self, summary_base_dir, summary_dir_num, dir_prefix, start_index=0):
"""Create summary logs in summary_base_dir.""" """Create summary logs in summary_base_dir."""
summary_metadata = dict() summary_metadata = dict()
steps_list = self._get_steps() steps_list = self._get_steps()
tag_name_list = self._get_tags() tag_name_list = self._get_tags()
for i in range(start_index, summary_dir_num + start_index): for i in range(start_index, summary_dir_num + start_index):
log_dir = os.path.join(summary_base_dir, f'{SUMMARY_PREFIX}{i}') log_dir = os.path.join(summary_base_dir, f'{dir_prefix}{i}')
os.makedirs(log_dir) os.makedirs(log_dir)
train_id = log_dir.replace(summary_base_dir, ".") train_id = log_dir.replace(summary_base_dir, ".")
...@@ -120,3 +119,47 @@ class LogOperations: ...@@ -120,3 +119,47 @@ class LogOperations:
metadata_dict = self.create_summary(log_dir, steps_list, tag_name_list) metadata_dict = self.create_summary(log_dir, steps_list, tag_name_list)
return {train_id: metadata_dict} return {train_id: metadata_dict}
def generate_log(self, plugin_name, log_dir, log_settings=None, valid=True):
"""
Generate log for ut.
Args:
plugin_name (str): Plugin name, contains 'graph', 'image', and 'scalar'.
log_dir (str): Log path to write log.
log_settings (dict): Info about the log, e.g.:
{
current_time (int): Timestamp in summary file name, not necessary.
graph_base_path (str): Path of graph_bas.json, necessary for `graph`.
steps (list[int]): Steps for `image` and `scalar`, default is [1].
tag (str): Tag name, default is 'default_tag'.
}
valid (bool): If true, summary name will be valid.
Returns:
str, Summary log path.
"""
if log_settings is None:
log_settings = dict()
current_time = log_settings.get('time', int(time.time()))
current_time = int(current_time)
log_generator = log_generators.get(plugin_name)
if valid:
temp_path = os.path.join(log_dir, '%s.%s' % ('test.summary', str(current_time)))
else:
temp_path = os.path.join(log_dir, '%s.%s' % ('test.invalid', str(current_time)))
if plugin_name == PluginNameEnum.GRAPH.value:
with open(self._graph_base_path, 'r') as load_f:
graph_dict = json.load(load_f)
graph_dict = log_generator.generate_log(temp_path, graph_dict)
return temp_path, graph_dict
steps_list = log_settings.get('steps', [1])
tag_name = log_settings.get('tag', 'default_tag')
metadata, values = log_generator.generate_log(temp_path, steps_list, tag_name)
return temp_path, metadata, values
...@@ -48,7 +48,7 @@ class WithLossCell(Cell): ...@@ -48,7 +48,7 @@ class WithLossCell(Cell):
class TrainOneStepWithLossScaleCell(Cell): class TrainOneStepWithLossScaleCell(Cell):
"""Mocked TrainOneStepWithLossScaleCell.""" """Mocked TrainOneStepWithLossScaleCell."""
def __init__(self, network, optimizer): def __init__(self, network=None, optimizer=None):
super(TrainOneStepWithLossScaleCell, self).__init__() super(TrainOneStepWithLossScaleCell, self).__init__()
self.network = network self.network = network
self.optimizer = optimizer self.optimizer = optimizer
# Copyright 2019 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -19,9 +19,13 @@ import io ...@@ -19,9 +19,13 @@ import io
import os import os
import shutil import shutil
import time import time
import json
from urllib.parse import urlencode from urllib.parse import urlencode
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from mindinsight.datavisual.common.enums import DataManagerStatus from mindinsight.datavisual.common.enums import DataManagerStatus
...@@ -69,3 +73,10 @@ def get_image_tensor_from_bytes(image_string): ...@@ -69,3 +73,10 @@ def get_image_tensor_from_bytes(image_string):
image_tensor = np.array(img) image_tensor = np.array(img)
return image_tensor return image_tensor
def compare_result_with_file(result, expected_file_path):
"""Compare result with file which contain the expected results."""
with open(expected_file_path, 'r') as file:
expected_results = json.load(file)
assert result == expected_results
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册