Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b366608a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b366608a
编写于
8月 22, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 22, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4952 Fix errors in log calculation logics
Merge pull request !4952 from peixu_ren/custom_pp_ops
上级
9b503e4f
1c8eb9b1
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
108 addition
and
26 deletion
+108
-26
mindspore/nn/probability/distribution/_utils/custom_ops.py
mindspore/nn/probability/distribution/_utils/custom_ops.py
+10
-4
tests/ut/python/nn/distribution/test_bernoulli.py
tests/ut/python/nn/distribution/test_bernoulli.py
+98
-22
未找到文件。
mindspore/nn/probability/distribution/_utils/custom_ops.py
浏览文件 @
b366608a
...
...
@@ -15,24 +15,30 @@
"""Utitly functions to help distribution class."""
import
numpy
as
np
from
mindspore.ops
import
operations
as
P
from
mindspore.common
import
dtype
as
mstype
def
log_by_step
(
input_x
):
"""
Log op on Ascend is calculated as log(abs(x)).
Fix this with putting negative values as nan.
"""
select
=
P
.
Select
()
log
=
P
.
Log
()
less
=
P
.
Less
()
lessequal
=
P
.
LessEqual
()
fill
=
P
.
Fill
()
cast
=
P
.
Cast
()
dtype
=
P
.
DType
()
shape
=
P
.
Shape
()
select
=
P
.
Select
()
input_x
=
cast
(
input_x
,
mstype
.
float32
)
nan
=
fill
(
dtype
(
input_x
),
shape
(
input_x
),
np
.
nan
)
inf
=
fill
(
dtype
(
input_x
),
shape
(
input_x
),
np
.
inf
)
neg_x
=
less
(
input_x
,
0.0
)
nonpos_x
=
lessequal
(
input_x
,
0.0
)
log_x
=
log
(
input_x
)
nan
=
fill
(
dtype
(
input_x
),
shape
(
input_x
),
np
.
nan
)
result
=
select
(
nonpos_x
,
nan
,
log_x
)
return
result
result
=
select
(
nonpos_x
,
-
inf
,
log_x
)
return
select
(
neg_x
,
nan
,
result
)
def
log1p_by_step
(
x
):
"""
...
...
tests/ut/python/nn/distribution/test_bernoulli.py
浏览文件 @
b366608a
...
...
@@ -157,51 +157,127 @@ def test_cross_entropy():
ans
=
net
(
probs_b
,
probs_a
)
assert
isinstance
(
ans
,
Tensor
)
class
BernoulliBasics
(
nn
.
Cell
):
class
BernoulliConstruct
(
nn
.
Cell
):
"""
Bernoulli distribution: going through construct.
"""
def
__init__
(
self
):
super
(
BernoulliConstruct
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
(
0.5
,
dtype
=
dtype
.
int32
)
self
.
b1
=
msd
.
Bernoulli
(
dtype
=
dtype
.
int32
)
def
construct
(
self
,
value
,
probs
):
prob
=
self
.
b
(
'prob'
,
value
)
prob1
=
self
.
b
(
'prob'
,
value
,
probs
)
prob2
=
self
.
b1
(
'prob'
,
value
,
probs
)
return
prob
+
prob1
+
prob2
def
test_bernoulli_construct
():
"""
Test probability function going through construct.
"""
net
=
BernoulliConstruct
()
value
=
Tensor
([
0
,
0
,
0
,
0
,
0
],
dtype
=
dtype
.
float32
)
probs
=
Tensor
([
0.5
],
dtype
=
dtype
.
float32
)
ans
=
net
(
value
,
probs
)
assert
isinstance
(
ans
,
Tensor
)
class
BernoulliMean
(
nn
.
Cell
):
"""
Test class: basic mean/sd/var/mode/entropy function.
"""
def
__init__
(
self
):
super
(
Bernoulli
Basics
,
self
).
__init__
()
super
(
Bernoulli
Mean
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
([
0.3
,
0.5
],
dtype
=
dtype
.
int32
)
def
construct
(
self
):
mean
=
self
.
b
.
mean
()
return
mean
def
test_mean
():
"""
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
"""
net
=
BernoulliMean
()
ans
=
net
()
assert
isinstance
(
ans
,
Tensor
)
class
BernoulliSd
(
nn
.
Cell
):
"""
Test class: basic mean/sd/var/mode/entropy function.
"""
def
__init__
(
self
):
super
(
BernoulliSd
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
([
0.3
,
0.5
],
dtype
=
dtype
.
int32
)
def
construct
(
self
):
sd
=
self
.
b
.
sd
()
return
sd
def
test_sd
():
"""
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
"""
net
=
BernoulliSd
()
ans
=
net
()
assert
isinstance
(
ans
,
Tensor
)
class
BernoulliVar
(
nn
.
Cell
):
"""
Test class: basic mean/sd/var/mode/entropy function.
"""
def
__init__
(
self
):
super
(
BernoulliVar
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
([
0.3
,
0.5
],
dtype
=
dtype
.
int32
)
def
construct
(
self
):
var
=
self
.
b
.
var
()
return
var
def
test_var
():
"""
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
"""
net
=
BernoulliVar
()
ans
=
net
()
assert
isinstance
(
ans
,
Tensor
)
class
BernoulliMode
(
nn
.
Cell
):
"""
Test class: basic mean/sd/var/mode/entropy function.
"""
def
__init__
(
self
):
super
(
BernoulliMode
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
([
0.3
,
0.5
],
dtype
=
dtype
.
int32
)
def
construct
(
self
):
mode
=
self
.
b
.
mode
()
entropy
=
self
.
b
.
entropy
()
return
mean
+
sd
+
var
+
mode
+
entropy
return
mode
def
test_
bascis
():
def
test_
mode
():
"""
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
"""
net
=
Bernoulli
Basics
()
net
=
Bernoulli
Mode
()
ans
=
net
()
assert
isinstance
(
ans
,
Tensor
)
class
Bernoulli
Construct
(
nn
.
Cell
):
class
Bernoulli
Entropy
(
nn
.
Cell
):
"""
Bernoulli distribution: going through construct
.
Test class: basic mean/sd/var/mode/entropy function
.
"""
def
__init__
(
self
):
super
(
BernoulliConstruct
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
(
0.5
,
dtype
=
dtype
.
int32
)
self
.
b1
=
msd
.
Bernoulli
(
dtype
=
dtype
.
int32
)
super
(
BernoulliEntropy
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
([
0.3
,
0.5
],
dtype
=
dtype
.
int32
)
def
construct
(
self
,
value
,
probs
):
prob
=
self
.
b
(
'prob'
,
value
)
prob1
=
self
.
b
(
'prob'
,
value
,
probs
)
prob2
=
self
.
b1
(
'prob'
,
value
,
probs
)
return
prob
+
prob1
+
prob2
def
construct
(
self
):
entropy
=
self
.
b
.
entropy
()
return
entropy
def
test_
bernoulli_construct
():
def
test_
entropy
():
"""
Test
probability function going through construct
.
Test
mean/sd/var/mode/entropy functionality of Bernoulli distribution
.
"""
net
=
BernoulliConstruct
()
value
=
Tensor
([
0
,
0
,
0
,
0
,
0
],
dtype
=
dtype
.
float32
)
probs
=
Tensor
([
0.5
],
dtype
=
dtype
.
float32
)
ans
=
net
(
value
,
probs
)
net
=
BernoulliEntropy
()
ans
=
net
()
assert
isinstance
(
ans
,
Tensor
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录