提交 fd6bb5e4 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!53 fix parsing pb file bug and abstracting pb parsing

Merge pull request !53 from ougongchang/fixbug_pb_file
...@@ -96,9 +96,16 @@ class MSGraph(Graph): ...@@ -96,9 +96,16 @@ class MSGraph(Graph):
""" """
logger.debug("Start to calc input.") logger.debug("Start to calc input.")
for node_def in graph_proto.node: for node_def in graph_proto.node:
if not node_def.name:
logger.debug("The node name is empty, ignore it.")
continue
node_name = leaf_node_id_map_name[node_def.name] node_name = leaf_node_id_map_name[node_def.name]
node = self._leaf_nodes[node_name] node = self._leaf_nodes[node_name]
for input_def in node_def.input: for input_def in node_def.input:
if not input_def.name:
logger.warning("The input node name is empty, ignore it. node name: %s.", node_name)
continue
edge_type = EdgeTypeEnum.DATA.value edge_type = EdgeTypeEnum.DATA.value
if input_def.type == "CONTROL_EDGE": if input_def.type == "CONTROL_EDGE":
edge_type = EdgeTypeEnum.CONTROL.value edge_type = EdgeTypeEnum.CONTROL.value
......
...@@ -60,7 +60,7 @@ class MSDataLoader: ...@@ -60,7 +60,7 @@ class MSDataLoader:
self._latest_summary_filename = '' self._latest_summary_filename = ''
self._latest_summary_file_size = 0 self._latest_summary_file_size = 0
self._summary_file_handler = None self._summary_file_handler = None
self._latest_pb_file_mtime = 0 self._pb_parser = _PbParser(summary_dir)
def get_events_data(self): def get_events_data(self):
"""Return events data read from log file.""" """Return events data read from log file."""
...@@ -348,14 +348,58 @@ class MSDataLoader: ...@@ -348,14 +348,58 @@ class MSDataLoader:
list[str], filename list. list[str], filename list.
""" """
pb_filenames = self._filter_pb_files(filenames) pb_filenames = self._filter_pb_files(filenames)
pb_filenames = sorted(pb_filenames, key=lambda file: FileHandler.file_stat( pb_filenames = self._pb_parser.sort_pb_files(pb_filenames)
FileHandler.join(self._summary_dir, file)).mtime)
for filename in pb_filenames: for filename in pb_filenames:
mtime = FileHandler.file_stat(FileHandler.join(self._summary_dir, filename)).mtime tensor_event = self._pb_parser.parse_pb_file(filename)
if mtime <= self._latest_pb_file_mtime: if tensor_event is None:
continue continue
self._latest_pb_file_mtime = mtime self._events_data.add_tensor_event(tensor_event)
self._parse_pb_file(filename)
class _PbParser:
"""This class is used to parse pb file."""
def __init__(self, summary_dir):
self._latest_filename = ''
self._latest_mtime = 0
self._summary_dir = summary_dir
def parse_pb_file(self, filename):
"""
Parse single pb file.
Args:
filename (str): The file path of pb file.
Returns:
TensorEvent, if load pb file and build graph success, will return tensor event, else return None.
"""
if not self._is_parse_pb_file(filename):
return None
try:
tensor_event = self._parse_pb_file(filename)
return tensor_event
except UnknownError:
# Parse pb file failed, so return None.
return None
def sort_pb_files(self, filenames):
"""Sort by creating time increments and filenames increments."""
filenames = sorted(filenames, key=lambda file: (
FileHandler.file_stat(FileHandler.join(self._summary_dir, file)).mtime, file))
return filenames
def _is_parse_pb_file(self, filename):
"""Determines whether the file should be loaded。"""
mtime = FileHandler.file_stat(FileHandler.join(self._summary_dir, filename)).mtime
if mtime < self._latest_mtime or \
(mtime == self._latest_mtime and filename <= self._latest_filename):
return False
self._latest_mtime = mtime
self._latest_filename = filename
return True
def _parse_pb_file(self, filename): def _parse_pb_file(self, filename):
""" """
...@@ -363,6 +407,9 @@ class MSDataLoader: ...@@ -363,6 +407,9 @@ class MSDataLoader:
Args: Args:
filename (str): The file path of pb file. filename (str): The file path of pb file.
Returns:
TensorEvent, if load pb file and build graph success, will return tensor event, else return None.
""" """
file_path = FileHandler.join(self._summary_dir, filename) file_path = FileHandler.join(self._summary_dir, filename)
logger.info("Start to load graph from pb file, file path: %s.", file_path) logger.info("Start to load graph from pb file, file path: %s.", file_path)
...@@ -372,13 +419,24 @@ class MSDataLoader: ...@@ -372,13 +419,24 @@ class MSDataLoader:
model_proto.ParseFromString(filehandler.read()) model_proto.ParseFromString(filehandler.read())
except ParseError: except ParseError:
logger.warning("The given file is not a valid pb file, file path: %s.", file_path) logger.warning("The given file is not a valid pb file, file path: %s.", file_path)
return return None
graph = MSGraph() graph = MSGraph()
try:
graph.build_graph(model_proto.graph) graph.build_graph(model_proto.graph)
except Exception as ex:
# Normally, there are no exceptions, and it is only possible for users on the MindSpore side
# to dump other non-default graphs.
logger.error("Build graph failed, file path: %s.", file_path)
logger.exception(ex)
raise UnknownError(str(ex))
tensor_event = TensorEvent(wall_time=FileHandler.file_stat(file_path), tensor_event = TensorEvent(wall_time=FileHandler.file_stat(file_path),
step=0, step=0,
tag=filename, tag=filename,
plugin_name=PluginNameEnum.GRAPH.value, plugin_name=PluginNameEnum.GRAPH.value,
value=graph) value=graph)
self._events_data.add_tensor_event(tensor_event)
logger.info("Build graph success, file path: %s.", file_path)
return tensor_event
...@@ -27,8 +27,12 @@ import pytest ...@@ -27,8 +27,12 @@ import pytest
from mindinsight.datavisual.data_transform import ms_data_loader from mindinsight.datavisual.data_transform import ms_data_loader
from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader
from mindinsight.datavisual.data_transform.ms_data_loader import _PbParser
from mindinsight.datavisual.data_transform.events_data import TensorEvent
from mindinsight.datavisual.common.enums import PluginNameEnum
from ..mock import MockLogger from ..mock import MockLogger
from ....utils.log_generators.graph_pb_generator import create_graph_pb_file
# bytes of 3 scalar events # bytes of 3 scalar events
SCALAR_RECORD = (b'\x1e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x96\xe1\xeb)>}\xd7A\x10\x01*' SCALAR_RECORD = (b'\x1e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x96\xe1\xeb)>}\xd7A\x10\x01*'
...@@ -69,9 +73,9 @@ class TestMsDataLoader: ...@@ -69,9 +73,9 @@ class TestMsDataLoader:
summary_dir = tempfile.mkdtemp() summary_dir = tempfile.mkdtemp()
ms_loader = MSDataLoader(summary_dir) ms_loader = MSDataLoader(summary_dir)
ms_loader._check_files_deleted(new_file_list, old_file_list) ms_loader._check_files_deleted(new_file_list, old_file_list)
shutil.rmtree(summary_dir)
assert MockLogger.log_msg['warning'] == "There are some files has been deleted, " \ assert MockLogger.log_msg['warning'] == "There are some files has been deleted, " \
"we will reload all files in path {}.".format(summary_dir) "we will reload all files in path {}.".format(summary_dir)
shutil.rmtree(summary_dir)
@pytest.mark.usefixtures('crc_pass') @pytest.mark.usefixtures('crc_pass')
def test_load_success_with_crc_pass(self): def test_load_success_with_crc_pass(self):
...@@ -96,8 +100,8 @@ class TestMsDataLoader: ...@@ -96,8 +100,8 @@ class TestMsDataLoader:
write_file(file2, SCALAR_RECORD) write_file(file2, SCALAR_RECORD)
ms_loader = MSDataLoader(summary_dir) ms_loader = MSDataLoader(summary_dir)
ms_loader.load() ms_loader.load()
assert 'Check crc faild and ignore this file' in str(MockLogger.log_msg['warning'])
shutil.rmtree(summary_dir) shutil.rmtree(summary_dir)
assert 'Check crc faild and ignore this file' in str(MockLogger.log_msg['warning'])
def test_filter_event_files(self): def test_filter_event_files(self):
"""Test filter_event_files function ok.""" """Test filter_event_files function ok."""
...@@ -112,9 +116,83 @@ class TestMsDataLoader: ...@@ -112,9 +116,83 @@ class TestMsDataLoader:
ms_loader = MSDataLoader(summary_dir) ms_loader = MSDataLoader(summary_dir)
res = ms_loader.filter_valid_files() res = ms_loader.filter_valid_files()
expected = sorted(['aaasummary.5678', 'summary.0012', 'hellosummary.98786', 'mysummary.123abce']) expected = sorted(['aaasummary.5678', 'summary.0012', 'hellosummary.98786', 'mysummary.123abce'])
shutil.rmtree(summary_dir)
assert sorted(res) == expected assert sorted(res) == expected
def test_load_single_pb_file(self):
"""Test load pb file success."""
filename = 'ms_output.pb'
summary_dir = tempfile.mkdtemp()
create_graph_pb_file(output_dir=summary_dir, filename=filename)
ms_loader = MSDataLoader(summary_dir)
ms_loader.load()
events_data = ms_loader.get_events_data()
plugins = events_data.list_tags_by_plugin(PluginNameEnum.GRAPH.value)
shutil.rmtree(summary_dir) shutil.rmtree(summary_dir)
assert len(plugins) == 1
assert plugins[0] == filename
class TestPbParser:
"""Test pb parser"""
_summary_dir = ''
def setup_method(self):
self._summary_dir = tempfile.mkdtemp()
def teardown_method(self):
shutil.rmtree(self._summary_dir)
def test_parse_pb_file(self):
"""Test parse pb file success."""
filename = 'ms_output.pb'
create_graph_pb_file(output_dir=self._summary_dir, filename=filename)
parser = _PbParser(self._summary_dir)
tensor_event = parser.parse_pb_file(filename)
assert isinstance(tensor_event, TensorEvent)
def test_is_parse_pb_file(self):
"""Test parse an older file."""
filename = 'ms_output.pb'
create_graph_pb_file(output_dir=self._summary_dir, filename=filename)
parser = _PbParser(self._summary_dir)
result = parser._is_parse_pb_file(filename)
assert result
filename = 'ms_output_older.pb'
file_path = create_graph_pb_file(output_dir=self._summary_dir, filename=filename)
atime = 1
mtime = 1
os.utime(file_path, (atime, mtime))
result = parser._is_parse_pb_file(filename)
assert not result
def test_sort_pb_file_by_mtime(self):
"""Test sort pb files."""
filenames = ['abc.pb', 'bbc.pb']
for file in filenames:
create_graph_pb_file(output_dir=self._summary_dir, filename=file)
parser = _PbParser(self._summary_dir)
sorted_filenames = parser.sort_pb_files(filenames)
assert filenames == sorted_filenames
def test_sort_pb_file_by_filename(self):
"""Test sort pb file by file name."""
filenames = ['aaa.pb', 'bbb.pb', 'ccc.pb']
for file in filenames:
create_graph_pb_file(output_dir=self._summary_dir, filename=file)
atime, mtime = (3, 3)
os.utime(os.path.realpath(os.path.join(self._summary_dir, 'aaa.pb')), (atime, mtime))
atime, mtime = (1, 1)
os.utime(os.path.realpath(os.path.join(self._summary_dir, 'bbb.pb')), (atime, mtime))
os.utime(os.path.realpath(os.path.join(self._summary_dir, 'ccc.pb')), (atime, mtime))
expected_filenames = ['bbb.pb', 'ccc.pb', 'aaa.pb']
parser = _PbParser(self._summary_dir)
sorted_filenames = parser.sort_pb_files(filenames)
assert expected_filenames == sorted_filenames
def write_file(filename, record): def write_file(filename, record):
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Log generator for graph pb file."""
import os
import json
from google.protobuf import json_format
from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2
def create_graph_pb_file(output_dir='./', filename='ms_output.pb'):
"""Create graph pb file, and return file path."""
graph_base = os.path.join(os.path.dirname(__file__), "graph_base.json")
with open(graph_base, 'r') as fp:
data = json.load(fp)
model_def = dict(graph=data)
model_proto = json_format.Parse(json.dumps(model_def), anf_ir_pb2.ModelProto())
msg = model_proto.SerializeToString()
output_path = os.path.realpath(os.path.join(output_dir, filename))
with open(output_path, 'wb') as fp:
fp.write(msg)
return output_path
if __name__ == '__main__':
create_graph_pb_file()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册