Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindinsight
提交
988aad75
M
mindinsight
项目概览
MindSpore
/
mindinsight
通知
7
Star
3
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindinsight
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
988aad75
编写于
4月 22, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 22, 2020
浏览文件
操作
浏览文件
下载
差异文件
!58 support resampling buckets
Merge pull request !58 from wenkai/wk0422
上级
09cd2808
84a39a4e
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
262 addition
and
38 deletion
+262
-38
mindinsight/datavisual/data_transform/histogram_container.py
mindinsight/datavisual/data_transform/histogram_container.py
+154
-10
mindinsight/datavisual/data_transform/reservoir.py
mindinsight/datavisual/data_transform/reservoir.py
+8
-26
mindinsight/datavisual/utils/utils.py
mindinsight/datavisual/utils/utils.py
+47
-0
tests/ut/datavisual/data_transform/test_histogram_container.py
.../ut/datavisual/data_transform/test_histogram_container.py
+53
-2
未找到文件。
mindinsight/datavisual/data_transform/histogram_container.py
浏览文件 @
988aad75
...
...
@@ -16,6 +16,8 @@
import
math
from
mindinsight.datavisual.proto_files.mindinsight_summary_pb2
import
Summary
from
mindinsight.utils.exceptions
import
ParamValueError
from
mindinsight.datavisual.utils.utils
import
calc_histogram_bins
def
_mask_invalid_number
(
num
):
...
...
@@ -26,6 +28,49 @@ def _mask_invalid_number(num):
return
num
class
Bucket
:
"""
Bucket data class.
Args:
left (double): Left edge of the histogram bucket.
width (double): Width of the histogram bucket.
count (int): Count of numbers fallen in the histogram bucket.
"""
def
__init__
(
self
,
left
,
width
,
count
):
self
.
_left
=
left
self
.
_width
=
width
self
.
_count
=
count
@
property
def
left
(
self
):
"""Gets left edge of the histogram bucket."""
return
self
.
_left
@
property
def
count
(
self
):
"""Gets count of numbers fallen in the histogram bucket."""
return
self
.
_count
@
property
def
width
(
self
):
"""Gets width of the histogram bucket."""
return
self
.
_width
@
property
def
right
(
self
):
"""Gets right edge of the histogram bucket."""
return
self
.
_left
+
self
.
_width
def
as_tuple
(
self
):
"""Gets the bucket as tuple."""
return
self
.
_left
,
self
.
_width
,
self
.
_count
def
__repr__
(
self
):
"""Returns repr(self)."""
return
"Bucket(left={}, width={}, count={})"
.
format
(
self
.
_left
,
self
.
_width
,
self
.
_count
)
class
HistogramContainer
:
"""
Histogram data container.
...
...
@@ -35,16 +80,19 @@ class HistogramContainer:
"""
def
__init__
(
self
,
histogram_message
:
Summary
.
Histogram
):
self
.
_msg
=
histogram_message
self
.
_original_buckets
=
tuple
((
bucket
.
left
,
bucket
.
width
,
bucket
.
count
)
for
bucket
in
self
.
_msg
.
buckets
)
original_buckets
=
[
Bucket
(
bucket
.
left
,
bucket
.
width
,
bucket
.
count
)
for
bucket
in
self
.
_msg
.
buckets
]
# Ensure buckets are sorted from min to max.
original_buckets
.
sort
(
key
=
lambda
bucket
:
bucket
.
left
)
self
.
_original_buckets
=
tuple
(
original_buckets
)
self
.
_count
=
sum
(
bucket
.
count
for
bucket
in
self
.
_original_buckets
)
self
.
_max
=
_mask_invalid_number
(
histogram_message
.
max
)
self
.
_min
=
_mask_invalid_number
(
histogram_message
.
min
)
self
.
_visual_max
=
self
.
_max
self
.
_visual_min
=
self
.
_min
# default bin number
self
.
_visual_bins
=
10
self
.
_count
=
sum
(
bucket
[
2
]
for
bucket
in
self
.
_original_buckets
)
self
.
_visual_bins
=
calc_histogram_bins
(
self
.
_count
)
# Note that tuple is immutable, so sharing tuple is often safe.
self
.
_re_sampled_buckets
=
self
.
_original_buckets
self
.
_re_sampled_buckets
=
()
@
property
def
max
(
self
):
...
...
@@ -63,7 +111,7 @@ class HistogramContainer:
@
property
def
original_msg
(
self
):
"""Get
original proto message
"""
"""Get
s original proto message.
"""
return
self
.
_msg
def
set_visual_range
(
self
,
max_val
:
float
,
min_val
:
float
,
bins
:
int
)
->
None
:
...
...
@@ -77,6 +125,13 @@ class HistogramContainer:
min_val (float): Min value for visual histogram.
bins (int): Bins number for visual histogram.
"""
if
max_val
<
min_val
:
raise
ParamValueError
(
"Invalid input. max_val({}) is less or equal than min_val({})."
.
format
(
max_val
,
min_val
))
if
bins
<
1
:
raise
ParamValueError
(
"Invalid input bins({}). Must be greater than 0."
.
format
(
bins
))
self
.
_visual_max
=
max_val
self
.
_visual_min
=
min_val
self
.
_visual_bins
=
bins
...
...
@@ -84,15 +139,104 @@ class HistogramContainer:
# mark _re_sampled_buckets to empty
self
.
_re_sampled_buckets
=
()
def
_re_sample_buckets
(
self
):
# Will call re-sample logic in later PR.
self
.
_re_sampled_buckets
=
self
.
_original_buckets
def
_calc_intersection_len
(
self
,
max1
,
min1
,
max2
,
min2
):
"""Calculates intersection length of [min1, max1] and [min2, max2]."""
if
max1
<
min1
:
raise
ParamValueError
(
"Invalid input. max1({}) is less than min1({})."
.
format
(
max1
,
min1
))
if
max2
<
min2
:
raise
ParamValueError
(
"Invalid input. max2({}) is less than min2({})."
.
format
(
max2
,
min2
))
if
min1
<=
min2
:
if
max1
<=
min2
:
# return value must be calculated by max1.__sub__
return
max1
-
max1
if
max1
<=
max2
:
return
max1
-
min2
# max1 > max2
return
max2
-
min2
# min1 > min2
if
max2
<=
min1
:
return
max2
-
max2
if
max2
<=
max1
:
return
max2
-
min1
return
max1
-
min1
def
buckets
(
self
):
def
_re_sample_buckets
(
self
):
"""Re-samples buckets according to visual_max, visual_min and visual_bins."""
if
self
.
_visual_max
==
self
.
_visual_min
:
# Adjust visual range if max equals min.
self
.
_visual_max
+=
0.5
self
.
_visual_min
-=
0.5
width
=
(
self
.
_visual_max
-
self
.
_visual_min
)
/
self
.
_visual_bins
if
not
self
.
count
:
self
.
_re_sampled_buckets
=
tuple
(
Bucket
(
self
.
_visual_min
+
width
*
i
,
width
,
0
)
for
i
in
range
(
self
.
_visual_bins
))
return
re_sampled
=
[]
original_pos
=
0
original_bucket
=
self
.
_original_buckets
[
original_pos
]
for
i
in
range
(
self
.
_visual_bins
):
cur_left
=
self
.
_visual_min
+
width
*
i
cur_right
=
cur_left
+
width
cur_estimated_count
=
0.0
# Skip no bucket range.
if
cur_right
<=
original_bucket
.
left
:
re_sampled
.
append
(
Bucket
(
cur_left
,
width
,
math
.
ceil
(
cur_estimated_count
)))
continue
# Skip no intersect range.
while
cur_left
>=
original_bucket
.
right
:
original_pos
+=
1
if
original_pos
>=
len
(
self
.
_original_buckets
):
break
original_bucket
=
self
.
_original_buckets
[
original_pos
]
# entering with this condition: cur_right > original_bucket.left and cur_left < original_bucket.right
while
True
:
if
original_pos
>=
len
(
self
.
_original_buckets
):
break
original_bucket
=
self
.
_original_buckets
[
original_pos
]
intersection
=
self
.
_calc_intersection_len
(
min1
=
cur_left
,
max1
=
cur_right
,
min2
=
original_bucket
.
left
,
max2
=
original_bucket
.
right
)
estimated_count
=
(
intersection
/
original_bucket
.
width
)
*
original_bucket
.
count
cur_estimated_count
+=
estimated_count
if
cur_right
>
original_bucket
.
right
:
# Need to sample next original bucket to this visual bucket.
original_pos
+=
1
else
:
# Current visual bucket has taken all intersect buckets into account.
break
re_sampled
.
append
(
Bucket
(
cur_left
,
width
,
math
.
ceil
(
cur_estimated_count
)))
self
.
_re_sampled_buckets
=
tuple
(
re_sampled
)
def
buckets
(
self
,
convert_to_tuple
=
True
):
"""
Get visual buckets instead of original buckets.
Args:
convert_to_tuple (bool): Whether convert bucket object to tuple.
Returns:
tuple, contains buckets.
"""
if
not
self
.
_re_sampled_buckets
:
self
.
_re_sample_buckets
()
if
not
convert_to_tuple
:
return
self
.
_re_sampled_buckets
return
tuple
(
bucket
.
as_tuple
()
for
bucket
in
self
.
_re_sampled_buckets
)
mindinsight/datavisual/data_transform/reservoir.py
浏览文件 @
988aad75
...
...
@@ -16,10 +16,11 @@
import
random
import
threading
import
math
from
mindinsight.datavisual.common.log
import
logger
from
mindinsight.datavisual.common.enums
import
PluginNameEnum
from
mindinsight.utils.exceptions
import
ParamValueError
from
mindinsight.datavisual.utils.utils
import
calc_histogram_bins
class
Reservoir
:
...
...
@@ -173,39 +174,20 @@ class HistogramReservoir(Reservoir):
max_count
=
max
(
histogram
.
count
,
max_count
)
visual_range
.
update
(
histogram
.
max
,
histogram
.
min
)
bins
=
self
.
_calc
_bins
(
max_count
)
bins
=
calc_histogram
_bins
(
max_count
)
# update visual range
logger
.
info
(
"Visual histogram: min %s, max %s, bins %s, max_count %s."
,
visual_range
.
min
,
visual_range
.
max
,
bins
,
max_count
)
for
sample
in
self
.
_samples
:
histogram
=
sample
.
value
histogram
.
set_visual_range
(
visual_range
.
max
,
visual_range
.
min
,
bins
)
return
list
(
self
.
_samples
)
def
_calc_bins
(
self
,
count
):
"""
Calculates experience-based optimal bins number.
To suppress re-sample bias, there should be enough number in each bin. So we calc bin numbers according to
count. For very small count(1 - 10), we assign carefully chosen number. For large count, we tried to make
sure there are 9-10 numbers in each bucket on average. Too many bins will also distract users, so we set max
number of bins to 30.
"""
number_per_bucket
=
10
max_bins
=
30
if
not
count
:
return
1
if
count
<=
5
:
return
2
if
count
<=
10
:
return
3
if
count
<=
280
:
# note that math.ceil(281/10) + 1 = 30
return
math
.
ceil
(
count
/
number_per_bucket
)
+
1
return
max_bins
class
ReservoirFactory
:
"""Factory class to get reservoir instances."""
...
...
mindinsight/datavisual/utils/utils.py
0 → 100644
浏览文件 @
988aad75
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils."""
import
math
def
calc_histogram_bins
(
count
):
"""
Calculates experience-based optimal bins number for histogram.
To suppress re-sample bias, there should be enough number in each bin. So we calc bin numbers according to
count. For very small count(1 - 10), we assign carefully chosen number. For large count, we tried to make
sure there are 9-10 numbers in each bucket on average. Too many bins will also distract users, so we set max
number of bins to 30.
Args:
count (int): Valid number count for the tensor.
Returns:
int, number of histogram bins.
"""
number_per_bucket
=
10
max_bins
=
30
if
not
count
:
return
1
if
count
<=
5
:
return
2
if
count
<=
10
:
return
3
if
count
<=
280
:
# note that math.ceil(281/10) + 1 equals 30
return
math
.
ceil
(
count
/
number_per_bucket
)
+
1
return
max_bins
tests/ut/datavisual/data_transform/test_histogram_container.py
浏览文件 @
988aad75
...
...
@@ -20,8 +20,9 @@ from mindinsight.datavisual.data_transform import histogram_container as hist
class
TestHistogram
:
"""Test histogram."""
def
test_get_buckets
(
self
):
"""Test get buckets."""
"""Test
s
get buckets."""
mocked_input
=
mock
.
MagicMock
()
mocked_bucket
=
mock
.
MagicMock
()
mocked_bucket
.
left
=
0
...
...
@@ -31,4 +32,54 @@ class TestHistogram:
histogram
=
hist
.
HistogramContainer
(
mocked_input
)
histogram
.
set_visual_range
(
max_val
=
1
,
min_val
=
0
,
bins
=
1
)
buckets
=
histogram
.
buckets
()
assert
len
(
buckets
)
==
1
\ No newline at end of file
assert
buckets
==
((
0.0
,
1.0
,
1
),)
def
test_re_sample_buckets_split_original
(
self
):
"""Tests splitting original buckets when re-sampling."""
mocked_input
=
mock
.
MagicMock
()
mocked_bucket
=
mock
.
MagicMock
()
mocked_bucket
.
left
=
0
mocked_bucket
.
width
=
1
mocked_bucket
.
count
=
1
mocked_input
.
buckets
=
[
mocked_bucket
]
histogram
=
hist
.
HistogramContainer
(
mocked_input
)
histogram
.
set_visual_range
(
max_val
=
1
,
min_val
=
0
,
bins
=
3
)
buckets
=
histogram
.
buckets
()
assert
buckets
==
((
0.0
,
0.3333333333333333
,
1
),
(
0.3333333333333333
,
0.3333333333333333
,
1
),
(
0.6666666666666666
,
0.3333333333333333
,
1
))
def
test_re_sample_buckets_zero_bucket
(
self
):
"""Tests zero bucket when re-sampling."""
mocked_input
=
mock
.
MagicMock
()
mocked_bucket
=
mock
.
MagicMock
()
mocked_bucket
.
left
=
0
mocked_bucket
.
width
=
1
mocked_bucket
.
count
=
1
mocked_bucket2
=
mock
.
MagicMock
()
mocked_bucket2
.
left
=
1
mocked_bucket2
.
width
=
1
mocked_bucket2
.
count
=
2
mocked_input
.
buckets
=
[
mocked_bucket
,
mocked_bucket2
]
histogram
=
hist
.
HistogramContainer
(
mocked_input
)
histogram
.
set_visual_range
(
max_val
=
3
,
min_val
=-
1
,
bins
=
4
)
buckets
=
histogram
.
buckets
()
assert
buckets
==
((
-
1.0
,
1.0
,
0
),
(
0.0
,
1.0
,
1
),
(
1.0
,
1.0
,
2
),
(
2.0
,
1.0
,
0
))
def
test_re_sample_buckets_merge_bucket
(
self
):
"""Tests merging counts from two buckets when re-sampling."""
mocked_input
=
mock
.
MagicMock
()
mocked_bucket
=
mock
.
MagicMock
()
mocked_bucket
.
left
=
0
mocked_bucket
.
width
=
1
mocked_bucket
.
count
=
1
mocked_bucket2
=
mock
.
MagicMock
()
mocked_bucket2
.
left
=
1
mocked_bucket2
.
width
=
1
mocked_bucket2
.
count
=
10
mocked_input
.
buckets
=
[
mocked_bucket
,
mocked_bucket2
]
histogram
=
hist
.
HistogramContainer
(
mocked_input
)
histogram
.
set_visual_range
(
max_val
=
3
,
min_val
=-
1
,
bins
=
5
)
buckets
=
histogram
.
buckets
()
assert
buckets
==
(
(
-
1.0
,
0.8
,
0
),
(
-
0.19999999999999996
,
0.8
,
1
),
(
0.6000000000000001
,
0.8
,
5
),
(
1.4000000000000004
,
0.8
,
6
),
(
2.2
,
0.8
,
0
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录