提交 c60e9624 编写于 作者: O ougongchang

fixing can not find node exception

fix ut for changes
上级 ecab5e89
...@@ -76,8 +76,11 @@ class SummaryLogIsLoading(MindInsightException): ...@@ -76,8 +76,11 @@ class SummaryLogIsLoading(MindInsightException):
class NodeNotInGraphError(MindInsightException): class NodeNotInGraphError(MindInsightException):
"""Can not find node in graph error.""" """Can not find node in graph error."""
def __init__(self): def __init__(self, node_name, node_type=None):
error_msg = "Can not find node in graph by given node name." if node_type is not None:
error_msg = f"Can not find node in graph by the given node name. node name: {node_name}, type: {node_type}."
else:
error_msg = f"Can not find node in graph by the given node name. node name: {node_name}."
super(NodeNotInGraphError, self).__init__(DataVisualErrors.NODE_NOT_IN_GRAPH_ERROR, super(NodeNotInGraphError, self).__init__(DataVisualErrors.NODE_NOT_IN_GRAPH_ERROR,
error_msg, error_msg,
http_code=400) http_code=400)
......
...@@ -21,7 +21,7 @@ import time ...@@ -21,7 +21,7 @@ import time
from enum import Enum from enum import Enum
from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common.exceptions import NodeNotInGraphError
from .node import NodeTypeEnum from .node import NodeTypeEnum
from .node import Node from .node import Node
...@@ -151,7 +151,7 @@ class Graph: ...@@ -151,7 +151,7 @@ class Graph:
""" """
if node_name and self._polymeric_nodes.get(node_name) is None \ if node_name and self._polymeric_nodes.get(node_name) is None \
and self._normal_nodes.get(node_name) is None: and self._normal_nodes.get(node_name) is None:
raise exceptions.NodeNotInGraphError() raise NodeNotInGraphError(node_name=node_name)
response = {} response = {}
nodes = self.get_normal_nodes() nodes = self.get_normal_nodes()
......
...@@ -82,7 +82,7 @@ class MSGraph(Graph): ...@@ -82,7 +82,7 @@ class MSGraph(Graph):
self._calc_output() self._calc_output()
logger.info("Build leaf nodes end, normal nodes count: %s, group count: %s, " logger.info("Build leaf nodes end, normal nodes count: %s, group count: %s, "
"left node count: %s.", len(self._normal_nodes), len(self._node_groups), "leaf nodes count: %s.", len(self._normal_nodes), len(self._node_groups),
len(self._leaf_nodes)) len(self._leaf_nodes))
def _calc_input(self, leaf_node_id_map_name, graph_proto, const_nodes_map): def _calc_input(self, leaf_node_id_map_name, graph_proto, const_nodes_map):
......
...@@ -23,6 +23,7 @@ from mindinsight.datavisual.common.validation import Validation ...@@ -23,6 +23,7 @@ from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.data_transform.graph import NodeTypeEnum from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
from mindinsight.datavisual.processors.base_processor import BaseProcessor from mindinsight.datavisual.processors.base_processor import BaseProcessor
from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamValueError
from mindinsight.datavisual.common.exceptions import NodeNotInGraphError
class GraphProcessor(BaseProcessor): class GraphProcessor(BaseProcessor):
...@@ -95,15 +96,15 @@ class GraphProcessor(BaseProcessor): ...@@ -95,15 +96,15 @@ class GraphProcessor(BaseProcessor):
'' % (NodeTypeEnum.NAME_SCOPE.value, NodeTypeEnum.POLYMERIC_SCOPE.value)) '' % (NodeTypeEnum.NAME_SCOPE.value, NodeTypeEnum.POLYMERIC_SCOPE.value))
if name and not self._graph.exist_node(name): if name and not self._graph.exist_node(name):
raise ParamValueError("The node name is not in graph.") raise NodeNotInGraphError(node_name=name, node_type=node_type)
nodes = [] nodes = []
if node_type == NodeTypeEnum.NAME_SCOPE.value: if node_type == NodeTypeEnum.NAME_SCOPE.value:
nodes = self._graph.get_normal_nodes(name) nodes = self._graph.get_normal_nodes(name)
if node_type == NodeTypeEnum.POLYMERIC_SCOPE.value: if node_type == NodeTypeEnum.POLYMERIC_SCOPE.value:
if not name: if not name:
raise ParamValueError('The node name "%s" not in graph, node type is %s.' % raise NodeNotInGraphError(node_name=name, node_type=node_type)
(name, node_type))
polymeric_scope_name = name polymeric_scope_name = name
nodes = self._graph.get_polymeric_nodes(polymeric_scope_name) nodes = self._graph.get_polymeric_nodes(polymeric_scope_name)
......
...@@ -27,6 +27,7 @@ import pytest ...@@ -27,6 +27,7 @@ import pytest
from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.common.exceptions import GraphNotExistError from mindinsight.datavisual.common.exceptions import GraphNotExistError
from mindinsight.datavisual.common.exceptions import NodeNotInGraphError
from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform import data_manager
from mindinsight.datavisual.data_transform.data_manager import DataManager from mindinsight.datavisual.data_transform.data_manager import DataManager
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
...@@ -120,14 +121,11 @@ class TestGraphProcessor: ...@@ -120,14 +121,11 @@ class TestGraphProcessor:
@pytest.mark.parametrize("name, node_type", [("not_exist_name", "name_scope"), ("", "polymeric_scope")]) @pytest.mark.parametrize("name, node_type", [("not_exist_name", "name_scope"), ("", "polymeric_scope")])
def test_get_nodes_with_not_exist_name(self, name, node_type): def test_get_nodes_with_not_exist_name(self, name, node_type):
"""Test getting nodes with not exist name.""" """Test getting nodes with not exist name."""
with pytest.raises(ParamValueError) as exc_info: with pytest.raises(NodeNotInGraphError) as exc_info:
graph_processor = GraphProcessor(self._train_id, self._mock_data_manager) graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
graph_processor.get_nodes(name, node_type) graph_processor.get_nodes(name, node_type)
if name: assert 'Can not find node in graph by the given node name' in exc_info.value.message
assert "The node name is not in graph." in exc_info.value.message
else:
assert f'The node name "{name}" not in graph, node type is {node_type}.' in exc_info.value.message
@pytest.mark.usefixtures('load_graph_record') @pytest.mark.usefixtures('load_graph_record')
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册