diff --git a/mindinsight/datavisual/data_transform/ms_data_loader.py b/mindinsight/datavisual/data_transform/ms_data_loader.py index 613ec393db8041d434bf8a98dc0334dd660275e3..3d1544b95ea4306165a682f423629e07959e0712 100644 --- a/mindinsight/datavisual/data_transform/ms_data_loader.py +++ b/mindinsight/datavisual/data_transform/ms_data_loader.py @@ -386,48 +386,40 @@ class _SummaryParser(_Parser): Args: event (Event): Message event in summary proto, data read from file handler. """ + plugins = { + 'scalar_value': PluginNameEnum.SCALAR, + 'image': PluginNameEnum.IMAGE, + 'histogram': PluginNameEnum.HISTOGRAM, + } + if event.HasField('summary'): for value in event.summary.value: - if value.HasField('scalar_value'): - tag = '{}/{}'.format(value.tag, PluginNameEnum.SCALAR.value) - tensor_event = TensorEvent(wall_time=event.wall_time, - step=event.step, - tag=tag, - plugin_name=PluginNameEnum.SCALAR.value, - value=value.scalar_value, - filename=self._latest_filename) - self._events_data.add_tensor_event(tensor_event) + for plugin in plugins: + if not value.HasField(plugin): + continue + plugin_name_enum = plugins[plugin] + tensor_event_value = getattr(value, plugin) + + if plugin == 'histogram': + tensor_event_value = HistogramContainer(tensor_event_value) + # Drop steps if original_buckets_count exceeds HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT + # to avoid time-consuming re-sample process. + if tensor_event_value.original_buckets_count > HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT: + logger.warning('original_buckets_count exceeds ' + 'HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT') + continue - if value.HasField('image'): - tag = '{}/{}'.format(value.tag, PluginNameEnum.IMAGE.value) tensor_event = TensorEvent(wall_time=event.wall_time, step=event.step, - tag=tag, - plugin_name=PluginNameEnum.IMAGE.value, - value=value.image, + tag='{}/{}'.format(value.tag, plugin_name_enum.value), + plugin_name=plugin_name_enum.value, + value=tensor_event_value, filename=self._latest_filename) self._events_data.add_tensor_event(tensor_event) - if value.HasField('histogram'): - histogram_msg = HistogramContainer(value.histogram) - # Drop steps if original_buckets_count exceeds HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT - # to avoid time-consuming re-sample process. - if histogram_msg.original_buckets_count > HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT: - logger.warning('original_buckets_count exceeds HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT') - else: - tag = '{}/{}'.format(value.tag, PluginNameEnum.HISTOGRAM.value) - tensor_event = TensorEvent(wall_time=event.wall_time, - step=event.step, - tag=tag, - plugin_name=PluginNameEnum.HISTOGRAM.value, - value=histogram_msg, - filename=self._latest_filename) - self._events_data.add_tensor_event(tensor_event) - - if event.HasField('graph_def'): - graph_proto = event.graph_def + elif event.HasField('graph_def'): graph = MSGraph() - graph.build_graph(graph_proto) + graph.build_graph(event.graph_def) tensor_event = TensorEvent(wall_time=event.wall_time, step=event.step, tag=self._latest_filename, @@ -439,6 +431,7 @@ class _SummaryParser(_Parser): graph_tags = self._events_data.list_tags_by_plugin(PluginNameEnum.GRAPH.value) except KeyError: graph_tags = [] + summary_tags = self.filter_files(graph_tags) for tag in summary_tags: self._events_data.delete_tensor_event(tag) diff --git a/mindinsight/datavisual/data_transform/summary_watcher.py b/mindinsight/datavisual/data_transform/summary_watcher.py index 8a4fff434d89f08fba774cb28475bfb7f0a54d97..6ecba170a217bf8ede1a8ec9231b457bca6efe7e 100644 --- a/mindinsight/datavisual/data_transform/summary_watcher.py +++ b/mindinsight/datavisual/data_transform/summary_watcher.py @@ -64,20 +64,12 @@ class SummaryWatcher: if self._contains_null_byte(summary_base_dir=summary_base_dir): return [] - if not os.path.exists(summary_base_dir): - logger.warning('Path of summary base directory not exists.') - return [] - - if not os.path.isdir(summary_base_dir): - logger.warning('Path of summary base directory is not a valid directory.') + relative_path = os.path.join('.', '') + if not self._is_valid_summary_directory(summary_base_dir, relative_path): return [] summary_dict = {} - - if not overall: - counter = Counter(max_count=self.MAX_SCAN_COUNT) - else: - counter = Counter() + counter = Counter(max_count=None if overall else self.MAX_SCAN_COUNT) try: entries = os.scandir(summary_base_dir) @@ -94,19 +86,13 @@ class SummaryWatcher: logger.info('Stop further scanning due to overall is False and ' 'number of scanned files exceeds upper limit.') break - relative_path = os.path.join('.', '') if entry.is_symlink(): pass elif entry.is_file(): self._update_summary_dict(summary_dict, summary_base_dir, relative_path, entry) elif entry.is_dir(): - full_path = os.path.realpath(os.path.join(summary_base_dir, entry.name)) - try: - subdir_entries = os.scandir(full_path) - except PermissionError: - logger.warning('Path of %s under summary base directory is not accessible.', entry.name) - continue - self._scan_subdir_entries(summary_dict, summary_base_dir, subdir_entries, entry.name, counter) + entry_path = os.path.realpath(os.path.join(summary_base_dir, entry.name)) + self._scan_subdir_entries(summary_dict, summary_base_dir, entry_path, entry.name, counter) directories = [] for key, value in summary_dict.items(): @@ -130,18 +116,24 @@ class SummaryWatcher: return directories - def _scan_subdir_entries(self, summary_dict, summary_base_dir, subdir_entries, entry_name, counter): + def _scan_subdir_entries(self, summary_dict, summary_base_dir, entry_path, entry_name, counter): """ Scan subdir entries. Args: summary_dict (dict): Temporary data structure to hold summary directory info. summary_base_dir (str): Path of summary base directory. + entry_path(str): Path entry. entry_name (str): Name of entry. - subdir_entries(DirEntry): Directory entry instance. counter (Counter): An instance of CountLimiter. """ + try: + subdir_entries = os.scandir(entry_path) + except PermissionError: + logger.warning('Path of %s under summary base directory is not accessible.', entry_name) + return + for subdir_entry in subdir_entries: if len(summary_dict) == self.MAX_SUMMARY_DIR_COUNT: break @@ -189,8 +181,6 @@ class SummaryWatcher: """ summary_base_dir = os.path.realpath(summary_base_dir) summary_directory = os.path.realpath(os.path.join(summary_base_dir, relative_path)) - if summary_base_dir == summary_directory: - return True if not os.path.exists(summary_directory): logger.warning('Path of summary directory not exists.')