diff --git a/mindinsight/datavisual/data_transform/ms_data_loader.py b/mindinsight/datavisual/data_transform/ms_data_loader.py index eed73f18847b9810b70254cdef67678436511d14..613ec393db8041d434bf8a98dc0334dd660275e3 100644 --- a/mindinsight/datavisual/data_transform/ms_data_loader.py +++ b/mindinsight/datavisual/data_transform/ms_data_loader.py @@ -51,16 +51,13 @@ class MSDataLoader: """ def __init__(self, summary_dir): - self._init_instance(summary_dir) - - def _init_instance(self, summary_dir): self._summary_dir = summary_dir self._valid_filenames = [] self._events_data = EventsData() - self._latest_summary_filename = '' - self._latest_summary_file_size = 0 - self._summary_file_handler = None - self._pb_parser = _PbParser(summary_dir) + + self._parser_list = [] + self._parser_list.append(_SummaryParser(summary_dir)) + self._parser_list.append(_PbParser(summary_dir)) def get_events_data(self): """Return events data read from log file.""" @@ -78,7 +75,7 @@ class MSDataLoader: if deleted_files: logger.warning("There are some files has been deleted, " "we will reload all files in path %s.", self._summary_dir) - self._init_instance(self._summary_dir) + self.__init__(self._summary_dir) def load(self): """ @@ -95,42 +92,216 @@ class MSDataLoader: self._valid_filenames = filenames self._check_files_deleted(filenames, old_filenames) - self._load_summary_files(self._valid_filenames) - self._load_pb_files(self._valid_filenames) + for parser in self._parser_list: + parser.parse_files(filenames, events_data=self._events_data) + + def filter_valid_files(self): + """ + Gets a list of valid files from the given file path. + + Returns: + list[str], file name list. + + """ + filenames = [] + for filename in FileHandler.list_dir(self._summary_dir): + if FileHandler.is_file(FileHandler.join(self._summary_dir, filename)): + filenames.append(filename) + + valid_filenames = [] + for parser in self._parser_list: + valid_filenames.extend(parser.filter_files(filenames)) + + return list(set(valid_filenames)) + + +class _Parser: + """Parsed base class.""" + + def __init__(self, summary_dir): + self._latest_filename = '' + self._latest_mtime = 0 + self._summary_dir = summary_dir + + def parse_files(self, filenames, events_data): + """ + Load files and parse files content. + + Args: + filenames (list[str]): File name list. + events_data (EventsData): The container of event data. + """ + raise NotImplementedError + + def sort_files(self, filenames): + """Sort by modify time increments and filenames increments.""" + filenames = sorted(filenames, key=lambda file: ( + FileHandler.file_stat(FileHandler.join(self._summary_dir, file)).mtime, file)) + return filenames + + def filter_files(self, filenames): + """ + Gets a list of files that this parsing class can parse. + + Args: + filenames (list[str]): File name list, like [filename1, filename2]. + + Returns: + list[str], filename list. + """ + raise NotImplementedError + + def _set_latest_file(self, filename): + """ + Check if the file's modification time is newer than the last time it was loaded, and if so, set the time. + + Args: + filename (str): The file name that needs to be checked and set. + + Returns: + bool, Returns True if the file was modified earlier than the last time it was loaded, or False. + """ + mtime = FileHandler.file_stat(FileHandler.join(self._summary_dir, filename)).mtime + if mtime < self._latest_mtime or \ + (mtime == self._latest_mtime and filename <= self._latest_filename): + return False + + self._latest_mtime = mtime + self._latest_filename = filename + + return True + + +class _PbParser(_Parser): + """This class is used to parse pb file.""" + + def parse_files(self, filenames, events_data): + pb_filenames = self.filter_files(filenames) + pb_filenames = self.sort_files(pb_filenames) + for filename in pb_filenames: + if not self._set_latest_file(filename): + continue + + try: + tensor_event = self._parse_pb_file(filename) + except UnknownError: + # Parse pb file failed, so return None. + continue + + events_data.add_tensor_event(tensor_event) + + def filter_files(self, filenames): + """ + Get a list of pb files. + + Args: + filenames (list[str]): File name list, like [filename1, filename2]. + + Returns: + list[str], filename list. + """ + return list(filter(lambda filename: re.search(r'\.pb$', filename), filenames)) + + def _parse_pb_file(self, filename): + """ + Parse pb file and write content to `EventsData`. + + Args: + filename (str): The file path of pb file. + + Returns: + TensorEvent, if load pb file and build graph success, will return tensor event, else return None. + """ + file_path = FileHandler.join(self._summary_dir, filename) + logger.info("Start to load graph from pb file, file path: %s.", file_path) + filehandler = FileHandler(file_path) + model_proto = anf_ir_pb2.ModelProto() + try: + model_proto.ParseFromString(filehandler.read()) + except ParseError: + logger.warning("The given file is not a valid pb file, file path: %s.", file_path) + return None + + graph = MSGraph() + + try: + graph.build_graph(model_proto.graph) + except Exception as ex: + # Normally, there are no exceptions, and it is only possible for users on the MindSpore side + # to dump other non-default graphs. + logger.error("Build graph failed, file path: %s.", file_path) + logger.exception(ex) + raise UnknownError(str(ex)) + + tensor_event = TensorEvent(wall_time=FileHandler.file_stat(file_path), + step=0, + tag=filename, + plugin_name=PluginNameEnum.GRAPH.value, + value=graph, + filename=filename) + + logger.info("Build graph success, file path: %s.", file_path) + return tensor_event - def _load_summary_files(self, filenames): + +class _SummaryParser(_Parser): + """The summary file parser.""" + + def __init__(self, summary_dir): + super(_SummaryParser, self).__init__(summary_dir) + self._latest_file_size = 0 + self._summary_file_handler = None + self._events_data = None + + def parse_files(self, filenames, events_data): """ Load summary file and parse file content. Args: filenames (list[str]): File name list. + events_data (EventsData): The container of event data. """ - summary_files = self._filter_summary_files(filenames) - summary_files = self._sorted_summary_files(summary_files) + self._events_data = events_data + summary_files = self.filter_files(filenames) + summary_files = self.sort_files(summary_files) for filename in summary_files: - if self._latest_summary_filename and \ - (self._compare_summary_file(self._latest_summary_filename, filename)): + if self._latest_filename and \ + (self._compare_summary_file(self._latest_filename, filename)): continue file_path = FileHandler.join(self._summary_dir, filename) - if filename != self._latest_summary_filename: + if filename != self._latest_filename: self._summary_file_handler = FileHandler(file_path, 'rb') - self._latest_summary_filename = filename - self._latest_summary_file_size = 0 + self._latest_filename = filename + self._latest_file_size = 0 new_size = FileHandler.file_stat(file_path).size - if new_size == self._latest_summary_file_size: + if new_size == self._latest_file_size: continue - self._latest_summary_file_size = new_size + self._latest_file_size = new_size try: self._load_single_file(self._summary_file_handler) except UnknownError as ex: logger.warning("Parse summary file failed, detail: %r," "file path: %s.", str(ex), file_path) + def filter_files(self, filenames): + """ + Gets a list of summary files. + + Args: + filenames (list[str]): File name list, like [filename1, filename2]. + + Returns: + list[str], filename list. + """ + return list(filter( + lambda filename: (re.search(r'summary\.\d+', filename) + and not filename.endswith("_lineage")), filenames)) + def _load_single_file(self, file_handler): """ Load a log file data. @@ -224,7 +395,7 @@ class MSDataLoader: tag=tag, plugin_name=PluginNameEnum.SCALAR.value, value=value.scalar_value, - filename=self._latest_summary_filename) + filename=self._latest_filename) self._events_data.add_tensor_event(tensor_event) if value.HasField('image'): @@ -234,7 +405,7 @@ class MSDataLoader: tag=tag, plugin_name=PluginNameEnum.IMAGE.value, value=value.image, - filename=self._latest_summary_filename) + filename=self._latest_filename) self._events_data.add_tensor_event(tensor_event) if value.HasField('histogram'): @@ -250,7 +421,7 @@ class MSDataLoader: tag=tag, plugin_name=PluginNameEnum.HISTOGRAM.value, value=histogram_msg, - filename=self._latest_summary_filename) + filename=self._latest_filename) self._events_data.add_tensor_event(tensor_event) if event.HasField('graph_def'): @@ -259,54 +430,21 @@ class MSDataLoader: graph.build_graph(graph_proto) tensor_event = TensorEvent(wall_time=event.wall_time, step=event.step, - tag=self._latest_summary_filename, + tag=self._latest_filename, plugin_name=PluginNameEnum.GRAPH.value, value=graph, - filename=self._latest_summary_filename) + filename=self._latest_filename) try: graph_tags = self._events_data.list_tags_by_plugin(PluginNameEnum.GRAPH.value) except KeyError: graph_tags = [] - summary_tags = self._filter_summary_files(graph_tags) + summary_tags = self.filter_files(graph_tags) for tag in summary_tags: self._events_data.delete_tensor_event(tag) self._events_data.add_tensor_event(tensor_event) - def filter_valid_files(self): - """ - Gets a list of valid files from the given file path. - - Returns: - list[str], file name list. - - """ - filenames = [] - for filename in FileHandler.list_dir(self._summary_dir): - if FileHandler.is_file(FileHandler.join(self._summary_dir, filename)): - filenames.append(filename) - - valid_filenames = [] - valid_filenames.extend(self._filter_summary_files(filenames)) - valid_filenames.extend(self._filter_pb_files(filenames)) - return list(set(valid_filenames)) - - @staticmethod - def _filter_summary_files(filenames): - """ - Gets a list of summary files. - - Args: - filenames (list[str]): File name list, like [filename1, filename2]. - - Returns: - list[str], filename list. - """ - return list(filter( - lambda filename: (re.search(r'summary\.\d+', filename) - and not filename.endswith("_lineage")), filenames)) - @staticmethod def _compare_summary_file(current_file, dst_file): """ @@ -325,128 +463,9 @@ class MSDataLoader: return True return False - @staticmethod - def _sorted_summary_files(summary_files): + def sort_files(self, filenames): """Sort by creating time increments and filenames decrement.""" - filenames = sorted(summary_files, + filenames = sorted(filenames, key=lambda filename: (-int(re.search(r'summary\.(\d+)', filename)[1]), filename), reverse=True) return filenames - - @staticmethod - def _filter_pb_files(filenames): - """ - Get a list of pb files. - - Args: - filenames (list[str]): File name list, like [filename1, filename2]. - - Returns: - list[str], filename list. - """ - return list(filter(lambda filename: re.search(r'\.pb$', filename), filenames)) - - def _load_pb_files(self, filenames): - """ - Load and parse the pb files. - - Args: - filenames (list[str]): File name list, like [filename1, filename2]. - - Returns: - list[str], filename list. - """ - pb_filenames = self._filter_pb_files(filenames) - pb_filenames = self._pb_parser.sort_pb_files(pb_filenames) - for filename in pb_filenames: - tensor_event = self._pb_parser.parse_pb_file(filename) - if tensor_event is None: - continue - self._events_data.add_tensor_event(tensor_event) - - -class _PbParser: - """This class is used to parse pb file.""" - - def __init__(self, summary_dir): - self._latest_filename = '' - self._latest_mtime = 0 - self._summary_dir = summary_dir - - def parse_pb_file(self, filename): - """ - Parse single pb file. - - Args: - filename (str): The file path of pb file. - Returns: - TensorEvent, if load pb file and build graph success, will return tensor event, else return None. - """ - if not self._is_parse_pb_file(filename): - return None - - try: - tensor_event = self._parse_pb_file(filename) - return tensor_event - except UnknownError: - # Parse pb file failed, so return None. - return None - - def sort_pb_files(self, filenames): - """Sort by creating time increments and filenames increments.""" - filenames = sorted(filenames, key=lambda file: ( - FileHandler.file_stat(FileHandler.join(self._summary_dir, file)).mtime, file)) - return filenames - - def _is_parse_pb_file(self, filename): - """Determines whether the file should be loaded。""" - mtime = FileHandler.file_stat(FileHandler.join(self._summary_dir, filename)).mtime - - if mtime < self._latest_mtime or \ - (mtime == self._latest_mtime and filename <= self._latest_filename): - return False - - self._latest_mtime = mtime - self._latest_filename = filename - return True - - def _parse_pb_file(self, filename): - """ - Parse pb file and write content to `EventsData`. - - Args: - filename (str): The file path of pb file. - - Returns: - TensorEvent, if load pb file and build graph success, will return tensor event, else return None. - """ - file_path = FileHandler.join(self._summary_dir, filename) - logger.info("Start to load graph from pb file, file path: %s.", file_path) - filehandler = FileHandler(file_path) - model_proto = anf_ir_pb2.ModelProto() - try: - model_proto.ParseFromString(filehandler.read()) - except ParseError: - logger.warning("The given file is not a valid pb file, file path: %s.", file_path) - return None - - graph = MSGraph() - - try: - graph.build_graph(model_proto.graph) - except Exception as ex: - # Normally, there are no exceptions, and it is only possible for users on the MindSpore side - # to dump other non-default graphs. - logger.error("Build graph failed, file path: %s.", file_path) - logger.exception(ex) - raise UnknownError(str(ex)) - - tensor_event = TensorEvent(wall_time=FileHandler.file_stat(file_path), - step=0, - tag=filename, - plugin_name=PluginNameEnum.GRAPH.value, - value=graph, - filename=filename) - - logger.info("Build graph success, file path: %s.", file_path) - return tensor_event diff --git a/tests/ut/datavisual/data_transform/test_ms_data_loader.py b/tests/ut/datavisual/data_transform/test_ms_data_loader.py index bd53c757a961cf61fce0d181ba469867efd57a66..c8530615420c52e6390ea3ab6856e054dddabcd2 100644 --- a/tests/ut/datavisual/data_transform/test_ms_data_loader.py +++ b/tests/ut/datavisual/data_transform/test_ms_data_loader.py @@ -87,7 +87,6 @@ class TestMsDataLoader: ms_loader._latest_summary_filename = 'summary.00' ms_loader.load() shutil.rmtree(summary_dir) - assert ms_loader._latest_summary_file_size == RECORD_LEN tag = ms_loader.get_events_data().list_tags_by_plugin('scalar') tensors = ms_loader.get_events_data().tensors(tag[0]) assert len(tensors) == 3 @@ -138,9 +137,11 @@ class TestPbParser: _summary_dir = '' def setup_method(self): + """Run before method.""" self._summary_dir = tempfile.mkdtemp() def teardown_method(self): + """Run after method.""" shutil.rmtree(self._summary_dir) def test_parse_pb_file(self): @@ -148,24 +149,24 @@ class TestPbParser: filename = 'ms_output.pb' create_graph_pb_file(output_dir=self._summary_dir, filename=filename) parser = _PbParser(self._summary_dir) - tensor_event = parser.parse_pb_file(filename) + tensor_event = parser._parse_pb_file(filename) assert isinstance(tensor_event, TensorEvent) - def test_is_parse_pb_file(self): - """Test parse an older file.""" + def test_set_latest_file(self): + """Test set latest file.""" filename = 'ms_output.pb' create_graph_pb_file(output_dir=self._summary_dir, filename=filename) parser = _PbParser(self._summary_dir) - result = parser._is_parse_pb_file(filename) - assert result + is_latest = parser._set_latest_file(filename) + assert is_latest filename = 'ms_output_older.pb' file_path = create_graph_pb_file(output_dir=self._summary_dir, filename=filename) atime = 1 mtime = 1 os.utime(file_path, (atime, mtime)) - result = parser._is_parse_pb_file(filename) - assert not result + is_latest = parser._set_latest_file(filename) + assert not is_latest def test_sort_pb_file_by_mtime(self): """Test sort pb files.""" @@ -174,7 +175,7 @@ class TestPbParser: create_graph_pb_file(output_dir=self._summary_dir, filename=file) parser = _PbParser(self._summary_dir) - sorted_filenames = parser.sort_pb_files(filenames) + sorted_filenames = parser.sort_files(filenames) assert filenames == sorted_filenames def test_sort_pb_file_by_filename(self): @@ -191,7 +192,7 @@ class TestPbParser: expected_filenames = ['bbb.pb', 'ccc.pb', 'aaa.pb'] parser = _PbParser(self._summary_dir) - sorted_filenames = parser.sort_pb_files(filenames) + sorted_filenames = parser.sort_files(filenames) assert expected_filenames == sorted_filenames