提交 988aad75 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!58 support resampling buckets

Merge pull request !58 from wenkai/wk0422
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
import math import math
from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Summary from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Summary
from mindinsight.utils.exceptions import ParamValueError
from mindinsight.datavisual.utils.utils import calc_histogram_bins
def _mask_invalid_number(num): def _mask_invalid_number(num):
...@@ -26,6 +28,49 @@ def _mask_invalid_number(num): ...@@ -26,6 +28,49 @@ def _mask_invalid_number(num):
return num return num
class Bucket:
"""
Bucket data class.
Args:
left (double): Left edge of the histogram bucket.
width (double): Width of the histogram bucket.
count (int): Count of numbers fallen in the histogram bucket.
"""
def __init__(self, left, width, count):
self._left = left
self._width = width
self._count = count
@property
def left(self):
"""Gets left edge of the histogram bucket."""
return self._left
@property
def count(self):
"""Gets count of numbers fallen in the histogram bucket."""
return self._count
@property
def width(self):
"""Gets width of the histogram bucket."""
return self._width
@property
def right(self):
"""Gets right edge of the histogram bucket."""
return self._left + self._width
def as_tuple(self):
"""Gets the bucket as tuple."""
return self._left, self._width, self._count
def __repr__(self):
"""Returns repr(self)."""
return "Bucket(left={}, width={}, count={})".format(self._left, self._width, self._count)
class HistogramContainer: class HistogramContainer:
""" """
Histogram data container. Histogram data container.
...@@ -35,16 +80,19 @@ class HistogramContainer: ...@@ -35,16 +80,19 @@ class HistogramContainer:
""" """
def __init__(self, histogram_message: Summary.Histogram): def __init__(self, histogram_message: Summary.Histogram):
self._msg = histogram_message self._msg = histogram_message
self._original_buckets = tuple((bucket.left, bucket.width, bucket.count) for bucket in self._msg.buckets) original_buckets = [Bucket(bucket.left, bucket.width, bucket.count) for bucket in self._msg.buckets]
# Ensure buckets are sorted from min to max.
original_buckets.sort(key=lambda bucket: bucket.left)
self._original_buckets = tuple(original_buckets)
self._count = sum(bucket.count for bucket in self._original_buckets)
self._max = _mask_invalid_number(histogram_message.max) self._max = _mask_invalid_number(histogram_message.max)
self._min = _mask_invalid_number(histogram_message.min) self._min = _mask_invalid_number(histogram_message.min)
self._visual_max = self._max self._visual_max = self._max
self._visual_min = self._min self._visual_min = self._min
# default bin number # default bin number
self._visual_bins = 10 self._visual_bins = calc_histogram_bins(self._count)
self._count = sum(bucket[2] for bucket in self._original_buckets)
# Note that tuple is immutable, so sharing tuple is often safe. # Note that tuple is immutable, so sharing tuple is often safe.
self._re_sampled_buckets = self._original_buckets self._re_sampled_buckets = ()
@property @property
def max(self): def max(self):
...@@ -63,7 +111,7 @@ class HistogramContainer: ...@@ -63,7 +111,7 @@ class HistogramContainer:
@property @property
def original_msg(self): def original_msg(self):
"""Get original proto message""" """Gets original proto message."""
return self._msg return self._msg
def set_visual_range(self, max_val: float, min_val: float, bins: int) -> None: def set_visual_range(self, max_val: float, min_val: float, bins: int) -> None:
...@@ -77,6 +125,13 @@ class HistogramContainer: ...@@ -77,6 +125,13 @@ class HistogramContainer:
min_val (float): Min value for visual histogram. min_val (float): Min value for visual histogram.
bins (int): Bins number for visual histogram. bins (int): Bins number for visual histogram.
""" """
if max_val < min_val:
raise ParamValueError(
"Invalid input. max_val({}) is less or equal than min_val({}).".format(max_val, min_val))
if bins < 1:
raise ParamValueError("Invalid input bins({}). Must be greater than 0.".format(bins))
self._visual_max = max_val self._visual_max = max_val
self._visual_min = min_val self._visual_min = min_val
self._visual_bins = bins self._visual_bins = bins
...@@ -84,15 +139,104 @@ class HistogramContainer: ...@@ -84,15 +139,104 @@ class HistogramContainer:
# mark _re_sampled_buckets to empty # mark _re_sampled_buckets to empty
self._re_sampled_buckets = () self._re_sampled_buckets = ()
def _re_sample_buckets(self): def _calc_intersection_len(self, max1, min1, max2, min2):
# Will call re-sample logic in later PR. """Calculates intersection length of [min1, max1] and [min2, max2]."""
self._re_sampled_buckets = self._original_buckets if max1 < min1:
raise ParamValueError(
"Invalid input. max1({}) is less than min1({}).".format(max1, min1))
if max2 < min2:
raise ParamValueError(
"Invalid input. max2({}) is less than min2({}).".format(max2, min2))
if min1 <= min2:
if max1 <= min2:
# return value must be calculated by max1.__sub__
return max1 - max1
if max1 <= max2:
return max1 - min2
# max1 > max2
return max2 - min2
# min1 > min2
if max2 <= min1:
return max2 - max2
if max2 <= max1:
return max2 - min1
return max1 - min1
def buckets(self): def _re_sample_buckets(self):
"""Re-samples buckets according to visual_max, visual_min and visual_bins."""
if self._visual_max == self._visual_min:
# Adjust visual range if max equals min.
self._visual_max += 0.5
self._visual_min -= 0.5
width = (self._visual_max - self._visual_min) / self._visual_bins
if not self.count:
self._re_sampled_buckets = tuple(
Bucket(self._visual_min + width * i, width, 0)
for i in range(self._visual_bins))
return
re_sampled = []
original_pos = 0
original_bucket = self._original_buckets[original_pos]
for i in range(self._visual_bins):
cur_left = self._visual_min + width * i
cur_right = cur_left + width
cur_estimated_count = 0.0
# Skip no bucket range.
if cur_right <= original_bucket.left:
re_sampled.append(Bucket(cur_left, width, math.ceil(cur_estimated_count)))
continue
# Skip no intersect range.
while cur_left >= original_bucket.right:
original_pos += 1
if original_pos >= len(self._original_buckets):
break
original_bucket = self._original_buckets[original_pos]
# entering with this condition: cur_right > original_bucket.left and cur_left < original_bucket.right
while True:
if original_pos >= len(self._original_buckets):
break
original_bucket = self._original_buckets[original_pos]
intersection = self._calc_intersection_len(
min1=cur_left, max1=cur_right,
min2=original_bucket.left, max2=original_bucket.right)
estimated_count = (intersection / original_bucket.width) * original_bucket.count
cur_estimated_count += estimated_count
if cur_right > original_bucket.right:
# Need to sample next original bucket to this visual bucket.
original_pos += 1
else:
# Current visual bucket has taken all intersect buckets into account.
break
re_sampled.append(Bucket(cur_left, width, math.ceil(cur_estimated_count)))
self._re_sampled_buckets = tuple(re_sampled)
def buckets(self, convert_to_tuple=True):
""" """
Get visual buckets instead of original buckets. Get visual buckets instead of original buckets.
Args:
convert_to_tuple (bool): Whether convert bucket object to tuple.
Returns:
tuple, contains buckets.
""" """
if not self._re_sampled_buckets: if not self._re_sampled_buckets:
self._re_sample_buckets() self._re_sample_buckets()
if not convert_to_tuple:
return self._re_sampled_buckets return self._re_sampled_buckets
return tuple(bucket.as_tuple() for bucket in self._re_sampled_buckets)
...@@ -16,10 +16,11 @@ ...@@ -16,10 +16,11 @@
import random import random
import threading import threading
import math
from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamValueError
from mindinsight.datavisual.utils.utils import calc_histogram_bins
class Reservoir: class Reservoir:
...@@ -173,39 +174,20 @@ class HistogramReservoir(Reservoir): ...@@ -173,39 +174,20 @@ class HistogramReservoir(Reservoir):
max_count = max(histogram.count, max_count) max_count = max(histogram.count, max_count)
visual_range.update(histogram.max, histogram.min) visual_range.update(histogram.max, histogram.min)
bins = self._calc_bins(max_count) bins = calc_histogram_bins(max_count)
# update visual range # update visual range
logger.info("Visual histogram: min %s, max %s, bins %s, max_count %s.",
visual_range.min,
visual_range.max,
bins,
max_count)
for sample in self._samples: for sample in self._samples:
histogram = sample.value histogram = sample.value
histogram.set_visual_range(visual_range.max, visual_range.min, bins) histogram.set_visual_range(visual_range.max, visual_range.min, bins)
return list(self._samples) return list(self._samples)
def _calc_bins(self, count):
"""
Calculates experience-based optimal bins number.
To suppress re-sample bias, there should be enough number in each bin. So we calc bin numbers according to
count. For very small count(1 - 10), we assign carefully chosen number. For large count, we tried to make
sure there are 9-10 numbers in each bucket on average. Too many bins will also distract users, so we set max
number of bins to 30.
"""
number_per_bucket = 10
max_bins = 30
if not count:
return 1
if count <= 5:
return 2
if count <= 10:
return 3
if count <= 280:
# note that math.ceil(281/10) + 1 = 30
return math.ceil(count / number_per_bucket) + 1
return max_bins
class ReservoirFactory: class ReservoirFactory:
"""Factory class to get reservoir instances.""" """Factory class to get reservoir instances."""
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils."""
import math
def calc_histogram_bins(count):
"""
Calculates experience-based optimal bins number for histogram.
To suppress re-sample bias, there should be enough number in each bin. So we calc bin numbers according to
count. For very small count(1 - 10), we assign carefully chosen number. For large count, we tried to make
sure there are 9-10 numbers in each bucket on average. Too many bins will also distract users, so we set max
number of bins to 30.
Args:
count (int): Valid number count for the tensor.
Returns:
int, number of histogram bins.
"""
number_per_bucket = 10
max_bins = 30
if not count:
return 1
if count <= 5:
return 2
if count <= 10:
return 3
if count <= 280:
# note that math.ceil(281/10) + 1 equals 30
return math.ceil(count / number_per_bucket) + 1
return max_bins
...@@ -20,8 +20,9 @@ from mindinsight.datavisual.data_transform import histogram_container as hist ...@@ -20,8 +20,9 @@ from mindinsight.datavisual.data_transform import histogram_container as hist
class TestHistogram: class TestHistogram:
"""Test histogram.""" """Test histogram."""
def test_get_buckets(self): def test_get_buckets(self):
"""Test get buckets.""" """Tests get buckets."""
mocked_input = mock.MagicMock() mocked_input = mock.MagicMock()
mocked_bucket = mock.MagicMock() mocked_bucket = mock.MagicMock()
mocked_bucket.left = 0 mocked_bucket.left = 0
...@@ -31,4 +32,54 @@ class TestHistogram: ...@@ -31,4 +32,54 @@ class TestHistogram:
histogram = hist.HistogramContainer(mocked_input) histogram = hist.HistogramContainer(mocked_input)
histogram.set_visual_range(max_val=1, min_val=0, bins=1) histogram.set_visual_range(max_val=1, min_val=0, bins=1)
buckets = histogram.buckets() buckets = histogram.buckets()
assert len(buckets) == 1 assert buckets == ((0.0, 1.0, 1),)
\ No newline at end of file
def test_re_sample_buckets_split_original(self):
"""Tests splitting original buckets when re-sampling."""
mocked_input = mock.MagicMock()
mocked_bucket = mock.MagicMock()
mocked_bucket.left = 0
mocked_bucket.width = 1
mocked_bucket.count = 1
mocked_input.buckets = [mocked_bucket]
histogram = hist.HistogramContainer(mocked_input)
histogram.set_visual_range(max_val=1, min_val=0, bins=3)
buckets = histogram.buckets()
assert buckets == ((0.0, 0.3333333333333333, 1), (0.3333333333333333, 0.3333333333333333, 1),
(0.6666666666666666, 0.3333333333333333, 1))
def test_re_sample_buckets_zero_bucket(self):
"""Tests zero bucket when re-sampling."""
mocked_input = mock.MagicMock()
mocked_bucket = mock.MagicMock()
mocked_bucket.left = 0
mocked_bucket.width = 1
mocked_bucket.count = 1
mocked_bucket2 = mock.MagicMock()
mocked_bucket2.left = 1
mocked_bucket2.width = 1
mocked_bucket2.count = 2
mocked_input.buckets = [mocked_bucket, mocked_bucket2]
histogram = hist.HistogramContainer(mocked_input)
histogram.set_visual_range(max_val=3, min_val=-1, bins=4)
buckets = histogram.buckets()
assert buckets == ((-1.0, 1.0, 0), (0.0, 1.0, 1), (1.0, 1.0, 2), (2.0, 1.0, 0))
def test_re_sample_buckets_merge_bucket(self):
"""Tests merging counts from two buckets when re-sampling."""
mocked_input = mock.MagicMock()
mocked_bucket = mock.MagicMock()
mocked_bucket.left = 0
mocked_bucket.width = 1
mocked_bucket.count = 1
mocked_bucket2 = mock.MagicMock()
mocked_bucket2.left = 1
mocked_bucket2.width = 1
mocked_bucket2.count = 10
mocked_input.buckets = [mocked_bucket, mocked_bucket2]
histogram = hist.HistogramContainer(mocked_input)
histogram.set_visual_range(max_val=3, min_val=-1, bins=5)
buckets = histogram.buckets()
assert buckets == (
(-1.0, 0.8, 0), (-0.19999999999999996, 0.8, 1), (0.6000000000000001, 0.8, 5), (1.4000000000000004, 0.8, 6),
(2.2, 0.8, 0))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册