diff --git a/mindinsight/datavisual/data_transform/ms_data_loader.py b/mindinsight/datavisual/data_transform/ms_data_loader.py index afe42fe6b492d8b5d5346eedd7ed458680c22466..b3d87e34fb1561cc5e056f28d8f7c3eb60ffa86d 100644 --- a/mindinsight/datavisual/data_transform/ms_data_loader.py +++ b/mindinsight/datavisual/data_transform/ms_data_loader.py @@ -416,6 +416,44 @@ class _SummaryParser(_Parser): return event_str + @staticmethod + def _parse_summary_value(value, plugin): + """ + Parse summary value and create corresponding container according to plugin. + + Args: + value (Summary.Value): Value message in summary file. + plugin (str): Plugin value. + + Returns: + Union[Summary.Value, HistogramContainer, TensorContainer, ImageContainer], original summary value + or an instance of HistogramContainer or TensorContainer or ImageContainer. + """ + tensor_event_value = getattr(value, plugin) + if plugin == PluginNameEnum.HISTOGRAM.value: + tensor_event_value = HistogramContainer(tensor_event_value) + # Drop steps if original_buckets_count exceeds HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT + # to avoid time-consuming re-sample process. + if tensor_event_value.histogram.original_buckets_count > Histogram.MAX_ORIGINAL_BUCKETS_COUNT: + logger.info('original_buckets_count exceeds ' + 'HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT') + return None + + elif plugin == PluginNameEnum.TENSOR.value: + tensor_event_value = TensorContainer(tensor_event_value) + tensor_count = 1 + for d in tensor_event_value.dims: + tensor_count *= d + if tensor_count > MAX_TENSOR_COUNT: + logger.warning('tag: %s/tensor, dims: %s, tensor count: %d exceeds %d and drop it.', + value.tag, tensor_event_value.dims, tensor_count, MAX_TENSOR_COUNT) + return None + + elif plugin == PluginNameEnum.IMAGE.value: + tensor_event_value = ImageContainer(tensor_event_value) + + return tensor_event_value + @staticmethod def _event_parse(event_str, latest_file_name): """ @@ -424,7 +462,7 @@ class _SummaryParser(_Parser): This method is static to avoid sending unnecessary objects to other processes. Args: - event (str): Message event string in summary proto, data read from file handler. + event_str (str): Message event string in summary proto, data read from file handler. latest_file_name (str): Latest file name. """ @@ -445,30 +483,10 @@ class _SummaryParser(_Parser): if not value.HasField(plugin): continue plugin_name_enum = plugins[plugin] - tensor_event_value = getattr(value, plugin) logger.debug("Processing plugin value: %s.", plugin_name_enum) - - if plugin == PluginNameEnum.HISTOGRAM.value: - tensor_event_value = HistogramContainer(tensor_event_value) - # Drop steps if original_buckets_count exceeds HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT - # to avoid time-consuming re-sample process. - if tensor_event_value.histogram.original_buckets_count > Histogram.MAX_ORIGINAL_BUCKETS_COUNT: - logger.info('original_buckets_count exceeds ' - 'HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT') - continue - - elif plugin == PluginNameEnum.TENSOR.value: - tensor_event_value = TensorContainer(tensor_event_value) - tensor_count = 1 - for d in tensor_event_value.dims: - tensor_count *= d - if tensor_count > MAX_TENSOR_COUNT: - logger.warning('tag: %s/tensor, tensor count: %d exceeds %d and drop it.', - value.tag, tensor_count, MAX_TENSOR_COUNT) - continue - - elif plugin == PluginNameEnum.IMAGE.value: - tensor_event_value = ImageContainer(tensor_event_value) + tensor_event_value = _SummaryParser._parse_summary_value(value, plugin) + if tensor_event_value is None: + continue tensor_event = TensorEvent(wall_time=event.wall_time, step=event.step,