提交 4f3d0793 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!167 optimize redundant funtion codes

Merge pull request !167 from liangyongxiong/redundant-codes
......@@ -386,48 +386,40 @@ class _SummaryParser(_Parser):
Args:
event (Event): Message event in summary proto, data read from file handler.
"""
plugins = {
'scalar_value': PluginNameEnum.SCALAR,
'image': PluginNameEnum.IMAGE,
'histogram': PluginNameEnum.HISTOGRAM,
}
if event.HasField('summary'):
for value in event.summary.value:
if value.HasField('scalar_value'):
tag = '{}/{}'.format(value.tag, PluginNameEnum.SCALAR.value)
tensor_event = TensorEvent(wall_time=event.wall_time,
step=event.step,
tag=tag,
plugin_name=PluginNameEnum.SCALAR.value,
value=value.scalar_value,
filename=self._latest_filename)
self._events_data.add_tensor_event(tensor_event)
for plugin in plugins:
if not value.HasField(plugin):
continue
plugin_name_enum = plugins[plugin]
tensor_event_value = getattr(value, plugin)
if plugin == 'histogram':
tensor_event_value = HistogramContainer(tensor_event_value)
# Drop steps if original_buckets_count exceeds HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT
# to avoid time-consuming re-sample process.
if tensor_event_value.original_buckets_count > HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT:
logger.warning('original_buckets_count exceeds '
'HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT')
continue
if value.HasField('image'):
tag = '{}/{}'.format(value.tag, PluginNameEnum.IMAGE.value)
tensor_event = TensorEvent(wall_time=event.wall_time,
step=event.step,
tag=tag,
plugin_name=PluginNameEnum.IMAGE.value,
value=value.image,
tag='{}/{}'.format(value.tag, plugin_name_enum.value),
plugin_name=plugin_name_enum.value,
value=tensor_event_value,
filename=self._latest_filename)
self._events_data.add_tensor_event(tensor_event)
if value.HasField('histogram'):
histogram_msg = HistogramContainer(value.histogram)
# Drop steps if original_buckets_count exceeds HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT
# to avoid time-consuming re-sample process.
if histogram_msg.original_buckets_count > HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT:
logger.warning('original_buckets_count exceeds HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT')
else:
tag = '{}/{}'.format(value.tag, PluginNameEnum.HISTOGRAM.value)
tensor_event = TensorEvent(wall_time=event.wall_time,
step=event.step,
tag=tag,
plugin_name=PluginNameEnum.HISTOGRAM.value,
value=histogram_msg,
filename=self._latest_filename)
self._events_data.add_tensor_event(tensor_event)
if event.HasField('graph_def'):
graph_proto = event.graph_def
elif event.HasField('graph_def'):
graph = MSGraph()
graph.build_graph(graph_proto)
graph.build_graph(event.graph_def)
tensor_event = TensorEvent(wall_time=event.wall_time,
step=event.step,
tag=self._latest_filename,
......@@ -439,6 +431,7 @@ class _SummaryParser(_Parser):
graph_tags = self._events_data.list_tags_by_plugin(PluginNameEnum.GRAPH.value)
except KeyError:
graph_tags = []
summary_tags = self.filter_files(graph_tags)
for tag in summary_tags:
self._events_data.delete_tensor_event(tag)
......
......@@ -64,20 +64,12 @@ class SummaryWatcher:
if self._contains_null_byte(summary_base_dir=summary_base_dir):
return []
if not os.path.exists(summary_base_dir):
logger.warning('Path of summary base directory not exists.')
return []
if not os.path.isdir(summary_base_dir):
logger.warning('Path of summary base directory is not a valid directory.')
relative_path = os.path.join('.', '')
if not self._is_valid_summary_directory(summary_base_dir, relative_path):
return []
summary_dict = {}
if not overall:
counter = Counter(max_count=self.MAX_SCAN_COUNT)
else:
counter = Counter()
counter = Counter(max_count=None if overall else self.MAX_SCAN_COUNT)
try:
entries = os.scandir(summary_base_dir)
......@@ -94,19 +86,13 @@ class SummaryWatcher:
logger.info('Stop further scanning due to overall is False and '
'number of scanned files exceeds upper limit.')
break
relative_path = os.path.join('.', '')
if entry.is_symlink():
pass
elif entry.is_file():
self._update_summary_dict(summary_dict, summary_base_dir, relative_path, entry)
elif entry.is_dir():
full_path = os.path.realpath(os.path.join(summary_base_dir, entry.name))
try:
subdir_entries = os.scandir(full_path)
except PermissionError:
logger.warning('Path of %s under summary base directory is not accessible.', entry.name)
continue
self._scan_subdir_entries(summary_dict, summary_base_dir, subdir_entries, entry.name, counter)
entry_path = os.path.realpath(os.path.join(summary_base_dir, entry.name))
self._scan_subdir_entries(summary_dict, summary_base_dir, entry_path, entry.name, counter)
directories = []
for key, value in summary_dict.items():
......@@ -130,18 +116,24 @@ class SummaryWatcher:
return directories
def _scan_subdir_entries(self, summary_dict, summary_base_dir, subdir_entries, entry_name, counter):
def _scan_subdir_entries(self, summary_dict, summary_base_dir, entry_path, entry_name, counter):
"""
Scan subdir entries.
Args:
summary_dict (dict): Temporary data structure to hold summary directory info.
summary_base_dir (str): Path of summary base directory.
entry_path(str): Path entry.
entry_name (str): Name of entry.
subdir_entries(DirEntry): Directory entry instance.
counter (Counter): An instance of CountLimiter.
"""
try:
subdir_entries = os.scandir(entry_path)
except PermissionError:
logger.warning('Path of %s under summary base directory is not accessible.', entry_name)
return
for subdir_entry in subdir_entries:
if len(summary_dict) == self.MAX_SUMMARY_DIR_COUNT:
break
......@@ -189,8 +181,6 @@ class SummaryWatcher:
"""
summary_base_dir = os.path.realpath(summary_base_dir)
summary_directory = os.path.realpath(os.path.join(summary_base_dir, relative_path))
if summary_base_dir == summary_directory:
return True
if not os.path.exists(summary_directory):
logger.warning('Path of summary directory not exists.')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册