diff --git a/mindinsight/datavisual/data_transform/histogram_container.py b/mindinsight/datavisual/data_transform/histogram_container.py index a3e55f62339294f05ebf378943cfc6dab671cb79..45ae1df35dc46fc38824b9f4af807b417c48823f 100644 --- a/mindinsight/datavisual/data_transform/histogram_container.py +++ b/mindinsight/datavisual/data_transform/histogram_container.py @@ -16,6 +16,8 @@ import math from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Summary +from mindinsight.utils.exceptions import ParamValueError +from mindinsight.datavisual.utils.utils import calc_histogram_bins def _mask_invalid_number(num): @@ -26,6 +28,49 @@ def _mask_invalid_number(num): return num +class Bucket: + """ + Bucket data class. + + Args: + left (double): Left edge of the histogram bucket. + width (double): Width of the histogram bucket. + count (int): Count of numbers fallen in the histogram bucket. + """ + def __init__(self, left, width, count): + self._left = left + self._width = width + self._count = count + + @property + def left(self): + """Gets left edge of the histogram bucket.""" + return self._left + + @property + def count(self): + """Gets count of numbers fallen in the histogram bucket.""" + return self._count + + @property + def width(self): + """Gets width of the histogram bucket.""" + return self._width + + @property + def right(self): + """Gets right edge of the histogram bucket.""" + return self._left + self._width + + def as_tuple(self): + """Gets the bucket as tuple.""" + return self._left, self._width, self._count + + def __repr__(self): + """Returns repr(self).""" + return "Bucket(left={}, width={}, count={})".format(self._left, self._width, self._count) + + class HistogramContainer: """ Histogram data container. @@ -35,16 +80,19 @@ class HistogramContainer: """ def __init__(self, histogram_message: Summary.Histogram): self._msg = histogram_message - self._original_buckets = tuple((bucket.left, bucket.width, bucket.count) for bucket in self._msg.buckets) + original_buckets = [Bucket(bucket.left, bucket.width, bucket.count) for bucket in self._msg.buckets] + # Ensure buckets are sorted from min to max. + original_buckets.sort(key=lambda bucket: bucket.left) + self._original_buckets = tuple(original_buckets) + self._count = sum(bucket.count for bucket in self._original_buckets) self._max = _mask_invalid_number(histogram_message.max) self._min = _mask_invalid_number(histogram_message.min) self._visual_max = self._max self._visual_min = self._min # default bin number - self._visual_bins = 10 - self._count = sum(bucket[2] for bucket in self._original_buckets) + self._visual_bins = calc_histogram_bins(self._count) # Note that tuple is immutable, so sharing tuple is often safe. - self._re_sampled_buckets = self._original_buckets + self._re_sampled_buckets = () @property def max(self): @@ -63,7 +111,7 @@ class HistogramContainer: @property def original_msg(self): - """Get original proto message""" + """Gets original proto message.""" return self._msg def set_visual_range(self, max_val: float, min_val: float, bins: int) -> None: @@ -77,6 +125,13 @@ class HistogramContainer: min_val (float): Min value for visual histogram. bins (int): Bins number for visual histogram. """ + if max_val < min_val: + raise ParamValueError( + "Invalid input. max_val({}) is less or equal than min_val({}).".format(max_val, min_val)) + + if bins < 1: + raise ParamValueError("Invalid input bins({}). Must be greater than 0.".format(bins)) + self._visual_max = max_val self._visual_min = min_val self._visual_bins = bins @@ -84,15 +139,104 @@ class HistogramContainer: # mark _re_sampled_buckets to empty self._re_sampled_buckets = () - def _re_sample_buckets(self): - # Will call re-sample logic in later PR. - self._re_sampled_buckets = self._original_buckets + def _calc_intersection_len(self, max1, min1, max2, min2): + """Calculates intersection length of [min1, max1] and [min2, max2].""" + if max1 < min1: + raise ParamValueError( + "Invalid input. max1({}) is less than min1({}).".format(max1, min1)) + + if max2 < min2: + raise ParamValueError( + "Invalid input. max2({}) is less than min2({}).".format(max2, min2)) + + if min1 <= min2: + if max1 <= min2: + # return value must be calculated by max1.__sub__ + return max1 - max1 + if max1 <= max2: + return max1 - min2 + # max1 > max2 + return max2 - min2 + + # min1 > min2 + if max2 <= min1: + return max2 - max2 + if max2 <= max1: + return max2 - min1 + return max1 - min1 - def buckets(self): + def _re_sample_buckets(self): + """Re-samples buckets according to visual_max, visual_min and visual_bins.""" + if self._visual_max == self._visual_min: + # Adjust visual range if max equals min. + self._visual_max += 0.5 + self._visual_min -= 0.5 + + width = (self._visual_max - self._visual_min) / self._visual_bins + + if not self.count: + self._re_sampled_buckets = tuple( + Bucket(self._visual_min + width * i, width, 0) + for i in range(self._visual_bins)) + return + + re_sampled = [] + original_pos = 0 + original_bucket = self._original_buckets[original_pos] + for i in range(self._visual_bins): + cur_left = self._visual_min + width * i + cur_right = cur_left + width + cur_estimated_count = 0.0 + + # Skip no bucket range. + if cur_right <= original_bucket.left: + re_sampled.append(Bucket(cur_left, width, math.ceil(cur_estimated_count))) + continue + + # Skip no intersect range. + while cur_left >= original_bucket.right: + original_pos += 1 + if original_pos >= len(self._original_buckets): + break + original_bucket = self._original_buckets[original_pos] + + # entering with this condition: cur_right > original_bucket.left and cur_left < original_bucket.right + while True: + if original_pos >= len(self._original_buckets): + break + original_bucket = self._original_buckets[original_pos] + + intersection = self._calc_intersection_len( + min1=cur_left, max1=cur_right, + min2=original_bucket.left, max2=original_bucket.right) + estimated_count = (intersection / original_bucket.width) * original_bucket.count + + cur_estimated_count += estimated_count + if cur_right > original_bucket.right: + # Need to sample next original bucket to this visual bucket. + original_pos += 1 + else: + # Current visual bucket has taken all intersect buckets into account. + break + + re_sampled.append(Bucket(cur_left, width, math.ceil(cur_estimated_count))) + + self._re_sampled_buckets = tuple(re_sampled) + + def buckets(self, convert_to_tuple=True): """ Get visual buckets instead of original buckets. + + Args: + convert_to_tuple (bool): Whether convert bucket object to tuple. + + Returns: + tuple, contains buckets. """ if not self._re_sampled_buckets: self._re_sample_buckets() - return self._re_sampled_buckets + if not convert_to_tuple: + return self._re_sampled_buckets + + return tuple(bucket.as_tuple() for bucket in self._re_sampled_buckets) diff --git a/mindinsight/datavisual/data_transform/reservoir.py b/mindinsight/datavisual/data_transform/reservoir.py index 0a1444fe32c1a26dbb490a82096e3718dbc1c43a..aef9879a7bfed11aa9045c74619b193ef9ede93f 100644 --- a/mindinsight/datavisual/data_transform/reservoir.py +++ b/mindinsight/datavisual/data_transform/reservoir.py @@ -16,10 +16,11 @@ import random import threading -import math +from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.utils.exceptions import ParamValueError +from mindinsight.datavisual.utils.utils import calc_histogram_bins class Reservoir: @@ -173,39 +174,20 @@ class HistogramReservoir(Reservoir): max_count = max(histogram.count, max_count) visual_range.update(histogram.max, histogram.min) - bins = self._calc_bins(max_count) + bins = calc_histogram_bins(max_count) # update visual range + logger.info("Visual histogram: min %s, max %s, bins %s, max_count %s.", + visual_range.min, + visual_range.max, + bins, + max_count) for sample in self._samples: histogram = sample.value histogram.set_visual_range(visual_range.max, visual_range.min, bins) return list(self._samples) - def _calc_bins(self, count): - """ - Calculates experience-based optimal bins number. - - To suppress re-sample bias, there should be enough number in each bin. So we calc bin numbers according to - count. For very small count(1 - 10), we assign carefully chosen number. For large count, we tried to make - sure there are 9-10 numbers in each bucket on average. Too many bins will also distract users, so we set max - number of bins to 30. - """ - number_per_bucket = 10 - max_bins = 30 - - if not count: - return 1 - if count <= 5: - return 2 - if count <= 10: - return 3 - if count <= 280: - # note that math.ceil(281/10) + 1 = 30 - return math.ceil(count / number_per_bucket) + 1 - - return max_bins - class ReservoirFactory: """Factory class to get reservoir instances.""" diff --git a/mindinsight/datavisual/utils/utils.py b/mindinsight/datavisual/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2dac02b7612457548e45d034696e407629cfcfc4 --- /dev/null +++ b/mindinsight/datavisual/utils/utils.py @@ -0,0 +1,47 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utils.""" +import math + + +def calc_histogram_bins(count): + """ + Calculates experience-based optimal bins number for histogram. + + To suppress re-sample bias, there should be enough number in each bin. So we calc bin numbers according to + count. For very small count(1 - 10), we assign carefully chosen number. For large count, we tried to make + sure there are 9-10 numbers in each bucket on average. Too many bins will also distract users, so we set max + number of bins to 30. + + Args: + count (int): Valid number count for the tensor. + + Returns: + int, number of histogram bins. + """ + number_per_bucket = 10 + max_bins = 30 + + if not count: + return 1 + if count <= 5: + return 2 + if count <= 10: + return 3 + if count <= 280: + # note that math.ceil(281/10) + 1 equals 30 + return math.ceil(count / number_per_bucket) + 1 + + return max_bins diff --git a/tests/ut/datavisual/data_transform/test_histogram_container.py b/tests/ut/datavisual/data_transform/test_histogram_container.py index 4263065fce0a1da2d8b61edaac6f26c396f4a626..e68154fb3fcb1c839b5231ad698741333cf3d14b 100644 --- a/tests/ut/datavisual/data_transform/test_histogram_container.py +++ b/tests/ut/datavisual/data_transform/test_histogram_container.py @@ -20,8 +20,9 @@ from mindinsight.datavisual.data_transform import histogram_container as hist class TestHistogram: """Test histogram.""" + def test_get_buckets(self): - """Test get buckets.""" + """Tests get buckets.""" mocked_input = mock.MagicMock() mocked_bucket = mock.MagicMock() mocked_bucket.left = 0 @@ -31,4 +32,54 @@ class TestHistogram: histogram = hist.HistogramContainer(mocked_input) histogram.set_visual_range(max_val=1, min_val=0, bins=1) buckets = histogram.buckets() - assert len(buckets) == 1 \ No newline at end of file + assert buckets == ((0.0, 1.0, 1),) + + def test_re_sample_buckets_split_original(self): + """Tests splitting original buckets when re-sampling.""" + mocked_input = mock.MagicMock() + mocked_bucket = mock.MagicMock() + mocked_bucket.left = 0 + mocked_bucket.width = 1 + mocked_bucket.count = 1 + mocked_input.buckets = [mocked_bucket] + histogram = hist.HistogramContainer(mocked_input) + histogram.set_visual_range(max_val=1, min_val=0, bins=3) + buckets = histogram.buckets() + assert buckets == ((0.0, 0.3333333333333333, 1), (0.3333333333333333, 0.3333333333333333, 1), + (0.6666666666666666, 0.3333333333333333, 1)) + + def test_re_sample_buckets_zero_bucket(self): + """Tests zero bucket when re-sampling.""" + mocked_input = mock.MagicMock() + mocked_bucket = mock.MagicMock() + mocked_bucket.left = 0 + mocked_bucket.width = 1 + mocked_bucket.count = 1 + mocked_bucket2 = mock.MagicMock() + mocked_bucket2.left = 1 + mocked_bucket2.width = 1 + mocked_bucket2.count = 2 + mocked_input.buckets = [mocked_bucket, mocked_bucket2] + histogram = hist.HistogramContainer(mocked_input) + histogram.set_visual_range(max_val=3, min_val=-1, bins=4) + buckets = histogram.buckets() + assert buckets == ((-1.0, 1.0, 0), (0.0, 1.0, 1), (1.0, 1.0, 2), (2.0, 1.0, 0)) + + def test_re_sample_buckets_merge_bucket(self): + """Tests merging counts from two buckets when re-sampling.""" + mocked_input = mock.MagicMock() + mocked_bucket = mock.MagicMock() + mocked_bucket.left = 0 + mocked_bucket.width = 1 + mocked_bucket.count = 1 + mocked_bucket2 = mock.MagicMock() + mocked_bucket2.left = 1 + mocked_bucket2.width = 1 + mocked_bucket2.count = 10 + mocked_input.buckets = [mocked_bucket, mocked_bucket2] + histogram = hist.HistogramContainer(mocked_input) + histogram.set_visual_range(max_val=3, min_val=-1, bins=5) + buckets = histogram.buckets() + assert buckets == ( + (-1.0, 0.8, 0), (-0.19999999999999996, 0.8, 1), (0.6000000000000001, 0.8, 5), (1.4000000000000004, 0.8, 6), + (2.2, 0.8, 0))