Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
56835aaf
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
56835aaf
编写于
8月 24, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 24, 2020
浏览文件
操作
浏览文件
下载
差异文件
!5005 Fix the bug in the formula in Bernoulli log_probs
Merge pull request !5005 from zichun_ye/fix_bernoulli_probs
上级
0bbce936
9e7d6e23
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
113 addition
and
40 deletion
+113
-40
mindspore/nn/probability/bijector/power_transform.py
mindspore/nn/probability/bijector/power_transform.py
+6
-4
mindspore/nn/probability/bijector/scalar_affine.py
mindspore/nn/probability/bijector/scalar_affine.py
+6
-4
mindspore/nn/probability/bijector/softplus.py
mindspore/nn/probability/bijector/softplus.py
+8
-5
mindspore/nn/probability/distribution/_utils/utils.py
mindspore/nn/probability/distribution/_utils/utils.py
+31
-19
mindspore/nn/probability/distribution/bernoulli.py
mindspore/nn/probability/distribution/bernoulli.py
+2
-2
mindspore/nn/probability/distribution/geometric.py
mindspore/nn/probability/distribution/geometric.py
+3
-4
tests/ut/python/nn/distribution/test_bernoulli.py
tests/ut/python/nn/distribution/test_bernoulli.py
+36
-1
tests/ut/python/nn/distribution/test_geometric.py
tests/ut/python/nn/distribution/test_geometric.py
+21
-1
未找到文件。
mindspore/nn/probability/bijector/power_transform.py
浏览文件 @
56835aaf
...
...
@@ -20,6 +20,7 @@ from ..distribution._utils.utils import CheckTensor
from
..distribution._utils.custom_ops
import
exp_by_step
,
expm1_by_step
,
log_by_step
,
log1p_by_step
from
.bijector
import
Bijector
class
PowerTransform
(
Bijector
):
r
"""
Power Bijector.
...
...
@@ -49,6 +50,7 @@ class PowerTransform(Bijector):
>>> # by replacing 'forward' with the name of the function
>>> ans = self.p1.forward(, value)
"""
def
__init__
(
self
,
power
=
0
,
name
=
'PowerTransform'
,
...
...
@@ -78,13 +80,13 @@ class PowerTransform(Bijector):
return
shape
def
_forward
(
self
,
x
):
self
.
checktensor
(
x
,
'
x
'
)
self
.
checktensor
(
x
,
'
value
'
)
if
self
.
power
==
0
:
return
self
.
exp
(
x
)
return
self
.
exp
(
self
.
log1p
(
x
*
self
.
power
)
/
self
.
power
)
def
_inverse
(
self
,
y
):
self
.
checktensor
(
y
,
'
y
'
)
self
.
checktensor
(
y
,
'
value
'
)
if
self
.
power
==
0
:
return
self
.
log
(
y
)
return
self
.
expm1
(
self
.
log
(
y
)
*
self
.
power
)
/
self
.
power
...
...
@@ -101,7 +103,7 @@ class PowerTransform(Bijector):
f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1}
\log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1)
"""
self
.
checktensor
(
x
,
'
x
'
)
self
.
checktensor
(
x
,
'
value
'
)
if
self
.
power
==
0
:
return
x
return
(
1.
/
self
.
power
-
1
)
*
self
.
log1p
(
x
*
self
.
power
)
...
...
@@ -118,5 +120,5 @@ class PowerTransform(Bijector):
f'(x) = \frac{e^c\log(y)}{y}
\log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y)
"""
self
.
checktensor
(
y
,
'
y
'
)
self
.
checktensor
(
y
,
'
value
'
)
return
(
self
.
power
-
1
)
*
self
.
log
(
y
)
mindspore/nn/probability/bijector/scalar_affine.py
浏览文件 @
56835aaf
...
...
@@ -19,6 +19,7 @@ from ..distribution._utils.utils import cast_to_tensor, CheckTensor
from
..distribution._utils.custom_ops
import
log_by_step
from
.bijector
import
Bijector
class
ScalarAffine
(
Bijector
):
"""
Scalar Affine Bijector.
...
...
@@ -47,6 +48,7 @@ class ScalarAffine(Bijector):
>>> ans = self.s1.forward_log_jacobian(value)
>>> ans = self.s1.inverse_log_jacobian(value)
"""
def
__init__
(
self
,
scale
=
1.0
,
shift
=
0.0
,
...
...
@@ -91,7 +93,7 @@ class ScalarAffine(Bijector):
.. math::
f(x) = a * x + b
"""
self
.
checktensor
(
x
,
'
x
'
)
self
.
checktensor
(
x
,
'
value
'
)
return
self
.
scale
*
x
+
self
.
shift
def
_inverse
(
self
,
y
):
...
...
@@ -99,7 +101,7 @@ class ScalarAffine(Bijector):
.. math::
f(y) = \frac{y - b}{a}
"""
self
.
checktensor
(
y
,
'
y
'
)
self
.
checktensor
(
y
,
'
value
'
)
return
(
y
-
self
.
shift
)
/
self
.
scale
def
_forward_log_jacobian
(
self
,
x
):
...
...
@@ -109,7 +111,7 @@ class ScalarAffine(Bijector):
f'(x) = a
\log(f'(x)) = \log(a)
"""
self
.
checktensor
(
x
,
'
x
'
)
self
.
checktensor
(
x
,
'
value
'
)
return
self
.
log
(
self
.
abs
(
self
.
scale
))
def
_inverse_log_jacobian
(
self
,
y
):
...
...
@@ -119,5 +121,5 @@ class ScalarAffine(Bijector):
f'(x) = \frac{1.0}{a}
\log(f'(x)) = - \log(a)
"""
self
.
checktensor
(
y
,
'
y
'
)
self
.
checktensor
(
y
,
'
value
'
)
return
-
1.
*
self
.
log
(
self
.
abs
(
self
.
scale
))
mindspore/nn/probability/bijector/softplus.py
浏览文件 @
56835aaf
...
...
@@ -22,6 +22,7 @@ from ..distribution._utils.utils import cast_to_tensor, CheckTensor
from
..distribution._utils.custom_ops
import
exp_by_step
,
expm1_by_step
,
log_by_step
from
.bijector
import
Bijector
class
Softplus
(
Bijector
):
r
"""
Softplus Bijector.
...
...
@@ -51,6 +52,7 @@ class Softplus(Bijector):
>>> ans = self.sp1.forward_log_jacobian(value)
>>> ans = self.sp1.inverse_log_jacobian(value)
"""
def
__init__
(
self
,
sharpness
=
1.0
,
name
=
'Softplus'
):
...
...
@@ -76,6 +78,7 @@ class Softplus(Bijector):
self
.
checktensor
=
CheckTensor
()
self
.
threshold
=
np
.
log
(
np
.
finfo
(
np
.
float32
).
eps
)
+
1
self
.
tiny
=
np
.
exp
(
self
.
threshold
)
def
_softplus
(
self
,
x
):
too_small
=
self
.
less
(
x
,
self
.
threshold
)
...
...
@@ -94,7 +97,7 @@ class Softplus(Bijector):
f(x) = \frac{\log(1 + e^{x}))}
f^{-1}(y) = \frac{\log(e^{y} - 1)}
"""
too_small
=
self
.
less
(
x
,
self
.
t
hreshold
)
too_small
=
self
.
less
(
x
,
self
.
t
iny
)
too_large
=
self
.
greater
(
x
,
-
self
.
threshold
)
too_small_value
=
self
.
log
(
x
)
too_large_value
=
x
...
...
@@ -116,7 +119,7 @@ class Softplus(Bijector):
return
shape
def
_forward
(
self
,
x
):
self
.
checktensor
(
x
,
'
x
'
)
self
.
checktensor
(
x
,
'
value
'
)
scaled_value
=
self
.
sharpness
*
x
return
self
.
softplus
(
scaled_value
)
/
self
.
sharpness
...
...
@@ -126,7 +129,7 @@ class Softplus(Bijector):
f(x) = \frac{\log(1 + e^{kx}))}{k}
f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k}
"""
self
.
checktensor
(
y
,
'
y
'
)
self
.
checktensor
(
y
,
'
value
'
)
scaled_value
=
self
.
sharpness
*
y
return
self
.
inverse_softplus
(
scaled_value
)
/
self
.
sharpness
...
...
@@ -137,7 +140,7 @@ class Softplus(Bijector):
f'(x) = \frac{e^{kx}}{ 1 + e^{kx}}
\log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx)
"""
self
.
checktensor
(
x
,
'
x
'
)
self
.
checktensor
(
x
,
'
value
'
)
scaled_value
=
self
.
sharpness
*
x
return
self
.
log_sigmoid
(
scaled_value
)
...
...
@@ -148,6 +151,6 @@ class Softplus(Bijector):
f'(y) = \frac{e^{ky}}{e^{ky} - 1}
\log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky)
"""
self
.
checktensor
(
y
,
'
y
'
)
self
.
checktensor
(
y
,
'
value
'
)
scaled_value
=
self
.
sharpness
*
y
return
scaled_value
-
self
.
inverse_softplus
(
scaled_value
)
mindspore/nn/probability/distribution/_utils/utils.py
浏览文件 @
56835aaf
...
...
@@ -26,6 +26,7 @@ from mindspore import context
import
mindspore.nn
as
nn
import
mindspore.nn.probability
as
msp
def
cast_to_tensor
(
t
,
hint_type
=
mstype
.
float32
):
"""
Cast an user input value into a Tensor of dtype.
...
...
@@ -47,7 +48,7 @@ def cast_to_tensor(t, hint_type=mstype.float32):
return
t
t_type
=
hint_type
if
isinstance
(
t
,
Tensor
):
#convert the type of tensor to dtype
#
convert the type of tensor to dtype
return
Tensor
(
t
.
asnumpy
(),
dtype
=
t_type
)
if
isinstance
(
t
,
(
list
,
np
.
ndarray
)):
return
Tensor
(
t
,
dtype
=
t_type
)
...
...
@@ -56,7 +57,8 @@ def cast_to_tensor(t, hint_type=mstype.float32):
if
isinstance
(
t
,
(
int
,
float
)):
return
Tensor
(
t
,
dtype
=
t_type
)
invalid_type
=
type
(
t
)
raise
TypeError
(
f
"Unable to convert input of type
{
invalid_type
}
to a Tensor of type
{
t_type
}
"
)
raise
TypeError
(
f
"Unable to convert input of type
{
invalid_type
}
to a Tensor of type
{
t_type
}
"
)
def
convert_to_batch
(
t
,
batch_shape
,
required_type
):
...
...
@@ -79,6 +81,7 @@ def convert_to_batch(t, batch_shape, required_type):
t
=
cast_to_tensor
(
t
,
required_type
)
return
Tensor
(
np
.
broadcast_to
(
t
.
asnumpy
(),
batch_shape
),
dtype
=
required_type
)
def
check_scalar_from_param
(
params
):
"""
Check if params are all scalars.
...
...
@@ -93,11 +96,7 @@ def check_scalar_from_param(params):
return
params
[
'distribution'
].
is_scalar_batch
if
isinstance
(
value
,
Parameter
):
return
False
if
isinstance
(
value
,
(
str
,
type
(
params
[
'dtype'
]))):
continue
elif
isinstance
(
value
,
(
int
,
float
)):
continue
else
:
if
not
isinstance
(
value
,
(
int
,
float
,
str
,
type
(
params
[
'dtype'
]))):
return
False
return
True
...
...
@@ -124,7 +123,8 @@ def calc_broadcast_shape_from_param(params):
value_t
=
value
.
default_input
else
:
value_t
=
cast_to_tensor
(
value
,
mstype
.
float32
)
broadcast_shape
=
utils
.
get_broadcast_shape
(
broadcast_shape
,
list
(
value_t
.
shape
),
params
[
'name'
])
broadcast_shape
=
utils
.
get_broadcast_shape
(
broadcast_shape
,
list
(
value_t
.
shape
),
params
[
'name'
])
return
tuple
(
broadcast_shape
)
...
...
@@ -148,6 +148,7 @@ def check_greater_equal_zero(value, name):
if
comp
.
any
():
raise
ValueError
(
f
'
{
name
}
should be greater than ot equal to zero.'
)
def
check_greater_zero
(
value
,
name
):
"""
Check if the given Tensor is strictly greater than zero.
...
...
@@ -251,6 +252,7 @@ def probs_to_logits(probs, is_binary=False):
return
P
.
Log
()(
ps_clamped
)
-
P
.
Log
()(
1
-
ps_clamped
)
return
P
.
Log
()(
ps_clamped
)
def
check_tensor_type
(
name
,
inputs
,
valid_type
):
"""
Check if inputs is proper.
...
...
@@ -268,25 +270,34 @@ def check_tensor_type(name, inputs, valid_type):
if
input_type
not
in
valid_type
:
raise
TypeError
(
f
"
{
name
}
dtype is invalid"
)
def
check_type
(
data_type
,
value_type
,
name
):
if
not
data_type
in
value_type
:
raise
TypeError
(
f
"For
{
name
}
, valid type include
{
value_type
}
,
{
data_type
}
is invalid"
)
raise
TypeError
(
f
"For
{
name
}
, valid type include
{
value_type
}
,
{
data_type
}
is invalid"
)
@
constexpr
def
raise_none_error
(
name
):
raise
TypeError
(
f
"the type
{
name
}
should be subclass of Tensor."
f
" It should not be None since it is not specified during initialization."
)
@
constexpr
def
raise_not_impl_error
(
name
):
raise
ValueError
(
f
"
{
name
}
function should be implemented for non-linear transformation"
)
raise
ValueError
(
f
"
{
name
}
function should be implemented for non-linear transformation"
)
@
constexpr
def
check_distribution_name
(
name
,
expected_name
):
if
name
is
None
:
raise
ValueError
(
f
"Distribution should be a constant which is not None."
)
raise
ValueError
(
f
"Distribution should be a constant which is not None."
)
if
name
!=
expected_name
:
raise
ValueError
(
f
"Expected distribution name is
{
expected_name
}
, but got
{
name
}
."
)
raise
ValueError
(
f
"Expected distribution name is
{
expected_name
}
, but got
{
name
}
."
)
class
CheckTuple
(
PrimitiveWithInfer
):
"""
...
...
@@ -294,13 +305,13 @@ class CheckTuple(PrimitiveWithInfer):
"""
@
prim_attr_register
def
__init__
(
self
):
"""init Cast"""
super
(
CheckTuple
,
self
).
__init__
(
"CheckTuple"
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
],
outputs
=
[
'dummy_output'
])
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'name'
],
outputs
=
[
'dummy_output'
])
def
__infer__
(
self
,
x
,
name
):
if
not
isinstance
(
x
[
'dtype'
],
tuple
):
raise
TypeError
(
f
"For
{
name
[
'value'
]
}
, Input type should b a tuple."
)
raise
TypeError
(
f
"For
{
name
[
'value'
]
}
, Input type should b a tuple."
)
out
=
{
'shape'
:
None
,
'dtype'
:
None
,
...
...
@@ -310,24 +321,25 @@ class CheckTuple(PrimitiveWithInfer):
def
__call__
(
self
,
x
,
name
):
if
context
.
get_context
(
"mode"
)
==
0
:
return
x
[
"value"
]
#Pynative mode
#
Pynative mode
if
isinstance
(
x
,
tuple
):
return
x
raise
TypeError
(
f
"For
{
name
[
'value'
]
}
, Input type should b a tuple."
)
class
CheckTensor
(
PrimitiveWithInfer
):
"""
Check if input is a Tensor.
"""
@
prim_attr_register
def
__init__
(
self
):
"""init Cast"""
super
(
CheckTensor
,
self
).
__init__
(
"CheckTensor"
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
],
outputs
=
[
'dummy_output'
])
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'name'
],
outputs
=
[
'dummy_output'
])
def
__infer__
(
self
,
x
,
name
):
src_type
=
x
[
'dtype'
]
validator
.
check_subclass
(
"input"
,
src_type
,
[
mstype
.
tensor
],
name
[
"value"
])
validator
.
check_subclass
(
"input"
,
src_type
,
[
mstype
.
tensor
],
name
[
"value"
])
out
=
{
'shape'
:
None
,
'dtype'
:
None
,
...
...
mindspore/nn/probability/distribution/bernoulli.py
浏览文件 @
56835aaf
...
...
@@ -20,6 +20,7 @@ from .distribution import Distribution
from
._utils.utils
import
cast_to_tensor
,
check_prob
,
check_type
,
check_distribution_name
,
raise_none_error
from
._utils.custom_ops
import
exp_by_step
,
log_by_step
class
Bernoulli
(
Distribution
):
"""
Bernoulli Distribution.
...
...
@@ -97,7 +98,7 @@ class Bernoulli(Distribution):
Constructor of Bernoulli distribution.
"""
param
=
dict
(
locals
())
valid_dtype
=
mstype
.
int_type
+
mstype
.
uint_type
valid_dtype
=
mstype
.
int_type
+
mstype
.
uint_type
+
mstype
.
float_type
check_type
(
dtype
,
valid_dtype
,
"Bernoulli"
)
super
(
Bernoulli
,
self
).
__init__
(
seed
,
dtype
,
name
,
param
)
self
.
parameter_type
=
mstype
.
float32
...
...
@@ -211,7 +212,6 @@ class Bernoulli(Distribution):
"""
self
.
checktensor
(
value
,
'value'
)
value
=
self
.
cast
(
value
,
mstype
.
float32
)
value
=
self
.
floor
(
value
)
probs1
=
self
.
_check_param
(
probs1
)
probs0
=
1.0
-
probs1
return
self
.
log
(
probs1
)
*
value
+
self
.
log
(
probs0
)
*
(
1.0
-
value
)
...
...
mindspore/nn/probability/distribution/geometric.py
浏览文件 @
56835aaf
...
...
@@ -19,9 +19,10 @@ from mindspore.ops import composite as C
from
mindspore.common
import
dtype
as
mstype
from
.distribution
import
Distribution
from
._utils.utils
import
cast_to_tensor
,
check_prob
,
check_type
,
check_distribution_name
,
\
raise_none_error
raise_none_error
from
._utils.custom_ops
import
exp_by_step
,
log_by_step
class
Geometric
(
Distribution
):
"""
Geometric Distribution.
...
...
@@ -100,7 +101,7 @@ class Geometric(Distribution):
Constructor of Geometric distribution.
"""
param
=
dict
(
locals
())
valid_dtype
=
mstype
.
int_type
+
mstype
.
uint_type
valid_dtype
=
mstype
.
int_type
+
mstype
.
uint_type
+
mstype
.
float_type
check_type
(
dtype
,
valid_dtype
,
"Geometric"
)
super
(
Geometric
,
self
).
__init__
(
seed
,
dtype
,
name
,
param
)
self
.
parameter_type
=
mstype
.
float32
...
...
@@ -130,7 +131,6 @@ class Geometric(Distribution):
self
.
sqrt
=
P
.
Sqrt
()
self
.
uniform
=
C
.
uniform
def
extend_repr
(
self
):
if
self
.
is_scalar_batch
:
str_info
=
f
'probs =
{
self
.
probs
}
'
...
...
@@ -243,7 +243,6 @@ class Geometric(Distribution):
comp
=
self
.
less
(
value
,
zeros
)
return
self
.
select
(
comp
,
zeros
,
cdf
)
def
_kl_loss
(
self
,
dist
,
probs1_b
,
probs1
=
None
):
r
"""
Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b).
...
...
tests/ut/python/nn/distribution/test_bernoulli.py
浏览文件 @
56835aaf
...
...
@@ -22,6 +22,7 @@ import mindspore.nn.probability.distribution as msd
from
mindspore
import
dtype
from
mindspore
import
Tensor
def
test_arguments
():
"""
Args passing during initialization.
...
...
@@ -31,18 +32,22 @@ def test_arguments():
b
=
msd
.
Bernoulli
([
0.1
,
0.3
,
0.5
,
0.9
],
dtype
=
dtype
.
int32
)
assert
isinstance
(
b
,
msd
.
Distribution
)
def
test_type
():
with
pytest
.
raises
(
TypeError
):
msd
.
Bernoulli
([
0.1
],
dtype
=
dtype
.
float32
)
msd
.
Bernoulli
([
0.1
],
dtype
=
dtype
.
bool_
)
def
test_name
():
with
pytest
.
raises
(
TypeError
):
msd
.
Bernoulli
([
0.1
],
name
=
1.0
)
def
test_seed
():
with
pytest
.
raises
(
TypeError
):
msd
.
Bernoulli
([
0.1
],
seed
=
'seed'
)
def
test_prob
():
"""
Invalid probability.
...
...
@@ -56,10 +61,12 @@ def test_prob():
with
pytest
.
raises
(
ValueError
):
msd
.
Bernoulli
([
1.0
],
dtype
=
dtype
.
int32
)
class
BernoulliProb
(
nn
.
Cell
):
"""
Bernoulli distribution: initialize with probs.
"""
def
__init__
(
self
):
super
(
BernoulliProb
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
(
0.5
,
dtype
=
dtype
.
int32
)
...
...
@@ -73,6 +80,7 @@ class BernoulliProb(nn.Cell):
log_sf
=
self
.
b
.
log_survival
(
value
)
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
def
test_bernoulli_prob
():
"""
Test probability functions: passing value through construct.
...
...
@@ -82,10 +90,12 @@ def test_bernoulli_prob():
ans
=
net
(
value
)
assert
isinstance
(
ans
,
Tensor
)
class
BernoulliProb1
(
nn
.
Cell
):
"""
Bernoulli distribution: initialize without probs.
"""
def
__init__
(
self
):
super
(
BernoulliProb1
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
(
dtype
=
dtype
.
int32
)
...
...
@@ -99,6 +109,7 @@ class BernoulliProb1(nn.Cell):
log_sf
=
self
.
b
.
log_survival
(
value
,
probs
)
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
def
test_bernoulli_prob1
():
"""
Test probability functions: passing value/probs through construct.
...
...
@@ -109,10 +120,12 @@ def test_bernoulli_prob1():
ans
=
net
(
value
,
probs
)
assert
isinstance
(
ans
,
Tensor
)
class
BernoulliKl
(
nn
.
Cell
):
"""
Test class: kl_loss between Bernoulli distributions.
"""
def
__init__
(
self
):
super
(
BernoulliKl
,
self
).
__init__
()
self
.
b1
=
msd
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
...
...
@@ -123,6 +136,7 @@ class BernoulliKl(nn.Cell):
kl2
=
self
.
b2
.
kl_loss
(
'Bernoulli'
,
probs_b
,
probs_a
)
return
kl1
+
kl2
def
test_kl
():
"""
Test kl_loss function.
...
...
@@ -133,10 +147,12 @@ def test_kl():
ans
=
ber_net
(
probs_b
,
probs_a
)
assert
isinstance
(
ans
,
Tensor
)
class
BernoulliCrossEntropy
(
nn
.
Cell
):
"""
Test class: cross_entropy of Bernoulli distribution.
"""
def
__init__
(
self
):
super
(
BernoulliCrossEntropy
,
self
).
__init__
()
self
.
b1
=
msd
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
...
...
@@ -147,6 +163,7 @@ class BernoulliCrossEntropy(nn.Cell):
h2
=
self
.
b2
.
cross_entropy
(
'Bernoulli'
,
probs_b
,
probs_a
)
return
h1
+
h2
def
test_cross_entropy
():
"""
Test cross_entropy between Bernoulli distributions.
...
...
@@ -157,10 +174,12 @@ def test_cross_entropy():
ans
=
net
(
probs_b
,
probs_a
)
assert
isinstance
(
ans
,
Tensor
)
class
BernoulliConstruct
(
nn
.
Cell
):
"""
Bernoulli distribution: going through construct.
"""
def
__init__
(
self
):
super
(
BernoulliConstruct
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
(
0.5
,
dtype
=
dtype
.
int32
)
...
...
@@ -172,6 +191,7 @@ class BernoulliConstruct(nn.Cell):
prob2
=
self
.
b1
(
'prob'
,
value
,
probs
)
return
prob
+
prob1
+
prob2
def
test_bernoulli_construct
():
"""
Test probability function going through construct.
...
...
@@ -182,10 +202,12 @@ def test_bernoulli_construct():
ans
=
net
(
value
,
probs
)
assert
isinstance
(
ans
,
Tensor
)
class
BernoulliMean
(
nn
.
Cell
):
"""
Test class: basic mean/sd/var/mode/entropy function.
"""
def
__init__
(
self
):
super
(
BernoulliMean
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
([
0.3
,
0.5
],
dtype
=
dtype
.
int32
)
...
...
@@ -194,6 +216,7 @@ class BernoulliMean(nn.Cell):
mean
=
self
.
b
.
mean
()
return
mean
def
test_mean
():
"""
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
...
...
@@ -202,10 +225,12 @@ def test_mean():
ans
=
net
()
assert
isinstance
(
ans
,
Tensor
)
class
BernoulliSd
(
nn
.
Cell
):
"""
Test class: basic mean/sd/var/mode/entropy function.
"""
def
__init__
(
self
):
super
(
BernoulliSd
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
([
0.3
,
0.5
],
dtype
=
dtype
.
int32
)
...
...
@@ -214,6 +239,7 @@ class BernoulliSd(nn.Cell):
sd
=
self
.
b
.
sd
()
return
sd
def
test_sd
():
"""
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
...
...
@@ -222,10 +248,12 @@ def test_sd():
ans
=
net
()
assert
isinstance
(
ans
,
Tensor
)
class
BernoulliVar
(
nn
.
Cell
):
"""
Test class: basic mean/sd/var/mode/entropy function.
"""
def
__init__
(
self
):
super
(
BernoulliVar
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
([
0.3
,
0.5
],
dtype
=
dtype
.
int32
)
...
...
@@ -234,6 +262,7 @@ class BernoulliVar(nn.Cell):
var
=
self
.
b
.
var
()
return
var
def
test_var
():
"""
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
...
...
@@ -242,10 +271,12 @@ def test_var():
ans
=
net
()
assert
isinstance
(
ans
,
Tensor
)
class
BernoulliMode
(
nn
.
Cell
):
"""
Test class: basic mean/sd/var/mode/entropy function.
"""
def
__init__
(
self
):
super
(
BernoulliMode
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
([
0.3
,
0.5
],
dtype
=
dtype
.
int32
)
...
...
@@ -254,6 +285,7 @@ class BernoulliMode(nn.Cell):
mode
=
self
.
b
.
mode
()
return
mode
def
test_mode
():
"""
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
...
...
@@ -262,10 +294,12 @@ def test_mode():
ans
=
net
()
assert
isinstance
(
ans
,
Tensor
)
class
BernoulliEntropy
(
nn
.
Cell
):
"""
Test class: basic mean/sd/var/mode/entropy function.
"""
def
__init__
(
self
):
super
(
BernoulliEntropy
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
([
0.3
,
0.5
],
dtype
=
dtype
.
int32
)
...
...
@@ -274,6 +308,7 @@ class BernoulliEntropy(nn.Cell):
entropy
=
self
.
b
.
entropy
()
return
entropy
def
test_entropy
():
"""
Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
...
...
tests/ut/python/nn/distribution/test_geometric.py
浏览文件 @
56835aaf
...
...
@@ -32,18 +32,22 @@ def test_arguments():
g
=
msd
.
Geometric
([
0.1
,
0.3
,
0.5
,
0.9
],
dtype
=
dtype
.
int32
)
assert
isinstance
(
g
,
msd
.
Distribution
)
def
test_type
():
with
pytest
.
raises
(
TypeError
):
msd
.
Geometric
([
0.1
],
dtype
=
dtype
.
float32
)
msd
.
Geometric
([
0.1
],
dtype
=
dtype
.
bool_
)
def
test_name
():
with
pytest
.
raises
(
TypeError
):
msd
.
Geometric
([
0.1
],
name
=
1.0
)
def
test_seed
():
with
pytest
.
raises
(
TypeError
):
msd
.
Geometric
([
0.1
],
seed
=
'seed'
)
def
test_prob
():
"""
Invalid probability.
...
...
@@ -57,10 +61,12 @@ def test_prob():
with
pytest
.
raises
(
ValueError
):
msd
.
Geometric
([
1.0
],
dtype
=
dtype
.
int32
)
class
GeometricProb
(
nn
.
Cell
):
"""
Geometric distribution: initialize with probs.
"""
def
__init__
(
self
):
super
(
GeometricProb
,
self
).
__init__
()
self
.
g
=
msd
.
Geometric
(
0.5
,
dtype
=
dtype
.
int32
)
...
...
@@ -74,6 +80,7 @@ class GeometricProb(nn.Cell):
log_sf
=
self
.
g
.
log_survival
(
value
)
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
def
test_geometric_prob
():
"""
Test probability functions: passing value through construct.
...
...
@@ -83,10 +90,12 @@ def test_geometric_prob():
ans
=
net
(
value
)
assert
isinstance
(
ans
,
Tensor
)
class
GeometricProb1
(
nn
.
Cell
):
"""
Geometric distribution: initialize without probs.
"""
def
__init__
(
self
):
super
(
GeometricProb1
,
self
).
__init__
()
self
.
g
=
msd
.
Geometric
(
dtype
=
dtype
.
int32
)
...
...
@@ -100,6 +109,7 @@ class GeometricProb1(nn.Cell):
log_sf
=
self
.
g
.
log_survival
(
value
,
probs
)
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
def
test_geometric_prob1
():
"""
Test probability functions: passing value/probs through construct.
...
...
@@ -115,6 +125,7 @@ class GeometricKl(nn.Cell):
"""
Test class: kl_loss between Geometric distributions.
"""
def
__init__
(
self
):
super
(
GeometricKl
,
self
).
__init__
()
self
.
g1
=
msd
.
Geometric
(
0.7
,
dtype
=
dtype
.
int32
)
...
...
@@ -125,6 +136,7 @@ class GeometricKl(nn.Cell):
kl2
=
self
.
g2
.
kl_loss
(
'Geometric'
,
probs_b
,
probs_a
)
return
kl1
+
kl2
def
test_kl
():
"""
Test kl_loss function.
...
...
@@ -135,10 +147,12 @@ def test_kl():
ans
=
ber_net
(
probs_b
,
probs_a
)
assert
isinstance
(
ans
,
Tensor
)
class
GeometricCrossEntropy
(
nn
.
Cell
):
"""
Test class: cross_entropy of Geometric distribution.
"""
def
__init__
(
self
):
super
(
GeometricCrossEntropy
,
self
).
__init__
()
self
.
g1
=
msd
.
Geometric
(
0.3
,
dtype
=
dtype
.
int32
)
...
...
@@ -149,6 +163,7 @@ class GeometricCrossEntropy(nn.Cell):
h2
=
self
.
g2
.
cross_entropy
(
'Geometric'
,
probs_b
,
probs_a
)
return
h1
+
h2
def
test_cross_entropy
():
"""
Test cross_entropy between Geometric distributions.
...
...
@@ -159,10 +174,12 @@ def test_cross_entropy():
ans
=
net
(
probs_b
,
probs_a
)
assert
isinstance
(
ans
,
Tensor
)
class
GeometricBasics
(
nn
.
Cell
):
"""
Test class: basic mean/sd/mode/entropy function.
"""
def
__init__
(
self
):
super
(
GeometricBasics
,
self
).
__init__
()
self
.
g
=
msd
.
Geometric
([
0.3
,
0.5
],
dtype
=
dtype
.
int32
)
...
...
@@ -175,6 +192,7 @@ class GeometricBasics(nn.Cell):
entropy
=
self
.
g
.
entropy
()
return
mean
+
sd
+
var
+
mode
+
entropy
def
test_bascis
():
"""
Test mean/sd/mode/entropy functionality of Geometric distribution.
...
...
@@ -188,6 +206,7 @@ class GeoConstruct(nn.Cell):
"""
Bernoulli distribution: going through construct.
"""
def
__init__
(
self
):
super
(
GeoConstruct
,
self
).
__init__
()
self
.
g
=
msd
.
Geometric
(
0.5
,
dtype
=
dtype
.
int32
)
...
...
@@ -199,6 +218,7 @@ class GeoConstruct(nn.Cell):
prob2
=
self
.
g1
(
'prob'
,
value
,
probs
)
return
prob
+
prob1
+
prob2
def
test_geo_construct
():
"""
Test probability function going through construct.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录