Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
akg
提交
5c0b3f49
A
akg
项目概览
MindSpore
/
akg
通知
58
Star
7
Fork
7
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
akg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5c0b3f49
编写于
7月 24, 2020
作者:
D
dabaiji
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add epsilon to random gaussian to avoid zero division error
上级
8437561d
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
11 addition
and
6 deletion
+11
-6
tests/common/gen_random.py
tests/common/gen_random.py
+9
-3
tests/common/test_run/distr_normal_diag_KLdiv_ad_run.py
tests/common/test_run/distr_normal_diag_KLdiv_ad_run.py
+2
-3
未找到文件。
tests/common/gen_random.py
浏览文件 @
5c0b3f49
...
...
@@ -28,6 +28,8 @@ from akg.utils.kernel_exec import get_profiling_mode
RANDOM_SEED_NUM
=
20
PROF_ERROR_CODE
=
9999999999
def
func
(
size_
,
miu_
=
0
,
sigma_
=
8
,
seed_
=
None
):
"""
Select random func according to RANDOM_FUNC_MODE and randint, calculated by the length of the random_func_list.
...
...
@@ -59,7 +61,7 @@ def func(size_, miu_=0, sigma_=8, seed_=None):
@
func_time_required
def
random_gaussian
(
size
,
miu
=
0
,
sigma
=
8
,
seed
=
None
):
def
random_gaussian
(
size
,
miu
=
0
,
sigma
=
8
,
epsilon
=
0
,
seed
=
None
):
"""Generate random array with absolution value obeys gaussian distribution."""
random_data_disk_path
=
None
if
os
.
environ
.
get
(
"RANDOM_DATA_DISK_PATH"
)
is
not
None
:
...
...
@@ -93,7 +95,7 @@ def random_gaussian(size, miu=0, sigma=8, seed=None):
numbers
.
extend
(
func
(
size_c
,
miu
,
sigma
,
s
))
ret
=
np
.
array
(
numbers
)
ret
=
ret
.
flatten
()
return
ret
[:
size_c
].
reshape
(
size
)
return
ret
[:
size_c
].
reshape
(
size
)
+
epsilon
data_len
=
functools
.
reduce
(
lambda
x
,
y
:
x
*
y
,
size
)
data_pool
=
np
.
fromfile
(
random_data_disk_path
)
...
...
@@ -107,4 +109,8 @@ def random_gaussian(size, miu=0, sigma=8, seed=None):
np
.
random
.
shuffle
(
data_copy
)
data_copy_list
.
append
(
data_copy
)
data_pool
=
np
.
concatenate
(
tuple
(
data_copy_list
),
axis
=
0
)
return
data_pool
[
0
:
data_len
].
reshape
(
size
)
\ No newline at end of file
return
data_pool
[
0
:
data_len
].
reshape
(
size
)
+
epsilon
def
gen_epsilon
(
dtype
):
"""Generate suggested epsilon according to data type."""
return
1e-7
if
dtype
==
np
.
float32
else
1e-3
tests/common/test_run/distr_normal_diag_KLdiv_ad_run.py
浏览文件 @
5c0b3f49
...
...
@@ -16,7 +16,7 @@ import numpy as np
from
tensorio
import
compare_tensor
from
akg.utils
import
kernel_exec
as
utils
from
test_op.prob_program
import
distr_normal_diag_KLdiv_ad
from
gen_random
import
random_gaussian
from
gen_random
import
random_gaussian
,
gen_epsilon
from
base
import
get_rtol_atol
...
...
@@ -40,9 +40,8 @@ def gen_data(dtype, shape):
support_list
=
{
"float16"
:
np
.
float16
,
"float32"
:
np
.
float32
}
m
,
k
=
shape
mean
=
random_gaussian
((
m
,
k
),
miu
=
1
,
sigma
=
0.1
).
astype
(
support_list
[
dtype
])
scale
=
random_gaussian
((
m
,
k
),
miu
=
1
,
sigma
=
0.1
).
astype
(
support_list
[
dtype
])
scale
=
random_gaussian
((
m
,
k
),
miu
=
1
,
sigma
=
0.1
,
epsilon
=
gen_epsilon
(
dtype
)
).
astype
(
support_list
[
dtype
])
head
=
random_gaussian
((
m
,
),
miu
=
1
,
sigma
=
0.1
).
astype
(
support_list
[
dtype
])
output1
=
np
.
full
((
m
,
k
),
0.0
,
dtype
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录