提交 c4c74bd6 编写于 作者: W wenkai

cross-step bucket unify

上级 41034372
...@@ -93,7 +93,9 @@ class EventsData: ...@@ -93,7 +93,9 @@ class EventsData:
with self._reservoir_mutex_lock: with self._reservoir_mutex_lock:
if tag not in self._reservoir_by_tag: if tag not in self._reservoir_by_tag:
reservoir_size = self._get_reservoir_size(tensor_event.plugin_name) reservoir_size = self._get_reservoir_size(tensor_event.plugin_name)
self._reservoir_by_tag[tag] = reservoir.Reservoir(reservoir_size) self._reservoir_by_tag[tag] = reservoir.ReservoirFactory().create_reservoir(
plugin_name, reservoir_size
)
tensor = _Tensor(wall_time=tensor_event.wall_time, tensor = _Tensor(wall_time=tensor_event.wall_time,
step=tensor_event.step, step=tensor_event.step,
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Histogram data container."""
import math
from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Summary
def _mask_invalid_number(num):
"""Mask invalid number to 0."""
if math.isnan(num) or math.isinf(num):
return type(num)(0)
return num
class HistogramContainer:
"""
Histogram data container.
Args:
histogram_message (Summary.Histogram): Histogram message in summary file.
"""
def __init__(self, histogram_message: Summary.Histogram):
self._msg = histogram_message
self._original_buckets = tuple((bucket.left, bucket.width, bucket.count) for bucket in self._msg.buckets)
self._max = _mask_invalid_number(histogram_message.max)
self._min = _mask_invalid_number(histogram_message.min)
self._visual_max = self._max
self._visual_min = self._min
# default bin number
self._visual_bins = 10
self._count = self._msg.count
# Note that tuple is immutable, so sharing tuple is often safe.
self._re_sampled_buckets = self._original_buckets
@property
def max(self):
"""Gets max value of the tensor."""
return self._max
@property
def min(self):
"""Gets min value of the tensor."""
return self._min
@property
def count(self):
"""Gets valid number count of the tensor."""
return self._count
@property
def original_msg(self):
"""Get original proto message"""
return self._msg
def set_visual_range(self, max_val: float, min_val: float, bins: int) -> None:
"""
Sets visual range for later re-sampling.
It's caller's duty to ensure input is valid.
Args:
max_val (float): Max value for visual histogram.
min_val (float): Min value for visual histogram.
bins (int): Bins number for visual histogram.
"""
self._visual_max = max_val
self._visual_min = min_val
self._visual_bins = bins
# mark _re_sampled_buckets to empty
self._re_sampled_buckets = ()
def _re_sample_buckets(self):
# Will call re-sample logic in later PR.
self._re_sampled_buckets = self._original_buckets
def buckets(self):
"""
Get visual buckets instead of original buckets.
"""
if not self._re_sampled_buckets:
self._re_sample_buckets()
return self._re_sampled_buckets
...@@ -36,6 +36,7 @@ from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summar ...@@ -36,6 +36,7 @@ from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summar
from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2 from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2
from mindinsight.datavisual.utils import crc32 from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import UnknownError from mindinsight.utils.exceptions import UnknownError
from mindinsight.datavisual.data_transform.histogram_container import HistogramContainer
HEADER_SIZE = 8 HEADER_SIZE = 8
CRC_STR_SIZE = 4 CRC_STR_SIZE = 4
...@@ -235,7 +236,7 @@ class MSDataLoader: ...@@ -235,7 +236,7 @@ class MSDataLoader:
self._events_data.add_tensor_event(tensor_event) self._events_data.add_tensor_event(tensor_event)
if value.HasField('histogram'): if value.HasField('histogram'):
histogram_msg = value.histogram histogram_msg = HistogramContainer(value.histogram)
tag = '{}/{}'.format(value.tag, PluginNameEnum.HISTOGRAM.value) tag = '{}/{}'.format(value.tag, PluginNameEnum.HISTOGRAM.value)
tensor_event = TensorEvent(wall_time=event.wall_time, tensor_event = TensorEvent(wall_time=event.wall_time,
step=event.step, step=event.step,
......
...@@ -16,7 +16,9 @@ ...@@ -16,7 +16,9 @@
import random import random
import threading import threading
import math
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamValueError
...@@ -106,3 +108,118 @@ class Reservoir: ...@@ -106,3 +108,118 @@ class Reservoir:
round(self._sample_counter * sample_remaining_rate)) round(self._sample_counter * sample_remaining_rate))
return remove_size return remove_size
class _VisualRange:
"""Simple helper class to merge visual ranges."""
def __init__(self):
self._max = 0.0
self._min = 0.0
self._updated = False
def update(self, max_val: float, min_val: float) -> None:
"""
Merge visual range with given range.
Args:
max_val (float): Max value of given range.
min_val (float): Min value of given range.
"""
if not self._updated:
self._max = max_val
self._min = min_val
self._updated = True
return
if max_val > self._max:
self._max = max_val
if min_val < self._min:
self._min = min_val
@property
def max(self):
"""Gets max value of current range."""
return self._max
@property
def min(self):
"""Gets min value of current range."""
return self._min
class HistogramReservoir(Reservoir):
"""
Reservoir for histogram, which needs updating range over all steps.
Args:
size (int): Container Size. If the size is 0, the container is not limited.
"""
def __init__(self, size):
super().__init__(size)
def samples(self):
"""Return all stored samples."""
with self._mutex:
# calc visual range
visual_range = _VisualRange()
max_count = 0
for sample in self._samples:
histogram = sample.value
if histogram.count == 0:
# ignore empty tensor
continue
max_count = max(histogram.count, max_count)
visual_range.update(histogram.max, histogram.min)
bins = self._calc_bins(max_count)
# update visual range
for sample in self._samples:
histogram = sample.value
histogram.set_visual_range(visual_range.max, visual_range.min, bins)
return list(self._samples)
def _calc_bins(self, count):
"""
Calculates experience-based optimal bins number.
To suppress re-sample bias, there should be enough number in each bin. So we calc bin numbers according to
count. For very small count(1 - 10), we assign carefully chosen number. For large count, we tried to make
sure there are 9-10 numbers in each bucket on average. Too many bins will also distract users, so we set max
number of bins to 30.
"""
number_per_bucket = 10
max_bins = 30
if not count:
return 1
if count <= 5:
return 2
if count <= 10:
return 3
if count <= 280:
# note that math.ceil(281/10) + 1 = 30
return math.ceil(count / number_per_bucket) + 1
return max_bins
class ReservoirFactory:
"""Factory class to get reservoir instances."""
def create_reservoir(self, plugin_name: str, size: int) -> Reservoir:
"""
Creates reservoir for given plugin name.
Args:
plugin_name (str): Plugin name
size (int): Container Size. If the size is 0, the container is not limited.
Returns:
Reservoir, reservoir instance for given plugin name.
"""
if plugin_name == PluginNameEnum.HISTOGRAM.value:
return HistogramReservoir(size)
return Reservoir(size)
...@@ -53,9 +53,8 @@ class HistogramProcessor(BaseProcessor): ...@@ -53,9 +53,8 @@ class HistogramProcessor(BaseProcessor):
histograms = [] histograms = []
for tensor in tensors: for tensor in tensors:
buckets = [] histogram = tensor.value
for bucket in tensor.value.buckets: buckets = histogram.buckets()
buckets.append([bucket.left, bucket.width, bucket.count])
histograms.append({ histograms.append({
"wall_time": tensor.wall_time, "wall_time": tensor.wall_time,
"step": tensor.step, "step": tensor.step,
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test histogram."""
import unittest.mock as mock
from mindinsight.datavisual.data_transform import histogram_container as hist
class TestHistogram:
"""Test histogram."""
def test_get_buckets(self):
"""Test get buckets."""
mocked_input = mock.MagicMock()
mocked_bucket = mock.MagicMock()
mocked_bucket.left = 0
mocked_bucket.width = 1
mocked_bucket.count = 1
mocked_input.buckets = [mocked_bucket]
histogram = hist.HistogramContainer(mocked_input)
histogram.set_visual_range(max_val=1, min_val=0, bins=1)
buckets = histogram.buckets()
assert len(buckets) == 1
\ No newline at end of file
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test reservoir."""
import unittest.mock as mock
import mindinsight.datavisual.data_transform.reservoir as reservoir
class TestHistogramReservoir:
"""Test histogram reservoir."""
def test_samples(self):
"""Test get samples."""
my_reservoir = reservoir.ReservoirFactory().create_reservoir(reservoir.PluginNameEnum.HISTOGRAM.value, size=10)
sample1 = mock.MagicMock()
sample1.value.count = 1
sample1.value.max = 102
sample1.value.min = 101
sample2 = mock.MagicMock()
sample2.value.count = 2
sample2.value.max = 102
sample2.value.min = 101
my_reservoir.add_sample(sample1)
my_reservoir.add_sample(sample2)
samples = my_reservoir.samples()
assert len(samples) == 2
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册