提交 988aad75 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!58 support resampling buckets

Merge pull request !58 from wenkai/wk0422
......@@ -16,6 +16,8 @@
import math
from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Summary
from mindinsight.utils.exceptions import ParamValueError
from mindinsight.datavisual.utils.utils import calc_histogram_bins
def _mask_invalid_number(num):
......@@ -26,6 +28,49 @@ def _mask_invalid_number(num):
return num
class Bucket:
"""
Bucket data class.
Args:
left (double): Left edge of the histogram bucket.
width (double): Width of the histogram bucket.
count (int): Count of numbers fallen in the histogram bucket.
"""
def __init__(self, left, width, count):
self._left = left
self._width = width
self._count = count
@property
def left(self):
"""Gets left edge of the histogram bucket."""
return self._left
@property
def count(self):
"""Gets count of numbers fallen in the histogram bucket."""
return self._count
@property
def width(self):
"""Gets width of the histogram bucket."""
return self._width
@property
def right(self):
"""Gets right edge of the histogram bucket."""
return self._left + self._width
def as_tuple(self):
"""Gets the bucket as tuple."""
return self._left, self._width, self._count
def __repr__(self):
"""Returns repr(self)."""
return "Bucket(left={}, width={}, count={})".format(self._left, self._width, self._count)
class HistogramContainer:
"""
Histogram data container.
......@@ -35,16 +80,19 @@ class HistogramContainer:
"""
def __init__(self, histogram_message: Summary.Histogram):
self._msg = histogram_message
self._original_buckets = tuple((bucket.left, bucket.width, bucket.count) for bucket in self._msg.buckets)
original_buckets = [Bucket(bucket.left, bucket.width, bucket.count) for bucket in self._msg.buckets]
# Ensure buckets are sorted from min to max.
original_buckets.sort(key=lambda bucket: bucket.left)
self._original_buckets = tuple(original_buckets)
self._count = sum(bucket.count for bucket in self._original_buckets)
self._max = _mask_invalid_number(histogram_message.max)
self._min = _mask_invalid_number(histogram_message.min)
self._visual_max = self._max
self._visual_min = self._min
# default bin number
self._visual_bins = 10
self._count = sum(bucket[2] for bucket in self._original_buckets)
self._visual_bins = calc_histogram_bins(self._count)
# Note that tuple is immutable, so sharing tuple is often safe.
self._re_sampled_buckets = self._original_buckets
self._re_sampled_buckets = ()
@property
def max(self):
......@@ -63,7 +111,7 @@ class HistogramContainer:
@property
def original_msg(self):
"""Get original proto message"""
"""Gets original proto message."""
return self._msg
def set_visual_range(self, max_val: float, min_val: float, bins: int) -> None:
......@@ -77,6 +125,13 @@ class HistogramContainer:
min_val (float): Min value for visual histogram.
bins (int): Bins number for visual histogram.
"""
if max_val < min_val:
raise ParamValueError(
"Invalid input. max_val({}) is less or equal than min_val({}).".format(max_val, min_val))
if bins < 1:
raise ParamValueError("Invalid input bins({}). Must be greater than 0.".format(bins))
self._visual_max = max_val
self._visual_min = min_val
self._visual_bins = bins
......@@ -84,15 +139,104 @@ class HistogramContainer:
# mark _re_sampled_buckets to empty
self._re_sampled_buckets = ()
def _re_sample_buckets(self):
# Will call re-sample logic in later PR.
self._re_sampled_buckets = self._original_buckets
def _calc_intersection_len(self, max1, min1, max2, min2):
"""Calculates intersection length of [min1, max1] and [min2, max2]."""
if max1 < min1:
raise ParamValueError(
"Invalid input. max1({}) is less than min1({}).".format(max1, min1))
if max2 < min2:
raise ParamValueError(
"Invalid input. max2({}) is less than min2({}).".format(max2, min2))
if min1 <= min2:
if max1 <= min2:
# return value must be calculated by max1.__sub__
return max1 - max1
if max1 <= max2:
return max1 - min2
# max1 > max2
return max2 - min2
# min1 > min2
if max2 <= min1:
return max2 - max2
if max2 <= max1:
return max2 - min1
return max1 - min1
def buckets(self):
def _re_sample_buckets(self):
"""Re-samples buckets according to visual_max, visual_min and visual_bins."""
if self._visual_max == self._visual_min:
# Adjust visual range if max equals min.
self._visual_max += 0.5
self._visual_min -= 0.5
width = (self._visual_max - self._visual_min) / self._visual_bins
if not self.count:
self._re_sampled_buckets = tuple(
Bucket(self._visual_min + width * i, width, 0)
for i in range(self._visual_bins))
return
re_sampled = []
original_pos = 0
original_bucket = self._original_buckets[original_pos]
for i in range(self._visual_bins):
cur_left = self._visual_min + width * i
cur_right = cur_left + width
cur_estimated_count = 0.0
# Skip no bucket range.
if cur_right <= original_bucket.left:
re_sampled.append(Bucket(cur_left, width, math.ceil(cur_estimated_count)))
continue
# Skip no intersect range.
while cur_left >= original_bucket.right:
original_pos += 1
if original_pos >= len(self._original_buckets):
break
original_bucket = self._original_buckets[original_pos]
# entering with this condition: cur_right > original_bucket.left and cur_left < original_bucket.right
while True:
if original_pos >= len(self._original_buckets):
break
original_bucket = self._original_buckets[original_pos]
intersection = self._calc_intersection_len(
min1=cur_left, max1=cur_right,
min2=original_bucket.left, max2=original_bucket.right)
estimated_count = (intersection / original_bucket.width) * original_bucket.count
cur_estimated_count += estimated_count
if cur_right > original_bucket.right:
# Need to sample next original bucket to this visual bucket.
original_pos += 1
else:
# Current visual bucket has taken all intersect buckets into account.
break
re_sampled.append(Bucket(cur_left, width, math.ceil(cur_estimated_count)))
self._re_sampled_buckets = tuple(re_sampled)
def buckets(self, convert_to_tuple=True):
"""
Get visual buckets instead of original buckets.
Args:
convert_to_tuple (bool): Whether convert bucket object to tuple.
Returns:
tuple, contains buckets.
"""
if not self._re_sampled_buckets:
self._re_sample_buckets()
return self._re_sampled_buckets
if not convert_to_tuple:
return self._re_sampled_buckets
return tuple(bucket.as_tuple() for bucket in self._re_sampled_buckets)
......@@ -16,10 +16,11 @@
import random
import threading
import math
from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.utils.exceptions import ParamValueError
from mindinsight.datavisual.utils.utils import calc_histogram_bins
class Reservoir:
......@@ -173,39 +174,20 @@ class HistogramReservoir(Reservoir):
max_count = max(histogram.count, max_count)
visual_range.update(histogram.max, histogram.min)
bins = self._calc_bins(max_count)
bins = calc_histogram_bins(max_count)
# update visual range
logger.info("Visual histogram: min %s, max %s, bins %s, max_count %s.",
visual_range.min,
visual_range.max,
bins,
max_count)
for sample in self._samples:
histogram = sample.value
histogram.set_visual_range(visual_range.max, visual_range.min, bins)
return list(self._samples)
def _calc_bins(self, count):
"""
Calculates experience-based optimal bins number.
To suppress re-sample bias, there should be enough number in each bin. So we calc bin numbers according to
count. For very small count(1 - 10), we assign carefully chosen number. For large count, we tried to make
sure there are 9-10 numbers in each bucket on average. Too many bins will also distract users, so we set max
number of bins to 30.
"""
number_per_bucket = 10
max_bins = 30
if not count:
return 1
if count <= 5:
return 2
if count <= 10:
return 3
if count <= 280:
# note that math.ceil(281/10) + 1 = 30
return math.ceil(count / number_per_bucket) + 1
return max_bins
class ReservoirFactory:
"""Factory class to get reservoir instances."""
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils."""
import math
def calc_histogram_bins(count):
"""
Calculates experience-based optimal bins number for histogram.
To suppress re-sample bias, there should be enough number in each bin. So we calc bin numbers according to
count. For very small count(1 - 10), we assign carefully chosen number. For large count, we tried to make
sure there are 9-10 numbers in each bucket on average. Too many bins will also distract users, so we set max
number of bins to 30.
Args:
count (int): Valid number count for the tensor.
Returns:
int, number of histogram bins.
"""
number_per_bucket = 10
max_bins = 30
if not count:
return 1
if count <= 5:
return 2
if count <= 10:
return 3
if count <= 280:
# note that math.ceil(281/10) + 1 equals 30
return math.ceil(count / number_per_bucket) + 1
return max_bins
......@@ -20,8 +20,9 @@ from mindinsight.datavisual.data_transform import histogram_container as hist
class TestHistogram:
"""Test histogram."""
def test_get_buckets(self):
"""Test get buckets."""
"""Tests get buckets."""
mocked_input = mock.MagicMock()
mocked_bucket = mock.MagicMock()
mocked_bucket.left = 0
......@@ -31,4 +32,54 @@ class TestHistogram:
histogram = hist.HistogramContainer(mocked_input)
histogram.set_visual_range(max_val=1, min_val=0, bins=1)
buckets = histogram.buckets()
assert len(buckets) == 1
\ No newline at end of file
assert buckets == ((0.0, 1.0, 1),)
def test_re_sample_buckets_split_original(self):
"""Tests splitting original buckets when re-sampling."""
mocked_input = mock.MagicMock()
mocked_bucket = mock.MagicMock()
mocked_bucket.left = 0
mocked_bucket.width = 1
mocked_bucket.count = 1
mocked_input.buckets = [mocked_bucket]
histogram = hist.HistogramContainer(mocked_input)
histogram.set_visual_range(max_val=1, min_val=0, bins=3)
buckets = histogram.buckets()
assert buckets == ((0.0, 0.3333333333333333, 1), (0.3333333333333333, 0.3333333333333333, 1),
(0.6666666666666666, 0.3333333333333333, 1))
def test_re_sample_buckets_zero_bucket(self):
"""Tests zero bucket when re-sampling."""
mocked_input = mock.MagicMock()
mocked_bucket = mock.MagicMock()
mocked_bucket.left = 0
mocked_bucket.width = 1
mocked_bucket.count = 1
mocked_bucket2 = mock.MagicMock()
mocked_bucket2.left = 1
mocked_bucket2.width = 1
mocked_bucket2.count = 2
mocked_input.buckets = [mocked_bucket, mocked_bucket2]
histogram = hist.HistogramContainer(mocked_input)
histogram.set_visual_range(max_val=3, min_val=-1, bins=4)
buckets = histogram.buckets()
assert buckets == ((-1.0, 1.0, 0), (0.0, 1.0, 1), (1.0, 1.0, 2), (2.0, 1.0, 0))
def test_re_sample_buckets_merge_bucket(self):
"""Tests merging counts from two buckets when re-sampling."""
mocked_input = mock.MagicMock()
mocked_bucket = mock.MagicMock()
mocked_bucket.left = 0
mocked_bucket.width = 1
mocked_bucket.count = 1
mocked_bucket2 = mock.MagicMock()
mocked_bucket2.left = 1
mocked_bucket2.width = 1
mocked_bucket2.count = 10
mocked_input.buckets = [mocked_bucket, mocked_bucket2]
histogram = hist.HistogramContainer(mocked_input)
histogram.set_visual_range(max_val=3, min_val=-1, bins=5)
buckets = histogram.buckets()
assert buckets == (
(-1.0, 0.8, 0), (-0.19999999999999996, 0.8, 1), (0.6000000000000001, 0.8, 5), (1.4000000000000004, 0.8, 6),
(2.2, 0.8, 0))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册