提交 22816869 编写于 作者: O ougongchang

modify the ms data loader and abstract the parser class

上级 8301a7fb
......@@ -51,16 +51,13 @@ class MSDataLoader:
"""
def __init__(self, summary_dir):
self._init_instance(summary_dir)
def _init_instance(self, summary_dir):
self._summary_dir = summary_dir
self._valid_filenames = []
self._events_data = EventsData()
self._latest_summary_filename = ''
self._latest_summary_file_size = 0
self._summary_file_handler = None
self._pb_parser = _PbParser(summary_dir)
self._parser_list = []
self._parser_list.append(_SummaryParser(summary_dir))
self._parser_list.append(_PbParser(summary_dir))
def get_events_data(self):
"""Return events data read from log file."""
......@@ -78,7 +75,7 @@ class MSDataLoader:
if deleted_files:
logger.warning("There are some files has been deleted, "
"we will reload all files in path %s.", self._summary_dir)
self._init_instance(self._summary_dir)
self.__init__(self._summary_dir)
def load(self):
"""
......@@ -95,42 +92,216 @@ class MSDataLoader:
self._valid_filenames = filenames
self._check_files_deleted(filenames, old_filenames)
self._load_summary_files(self._valid_filenames)
self._load_pb_files(self._valid_filenames)
for parser in self._parser_list:
parser.parse_files(filenames, events_data=self._events_data)
def filter_valid_files(self):
"""
Gets a list of valid files from the given file path.
Returns:
list[str], file name list.
"""
filenames = []
for filename in FileHandler.list_dir(self._summary_dir):
if FileHandler.is_file(FileHandler.join(self._summary_dir, filename)):
filenames.append(filename)
valid_filenames = []
for parser in self._parser_list:
valid_filenames.extend(parser.filter_files(filenames))
return list(set(valid_filenames))
class _Parser:
"""Parsed base class."""
def __init__(self, summary_dir):
self._latest_filename = ''
self._latest_mtime = 0
self._summary_dir = summary_dir
def parse_files(self, filenames, events_data):
"""
Load files and parse files content.
Args:
filenames (list[str]): File name list.
events_data (EventsData): The container of event data.
"""
raise NotImplementedError
def sort_files(self, filenames):
"""Sort by modify time increments and filenames increments."""
filenames = sorted(filenames, key=lambda file: (
FileHandler.file_stat(FileHandler.join(self._summary_dir, file)).mtime, file))
return filenames
def filter_files(self, filenames):
"""
Gets a list of files that this parsing class can parse.
Args:
filenames (list[str]): File name list, like [filename1, filename2].
Returns:
list[str], filename list.
"""
raise NotImplementedError
def _set_latest_file(self, filename):
"""
Check if the file's modification time is newer than the last time it was loaded, and if so, set the time.
Args:
filename (str): The file name that needs to be checked and set.
Returns:
bool, Returns True if the file was modified earlier than the last time it was loaded, or False.
"""
mtime = FileHandler.file_stat(FileHandler.join(self._summary_dir, filename)).mtime
if mtime < self._latest_mtime or \
(mtime == self._latest_mtime and filename <= self._latest_filename):
return False
self._latest_mtime = mtime
self._latest_filename = filename
return True
class _PbParser(_Parser):
"""This class is used to parse pb file."""
def parse_files(self, filenames, events_data):
pb_filenames = self.filter_files(filenames)
pb_filenames = self.sort_files(pb_filenames)
for filename in pb_filenames:
if not self._set_latest_file(filename):
continue
try:
tensor_event = self._parse_pb_file(filename)
except UnknownError:
# Parse pb file failed, so return None.
continue
events_data.add_tensor_event(tensor_event)
def filter_files(self, filenames):
"""
Get a list of pb files.
Args:
filenames (list[str]): File name list, like [filename1, filename2].
Returns:
list[str], filename list.
"""
return list(filter(lambda filename: re.search(r'\.pb$', filename), filenames))
def _parse_pb_file(self, filename):
"""
Parse pb file and write content to `EventsData`.
Args:
filename (str): The file path of pb file.
Returns:
TensorEvent, if load pb file and build graph success, will return tensor event, else return None.
"""
file_path = FileHandler.join(self._summary_dir, filename)
logger.info("Start to load graph from pb file, file path: %s.", file_path)
filehandler = FileHandler(file_path)
model_proto = anf_ir_pb2.ModelProto()
try:
model_proto.ParseFromString(filehandler.read())
except ParseError:
logger.warning("The given file is not a valid pb file, file path: %s.", file_path)
return None
graph = MSGraph()
try:
graph.build_graph(model_proto.graph)
except Exception as ex:
# Normally, there are no exceptions, and it is only possible for users on the MindSpore side
# to dump other non-default graphs.
logger.error("Build graph failed, file path: %s.", file_path)
logger.exception(ex)
raise UnknownError(str(ex))
tensor_event = TensorEvent(wall_time=FileHandler.file_stat(file_path),
step=0,
tag=filename,
plugin_name=PluginNameEnum.GRAPH.value,
value=graph,
filename=filename)
logger.info("Build graph success, file path: %s.", file_path)
return tensor_event
def _load_summary_files(self, filenames):
class _SummaryParser(_Parser):
"""The summary file parser."""
def __init__(self, summary_dir):
super(_SummaryParser, self).__init__(summary_dir)
self._latest_file_size = 0
self._summary_file_handler = None
self._events_data = None
def parse_files(self, filenames, events_data):
"""
Load summary file and parse file content.
Args:
filenames (list[str]): File name list.
events_data (EventsData): The container of event data.
"""
summary_files = self._filter_summary_files(filenames)
summary_files = self._sorted_summary_files(summary_files)
self._events_data = events_data
summary_files = self.filter_files(filenames)
summary_files = self.sort_files(summary_files)
for filename in summary_files:
if self._latest_summary_filename and \
(self._compare_summary_file(self._latest_summary_filename, filename)):
if self._latest_filename and \
(self._compare_summary_file(self._latest_filename, filename)):
continue
file_path = FileHandler.join(self._summary_dir, filename)
if filename != self._latest_summary_filename:
if filename != self._latest_filename:
self._summary_file_handler = FileHandler(file_path, 'rb')
self._latest_summary_filename = filename
self._latest_summary_file_size = 0
self._latest_filename = filename
self._latest_file_size = 0
new_size = FileHandler.file_stat(file_path).size
if new_size == self._latest_summary_file_size:
if new_size == self._latest_file_size:
continue
self._latest_summary_file_size = new_size
self._latest_file_size = new_size
try:
self._load_single_file(self._summary_file_handler)
except UnknownError as ex:
logger.warning("Parse summary file failed, detail: %r,"
"file path: %s.", str(ex), file_path)
def filter_files(self, filenames):
"""
Gets a list of summary files.
Args:
filenames (list[str]): File name list, like [filename1, filename2].
Returns:
list[str], filename list.
"""
return list(filter(
lambda filename: (re.search(r'summary\.\d+', filename)
and not filename.endswith("_lineage")), filenames))
def _load_single_file(self, file_handler):
"""
Load a log file data.
......@@ -224,7 +395,7 @@ class MSDataLoader:
tag=tag,
plugin_name=PluginNameEnum.SCALAR.value,
value=value.scalar_value,
filename=self._latest_summary_filename)
filename=self._latest_filename)
self._events_data.add_tensor_event(tensor_event)
if value.HasField('image'):
......@@ -234,7 +405,7 @@ class MSDataLoader:
tag=tag,
plugin_name=PluginNameEnum.IMAGE.value,
value=value.image,
filename=self._latest_summary_filename)
filename=self._latest_filename)
self._events_data.add_tensor_event(tensor_event)
if value.HasField('histogram'):
......@@ -250,7 +421,7 @@ class MSDataLoader:
tag=tag,
plugin_name=PluginNameEnum.HISTOGRAM.value,
value=histogram_msg,
filename=self._latest_summary_filename)
filename=self._latest_filename)
self._events_data.add_tensor_event(tensor_event)
if event.HasField('graph_def'):
......@@ -259,54 +430,21 @@ class MSDataLoader:
graph.build_graph(graph_proto)
tensor_event = TensorEvent(wall_time=event.wall_time,
step=event.step,
tag=self._latest_summary_filename,
tag=self._latest_filename,
plugin_name=PluginNameEnum.GRAPH.value,
value=graph,
filename=self._latest_summary_filename)
filename=self._latest_filename)
try:
graph_tags = self._events_data.list_tags_by_plugin(PluginNameEnum.GRAPH.value)
except KeyError:
graph_tags = []
summary_tags = self._filter_summary_files(graph_tags)
summary_tags = self.filter_files(graph_tags)
for tag in summary_tags:
self._events_data.delete_tensor_event(tag)
self._events_data.add_tensor_event(tensor_event)
def filter_valid_files(self):
"""
Gets a list of valid files from the given file path.
Returns:
list[str], file name list.
"""
filenames = []
for filename in FileHandler.list_dir(self._summary_dir):
if FileHandler.is_file(FileHandler.join(self._summary_dir, filename)):
filenames.append(filename)
valid_filenames = []
valid_filenames.extend(self._filter_summary_files(filenames))
valid_filenames.extend(self._filter_pb_files(filenames))
return list(set(valid_filenames))
@staticmethod
def _filter_summary_files(filenames):
"""
Gets a list of summary files.
Args:
filenames (list[str]): File name list, like [filename1, filename2].
Returns:
list[str], filename list.
"""
return list(filter(
lambda filename: (re.search(r'summary\.\d+', filename)
and not filename.endswith("_lineage")), filenames))
@staticmethod
def _compare_summary_file(current_file, dst_file):
"""
......@@ -325,128 +463,9 @@ class MSDataLoader:
return True
return False
@staticmethod
def _sorted_summary_files(summary_files):
def sort_files(self, filenames):
"""Sort by creating time increments and filenames decrement."""
filenames = sorted(summary_files,
filenames = sorted(filenames,
key=lambda filename: (-int(re.search(r'summary\.(\d+)', filename)[1]), filename),
reverse=True)
return filenames
@staticmethod
def _filter_pb_files(filenames):
"""
Get a list of pb files.
Args:
filenames (list[str]): File name list, like [filename1, filename2].
Returns:
list[str], filename list.
"""
return list(filter(lambda filename: re.search(r'\.pb$', filename), filenames))
def _load_pb_files(self, filenames):
"""
Load and parse the pb files.
Args:
filenames (list[str]): File name list, like [filename1, filename2].
Returns:
list[str], filename list.
"""
pb_filenames = self._filter_pb_files(filenames)
pb_filenames = self._pb_parser.sort_pb_files(pb_filenames)
for filename in pb_filenames:
tensor_event = self._pb_parser.parse_pb_file(filename)
if tensor_event is None:
continue
self._events_data.add_tensor_event(tensor_event)
class _PbParser:
"""This class is used to parse pb file."""
def __init__(self, summary_dir):
self._latest_filename = ''
self._latest_mtime = 0
self._summary_dir = summary_dir
def parse_pb_file(self, filename):
"""
Parse single pb file.
Args:
filename (str): The file path of pb file.
Returns:
TensorEvent, if load pb file and build graph success, will return tensor event, else return None.
"""
if not self._is_parse_pb_file(filename):
return None
try:
tensor_event = self._parse_pb_file(filename)
return tensor_event
except UnknownError:
# Parse pb file failed, so return None.
return None
def sort_pb_files(self, filenames):
"""Sort by creating time increments and filenames increments."""
filenames = sorted(filenames, key=lambda file: (
FileHandler.file_stat(FileHandler.join(self._summary_dir, file)).mtime, file))
return filenames
def _is_parse_pb_file(self, filename):
"""Determines whether the file should be loaded。"""
mtime = FileHandler.file_stat(FileHandler.join(self._summary_dir, filename)).mtime
if mtime < self._latest_mtime or \
(mtime == self._latest_mtime and filename <= self._latest_filename):
return False
self._latest_mtime = mtime
self._latest_filename = filename
return True
def _parse_pb_file(self, filename):
"""
Parse pb file and write content to `EventsData`.
Args:
filename (str): The file path of pb file.
Returns:
TensorEvent, if load pb file and build graph success, will return tensor event, else return None.
"""
file_path = FileHandler.join(self._summary_dir, filename)
logger.info("Start to load graph from pb file, file path: %s.", file_path)
filehandler = FileHandler(file_path)
model_proto = anf_ir_pb2.ModelProto()
try:
model_proto.ParseFromString(filehandler.read())
except ParseError:
logger.warning("The given file is not a valid pb file, file path: %s.", file_path)
return None
graph = MSGraph()
try:
graph.build_graph(model_proto.graph)
except Exception as ex:
# Normally, there are no exceptions, and it is only possible for users on the MindSpore side
# to dump other non-default graphs.
logger.error("Build graph failed, file path: %s.", file_path)
logger.exception(ex)
raise UnknownError(str(ex))
tensor_event = TensorEvent(wall_time=FileHandler.file_stat(file_path),
step=0,
tag=filename,
plugin_name=PluginNameEnum.GRAPH.value,
value=graph,
filename=filename)
logger.info("Build graph success, file path: %s.", file_path)
return tensor_event
......@@ -87,7 +87,6 @@ class TestMsDataLoader:
ms_loader._latest_summary_filename = 'summary.00'
ms_loader.load()
shutil.rmtree(summary_dir)
assert ms_loader._latest_summary_file_size == RECORD_LEN
tag = ms_loader.get_events_data().list_tags_by_plugin('scalar')
tensors = ms_loader.get_events_data().tensors(tag[0])
assert len(tensors) == 3
......@@ -138,9 +137,11 @@ class TestPbParser:
_summary_dir = ''
def setup_method(self):
"""Run before method."""
self._summary_dir = tempfile.mkdtemp()
def teardown_method(self):
"""Run after method."""
shutil.rmtree(self._summary_dir)
def test_parse_pb_file(self):
......@@ -148,24 +149,24 @@ class TestPbParser:
filename = 'ms_output.pb'
create_graph_pb_file(output_dir=self._summary_dir, filename=filename)
parser = _PbParser(self._summary_dir)
tensor_event = parser.parse_pb_file(filename)
tensor_event = parser._parse_pb_file(filename)
assert isinstance(tensor_event, TensorEvent)
def test_is_parse_pb_file(self):
"""Test parse an older file."""
def test_set_latest_file(self):
"""Test set latest file."""
filename = 'ms_output.pb'
create_graph_pb_file(output_dir=self._summary_dir, filename=filename)
parser = _PbParser(self._summary_dir)
result = parser._is_parse_pb_file(filename)
assert result
is_latest = parser._set_latest_file(filename)
assert is_latest
filename = 'ms_output_older.pb'
file_path = create_graph_pb_file(output_dir=self._summary_dir, filename=filename)
atime = 1
mtime = 1
os.utime(file_path, (atime, mtime))
result = parser._is_parse_pb_file(filename)
assert not result
is_latest = parser._set_latest_file(filename)
assert not is_latest
def test_sort_pb_file_by_mtime(self):
"""Test sort pb files."""
......@@ -174,7 +175,7 @@ class TestPbParser:
create_graph_pb_file(output_dir=self._summary_dir, filename=file)
parser = _PbParser(self._summary_dir)
sorted_filenames = parser.sort_pb_files(filenames)
sorted_filenames = parser.sort_files(filenames)
assert filenames == sorted_filenames
def test_sort_pb_file_by_filename(self):
......@@ -191,7 +192,7 @@ class TestPbParser:
expected_filenames = ['bbb.pb', 'ccc.pb', 'aaa.pb']
parser = _PbParser(self._summary_dir)
sorted_filenames = parser.sort_pb_files(filenames)
sorted_filenames = parser.sort_files(filenames)
assert expected_filenames == sorted_filenames
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册