提交 8ea4d18c 编写于 作者: O ougongchang

Extract the common function methods and reduced cyclomatic complexity of functions

上级 b91233a9
...@@ -18,19 +18,21 @@ This file is used to define the basic graph. ...@@ -18,19 +18,21 @@ This file is used to define the basic graph.
import copy import copy
import time import time
from enum import Enum
from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common import exceptions
from .node import NodeTypeEnum from .node import NodeTypeEnum
from .node import Node from .node import Node
class EdgeTypeEnum: class EdgeTypeEnum(Enum):
"""Node edge type enum.""" """Node edge type enum."""
control = 'control' CONTROL = 'control'
data = 'data' DATA = 'data'
class DataTypeEnum: class DataTypeEnum(Enum):
"""Data type enum.""" """Data type enum."""
DT_TENSOR = 13 DT_TENSOR = 13
...@@ -292,70 +294,65 @@ class Graph: ...@@ -292,70 +294,65 @@ class Graph:
output_attr['scope'] = NodeTypeEnum.POLYMERIC_SCOPE.value output_attr['scope'] = NodeTypeEnum.POLYMERIC_SCOPE.value
node.update_output({dst_name: output_attr}) node.update_output({dst_name: output_attr})
def _calc_polymeric_input_output(self): def _update_polymeric_input_output(self):
"""Calc polymeric input and output after build polymeric node.""" """Calc polymeric input and output after build polymeric node."""
for name, node in self._normal_nodes.items(): for node in self._normal_nodes.values():
polymeric_input = {} polymeric_input = self._calc_polymeric_attr(node, 'input')
for src_name in node.input:
src_node = self._polymeric_nodes.get(src_name)
if node.node_type == NodeTypeEnum.POLYMERIC_SCOPE.value:
src_name = src_name if not src_node else src_node.polymeric_scope_name
output_name = self._calc_dummy_node_name(name, src_name)
polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.data}})
continue
if not src_node:
continue
if not node.name_scope and src_node.name_scope:
# if current node is in first layer, and the src node is not in
# the first layer, the src node will not be the polymeric input of current node.
continue
if node.name_scope == src_node.name_scope \
or node.name_scope.startswith(src_node.name_scope):
polymeric_input.update(
{src_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.data}})
node.update_polymeric_input(polymeric_input) node.update_polymeric_input(polymeric_input)
polymeric_output = {} polymeric_output = self._calc_polymeric_attr(node, 'output')
for dst_name in node.output:
dst_node = self._polymeric_nodes.get(dst_name)
if node.node_type == NodeTypeEnum.POLYMERIC_SCOPE.value:
dst_name = dst_name if not dst_node else dst_node.polymeric_scope_name
output_name = self._calc_dummy_node_name(name, dst_name)
polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.data}})
continue
if not dst_node:
continue
if not node.name_scope and dst_node.name_scope:
continue
if node.name_scope == dst_node.name_scope \
or node.name_scope.startswith(dst_node.name_scope):
polymeric_output.update(
{dst_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.data}})
node.update_polymeric_output(polymeric_output) node.update_polymeric_output(polymeric_output)
for name, node in self._polymeric_nodes.items(): for name, node in self._polymeric_nodes.items():
polymeric_input = {} polymeric_input = {}
for src_name in node.input: for src_name in node.input:
output_name = self._calc_dummy_node_name(name, src_name) output_name = self._calc_dummy_node_name(name, src_name)
polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.data}}) polymeric_input.update({output_name: {'edge_type': EdgeTypeEnum.DATA.value}})
node.update_polymeric_input(polymeric_input) node.update_polymeric_input(polymeric_input)
polymeric_output = {} polymeric_output = {}
for dst_name in node.output: for dst_name in node.output:
polymeric_output = {} polymeric_output = {}
output_name = self._calc_dummy_node_name(name, dst_name) output_name = self._calc_dummy_node_name(name, dst_name)
polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.data}}) polymeric_output.update({output_name: {'edge_type': EdgeTypeEnum.DATA.value}})
node.update_polymeric_output(polymeric_output) node.update_polymeric_output(polymeric_output)
def _calc_polymeric_attr(self, node, attr):
"""
Calc polymeric input or polymeric output after build polymeric node.
Args:
node (Node): Computes the polymeric input for a given node.
attr (str): The polymeric attr, optional value is `input` or `output`.
Returns:
dict, return polymeric input or polymeric output of the given node.
"""
polymeric_attr = {}
for node_name in getattr(node, attr):
polymeric_node = self._polymeric_nodes.get(node_name)
if node.node_type == NodeTypeEnum.POLYMERIC_SCOPE.value:
node_name = node_name if not polymeric_node else polymeric_node.polymeric_scope_name
dummy_node_name = self._calc_dummy_node_name(node.name, node_name)
polymeric_attr.update({dummy_node_name: {'edge_type': EdgeTypeEnum.DATA.value}})
continue
if not polymeric_node:
continue
if not node.name_scope and polymeric_node.name_scope:
# If current node is in top-level layer, and the polymeric_node node is not in
# the top-level layer, the polymeric node will not be the polymeric input
# or polymeric output of current node.
continue
if node.name_scope == polymeric_node.name_scope \
or node.name_scope.startswith(polymeric_node.name_scope + '/'):
polymeric_attr.update(
{polymeric_node.polymeric_scope_name: {'edge_type': EdgeTypeEnum.DATA.value}})
return polymeric_attr
def _calc_dummy_node_name(self, current_node_name, other_node_name): def _calc_dummy_node_name(self, current_node_name, other_node_name):
""" """
Calc dummy node name. Calc dummy node name.
......
...@@ -39,7 +39,7 @@ class MSGraph(Graph): ...@@ -39,7 +39,7 @@ class MSGraph(Graph):
self._build_leaf_nodes(graph_proto) self._build_leaf_nodes(graph_proto)
self._build_polymeric_nodes() self._build_polymeric_nodes()
self._build_name_scope_nodes() self._build_name_scope_nodes()
self._calc_polymeric_input_output() self._update_polymeric_input_output()
logger.info("Build graph end, normal node count: %s, polymeric node " logger.info("Build graph end, normal node count: %s, polymeric node "
"count: %s.", len(self._normal_nodes), len(self._polymeric_nodes)) "count: %s.", len(self._normal_nodes), len(self._polymeric_nodes))
...@@ -90,9 +90,9 @@ class MSGraph(Graph): ...@@ -90,9 +90,9 @@ class MSGraph(Graph):
node_name = leaf_node_id_map_name[node_def.name] node_name = leaf_node_id_map_name[node_def.name]
node = self._leaf_nodes[node_name] node = self._leaf_nodes[node_name]
for input_def in node_def.input: for input_def in node_def.input:
edge_type = EdgeTypeEnum.data edge_type = EdgeTypeEnum.DATA.value
if input_def.type == "CONTROL_EDGE": if input_def.type == "CONTROL_EDGE":
edge_type = EdgeTypeEnum.control edge_type = EdgeTypeEnum.CONTROL.value
if const_nodes_map.get(input_def.name): if const_nodes_map.get(input_def.name):
const_node = copy.deepcopy(const_nodes_map[input_def.name]) const_node = copy.deepcopy(const_nodes_map[input_def.name])
...@@ -218,7 +218,7 @@ class MSGraph(Graph): ...@@ -218,7 +218,7 @@ class MSGraph(Graph):
node = Node(name=const.key, node_id=const_node_id) node = Node(name=const.key, node_id=const_node_id)
node.node_type = NodeTypeEnum.CONST.value node.node_type = NodeTypeEnum.CONST.value
node.update_attr({const.key: str(const.value)}) node.update_attr({const.key: str(const.value)})
if const.value.dtype == DataTypeEnum.DT_TENSOR: if const.value.dtype == DataTypeEnum.DT_TENSOR.value:
shape = [] shape = []
for dim in const.value.tensor_val.dims: for dim in const.value.tensor_val.dims:
shape.append(dim) shape.append(dim)
......
...@@ -172,7 +172,7 @@ class Node: ...@@ -172,7 +172,7 @@ class Node:
Args: Args:
polymeric_output (dict[str, dict): Format is {dst_node.polymeric_scope_name: polymeric_output (dict[str, dict): Format is {dst_node.polymeric_scope_name:
{'edge_type': EdgeTypeEnum.data}}). {'edge_type': EdgeTypeEnum.DATA.value}}).
""" """
self._polymeric_output.update(polymeric_output) self._polymeric_output.update(polymeric_output)
......
...@@ -19,11 +19,11 @@ Usage: ...@@ -19,11 +19,11 @@ Usage:
pytest tests/st/func/datavisual pytest tests/st/func/datavisual
""" """
import os import os
import json
import pytest import pytest
from .. import globals as gbl from .. import globals as gbl
from .....utils.tools import get_url from .....utils.tools import get_url, compare_result_with_file
BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes' BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes'
...@@ -33,12 +33,6 @@ class TestQueryNodes: ...@@ -33,12 +33,6 @@ class TestQueryNodes:
graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results') graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results')
def compare_result_with_file(self, result, filename):
"""Compare result with file which contain the expected results."""
with open(os.path.join(self.graph_results_dir, filename), 'r') as fp:
expected_results = json.load(fp)
assert result == expected_results
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.env_single @pytest.mark.env_single
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
...@@ -65,4 +59,5 @@ class TestQueryNodes: ...@@ -65,4 +59,5 @@ class TestQueryNodes:
url = get_url(BASE_URL, params) url = get_url(BASE_URL, params)
response = client.get(url) response = client.get(url)
assert response.status_code == 200 assert response.status_code == 200
self.compare_result_with_file(response.get_json(), result_file) file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(response.get_json(), file_path)
...@@ -19,12 +19,11 @@ Usage: ...@@ -19,12 +19,11 @@ Usage:
pytest tests/st/func/datavisual pytest tests/st/func/datavisual
""" """
import os import os
import json
import pytest import pytest
from .. import globals as gbl from .. import globals as gbl
from .....utils.tools import get_url from .....utils.tools import get_url, compare_result_with_file
BASE_URL = '/v1/mindinsight/datavisual/graphs/single-node' BASE_URL = '/v1/mindinsight/datavisual/graphs/single-node'
...@@ -34,12 +33,6 @@ class TestQuerySingleNode: ...@@ -34,12 +33,6 @@ class TestQuerySingleNode:
graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results') graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results')
def compare_result_with_file(self, result, filename):
"""Compare result with file which contain the expected results."""
with open(os.path.join(self.graph_results_dir, filename), 'r') as fp:
expected_results = json.load(fp)
assert result == expected_results
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.env_single @pytest.mark.env_single
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
...@@ -59,4 +52,5 @@ class TestQuerySingleNode: ...@@ -59,4 +52,5 @@ class TestQuerySingleNode:
url = get_url(BASE_URL, params) url = get_url(BASE_URL, params)
response = client.get(url) response = client.get(url)
assert response.status_code == 200 assert response.status_code == 200
self.compare_result_with_file(response.get_json(), result_file) file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(response.get_json(), file_path)
...@@ -19,11 +19,11 @@ Usage: ...@@ -19,11 +19,11 @@ Usage:
pytest tests/st/func/datavisual pytest tests/st/func/datavisual
""" """
import os import os
import json
import pytest import pytest
from .. import globals as gbl from .. import globals as gbl
from .....utils.tools import get_url from .....utils.tools import get_url, compare_result_with_file
BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes/names' BASE_URL = '/v1/mindinsight/datavisual/graphs/nodes/names'
...@@ -33,12 +33,6 @@ class TestSearchNodes: ...@@ -33,12 +33,6 @@ class TestSearchNodes:
graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results') graph_results_dir = os.path.join(os.path.dirname(__file__), 'graph_results')
def compare_result_with_file(self, result, filename):
"""Compare result with file which contain the expected results."""
with open(os.path.join(self.graph_results_dir, filename), 'r') as fp:
expected_results = json.load(fp)
assert result == expected_results
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.env_single @pytest.mark.env_single
@pytest.mark.platform_x86_cpu @pytest.mark.platform_x86_cpu
...@@ -59,4 +53,5 @@ class TestSearchNodes: ...@@ -59,4 +53,5 @@ class TestSearchNodes:
url = get_url(BASE_URL, params) url = get_url(BASE_URL, params)
response = client.get(url) response = client.get(url)
assert response.status_code == 200 assert response.status_code == 200
self.compare_result_with_file(response.get_json(), result_file) file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(response.get_json(), file_path)
...@@ -29,7 +29,7 @@ import pytest ...@@ -29,7 +29,7 @@ import pytest
from ..mock import MockLogger from ..mock import MockLogger
from ....utils.log_operations import LogOperations from ....utils.log_operations import LogOperations
from ....utils.tools import check_loading_done, delete_files_or_dirs from ....utils.tools import check_loading_done, delete_files_or_dirs, compare_result_with_file
from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
...@@ -103,12 +103,6 @@ class TestGraphProcessor: ...@@ -103,12 +103,6 @@ class TestGraphProcessor:
# wait for loading done # wait for loading done
check_loading_done(self._mock_data_manager, time_limit=5) check_loading_done(self._mock_data_manager, time_limit=5)
def compare_result_with_file(self, result, filename):
"""Compare result with file which contain the expected results."""
with open(os.path.join(self.graph_results_dir, filename), 'r') as fp:
expected_results = json.load(fp)
assert result == expected_results
def test_get_nodes_with_not_exist_train_id(self, load_graph_record): def test_get_nodes_with_not_exist_train_id(self, load_graph_record):
"""Test getting nodes with not exist train id.""" """Test getting nodes with not exist train id."""
test_train_id = "not_exist_train_id" test_train_id = "not_exist_train_id"
...@@ -152,7 +146,9 @@ class TestGraphProcessor: ...@@ -152,7 +146,9 @@ class TestGraphProcessor:
graph_processor = GraphProcessor(self._train_id, graph_processor = GraphProcessor(self._train_id,
self._mock_data_manager) self._mock_data_manager)
results = graph_processor.get_nodes(name, node_type) results = graph_processor.get_nodes(name, node_type)
self.compare_result_with_file(results, result_file)
expected_file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(results, expected_file_path)
@pytest.mark.parametrize("search_content, result_file", [ @pytest.mark.parametrize("search_content, result_file", [
(None, 'test_search_node_names_with_search_content_expected_results1.json'), (None, 'test_search_node_names_with_search_content_expected_results1.json'),
...@@ -175,7 +171,8 @@ class TestGraphProcessor: ...@@ -175,7 +171,8 @@ class TestGraphProcessor:
expected_results = {'names': []} expected_results = {'names': []}
assert results == expected_results assert results == expected_results
else: else:
self.compare_result_with_file(results, result_file) expected_file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(results, expected_file_path)
@pytest.mark.parametrize("offset", [-100, -1]) @pytest.mark.parametrize("offset", [-100, -1])
def test_search_node_names_with_negative_offset(self, load_graph_record, offset): def test_search_node_names_with_negative_offset(self, load_graph_record, offset):
...@@ -203,7 +200,8 @@ class TestGraphProcessor: ...@@ -203,7 +200,8 @@ class TestGraphProcessor:
results = graph_processor.search_node_names(test_search_content, results = graph_processor.search_node_names(test_search_content,
test_offset, test_offset,
test_limit) test_limit)
self.compare_result_with_file(results, result_file) expected_file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(results, expected_file_path)
def test_search_node_names_with_wrong_limit(self, load_graph_record): def test_search_node_names_with_wrong_limit(self, load_graph_record):
"""Test search node names with wrong limit.""" """Test search node names with wrong limit."""
...@@ -227,7 +225,8 @@ class TestGraphProcessor: ...@@ -227,7 +225,8 @@ class TestGraphProcessor:
graph_processor = GraphProcessor(self._train_id, graph_processor = GraphProcessor(self._train_id,
self._mock_data_manager) self._mock_data_manager)
results = graph_processor.search_single_node(name) results = graph_processor.search_single_node(name)
self.compare_result_with_file(results, result_file) expected_file_path = os.path.join(self.graph_results_dir, result_file)
compare_result_with_file(results, expected_file_path)
def test_search_single_node_with_not_exist_name(self, load_graph_record): def test_search_single_node_with_not_exist_name(self, load_graph_record):
......
# Copyright 2019 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -19,9 +19,13 @@ import io ...@@ -19,9 +19,13 @@ import io
import os import os
import shutil import shutil
import time import time
import json
from urllib.parse import urlencode from urllib.parse import urlencode
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from mindinsight.datavisual.common.enums import DataManagerStatus from mindinsight.datavisual.common.enums import DataManagerStatus
...@@ -69,3 +73,10 @@ def get_image_tensor_from_bytes(image_string): ...@@ -69,3 +73,10 @@ def get_image_tensor_from_bytes(image_string):
image_tensor = np.array(img) image_tensor = np.array(img)
return image_tensor return image_tensor
def compare_result_with_file(result, expected_file_path):
"""Compare result with file which contain the expected results."""
with open(expected_file_path, 'r') as file:
expected_results = json.load(file)
assert result == expected_results
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册