From 902972e3eb4d81b7f2516c12da92d2599c851cd8 Mon Sep 17 00:00:00 2001 From: wenkai Date: Tue, 21 Jul 2020 17:40:03 +0800 Subject: [PATCH] fix sometimes deserialized protobuf data cannot be pickled to be sent to another process. 1. delete original message in HistogramContainer 2. Wrap image content in ImageContainer --- .../data_transform/histogram_container.py | 3 +- .../data_transform/image_container.py | 30 +++++++++++++++++++ .../data_transform/ms_data_loader.py | 7 +++-- 3 files changed, 36 insertions(+), 4 deletions(-) create mode 100644 mindinsight/datavisual/data_transform/image_container.py diff --git a/mindinsight/datavisual/data_transform/histogram_container.py b/mindinsight/datavisual/data_transform/histogram_container.py index e60fe86..c066110 100644 --- a/mindinsight/datavisual/data_transform/histogram_container.py +++ b/mindinsight/datavisual/data_transform/histogram_container.py @@ -26,8 +26,7 @@ class HistogramContainer: """ def __init__(self, histogram_message: Summary.Histogram): - self._msg = histogram_message - original_buckets = [Bucket(bucket.left, bucket.width, bucket.count) for bucket in self._msg.buckets] + original_buckets = [Bucket(bucket.left, bucket.width, bucket.count) for bucket in histogram_message.buckets] # Ensure buckets are sorted from min to max. original_buckets.sort(key=lambda bucket: bucket.left) self._count = sum(bucket.count for bucket in original_buckets) diff --git a/mindinsight/datavisual/data_transform/image_container.py b/mindinsight/datavisual/data_transform/image_container.py new file mode 100644 index 0000000..4a09565 --- /dev/null +++ b/mindinsight/datavisual/data_transform/image_container.py @@ -0,0 +1,30 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Image container.""" +from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Summary + + +class ImageContainer: + """ + Container for image to allow pickling. + + Args: + image_message (Summary.Image): Image proto buffer message. + """ + def __init__(self, image_message: Summary.Image): + self.height = image_message.height + self.width = image_message.width + self.colorspace = image_message.colorspace + self.encoded_image = image_message.encoded_image diff --git a/mindinsight/datavisual/data_transform/ms_data_loader.py b/mindinsight/datavisual/data_transform/ms_data_loader.py index e7c3509..2ac1a40 100644 --- a/mindinsight/datavisual/data_transform/ms_data_loader.py +++ b/mindinsight/datavisual/data_transform/ms_data_loader.py @@ -39,6 +39,7 @@ from mindinsight.datavisual.data_transform.events_data import TensorEvent from mindinsight.datavisual.data_transform.graph import MSGraph from mindinsight.datavisual.data_transform.histogram import Histogram from mindinsight.datavisual.data_transform.histogram_container import HistogramContainer +from mindinsight.datavisual.data_transform.image_container import ImageContainer from mindinsight.datavisual.data_transform.tensor_container import TensorContainer from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2 from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 @@ -462,7 +463,7 @@ class _SummaryParser(_Parser): tensor_event_value = getattr(value, plugin) logger.debug("Processing plugin value: %s.", plugin_name_enum) - if plugin == 'histogram': + if plugin == PluginNameEnum.HISTOGRAM.value: tensor_event_value = HistogramContainer(tensor_event_value) # Drop steps if original_buckets_count exceeds HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT # to avoid time-consuming re-sample process. @@ -470,8 +471,10 @@ class _SummaryParser(_Parser): logger.info('original_buckets_count exceeds ' 'HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT') continue - elif plugin == 'tensor': + elif plugin == PluginNameEnum.TENSOR.value: tensor_event_value = TensorContainer(tensor_event_value) + elif plugin == PluginNameEnum.IMAGE.value: + tensor_event_value = ImageContainer(tensor_event_value) tensor_event = TensorEvent(wall_time=event.wall_time, step=event.step, -- GitLab