提交 c60e9624 编写于 作者: O ougongchang

fixing can not find node exception

fix ut for changes
上级 ecab5e89
......@@ -76,8 +76,11 @@ class SummaryLogIsLoading(MindInsightException):
class NodeNotInGraphError(MindInsightException):
"""Can not find node in graph error."""
def __init__(self):
error_msg = "Can not find node in graph by given node name."
def __init__(self, node_name, node_type=None):
if node_type is not None:
error_msg = f"Can not find node in graph by the given node name. node name: {node_name}, type: {node_type}."
else:
error_msg = f"Can not find node in graph by the given node name. node name: {node_name}."
super(NodeNotInGraphError, self).__init__(DataVisualErrors.NODE_NOT_IN_GRAPH_ERROR,
error_msg,
http_code=400)
......
......@@ -21,7 +21,7 @@ import time
from enum import Enum
from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.exceptions import NodeNotInGraphError
from .node import NodeTypeEnum
from .node import Node
......@@ -151,7 +151,7 @@ class Graph:
"""
if node_name and self._polymeric_nodes.get(node_name) is None \
and self._normal_nodes.get(node_name) is None:
raise exceptions.NodeNotInGraphError()
raise NodeNotInGraphError(node_name=node_name)
response = {}
nodes = self.get_normal_nodes()
......
......@@ -82,7 +82,7 @@ class MSGraph(Graph):
self._calc_output()
logger.info("Build leaf nodes end, normal nodes count: %s, group count: %s, "
"left node count: %s.", len(self._normal_nodes), len(self._node_groups),
"leaf nodes count: %s.", len(self._normal_nodes), len(self._node_groups),
len(self._leaf_nodes))
def _calc_input(self, leaf_node_id_map_name, graph_proto, const_nodes_map):
......
......@@ -23,6 +23,7 @@ from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
from mindinsight.datavisual.processors.base_processor import BaseProcessor
from mindinsight.utils.exceptions import ParamValueError
from mindinsight.datavisual.common.exceptions import NodeNotInGraphError
class GraphProcessor(BaseProcessor):
......@@ -95,15 +96,15 @@ class GraphProcessor(BaseProcessor):
'' % (NodeTypeEnum.NAME_SCOPE.value, NodeTypeEnum.POLYMERIC_SCOPE.value))
if name and not self._graph.exist_node(name):
raise ParamValueError("The node name is not in graph.")
raise NodeNotInGraphError(node_name=name, node_type=node_type)
nodes = []
if node_type == NodeTypeEnum.NAME_SCOPE.value:
nodes = self._graph.get_normal_nodes(name)
if node_type == NodeTypeEnum.POLYMERIC_SCOPE.value:
if not name:
raise ParamValueError('The node name "%s" not in graph, node type is %s.' %
(name, node_type))
raise NodeNotInGraphError(node_name=name, node_type=node_type)
polymeric_scope_name = name
nodes = self._graph.get_polymeric_nodes(polymeric_scope_name)
......
......@@ -27,6 +27,7 @@ import pytest
from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.common.exceptions import GraphNotExistError
from mindinsight.datavisual.common.exceptions import NodeNotInGraphError
from mindinsight.datavisual.data_transform import data_manager
from mindinsight.datavisual.data_transform.data_manager import DataManager
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
......@@ -120,14 +121,11 @@ class TestGraphProcessor:
@pytest.mark.parametrize("name, node_type", [("not_exist_name", "name_scope"), ("", "polymeric_scope")])
def test_get_nodes_with_not_exist_name(self, name, node_type):
"""Test getting nodes with not exist name."""
with pytest.raises(ParamValueError) as exc_info:
with pytest.raises(NodeNotInGraphError) as exc_info:
graph_processor = GraphProcessor(self._train_id, self._mock_data_manager)
graph_processor.get_nodes(name, node_type)
if name:
assert "The node name is not in graph." in exc_info.value.message
else:
assert f'The node name "{name}" not in graph, node type is {node_type}.' in exc_info.value.message
assert 'Can not find node in graph by the given node name' in exc_info.value.message
@pytest.mark.usefixtures('load_graph_record')
@pytest.mark.parametrize(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册