提交 ca5f81af 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!748 fix np.histogram sometimes calc very large bucket number

Merge pull request !748 from wenkai/wkr02
......@@ -15,6 +15,7 @@
"""Generate the summary event which conform to proto format."""
import time
import socket
import math
from enum import Enum, unique
import numpy as np
from PIL import Image
......@@ -292,6 +293,36 @@ def _get_tensor_summary(tag: str, np_value, summary_tensor):
return summary_tensor
def _calc_histogram_bins(count):
"""
Calculates experience-based optimal bins number for histogram.
There should be enough number in each bin. So we calc bin numbers according to count. For very small count(1 -
10), we assign carefully chosen number. For large count, we tried to make sure there are 9-10 numbers in each
bucket on average. Too many bins will slow down performance, so we set max number of bins to 90.
Args:
count (int): Valid number count for the tensor.
Returns:
int, number of histogram bins.
"""
number_per_bucket = 10
max_bins = 90
if not count:
return 1
if count <= 5:
return 2
if count <= 10:
return 3
if count <= 880:
# note that math.ceil(881/10) + 1 equals 90
return int(math.ceil(count / number_per_bucket) + 1)
return max_bins
def _fill_histogram_summary(tag: str, np_value: np.array, summary_histogram) -> None:
"""
Package the histogram summary.
......@@ -347,7 +378,8 @@ def _fill_histogram_summary(tag: str, np_value: np.array, summary_histogram) ->
return
counts, edges = np.histogram(np_value, bins='auto', range=(tensor_min, tensor_max))
bin_number = _calc_histogram_bins(masked_value.count())
counts, edges = np.histogram(np_value, bins=bin_number, range=(tensor_min, tensor_max))
for ind, count in enumerate(counts):
bucket = summary_histogram.buckets.add()
......
......@@ -22,6 +22,7 @@ import numpy as np
from mindspore.common.tensor import Tensor
from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data
from mindspore.train.summary._summary_adapter import _calc_histogram_bins
from .summary_reader import SummaryReader
CUR_DIR = os.getcwd()
......@@ -139,7 +140,7 @@ def test_histogram_summary_same_value():
event = reader.read_event()
LOG.debug(event)
assert len(event.summary.value[0].histogram.buckets) == 1
assert len(event.summary.value[0].histogram.buckets) == _calc_histogram_bins(dim1 * dim2)
def test_histogram_summary_high_dims():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册