Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindinsight
提交
988aad75
M
mindinsight
项目概览
MindSpore
/
mindinsight
通知
8
Star
3
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindinsight
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
988aad75
编写于
4月 22, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 22, 2020
浏览文件
操作
浏览文件
下载
差异文件
!58 support resampling buckets
Merge pull request !58 from wenkai/wk0422
上级
09cd2808
84a39a4e
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
262 addition
and
38 deletion
+262
-38
mindinsight/datavisual/data_transform/histogram_container.py
mindinsight/datavisual/data_transform/histogram_container.py
+154
-10
mindinsight/datavisual/data_transform/reservoir.py
mindinsight/datavisual/data_transform/reservoir.py
+8
-26
mindinsight/datavisual/utils/utils.py
mindinsight/datavisual/utils/utils.py
+47
-0
tests/ut/datavisual/data_transform/test_histogram_container.py
.../ut/datavisual/data_transform/test_histogram_container.py
+53
-2
未找到文件。
mindinsight/datavisual/data_transform/histogram_container.py
浏览文件 @
988aad75
...
@@ -16,6 +16,8 @@
...
@@ -16,6 +16,8 @@
import
math
import
math
from
mindinsight.datavisual.proto_files.mindinsight_summary_pb2
import
Summary
from
mindinsight.datavisual.proto_files.mindinsight_summary_pb2
import
Summary
from
mindinsight.utils.exceptions
import
ParamValueError
from
mindinsight.datavisual.utils.utils
import
calc_histogram_bins
def
_mask_invalid_number
(
num
):
def
_mask_invalid_number
(
num
):
...
@@ -26,6 +28,49 @@ def _mask_invalid_number(num):
...
@@ -26,6 +28,49 @@ def _mask_invalid_number(num):
return
num
return
num
class
Bucket
:
"""
Bucket data class.
Args:
left (double): Left edge of the histogram bucket.
width (double): Width of the histogram bucket.
count (int): Count of numbers fallen in the histogram bucket.
"""
def
__init__
(
self
,
left
,
width
,
count
):
self
.
_left
=
left
self
.
_width
=
width
self
.
_count
=
count
@
property
def
left
(
self
):
"""Gets left edge of the histogram bucket."""
return
self
.
_left
@
property
def
count
(
self
):
"""Gets count of numbers fallen in the histogram bucket."""
return
self
.
_count
@
property
def
width
(
self
):
"""Gets width of the histogram bucket."""
return
self
.
_width
@
property
def
right
(
self
):
"""Gets right edge of the histogram bucket."""
return
self
.
_left
+
self
.
_width
def
as_tuple
(
self
):
"""Gets the bucket as tuple."""
return
self
.
_left
,
self
.
_width
,
self
.
_count
def
__repr__
(
self
):
"""Returns repr(self)."""
return
"Bucket(left={}, width={}, count={})"
.
format
(
self
.
_left
,
self
.
_width
,
self
.
_count
)
class
HistogramContainer
:
class
HistogramContainer
:
"""
"""
Histogram data container.
Histogram data container.
...
@@ -35,16 +80,19 @@ class HistogramContainer:
...
@@ -35,16 +80,19 @@ class HistogramContainer:
"""
"""
def
__init__
(
self
,
histogram_message
:
Summary
.
Histogram
):
def
__init__
(
self
,
histogram_message
:
Summary
.
Histogram
):
self
.
_msg
=
histogram_message
self
.
_msg
=
histogram_message
self
.
_original_buckets
=
tuple
((
bucket
.
left
,
bucket
.
width
,
bucket
.
count
)
for
bucket
in
self
.
_msg
.
buckets
)
original_buckets
=
[
Bucket
(
bucket
.
left
,
bucket
.
width
,
bucket
.
count
)
for
bucket
in
self
.
_msg
.
buckets
]
# Ensure buckets are sorted from min to max.
original_buckets
.
sort
(
key
=
lambda
bucket
:
bucket
.
left
)
self
.
_original_buckets
=
tuple
(
original_buckets
)
self
.
_count
=
sum
(
bucket
.
count
for
bucket
in
self
.
_original_buckets
)
self
.
_max
=
_mask_invalid_number
(
histogram_message
.
max
)
self
.
_max
=
_mask_invalid_number
(
histogram_message
.
max
)
self
.
_min
=
_mask_invalid_number
(
histogram_message
.
min
)
self
.
_min
=
_mask_invalid_number
(
histogram_message
.
min
)
self
.
_visual_max
=
self
.
_max
self
.
_visual_max
=
self
.
_max
self
.
_visual_min
=
self
.
_min
self
.
_visual_min
=
self
.
_min
# default bin number
# default bin number
self
.
_visual_bins
=
10
self
.
_visual_bins
=
calc_histogram_bins
(
self
.
_count
)
self
.
_count
=
sum
(
bucket
[
2
]
for
bucket
in
self
.
_original_buckets
)
# Note that tuple is immutable, so sharing tuple is often safe.
# Note that tuple is immutable, so sharing tuple is often safe.
self
.
_re_sampled_buckets
=
self
.
_original_buckets
self
.
_re_sampled_buckets
=
()
@
property
@
property
def
max
(
self
):
def
max
(
self
):
...
@@ -63,7 +111,7 @@ class HistogramContainer:
...
@@ -63,7 +111,7 @@ class HistogramContainer:
@
property
@
property
def
original_msg
(
self
):
def
original_msg
(
self
):
"""Get
original proto message
"""
"""Get
s original proto message.
"""
return
self
.
_msg
return
self
.
_msg
def
set_visual_range
(
self
,
max_val
:
float
,
min_val
:
float
,
bins
:
int
)
->
None
:
def
set_visual_range
(
self
,
max_val
:
float
,
min_val
:
float
,
bins
:
int
)
->
None
:
...
@@ -77,6 +125,13 @@ class HistogramContainer:
...
@@ -77,6 +125,13 @@ class HistogramContainer:
min_val (float): Min value for visual histogram.
min_val (float): Min value for visual histogram.
bins (int): Bins number for visual histogram.
bins (int): Bins number for visual histogram.
"""
"""
if
max_val
<
min_val
:
raise
ParamValueError
(
"Invalid input. max_val({}) is less or equal than min_val({})."
.
format
(
max_val
,
min_val
))
if
bins
<
1
:
raise
ParamValueError
(
"Invalid input bins({}). Must be greater than 0."
.
format
(
bins
))
self
.
_visual_max
=
max_val
self
.
_visual_max
=
max_val
self
.
_visual_min
=
min_val
self
.
_visual_min
=
min_val
self
.
_visual_bins
=
bins
self
.
_visual_bins
=
bins
...
@@ -84,15 +139,104 @@ class HistogramContainer:
...
@@ -84,15 +139,104 @@ class HistogramContainer:
# mark _re_sampled_buckets to empty
# mark _re_sampled_buckets to empty
self
.
_re_sampled_buckets
=
()
self
.
_re_sampled_buckets
=
()
def
_re_sample_buckets
(
self
):
def
_calc_intersection_len
(
self
,
max1
,
min1
,
max2
,
min2
):
# Will call re-sample logic in later PR.
"""Calculates intersection length of [min1, max1] and [min2, max2]."""
self
.
_re_sampled_buckets
=
self
.
_original_buckets
if
max1
<
min1
:
raise
ParamValueError
(
"Invalid input. max1({}) is less than min1({})."
.
format
(
max1
,
min1
))
if
max2
<
min2
:
raise
ParamValueError
(
"Invalid input. max2({}) is less than min2({})."
.
format
(
max2
,
min2
))
if
min1
<=
min2
:
if
max1
<=
min2
:
# return value must be calculated by max1.__sub__
return
max1
-
max1
if
max1
<=
max2
:
return
max1
-
min2
# max1 > max2
return
max2
-
min2
# min1 > min2
if
max2
<=
min1
:
return
max2
-
max2
if
max2
<=
max1
:
return
max2
-
min1
return
max1
-
min1
def
buckets
(
self
):
def
_re_sample_buckets
(
self
):
"""Re-samples buckets according to visual_max, visual_min and visual_bins."""
if
self
.
_visual_max
==
self
.
_visual_min
:
# Adjust visual range if max equals min.
self
.
_visual_max
+=
0.5
self
.
_visual_min
-=
0.5
width
=
(
self
.
_visual_max
-
self
.
_visual_min
)
/
self
.
_visual_bins
if
not
self
.
count
:
self
.
_re_sampled_buckets
=
tuple
(
Bucket
(
self
.
_visual_min
+
width
*
i
,
width
,
0
)
for
i
in
range
(
self
.
_visual_bins
))
return
re_sampled
=
[]
original_pos
=
0
original_bucket
=
self
.
_original_buckets
[
original_pos
]
for
i
in
range
(
self
.
_visual_bins
):
cur_left
=
self
.
_visual_min
+
width
*
i
cur_right
=
cur_left
+
width
cur_estimated_count
=
0.0
# Skip no bucket range.
if
cur_right
<=
original_bucket
.
left
:
re_sampled
.
append
(
Bucket
(
cur_left
,
width
,
math
.
ceil
(
cur_estimated_count
)))
continue
# Skip no intersect range.
while
cur_left
>=
original_bucket
.
right
:
original_pos
+=
1
if
original_pos
>=
len
(
self
.
_original_buckets
):
break
original_bucket
=
self
.
_original_buckets
[
original_pos
]
# entering with this condition: cur_right > original_bucket.left and cur_left < original_bucket.right
while
True
:
if
original_pos
>=
len
(
self
.
_original_buckets
):
break
original_bucket
=
self
.
_original_buckets
[
original_pos
]
intersection
=
self
.
_calc_intersection_len
(
min1
=
cur_left
,
max1
=
cur_right
,
min2
=
original_bucket
.
left
,
max2
=
original_bucket
.
right
)
estimated_count
=
(
intersection
/
original_bucket
.
width
)
*
original_bucket
.
count
cur_estimated_count
+=
estimated_count
if
cur_right
>
original_bucket
.
right
:
# Need to sample next original bucket to this visual bucket.
original_pos
+=
1
else
:
# Current visual bucket has taken all intersect buckets into account.
break
re_sampled
.
append
(
Bucket
(
cur_left
,
width
,
math
.
ceil
(
cur_estimated_count
)))
self
.
_re_sampled_buckets
=
tuple
(
re_sampled
)
def
buckets
(
self
,
convert_to_tuple
=
True
):
"""
"""
Get visual buckets instead of original buckets.
Get visual buckets instead of original buckets.
Args:
convert_to_tuple (bool): Whether convert bucket object to tuple.
Returns:
tuple, contains buckets.
"""
"""
if
not
self
.
_re_sampled_buckets
:
if
not
self
.
_re_sampled_buckets
:
self
.
_re_sample_buckets
()
self
.
_re_sample_buckets
()
return
self
.
_re_sampled_buckets
if
not
convert_to_tuple
:
return
self
.
_re_sampled_buckets
return
tuple
(
bucket
.
as_tuple
()
for
bucket
in
self
.
_re_sampled_buckets
)
mindinsight/datavisual/data_transform/reservoir.py
浏览文件 @
988aad75
...
@@ -16,10 +16,11 @@
...
@@ -16,10 +16,11 @@
import
random
import
random
import
threading
import
threading
import
math
from
mindinsight.datavisual.common.log
import
logger
from
mindinsight.datavisual.common.enums
import
PluginNameEnum
from
mindinsight.datavisual.common.enums
import
PluginNameEnum
from
mindinsight.utils.exceptions
import
ParamValueError
from
mindinsight.utils.exceptions
import
ParamValueError
from
mindinsight.datavisual.utils.utils
import
calc_histogram_bins
class
Reservoir
:
class
Reservoir
:
...
@@ -173,39 +174,20 @@ class HistogramReservoir(Reservoir):
...
@@ -173,39 +174,20 @@ class HistogramReservoir(Reservoir):
max_count
=
max
(
histogram
.
count
,
max_count
)
max_count
=
max
(
histogram
.
count
,
max_count
)
visual_range
.
update
(
histogram
.
max
,
histogram
.
min
)
visual_range
.
update
(
histogram
.
max
,
histogram
.
min
)
bins
=
self
.
_calc
_bins
(
max_count
)
bins
=
calc_histogram
_bins
(
max_count
)
# update visual range
# update visual range
logger
.
info
(
"Visual histogram: min %s, max %s, bins %s, max_count %s."
,
visual_range
.
min
,
visual_range
.
max
,
bins
,
max_count
)
for
sample
in
self
.
_samples
:
for
sample
in
self
.
_samples
:
histogram
=
sample
.
value
histogram
=
sample
.
value
histogram
.
set_visual_range
(
visual_range
.
max
,
visual_range
.
min
,
bins
)
histogram
.
set_visual_range
(
visual_range
.
max
,
visual_range
.
min
,
bins
)
return
list
(
self
.
_samples
)
return
list
(
self
.
_samples
)
def
_calc_bins
(
self
,
count
):
"""
Calculates experience-based optimal bins number.
To suppress re-sample bias, there should be enough number in each bin. So we calc bin numbers according to
count. For very small count(1 - 10), we assign carefully chosen number. For large count, we tried to make
sure there are 9-10 numbers in each bucket on average. Too many bins will also distract users, so we set max
number of bins to 30.
"""
number_per_bucket
=
10
max_bins
=
30
if
not
count
:
return
1
if
count
<=
5
:
return
2
if
count
<=
10
:
return
3
if
count
<=
280
:
# note that math.ceil(281/10) + 1 = 30
return
math
.
ceil
(
count
/
number_per_bucket
)
+
1
return
max_bins
class
ReservoirFactory
:
class
ReservoirFactory
:
"""Factory class to get reservoir instances."""
"""Factory class to get reservoir instances."""
...
...
mindinsight/datavisual/utils/utils.py
0 → 100644
浏览文件 @
988aad75
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils."""
import
math
def
calc_histogram_bins
(
count
):
"""
Calculates experience-based optimal bins number for histogram.
To suppress re-sample bias, there should be enough number in each bin. So we calc bin numbers according to
count. For very small count(1 - 10), we assign carefully chosen number. For large count, we tried to make
sure there are 9-10 numbers in each bucket on average. Too many bins will also distract users, so we set max
number of bins to 30.
Args:
count (int): Valid number count for the tensor.
Returns:
int, number of histogram bins.
"""
number_per_bucket
=
10
max_bins
=
30
if
not
count
:
return
1
if
count
<=
5
:
return
2
if
count
<=
10
:
return
3
if
count
<=
280
:
# note that math.ceil(281/10) + 1 equals 30
return
math
.
ceil
(
count
/
number_per_bucket
)
+
1
return
max_bins
tests/ut/datavisual/data_transform/test_histogram_container.py
浏览文件 @
988aad75
...
@@ -20,8 +20,9 @@ from mindinsight.datavisual.data_transform import histogram_container as hist
...
@@ -20,8 +20,9 @@ from mindinsight.datavisual.data_transform import histogram_container as hist
class
TestHistogram
:
class
TestHistogram
:
"""Test histogram."""
"""Test histogram."""
def
test_get_buckets
(
self
):
def
test_get_buckets
(
self
):
"""Test get buckets."""
"""Test
s
get buckets."""
mocked_input
=
mock
.
MagicMock
()
mocked_input
=
mock
.
MagicMock
()
mocked_bucket
=
mock
.
MagicMock
()
mocked_bucket
=
mock
.
MagicMock
()
mocked_bucket
.
left
=
0
mocked_bucket
.
left
=
0
...
@@ -31,4 +32,54 @@ class TestHistogram:
...
@@ -31,4 +32,54 @@ class TestHistogram:
histogram
=
hist
.
HistogramContainer
(
mocked_input
)
histogram
=
hist
.
HistogramContainer
(
mocked_input
)
histogram
.
set_visual_range
(
max_val
=
1
,
min_val
=
0
,
bins
=
1
)
histogram
.
set_visual_range
(
max_val
=
1
,
min_val
=
0
,
bins
=
1
)
buckets
=
histogram
.
buckets
()
buckets
=
histogram
.
buckets
()
assert
len
(
buckets
)
==
1
assert
buckets
==
((
0.0
,
1.0
,
1
),)
\ No newline at end of file
def
test_re_sample_buckets_split_original
(
self
):
"""Tests splitting original buckets when re-sampling."""
mocked_input
=
mock
.
MagicMock
()
mocked_bucket
=
mock
.
MagicMock
()
mocked_bucket
.
left
=
0
mocked_bucket
.
width
=
1
mocked_bucket
.
count
=
1
mocked_input
.
buckets
=
[
mocked_bucket
]
histogram
=
hist
.
HistogramContainer
(
mocked_input
)
histogram
.
set_visual_range
(
max_val
=
1
,
min_val
=
0
,
bins
=
3
)
buckets
=
histogram
.
buckets
()
assert
buckets
==
((
0.0
,
0.3333333333333333
,
1
),
(
0.3333333333333333
,
0.3333333333333333
,
1
),
(
0.6666666666666666
,
0.3333333333333333
,
1
))
def
test_re_sample_buckets_zero_bucket
(
self
):
"""Tests zero bucket when re-sampling."""
mocked_input
=
mock
.
MagicMock
()
mocked_bucket
=
mock
.
MagicMock
()
mocked_bucket
.
left
=
0
mocked_bucket
.
width
=
1
mocked_bucket
.
count
=
1
mocked_bucket2
=
mock
.
MagicMock
()
mocked_bucket2
.
left
=
1
mocked_bucket2
.
width
=
1
mocked_bucket2
.
count
=
2
mocked_input
.
buckets
=
[
mocked_bucket
,
mocked_bucket2
]
histogram
=
hist
.
HistogramContainer
(
mocked_input
)
histogram
.
set_visual_range
(
max_val
=
3
,
min_val
=-
1
,
bins
=
4
)
buckets
=
histogram
.
buckets
()
assert
buckets
==
((
-
1.0
,
1.0
,
0
),
(
0.0
,
1.0
,
1
),
(
1.0
,
1.0
,
2
),
(
2.0
,
1.0
,
0
))
def
test_re_sample_buckets_merge_bucket
(
self
):
"""Tests merging counts from two buckets when re-sampling."""
mocked_input
=
mock
.
MagicMock
()
mocked_bucket
=
mock
.
MagicMock
()
mocked_bucket
.
left
=
0
mocked_bucket
.
width
=
1
mocked_bucket
.
count
=
1
mocked_bucket2
=
mock
.
MagicMock
()
mocked_bucket2
.
left
=
1
mocked_bucket2
.
width
=
1
mocked_bucket2
.
count
=
10
mocked_input
.
buckets
=
[
mocked_bucket
,
mocked_bucket2
]
histogram
=
hist
.
HistogramContainer
(
mocked_input
)
histogram
.
set_visual_range
(
max_val
=
3
,
min_val
=-
1
,
bins
=
5
)
buckets
=
histogram
.
buckets
()
assert
buckets
==
(
(
-
1.0
,
0.8
,
0
),
(
-
0.19999999999999996
,
0.8
,
1
),
(
0.6000000000000001
,
0.8
,
5
),
(
1.4000000000000004
,
0.8
,
6
),
(
2.2
,
0.8
,
0
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录