diff --git a/mindinsight/backend/datavisual/train_visual_api.py b/mindinsight/backend/datavisual/train_visual_api.py index a868a443c817c402a689b20195737d12c7706bd9..afcda045823e0163a559c65e9c3b94bae6ca2c66 100644 --- a/mindinsight/backend/datavisual/train_visual_api.py +++ b/mindinsight/backend/datavisual/train_visual_api.py @@ -25,6 +25,7 @@ from mindinsight.conf import settings from mindinsight.datavisual.utils.tools import get_train_id from mindinsight.datavisual.utils.tools import if_nan_inf_to_none from mindinsight.datavisual.processors.histogram_processor import HistogramProcessor +from mindinsight.datavisual.processors.tensor_processor import TensorProcessor from mindinsight.datavisual.processors.images_processor import ImageProcessor from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor from mindinsight.datavisual.processors.graph_processor import GraphProcessor @@ -173,6 +174,25 @@ def get_scalars(): return jsonify({'scalars': scalars}) +@BLUEPRINT.route("/datavisual/tensors", methods=["GET"]) +def get_tensors(): + """ + Interface to obtain tensor data. + + Returns: + Response, which contains a JSON object. + """ + train_ids = request.args.getlist('train_id') + tags = request.args.getlist('tag') + step = request.args.get("step", default=None) + dims = request.args.get("dims", default=None) + detail = request.args.get("detail", default=None) + + processor = TensorProcessor(DATA_MANAGER) + response = processor.get_tensors(train_ids, tags, step, dims, detail) + return jsonify(response) + + def init_module(app): """ Init module entry. diff --git a/mindinsight/conf/constants.py b/mindinsight/conf/constants.py index 5677b73f217021925f5c6e1740771f3e7a8967f1..cdd8aa1af44a7fff26464dc091b2ec4bc5d66218 100644 --- a/mindinsight/conf/constants.py +++ b/mindinsight/conf/constants.py @@ -57,3 +57,5 @@ MAX_IMAGE_STEP_SIZE_PER_TAG = 10 MAX_SCALAR_STEP_SIZE_PER_TAG = 1000 MAX_GRAPH_STEP_SIZE_PER_TAG = 1 MAX_HISTOGRAM_STEP_SIZE_PER_TAG = 50 +MAX_TENSOR_STEP_SIZE_PER_TAG = 50 +MAX_TENSOR_RESPONSE_DATA_SIZE = 300000 diff --git a/mindinsight/datavisual/common/enums.py b/mindinsight/datavisual/common/enums.py index d93dde54de9bee68cc83c69e1a75d02309cfdb08..b5f452be528d29e0c5eb8b518ac35998ec00dc99 100644 --- a/mindinsight/datavisual/common/enums.py +++ b/mindinsight/datavisual/common/enums.py @@ -38,6 +38,7 @@ class PluginNameEnum(BaseEnum): SCALAR = 'scalar' GRAPH = 'graph' HISTOGRAM = 'histogram' + TENSOR = 'tensor' @enum.unique diff --git a/mindinsight/datavisual/common/exceptions.py b/mindinsight/datavisual/common/exceptions.py index 35e0f9217960aa229feda1b21ae129c5322fd59c..47322d4aa61ea824b12d56795de5e6b813c64225 100644 --- a/mindinsight/datavisual/common/exceptions.py +++ b/mindinsight/datavisual/common/exceptions.py @@ -161,6 +161,33 @@ class HistogramNotExistError(MindInsightException): http_code=400) +class TensorNotExistError(MindInsightException): + """Unable to get tensor values based on a given condition.""" + def __init__(self, error_detail): + error_msg = f'Tensor value is not exist. Detail: {error_detail}' + super(TensorNotExistError, self).__init__(DataVisualErrors.TENSOR_NOT_EXIST, + error_msg, + http_code=400) + + +class StepTensorDataNotInCacheError(MindInsightException): + """Tensor data with specific step does not in cache.""" + def __init__(self, error_detail): + error_msg = f'Tensor data not in cache. Detail: {error_detail}' + super(StepTensorDataNotInCacheError, self).__init__(DataVisualErrors.STEP_TENSOR_DATA_NOT_IN_CACHE, + error_msg, + http_code=400) + + +class ResponseDataExceedMaxValueError(MindInsightException): + """Response data exceed max value based on a given condition.""" + def __init__(self, error_detail): + error_msg = f'Response data exceed max value. Detail: {error_detail}' + super(ResponseDataExceedMaxValueError, self).__init__(DataVisualErrors.MAX_RESPONSE_DATA_EXCEEDED_ERROR, + error_msg, + http_code=400) + + class TrainJobDetailNotInCacheError(MindInsightException): """Detail info of given train job is not in cache.""" def __init__(self, error_detail="no detail provided."): diff --git a/mindinsight/datavisual/data_transform/events_data.py b/mindinsight/datavisual/data_transform/events_data.py index 282bb2a3f31e9681dd03d11acab7ee18c730fc31..d297189abe4706cdf8912bacd7acd375b3e0fb97 100644 --- a/mindinsight/datavisual/data_transform/events_data.py +++ b/mindinsight/datavisual/data_transform/events_data.py @@ -41,7 +41,8 @@ CONFIG = { PluginNameEnum.SCALAR.value: settings.MAX_SCALAR_STEP_SIZE_PER_TAG, PluginNameEnum.IMAGE.value: settings.MAX_IMAGE_STEP_SIZE_PER_TAG, PluginNameEnum.GRAPH.value: settings.MAX_GRAPH_STEP_SIZE_PER_TAG, - PluginNameEnum.HISTOGRAM.value: settings.MAX_HISTOGRAM_STEP_SIZE_PER_TAG + PluginNameEnum.HISTOGRAM.value: settings.MAX_HISTOGRAM_STEP_SIZE_PER_TAG, + PluginNameEnum.TENSOR.value: settings.MAX_TENSOR_STEP_SIZE_PER_TAG } } diff --git a/mindinsight/datavisual/data_transform/histogram.py b/mindinsight/datavisual/data_transform/histogram.py new file mode 100644 index 0000000000000000000000000000000000000000..7f4bf9086d0c2197519dccd57e080d407a6c93d4 --- /dev/null +++ b/mindinsight/datavisual/data_transform/histogram.py @@ -0,0 +1,234 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Histogram data.""" +import math + +from mindinsight.utils.exceptions import ParamValueError +from mindinsight.datavisual.utils.utils import calc_histogram_bins + + +def mask_invalid_number(num): + """Mask invalid number to 0.""" + if math.isnan(num) or math.isinf(num): + return type(num)(0) + + return num + + +class Bucket: + """ + Bucket data class. + + Args: + left (double): Left edge of the histogram bucket. + width (double): Width of the histogram bucket. + count (int): Count of numbers fallen in the histogram bucket. + """ + def __init__(self, left, width, count): + self._left = left + self._width = width + self._count = count + + @property + def left(self): + """Gets left edge of the histogram bucket.""" + return self._left + + @property + def count(self): + """Gets count of numbers fallen in the histogram bucket.""" + return self._count + + @property + def width(self): + """Gets width of the histogram bucket.""" + return self._width + + @property + def right(self): + """Gets right edge of the histogram bucket.""" + return self._left + self._width + + def as_tuple(self): + """Gets the bucket as tuple.""" + return self._left, self._width, self._count + + def __repr__(self): + """Returns repr(self).""" + return "Bucket(left={}, width={}, count={})".format(self._left, self._width, self._count) + + +class Histogram: + """ + Histogram data class. + + Args: + buckets (tuple[Bucket]) + """ + + # Max quantity of original buckets. + MAX_ORIGINAL_BUCKETS_COUNT = 90 + + def __init__(self, buckets, max_val, min_val, count): + self._visual_max = max_val + self._visual_min = min_val + self._count = count + self._original_buckets = buckets + # default bin number + self._visual_bins = calc_histogram_bins(count) + # Note that tuple is immutable, so sharing tuple is often safe. + self._re_sampled_buckets = () + + @property + def original_buckets_count(self): + """Gets original buckets quantity.""" + return len(self._original_buckets) + + def set_visual_range(self, max_val: float, min_val: float, bins: int) -> None: + """ + Sets visual range for later re-sampling. + + It's caller's duty to ensure input is valid. + + Why we need visual range for histograms? Aligned buckets between steps can help users know about the trend of + tensors. Miss aligned buckets between steps might miss-lead users about the trend of a tensor. Because for + given tensor, if you have thinner buckets, count of every bucket will get lower, however, if you have + thicker buckets, count of every bucket will get higher. When they are displayed together, user might think + the histogram with thicker buckets has more values. This is miss-leading. So we need to unify buckets across + steps. Visual range for histogram is a technology for unifying buckets. + + Args: + max_val (float): Max value for visual histogram. + min_val (float): Min value for visual histogram. + bins (int): Bins number for visual histogram. + """ + if max_val < min_val: + raise ParamValueError( + "Invalid input. max_val({}) is less or equal than min_val({}).".format(max_val, min_val)) + + if bins < 1: + raise ParamValueError("Invalid input bins({}). Must be greater than 0.".format(bins)) + + self._visual_max = max_val + self._visual_min = min_val + self._visual_bins = bins + + # mark _re_sampled_buckets to empty + self._re_sampled_buckets = () + + def _calc_intersection_len(self, max1, min1, max2, min2): + """Calculates intersection length of [min1, max1] and [min2, max2].""" + if max1 < min1: + raise ParamValueError( + "Invalid input. max1({}) is less than min1({}).".format(max1, min1)) + + if max2 < min2: + raise ParamValueError( + "Invalid input. max2({}) is less than min2({}).".format(max2, min2)) + + if min1 <= min2: + if max1 <= min2: + # return value must be calculated by max1.__sub__ + return max1 - max1 + if max1 <= max2: + return max1 - min2 + # max1 > max2 + return max2 - min2 + + # min1 > min2 + if max2 <= min1: + return max2 - max2 + if max2 <= max1: + return max2 - min1 + return max1 - min1 + + def _re_sample_buckets(self): + """Re-samples buckets according to visual_max, visual_min and visual_bins.""" + if self._visual_max == self._visual_min: + # Adjust visual range if max equals min. + self._visual_max += 0.5 + self._visual_min -= 0.5 + + width = (self._visual_max - self._visual_min) / self._visual_bins + + if not self._count: + self._re_sampled_buckets = tuple( + Bucket(self._visual_min + width * i, width, 0) + for i in range(self._visual_bins)) + return + + re_sampled = [] + original_pos = 0 + original_bucket = self._original_buckets[original_pos] + for i in range(self._visual_bins): + cur_left = self._visual_min + width * i + cur_right = cur_left + width + cur_estimated_count = 0.0 + + # Skip no bucket range. + if cur_right <= original_bucket.left: + re_sampled.append(Bucket(cur_left, width, math.ceil(cur_estimated_count))) + continue + + # Skip no intersect range. + while cur_left >= original_bucket.right: + original_pos += 1 + if original_pos >= len(self._original_buckets): + break + original_bucket = self._original_buckets[original_pos] + + # entering with this condition: cur_right > original_bucket.left and cur_left < original_bucket.right + while True: + if original_pos >= len(self._original_buckets): + break + original_bucket = self._original_buckets[original_pos] + + intersection = self._calc_intersection_len( + min1=cur_left, max1=cur_right, + min2=original_bucket.left, max2=original_bucket.right) + if not original_bucket.width: + estimated_count = original_bucket.count + else: + estimated_count = (intersection / original_bucket.width) * original_bucket.count + + cur_estimated_count += estimated_count + if cur_right > original_bucket.right: + # Need to sample next original bucket to this visual bucket. + original_pos += 1 + else: + # Current visual bucket has taken all intersect buckets into account. + break + + re_sampled.append(Bucket(cur_left, width, math.ceil(cur_estimated_count))) + + self._re_sampled_buckets = tuple(re_sampled) + + def buckets(self, convert_to_tuple=True): + """ + Get visual buckets instead of original buckets. + + Args: + convert_to_tuple (bool): Whether convert bucket object to tuple. + + Returns: + tuple, contains buckets. + """ + if not self._re_sampled_buckets: + self._re_sample_buckets() + + if not convert_to_tuple: + return self._re_sampled_buckets + + return tuple(bucket.as_tuple() for bucket in self._re_sampled_buckets) diff --git a/mindinsight/datavisual/data_transform/histogram_container.py b/mindinsight/datavisual/data_transform/histogram_container.py index 09fcb4e21f8e7b370d3a00e02a2652be38714e42..e60fe8645076332457261ee24fe0f0a2d8bcf1d5 100644 --- a/mindinsight/datavisual/data_transform/histogram_container.py +++ b/mindinsight/datavisual/data_transform/histogram_container.py @@ -13,90 +13,27 @@ # limitations under the License. # ============================================================================ """Histogram data container.""" -import math - +from mindinsight.datavisual.data_transform.histogram import Histogram, Bucket, mask_invalid_number from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Summary -from mindinsight.utils.exceptions import ParamValueError -from mindinsight.datavisual.utils.utils import calc_histogram_bins - - -def _mask_invalid_number(num): - """Mask invalid number to 0.""" - if math.isnan(num) or math.isinf(num): - return type(num)(0) - - return num - - -class Bucket: - """ - Bucket data class. - - Args: - left (double): Left edge of the histogram bucket. - width (double): Width of the histogram bucket. - count (int): Count of numbers fallen in the histogram bucket. - """ - def __init__(self, left, width, count): - self._left = left - self._width = width - self._count = count - - @property - def left(self): - """Gets left edge of the histogram bucket.""" - return self._left - - @property - def count(self): - """Gets count of numbers fallen in the histogram bucket.""" - return self._count - - @property - def width(self): - """Gets width of the histogram bucket.""" - return self._width - - @property - def right(self): - """Gets right edge of the histogram bucket.""" - return self._left + self._width - - def as_tuple(self): - """Gets the bucket as tuple.""" - return self._left, self._width, self._count - - def __repr__(self): - """Returns repr(self).""" - return "Bucket(left={}, width={}, count={})".format(self._left, self._width, self._count) class HistogramContainer: """ - Histogram data container. + Histogram data container. - Args: - histogram_message (Summary.Histogram): Histogram message in summary file. + Args: + histogram_message (Summary.Histogram): Histogram message in summary file. """ - # Max quantity of original buckets. - MAX_ORIGINAL_BUCKETS_COUNT = 90 - def __init__(self, histogram_message: Summary.Histogram): self._msg = histogram_message original_buckets = [Bucket(bucket.left, bucket.width, bucket.count) for bucket in self._msg.buckets] # Ensure buckets are sorted from min to max. original_buckets.sort(key=lambda bucket: bucket.left) - self._original_buckets = tuple(original_buckets) - self._count = sum(bucket.count for bucket in self._original_buckets) - self._max = _mask_invalid_number(histogram_message.max) - self._min = _mask_invalid_number(histogram_message.min) - self._visual_max = self._max - self._visual_min = self._min - # default bin number - self._visual_bins = calc_histogram_bins(self._count) - # Note that tuple is immutable, so sharing tuple is often safe. - self._re_sampled_buckets = () + self._count = sum(bucket.count for bucket in original_buckets) + self._max = mask_invalid_number(histogram_message.max) + self._min = mask_invalid_number(histogram_message.min) + self._histogram = Histogram(tuple(original_buckets), self._max, self._min, self._count) @property def max(self): @@ -114,148 +51,10 @@ class HistogramContainer: return self._count @property - def original_msg(self): - """Gets original proto message.""" - return self._msg - - @property - def original_buckets_count(self): - """Gets original buckets quantity.""" - return len(self._original_buckets) - - def set_visual_range(self, max_val: float, min_val: float, bins: int) -> None: - """ - Sets visual range for later re-sampling. - - It's caller's duty to ensure input is valid. - - Why we need visual range for histograms? Aligned buckets between steps can help users know about the trend of - tensors. Miss aligned buckets between steps might miss-lead users about the trend of a tensor. Because for - given tensor, if you have thinner buckets, count of every bucket will get lower, however, if you have - thicker buckets, count of every bucket will get higher. When they are displayed together, user might think - the histogram with thicker buckets has more values. This is miss-leading. So we need to unify buckets across - steps. Visual range for histogram is a technology for unifying buckets. - - Args: - max_val (float): Max value for visual histogram. - min_val (float): Min value for visual histogram. - bins (int): Bins number for visual histogram. - """ - if max_val < min_val: - raise ParamValueError( - "Invalid input. max_val({}) is less or equal than min_val({}).".format(max_val, min_val)) - - if bins < 1: - raise ParamValueError("Invalid input bins({}). Must be greater than 0.".format(bins)) - - self._visual_max = max_val - self._visual_min = min_val - self._visual_bins = bins - - # mark _re_sampled_buckets to empty - self._re_sampled_buckets = () - - def _calc_intersection_len(self, max1, min1, max2, min2): - """Calculates intersection length of [min1, max1] and [min2, max2].""" - if max1 < min1: - raise ParamValueError( - "Invalid input. max1({}) is less than min1({}).".format(max1, min1)) - - if max2 < min2: - raise ParamValueError( - "Invalid input. max2({}) is less than min2({}).".format(max2, min2)) - - if min1 <= min2: - if max1 <= min2: - # return value must be calculated by max1.__sub__ - return max1 - max1 - if max1 <= max2: - return max1 - min2 - # max1 > max2 - return max2 - min2 - - # min1 > min2 - if max2 <= min1: - return max2 - max2 - if max2 <= max1: - return max2 - min1 - return max1 - min1 - - def _re_sample_buckets(self): - """Re-samples buckets according to visual_max, visual_min and visual_bins.""" - if self._visual_max == self._visual_min: - # Adjust visual range if max equals min. - self._visual_max += 0.5 - self._visual_min -= 0.5 - - width = (self._visual_max - self._visual_min) / self._visual_bins - - if not self.count: - self._re_sampled_buckets = tuple( - Bucket(self._visual_min + width * i, width, 0) - for i in range(self._visual_bins)) - return - - re_sampled = [] - original_pos = 0 - original_bucket = self._original_buckets[original_pos] - for i in range(self._visual_bins): - cur_left = self._visual_min + width * i - cur_right = cur_left + width - cur_estimated_count = 0.0 - - # Skip no bucket range. - if cur_right <= original_bucket.left: - re_sampled.append(Bucket(cur_left, width, math.ceil(cur_estimated_count))) - continue - - # Skip no intersect range. - while cur_left >= original_bucket.right: - original_pos += 1 - if original_pos >= len(self._original_buckets): - break - original_bucket = self._original_buckets[original_pos] - - # entering with this condition: cur_right > original_bucket.left and cur_left < original_bucket.right - while True: - if original_pos >= len(self._original_buckets): - break - original_bucket = self._original_buckets[original_pos] - - intersection = self._calc_intersection_len( - min1=cur_left, max1=cur_right, - min2=original_bucket.left, max2=original_bucket.right) - if not original_bucket.width: - estimated_count = original_bucket.count - else: - estimated_count = (intersection / original_bucket.width) * original_bucket.count - - cur_estimated_count += estimated_count - if cur_right > original_bucket.right: - # Need to sample next original bucket to this visual bucket. - original_pos += 1 - else: - # Current visual bucket has taken all intersect buckets into account. - break - - re_sampled.append(Bucket(cur_left, width, math.ceil(cur_estimated_count))) - - self._re_sampled_buckets = tuple(re_sampled) - - def buckets(self, convert_to_tuple=True): - """ - Get visual buckets instead of original buckets. - - Args: - convert_to_tuple (bool): Whether convert bucket object to tuple. - - Returns: - tuple, contains buckets. - """ - if not self._re_sampled_buckets: - self._re_sample_buckets() - - if not convert_to_tuple: - return self._re_sampled_buckets + def histogram(self): + """Gets histogram data""" + return self._histogram - return tuple(bucket.as_tuple() for bucket in self._re_sampled_buckets) + def buckets(self): + """Gets histogram buckets""" + return self._histogram.buckets() diff --git a/mindinsight/datavisual/data_transform/ms_data_loader.py b/mindinsight/datavisual/data_transform/ms_data_loader.py index 32b7441dd74204d09fbf91dc52ad742ac36d7bf4..f4e062fdd7e79be65433cbcbb2b851ba596fe1e6 100644 --- a/mindinsight/datavisual/data_transform/ms_data_loader.py +++ b/mindinsight/datavisual/data_transform/ms_data_loader.py @@ -36,7 +36,9 @@ from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summar from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2 from mindinsight.datavisual.utils import crc32 from mindinsight.utils.exceptions import UnknownError +from mindinsight.datavisual.data_transform.histogram import Histogram from mindinsight.datavisual.data_transform.histogram_container import HistogramContainer +from mindinsight.datavisual.data_transform.tensor_container import TensorContainer HEADER_SIZE = 8 CRC_STR_SIZE = 4 @@ -390,6 +392,7 @@ class _SummaryParser(_Parser): 'scalar_value': PluginNameEnum.SCALAR, 'image': PluginNameEnum.IMAGE, 'histogram': PluginNameEnum.HISTOGRAM, + 'tensor': PluginNameEnum.TENSOR } if event.HasField('summary'): @@ -404,10 +407,12 @@ class _SummaryParser(_Parser): tensor_event_value = HistogramContainer(tensor_event_value) # Drop steps if original_buckets_count exceeds HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT # to avoid time-consuming re-sample process. - if tensor_event_value.original_buckets_count > HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT: + if tensor_event_value.histogram.original_buckets_count > Histogram.MAX_ORIGINAL_BUCKETS_COUNT: logger.info('original_buckets_count exceeds ' 'HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT') continue + elif plugin == 'tensor': + tensor_event_value = TensorContainer(tensor_event_value) tensor_event = TensorEvent(wall_time=event.wall_time, step=event.step, diff --git a/mindinsight/datavisual/data_transform/reservoir.py b/mindinsight/datavisual/data_transform/reservoir.py index 1ae380e3edc6146d67626080c5ed292ac43cbf96..c14e5a7161ff48b9e9a529b7ee3166ef0b8e7f42 100644 --- a/mindinsight/datavisual/data_transform/reservoir.py +++ b/mindinsight/datavisual/data_transform/reservoir.py @@ -205,12 +205,12 @@ class HistogramReservoir(Reservoir): visual_range = _VisualRange() max_count = 0 for sample in self._samples: - histogram = sample.value - if histogram.count == 0: + histogram_container = sample.value + if histogram_container.count == 0: # ignore empty tensor continue - max_count = max(histogram.count, max_count) - visual_range.update(histogram.max, histogram.min) + max_count = max(histogram_container.count, max_count) + visual_range.update(histogram_container.max, histogram_container.min) if visual_range.max == visual_range.min and not max_count: logger.info("Max equals to min. Count is zero.") @@ -225,7 +225,7 @@ class HistogramReservoir(Reservoir): bins, max_count) for sample in self._samples: - histogram = sample.value + histogram = sample.value.histogram histogram.set_visual_range(visual_range.max, visual_range.min, bins) self._visual_range_up_to_date = True @@ -245,6 +245,6 @@ class ReservoirFactory: Returns: Reservoir, reservoir instance for given plugin name. """ - if plugin_name == PluginNameEnum.HISTOGRAM.value: + if plugin_name in (PluginNameEnum.HISTOGRAM.value, PluginNameEnum.TENSOR.value): return HistogramReservoir(size) return Reservoir(size) diff --git a/mindinsight/datavisual/data_transform/tensor_container.py b/mindinsight/datavisual/data_transform/tensor_container.py new file mode 100644 index 0000000000000000000000000000000000000000..fc1cbba12d7bd83f633dacf1b947745e78ab5f4c --- /dev/null +++ b/mindinsight/datavisual/data_transform/tensor_container.py @@ -0,0 +1,269 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tensor data container.""" +import threading + +import numpy as np + +from mindinsight.datavisual.common.log import logger +from mindinsight.datavisual.data_transform.histogram import Histogram, Bucket +from mindinsight.datavisual.utils.utils import calc_histogram_bins +from mindinsight.utils.exceptions import ParamValueError + +F32_MIN, F32_MAX = np.finfo(np.float32).min, np.finfo(np.float32).max + + +class Statistics: + """Statistics data class. + + Args: + max_value (float): max value of tensor data. + min_value (float): min value of tensor data. + avg_value (float): avg value of tensor data. + count (int): total count of tensor data. + nan_count (int): count of NAN. + neg_inf_count (int): count of negative INF. + pos_inf_count (int): count of positive INF. + """ + + def __init__(self, max_value=0, min_value=0, avg_value=0, + count=0, nan_count=0, neg_inf_count=0, pos_inf_count=0): + self._max = max_value + self._min = min_value + self._avg = avg_value + self._count = count + self._nan_count = nan_count + self._neg_inf_count = neg_inf_count + self._pos_inf_count = pos_inf_count + + @property + def max(self): + """Get max value of tensor.""" + return self._max + + @property + def min(self): + """Get min value of tensor.""" + return self._min + + @property + def avg(self): + """Get avg value of tensor.""" + return self._avg + + @property + def count(self): + """Get total count of tensor.""" + return self._count + + @property + def nan_count(self): + """Get count of NAN.""" + return self._nan_count + + @property + def neg_inf_count(self): + """Get count of negative INF.""" + return self._neg_inf_count + + @property + def pos_inf_count(self): + """Get count of positive INF.""" + return self._pos_inf_count + + +def get_statistics_from_tensor(tensors): + """ + Calculates statistics data of tensor. + + Args: + tensors (numpy.ndarray): An numpy.ndarray of tensor data. + + Returns: + an instance of Statistics. + """ + ma_value = np.ma.masked_invalid(tensors) + total, valid = tensors.size, ma_value.count() + invalids = [] + for isfn in np.isnan, np.isposinf, np.isneginf: + if total - valid > sum(invalids): + count = np.count_nonzero(isfn(tensors)) + invalids.append(count) + else: + invalids.append(0) + + nan_count, pos_inf_count, neg_inf_count = invalids + if not valid: + logger.warning('There are no valid values in the tensors(size=%d, shape=%s)', total, tensors.shape) + statistics = Statistics(max_value=0, + min_value=0, + avg_value=0, + count=total, + nan_count=nan_count, + neg_inf_count=neg_inf_count, + pos_inf_count=pos_inf_count) + return statistics + + # BUG: max of a masked array with dtype np.float16 returns inf + # See numpy issue#15077 + if issubclass(tensors.dtype.type, np.floating): + tensor_min = ma_value.min(fill_value=np.PINF) + tensor_max = ma_value.max(fill_value=np.NINF) + if tensor_min < F32_MIN or tensor_max > F32_MAX: + logger.warning('Values(%f, %f) are too large, you may encounter some undefined ' + 'behaviours hereafter.', tensor_min, tensor_max) + else: + tensor_min = ma_value.min() + tensor_max = ma_value.max() + tensor_sum = ma_value.sum(dtype=np.float64) + statistics = Statistics(max_value=tensor_max, + min_value=tensor_min, + avg_value=tensor_sum / valid, + count=total, + nan_count=nan_count, + neg_inf_count=neg_inf_count, + pos_inf_count=pos_inf_count) + return statistics + + +def _get_data_from_tensor(tensor): + """ + Get data from tensor and convert to tuple. + + Args: + tensor (TensorProto): Tensor proto data. + + Returns: + tuple, the item of tensor value. + """ + return tuple(tensor.float_data) + + +def calc_original_buckets(np_value, stats): + """ + Calculate buckets from tensor data. + + Args: + np_value (numpy.ndarray): An numpy.ndarray of tensor data. + stats (Statistics): An instance of Statistics about tensor data. + + Returns: + list, a list of bucket about tensor data. + + Raises: + ParamValueError, If np_value or stats is None. + """ + if np_value is None or stats is None: + raise ParamValueError("Invalid input. np_value or stats is None.") + valid_count = stats.count - stats.nan_count - stats.neg_inf_count - stats.pos_inf_count + if not valid_count: + return [] + + bins = calc_histogram_bins(valid_count) + first_edge, last_edge = stats.min, stats.max + + if not first_edge < last_edge: + first_edge -= 0.5 + last_edge += 0.5 + + bins = np.linspace(first_edge, last_edge, bins + 1, dtype=np_value.dtype) + hists, edges = np.histogram(np_value, bins=bins) + + buckets = [] + for hist, edge1, edge2 in zip(hists, edges, edges[1:]): + bucket = Bucket(edge1, edge2 - edge1, hist) + buckets.append(bucket) + + return buckets + + +class TensorContainer: + """ + Tensor data container. + + Args: + tensor_message (Summary.TensorProto): Tensor message in summary file. + """ + + def __init__(self, tensor_message): + self._lock = threading.Lock + self._msg = tensor_message + self._dims = tensor_message.dims + self._data_type = tensor_message.data_type + self._np_array = None + self._data = _get_data_from_tensor(tensor_message) + self._stats = get_statistics_from_tensor(self.get_or_calc_ndarray()) + original_buckets = calc_original_buckets(self.get_or_calc_ndarray(), self._stats) + self._count = sum(bucket.count for bucket in original_buckets) + self._max = self._stats.max + self._min = self._stats.min + self._histogram = Histogram(tuple(original_buckets), self._max, self._min, self._count) + + @property + def dims(self): + """Get dims of tensor.""" + return self._dims + + @property + def data_type(self): + """Get data type of tensor.""" + return self._data_type + + @property + def max(self): + """Get max value of tensor.""" + return self._max + + @property + def min(self): + """Get min value of tensor.""" + return self._min + + @property + def stats(self): + """Get statistics data of tensor.""" + return self._stats + + @property + def count(self): + """Get count value of tensor.""" + return self._count + + @property + def histogram(self): + """Get histogram data.""" + return self._histogram + + def buckets(self): + """Get histogram buckets.""" + return self._histogram.buckets() + + def get_or_calc_ndarray(self): + """Get or calculate ndarray.""" + with self._lock(): + if self._np_array is None: + self._convert_to_numpy_array() + return self._np_array + + def _convert_to_numpy_array(self): + """Convert a list data to numpy array.""" + try: + ndarray = np.array(self._data).reshape(self._dims) + except ValueError as ex: + logger.error("Reshape array fail, detail: %r", str(ex)) + return + + self._msg = None + self._np_array = ndarray diff --git a/mindinsight/datavisual/processors/tensor_processor.py b/mindinsight/datavisual/processors/tensor_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..07996f1272003d0fb28631b734d825dee2d9a769 --- /dev/null +++ b/mindinsight/datavisual/processors/tensor_processor.py @@ -0,0 +1,372 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tensor Processor APIs.""" +from urllib.parse import unquote + +import numpy as np + +from mindinsight.datavisual.utils.tools import to_int +from mindinsight.utils.exceptions import ParamValueError, UrlDecodeError +from mindinsight.conf.constants import MAX_TENSOR_RESPONSE_DATA_SIZE +from mindinsight.datavisual.common.validation import Validation +from mindinsight.datavisual.common.exceptions import StepTensorDataNotInCacheError, TensorNotExistError +from mindinsight.datavisual.common.exceptions import ResponseDataExceedMaxValueError +from mindinsight.datavisual.data_transform.tensor_container import TensorContainer, get_statistics_from_tensor +from mindinsight.datavisual.processors.base_processor import BaseProcessor +from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2 + + +def convert_array_from_str(dims, limit=0): + """ + Convert string of dims data to array. + + Args: + dims (str): Specify dims of tensor. + limit (int): The max flexible dimension count, default value is 0 which means that there is no limitation. + + Returns: + list, a string like this: "[0, 0, :, :]" will convert to this value: [0, 0, None, None]. + + Raises: + ParamValueError, If flexible dimensions exceed limit value. + """ + dims = dims.replace('[', '') \ + .replace(']', '') + dims_list = [] + count = 0 + for dim in dims.split(','): + dim = dim.strip() + if dim == ':': + dims_list.append(None) + count += 1 + else: + dims_list.append(to_int(dim, "dim")) + if limit and count > limit: + raise ParamValueError("Flexible dimensions cannot exceed limit value: {}, size: {}" + .format(limit, count)) + return dims_list + + +def get_specific_dims_data(ndarray, dims, tensor_dims): + """ + Get specific dims data. + + Args: + ndarray (numpy.ndarray): An ndarray of numpy. + dims (list): A list of specific dims. + tensor_dims (list): A list of tensor dims. + + Returns: + numpy.ndarray, an ndarray of specific dims tensor data. + + Raises: + ParamValueError, If the length of param dims is not equal to the length of tensor dims or + the index of param dims out of range. + """ + if len(dims) != len(tensor_dims): + raise ParamValueError("The length of param dims: {}, is not equal to the " + "length of tensor dims: {}.".format(len(dims), len(tensor_dims))) + indices = [] + for k, d in enumerate(dims): + if d is not None: + if d >= tensor_dims[k]: + raise ParamValueError("The index: {} of param dims out of range: {}.".format(d, tensor_dims[k])) + indices.append(d) + else: + indices.append(slice(0, tensor_dims[k])) + return ndarray[tuple(indices)] + + +def get_statistics_dict(tensor_container, tensors): + """ + Get statistics dict according to tensor data. + + Args: + tensor_container (TensorContainer): An instance of TensorContainer. + tensors (numpy.ndarray or number): An numpy.ndarray or number of tensor data. + + Returns: + dict, a dict including 'max', 'min', 'avg', 'count', 'nan_count', 'neg_inf_count', 'pos_inf_count'. + """ + if tensors is None: + statistics = { + "max": tensor_container.stats.max, + "min": tensor_container.stats.min, + "avg": tensor_container.stats.avg, + "count": tensor_container.stats.count, + "nan_count": tensor_container.stats.nan_count, + "neg_inf_count": tensor_container.stats.neg_inf_count, + "pos_inf_count": tensor_container.stats.pos_inf_count + } + return statistics + + if not isinstance(tensors, np.ndarray): + tensors = np.array(tensors) + + stats = get_statistics_from_tensor(tensors) + statistics = { + "max": stats.max, + "min": stats.min, + "avg": stats.avg, + "count": stats.count, + "nan_count": stats.nan_count, + "neg_inf_count": stats.neg_inf_count, + "pos_inf_count": stats.pos_inf_count + } + return statistics + + +class TensorProcessor(BaseProcessor): + """Tensor Processor.""" + def get_tensors(self, train_ids, tags, step, dims, detail): + """ + Get tensor data for given train_ids, tags, step, dims and detail. + + Args: + train_ids (list): Specify list of train job ID. + tags (list): Specify list of tag. + step (int): Specify step of tag, it's necessary when detail is equal to 'data'. + dims (str): Specify dims of step, it's necessary when detail is equal to 'data'. + detail (str): Specify which data to query, available values: 'stats', 'histogram' and 'data'. + + Returns: + dict, a dict including the `tensors`. + + Raises: + UrlDecodeError, If unquote train id error with strict mode. + """ + Validation.check_param_empty(train_id=train_ids, tag=tags) + for index, train_id in enumerate(train_ids): + try: + train_id = unquote(train_id, errors='strict') + except UnicodeDecodeError: + raise UrlDecodeError('Unquote train id error with strict mode') + else: + train_ids[index] = train_id + + tensors = [] + for train_id in train_ids: + tensors += self._get_train_tensors(train_id, tags, step, dims, detail) + + return {"tensors": tensors} + + def _get_train_tensors(self, train_id, tags, step, dims, detail): + """ + Get tensor data for given train_id, tags, step, dims and detail. + + Args: + train_id (str): Specify list of train job ID. + tags (list): Specify list of tag. + step (int): Specify step of tensor, it's necessary when detail is set to 'data'. + dims (str): Specify dims of tensor, it's necessary when detail is set to 'data'. + detail (str): Specify which data to query, available values: 'stats', 'histogram' and 'data'. + + Returns: + list[dict], a list of dictionaries containing the `train_id`, `tag`, `values`. + + Raises: + TensorNotExistError, If tensor with specific train_id and tag is not exist in cache. + ParamValueError, If the value of detail is not within available values: + 'stats', 'histogram' and 'data'. + """ + + tensors_response = [] + for tag in tags: + try: + tensors = self._data_manager.list_tensors(train_id, tag) + except ParamValueError as err: + raise TensorNotExistError(err.message) + + if tensors and not isinstance(tensors[0].value, TensorContainer): + raise TensorNotExistError("there is no tensor data in this tag: {}".format(tag)) + + if detail is None or detail == 'stats': + values = self._get_tensors_summary(detail, tensors) + elif detail == 'data': + Validation.check_param_empty(step=step, dims=dims) + step = to_int(step, "step") + values = self._get_tensors_data(step, dims, tensors) + elif detail == 'histogram': + values = self._get_tensors_histogram(tensors) + else: + raise ParamValueError('Can not support this value: {} of detail.'.format(detail)) + + tensor = { + "train_id": train_id, + "tag": tag, + "values": values + } + tensors_response.append(tensor) + + return tensors_response + + def _get_tensors_summary(self, detail, tensors): + """ + Builds a JSON-serializable object with information about tensor summary. + + Args: + detail (str): Specify which data to query, detail value is None or 'stats' at this method. + tensors (list): The list of _Tensor data. + + Returns: + dict, a dict including the `wall_time`, `step`, and `value' for each tensor. + { + "wall_time": 0, + "step": 0, + "value": { + "dims": [1], + "data_type": "DT_FLOAT32" + "statistics": { + "max": 0, + "min": 0, + "avg": 0, + "count": 1, + "nan_count": 0, + "neg_inf_count": 0, + "pos_inf_count": 0 + } This dict is being set when detail is equal to stats. + } + } + """ + values = [] + for tensor in tensors: + # This value is an instance of TensorContainer + value = tensor.value + value_dict = { + "dims": tuple(value.dims), + "data_type": anf_ir_pb2.DataType.Name(value.data_type) + } + if detail and detail == 'stats': + stats = get_statistics_dict(value, None) + value_dict.update({"statistics": stats}) + + values.append({ + "wall_time": tensor.wall_time, + "step": tensor.step, + "value": value_dict + }) + + return values + + def _get_tensors_data(self, step, dims, tensors): + """ + Builds a JSON-serializable object with information about tensor dims data. + + Args: + step (int): Specify step of tensor. + dims (str): Specify dims of tensor. + tensors (list): The list of _Tensor data. + + Returns: + dict, a dict including the `wall_time`, `step`, and `value' for each tensor. + { + "wall_time": 0, + "step": 0, + "value": { + "dims": [1], + "data_type": "DT_FLOAT32", + "data": [[0.1]] + "statistics": { + "max": 0, + "min": 0, + "avg": 0, + "count": 1, + "nan_count": 0, + "neg_inf_count": 0, + "pos_inf_count": 0 + } + } + } + + Raises: + ResponseDataExceedMaxValueError, If the size of response data exceed max value. + StepTensorDataNotInCacheError, If query step is not in cache. + """ + values = [] + step_in_cache = False + dims = convert_array_from_str(dims, limit=2) + for tensor in tensors: + # This value is an instance of TensorContainer + value = tensor.value + if step != tensor.step: + continue + step_in_cache = True + ndarray = value.get_or_calc_ndarray() + res_data = get_specific_dims_data(ndarray, dims, list(value.dims)) + flatten_data = res_data.flatten().tolist() + if len(flatten_data) > MAX_TENSOR_RESPONSE_DATA_SIZE: + raise ResponseDataExceedMaxValueError("the size of response data: {} exceed max value: {}." + .format(len(flatten_data), MAX_TENSOR_RESPONSE_DATA_SIZE)) + values.append({ + "wall_time": tensor.wall_time, + "step": tensor.step, + "value": { + "dims": tuple(value.dims), + "data_type": anf_ir_pb2.DataType.Name(value.data_type), + "data": res_data.tolist(), + "statistics": get_statistics_dict(value, flatten_data) + } + }) + break + if not step_in_cache: + raise StepTensorDataNotInCacheError("this step: {} data may has been dropped.".format(step)) + + return values + + def _get_tensors_histogram(self, tensors): + """ + Builds a JSON-serializable object with information about tensor histogram data. + + Args: + tensors (list): The list of _Tensor data. + + Returns: + dict, a dict including the `wall_time`, `step`, and `value' for each tensor. + { + "wall_time": 0, + "step": 0, + "value": { + "dims": [1], + "data_type": "DT_FLOAT32", + "histogram_buckets": [[0.1, 0.2, 3]] + "statistics": { + "max": 0, + "min": 0, + "avg": 0, + "count": 1, + "nan_count": 0, + "neg_inf_count": 0, + "pos_inf_count": 0 + } + } + } + """ + values = [] + for tensor in tensors: + # This value is an instance of TensorContainer + value = tensor.value + buckets = value.buckets() + values.append({ + "wall_time": tensor.wall_time, + "step": tensor.step, + "value": { + "dims": tuple(value.dims), + "data_type": anf_ir_pb2.DataType.Name(value.data_type), + "histogram_buckets": buckets, + "statistics": get_statistics_dict(value, None) + } + }) + + return values diff --git a/mindinsight/utils/constant.py b/mindinsight/utils/constant.py index 0bac03c7594181341eb4df5b062fc77fd09ed5d9..b84f0f54b3b900dd47b6abb051da3e2ff37bbf33 100644 --- a/mindinsight/utils/constant.py +++ b/mindinsight/utils/constant.py @@ -71,6 +71,9 @@ class DataVisualErrors(Enum): HISTOGRAM_NOT_EXIST = 15 TRAIN_JOB_DETAIL_NOT_IN_CACHE = 16 QUERY_STRING_CONTAINS_NULL_BYTE = 17 + TENSOR_NOT_EXIST = 18 + MAX_RESPONSE_DATA_EXCEEDED_ERROR = 19 + STEP_TENSOR_DATA_NOT_IN_CACHE = 20 class ScriptConverterErrors(Enum): diff --git a/tests/ut/datavisual/data_transform/test_histogram_container.py b/tests/ut/datavisual/data_transform/test_histogram_container.py index 615ff5648c40d102131f55236b89f65bd7df06c1..dab192d695967b3f4da119fe6bd07c9c7fced999 100644 --- a/tests/ut/datavisual/data_transform/test_histogram_container.py +++ b/tests/ut/datavisual/data_transform/test_histogram_container.py @@ -29,9 +29,9 @@ class TestHistogram: mocked_bucket.width = 1 mocked_bucket.count = 1 mocked_input.buckets = [mocked_bucket] - histogram = hist.HistogramContainer(mocked_input) - histogram.set_visual_range(max_val=1, min_val=0, bins=1) - buckets = histogram.buckets() + histogram_container = hist.HistogramContainer(mocked_input) + histogram_container.histogram.set_visual_range(max_val=1, min_val=0, bins=1) + buckets = histogram_container.buckets() assert buckets == ((0.0, 1.0, 1),) def test_re_sample_buckets_split_original(self): @@ -42,9 +42,9 @@ class TestHistogram: mocked_bucket.width = 1 mocked_bucket.count = 1 mocked_input.buckets = [mocked_bucket] - histogram = hist.HistogramContainer(mocked_input) - histogram.set_visual_range(max_val=1, min_val=0, bins=3) - buckets = histogram.buckets() + histogram_container = hist.HistogramContainer(mocked_input) + histogram_container.histogram.set_visual_range(max_val=1, min_val=0, bins=3) + buckets = histogram_container.buckets() assert buckets == ((0.0, 0.3333333333333333, 1), (0.3333333333333333, 0.3333333333333333, 1), (0.6666666666666666, 0.3333333333333333, 1)) @@ -60,9 +60,9 @@ class TestHistogram: mocked_bucket2.width = 1 mocked_bucket2.count = 2 mocked_input.buckets = [mocked_bucket, mocked_bucket2] - histogram = hist.HistogramContainer(mocked_input) - histogram.set_visual_range(max_val=3, min_val=-1, bins=4) - buckets = histogram.buckets() + histogram_container = hist.HistogramContainer(mocked_input) + histogram_container.histogram.set_visual_range(max_val=3, min_val=-1, bins=4) + buckets = histogram_container.buckets() assert buckets == ((-1.0, 1.0, 0), (0.0, 1.0, 1), (1.0, 1.0, 2), (2.0, 1.0, 0)) def test_re_sample_buckets_merge_bucket(self): @@ -77,9 +77,9 @@ class TestHistogram: mocked_bucket2.width = 1 mocked_bucket2.count = 10 mocked_input.buckets = [mocked_bucket, mocked_bucket2] - histogram = hist.HistogramContainer(mocked_input) - histogram.set_visual_range(max_val=3, min_val=-1, bins=5) - buckets = histogram.buckets() + histogram_container = hist.HistogramContainer(mocked_input) + histogram_container.histogram.set_visual_range(max_val=3, min_val=-1, bins=5) + buckets = histogram_container.buckets() assert buckets == ( (-1.0, 0.8, 0), (-0.19999999999999996, 0.8, 1), (0.6000000000000001, 0.8, 5), (1.4000000000000004, 0.8, 6), (2.2, 0.8, 0)) @@ -96,9 +96,9 @@ class TestHistogram: mocked_bucket2.width = 0 mocked_bucket2.count = 2 mocked_input.buckets = [mocked_bucket, mocked_bucket2] - histogram = hist.HistogramContainer(mocked_input) - histogram.set_visual_range(max_val=2, min_val=0, bins=3) - buckets = histogram.buckets() + histogram_container = hist.HistogramContainer(mocked_input) + histogram_container.histogram.set_visual_range(max_val=2, min_val=0, bins=3) + buckets = histogram_container.buckets() assert buckets == ( (0.0, 0.6666666666666666, 1), (0.6666666666666666, 0.6666666666666666, 3), diff --git a/tests/ut/datavisual/processors/test_train_task_manager.py b/tests/ut/datavisual/processors/test_train_task_manager.py index 80c4eb989fc1716b0a5f1ac45e6ce91db557db45..fefce9f10093f405424d9bbd9ad583e292bce97d 100644 --- a/tests/ut/datavisual/processors/test_train_task_manager.py +++ b/tests/ut/datavisual/processors/test_train_task_manager.py @@ -69,7 +69,7 @@ class TestTrainTaskManager: def load_data(self): """Load data.""" log_operation = LogOperations() - self._plugins_id_map = {'image': [], 'scalar': [], 'graph': [], 'histogram': []} + self._plugins_id_map = {'image': [], 'scalar': [], 'graph': [], 'histogram': [], 'tensor': []} self._events_names = [] self._train_id_list = [] diff --git a/tests/utils/log_generators/tensor_log_generator.py b/tests/utils/log_generators/tensor_log_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..485adf4e0094dc2572fe2931fe43776c3edf351c --- /dev/null +++ b/tests/utils/log_generators/tensor_log_generator.py @@ -0,0 +1,110 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Log generator for tensor data.""" +import time + +from operator import mul +from functools import reduce +import numpy as np +from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2 +from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 + +from .log_generator import LogGenerator + + +class TensorLogGenerator(LogGenerator): + """ + Log generator for tensor data. + + This is a log generator writing tensor data. User can use it to generate fake + summary logs about tensor. + """ + + def generate_event(self, values): + """ + Method for generating tensor event. + + Args: + values (dict): A dict contains: + { + wall_time (float): Timestamp. + step (int): Train step. + value (float): Tensor value. + tag (str): Tag name. + } + + Returns: + summary_pb2.Event. + + """ + tensor_event = summary_pb2.Event() + tensor_event.wall_time = values.get('wall_time') + tensor_event.step = values.get('step') + + value = tensor_event.summary.value.add() + value.tag = values.get('tag') + tensor = values.get('value') + + value.tensor.dims[:] = tensor.get('dims') + value.tensor.data_type = tensor.get('data_type') + value.tensor.float_data[:] = tensor.get('float_data') + print(tensor.get('float_data')) + + return tensor_event + + def generate_log(self, file_path, steps_list, tag_name): + """ + Generate log for external calls. + + Args: + file_path (str): Path to write logs. + steps_list (list): A list consists of step. + tag_name (str): Tag name. + + Returns: + list[dict], generated tensor metadata. + list, generated tensors. + + """ + tensor_metadata = [] + tensor_values = dict() + for step in steps_list: + tensor = dict() + + wall_time = time.time() + tensor.update({'wall_time': wall_time}) + tensor.update({'step': step}) + tensor.update({'tag': tag_name}) + dims = list(np.random.randint(1, 10, 4)) + mul_value = reduce(mul, dims) + tensor.update({'value': { + "dims": dims, + "data_type": anf_ir_pb2.DataType.DT_FLOAT32, + "float_data": np.random.randn(mul_value) + }}) + tensor_metadata.append(tensor) + tensor_values.update({step: tensor}) + + self._write_log_one_step(file_path, tensor) + + return tensor_metadata, tensor_values + + +if __name__ == "__main__": + tensor_log_generator = TensorLogGenerator() + test_file_name = '%s.%s.%s' % ('tensor', 'summary', str(time.time())) + test_steps = [1, 3, 5] + test_tag = "test_tensor_tag_name" + tensor_log_generator.generate_log(test_file_name, test_steps, test_tag) diff --git a/tests/utils/log_operations.py b/tests/utils/log_operations.py index 026382bf485865ba01e93b9589f8f0a75446db21..55bf46a4b83451c8d647d1530b122f8421b07ad3 100644 --- a/tests/utils/log_operations.py +++ b/tests/utils/log_operations.py @@ -25,12 +25,14 @@ from .log_generators.graph_log_generator import GraphLogGenerator from .log_generators.images_log_generator import ImagesLogGenerator from .log_generators.scalars_log_generator import ScalarsLogGenerator from .log_generators.histogram_log_generator import HistogramLogGenerator +from .log_generators.tensor_log_generator import TensorLogGenerator log_generators = { PluginNameEnum.GRAPH.value: GraphLogGenerator(), PluginNameEnum.IMAGE.value: ImagesLogGenerator(), PluginNameEnum.SCALAR.value: ScalarsLogGenerator(), - PluginNameEnum.HISTOGRAM.value: HistogramLogGenerator() + PluginNameEnum.HISTOGRAM.value: HistogramLogGenerator(), + PluginNameEnum.TENSOR.value: TensorLogGenerator() }