diff --git a/mindinsight/datavisual/data_transform/graph/msgraph.py b/mindinsight/datavisual/data_transform/graph/msgraph.py index e2f01698437b7aa7ff4812dfad1555b39d952aba..0f17084d246614a9b0848a2c1f0641d9a77ed805 100644 --- a/mindinsight/datavisual/data_transform/graph/msgraph.py +++ b/mindinsight/datavisual/data_transform/graph/msgraph.py @@ -96,9 +96,16 @@ class MSGraph(Graph): """ logger.debug("Start to calc input.") for node_def in graph_proto.node: + if not node_def.name: + logger.debug("The node name is empty, ignore it.") + continue node_name = leaf_node_id_map_name[node_def.name] node = self._leaf_nodes[node_name] for input_def in node_def.input: + if not input_def.name: + logger.warning("The input node name is empty, ignore it. node name: %s.", node_name) + continue + edge_type = EdgeTypeEnum.DATA.value if input_def.type == "CONTROL_EDGE": edge_type = EdgeTypeEnum.CONTROL.value diff --git a/mindinsight/datavisual/data_transform/ms_data_loader.py b/mindinsight/datavisual/data_transform/ms_data_loader.py index 668616e5522547f62229ebe32a9ac93a49cf969d..c041357e09ab07881ec0443df4a6092119e15488 100644 --- a/mindinsight/datavisual/data_transform/ms_data_loader.py +++ b/mindinsight/datavisual/data_transform/ms_data_loader.py @@ -60,7 +60,7 @@ class MSDataLoader: self._latest_summary_filename = '' self._latest_summary_file_size = 0 self._summary_file_handler = None - self._latest_pb_file_mtime = 0 + self._pb_parser = _PbParser(summary_dir) def get_events_data(self): """Return events data read from log file.""" @@ -348,14 +348,58 @@ class MSDataLoader: list[str], filename list. """ pb_filenames = self._filter_pb_files(filenames) - pb_filenames = sorted(pb_filenames, key=lambda file: FileHandler.file_stat( - FileHandler.join(self._summary_dir, file)).mtime) + pb_filenames = self._pb_parser.sort_pb_files(pb_filenames) for filename in pb_filenames: - mtime = FileHandler.file_stat(FileHandler.join(self._summary_dir, filename)).mtime - if mtime <= self._latest_pb_file_mtime: + tensor_event = self._pb_parser.parse_pb_file(filename) + if tensor_event is None: continue - self._latest_pb_file_mtime = mtime - self._parse_pb_file(filename) + self._events_data.add_tensor_event(tensor_event) + + +class _PbParser: + """This class is used to parse pb file.""" + + def __init__(self, summary_dir): + self._latest_filename = '' + self._latest_mtime = 0 + self._summary_dir = summary_dir + + def parse_pb_file(self, filename): + """ + Parse single pb file. + + Args: + filename (str): The file path of pb file. + Returns: + TensorEvent, if load pb file and build graph success, will return tensor event, else return None. + """ + if not self._is_parse_pb_file(filename): + return None + + try: + tensor_event = self._parse_pb_file(filename) + return tensor_event + except UnknownError: + # Parse pb file failed, so return None. + return None + + def sort_pb_files(self, filenames): + """Sort by creating time increments and filenames increments.""" + filenames = sorted(filenames, key=lambda file: ( + FileHandler.file_stat(FileHandler.join(self._summary_dir, file)).mtime, file)) + return filenames + + def _is_parse_pb_file(self, filename): + """Determines whether the file should be loaded。""" + mtime = FileHandler.file_stat(FileHandler.join(self._summary_dir, filename)).mtime + + if mtime < self._latest_mtime or \ + (mtime == self._latest_mtime and filename <= self._latest_filename): + return False + + self._latest_mtime = mtime + self._latest_filename = filename + return True def _parse_pb_file(self, filename): """ @@ -363,6 +407,9 @@ class MSDataLoader: Args: filename (str): The file path of pb file. + + Returns: + TensorEvent, if load pb file and build graph success, will return tensor event, else return None. """ file_path = FileHandler.join(self._summary_dir, filename) logger.info("Start to load graph from pb file, file path: %s.", file_path) @@ -372,13 +419,24 @@ class MSDataLoader: model_proto.ParseFromString(filehandler.read()) except ParseError: logger.warning("The given file is not a valid pb file, file path: %s.", file_path) - return + return None graph = MSGraph() - graph.build_graph(model_proto.graph) + + try: + graph.build_graph(model_proto.graph) + except Exception as ex: + # Normally, there are no exceptions, and it is only possible for users on the MindSpore side + # to dump other non-default graphs. + logger.error("Build graph failed, file path: %s.", file_path) + logger.exception(ex) + raise UnknownError(str(ex)) + tensor_event = TensorEvent(wall_time=FileHandler.file_stat(file_path), step=0, tag=filename, plugin_name=PluginNameEnum.GRAPH.value, value=graph) - self._events_data.add_tensor_event(tensor_event) + + logger.info("Build graph success, file path: %s.", file_path) + return tensor_event diff --git a/tests/ut/datavisual/data_transform/test_ms_data_loader.py b/tests/ut/datavisual/data_transform/test_ms_data_loader.py index bcbe329c4e5d6794fbedcc85e498e3f14d849081..bd53c757a961cf61fce0d181ba469867efd57a66 100644 --- a/tests/ut/datavisual/data_transform/test_ms_data_loader.py +++ b/tests/ut/datavisual/data_transform/test_ms_data_loader.py @@ -27,8 +27,12 @@ import pytest from mindinsight.datavisual.data_transform import ms_data_loader from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader +from mindinsight.datavisual.data_transform.ms_data_loader import _PbParser +from mindinsight.datavisual.data_transform.events_data import TensorEvent +from mindinsight.datavisual.common.enums import PluginNameEnum from ..mock import MockLogger +from ....utils.log_generators.graph_pb_generator import create_graph_pb_file # bytes of 3 scalar events SCALAR_RECORD = (b'\x1e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x96\xe1\xeb)>}\xd7A\x10\x01*' @@ -69,9 +73,9 @@ class TestMsDataLoader: summary_dir = tempfile.mkdtemp() ms_loader = MSDataLoader(summary_dir) ms_loader._check_files_deleted(new_file_list, old_file_list) + shutil.rmtree(summary_dir) assert MockLogger.log_msg['warning'] == "There are some files has been deleted, " \ "we will reload all files in path {}.".format(summary_dir) - shutil.rmtree(summary_dir) @pytest.mark.usefixtures('crc_pass') def test_load_success_with_crc_pass(self): @@ -96,8 +100,8 @@ class TestMsDataLoader: write_file(file2, SCALAR_RECORD) ms_loader = MSDataLoader(summary_dir) ms_loader.load() - assert 'Check crc faild and ignore this file' in str(MockLogger.log_msg['warning']) shutil.rmtree(summary_dir) + assert 'Check crc faild and ignore this file' in str(MockLogger.log_msg['warning']) def test_filter_event_files(self): """Test filter_event_files function ok.""" @@ -112,9 +116,83 @@ class TestMsDataLoader: ms_loader = MSDataLoader(summary_dir) res = ms_loader.filter_valid_files() expected = sorted(['aaasummary.5678', 'summary.0012', 'hellosummary.98786', 'mysummary.123abce']) + shutil.rmtree(summary_dir) assert sorted(res) == expected + def test_load_single_pb_file(self): + """Test load pb file success.""" + filename = 'ms_output.pb' + summary_dir = tempfile.mkdtemp() + create_graph_pb_file(output_dir=summary_dir, filename=filename) + ms_loader = MSDataLoader(summary_dir) + ms_loader.load() + events_data = ms_loader.get_events_data() + plugins = events_data.list_tags_by_plugin(PluginNameEnum.GRAPH.value) shutil.rmtree(summary_dir) + assert len(plugins) == 1 + assert plugins[0] == filename + + +class TestPbParser: + """Test pb parser""" + _summary_dir = '' + + def setup_method(self): + self._summary_dir = tempfile.mkdtemp() + + def teardown_method(self): + shutil.rmtree(self._summary_dir) + + def test_parse_pb_file(self): + """Test parse pb file success.""" + filename = 'ms_output.pb' + create_graph_pb_file(output_dir=self._summary_dir, filename=filename) + parser = _PbParser(self._summary_dir) + tensor_event = parser.parse_pb_file(filename) + assert isinstance(tensor_event, TensorEvent) + + def test_is_parse_pb_file(self): + """Test parse an older file.""" + filename = 'ms_output.pb' + create_graph_pb_file(output_dir=self._summary_dir, filename=filename) + parser = _PbParser(self._summary_dir) + result = parser._is_parse_pb_file(filename) + assert result + + filename = 'ms_output_older.pb' + file_path = create_graph_pb_file(output_dir=self._summary_dir, filename=filename) + atime = 1 + mtime = 1 + os.utime(file_path, (atime, mtime)) + result = parser._is_parse_pb_file(filename) + assert not result + + def test_sort_pb_file_by_mtime(self): + """Test sort pb files.""" + filenames = ['abc.pb', 'bbc.pb'] + for file in filenames: + create_graph_pb_file(output_dir=self._summary_dir, filename=file) + parser = _PbParser(self._summary_dir) + + sorted_filenames = parser.sort_pb_files(filenames) + assert filenames == sorted_filenames + + def test_sort_pb_file_by_filename(self): + """Test sort pb file by file name.""" + filenames = ['aaa.pb', 'bbb.pb', 'ccc.pb'] + for file in filenames: + create_graph_pb_file(output_dir=self._summary_dir, filename=file) + + atime, mtime = (3, 3) + os.utime(os.path.realpath(os.path.join(self._summary_dir, 'aaa.pb')), (atime, mtime)) + atime, mtime = (1, 1) + os.utime(os.path.realpath(os.path.join(self._summary_dir, 'bbb.pb')), (atime, mtime)) + os.utime(os.path.realpath(os.path.join(self._summary_dir, 'ccc.pb')), (atime, mtime)) + + expected_filenames = ['bbb.pb', 'ccc.pb', 'aaa.pb'] + parser = _PbParser(self._summary_dir) + sorted_filenames = parser.sort_pb_files(filenames) + assert expected_filenames == sorted_filenames def write_file(filename, record): diff --git a/tests/utils/log_generators/graph_pb_generator.py b/tests/utils/log_generators/graph_pb_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..7019d87704d720bb71926575c7652ba3f75b9984 --- /dev/null +++ b/tests/utils/log_generators/graph_pb_generator.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Log generator for graph pb file.""" +import os +import json + +from google.protobuf import json_format + +from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2 + + +def create_graph_pb_file(output_dir='./', filename='ms_output.pb'): + """Create graph pb file, and return file path.""" + graph_base = os.path.join(os.path.dirname(__file__), "graph_base.json") + with open(graph_base, 'r') as fp: + data = json.load(fp) + model_def = dict(graph=data) + model_proto = json_format.Parse(json.dumps(model_def), anf_ir_pb2.ModelProto()) + msg = model_proto.SerializeToString() + output_path = os.path.realpath(os.path.join(output_dir, filename)) + with open(output_path, 'wb') as fp: + fp.write(msg) + + return output_path + + +if __name__ == '__main__': + create_graph_pb_file()