提交 ab04b3dc 编写于 作者: W wenkai

fix np.histograms(bins='auto') sometimes calc very small width and very large...

fix np.histograms(bins='auto') sometimes calc very small width and very large bucket number, which lead to error/long compute time.
上级 41df9c20
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Generate the summary event which conform to proto format.""" """Generate the summary event which conform to proto format."""
import time import time
import socket import socket
import math
from enum import Enum, unique from enum import Enum, unique
import numpy as np import numpy as np
from PIL import Image from PIL import Image
...@@ -292,6 +293,36 @@ def _get_tensor_summary(tag: str, np_value, summary_tensor): ...@@ -292,6 +293,36 @@ def _get_tensor_summary(tag: str, np_value, summary_tensor):
return summary_tensor return summary_tensor
def _calc_histogram_bins(count):
"""
Calculates experience-based optimal bins number for histogram.
There should be enough number in each bin. So we calc bin numbers according to count. For very small count(1 -
10), we assign carefully chosen number. For large count, we tried to make sure there are 9-10 numbers in each
bucket on average. Too many bins will slow down performance, so we set max number of bins to 90.
Args:
count (int): Valid number count for the tensor.
Returns:
int, number of histogram bins.
"""
number_per_bucket = 10
max_bins = 90
if not count:
return 1
if count <= 5:
return 2
if count <= 10:
return 3
if count <= 880:
# note that math.ceil(881/10) + 1 equals 90
return int(math.ceil(count / number_per_bucket) + 1)
return max_bins
def _fill_histogram_summary(tag: str, np_value: np.array, summary_histogram) -> None: def _fill_histogram_summary(tag: str, np_value: np.array, summary_histogram) -> None:
""" """
Package the histogram summary. Package the histogram summary.
...@@ -347,7 +378,8 @@ def _fill_histogram_summary(tag: str, np_value: np.array, summary_histogram) -> ...@@ -347,7 +378,8 @@ def _fill_histogram_summary(tag: str, np_value: np.array, summary_histogram) ->
return return
counts, edges = np.histogram(np_value, bins='auto', range=(tensor_min, tensor_max)) bin_number = _calc_histogram_bins(masked_value.count())
counts, edges = np.histogram(np_value, bins=bin_number, range=(tensor_min, tensor_max))
for ind, count in enumerate(counts): for ind, count in enumerate(counts):
bucket = summary_histogram.buckets.add() bucket = summary_histogram.buckets.add()
......
...@@ -22,6 +22,7 @@ import numpy as np ...@@ -22,6 +22,7 @@ import numpy as np
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data
from mindspore.train.summary._summary_adapter import _calc_histogram_bins
from .summary_reader import SummaryReader from .summary_reader import SummaryReader
CUR_DIR = os.getcwd() CUR_DIR = os.getcwd()
...@@ -139,7 +140,7 @@ def test_histogram_summary_same_value(): ...@@ -139,7 +140,7 @@ def test_histogram_summary_same_value():
event = reader.read_event() event = reader.read_event()
LOG.debug(event) LOG.debug(event)
assert len(event.summary.value[0].histogram.buckets) == 1 assert len(event.summary.value[0].histogram.buckets) == _calc_histogram_bins(dim1 * dim2)
def test_histogram_summary_high_dims(): def test_histogram_summary_high_dims():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册