Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
e87e1fc6
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e87e1fc6
编写于
7月 29, 2020
作者:
X
Xun Deng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
changed distribution api
上级
6945eb28
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
753 addition
and
802 deletion
+753
-802
mindspore/nn/probability/distribution/bernoulli.py
mindspore/nn/probability/distribution/bernoulli.py
+73
-91
mindspore/nn/probability/distribution/distribution.py
mindspore/nn/probability/distribution/distribution.py
+49
-57
mindspore/nn/probability/distribution/exponential.py
mindspore/nn/probability/distribution/exponential.py
+66
-81
mindspore/nn/probability/distribution/geometric.py
mindspore/nn/probability/distribution/geometric.py
+80
-97
mindspore/nn/probability/distribution/normal.py
mindspore/nn/probability/distribution/normal.py
+69
-88
mindspore/nn/probability/distribution/uniform.py
mindspore/nn/probability/distribution/uniform.py
+80
-97
tests/st/ops/ascend/test_distribution/test_bernoulli.py
tests/st/ops/ascend/test_distribution/test_bernoulli.py
+13
-26
tests/st/ops/ascend/test_distribution/test_exponential.py
tests/st/ops/ascend/test_distribution/test_exponential.py
+13
-25
tests/st/ops/ascend/test_distribution/test_geometric.py
tests/st/ops/ascend/test_distribution/test_geometric.py
+13
-25
tests/st/ops/ascend/test_distribution/test_normal.py
tests/st/ops/ascend/test_distribution/test_normal.py
+50
-25
tests/st/ops/ascend/test_distribution/test_normal_new_api.py
tests/st/ops/ascend/test_distribution/test_normal_new_api.py
+0
-62
tests/st/ops/ascend/test_distribution/test_uniform.py
tests/st/ops/ascend/test_distribution/test_uniform.py
+13
-25
tests/ut/python/nn/distribution/test_bernoulli.py
tests/ut/python/nn/distribution/test_bernoulli.py
+46
-21
tests/ut/python/nn/distribution/test_exponential.py
tests/ut/python/nn/distribution/test_exponential.py
+47
-21
tests/ut/python/nn/distribution/test_geometric.py
tests/ut/python/nn/distribution/test_geometric.py
+47
-21
tests/ut/python/nn/distribution/test_normal.py
tests/ut/python/nn/distribution/test_normal.py
+47
-20
tests/ut/python/nn/distribution/test_uniform.py
tests/ut/python/nn/distribution/test_uniform.py
+47
-20
未找到文件。
mindspore/nn/probability/distribution/bernoulli.py
浏览文件 @
e87e1fc6
...
@@ -34,55 +34,56 @@ class Bernoulli(Distribution):
...
@@ -34,55 +34,56 @@ class Bernoulli(Distribution):
Examples:
Examples:
>>> # To initialize a Bernoulli distribution of prob 0.5
>>> # To initialize a Bernoulli distribution of prob 0.5
>>> n = nn.Bernoulli(0.5, dtype=mstype.int32)
>>> import mindspore.nn.probability.distribution as msd
>>> b = msd.Bernoulli(0.5, dtype=mstype.int32)
>>>
>>>
>>> # The following creates two independent Bernoulli distributions
>>> # The following creates two independent Bernoulli distributions
>>>
n = nn
.Bernoulli([0.5, 0.5], dtype=mstype.int32)
>>>
b = msd
.Bernoulli([0.5, 0.5], dtype=mstype.int32)
>>>
>>>
>>> # A Bernoulli distribution can be initilized without arguments
>>> # A Bernoulli distribution can be initilized without arguments
>>> # In this case, probs must be passed in through
construct
.
>>> # In this case, probs must be passed in through
args during function calls
.
>>>
n = nn
.Bernoulli(dtype=mstype.int32)
>>>
b = msd
.Bernoulli(dtype=mstype.int32)
>>>
>>>
>>> # To use Bernoulli
distribution
in a network
>>> # To use Bernoulli in a network
>>> class net(Cell):
>>> class net(Cell):
>>> def __init__(self):
>>> def __init__(self):
>>> super(net, self).__init__():
>>> super(net, self).__init__():
>>> self.b1 =
nn
.Bernoulli(0.5, dtype=mstype.int32)
>>> self.b1 =
msd
.Bernoulli(0.5, dtype=mstype.int32)
>>> self.b2 =
nn
.Bernoulli(dtype=mstype.int32)
>>> self.b2 =
msd
.Bernoulli(dtype=mstype.int32)
>>>
>>>
>>> # All the following calls in construct are valid
>>> # All the following calls in construct are valid
>>> def construct(self, value, probs_b, probs_a):
>>> def construct(self, value, probs_b, probs_a):
>>>
>>>
>>> # Similar calls can be made to other probability functions
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' with the name of the function
>>> # by replacing 'prob' with the name of the function
>>> ans = self.b1
('prob',
value)
>>> ans = self.b1
.prob(
value)
>>> # Evaluate with the respect to distribution b
>>> # Evaluate with the respect to distribution b
>>> ans = self.b1
('prob',
value, probs_b)
>>> ans = self.b1
.prob(
value, probs_b)
>>>
>>>
>>> # probs must be passed in
through construct
>>> # probs must be passed in
during function calls
>>> ans = self.b2
('prob',
value, probs_a)
>>> ans = self.b2
.prob(
value, probs_a)
>>>
>>>
>>> # Functions 'sd', 'var', 'entropy' have the same usage
like
'mean'
>>> # Functions 'sd', 'var', 'entropy' have the same usage
as
'mean'
>>> # Will return
[0.0]
>>> # Will return
0.5
>>> ans = self.b1
('mean'
)
>>> ans = self.b1
.mean(
)
>>> # Will return
mean
_b
>>> # Will return
probs
_b
>>> ans = self.b1
('mean',
probs_b)
>>> ans = self.b1
.mean(
probs_b)
>>>
>>>
>>> # probs must be passed in
through construct
>>> # probs must be passed in
during function calls
>>> ans = self.b2
('mean',
probs_a)
>>> ans = self.b2
.mean(
probs_a)
>>>
>>>
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
>>> ans = self.b1
('kl_loss',
'Bernoulli', probs_b)
>>> ans = self.b1
.kl_loss(
'Bernoulli', probs_b)
>>> ans = self.b1
('kl_loss',
'Bernoulli', probs_b, probs_a)
>>> ans = self.b1
.kl_loss(
'Bernoulli', probs_b, probs_a)
>>>
>>>
>>> # Additional probs_a must be passed in through
construct
>>> # Additional probs_a must be passed in through
>>> ans = self.b2
('kl_loss',
'Bernoulli', probs_b, probs_a)
>>> ans = self.b2
.kl_loss(
'Bernoulli', probs_b, probs_a)
>>>
>>>
>>> # Sample
Usage
>>> # Sample
>>> ans = self.b1
('sample'
)
>>> ans = self.b1
.sample(
)
>>> ans = self.b1
('sample',
(2,3))
>>> ans = self.b1
.sample(
(2,3))
>>> ans = self.b1
('sample',
(2,3), probs_b)
>>> ans = self.b1
.sample(
(2,3), probs_b)
>>> ans = self.b2
('sample',
(2,3), probs_a)
>>> ans = self.b2
.sample(
(2,3), probs_a)
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -130,71 +131,61 @@ class Bernoulli(Distribution):
...
@@ -130,71 +131,61 @@ class Bernoulli(Distribution):
"""
"""
return
self
.
_probs
return
self
.
_probs
def
_mean
(
self
,
name
=
'mean'
,
probs1
=
None
):
def
_mean
(
self
,
probs1
=
None
):
r
"""
r
"""
.. math::
.. math::
MEAN(B) = probs1
MEAN(B) = probs1
"""
"""
if
name
==
'mean'
:
return
self
.
probs
if
probs1
is
None
else
probs1
return
self
.
probs
if
probs1
is
None
else
probs1
return
None
def
_mode
(
self
,
name
=
'mode'
,
probs1
=
None
):
def
_mode
(
self
,
probs1
=
None
):
r
"""
r
"""
.. math::
.. math::
MODE(B) = 1 if probs1 > 0.5 else = 0
MODE(B) = 1 if probs1 > 0.5 else = 0
"""
"""
if
name
==
'mode'
:
probs1
=
self
.
probs
if
probs1
is
None
else
probs1
probs1
=
self
.
probs
if
probs1
is
None
else
probs1
prob_type
=
self
.
dtypeop
(
probs1
)
prob_type
=
self
.
dtypeop
(
probs1
)
zeros
=
self
.
fill
(
prob_type
,
self
.
shape
(
probs1
),
0.0
)
zeros
=
self
.
fill
(
prob_type
,
self
.
shape
(
probs1
),
0.0
)
ones
=
self
.
fill
(
prob_type
,
self
.
shape
(
probs1
),
1.0
)
ones
=
self
.
fill
(
prob_type
,
self
.
shape
(
probs1
),
1.0
)
comp
=
self
.
less
(
0.5
,
probs1
)
comp
=
self
.
less
(
0.5
,
probs1
)
return
self
.
select
(
comp
,
ones
,
zeros
)
return
self
.
select
(
comp
,
ones
,
zeros
)
return
None
def
_var
(
self
,
name
=
'var'
,
probs1
=
None
):
def
_var
(
self
,
probs1
=
None
):
r
"""
r
"""
.. math::
.. math::
VAR(B) = probs1 * probs0
VAR(B) = probs1 * probs0
"""
"""
if
name
in
self
.
_variance_functions
:
probs1
=
self
.
probs
if
probs1
is
None
else
probs1
probs1
=
self
.
probs
if
probs1
is
None
else
probs1
probs0
=
1.0
-
probs1
probs0
=
1.0
-
probs1
return
probs0
*
probs1
return
probs0
*
probs1
return
None
def
_entropy
(
self
,
name
=
'entropy'
,
probs
=
None
):
def
_entropy
(
self
,
probs
=
None
):
r
"""
r
"""
.. math::
.. math::
H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1)
H(B) = -probs0 * \log(probs0) - probs1 * \log(probs1)
"""
"""
if
name
==
'entropy'
:
probs1
=
self
.
probs
if
probs
is
None
else
probs
probs1
=
self
.
probs
if
probs
is
None
else
probs
probs0
=
1
-
probs1
probs0
=
1
-
probs1
return
-
1
*
(
probs0
*
self
.
log
(
probs0
))
-
(
probs1
*
self
.
log
(
probs1
))
return
-
1
*
(
probs0
*
self
.
log
(
probs0
))
-
(
probs1
*
self
.
log
(
probs1
))
return
None
def
_cross_entropy
(
self
,
name
,
dist
,
probs1_b
,
probs1_a
=
None
):
def
_cross_entropy
(
self
,
dist
,
probs1_b
,
probs1_a
=
None
):
"""
"""
Evaluate cross_entropy between Bernoulli distributions.
Evaluate cross_entropy between Bernoulli distributions.
Args:
Args:
name (str): name of the funtion.
dist (str): type of the distributions. Should be "Bernoulli" in this case.
dist (str): type of the distributions. Should be "Bernoulli" in this case.
probs1_b (Tensor): probs1 of distribution b.
probs1_b (Tensor): probs1 of distribution b.
probs1_a (Tensor): probs1 of distribution a. Default: self.probs.
probs1_a (Tensor): probs1 of distribution a. Default: self.probs.
"""
"""
if
name
==
'cross_entropy'
and
dist
==
'Bernoulli'
:
if
dist
==
'Bernoulli'
:
return
self
.
_entropy
(
probs
=
probs1_a
)
+
self
.
_kl_loss
(
name
,
dist
,
probs1_b
,
probs1_a
)
return
self
.
_entropy
(
probs
=
probs1_a
)
+
self
.
_kl_loss
(
dist
,
probs1_b
,
probs1_a
)
return
None
return
None
def
_prob
(
self
,
name
,
value
,
probs
=
None
):
def
_prob
(
self
,
value
,
probs
=
None
):
r
"""
r
"""
pmf of Bernoulli distribution.
pmf of Bernoulli distribution.
Args:
Args:
name (str): name of the function. Should be "prob" when passed in from construct.
value (Tensor): a Tensor composed of only zeros and ones.
value (Tensor): a Tensor composed of only zeros and ones.
probs (Tensor): probability of outcome is 1. Default: self.probs.
probs (Tensor): probability of outcome is 1. Default: self.probs.
...
@@ -202,18 +193,15 @@ class Bernoulli(Distribution):
...
@@ -202,18 +193,15 @@ class Bernoulli(Distribution):
pmf(k) = probs1 if k = 1;
pmf(k) = probs1 if k = 1;
pmf(k) = probs0 if k = 0;
pmf(k) = probs0 if k = 0;
"""
"""
if
name
in
self
.
_prob_functions
:
probs1
=
self
.
probs
if
probs
is
None
else
probs
probs1
=
self
.
probs
if
probs
is
None
else
probs
probs0
=
1.0
-
probs1
probs0
=
1.0
-
probs1
return
(
probs1
*
value
)
+
(
probs0
*
(
1.0
-
value
))
return
(
probs1
*
value
)
+
(
probs0
*
(
1.0
-
value
))
return
None
def
_cdf
(
self
,
name
,
value
,
probs
=
None
):
def
_cdf
(
self
,
value
,
probs
=
None
):
r
"""
r
"""
cdf of Bernoulli distribution.
cdf of Bernoulli distribution.
Args:
Args:
name (str): name of the function.
value (Tensor): value to be evaluated.
value (Tensor): value to be evaluated.
probs (Tensor): probability of outcome is 1. Default: self.probs.
probs (Tensor): probability of outcome is 1. Default: self.probs.
...
@@ -222,25 +210,22 @@ class Bernoulli(Distribution):
...
@@ -222,25 +210,22 @@ class Bernoulli(Distribution):
cdf(k) = probs0 if 0 <= k <1;
cdf(k) = probs0 if 0 <= k <1;
cdf(k) = 1 if k >=1;
cdf(k) = 1 if k >=1;
"""
"""
if
name
in
self
.
_cdf_survival_functions
:
probs1
=
self
.
probs
if
probs
is
None
else
probs
probs1
=
self
.
probs
if
probs
is
None
else
probs
prob_type
=
self
.
dtypeop
(
probs1
)
prob_type
=
self
.
dtypeop
(
probs1
)
value
=
value
*
self
.
fill
(
prob_type
,
self
.
shape
(
probs1
),
1.0
)
value
=
value
*
self
.
fill
(
prob_type
,
self
.
shape
(
probs1
),
1.0
)
probs0
=
1.0
-
probs1
*
self
.
fill
(
prob_type
,
self
.
shape
(
value
),
1.0
)
probs0
=
1.0
-
probs1
*
self
.
fill
(
prob_type
,
self
.
shape
(
value
),
1.0
)
comp_zero
=
self
.
less
(
value
,
0.0
)
comp_zero
=
self
.
less
(
value
,
0.0
)
comp_one
=
self
.
less
(
value
,
1.0
)
comp_one
=
self
.
less
(
value
,
1.0
)
zeros
=
self
.
fill
(
prob_type
,
self
.
shape
(
value
),
0.0
)
zeros
=
self
.
fill
(
prob_type
,
self
.
shape
(
value
),
0.0
)
ones
=
self
.
fill
(
prob_type
,
self
.
shape
(
value
),
1.0
)
ones
=
self
.
fill
(
prob_type
,
self
.
shape
(
value
),
1.0
)
less_than_zero
=
self
.
select
(
comp_zero
,
zeros
,
probs0
)
less_than_zero
=
self
.
select
(
comp_zero
,
zeros
,
probs0
)
return
self
.
select
(
comp_one
,
less_than_zero
,
ones
)
return
self
.
select
(
comp_one
,
less_than_zero
,
ones
)
return
None
def
_kl_loss
(
self
,
name
,
dist
,
probs1_b
,
probs1_a
=
None
):
def
_kl_loss
(
self
,
dist
,
probs1_b
,
probs1_a
=
None
):
r
"""
r
"""
Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b).
Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b).
Args:
Args:
name (str): name of the funtion.
dist (str): type of the distributions. Should be "Bernoulli" in this case.
dist (str): type of the distributions. Should be "Bernoulli" in this case.
probs1_b (Tensor): probs1 of distribution b.
probs1_b (Tensor): probs1 of distribution b.
probs1_a (Tensor): probs1 of distribution a. Default: self.probs.
probs1_a (Tensor): probs1 of distribution a. Default: self.probs.
...
@@ -249,31 +234,28 @@ class Bernoulli(Distribution):
...
@@ -249,31 +234,28 @@ class Bernoulli(Distribution):
KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) +
KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) +
probs0_a * \log(\fract{probs0_a}{probs0_b})
probs0_a * \log(\fract{probs0_a}{probs0_b})
"""
"""
if
name
in
self
.
_divergence_functions
and
dist
==
'Bernoulli'
:
if
dist
==
'Bernoulli'
:
probs1_a
=
self
.
probs
if
probs1_a
is
None
else
probs1_a
probs1_a
=
self
.
probs
if
probs1_a
is
None
else
probs1_a
probs0_a
=
1.0
-
probs1_a
probs0_a
=
1.0
-
probs1_a
probs0_b
=
1.0
-
probs1_b
probs0_b
=
1.0
-
probs1_b
return
probs1_a
*
self
.
log
(
probs1_a
/
probs1_b
)
+
probs0_a
*
self
.
log
(
probs0_a
/
probs0_b
)
return
probs1_a
*
self
.
log
(
probs1_a
/
probs1_b
)
+
probs0_a
*
self
.
log
(
probs0_a
/
probs0_b
)
return
None
return
None
def
_sample
(
self
,
name
,
shape
=
(),
probs
=
None
):
def
_sample
(
self
,
shape
=
(),
probs
=
None
):
"""
"""
Sampling.
Sampling.
Args:
Args:
name (str): name of the function. Should always be 'sample' when passed in from construct.
shape (tuple): shape of the sample. Default: ().
shape (tuple): shape of the sample. Default: ().
probs (Tensor): probs1 of the samples. Default: self.probs.
probs (Tensor): probs1 of the samples. Default: self.probs.
Returns:
Returns:
Tensor, shape is shape + batch_shape.
Tensor, shape is shape + batch_shape.
"""
"""
if
name
==
'sample'
:
probs1
=
self
.
probs
if
probs
is
None
else
probs
probs1
=
self
.
probs
if
probs
is
None
else
probs
l_zero
=
self
.
const
(
0.0
)
l_zero
=
self
.
const
(
0.0
)
h_one
=
self
.
const
(
1.0
)
h_one
=
self
.
const
(
1.0
)
sample_uniform
=
self
.
uniform
(
shape
+
self
.
shape
(
probs1
),
l_zero
,
h_one
)
sample_uniform
=
self
.
uniform
(
shape
+
self
.
shape
(
probs1
),
l_zero
,
h_one
)
sample
=
self
.
less
(
sample_uniform
,
probs1
)
sample
=
self
.
less
(
sample_uniform
,
probs1
)
sample
=
self
.
cast
(
sample
,
self
.
dtype
)
sample
=
self
.
cast
(
sample
,
self
.
dtype
)
return
sample
return
sample
return
None
mindspore/nn/probability/distribution/distribution.py
浏览文件 @
e87e1fc6
...
@@ -27,11 +27,7 @@ class Distribution(Cell):
...
@@ -27,11 +27,7 @@ class Distribution(Cell):
Note:
Note:
Derived class should override operations such as ,_mean, _prob,
Derived class should override operations such as ,_mean, _prob,
and _log_prob. Functions should be called through construct when
and _log_prob. Arguments should be passed in through *args.
used inside a network. Arguments should be passed in through *args
in the form of function name followed by additional arguments.
Functions such as cdf and prob, require a value to be passed in while
functions such as mean and sd do not require arguments other than name.
Dist_spec_args are unique for each type of distribution. For example, mean and sd
Dist_spec_args are unique for each type of distribution. For example, mean and sd
are the dist_spec_args for a Normal distribution.
are the dist_spec_args for a Normal distribution.
...
@@ -73,11 +69,6 @@ class Distribution(Cell):
...
@@ -73,11 +69,6 @@ class Distribution(Cell):
self
.
_set_log_survival
()
self
.
_set_log_survival
()
self
.
_set_cross_entropy
()
self
.
_set_cross_entropy
()
self
.
_prob_functions
=
(
'prob'
,
'log_prob'
)
self
.
_cdf_survival_functions
=
(
'cdf'
,
'log_cdf'
,
'survival_function'
,
'log_survival'
)
self
.
_variance_functions
=
(
'var'
,
'sd'
)
self
.
_divergence_functions
=
(
'kl_loss'
,
'cross_entropy'
)
@
property
@
property
def
name
(
self
):
def
name
(
self
):
return
self
.
_name
return
self
.
_name
...
@@ -185,7 +176,7 @@ class Distribution(Cell):
...
@@ -185,7 +176,7 @@ class Distribution(Cell):
Evaluate the log probability(pdf or pmf) at the given value.
Evaluate the log probability(pdf or pmf) at the given value.
Note:
Note:
Args must include
name of the function and
value.
Args must include value.
Dist_spec_args are optional.
Dist_spec_args are optional.
"""
"""
return
self
.
_call_log_prob
(
*
args
)
return
self
.
_call_log_prob
(
*
args
)
...
@@ -204,7 +195,7 @@ class Distribution(Cell):
...
@@ -204,7 +195,7 @@ class Distribution(Cell):
Evaluate the probability (pdf or pmf) at given value.
Evaluate the probability (pdf or pmf) at given value.
Note:
Note:
Args must include
name of the function and
value.
Args must include value.
Dist_spec_args are optional.
Dist_spec_args are optional.
"""
"""
return
self
.
_call_prob
(
*
args
)
return
self
.
_call_prob
(
*
args
)
...
@@ -223,7 +214,7 @@ class Distribution(Cell):
...
@@ -223,7 +214,7 @@ class Distribution(Cell):
Evaluate the cdf at given value.
Evaluate the cdf at given value.
Note:
Note:
Args must include
name of the function and
value.
Args must include value.
Dist_spec_args are optional.
Dist_spec_args are optional.
"""
"""
return
self
.
_call_cdf
(
*
args
)
return
self
.
_call_cdf
(
*
args
)
...
@@ -260,7 +251,7 @@ class Distribution(Cell):
...
@@ -260,7 +251,7 @@ class Distribution(Cell):
Evaluate the log cdf at given value.
Evaluate the log cdf at given value.
Note:
Note:
Args must include
name of the function and
value.
Args must include value.
Dist_spec_args are optional.
Dist_spec_args are optional.
"""
"""
return
self
.
_call_log_cdf
(
*
args
)
return
self
.
_call_log_cdf
(
*
args
)
...
@@ -279,7 +270,7 @@ class Distribution(Cell):
...
@@ -279,7 +270,7 @@ class Distribution(Cell):
Evaluate the survival function at given value.
Evaluate the survival function at given value.
Note:
Note:
Args must include
name of the function and
value.
Args must include value.
Dist_spec_args are optional.
Dist_spec_args are optional.
"""
"""
return
self
.
_call_survival
(
*
args
)
return
self
.
_call_survival
(
*
args
)
...
@@ -307,7 +298,7 @@ class Distribution(Cell):
...
@@ -307,7 +298,7 @@ class Distribution(Cell):
Evaluate the log survival function at given value.
Evaluate the log survival function at given value.
Note:
Note:
Args must include
name of the function and
value.
Args must include value.
Dist_spec_args are optional.
Dist_spec_args are optional.
"""
"""
return
self
.
_call_log_survival
(
*
args
)
return
self
.
_call_log_survival
(
*
args
)
...
@@ -326,7 +317,7 @@ class Distribution(Cell):
...
@@ -326,7 +317,7 @@ class Distribution(Cell):
Evaluate the KL divergence, i.e. KL(a||b).
Evaluate the KL divergence, i.e. KL(a||b).
Note:
Note:
Args must include
name of the function,
type of the distribution, parameters of distribution b.
Args must include type of the distribution, parameters of distribution b.
Parameters for distribution a are optional.
Parameters for distribution a are optional.
"""
"""
return
self
.
_kl_loss
(
*
args
)
return
self
.
_kl_loss
(
*
args
)
...
@@ -336,7 +327,7 @@ class Distribution(Cell):
...
@@ -336,7 +327,7 @@ class Distribution(Cell):
Evaluate the mean.
Evaluate the mean.
Note:
Note:
Args must include the name of function.
Dist_spec_args are optional.
Dist_spec_args are optional.
"""
"""
return
self
.
_mean
(
*
args
)
return
self
.
_mean
(
*
args
)
...
@@ -345,7 +336,7 @@ class Distribution(Cell):
...
@@ -345,7 +336,7 @@ class Distribution(Cell):
Evaluate the mode.
Evaluate the mode.
Note:
Note:
Args must include the name of function.
Dist_spec_args are optional.
Dist_spec_args are optional.
"""
"""
return
self
.
_mode
(
*
args
)
return
self
.
_mode
(
*
args
)
...
@@ -354,7 +345,7 @@ class Distribution(Cell):
...
@@ -354,7 +345,7 @@ class Distribution(Cell):
Evaluate the standard deviation.
Evaluate the standard deviation.
Note:
Note:
Args must include the name of function.
Dist_spec_args are optional.
Dist_spec_args are optional.
"""
"""
return
self
.
_call_sd
(
*
args
)
return
self
.
_call_sd
(
*
args
)
...
@@ -363,7 +354,7 @@ class Distribution(Cell):
...
@@ -363,7 +354,7 @@ class Distribution(Cell):
Evaluate the variance.
Evaluate the variance.
Note:
Note:
Args must include the name of function.
Dist_spec_args are optional.
Dist_spec_args are optional.
"""
"""
return
self
.
_call_var
(
*
args
)
return
self
.
_call_var
(
*
args
)
...
@@ -390,7 +381,7 @@ class Distribution(Cell):
...
@@ -390,7 +381,7 @@ class Distribution(Cell):
Evaluate the entropy.
Evaluate the entropy.
Note:
Note:
Args must include the name of function.
Dist_spec_args are optional.
Dist_spec_args are optional.
"""
"""
return
self
.
_entropy
(
*
args
)
return
self
.
_entropy
(
*
args
)
...
@@ -399,7 +390,7 @@ class Distribution(Cell):
...
@@ -399,7 +390,7 @@ class Distribution(Cell):
Evaluate the cross_entropy between distribution a and b.
Evaluate the cross_entropy between distribution a and b.
Note:
Note:
Args must include
name of the function,
type of the distribution, parameters of distribution b.
Args must include type of the distribution, parameters of distribution b.
Parameters for distribution a are optional.
Parameters for distribution a are optional.
"""
"""
return
self
.
_call_cross_entropy
(
*
args
)
return
self
.
_call_cross_entropy
(
*
args
)
...
@@ -421,13 +412,13 @@ class Distribution(Cell):
...
@@ -421,13 +412,13 @@ class Distribution(Cell):
*args (list): arguments passed in through construct.
*args (list): arguments passed in through construct.
Note:
Note:
Args must include name of the function
.
Shape of the sample is default to ()
.
Shape of the sample and d
ist_spec_args are optional.
D
ist_spec_args are optional.
"""
"""
return
self
.
_sample
(
*
args
)
return
self
.
_sample
(
*
args
)
def
construct
(
self
,
*
input
s
):
def
construct
(
self
,
name
,
*
arg
s
):
"""
"""
Override construct in Cell.
Override construct in Cell.
...
@@ -437,35 +428,36 @@ class Distribution(Cell):
...
@@ -437,35 +428,36 @@ class Distribution(Cell):
'var', 'sd', 'entropy', 'kl_loss', 'cross_entropy', 'sample'.
'var', 'sd', 'entropy', 'kl_loss', 'cross_entropy', 'sample'.
Args:
Args:
*inputs (list): inputs[0] is always the name of the function.
name (str): name of the function.
"""
*args (list): list of arguments needed for the function.
"""
if
inputs
[
0
]
==
'log_prob'
:
return
self
.
_call_log_prob
(
*
inputs
)
if
name
==
'log_prob'
:
if
inputs
[
0
]
==
'prob'
:
return
self
.
_call_log_prob
(
*
args
)
return
self
.
_call_prob
(
*
inputs
)
if
name
==
'prob'
:
if
inputs
[
0
]
==
'cdf'
:
return
self
.
_call_prob
(
*
args
)
return
self
.
_call_cdf
(
*
inputs
)
if
name
==
'cdf'
:
if
inputs
[
0
]
==
'log_cdf'
:
return
self
.
_call_cdf
(
*
args
)
return
self
.
_call_log_cdf
(
*
inputs
)
if
name
==
'log_cdf'
:
if
inputs
[
0
]
==
'survival_function'
:
return
self
.
_call_log_cdf
(
*
args
)
return
self
.
_call_survival
(
*
inputs
)
if
name
==
'survival_function'
:
if
inputs
[
0
]
==
'log_survival'
:
return
self
.
_call_survival
(
*
args
)
return
self
.
_call_log_survival
(
*
inputs
)
if
name
==
'log_survival'
:
if
inputs
[
0
]
==
'kl_loss'
:
return
self
.
_call_log_survival
(
*
args
)
return
self
.
_kl_loss
(
*
inputs
)
if
name
==
'kl_loss'
:
if
inputs
[
0
]
==
'mean'
:
return
self
.
_kl_loss
(
*
args
)
return
self
.
_mean
(
*
inputs
)
if
name
==
'mean'
:
if
inputs
[
0
]
==
'mode'
:
return
self
.
_mean
(
*
args
)
return
self
.
_mode
(
*
inputs
)
if
name
==
'mode'
:
if
inputs
[
0
]
==
'sd'
:
return
self
.
_mode
(
*
args
)
return
self
.
_call_sd
(
*
inputs
)
if
name
==
'sd'
:
if
inputs
[
0
]
==
'var'
:
return
self
.
_call_sd
(
*
args
)
return
self
.
_call_var
(
*
inputs
)
if
name
==
'var'
:
if
inputs
[
0
]
==
'entropy'
:
return
self
.
_call_var
(
*
args
)
return
self
.
_entropy
(
*
inputs
)
if
name
==
'entropy'
:
if
inputs
[
0
]
==
'cross_entropy'
:
return
self
.
_entropy
(
*
args
)
return
self
.
_call_cross_entropy
(
*
inputs
)
if
name
==
'cross_entropy'
:
if
inputs
[
0
]
==
'sample'
:
return
self
.
_call_cross_entropy
(
*
args
)
return
self
.
_sample
(
*
inputs
)
if
name
==
'sample'
:
return
self
.
_sample
(
*
args
)
return
None
return
None
mindspore/nn/probability/distribution/exponential.py
浏览文件 @
e87e1fc6
...
@@ -35,55 +35,56 @@ class Exponential(Distribution):
...
@@ -35,55 +35,56 @@ class Exponential(Distribution):
Examples:
Examples:
>>> # To initialize an Exponential distribution of rate 0.5
>>> # To initialize an Exponential distribution of rate 0.5
>>> n = nn.Exponential(0.5, dtype=mstype.float32)
>>> import mindspore.nn.probability.distribution as msd
>>> e = msd.Exponential(0.5, dtype=mstype.float32)
>>>
>>>
>>> # The following creates two independent Exponential distributions
>>> # The following creates two independent Exponential distributions
>>>
n = nn
.Exponential([0.5, 0.5], dtype=mstype.float32)
>>>
e = msd
.Exponential([0.5, 0.5], dtype=mstype.float32)
>>>
>>>
>>> # A Exponential distribution can be initilized without arguments
>>> # A
n
Exponential distribution can be initilized without arguments
>>> # In this case, rate must be passed in through
construct.
>>> # In this case, rate must be passed in through
args during function calls
>>>
n = nn
.Exponential(dtype=mstype.float32)
>>>
e = msd
.Exponential(dtype=mstype.float32)
>>>
>>>
>>> # To use Exponential
distribution
in a network
>>> # To use Exponential in a network
>>> class net(Cell):
>>> class net(Cell):
>>> def __init__(self):
>>> def __init__(self):
>>> super(net, self).__init__():
>>> super(net, self).__init__():
>>> self.e1 =
nn
.Exponential(0.5, dtype=mstype.float32)
>>> self.e1 =
msd
.Exponential(0.5, dtype=mstype.float32)
>>> self.e2 =
nn
.Exponential(dtype=mstype.float32)
>>> self.e2 =
msd
.Exponential(dtype=mstype.float32)
>>>
>>>
>>> # All the following calls in construct are valid
>>> # All the following calls in construct are valid
>>> def construct(self, value, rate_b, rate_a):
>>> def construct(self, value, rate_b, rate_a):
>>>
>>>
>>> # Similar calls can be made to other probability functions
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' with the name of the function
>>> # by replacing 'prob' with the name of the function
>>> ans = self.e1
('prob',
value)
>>> ans = self.e1
.prob(
value)
>>> # Evaluate with the respect to distribution b
>>> # Evaluate with the respect to distribution b
>>> ans = self.e1
('prob',
value, rate_b)
>>> ans = self.e1
.prob(
value, rate_b)
>>>
>>>
>>> # Rate must be passed in
through construct
>>> # Rate must be passed in
during function calls
>>> ans = self.e2
('prob',
value, rate_a)
>>> ans = self.e2
.prob(
value, rate_a)
>>>
>>>
>>> # Functions 'sd', 'var', 'entropy' have the same usage
with
'mean'
>>> # Functions 'sd', 'var', 'entropy' have the same usage
as
'mean'
>>> # Will return
[0.0]
>>> # Will return
2
>>> ans = self.e1
('mean'
)
>>> ans = self.e1
.mean(
)
>>> # Will return
mean
_b
>>> # Will return
1 / rate
_b
>>> ans = self.e1
('mean',
rate_b)
>>> ans = self.e1
.mean(
rate_b)
>>>
>>>
>>> # Rate must be passed in
through construct
>>> # Rate must be passed in
during function calls
>>> ans = self.e2
('mean',
rate_a)
>>> ans = self.e2
.mean(
rate_a)
>>>
>>>
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
>>> ans = self.e1
('kl_loss',
'Exponential', rate_b)
>>> ans = self.e1
.kl_loss(
'Exponential', rate_b)
>>> ans = self.e1
('kl_loss',
'Exponential', rate_b, rate_a)
>>> ans = self.e1
.kl_loss(
'Exponential', rate_b, rate_a)
>>>
>>>
>>> # Additional rate must be passed in
through construct
>>> # Additional rate must be passed in
>>> ans = self.e2
('kl_loss',
'Exponential', rate_b, rate_a)
>>> ans = self.e2
.kl_loss(
'Exponential', rate_b, rate_a)
>>>
>>>
>>> # Sample
Usage
>>> # Sample
>>> ans = self.e1
('sample'
)
>>> ans = self.e1
.sample(
)
>>> ans = self.e1
('sample',
(2,3))
>>> ans = self.e1
.sample(
(2,3))
>>> ans = self.e1
('sample',
(2,3), rate_b)
>>> ans = self.e1
.sample(
(2,3), rate_b)
>>> ans = self.e2
('sample',
(2,3), rate_a)
>>> ans = self.e2
.sample(
(2,3), rate_a)
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -131,67 +132,59 @@ class Exponential(Distribution):
...
@@ -131,67 +132,59 @@ class Exponential(Distribution):
"""
"""
return
self
.
_rate
return
self
.
_rate
def
_mean
(
self
,
name
=
'mean'
,
rate
=
None
):
def
_mean
(
self
,
rate
=
None
):
r
"""
r
"""
.. math::
.. math::
MEAN(EXP) = \fract{1.0}{\lambda}.
MEAN(EXP) = \fract{1.0}{\lambda}.
"""
"""
if
name
==
'mean'
:
rate
=
self
.
rate
if
rate
is
None
else
rate
rate
=
self
.
rate
if
rate
is
None
else
rate
return
1.0
/
rate
return
1.0
/
rate
return
None
def
_mode
(
self
,
name
=
'mode'
,
rate
=
None
):
def
_mode
(
self
,
rate
=
None
):
r
"""
r
"""
.. math::
.. math::
MODE(EXP) = 0.
MODE(EXP) = 0.
"""
"""
if
name
==
'mode'
:
rate
=
self
.
rate
if
rate
is
None
else
rate
rate
=
self
.
rate
if
rate
is
None
else
rate
return
self
.
fill
(
self
.
dtype
,
self
.
shape
(
rate
),
0.
)
return
self
.
fill
(
self
.
dtype
,
self
.
shape
(
rate
),
0.
)
return
None
def
_sd
(
self
,
name
=
'sd'
,
rate
=
None
):
def
_sd
(
self
,
rate
=
None
):
r
"""
r
"""
.. math::
.. math::
sd(EXP) = \fract{1.0}{\lambda}.
sd(EXP) = \fract{1.0}{\lambda}.
"""
"""
if
name
in
self
.
_variance_functions
:
rate
=
self
.
rate
if
rate
is
None
else
rate
rate
=
self
.
rate
if
rate
is
None
else
rate
return
1.0
/
rate
return
1.0
/
rate
return
None
def
_entropy
(
self
,
name
=
'entropy'
,
rate
=
None
):
def
_entropy
(
self
,
rate
=
None
):
r
"""
r
"""
.. math::
.. math::
H(Exp) = 1 - \log(\lambda).
H(Exp) = 1 - \log(\lambda).
"""
"""
rate
=
self
.
rate
if
rate
is
None
else
rate
rate
=
self
.
rate
if
rate
is
None
else
rate
if
name
==
'entropy'
:
return
1.0
-
self
.
log
(
rate
)
return
1.0
-
self
.
log
(
rate
)
return
None
def
_cross_entropy
(
self
,
name
,
dist
,
rate_b
,
rate_a
=
None
):
def
_cross_entropy
(
self
,
dist
,
rate_b
,
rate_a
=
None
):
"""
"""
Evaluate cross_entropy between Exponential distributions.
Evaluate cross_entropy between Exponential distributions.
Args:
Args:
name (str): name of the funtion. Should always be "cross_entropy" when passed in from construct.
dist (str): type of the distributions. Should be "Exponential" in this case.
dist (str): type of the distributions. Should be "Exponential" in this case.
rate_b (Tensor): rate of distribution b.
rate_b (Tensor): rate of distribution b.
rate_a (Tensor): rate of distribution a. Default: self.rate.
rate_a (Tensor): rate of distribution a. Default: self.rate.
"""
"""
if
name
==
'cross_entropy'
and
dist
==
'Exponential'
:
if
dist
==
'Exponential'
:
return
self
.
_entropy
(
rate
=
rate_a
)
+
self
.
_kl_loss
(
name
,
dist
,
rate_b
,
rate_a
)
return
self
.
_entropy
(
rate
=
rate_a
)
+
self
.
_kl_loss
(
dist
,
rate_b
,
rate_a
)
return
None
return
None
def
_prob
(
self
,
name
,
value
,
rate
=
None
):
def
_prob
(
self
,
value
,
rate
=
None
):
r
"""
r
"""
pdf of Exponential distribution.
pdf of Exponential distribution.
Args:
Args:
Args:
Args:
name (str): name of the function.
value (Tensor): value to be evaluated.
value (Tensor): value to be evaluated.
rate (Tensor): rate of the distribution. Default: self.rate.
rate (Tensor): rate of the distribution. Default: self.rate.
...
@@ -201,20 +194,17 @@ class Exponential(Distribution):
...
@@ -201,20 +194,17 @@ class Exponential(Distribution):
.. math::
.. math::
pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0
pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0
"""
"""
if
name
in
self
.
_prob_functions
:
rate
=
self
.
rate
if
rate
is
None
else
rate
rate
=
self
.
rate
if
rate
is
None
else
rate
prob
=
rate
*
self
.
exp
(
-
1.
*
rate
*
value
)
prob
=
rate
*
self
.
exp
(
-
1.
*
rate
*
value
)
zeros
=
self
.
fill
(
self
.
dtypeop
(
prob
),
self
.
shape
(
prob
),
0.0
)
zeros
=
self
.
fill
(
self
.
dtypeop
(
prob
),
self
.
shape
(
prob
),
0.0
)
comp
=
self
.
less
(
value
,
zeros
)
comp
=
self
.
less
(
value
,
zeros
)
return
self
.
select
(
comp
,
zeros
,
prob
)
return
self
.
select
(
comp
,
zeros
,
prob
)
return
None
def
_cdf
(
self
,
name
,
value
,
rate
=
None
):
def
_cdf
(
self
,
value
,
rate
=
None
):
r
"""
r
"""
cdf of Exponential distribution.
cdf of Exponential distribution.
Args:
Args:
name (str): name of the function.
value (Tensor): value to be evaluated.
value (Tensor): value to be evaluated.
rate (Tensor): rate of the distribution. Default: self.rate.
rate (Tensor): rate of the distribution. Default: self.rate.
...
@@ -224,45 +214,40 @@ class Exponential(Distribution):
...
@@ -224,45 +214,40 @@ class Exponential(Distribution):
.. math::
.. math::
cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0
cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0
"""
"""
if
name
in
self
.
_cdf_survival_functions
:
rate
=
self
.
rate
if
rate
is
None
else
rate
rate
=
self
.
rate
if
rate
is
None
else
rate
cdf
=
1.0
-
self
.
exp
(
-
1.
*
rate
*
value
)
cdf
=
1.0
-
self
.
exp
(
-
1.
*
rate
*
value
)
zeros
=
self
.
fill
(
self
.
dtypeop
(
cdf
),
self
.
shape
(
cdf
),
0.0
)
zeros
=
self
.
fill
(
self
.
dtypeop
(
cdf
),
self
.
shape
(
cdf
),
0.0
)
comp
=
self
.
less
(
value
,
zeros
)
comp
=
self
.
less
(
value
,
zeros
)
return
self
.
select
(
comp
,
zeros
,
cdf
)
return
self
.
select
(
comp
,
zeros
,
cdf
)
return
None
def
_kl_loss
(
self
,
name
,
dist
,
rate_b
,
rate_a
=
None
):
def
_kl_loss
(
self
,
dist
,
rate_b
,
rate_a
=
None
):
"""
"""
Evaluate exp-exp kl divergence, i.e. KL(a||b).
Evaluate exp-exp kl divergence, i.e. KL(a||b).
Args:
Args:
name (str): name of the funtion.
dist (str): type of the distributions. Should be "Exponential" in this case.
dist (str): type of the distributions. Should be "Exponential" in this case.
rate_b (Tensor): rate of distribution b.
rate_b (Tensor): rate of distribution b.
rate_a (Tensor): rate of distribution a. Default: self.rate.
rate_a (Tensor): rate of distribution a. Default: self.rate.
"""
"""
if
name
in
self
.
_divergence_functions
and
dist
==
'Exponential'
:
if
dist
==
'Exponential'
:
rate_a
=
self
.
rate
if
rate_a
is
None
else
rate_a
rate_a
=
self
.
rate
if
rate_a
is
None
else
rate_a
return
self
.
log
(
rate_a
)
-
self
.
log
(
rate_b
)
+
rate_b
/
rate_a
-
1.0
return
self
.
log
(
rate_a
)
-
self
.
log
(
rate_b
)
+
rate_b
/
rate_a
-
1.0
return
None
return
None
def
_sample
(
self
,
name
,
shape
=
(),
rate
=
None
):
def
_sample
(
self
,
shape
=
(),
rate
=
None
):
"""
"""
Sampling.
Sampling.
Args:
Args:
name (str): name of the function.
shape (tuple): shape of the sample. Default: ().
shape (tuple): shape of the sample. Default: ().
rate (Tensor): rate of the distribution. Default: self.rate.
rate (Tensor): rate of the distribution. Default: self.rate.
Returns:
Returns:
Tensor, shape is shape + batch_shape.
Tensor, shape is shape + batch_shape.
"""
"""
if
name
==
'sample'
:
rate
=
self
.
rate
if
rate
is
None
else
rate
rate
=
self
.
rate
if
rate
is
None
else
rate
minval
=
self
.
const
(
self
.
minval
)
minval
=
self
.
const
(
self
.
minval
)
maxval
=
self
.
const
(
1.0
)
maxval
=
self
.
const
(
1.0
)
sample
=
self
.
uniform
(
shape
+
self
.
shape
(
rate
),
minval
,
maxval
)
sample
=
self
.
uniform
(
shape
+
self
.
shape
(
rate
),
minval
,
maxval
)
return
-
self
.
log
(
sample
)
/
rate
return
-
self
.
log
(
sample
)
/
rate
return
None
mindspore/nn/probability/distribution/geometric.py
浏览文件 @
e87e1fc6
...
@@ -36,55 +36,56 @@ class Geometric(Distribution):
...
@@ -36,55 +36,56 @@ class Geometric(Distribution):
Examples:
Examples:
>>> # To initialize a Geometric distribution of prob 0.5
>>> # To initialize a Geometric distribution of prob 0.5
>>> n = nn.Geometric(0.5, dtype=mstype.int32)
>>> import mindspore.nn.probability.distribution as msd
>>> n = msd.Geometric(0.5, dtype=mstype.int32)
>>>
>>>
>>> # The following creates two independent Geometric distributions
>>> # The following creates two independent Geometric distributions
>>> n =
nn
.Geometric([0.5, 0.5], dtype=mstype.int32)
>>> n =
msd
.Geometric([0.5, 0.5], dtype=mstype.int32)
>>>
>>>
>>> # A Geometric distribution can be initilized without arguments
>>> # A Geometric distribution can be initilized without arguments
>>> # In this case, probs must be passed in through
construct
.
>>> # In this case, probs must be passed in through
args during function calls
.
>>> n =
nn
.Geometric(dtype=mstype.int32)
>>> n =
msd
.Geometric(dtype=mstype.int32)
>>>
>>>
>>> # To use Geometric
distribution
in a network
>>> # To use Geometric in a network
>>> class net(Cell):
>>> class net(Cell):
>>> def __init__(self):
>>> def __init__(self):
>>> super(net, self).__init__():
>>> super(net, self).__init__():
>>> self.g1 =
nn
.Geometric(0.5, dtype=mstype.int32)
>>> self.g1 =
msd
.Geometric(0.5, dtype=mstype.int32)
>>> self.g2 =
nn
.Geometric(dtype=mstype.int32)
>>> self.g2 =
msd
.Geometric(dtype=mstype.int32)
>>>
>>>
>>> # Tthe following calls are valid in construct
>>> # Tthe following calls are valid in construct
>>> def construct(self, value, probs_b, probs_a):
>>> def construct(self, value, probs_b, probs_a):
>>>
>>>
>>> # Similar calls can be made to other probability functions
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' with the name of the function
>>> # by replacing 'prob' with the name of the function
>>> ans = self.g1
('prob',
value)
>>> ans = self.g1
.prob(
value)
>>> # Evaluate with the respect to distribution b
>>> # Evaluate with the respect to distribution b
>>> ans = self.g1
('prob',
value, probs_b)
>>> ans = self.g1
.prob(
value, probs_b)
>>>
>>>
>>> # Probs must be passed in
through construct
>>> # Probs must be passed in
during function calls
>>> ans = self.g2
('prob',
value, probs_a)
>>> ans = self.g2
.prob(
value, probs_a)
>>>
>>>
>>> # Functions 'sd', 'var', 'entropy' have the same usage
with
'mean'
>>> # Functions 'sd', 'var', 'entropy' have the same usage
as
'mean'
>>> # Will return
[0.0]
>>> # Will return
1.0
>>> ans = self.g1
('mean'
)
>>> ans = self.g1
.mean(
)
>>> #
Will return mean_b
>>> #
Another possible usage
>>> ans = self.g1
('mean',
probs_b)
>>> ans = self.g1
.mean(
probs_b)
>>>
>>>
>>> # Probs must be passed in
through construct
>>> # Probs must be passed in
during function calls
>>> ans = self.g2
('mean',
probs_a)
>>> ans = self.g2
.mean(
probs_a)
>>>
>>>
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
>>> ans = self.g1
('kl_loss',
'Geometric', probs_b)
>>> ans = self.g1
.kl_loss(
'Geometric', probs_b)
>>> ans = self.g1
('kl_loss',
'Geometric', probs_b, probs_a)
>>> ans = self.g1
.kl_loss(
'Geometric', probs_b, probs_a)
>>>
>>>
>>> # Additional probs must be passed in
through construct
>>> # Additional probs must be passed in
>>> ans = self.g2
('kl_loss',
'Geometric', probs_b, probs_a)
>>> ans = self.g2
.kl_loss(
'Geometric', probs_b, probs_a)
>>>
>>>
>>> # Sample
Usage
>>> # Sample
>>> ans = self.g1
('sample'
)
>>> ans = self.g1
.sample(
)
>>> ans = self.g1
('sample',
(2,3))
>>> ans = self.g1
.sample(
(2,3))
>>> ans = self.g1
('sample',
(2,3), probs_b)
>>> ans = self.g1
.sample(
(2,3), probs_b)
>>> ans = self.g2
('sample',
(2,3), probs_a)
>>> ans = self.g2
.sample(
(2,3), probs_a)
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -134,67 +135,57 @@ class Geometric(Distribution):
...
@@ -134,67 +135,57 @@ class Geometric(Distribution):
"""
"""
return
self
.
_probs
return
self
.
_probs
def
_mean
(
self
,
name
=
'mean'
,
probs1
=
None
):
def
_mean
(
self
,
probs1
=
None
):
r
"""
r
"""
.. math::
.. math::
MEAN(Geo) = \fratc{1 - probs1}{probs1}
MEAN(Geo) = \fratc{1 - probs1}{probs1}
"""
"""
if
name
==
'mean'
:
probs1
=
self
.
probs
if
probs1
is
None
else
probs1
probs1
=
self
.
probs
if
probs1
is
None
else
probs1
return
(
1.
-
probs1
)
/
probs1
return
(
1.
-
probs1
)
/
probs1
return
None
def
_mode
(
self
,
name
=
'mode'
,
probs1
=
None
):
def
_mode
(
self
,
probs1
=
None
):
r
"""
r
"""
.. math::
.. math::
MODE(Geo) = 0
MODE(Geo) = 0
"""
"""
if
name
==
'mode'
:
probs1
=
self
.
probs
if
probs1
is
None
else
probs1
probs1
=
self
.
probs
if
probs1
is
None
else
probs1
return
self
.
fill
(
self
.
dtypeop
(
probs1
),
self
.
shape
(
probs1
),
0.
)
return
self
.
fill
(
self
.
dtypeop
(
probs1
),
self
.
shape
(
probs1
),
0.
)
return
None
def
_var
(
self
,
name
=
'var'
,
probs1
=
None
):
def
_var
(
self
,
probs1
=
None
):
r
"""
r
"""
.. math::
.. math::
VAR(Geo) = \fract{1 - probs1}{probs1 ^ {2}}
VAR(Geo) = \fract{1 - probs1}{probs1 ^ {2}}
"""
"""
if
name
in
self
.
_variance_functions
:
probs1
=
self
.
probs
if
probs1
is
None
else
probs1
probs1
=
self
.
probs
if
probs1
is
None
else
probs1
return
(
1.0
-
probs1
)
/
self
.
sq
(
probs1
)
return
(
1.0
-
probs1
)
/
self
.
sq
(
probs1
)
return
None
def
_entropy
(
self
,
name
=
'entropy'
,
probs
=
None
):
def
_entropy
(
self
,
probs
=
None
):
r
"""
r
"""
.. math::
.. math::
H(Geo) = \fract{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
H(Geo) = \fract{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
"""
"""
if
name
==
'entropy'
:
probs1
=
self
.
probs
if
probs
is
None
else
probs
probs1
=
self
.
probs
if
probs
is
None
else
probs
probs0
=
1.0
-
probs1
probs0
=
1.0
-
probs1
return
(
-
probs0
*
self
.
log
(
probs0
)
-
probs1
*
self
.
log
(
probs1
))
/
probs1
return
(
-
probs0
*
self
.
log
(
probs0
)
-
probs1
*
self
.
log
(
probs1
))
/
probs1
return
None
def
_cross_entropy
(
self
,
name
,
dist
,
probs1_b
,
probs1_a
=
None
):
def
_cross_entropy
(
self
,
dist
,
probs1_b
,
probs1_a
=
None
):
r
"""
r
"""
Evaluate cross_entropy between Geometric distributions.
Evaluate cross_entropy between Geometric distributions.
Args:
Args:
name (str): name of the funtion. Should always be "cross_entropy" when passed in from construct.
dist (str): type of the distributions. Should be "Geometric" in this case.
dist (str): type of the distributions. Should be "Geometric" in this case.
probs1_b (Tensor): probability of success of distribution b.
probs1_b (Tensor): probability of success of distribution b.
probs1_a (Tensor): probability of success of distribution a. Default: self.probs.
probs1_a (Tensor): probability of success of distribution a. Default: self.probs.
"""
"""
if
name
==
'cross_entropy'
and
dist
==
'Geometric'
:
if
dist
==
'Geometric'
:
return
self
.
_entropy
(
probs
=
probs1_a
)
+
self
.
_kl_loss
(
name
,
dist
,
probs1_b
,
probs1_a
)
return
self
.
_entropy
(
probs
=
probs1_a
)
+
self
.
_kl_loss
(
dist
,
probs1_b
,
probs1_a
)
return
None
return
None
def
_prob
(
self
,
name
,
value
,
probs
=
None
):
def
_prob
(
self
,
value
,
probs
=
None
):
r
"""
r
"""
pmf of Geometric distribution.
pmf of Geometric distribution.
Args:
Args:
name (str): name of the function. Should be "prob" when passed in from construct.
value (Tensor): a Tensor composed of only natural numbers.
value (Tensor): a Tensor composed of only natural numbers.
probs (Tensor): probability of success. Default: self.probs.
probs (Tensor): probability of success. Default: self.probs.
...
@@ -202,27 +193,24 @@ class Geometric(Distribution):
...
@@ -202,27 +193,24 @@ class Geometric(Distribution):
pmf(k) = probs0 ^k * probs1 if k >= 0;
pmf(k) = probs0 ^k * probs1 if k >= 0;
pmf(k) = 0 if k < 0.
pmf(k) = 0 if k < 0.
"""
"""
if
name
in
self
.
_prob_functions
:
probs1
=
self
.
probs
if
probs
is
None
else
probs
probs1
=
self
.
probs
if
probs
is
None
else
probs
dtype
=
self
.
dtypeop
(
value
)
dtype
=
self
.
dtypeop
(
value
)
if
self
.
issubclass
(
dtype
,
mstype
.
int_
):
if
self
.
issubclass
(
dtype
,
mstype
.
int_
):
pass
pass
elif
self
.
issubclass
(
dtype
,
mstype
.
float_
):
elif
self
.
issubclass
(
dtype
,
mstype
.
float_
):
value
=
self
.
floor
(
value
)
value
=
self
.
floor
(
value
)
else
:
else
:
return
None
return
None
pmf
=
self
.
pow
((
1.0
-
probs1
),
value
)
*
probs1
pmf
=
self
.
pow
((
1.0
-
probs1
),
value
)
*
probs1
zeros
=
self
.
fill
(
self
.
dtypeop
(
probs1
),
self
.
shape
(
pmf
),
0.0
)
zeros
=
self
.
fill
(
self
.
dtypeop
(
probs1
),
self
.
shape
(
pmf
),
0.0
)
comp
=
self
.
less
(
value
,
zeros
)
comp
=
self
.
less
(
value
,
zeros
)
return
self
.
select
(
comp
,
zeros
,
pmf
)
return
self
.
select
(
comp
,
zeros
,
pmf
)
return
None
def
_cdf
(
self
,
name
,
value
,
probs
=
None
):
def
_cdf
(
self
,
value
,
probs
=
None
):
r
"""
r
"""
cdf of Geometric distribution.
cdf of Geometric distribution.
Args:
Args:
name (str): name of the function.
value (Tensor): a Tensor composed of only natural numbers.
value (Tensor): a Tensor composed of only natural numbers.
probs (Tensor): probability of success. Default: self.probs.
probs (Tensor): probability of success. Default: self.probs.
...
@@ -231,28 +219,26 @@ class Geometric(Distribution):
...
@@ -231,28 +219,26 @@ class Geometric(Distribution):
cdf(k) = 0 if k < 0.
cdf(k) = 0 if k < 0.
"""
"""
if
name
in
self
.
_cdf_survival_functions
:
probs1
=
self
.
probs
if
probs
is
None
else
probs
probs1
=
self
.
probs
if
probs
is
None
else
probs
probs0
=
1.0
-
probs1
probs0
=
1.0
-
probs1
dtype
=
self
.
dtypeop
(
value
)
dtype
=
self
.
dtypeop
(
value
)
if
self
.
issubclass
(
dtype
,
mstype
.
int_
):
if
self
.
issubclass
(
dtype
,
mstype
.
int_
):
pass
pass
elif
self
.
issubclass
(
dtype
,
mstype
.
float_
):
elif
self
.
issubclass
(
dtype
,
mstype
.
float_
):
value
=
self
.
floor
(
value
)
value
=
self
.
floor
(
value
)
else
:
else
:
return
None
return
None
cdf
=
1.0
-
self
.
pow
(
probs0
,
value
+
1.0
)
cdf
=
1.0
-
self
.
pow
(
probs0
,
value
+
1.0
)
zeros
=
self
.
fill
(
self
.
dtypeop
(
probs1
),
self
.
shape
(
cdf
),
0.0
)
zeros
=
self
.
fill
(
self
.
dtypeop
(
probs1
),
self
.
shape
(
cdf
),
0.0
)
comp
=
self
.
less
(
value
,
zeros
)
comp
=
self
.
less
(
value
,
zeros
)
return
self
.
select
(
comp
,
zeros
,
cdf
)
return
self
.
select
(
comp
,
zeros
,
cdf
)
return
None
def
_kl_loss
(
self
,
name
,
dist
,
probs1_b
,
probs1_a
=
None
):
def
_kl_loss
(
self
,
dist
,
probs1_b
,
probs1_a
=
None
):
r
"""
r
"""
Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b).
Evaluate Geometric-Geometric kl divergence, i.e. KL(a||b).
Args:
Args:
name (str): name of the funtion.
dist (str): type of the distributions. Should be "Geometric" in this case.
dist (str): type of the distributions. Should be "Geometric" in this case.
probs1_b (Tensor): probability of success of distribution b.
probs1_b (Tensor): probability of success of distribution b.
probs1_a (Tensor): probability of success of distribution a. Default: self.probs.
probs1_a (Tensor): probability of success of distribution a. Default: self.probs.
...
@@ -260,29 +246,26 @@ class Geometric(Distribution):
...
@@ -260,29 +246,26 @@ class Geometric(Distribution):
.. math::
.. math::
KL(a||b) = \log(\fract{probs1_a}{probs1_b}) + \fract{probs0_a}{probs1_a} * \log(\fract{probs0_a}{probs0_b})
KL(a||b) = \log(\fract{probs1_a}{probs1_b}) + \fract{probs0_a}{probs1_a} * \log(\fract{probs0_a}{probs0_b})
"""
"""
if
name
in
self
.
_divergence_functions
and
dist
==
'Geometric'
:
if
dist
==
'Geometric'
:
probs1_a
=
self
.
probs
if
probs1_a
is
None
else
probs1_a
probs1_a
=
self
.
probs
if
probs1_a
is
None
else
probs1_a
probs0_a
=
1.0
-
probs1_a
probs0_a
=
1.0
-
probs1_a
probs0_b
=
1.0
-
probs1_b
probs0_b
=
1.0
-
probs1_b
return
self
.
log
(
probs1_a
/
probs1_b
)
+
(
probs0_a
/
probs1_a
)
*
self
.
log
(
probs0_a
/
probs0_b
)
return
self
.
log
(
probs1_a
/
probs1_b
)
+
(
probs0_a
/
probs1_a
)
*
self
.
log
(
probs0_a
/
probs0_b
)
return
None
return
None
def
_sample
(
self
,
name
,
shape
=
(),
probs
=
None
):
def
_sample
(
self
,
shape
=
(),
probs
=
None
):
"""
"""
Sampling.
Sampling.
Args:
Args:
name (str): name of the function. Should always be 'sample' when passed in from construct.
shape (tuple): shape of the sample. Default: ().
shape (tuple): shape of the sample. Default: ().
probs (Tensor): probability of success. Default: self.probs.
probs (Tensor): probability of success. Default: self.probs.
Returns:
Returns:
Tensor, shape is shape + batch_shape.
Tensor, shape is shape + batch_shape.
"""
"""
if
name
==
'sample'
:
probs
=
self
.
probs
if
probs
is
None
else
probs
probs
=
self
.
probs
if
probs
is
None
else
probs
minval
=
self
.
const
(
self
.
minval
)
minval
=
self
.
const
(
self
.
minval
)
maxval
=
self
.
const
(
1.0
)
maxval
=
self
.
const
(
1.0
)
sample_uniform
=
self
.
uniform
(
shape
+
self
.
shape
(
probs
),
minval
,
maxval
)
sample_uniform
=
self
.
uniform
(
shape
+
self
.
shape
(
probs
),
minval
,
maxval
)
return
self
.
floor
(
self
.
log
(
sample_uniform
)
/
self
.
log
(
1.0
-
probs
))
return
self
.
floor
(
self
.
log
(
sample_uniform
)
/
self
.
log
(
1.0
-
probs
))
return
None
mindspore/nn/probability/distribution/normal.py
浏览文件 @
e87e1fc6
...
@@ -17,7 +17,6 @@ import numpy as np
...
@@ -17,7 +17,6 @@ import numpy as np
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
composite
as
C
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common
import
dtype
as
mstype
from
mindspore.context
import
get_context
from
.distribution
import
Distribution
from
.distribution
import
Distribution
from
._utils.utils
import
convert_to_batch
,
check_greater_equal_zero
from
._utils.utils
import
convert_to_batch
,
check_greater_equal_zero
...
@@ -39,55 +38,56 @@ class Normal(Distribution):
...
@@ -39,55 +38,56 @@ class Normal(Distribution):
Examples:
Examples:
>>> # To initialize a Normal distribution of mean 3.0 and standard deviation 4.0
>>> # To initialize a Normal distribution of mean 3.0 and standard deviation 4.0
>>> n = nn.Normal(3.0, 4.0, dtype=mstype.float32)
>>> import mindspore.nn.probability.distribution as msd
>>> n = msd.Normal(3.0, 4.0, dtype=mstype.float32)
>>>
>>>
>>> # The following creates two independent Normal distributions
>>> # The following creates two independent Normal distributions
>>> n =
nn
.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32)
>>> n =
msd
.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32)
>>>
>>>
>>> # A
n
ormal distribution can be initilize without arguments
>>> # A
N
ormal distribution can be initilize without arguments
>>> # In this case, mean and sd must be passed in through
construct
.
>>> # In this case, mean and sd must be passed in through
args
.
>>> n =
nn
.Normal(dtype=mstype.float32)
>>> n =
msd
.Normal(dtype=mstype.float32)
>>>
>>>
>>> # To use
n
ormal in a network
>>> # To use
N
ormal in a network
>>> class net(Cell):
>>> class net(Cell):
>>> def __init__(self):
>>> def __init__(self):
>>> super(net, self).__init__():
>>> super(net, self).__init__():
>>> self.n1 =
nn.Norm
al(0.0, 1.0, dtype=mstype.float32)
>>> self.n1 =
msd.Nomr
al(0.0, 1.0, dtype=mstype.float32)
>>> self.n2 =
nn
.Normal(dtype=mstype.float32)
>>> self.n2 =
msd
.Normal(dtype=mstype.float32)
>>>
>>>
>>> # The following calls are valid in construct
>>> # The following calls are valid in construct
>>> def construct(self, value, mean_b, sd_b, mean_a, sd_a):
>>> def construct(self, value, mean_b, sd_b, mean_a, sd_a):
>>>
>>>
>>> # Similar calls can be made to other probability functions
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' with the name of the function
>>> # by replacing 'prob' with the name of the function
>>> ans = self.n1
('prob',
value)
>>> ans = self.n1
.prob(
value)
>>> # Evaluate with the respect to distribution b
>>> # Evaluate with the respect to distribution b
>>> ans = self.n1
('prob',
value, mean_b, sd_b)
>>> ans = self.n1
.prob(
value, mean_b, sd_b)
>>>
>>>
>>> # mean and sd must be passed in
through construct
>>> # mean and sd must be passed in
during function calls
>>> ans = self.n2
('prob',
value, mean_a, sd_a)
>>> ans = self.n2
.prob(
value, mean_a, sd_a)
>>>
>>>
>>> # Functions 'sd', 'var', 'entropy' have the same usage
with
'mean'
>>> # Functions 'sd', 'var', 'entropy' have the same usage
as
'mean'
>>> #
W
ill return [0.0]
>>> #
w
ill return [0.0]
>>> ans = self.n1
('mean'
)
>>> ans = self.n1
.mean(
)
>>> #
W
ill return mean_b
>>> #
w
ill return mean_b
>>> ans = self.n1
('mean',
mean_b, sd_b)
>>> ans = self.n1
.mean(
mean_b, sd_b)
>>>
>>>
>>> # mean and sd must be passed
in through construct
>>> # mean and sd must be passed
during function calls
>>> ans = self.n2
('mean',
mean_a, sd_a)
>>> ans = self.n2
.mean(
mean_a, sd_a)
>>>
>>>
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
>>> ans = self.n1
('kl_loss',
'Normal', mean_b, sd_b)
>>> ans = self.n1
.kl_loss(
'Normal', mean_b, sd_b)
>>> ans = self.n1
('kl_loss',
'Normal', mean_b, sd_b, mean_a, sd_a)
>>> ans = self.n1
.kl_loss(
'Normal', mean_b, sd_b, mean_a, sd_a)
>>>
>>>
>>> # Additional mean and sd must be passed
in through construct
>>> # Additional mean and sd must be passed
>>> ans = self.n2
('kl_loss',
'Normal', mean_b, sd_b, mean_a, sd_a)
>>> ans = self.n2
.kl_loss(
'Normal', mean_b, sd_b, mean_a, sd_a)
>>>
>>>
>>> # Sample
Usage
>>> # Sample
>>> ans = self.n1
('sample'
)
>>> ans = self.n1
.sample(
)
>>> ans = self.n1
('sample',
(2,3))
>>> ans = self.n1
.sample(
(2,3))
>>> ans = self.n1
('sample',
(2,3), mean_b, sd_b)
>>> ans = self.n1
.sample(
(2,3), mean_b, sd_b)
>>> ans = self.n2
('sample',
(2,3), mean_a, sd_a)
>>> ans = self.n2
.sample(
(2,3), mean_a, sd_a)
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -114,7 +114,7 @@ class Normal(Distribution):
...
@@ -114,7 +114,7 @@ class Normal(Distribution):
self
.
const
=
P
.
ScalarToArray
()
self
.
const
=
P
.
ScalarToArray
()
self
.
erf
=
P
.
Erf
()
self
.
erf
=
P
.
Erf
()
self
.
exp
=
P
.
Exp
()
self
.
exp
=
P
.
Exp
()
self
.
expm1
=
P
.
Expm1
()
if
get_context
(
'device_target'
)
==
'Ascend'
else
self
.
_expm1_by_step
self
.
expm1
=
self
.
_expm1_by_step
self
.
fill
=
P
.
Fill
()
self
.
fill
=
P
.
Fill
()
self
.
log
=
P
.
Log
()
self
.
log
=
P
.
Log
()
self
.
shape
=
P
.
Shape
()
self
.
shape
=
P
.
Shape
()
...
@@ -135,67 +135,57 @@ class Normal(Distribution):
...
@@ -135,67 +135,57 @@ class Normal(Distribution):
"""
"""
return
self
.
exp
(
x
)
-
1.0
return
self
.
exp
(
x
)
-
1.0
def
_mean
(
self
,
name
=
'mean'
,
mean
=
None
,
sd
=
None
):
def
_mean
(
self
,
mean
=
None
,
sd
=
None
):
"""
"""
Mean of the distribution.
Mean of the distribution.
"""
"""
if
name
==
'mean'
:
mean
=
self
.
_mean_value
if
mean
is
None
or
sd
is
None
else
mean
mean
=
self
.
_mean_value
if
mean
is
None
or
sd
is
None
else
mean
return
mean
return
mean
return
None
def
_mode
(
self
,
name
=
'mode'
,
mean
=
None
,
sd
=
None
):
def
_mode
(
self
,
mean
=
None
,
sd
=
None
):
"""
"""
Mode of the distribution.
Mode of the distribution.
"""
"""
if
name
==
'mode'
:
mean
=
self
.
_mean_value
if
mean
is
None
or
sd
is
None
else
mean
mean
=
self
.
_mean_value
if
mean
is
None
or
sd
is
None
else
mean
return
mean
return
mean
return
None
def
_sd
(
self
,
name
=
'sd'
,
mean
=
None
,
sd
=
None
):
def
_sd
(
self
,
mean
=
None
,
sd
=
None
):
"""
"""
Standard deviation of the distribution.
Standard deviation of the distribution.
"""
"""
if
name
in
self
.
_variance_functions
:
sd
=
self
.
_sd_value
if
mean
is
None
or
sd
is
None
else
sd
sd
=
self
.
_sd_value
if
mean
is
None
or
sd
is
None
else
sd
return
sd
return
sd
return
None
def
_entropy
(
self
,
name
=
'entropy'
,
sd
=
None
):
def
_entropy
(
self
,
sd
=
None
):
r
"""
r
"""
Evaluate entropy.
Evaluate entropy.
.. math::
.. math::
H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma)))
H(X) = \log(\sqrt(numpy.e * 2. * numpy.pi * \sq(\sigma)))
"""
"""
if
name
==
'entropy'
:
sd
=
self
.
_sd_value
if
sd
is
None
else
sd
sd
=
self
.
_sd_value
if
sd
is
None
else
sd
return
self
.
log
(
self
.
sqrt
(
np
.
e
*
2.
*
np
.
pi
*
self
.
sq
(
sd
)))
return
self
.
log
(
self
.
sqrt
(
np
.
e
*
2.
*
np
.
pi
*
self
.
sq
(
sd
)))
return
None
def
_cross_entropy
(
self
,
name
,
dist
,
mean_b
,
sd_b
,
mean_a
=
None
,
sd_a
=
None
):
def
_cross_entropy
(
self
,
dist
,
mean_b
,
sd_b
,
mean_a
=
None
,
sd_a
=
None
):
r
"""
r
"""
Evaluate cross_entropy between normal distributions.
Evaluate cross_entropy between normal distributions.
Args:
Args:
name (str): name of the funtion passed in from construct. Should always be "cross_entropy".
dist (str): type of the distributions. Should be "Normal" in this case.
dist (str): type of the distributions. Should be "Normal" in this case.
mean_b (Tensor): mean of distribution b.
mean_b (Tensor): mean of distribution b.
sd_b (Tensor): standard deviation distribution b.
sd_b (Tensor): standard deviation distribution b.
mean_a (Tensor): mean of distribution a. Default: self._mean_value.
mean_a (Tensor): mean of distribution a. Default: self._mean_value.
sd_a (Tensor): standard deviation distribution a. Default: self._sd_value.
sd_a (Tensor): standard deviation distribution a. Default: self._sd_value.
"""
"""
if
name
==
'cross_entropy'
and
dist
==
'Normal'
:
if
dist
==
'Normal'
:
return
self
.
_entropy
(
sd
=
sd_a
)
+
self
.
_kl_loss
(
name
,
dist
,
mean_b
,
sd_b
,
mean_a
,
sd_a
)
return
self
.
_entropy
(
sd
=
sd_a
)
+
self
.
_kl_loss
(
dist
,
mean_b
,
sd_b
,
mean_a
,
sd_a
)
return
None
return
None
def
_log_prob
(
self
,
name
,
value
,
mean
=
None
,
sd
=
None
):
def
_log_prob
(
self
,
value
,
mean
=
None
,
sd
=
None
):
r
"""
r
"""
Evaluate log probability.
Evaluate log probability.
Args:
Args:
name (str): name of the funtion passed in from construct.
value (Tensor): value to be evaluated.
value (Tensor): value to be evaluated.
mean (Tensor): mean of the distribution. Default: self._mean_value.
mean (Tensor): mean of the distribution. Default: self._mean_value.
sd (Tensor): standard deviation the distribution. Default: self._sd_value.
sd (Tensor): standard deviation the distribution. Default: self._sd_value.
...
@@ -203,20 +193,17 @@ class Normal(Distribution):
...
@@ -203,20 +193,17 @@ class Normal(Distribution):
.. math::
.. math::
L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
"""
"""
if
name
in
self
.
_prob_functions
:
mean
=
self
.
_mean_value
if
mean
is
None
else
mean
mean
=
self
.
_mean_value
if
mean
is
None
else
mean
sd
=
self
.
_sd_value
if
sd
is
None
else
sd
sd
=
self
.
_sd_value
if
sd
is
None
else
sd
unnormalized_log_prob
=
-
1.
*
(
self
.
sq
(
value
-
mean
))
/
(
2.
*
self
.
sq
(
sd
))
unnormalized_log_prob
=
-
1.
*
(
self
.
sq
(
value
-
mean
))
/
(
2.
*
self
.
sq
(
sd
))
neg_normalization
=
-
1.
*
self
.
log
(
self
.
sqrt
(
2.
*
np
.
pi
*
self
.
sq
(
sd
)))
neg_normalization
=
-
1.
*
self
.
log
(
self
.
sqrt
(
2.
*
np
.
pi
*
self
.
sq
(
sd
)))
return
unnormalized_log_prob
+
neg_normalization
return
unnormalized_log_prob
+
neg_normalization
return
None
def
_cdf
(
self
,
name
,
value
,
mean
=
None
,
sd
=
None
):
def
_cdf
(
self
,
value
,
mean
=
None
,
sd
=
None
):
r
"""
r
"""
Evaluate cdf of given value.
Evaluate cdf of given value.
Args:
Args:
name (str): name of the funtion passed in from construct. Should always be "cdf".
value (Tensor): value to be evaluated.
value (Tensor): value to be evaluated.
mean (Tensor): mean of the distribution. Default: self._mean_value.
mean (Tensor): mean of the distribution. Default: self._mean_value.
sd (Tensor): standard deviation the distribution. Default: self._sd_value.
sd (Tensor): standard deviation the distribution. Default: self._sd_value.
...
@@ -224,20 +211,17 @@ class Normal(Distribution):
...
@@ -224,20 +211,17 @@ class Normal(Distribution):
.. math::
.. math::
cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2))))
cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2))))
"""
"""
if
name
in
self
.
_cdf_survival_functions
:
mean
=
self
.
_mean_value
if
mean
is
None
else
mean
mean
=
self
.
_mean_value
if
mean
is
None
else
mean
sd
=
self
.
_sd_value
if
sd
is
None
else
sd
sd
=
self
.
_sd_value
if
sd
is
None
else
sd
sqrt2
=
self
.
sqrt
(
self
.
const
(
2.0
))
sqrt2
=
self
.
sqrt
(
self
.
const
(
2.0
))
adjusted
=
(
value
-
mean
)
/
(
sd
*
sqrt2
)
adjusted
=
(
value
-
mean
)
/
(
sd
*
sqrt2
)
return
0.5
*
(
1.0
+
self
.
erf
(
adjusted
))
return
0.5
*
(
1.0
+
self
.
erf
(
adjusted
))
return
None
def
_kl_loss
(
self
,
name
,
dist
,
mean_b
,
sd_b
,
mean_a
=
None
,
sd_a
=
None
):
def
_kl_loss
(
self
,
dist
,
mean_b
,
sd_b
,
mean_a
=
None
,
sd_a
=
None
):
r
"""
r
"""
Evaluate Normal-Normal kl divergence, i.e. KL(a||b).
Evaluate Normal-Normal kl divergence, i.e. KL(a||b).
Args:
Args:
name (str): name of the funtion passed in from construct.
dist (str): type of the distributions. Should be "Normal" in this case.
dist (str): type of the distributions. Should be "Normal" in this case.
mean_b (Tensor): mean of distribution b.
mean_b (Tensor): mean of distribution b.
sd_b (Tensor): standard deviation distribution b.
sd_b (Tensor): standard deviation distribution b.
...
@@ -248,7 +232,7 @@ class Normal(Distribution):
...
@@ -248,7 +232,7 @@ class Normal(Distribution):
KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 +
KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 +
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
"""
"""
if
name
in
self
.
_divergence_functions
and
dist
==
'Normal'
:
if
dist
==
'Normal'
:
mean_a
=
self
.
_mean_value
if
mean_a
is
None
else
mean_a
mean_a
=
self
.
_mean_value
if
mean_a
is
None
else
mean_a
sd_a
=
self
.
_sd_value
if
sd_a
is
None
else
sd_a
sd_a
=
self
.
_sd_value
if
sd_a
is
None
else
sd_a
diff_log_scale
=
self
.
log
(
sd_a
)
-
self
.
log
(
sd_b
)
diff_log_scale
=
self
.
log
(
sd_a
)
-
self
.
log
(
sd_b
)
...
@@ -256,12 +240,11 @@ class Normal(Distribution):
...
@@ -256,12 +240,11 @@ class Normal(Distribution):
return
0.5
*
squared_diff
+
0.5
*
self
.
expm1
(
2
*
diff_log_scale
)
-
diff_log_scale
return
0.5
*
squared_diff
+
0.5
*
self
.
expm1
(
2
*
diff_log_scale
)
-
diff_log_scale
return
None
return
None
def
_sample
(
self
,
name
,
shape
=
(),
mean
=
None
,
sd
=
None
):
def
_sample
(
self
,
shape
=
(),
mean
=
None
,
sd
=
None
):
"""
"""
Sampling.
Sampling.
Args:
Args:
name (str): name of the function. Should always be 'sample' when passed in from construct.
shape (tuple): shape of the sample. Default: ().
shape (tuple): shape of the sample. Default: ().
mean (Tensor): mean of the samples. Default: self._mean_value.
mean (Tensor): mean of the samples. Default: self._mean_value.
sd (Tensor): standard deviation of the samples. Default: self._sd_value.
sd (Tensor): standard deviation of the samples. Default: self._sd_value.
...
@@ -269,14 +252,12 @@ class Normal(Distribution):
...
@@ -269,14 +252,12 @@ class Normal(Distribution):
Returns:
Returns:
Tensor, shape is shape + batch_shape.
Tensor, shape is shape + batch_shape.
"""
"""
if
name
==
'sample'
:
mean
=
self
.
_mean_value
if
mean
is
None
else
mean
mean
=
self
.
_mean_value
if
mean
is
None
else
mean
sd
=
self
.
_sd_value
if
sd
is
None
else
sd
sd
=
self
.
_sd_value
if
sd
is
None
else
sd
batch_shape
=
self
.
shape
(
self
.
zeroslike
(
mean
)
+
self
.
zeroslike
(
sd
))
batch_shape
=
self
.
shape
(
self
.
zeroslike
(
mean
)
+
self
.
zeroslike
(
sd
))
sample_shape
=
shape
+
batch_shape
sample_shape
=
shape
+
batch_shape
mean_zero
=
self
.
const
(
0.0
)
mean_zero
=
self
.
const
(
0.0
)
sd_one
=
self
.
const
(
1.0
)
sd_one
=
self
.
const
(
1.0
)
sample_norm
=
C
.
normal
(
sample_shape
,
mean_zero
,
sd_one
,
self
.
seed
)
sample_norm
=
C
.
normal
(
sample_shape
,
mean_zero
,
sd_one
,
self
.
seed
)
sample
=
mean
+
sample_norm
*
sd
sample
=
mean
+
sample_norm
*
sd
return
sample
return
sample
return
None
mindspore/nn/probability/distribution/uniform.py
浏览文件 @
e87e1fc6
...
@@ -35,55 +35,56 @@ class Uniform(Distribution):
...
@@ -35,55 +35,56 @@ class Uniform(Distribution):
Examples:
Examples:
>>> # To initialize a Uniform distribution of mean 3.0 and standard deviation 4.0
>>> # To initialize a Uniform distribution of mean 3.0 and standard deviation 4.0
>>> n = nn.Uniform(0.0, 1.0, dtype=mstype.float32)
>>> import mindspore.nn.probability.distribution as msd
>>> u = msd.Uniform(0.0, 1.0, dtype=mstype.float32)
>>>
>>>
>>> # The following creates two independent Uniform distributions
>>> # The following creates two independent Uniform distributions
>>>
n = nn
.Uniform([0.0, 0.0], [1.0, 2.0], dtype=mstype.float32)
>>>
u = msd
.Uniform([0.0, 0.0], [1.0, 2.0], dtype=mstype.float32)
>>>
>>>
>>> # A Uniform distribution can be initilized without arguments
>>> # A Uniform distribution can be initilized without arguments
>>> # In this case, high and low must be passed in through
construct
.
>>> # In this case, high and low must be passed in through
args during function calls
.
>>>
n = nn
.Uniform(dtype=mstype.float32)
>>>
u = msd
.Uniform(dtype=mstype.float32)
>>>
>>>
>>> # To use Uniform in a network
>>> # To use Uniform in a network
>>> class net(Cell):
>>> class net(Cell):
>>> def __init__(self)
>>> def __init__(self)
>>> super(net, self).__init__():
>>> super(net, self).__init__():
>>> self.u1 =
nn
.Uniform(0.0, 1.0, dtype=mstype.float32)
>>> self.u1 =
msd
.Uniform(0.0, 1.0, dtype=mstype.float32)
>>> self.u2 =
nn
.Uniform(dtype=mstype.float32)
>>> self.u2 =
msd
.Uniform(dtype=mstype.float32)
>>>
>>>
>>> # All the following calls in construct are valid
>>> # All the following calls in construct are valid
>>> def construct(self, value, low_b, high_b, low_a, high_a):
>>> def construct(self, value, low_b, high_b, low_a, high_a):
>>>
>>>
>>> # Similar calls can be made to other probability functions
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'prob' with the name of the function
>>> # by replacing 'prob' with the name of the function
>>> ans = self.u1
('prob',
value)
>>> ans = self.u1
.prob(
value)
>>> # Evaluate with the respect to distribution b
>>> # Evaluate with the respect to distribution b
>>> ans = self.u1
('prob',
value, low_b, high_b)
>>> ans = self.u1
.prob(
value, low_b, high_b)
>>>
>>>
>>> # High and low must be passed in
through construct
>>> # High and low must be passed in
during function calls
>>> ans = self.u2
('prob',
value, low_a, high_a)
>>> ans = self.u2
.prob(
value, low_a, high_a)
>>>
>>>
>>> # Functions 'sd', 'var', 'entropy' have the same usage
with
'mean'
>>> # Functions 'sd', 'var', 'entropy' have the same usage
as
'mean'
>>> # Will return
[0.0]
>>> # Will return
0.5
>>> ans = self.u1
('mean'
)
>>> ans = self.u1
.mean(
)
>>> # Will return
low_b
>>> # Will return
(low_b + high_b) / 2
>>> ans = self.u1
('mean',
low_b, high_b)
>>> ans = self.u1
.mean(
low_b, high_b)
>>>
>>>
>>> # High and low must be passed in
through construct
>>> # High and low must be passed in
during function calls
>>> ans = self.u2
('mean',
low_a, high_a)
>>> ans = self.u2
.mean(
low_a, high_a)
>>>
>>>
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
>>> # Usage of 'kl_loss' and 'cross_entropy' are similar
>>> ans = self.u1
('kl_loss',
'Uniform', low_b, high_b)
>>> ans = self.u1
.kl_loss(
'Uniform', low_b, high_b)
>>> ans = self.u1
('kl_loss',
'Uniform', low_b, high_b, low_a, high_a)
>>> ans = self.u1
.kl_loss(
'Uniform', low_b, high_b, low_a, high_a)
>>>
>>>
>>> # Additional high and low must be passed
in through construct
>>> # Additional high and low must be passed
>>> ans = self.u2
('kl_loss',
'Uniform', low_b, high_b, low_a, high_a)
>>> ans = self.u2
.kl_loss(
'Uniform', low_b, high_b, low_a, high_a)
>>>
>>>
>>> # Sample
Usage
>>> # Sample
>>> ans = self.u1
('sample'
)
>>> ans = self.u1
.sample(
)
>>> ans = self.u1
('sample',
(2,3))
>>> ans = self.u1
.sample(
(2,3))
>>> ans = self.u1
('sample',
(2,3), low_b, high_b)
>>> ans = self.u1
.sample(
(2,3), low_b, high_b)
>>> ans = self.u2
('sample',
(2,3), low_a, high_a)
>>> ans = self.u2
.sample(
(2,3), low_a, high_a)
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -142,73 +143,64 @@ class Uniform(Distribution):
...
@@ -142,73 +143,64 @@ class Uniform(Distribution):
"""
"""
return
self
.
_high
return
self
.
_high
def
_range
(
self
,
name
=
'range'
,
low
=
None
,
high
=
None
):
def
_range
(
self
,
low
=
None
,
high
=
None
):
r
"""
r
"""
Return the range of the distribution.
Return the range of the distribution.
.. math::
.. math::
range(U) = high -low
range(U) = high -low
"""
"""
if
name
==
'range'
:
low
=
self
.
low
if
low
is
None
else
low
low
=
self
.
low
if
low
is
None
else
low
high
=
self
.
high
if
high
is
None
else
high
high
=
self
.
high
if
high
is
None
else
high
return
high
-
low
return
high
-
low
return
None
def
_mean
(
self
,
name
=
'mean'
,
low
=
None
,
high
=
None
):
def
_mean
(
self
,
low
=
None
,
high
=
None
):
r
"""
r
"""
.. math::
.. math::
MEAN(U) = \fract{low + high}{2}.
MEAN(U) = \fract{low + high}{2}.
"""
"""
if
name
==
'mean'
:
low
=
self
.
low
if
low
is
None
else
low
low
=
self
.
low
if
low
is
None
else
low
high
=
self
.
high
if
high
is
None
else
high
high
=
self
.
high
if
high
is
None
else
high
return
(
low
+
high
)
/
2.
return
(
low
+
high
)
/
2.
return
None
def
_var
(
self
,
name
=
'var'
,
low
=
None
,
high
=
None
):
def
_var
(
self
,
low
=
None
,
high
=
None
):
r
"""
r
"""
.. math::
.. math::
VAR(U) = \fract{(high -low) ^ 2}{12}.
VAR(U) = \fract{(high -low) ^ 2}{12}.
"""
"""
if
name
in
self
.
_variance_functions
:
low
=
self
.
low
if
low
is
None
else
low
low
=
self
.
low
if
low
is
None
else
low
high
=
self
.
high
if
high
is
None
else
high
high
=
self
.
high
if
high
is
None
else
high
return
self
.
sq
(
high
-
low
)
/
12.0
return
self
.
sq
(
high
-
low
)
/
12.0
return
None
def
_entropy
(
self
,
name
=
'entropy'
,
low
=
None
,
high
=
None
):
def
_entropy
(
self
,
low
=
None
,
high
=
None
):
r
"""
r
"""
.. math::
.. math::
H(U) = \log(high - low).
H(U) = \log(high - low).
"""
"""
if
name
==
'entropy'
:
low
=
self
.
low
if
low
is
None
else
low
low
=
self
.
low
if
low
is
None
else
low
high
=
self
.
high
if
high
is
None
else
high
high
=
self
.
high
if
high
is
None
else
high
return
self
.
log
(
high
-
low
)
return
self
.
log
(
high
-
low
)
return
None
def
_cross_entropy
(
self
,
name
,
dist
,
low_b
,
high_b
,
low_a
=
None
,
high_a
=
None
):
def
_cross_entropy
(
self
,
dist
,
low_b
,
high_b
,
low_a
=
None
,
high_a
=
None
):
"""
"""
Evaluate cross_entropy between Uniform distributoins.
Evaluate cross_entropy between Uniform distributoins.
Args:
Args:
name (str): name of the funtion.
dist (str): type of the distributions. Should be "Uniform" in this case.
dist (str): type of the distributions. Should be "Uniform" in this case.
low_b (Tensor): lower bound of distribution b.
low_b (Tensor): lower bound of distribution b.
high_b (Tensor): upper bound of distribution b.
high_b (Tensor): upper bound of distribution b.
low_a (Tensor): lower bound of distribution a. Default: self.low.
low_a (Tensor): lower bound of distribution a. Default: self.low.
high_a (Tensor): upper bound of distribution a. Default: self.high.
high_a (Tensor): upper bound of distribution a. Default: self.high.
"""
"""
if
name
==
'cross_entropy'
and
dist
==
'Uniform'
:
if
dist
==
'Uniform'
:
return
self
.
_entropy
(
low
=
low_a
,
high
=
high_a
)
+
self
.
_kl_loss
(
name
,
dist
,
low_b
,
high_b
,
low_a
,
high_a
)
return
self
.
_entropy
(
low
=
low_a
,
high
=
high_a
)
+
self
.
_kl_loss
(
dist
,
low_b
,
high_b
,
low_a
,
high_a
)
return
None
return
None
def
_prob
(
self
,
name
,
value
,
low
=
None
,
high
=
None
):
def
_prob
(
self
,
value
,
low
=
None
,
high
=
None
):
r
"""
r
"""
pdf of Uniform distribution.
pdf of Uniform distribution.
Args:
Args:
name (str): name of the function.
value (Tensor): value to be evaluated.
value (Tensor): value to be evaluated.
low (Tensor): lower bound of the distribution. Default: self.low.
low (Tensor): lower bound of the distribution. Default: self.low.
high (Tensor): upper bound of the distribution. Default: self.high.
high (Tensor): upper bound of the distribution. Default: self.high.
...
@@ -218,32 +210,29 @@ class Uniform(Distribution):
...
@@ -218,32 +210,29 @@ class Uniform(Distribution):
pdf(x) = \fract{1.0}{high -low} if low <= x <= high;
pdf(x) = \fract{1.0}{high -low} if low <= x <= high;
pdf(x) = 0 if x > high;
pdf(x) = 0 if x > high;
"""
"""
if
name
in
self
.
_prob_functions
:
low
=
self
.
low
if
low
is
None
else
low
low
=
self
.
low
if
low
is
None
else
low
high
=
self
.
high
if
high
is
None
else
high
high
=
self
.
high
if
high
is
None
else
high
ones
=
self
.
fill
(
self
.
dtype
,
self
.
shape
(
value
),
1.0
)
ones
=
self
.
fill
(
self
.
dtype
,
self
.
shape
(
value
),
1.0
)
prob
=
ones
/
(
high
-
low
)
prob
=
ones
/
(
high
-
low
)
broadcast_shape
=
self
.
shape
(
prob
)
broadcast_shape
=
self
.
shape
(
prob
)
zeros
=
self
.
fill
(
self
.
dtypeop
(
prob
),
broadcast_shape
,
0.0
)
zeros
=
self
.
fill
(
self
.
dtypeop
(
prob
),
broadcast_shape
,
0.0
)
comp_lo
=
self
.
less
(
value
,
low
)
comp_lo
=
self
.
less
(
value
,
low
)
comp_hi
=
self
.
lessequal
(
value
,
high
)
comp_hi
=
self
.
lessequal
(
value
,
high
)
less_than_low
=
self
.
select
(
comp_lo
,
zeros
,
prob
)
less_than_low
=
self
.
select
(
comp_lo
,
zeros
,
prob
)
return
self
.
select
(
comp_hi
,
less_than_low
,
zeros
)
return
self
.
select
(
comp_hi
,
less_than_low
,
zeros
)
return
None
def
_kl_loss
(
self
,
name
,
dist
,
low_b
,
high_b
,
low_a
=
None
,
high_a
=
None
):
def
_kl_loss
(
self
,
dist
,
low_b
,
high_b
,
low_a
=
None
,
high_a
=
None
):
"""
"""
Evaluate uniform-uniform kl divergence, i.e. KL(a||b).
Evaluate uniform-uniform kl divergence, i.e. KL(a||b).
Args:
Args:
name (str): name of the funtion.
dist (str): type of the distributions. Should be "Uniform" in this case.
dist (str): type of the distributions. Should be "Uniform" in this case.
low_b (Tensor): lower bound of distribution b.
low_b (Tensor): lower bound of distribution b.
high_b (Tensor): upper bound of distribution b.
high_b (Tensor): upper bound of distribution b.
low_a (Tensor): lower bound of distribution a. Default: self.low.
low_a (Tensor): lower bound of distribution a. Default: self.low.
high_a (Tensor): upper bound of distribution a. Default: self.high.
high_a (Tensor): upper bound of distribution a. Default: self.high.
"""
"""
if
name
in
self
.
_divergence_functions
and
dist
==
'Uniform'
:
if
dist
==
'Uniform'
:
low_a
=
self
.
low
if
low_a
is
None
else
low_a
low_a
=
self
.
low
if
low_a
is
None
else
low_a
high_a
=
self
.
high
if
high_a
is
None
else
high_a
high_a
=
self
.
high
if
high_a
is
None
else
high_a
kl
=
self
.
log
(
high_b
-
low_b
)
/
self
.
log
(
high_a
-
low_a
)
kl
=
self
.
log
(
high_b
-
low_b
)
/
self
.
log
(
high_a
-
low_a
)
...
@@ -251,12 +240,11 @@ class Uniform(Distribution):
...
@@ -251,12 +240,11 @@ class Uniform(Distribution):
return
self
.
select
(
comp
,
kl
,
self
.
log
(
self
.
zeroslike
(
kl
)))
return
self
.
select
(
comp
,
kl
,
self
.
log
(
self
.
zeroslike
(
kl
)))
return
None
return
None
def
_cdf
(
self
,
name
,
value
,
low
=
None
,
high
=
None
):
def
_cdf
(
self
,
value
,
low
=
None
,
high
=
None
):
r
"""
r
"""
cdf of Uniform distribution.
cdf of Uniform distribution.
Args:
Args:
name (str): name of the function.
value (Tensor): value to be evaluated.
value (Tensor): value to be evaluated.
low (Tensor): lower bound of the distribution. Default: self.low.
low (Tensor): lower bound of the distribution. Default: self.low.
high (Tensor): upper bound of the distribution. Default: self.high.
high (Tensor): upper bound of the distribution. Default: self.high.
...
@@ -266,25 +254,22 @@ class Uniform(Distribution):
...
@@ -266,25 +254,22 @@ class Uniform(Distribution):
cdf(x) = \fract{x - low}{high -low} if low <= x <= high;
cdf(x) = \fract{x - low}{high -low} if low <= x <= high;
cdf(x) = 1 if x > high;
cdf(x) = 1 if x > high;
"""
"""
if
name
in
self
.
_cdf_survival_functions
:
low
=
self
.
low
if
low
is
None
else
low
low
=
self
.
low
if
low
is
None
else
low
high
=
self
.
high
if
high
is
None
else
high
high
=
self
.
high
if
high
is
None
else
high
prob
=
(
value
-
low
)
/
(
high
-
low
)
prob
=
(
value
-
low
)
/
(
high
-
low
)
broadcast_shape
=
self
.
shape
(
prob
)
broadcast_shape
=
self
.
shape
(
prob
)
zeros
=
self
.
fill
(
self
.
dtypeop
(
prob
),
broadcast_shape
,
0.0
)
zeros
=
self
.
fill
(
self
.
dtypeop
(
prob
),
broadcast_shape
,
0.0
)
ones
=
self
.
fill
(
self
.
dtypeop
(
prob
),
broadcast_shape
,
1.0
)
ones
=
self
.
fill
(
self
.
dtypeop
(
prob
),
broadcast_shape
,
1.0
)
comp_lo
=
self
.
less
(
value
,
low
)
comp_lo
=
self
.
less
(
value
,
low
)
comp_hi
=
self
.
less
(
value
,
high
)
comp_hi
=
self
.
less
(
value
,
high
)
less_than_low
=
self
.
select
(
comp_lo
,
zeros
,
prob
)
less_than_low
=
self
.
select
(
comp_lo
,
zeros
,
prob
)
return
self
.
select
(
comp_hi
,
less_than_low
,
ones
)
return
self
.
select
(
comp_hi
,
less_than_low
,
ones
)
return
None
def
_sample
(
self
,
name
,
shape
=
(),
low
=
None
,
high
=
None
):
def
_sample
(
self
,
shape
=
(),
low
=
None
,
high
=
None
):
"""
"""
Sampling.
Sampling.
Args:
Args:
name (str): name of the function. Should always be 'sample' when passed in from construct.
shape (tuple): shape of the sample. Default: ().
shape (tuple): shape of the sample. Default: ().
low (Tensor): lower bound of the distribution. Default: self.low.
low (Tensor): lower bound of the distribution. Default: self.low.
high (Tensor): upper bound of the distribution. Default: self.high.
high (Tensor): upper bound of the distribution. Default: self.high.
...
@@ -292,13 +277,11 @@ class Uniform(Distribution):
...
@@ -292,13 +277,11 @@ class Uniform(Distribution):
Returns:
Returns:
Tensor, shape is shape + batch_shape.
Tensor, shape is shape + batch_shape.
"""
"""
if
name
==
'sample'
:
low
=
self
.
low
if
low
is
None
else
low
low
=
self
.
low
if
low
is
None
else
low
high
=
self
.
high
if
high
is
None
else
high
high
=
self
.
high
if
high
is
None
else
high
broadcast_shape
=
self
.
shape
(
low
+
high
)
broadcast_shape
=
self
.
shape
(
low
+
high
)
l_zero
=
self
.
const
(
0.0
)
l_zero
=
self
.
const
(
0.0
)
h_one
=
self
.
const
(
1.0
)
h_one
=
self
.
const
(
1.0
)
sample_uniform
=
self
.
uniform
(
shape
+
broadcast_shape
,
l_zero
,
h_one
)
sample_uniform
=
self
.
uniform
(
shape
+
broadcast_shape
,
l_zero
,
h_one
)
sample
=
(
high
-
low
)
*
sample_uniform
+
low
sample
=
(
high
-
low
)
*
sample_uniform
+
low
return
sample
return
sample
return
None
tests/st/ops/ascend/test_distribution/test_bernoulli.py
浏览文件 @
e87e1fc6
...
@@ -19,7 +19,6 @@ import mindspore.context as context
...
@@ -19,7 +19,6 @@ import mindspore.context as context
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
import
mindspore.nn.probability.distribution
as
msd
import
mindspore.nn.probability.distribution
as
msd
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindspore.common.api
import
ms_function
from
mindspore
import
dtype
from
mindspore
import
dtype
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
...
@@ -32,9 +31,8 @@ class Prob(nn.Cell):
...
@@ -32,9 +31,8 @@ class Prob(nn.Cell):
super
(
Prob
,
self
).
__init__
()
super
(
Prob
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
self
.
b
=
msd
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
b
(
'prob'
,
x_
)
return
self
.
b
.
prob
(
x_
)
def
test_pmf
():
def
test_pmf
():
"""
"""
...
@@ -57,9 +55,8 @@ class LogProb(nn.Cell):
...
@@ -57,9 +55,8 @@ class LogProb(nn.Cell):
super
(
LogProb
,
self
).
__init__
()
super
(
LogProb
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
self
.
b
=
msd
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
b
(
'log_prob'
,
x_
)
return
self
.
b
.
log_prob
(
x_
)
def
test_log_likelihood
():
def
test_log_likelihood
():
"""
"""
...
@@ -81,9 +78,8 @@ class KL(nn.Cell):
...
@@ -81,9 +78,8 @@ class KL(nn.Cell):
super
(
KL
,
self
).
__init__
()
super
(
KL
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
self
.
b
=
msd
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
b
(
'kl_loss'
,
'Bernoulli'
,
x_
)
return
self
.
b
.
kl_loss
(
'Bernoulli'
,
x_
)
def
test_kl_loss
():
def
test_kl_loss
():
"""
"""
...
@@ -107,9 +103,8 @@ class Basics(nn.Cell):
...
@@ -107,9 +103,8 @@ class Basics(nn.Cell):
super
(
Basics
,
self
).
__init__
()
super
(
Basics
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
([
0.3
,
0.5
,
0.7
],
dtype
=
dtype
.
int32
)
self
.
b
=
msd
.
Bernoulli
([
0.3
,
0.5
,
0.7
],
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
):
def
construct
(
self
):
return
self
.
b
(
'mean'
),
self
.
b
(
'sd'
),
self
.
b
(
'mode'
)
return
self
.
b
.
mean
(),
self
.
b
.
sd
(),
self
.
b
.
mode
(
)
def
test_basics
():
def
test_basics
():
"""
"""
...
@@ -134,9 +129,8 @@ class Sampling(nn.Cell):
...
@@ -134,9 +129,8 @@ class Sampling(nn.Cell):
self
.
b
=
msd
.
Bernoulli
([
0.7
,
0.5
],
seed
=
seed
,
dtype
=
dtype
.
int32
)
self
.
b
=
msd
.
Bernoulli
([
0.7
,
0.5
],
seed
=
seed
,
dtype
=
dtype
.
int32
)
self
.
shape
=
shape
self
.
shape
=
shape
@
ms_function
def
construct
(
self
,
probs
=
None
):
def
construct
(
self
,
probs
=
None
):
return
self
.
b
(
'sample'
,
self
.
shape
,
probs
)
return
self
.
b
.
sample
(
self
.
shape
,
probs
)
def
test_sample
():
def
test_sample
():
"""
"""
...
@@ -155,9 +149,8 @@ class CDF(nn.Cell):
...
@@ -155,9 +149,8 @@ class CDF(nn.Cell):
super
(
CDF
,
self
).
__init__
()
super
(
CDF
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
self
.
b
=
msd
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
b
(
'cdf'
,
x_
)
return
self
.
b
.
cdf
(
x_
)
def
test_cdf
():
def
test_cdf
():
"""
"""
...
@@ -171,7 +164,6 @@ def test_cdf():
...
@@ -171,7 +164,6 @@ def test_cdf():
tol
=
1e-6
tol
=
1e-6
assert
(
np
.
abs
(
output
.
asnumpy
()
-
expect_cdf
)
<
tol
).
all
()
assert
(
np
.
abs
(
output
.
asnumpy
()
-
expect_cdf
)
<
tol
).
all
()
class
LogCDF
(
nn
.
Cell
):
class
LogCDF
(
nn
.
Cell
):
"""
"""
Test class: log cdf of bernoulli distributions.
Test class: log cdf of bernoulli distributions.
...
@@ -180,9 +172,8 @@ class LogCDF(nn.Cell):
...
@@ -180,9 +172,8 @@ class LogCDF(nn.Cell):
super
(
LogCDF
,
self
).
__init__
()
super
(
LogCDF
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
self
.
b
=
msd
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
b
(
'log_cdf'
,
x_
)
return
self
.
b
.
log_cdf
(
x_
)
def
test_logcdf
():
def
test_logcdf
():
"""
"""
...
@@ -205,9 +196,8 @@ class SF(nn.Cell):
...
@@ -205,9 +196,8 @@ class SF(nn.Cell):
super
(
SF
,
self
).
__init__
()
super
(
SF
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
self
.
b
=
msd
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
b
(
'survival_function'
,
x_
)
return
self
.
b
.
survival_function
(
x_
)
def
test_survival
():
def
test_survival
():
"""
"""
...
@@ -230,9 +220,8 @@ class LogSF(nn.Cell):
...
@@ -230,9 +220,8 @@ class LogSF(nn.Cell):
super
(
LogSF
,
self
).
__init__
()
super
(
LogSF
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
self
.
b
=
msd
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
b
(
'log_survival'
,
x_
)
return
self
.
b
.
log_survival
(
x_
)
def
test_log_survival
():
def
test_log_survival
():
"""
"""
...
@@ -254,9 +243,8 @@ class EntropyH(nn.Cell):
...
@@ -254,9 +243,8 @@ class EntropyH(nn.Cell):
super
(
EntropyH
,
self
).
__init__
()
super
(
EntropyH
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
self
.
b
=
msd
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
):
def
construct
(
self
):
return
self
.
b
(
'entropy'
)
return
self
.
b
.
entropy
(
)
def
test_entropy
():
def
test_entropy
():
"""
"""
...
@@ -277,12 +265,11 @@ class CrossEntropy(nn.Cell):
...
@@ -277,12 +265,11 @@ class CrossEntropy(nn.Cell):
super
(
CrossEntropy
,
self
).
__init__
()
super
(
CrossEntropy
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
self
.
b
=
msd
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
entropy
=
self
.
b
(
'entropy'
)
entropy
=
self
.
b
.
entropy
(
)
kl_loss
=
self
.
b
(
'kl_loss'
,
'Bernoulli'
,
x_
)
kl_loss
=
self
.
b
.
kl_loss
(
'Bernoulli'
,
x_
)
h_sum_kl
=
entropy
+
kl_loss
h_sum_kl
=
entropy
+
kl_loss
cross_entropy
=
self
.
b
(
'cross_entropy'
,
'Bernoulli'
,
x_
)
cross_entropy
=
self
.
b
.
cross_entropy
(
'Bernoulli'
,
x_
)
return
h_sum_kl
-
cross_entropy
return
h_sum_kl
-
cross_entropy
def
test_cross_entropy
():
def
test_cross_entropy
():
...
...
tests/st/ops/ascend/test_distribution/test_exponential.py
浏览文件 @
e87e1fc6
...
@@ -19,7 +19,6 @@ import mindspore.context as context
...
@@ -19,7 +19,6 @@ import mindspore.context as context
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
import
mindspore.nn.probability.distribution
as
msd
import
mindspore.nn.probability.distribution
as
msd
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindspore.common.api
import
ms_function
from
mindspore
import
dtype
from
mindspore
import
dtype
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
...
@@ -32,9 +31,8 @@ class Prob(nn.Cell):
...
@@ -32,9 +31,8 @@ class Prob(nn.Cell):
super
(
Prob
,
self
).
__init__
()
super
(
Prob
,
self
).
__init__
()
self
.
e
=
msd
.
Exponential
([[
1.0
],
[
0.5
]],
dtype
=
dtype
.
float32
)
self
.
e
=
msd
.
Exponential
([[
1.0
],
[
0.5
]],
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
e
(
'prob'
,
x_
)
return
self
.
e
.
prob
(
x_
)
def
test_pdf
():
def
test_pdf
():
"""
"""
...
@@ -56,9 +54,8 @@ class LogProb(nn.Cell):
...
@@ -56,9 +54,8 @@ class LogProb(nn.Cell):
super
(
LogProb
,
self
).
__init__
()
super
(
LogProb
,
self
).
__init__
()
self
.
e
=
msd
.
Exponential
([[
1.0
],
[
0.5
]],
dtype
=
dtype
.
float32
)
self
.
e
=
msd
.
Exponential
([[
1.0
],
[
0.5
]],
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
e
(
'log_prob'
,
x_
)
return
self
.
e
.
log_prob
(
x_
)
def
test_log_likelihood
():
def
test_log_likelihood
():
"""
"""
...
@@ -80,9 +77,8 @@ class KL(nn.Cell):
...
@@ -80,9 +77,8 @@ class KL(nn.Cell):
super
(
KL
,
self
).
__init__
()
super
(
KL
,
self
).
__init__
()
self
.
e
=
msd
.
Exponential
([
1.5
],
dtype
=
dtype
.
float32
)
self
.
e
=
msd
.
Exponential
([
1.5
],
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
e
(
'kl_loss'
,
'Exponential'
,
x_
)
return
self
.
e
.
kl_loss
(
'Exponential'
,
x_
)
def
test_kl_loss
():
def
test_kl_loss
():
"""
"""
...
@@ -104,9 +100,8 @@ class Basics(nn.Cell):
...
@@ -104,9 +100,8 @@ class Basics(nn.Cell):
super
(
Basics
,
self
).
__init__
()
super
(
Basics
,
self
).
__init__
()
self
.
e
=
msd
.
Exponential
([
0.5
],
dtype
=
dtype
.
float32
)
self
.
e
=
msd
.
Exponential
([
0.5
],
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
):
def
construct
(
self
):
return
self
.
e
(
'mean'
),
self
.
e
(
'sd'
),
self
.
e
(
'mode'
)
return
self
.
e
.
mean
(),
self
.
e
.
sd
(),
self
.
e
.
mode
(
)
def
test_basics
():
def
test_basics
():
"""
"""
...
@@ -131,9 +126,8 @@ class Sampling(nn.Cell):
...
@@ -131,9 +126,8 @@ class Sampling(nn.Cell):
self
.
e
=
msd
.
Exponential
([[
1.0
],
[
0.5
]],
seed
=
seed
,
dtype
=
dtype
.
float32
)
self
.
e
=
msd
.
Exponential
([[
1.0
],
[
0.5
]],
seed
=
seed
,
dtype
=
dtype
.
float32
)
self
.
shape
=
shape
self
.
shape
=
shape
@
ms_function
def
construct
(
self
,
rate
=
None
):
def
construct
(
self
,
rate
=
None
):
return
self
.
e
(
'sample'
,
self
.
shape
,
rate
)
return
self
.
e
.
sample
(
self
.
shape
,
rate
)
def
test_sample
():
def
test_sample
():
"""
"""
...
@@ -154,9 +148,8 @@ class CDF(nn.Cell):
...
@@ -154,9 +148,8 @@ class CDF(nn.Cell):
super
(
CDF
,
self
).
__init__
()
super
(
CDF
,
self
).
__init__
()
self
.
e
=
msd
.
Exponential
([[
1.0
],
[
0.5
]],
dtype
=
dtype
.
float32
)
self
.
e
=
msd
.
Exponential
([[
1.0
],
[
0.5
]],
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
e
(
'cdf'
,
x_
)
return
self
.
e
.
cdf
(
x_
)
def
test_cdf
():
def
test_cdf
():
"""
"""
...
@@ -178,9 +171,8 @@ class LogCDF(nn.Cell):
...
@@ -178,9 +171,8 @@ class LogCDF(nn.Cell):
super
(
LogCDF
,
self
).
__init__
()
super
(
LogCDF
,
self
).
__init__
()
self
.
e
=
msd
.
Exponential
([[
1.0
],
[
0.5
]],
dtype
=
dtype
.
float32
)
self
.
e
=
msd
.
Exponential
([[
1.0
],
[
0.5
]],
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
e
(
'log_cdf'
,
x_
)
return
self
.
e
.
log_cdf
(
x_
)
def
test_log_cdf
():
def
test_log_cdf
():
"""
"""
...
@@ -202,9 +194,8 @@ class SF(nn.Cell):
...
@@ -202,9 +194,8 @@ class SF(nn.Cell):
super
(
SF
,
self
).
__init__
()
super
(
SF
,
self
).
__init__
()
self
.
e
=
msd
.
Exponential
([[
1.0
],
[
0.5
]],
dtype
=
dtype
.
float32
)
self
.
e
=
msd
.
Exponential
([[
1.0
],
[
0.5
]],
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
e
(
'survival_function'
,
x_
)
return
self
.
e
.
survival_function
(
x_
)
def
test_survival
():
def
test_survival
():
"""
"""
...
@@ -226,9 +217,8 @@ class LogSF(nn.Cell):
...
@@ -226,9 +217,8 @@ class LogSF(nn.Cell):
super
(
LogSF
,
self
).
__init__
()
super
(
LogSF
,
self
).
__init__
()
self
.
e
=
msd
.
Exponential
([[
1.0
],
[
0.5
]],
dtype
=
dtype
.
float32
)
self
.
e
=
msd
.
Exponential
([[
1.0
],
[
0.5
]],
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
e
(
'log_survival'
,
x_
)
return
self
.
e
.
log_survival
(
x_
)
def
test_log_survival
():
def
test_log_survival
():
"""
"""
...
@@ -250,9 +240,8 @@ class EntropyH(nn.Cell):
...
@@ -250,9 +240,8 @@ class EntropyH(nn.Cell):
super
(
EntropyH
,
self
).
__init__
()
super
(
EntropyH
,
self
).
__init__
()
self
.
e
=
msd
.
Exponential
([[
1.0
],
[
0.5
]],
dtype
=
dtype
.
float32
)
self
.
e
=
msd
.
Exponential
([[
1.0
],
[
0.5
]],
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
):
def
construct
(
self
):
return
self
.
e
(
'entropy'
)
return
self
.
e
.
entropy
(
)
def
test_entropy
():
def
test_entropy
():
"""
"""
...
@@ -273,12 +262,11 @@ class CrossEntropy(nn.Cell):
...
@@ -273,12 +262,11 @@ class CrossEntropy(nn.Cell):
super
(
CrossEntropy
,
self
).
__init__
()
super
(
CrossEntropy
,
self
).
__init__
()
self
.
e
=
msd
.
Exponential
([
1.0
],
dtype
=
dtype
.
float32
)
self
.
e
=
msd
.
Exponential
([
1.0
],
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
entropy
=
self
.
e
(
'entropy'
)
entropy
=
self
.
e
.
entropy
(
)
kl_loss
=
self
.
e
(
'kl_loss'
,
'Exponential'
,
x_
)
kl_loss
=
self
.
e
.
kl_loss
(
'Exponential'
,
x_
)
h_sum_kl
=
entropy
+
kl_loss
h_sum_kl
=
entropy
+
kl_loss
cross_entropy
=
self
.
e
(
'cross_entropy'
,
'Exponential'
,
x_
)
cross_entropy
=
self
.
e
.
cross_entropy
(
'Exponential'
,
x_
)
return
h_sum_kl
-
cross_entropy
return
h_sum_kl
-
cross_entropy
def
test_cross_entropy
():
def
test_cross_entropy
():
...
...
tests/st/ops/ascend/test_distribution/test_geometric.py
浏览文件 @
e87e1fc6
...
@@ -19,7 +19,6 @@ import mindspore.context as context
...
@@ -19,7 +19,6 @@ import mindspore.context as context
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
import
mindspore.nn.probability.distribution
as
msd
import
mindspore.nn.probability.distribution
as
msd
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindspore.common.api
import
ms_function
from
mindspore
import
dtype
from
mindspore
import
dtype
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
...
@@ -32,9 +31,8 @@ class Prob(nn.Cell):
...
@@ -32,9 +31,8 @@ class Prob(nn.Cell):
super
(
Prob
,
self
).
__init__
()
super
(
Prob
,
self
).
__init__
()
self
.
g
=
msd
.
Geometric
(
0.7
,
dtype
=
dtype
.
int32
)
self
.
g
=
msd
.
Geometric
(
0.7
,
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
g
(
'prob'
,
x_
)
return
self
.
g
.
prob
(
x_
)
def
test_pmf
():
def
test_pmf
():
"""
"""
...
@@ -56,9 +54,8 @@ class LogProb(nn.Cell):
...
@@ -56,9 +54,8 @@ class LogProb(nn.Cell):
super
(
LogProb
,
self
).
__init__
()
super
(
LogProb
,
self
).
__init__
()
self
.
g
=
msd
.
Geometric
(
0.7
,
dtype
=
dtype
.
int32
)
self
.
g
=
msd
.
Geometric
(
0.7
,
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
g
(
'log_prob'
,
x_
)
return
self
.
g
.
log_prob
(
x_
)
def
test_log_likelihood
():
def
test_log_likelihood
():
"""
"""
...
@@ -80,9 +77,8 @@ class KL(nn.Cell):
...
@@ -80,9 +77,8 @@ class KL(nn.Cell):
super
(
KL
,
self
).
__init__
()
super
(
KL
,
self
).
__init__
()
self
.
g
=
msd
.
Geometric
(
0.7
,
dtype
=
dtype
.
int32
)
self
.
g
=
msd
.
Geometric
(
0.7
,
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
g
(
'kl_loss'
,
'Geometric'
,
x_
)
return
self
.
g
.
kl_loss
(
'Geometric'
,
x_
)
def
test_kl_loss
():
def
test_kl_loss
():
"""
"""
...
@@ -106,9 +102,8 @@ class Basics(nn.Cell):
...
@@ -106,9 +102,8 @@ class Basics(nn.Cell):
super
(
Basics
,
self
).
__init__
()
super
(
Basics
,
self
).
__init__
()
self
.
g
=
msd
.
Geometric
([
0.5
,
0.5
],
dtype
=
dtype
.
int32
)
self
.
g
=
msd
.
Geometric
([
0.5
,
0.5
],
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
):
def
construct
(
self
):
return
self
.
g
(
'mean'
),
self
.
g
(
'sd'
),
self
.
g
(
'mode'
)
return
self
.
g
.
mean
(),
self
.
g
.
sd
(),
self
.
g
.
mode
(
)
def
test_basics
():
def
test_basics
():
"""
"""
...
@@ -133,9 +128,8 @@ class Sampling(nn.Cell):
...
@@ -133,9 +128,8 @@ class Sampling(nn.Cell):
self
.
g
=
msd
.
Geometric
([
0.7
,
0.5
],
seed
=
seed
,
dtype
=
dtype
.
int32
)
self
.
g
=
msd
.
Geometric
([
0.7
,
0.5
],
seed
=
seed
,
dtype
=
dtype
.
int32
)
self
.
shape
=
shape
self
.
shape
=
shape
@
ms_function
def
construct
(
self
,
probs
=
None
):
def
construct
(
self
,
probs
=
None
):
return
self
.
g
(
'sample'
,
self
.
shape
,
probs
)
return
self
.
g
.
sample
(
self
.
shape
,
probs
)
def
test_sample
():
def
test_sample
():
"""
"""
...
@@ -154,9 +148,8 @@ class CDF(nn.Cell):
...
@@ -154,9 +148,8 @@ class CDF(nn.Cell):
super
(
CDF
,
self
).
__init__
()
super
(
CDF
,
self
).
__init__
()
self
.
g
=
msd
.
Geometric
(
0.7
,
dtype
=
dtype
.
int32
)
self
.
g
=
msd
.
Geometric
(
0.7
,
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
g
(
'cdf'
,
x_
)
return
self
.
g
.
cdf
(
x_
)
def
test_cdf
():
def
test_cdf
():
"""
"""
...
@@ -178,9 +171,8 @@ class LogCDF(nn.Cell):
...
@@ -178,9 +171,8 @@ class LogCDF(nn.Cell):
super
(
LogCDF
,
self
).
__init__
()
super
(
LogCDF
,
self
).
__init__
()
self
.
g
=
msd
.
Geometric
(
0.7
,
dtype
=
dtype
.
int32
)
self
.
g
=
msd
.
Geometric
(
0.7
,
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
g
(
'log_cdf'
,
x_
)
return
self
.
g
.
log_cdf
(
x_
)
def
test_logcdf
():
def
test_logcdf
():
"""
"""
...
@@ -202,9 +194,8 @@ class SF(nn.Cell):
...
@@ -202,9 +194,8 @@ class SF(nn.Cell):
super
(
SF
,
self
).
__init__
()
super
(
SF
,
self
).
__init__
()
self
.
g
=
msd
.
Geometric
(
0.7
,
dtype
=
dtype
.
int32
)
self
.
g
=
msd
.
Geometric
(
0.7
,
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
g
(
'survival_function'
,
x_
)
return
self
.
g
.
survival_function
(
x_
)
def
test_survival
():
def
test_survival
():
"""
"""
...
@@ -226,9 +217,8 @@ class LogSF(nn.Cell):
...
@@ -226,9 +217,8 @@ class LogSF(nn.Cell):
super
(
LogSF
,
self
).
__init__
()
super
(
LogSF
,
self
).
__init__
()
self
.
g
=
msd
.
Geometric
(
0.7
,
dtype
=
dtype
.
int32
)
self
.
g
=
msd
.
Geometric
(
0.7
,
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
g
(
'log_survival'
,
x_
)
return
self
.
g
.
log_survival
(
x_
)
def
test_log_survival
():
def
test_log_survival
():
"""
"""
...
@@ -250,9 +240,8 @@ class EntropyH(nn.Cell):
...
@@ -250,9 +240,8 @@ class EntropyH(nn.Cell):
super
(
EntropyH
,
self
).
__init__
()
super
(
EntropyH
,
self
).
__init__
()
self
.
g
=
msd
.
Geometric
(
0.7
,
dtype
=
dtype
.
int32
)
self
.
g
=
msd
.
Geometric
(
0.7
,
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
):
def
construct
(
self
):
return
self
.
g
(
'entropy'
)
return
self
.
g
.
entropy
(
)
def
test_entropy
():
def
test_entropy
():
"""
"""
...
@@ -273,12 +262,11 @@ class CrossEntropy(nn.Cell):
...
@@ -273,12 +262,11 @@ class CrossEntropy(nn.Cell):
super
(
CrossEntropy
,
self
).
__init__
()
super
(
CrossEntropy
,
self
).
__init__
()
self
.
g
=
msd
.
Geometric
(
0.7
,
dtype
=
dtype
.
int32
)
self
.
g
=
msd
.
Geometric
(
0.7
,
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
entropy
=
self
.
g
(
'entropy'
)
entropy
=
self
.
g
.
entropy
(
)
kl_loss
=
self
.
g
(
'kl_loss'
,
'Geometric'
,
x_
)
kl_loss
=
self
.
g
.
kl_loss
(
'Geometric'
,
x_
)
h_sum_kl
=
entropy
+
kl_loss
h_sum_kl
=
entropy
+
kl_loss
ans
=
self
.
g
(
'cross_entropy'
,
'Geometric'
,
x_
)
ans
=
self
.
g
.
cross_entropy
(
'Geometric'
,
x_
)
return
h_sum_kl
-
ans
return
h_sum_kl
-
ans
def
test_cross_entropy
():
def
test_cross_entropy
():
...
...
tests/st/ops/ascend/test_distribution/test_normal.py
浏览文件 @
e87e1fc6
...
@@ -19,7 +19,6 @@ import mindspore.context as context
...
@@ -19,7 +19,6 @@ import mindspore.context as context
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
import
mindspore.nn.probability.distribution
as
msd
import
mindspore.nn.probability.distribution
as
msd
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindspore.common.api
import
ms_function
from
mindspore
import
dtype
from
mindspore
import
dtype
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
...
@@ -32,9 +31,8 @@ class Prob(nn.Cell):
...
@@ -32,9 +31,8 @@ class Prob(nn.Cell):
super
(
Prob
,
self
).
__init__
()
super
(
Prob
,
self
).
__init__
()
self
.
n
=
msd
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([[
2.0
],
[
4.0
]]),
dtype
=
dtype
.
float32
)
self
.
n
=
msd
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([[
2.0
],
[
4.0
]]),
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
n
(
'prob'
,
x_
)
return
self
.
n
.
prob
(
x_
)
def
test_pdf
():
def
test_pdf
():
"""
"""
...
@@ -55,9 +53,8 @@ class LogProb(nn.Cell):
...
@@ -55,9 +53,8 @@ class LogProb(nn.Cell):
super
(
LogProb
,
self
).
__init__
()
super
(
LogProb
,
self
).
__init__
()
self
.
n
=
msd
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([[
2.0
],
[
4.0
]]),
dtype
=
dtype
.
float32
)
self
.
n
=
msd
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([[
2.0
],
[
4.0
]]),
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
n
(
'log_prob'
,
x_
)
return
self
.
n
.
log_prob
(
x_
)
def
test_log_likelihood
():
def
test_log_likelihood
():
"""
"""
...
@@ -79,9 +76,8 @@ class KL(nn.Cell):
...
@@ -79,9 +76,8 @@ class KL(nn.Cell):
super
(
KL
,
self
).
__init__
()
super
(
KL
,
self
).
__init__
()
self
.
n
=
msd
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([
4.0
]),
dtype
=
dtype
.
float32
)
self
.
n
=
msd
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([
4.0
]),
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
,
y_
):
def
construct
(
self
,
x_
,
y_
):
return
self
.
n
(
'kl_loss'
,
'Normal'
,
x_
,
y_
)
return
self
.
n
.
kl_loss
(
'Normal'
,
x_
,
y_
)
def
test_kl_loss
():
def
test_kl_loss
():
...
@@ -113,9 +109,8 @@ class Basics(nn.Cell):
...
@@ -113,9 +109,8 @@ class Basics(nn.Cell):
super
(
Basics
,
self
).
__init__
()
super
(
Basics
,
self
).
__init__
()
self
.
n
=
msd
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([
2.0
,
4.0
]),
dtype
=
dtype
.
float32
)
self
.
n
=
msd
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([
2.0
,
4.0
]),
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
):
def
construct
(
self
):
return
self
.
n
(
'mean'
),
self
.
n
(
'sd'
),
self
.
n
(
'mode'
)
return
self
.
n
.
mean
(),
self
.
n
.
sd
(),
self
.
n
.
mode
(
)
def
test_basics
():
def
test_basics
():
"""
"""
...
@@ -139,9 +134,8 @@ class Sampling(nn.Cell):
...
@@ -139,9 +134,8 @@ class Sampling(nn.Cell):
self
.
n
=
msd
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([[
2.0
],
[
4.0
]]),
seed
=
seed
,
dtype
=
dtype
.
float32
)
self
.
n
=
msd
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([[
2.0
],
[
4.0
]]),
seed
=
seed
,
dtype
=
dtype
.
float32
)
self
.
shape
=
shape
self
.
shape
=
shape
@
ms_function
def
construct
(
self
,
mean
=
None
,
sd
=
None
):
def
construct
(
self
,
mean
=
None
,
sd
=
None
):
return
self
.
n
(
'sample'
,
self
.
shape
,
mean
,
sd
)
return
self
.
n
.
sample
(
self
.
shape
,
mean
,
sd
)
def
test_sample
():
def
test_sample
():
"""
"""
...
@@ -163,9 +157,8 @@ class CDF(nn.Cell):
...
@@ -163,9 +157,8 @@ class CDF(nn.Cell):
super
(
CDF
,
self
).
__init__
()
super
(
CDF
,
self
).
__init__
()
self
.
n
=
msd
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([[
2.0
],
[
4.0
]]),
dtype
=
dtype
.
float32
)
self
.
n
=
msd
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([[
2.0
],
[
4.0
]]),
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
n
(
'cdf'
,
x_
)
return
self
.
n
.
cdf
(
x_
)
def
test_cdf
():
def
test_cdf
():
...
@@ -187,9 +180,8 @@ class LogCDF(nn.Cell):
...
@@ -187,9 +180,8 @@ class LogCDF(nn.Cell):
super
(
LogCDF
,
self
).
__init__
()
super
(
LogCDF
,
self
).
__init__
()
self
.
n
=
msd
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([[
2.0
],
[
4.0
]]),
dtype
=
dtype
.
float32
)
self
.
n
=
msd
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([[
2.0
],
[
4.0
]]),
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
n
(
'log_cdf'
,
x_
)
return
self
.
n
.
log_cdf
(
x_
)
def
test_log_cdf
():
def
test_log_cdf
():
"""
"""
...
@@ -210,9 +202,8 @@ class SF(nn.Cell):
...
@@ -210,9 +202,8 @@ class SF(nn.Cell):
super
(
SF
,
self
).
__init__
()
super
(
SF
,
self
).
__init__
()
self
.
n
=
msd
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([[
2.0
],
[
4.0
]]),
dtype
=
dtype
.
float32
)
self
.
n
=
msd
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([[
2.0
],
[
4.0
]]),
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
n
(
'survival_function'
,
x_
)
return
self
.
n
.
survival_function
(
x_
)
def
test_survival
():
def
test_survival
():
"""
"""
...
@@ -233,9 +224,8 @@ class LogSF(nn.Cell):
...
@@ -233,9 +224,8 @@ class LogSF(nn.Cell):
super
(
LogSF
,
self
).
__init__
()
super
(
LogSF
,
self
).
__init__
()
self
.
n
=
msd
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([[
2.0
],
[
4.0
]]),
dtype
=
dtype
.
float32
)
self
.
n
=
msd
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([[
2.0
],
[
4.0
]]),
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
n
(
'log_survival'
,
x_
)
return
self
.
n
.
log_survival
(
x_
)
def
test_log_survival
():
def
test_log_survival
():
"""
"""
...
@@ -256,9 +246,8 @@ class EntropyH(nn.Cell):
...
@@ -256,9 +246,8 @@ class EntropyH(nn.Cell):
super
(
EntropyH
,
self
).
__init__
()
super
(
EntropyH
,
self
).
__init__
()
self
.
n
=
msd
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([[
2.0
],
[
4.0
]]),
dtype
=
dtype
.
float32
)
self
.
n
=
msd
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([[
2.0
],
[
4.0
]]),
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
):
def
construct
(
self
):
return
self
.
n
(
'entropy'
)
return
self
.
n
.
entropy
(
)
def
test_entropy
():
def
test_entropy
():
"""
"""
...
@@ -279,12 +268,11 @@ class CrossEntropy(nn.Cell):
...
@@ -279,12 +268,11 @@ class CrossEntropy(nn.Cell):
super
(
CrossEntropy
,
self
).
__init__
()
super
(
CrossEntropy
,
self
).
__init__
()
self
.
n
=
msd
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([
4.0
]),
dtype
=
dtype
.
float32
)
self
.
n
=
msd
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([
4.0
]),
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
,
y_
):
def
construct
(
self
,
x_
,
y_
):
entropy
=
self
.
n
(
'entropy'
)
entropy
=
self
.
n
.
entropy
(
)
kl_loss
=
self
.
n
(
'kl_loss'
,
'Normal'
,
x_
,
y_
)
kl_loss
=
self
.
n
.
kl_loss
(
'Normal'
,
x_
,
y_
)
h_sum_kl
=
entropy
+
kl_loss
h_sum_kl
=
entropy
+
kl_loss
cross_entropy
=
self
.
n
(
'cross_entropy'
,
'Normal'
,
x_
,
y_
)
cross_entropy
=
self
.
n
.
cross_entropy
(
'Normal'
,
x_
,
y_
)
return
h_sum_kl
-
cross_entropy
return
h_sum_kl
-
cross_entropy
def
test_cross_entropy
():
def
test_cross_entropy
():
...
@@ -297,3 +285,40 @@ def test_cross_entropy():
...
@@ -297,3 +285,40 @@ def test_cross_entropy():
diff
=
cross_entropy
(
mean
,
sd
)
diff
=
cross_entropy
(
mean
,
sd
)
tol
=
1e-6
tol
=
1e-6
assert
(
np
.
abs
(
diff
.
asnumpy
()
-
np
.
zeros
(
diff
.
shape
))
<
tol
).
all
()
assert
(
np
.
abs
(
diff
.
asnumpy
()
-
np
.
zeros
(
diff
.
shape
))
<
tol
).
all
()
class
Net
(
nn
.
Cell
):
"""
Test class: expand single distribution instance to multiple graphs
by specifying the attributes.
"""
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
normal
=
msd
.
Normal
(
0.
,
1.
,
dtype
=
dtype
.
float32
)
def
construct
(
self
,
x_
,
y_
):
kl
=
self
.
normal
.
kl_loss
(
'Normal'
,
x_
,
y_
)
prob
=
self
.
normal
.
prob
(
kl
)
return
prob
def
test_multiple_graphs
():
"""
Test multiple graphs case.
"""
prob
=
Net
()
mean_a
=
np
.
array
([
0.0
]).
astype
(
np
.
float32
)
sd_a
=
np
.
array
([
1.0
]).
astype
(
np
.
float32
)
mean_b
=
np
.
array
([
1.0
]).
astype
(
np
.
float32
)
sd_b
=
np
.
array
([
1.0
]).
astype
(
np
.
float32
)
ans
=
prob
(
Tensor
(
mean_b
),
Tensor
(
sd_b
))
diff_log_scale
=
np
.
log
(
sd_a
)
-
np
.
log
(
sd_b
)
squared_diff
=
np
.
square
(
mean_a
/
sd_b
-
mean_b
/
sd_b
)
expect_kl_loss
=
0.5
*
squared_diff
+
0.5
*
\
np
.
expm1
(
2
*
diff_log_scale
)
-
diff_log_scale
norm_benchmark
=
stats
.
norm
(
np
.
array
([
0.0
]),
np
.
array
([
1.0
]))
expect_prob
=
norm_benchmark
.
pdf
(
expect_kl_loss
).
astype
(
np
.
float32
)
tol
=
1e-6
assert
(
np
.
abs
(
ans
.
asnumpy
()
-
expect_prob
)
<
tol
).
all
()
tests/st/ops/ascend/test_distribution/test_normal_new_api.py
已删除
100644 → 0
浏览文件 @
6945eb28
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for new api of normal distribution"""
import
numpy
as
np
from
scipy
import
stats
import
mindspore.nn
as
nn
import
mindspore.nn.probability.distribution
as
msd
from
mindspore
import
dtype
from
mindspore
import
Tensor
import
mindspore.context
as
context
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
class
Net
(
nn
.
Cell
):
"""
Test class: new api of normal distribution.
"""
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
normal
=
msd
.
Normal
(
0.
,
1.
,
dtype
=
dtype
.
float32
)
def
construct
(
self
,
x_
,
y_
):
kl
=
self
.
normal
.
kl_loss
(
'kl_loss'
,
'Normal'
,
x_
,
y_
)
prob
=
self
.
normal
.
prob
(
'prob'
,
kl
)
return
prob
def
test_new_api
():
"""
Test new api of normal distribution.
"""
prob
=
Net
()
mean_a
=
np
.
array
([
0.0
]).
astype
(
np
.
float32
)
sd_a
=
np
.
array
([
1.0
]).
astype
(
np
.
float32
)
mean_b
=
np
.
array
([
1.0
]).
astype
(
np
.
float32
)
sd_b
=
np
.
array
([
1.0
]).
astype
(
np
.
float32
)
ans
=
prob
(
Tensor
(
mean_b
),
Tensor
(
sd_b
))
diff_log_scale
=
np
.
log
(
sd_a
)
-
np
.
log
(
sd_b
)
squared_diff
=
np
.
square
(
mean_a
/
sd_b
-
mean_b
/
sd_b
)
expect_kl_loss
=
0.5
*
squared_diff
+
0.5
*
\
np
.
expm1
(
2
*
diff_log_scale
)
-
diff_log_scale
norm_benchmark
=
stats
.
norm
(
np
.
array
([
0.0
]),
np
.
array
([
1.0
]))
expect_prob
=
norm_benchmark
.
pdf
(
expect_kl_loss
).
astype
(
np
.
float32
)
tol
=
1e-6
assert
(
np
.
abs
(
ans
.
asnumpy
()
-
expect_prob
)
<
tol
).
all
()
tests/st/ops/ascend/test_distribution/test_uniform.py
浏览文件 @
e87e1fc6
...
@@ -19,7 +19,6 @@ import mindspore.context as context
...
@@ -19,7 +19,6 @@ import mindspore.context as context
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
import
mindspore.nn.probability.distribution
as
msd
import
mindspore.nn.probability.distribution
as
msd
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindspore.common.api
import
ms_function
from
mindspore
import
dtype
from
mindspore
import
dtype
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
...
@@ -32,9 +31,8 @@ class Prob(nn.Cell):
...
@@ -32,9 +31,8 @@ class Prob(nn.Cell):
super
(
Prob
,
self
).
__init__
()
super
(
Prob
,
self
).
__init__
()
self
.
u
=
msd
.
Uniform
([
0.0
],
[[
1.0
],
[
2.0
]],
dtype
=
dtype
.
float32
)
self
.
u
=
msd
.
Uniform
([
0.0
],
[[
1.0
],
[
2.0
]],
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
u
(
'prob'
,
x_
)
return
self
.
u
.
prob
(
x_
)
def
test_pdf
():
def
test_pdf
():
"""
"""
...
@@ -56,9 +54,8 @@ class LogProb(nn.Cell):
...
@@ -56,9 +54,8 @@ class LogProb(nn.Cell):
super
(
LogProb
,
self
).
__init__
()
super
(
LogProb
,
self
).
__init__
()
self
.
u
=
msd
.
Uniform
([
0.0
],
[[
1.0
],
[
2.0
]],
dtype
=
dtype
.
float32
)
self
.
u
=
msd
.
Uniform
([
0.0
],
[[
1.0
],
[
2.0
]],
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
u
(
'log_prob'
,
x_
)
return
self
.
u
.
log_prob
(
x_
)
def
test_log_likelihood
():
def
test_log_likelihood
():
"""
"""
...
@@ -80,9 +77,8 @@ class KL(nn.Cell):
...
@@ -80,9 +77,8 @@ class KL(nn.Cell):
super
(
KL
,
self
).
__init__
()
super
(
KL
,
self
).
__init__
()
self
.
u
=
msd
.
Uniform
([
0.0
],
[
1.5
],
dtype
=
dtype
.
float32
)
self
.
u
=
msd
.
Uniform
([
0.0
],
[
1.5
],
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
,
y_
):
def
construct
(
self
,
x_
,
y_
):
return
self
.
u
(
'kl_loss'
,
'Uniform'
,
x_
,
y_
)
return
self
.
u
.
kl_loss
(
'Uniform'
,
x_
,
y_
)
def
test_kl_loss
():
def
test_kl_loss
():
"""
"""
...
@@ -106,9 +102,8 @@ class Basics(nn.Cell):
...
@@ -106,9 +102,8 @@ class Basics(nn.Cell):
super
(
Basics
,
self
).
__init__
()
super
(
Basics
,
self
).
__init__
()
self
.
u
=
msd
.
Uniform
([
0.0
],
[
3.0
],
dtype
=
dtype
.
float32
)
self
.
u
=
msd
.
Uniform
([
0.0
],
[
3.0
],
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
):
def
construct
(
self
):
return
self
.
u
(
'mean'
),
self
.
u
(
'sd'
)
return
self
.
u
.
mean
(),
self
.
u
.
sd
(
)
def
test_basics
():
def
test_basics
():
"""
"""
...
@@ -131,9 +126,8 @@ class Sampling(nn.Cell):
...
@@ -131,9 +126,8 @@ class Sampling(nn.Cell):
self
.
u
=
msd
.
Uniform
([
0.0
],
[[
1.0
],
[
2.0
]],
seed
=
seed
,
dtype
=
dtype
.
float32
)
self
.
u
=
msd
.
Uniform
([
0.0
],
[[
1.0
],
[
2.0
]],
seed
=
seed
,
dtype
=
dtype
.
float32
)
self
.
shape
=
shape
self
.
shape
=
shape
@
ms_function
def
construct
(
self
,
low
=
None
,
high
=
None
):
def
construct
(
self
,
low
=
None
,
high
=
None
):
return
self
.
u
(
'sample'
,
self
.
shape
,
low
,
high
)
return
self
.
u
.
sample
(
self
.
shape
,
low
,
high
)
def
test_sample
():
def
test_sample
():
"""
"""
...
@@ -155,9 +149,8 @@ class CDF(nn.Cell):
...
@@ -155,9 +149,8 @@ class CDF(nn.Cell):
super
(
CDF
,
self
).
__init__
()
super
(
CDF
,
self
).
__init__
()
self
.
u
=
msd
.
Uniform
([
0.0
],
[
1.0
],
dtype
=
dtype
.
float32
)
self
.
u
=
msd
.
Uniform
([
0.0
],
[
1.0
],
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
u
(
'cdf'
,
x_
)
return
self
.
u
.
cdf
(
x_
)
def
test_cdf
():
def
test_cdf
():
"""
"""
...
@@ -179,9 +172,8 @@ class LogCDF(nn.Cell):
...
@@ -179,9 +172,8 @@ class LogCDF(nn.Cell):
super
(
LogCDF
,
self
).
__init__
()
super
(
LogCDF
,
self
).
__init__
()
self
.
u
=
msd
.
Uniform
([
0.0
],
[
1.0
],
dtype
=
dtype
.
float32
)
self
.
u
=
msd
.
Uniform
([
0.0
],
[
1.0
],
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
u
(
'log_cdf'
,
x_
)
return
self
.
u
.
log_cdf
(
x_
)
class
SF
(
nn
.
Cell
):
class
SF
(
nn
.
Cell
):
"""
"""
...
@@ -191,9 +183,8 @@ class SF(nn.Cell):
...
@@ -191,9 +183,8 @@ class SF(nn.Cell):
super
(
SF
,
self
).
__init__
()
super
(
SF
,
self
).
__init__
()
self
.
u
=
msd
.
Uniform
([
0.0
],
[
1.0
],
dtype
=
dtype
.
float32
)
self
.
u
=
msd
.
Uniform
([
0.0
],
[
1.0
],
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
u
(
'survival_function'
,
x_
)
return
self
.
u
.
survival_function
(
x_
)
class
LogSF
(
nn
.
Cell
):
class
LogSF
(
nn
.
Cell
):
"""
"""
...
@@ -203,9 +194,8 @@ class LogSF(nn.Cell):
...
@@ -203,9 +194,8 @@ class LogSF(nn.Cell):
super
(
LogSF
,
self
).
__init__
()
super
(
LogSF
,
self
).
__init__
()
self
.
u
=
msd
.
Uniform
([
0.0
],
[
1.0
],
dtype
=
dtype
.
float32
)
self
.
u
=
msd
.
Uniform
([
0.0
],
[
1.0
],
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
):
def
construct
(
self
,
x_
):
return
self
.
u
(
'log_survival'
,
x_
)
return
self
.
u
.
log_survival
(
x_
)
class
EntropyH
(
nn
.
Cell
):
class
EntropyH
(
nn
.
Cell
):
"""
"""
...
@@ -215,9 +205,8 @@ class EntropyH(nn.Cell):
...
@@ -215,9 +205,8 @@ class EntropyH(nn.Cell):
super
(
EntropyH
,
self
).
__init__
()
super
(
EntropyH
,
self
).
__init__
()
self
.
u
=
msd
.
Uniform
([
0.0
],
[
1.0
,
2.0
],
dtype
=
dtype
.
float32
)
self
.
u
=
msd
.
Uniform
([
0.0
],
[
1.0
,
2.0
],
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
):
def
construct
(
self
):
return
self
.
u
(
'entropy'
)
return
self
.
u
.
entropy
(
)
def
test_entropy
():
def
test_entropy
():
"""
"""
...
@@ -238,12 +227,11 @@ class CrossEntropy(nn.Cell):
...
@@ -238,12 +227,11 @@ class CrossEntropy(nn.Cell):
super
(
CrossEntropy
,
self
).
__init__
()
super
(
CrossEntropy
,
self
).
__init__
()
self
.
u
=
msd
.
Uniform
([
0.0
],
[
1.5
],
dtype
=
dtype
.
float32
)
self
.
u
=
msd
.
Uniform
([
0.0
],
[
1.5
],
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
,
y_
):
def
construct
(
self
,
x_
,
y_
):
entropy
=
self
.
u
(
'entropy'
)
entropy
=
self
.
u
.
entropy
(
)
kl_loss
=
self
.
u
(
'kl_loss'
,
'Uniform'
,
x_
,
y_
)
kl_loss
=
self
.
u
.
kl_loss
(
'Uniform'
,
x_
,
y_
)
h_sum_kl
=
entropy
+
kl_loss
h_sum_kl
=
entropy
+
kl_loss
cross_entropy
=
self
.
u
(
'cross_entropy'
,
'Uniform'
,
x_
,
y_
)
cross_entropy
=
self
.
u
.
cross_entropy
(
'Uniform'
,
x_
,
y_
)
return
h_sum_kl
-
cross_entropy
return
h_sum_kl
-
cross_entropy
def
test_log_cdf
():
def
test_log_cdf
():
...
...
tests/ut/python/nn/distribution/test_bernoulli.py
浏览文件 @
e87e1fc6
...
@@ -49,12 +49,12 @@ class BernoulliProb(nn.Cell):
...
@@ -49,12 +49,12 @@ class BernoulliProb(nn.Cell):
self
.
b
=
msd
.
Bernoulli
(
0.5
,
dtype
=
dtype
.
int32
)
self
.
b
=
msd
.
Bernoulli
(
0.5
,
dtype
=
dtype
.
int32
)
def
construct
(
self
,
value
):
def
construct
(
self
,
value
):
prob
=
self
.
b
(
'prob'
,
value
)
prob
=
self
.
b
.
prob
(
value
)
log_prob
=
self
.
b
(
'log_prob'
,
value
)
log_prob
=
self
.
b
.
log_prob
(
value
)
cdf
=
self
.
b
(
'cdf'
,
value
)
cdf
=
self
.
b
.
cdf
(
value
)
log_cdf
=
self
.
b
(
'log_cdf'
,
value
)
log_cdf
=
self
.
b
.
log_cdf
(
value
)
sf
=
self
.
b
(
'survival_function'
,
value
)
sf
=
self
.
b
.
survival_function
(
value
)
log_sf
=
self
.
b
(
'log_survival'
,
value
)
log_sf
=
self
.
b
.
log_survival
(
value
)
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
def
test_bernoulli_prob
():
def
test_bernoulli_prob
():
...
@@ -75,12 +75,12 @@ class BernoulliProb1(nn.Cell):
...
@@ -75,12 +75,12 @@ class BernoulliProb1(nn.Cell):
self
.
b
=
msd
.
Bernoulli
(
dtype
=
dtype
.
int32
)
self
.
b
=
msd
.
Bernoulli
(
dtype
=
dtype
.
int32
)
def
construct
(
self
,
value
,
probs
):
def
construct
(
self
,
value
,
probs
):
prob
=
self
.
b
(
'prob'
,
value
,
probs
)
prob
=
self
.
b
.
prob
(
value
,
probs
)
log_prob
=
self
.
b
(
'log_prob'
,
value
,
probs
)
log_prob
=
self
.
b
.
log_prob
(
value
,
probs
)
cdf
=
self
.
b
(
'cdf'
,
value
,
probs
)
cdf
=
self
.
b
.
cdf
(
value
,
probs
)
log_cdf
=
self
.
b
(
'log_cdf'
,
value
,
probs
)
log_cdf
=
self
.
b
.
log_cdf
(
value
,
probs
)
sf
=
self
.
b
(
'survival_function'
,
value
,
probs
)
sf
=
self
.
b
.
survival_function
(
value
,
probs
)
log_sf
=
self
.
b
(
'log_survival'
,
value
,
probs
)
log_sf
=
self
.
b
.
log_survival
(
value
,
probs
)
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
def
test_bernoulli_prob1
():
def
test_bernoulli_prob1
():
...
@@ -103,8 +103,8 @@ class BernoulliKl(nn.Cell):
...
@@ -103,8 +103,8 @@ class BernoulliKl(nn.Cell):
self
.
b2
=
msd
.
Bernoulli
(
dtype
=
dtype
.
int32
)
self
.
b2
=
msd
.
Bernoulli
(
dtype
=
dtype
.
int32
)
def
construct
(
self
,
probs_b
,
probs_a
):
def
construct
(
self
,
probs_b
,
probs_a
):
kl1
=
self
.
b1
(
'kl_loss'
,
'Bernoulli'
,
probs_b
)
kl1
=
self
.
b1
.
kl_loss
(
'Bernoulli'
,
probs_b
)
kl2
=
self
.
b2
(
'kl_loss'
,
'Bernoulli'
,
probs_b
,
probs_a
)
kl2
=
self
.
b2
.
kl_loss
(
'Bernoulli'
,
probs_b
,
probs_a
)
return
kl1
+
kl2
return
kl1
+
kl2
def
test_kl
():
def
test_kl
():
...
@@ -127,8 +127,8 @@ class BernoulliCrossEntropy(nn.Cell):
...
@@ -127,8 +127,8 @@ class BernoulliCrossEntropy(nn.Cell):
self
.
b2
=
msd
.
Bernoulli
(
dtype
=
dtype
.
int32
)
self
.
b2
=
msd
.
Bernoulli
(
dtype
=
dtype
.
int32
)
def
construct
(
self
,
probs_b
,
probs_a
):
def
construct
(
self
,
probs_b
,
probs_a
):
h1
=
self
.
b1
(
'cross_entropy'
,
'Bernoulli'
,
probs_b
)
h1
=
self
.
b1
.
cross_entropy
(
'Bernoulli'
,
probs_b
)
h2
=
self
.
b2
(
'cross_entropy'
,
'Bernoulli'
,
probs_b
,
probs_a
)
h2
=
self
.
b2
.
cross_entropy
(
'Bernoulli'
,
probs_b
,
probs_a
)
return
h1
+
h2
return
h1
+
h2
def
test_cross_entropy
():
def
test_cross_entropy
():
...
@@ -150,11 +150,11 @@ class BernoulliBasics(nn.Cell):
...
@@ -150,11 +150,11 @@ class BernoulliBasics(nn.Cell):
self
.
b
=
msd
.
Bernoulli
([
0.3
,
0.5
],
dtype
=
dtype
.
int32
)
self
.
b
=
msd
.
Bernoulli
([
0.3
,
0.5
],
dtype
=
dtype
.
int32
)
def
construct
(
self
):
def
construct
(
self
):
mean
=
self
.
b
(
'mean'
)
mean
=
self
.
b
.
mean
(
)
sd
=
self
.
b
(
'sd'
)
sd
=
self
.
b
.
sd
(
)
var
=
self
.
b
(
'var'
)
var
=
self
.
b
.
var
(
)
mode
=
self
.
b
(
'mode'
)
mode
=
self
.
b
.
mode
(
)
entropy
=
self
.
b
(
'entropy'
)
entropy
=
self
.
b
.
entropy
(
)
return
mean
+
sd
+
var
+
mode
+
entropy
return
mean
+
sd
+
var
+
mode
+
entropy
def
test_bascis
():
def
test_bascis
():
...
@@ -164,3 +164,28 @@ def test_bascis():
...
@@ -164,3 +164,28 @@ def test_bascis():
net
=
BernoulliBasics
()
net
=
BernoulliBasics
()
ans
=
net
()
ans
=
net
()
assert
isinstance
(
ans
,
Tensor
)
assert
isinstance
(
ans
,
Tensor
)
class
BernoulliConstruct
(
nn
.
Cell
):
"""
Bernoulli distribution: going through construct.
"""
def
__init__
(
self
):
super
(
BernoulliConstruct
,
self
).
__init__
()
self
.
b
=
msd
.
Bernoulli
(
0.5
,
dtype
=
dtype
.
int32
)
self
.
b1
=
msd
.
Bernoulli
(
dtype
=
dtype
.
int32
)
def
construct
(
self
,
value
,
probs
):
prob
=
self
.
b
(
'prob'
,
value
)
prob1
=
self
.
b
(
'prob'
,
value
,
probs
)
prob2
=
self
.
b1
(
'prob'
,
value
,
probs
)
return
prob
+
prob1
+
prob2
def
test_bernoulli_construct
():
"""
Test probability function going through construct.
"""
net
=
BernoulliConstruct
()
value
=
Tensor
([
0
,
0
,
0
,
0
,
0
],
dtype
=
dtype
.
float32
)
probs
=
Tensor
([
0.5
],
dtype
=
dtype
.
float32
)
ans
=
net
(
value
,
probs
)
assert
isinstance
(
ans
,
Tensor
)
tests/ut/python/nn/distribution/test_exponential.py
浏览文件 @
e87e1fc6
...
@@ -50,12 +50,12 @@ class ExponentialProb(nn.Cell):
...
@@ -50,12 +50,12 @@ class ExponentialProb(nn.Cell):
self
.
e
=
msd
.
Exponential
(
0.5
,
dtype
=
dtype
.
float32
)
self
.
e
=
msd
.
Exponential
(
0.5
,
dtype
=
dtype
.
float32
)
def
construct
(
self
,
value
):
def
construct
(
self
,
value
):
prob
=
self
.
e
(
'prob'
,
value
)
prob
=
self
.
e
.
prob
(
value
)
log_prob
=
self
.
e
(
'log_prob'
,
value
)
log_prob
=
self
.
e
.
log_prob
(
value
)
cdf
=
self
.
e
(
'cdf'
,
value
)
cdf
=
self
.
e
.
cdf
(
value
)
log_cdf
=
self
.
e
(
'log_cdf'
,
value
)
log_cdf
=
self
.
e
.
log_cdf
(
value
)
sf
=
self
.
e
(
'survival_function'
,
value
)
sf
=
self
.
e
.
survival_function
(
value
)
log_sf
=
self
.
e
(
'log_survival'
,
value
)
log_sf
=
self
.
e
.
log_survival
(
value
)
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
def
test_exponential_prob
():
def
test_exponential_prob
():
...
@@ -76,12 +76,12 @@ class ExponentialProb1(nn.Cell):
...
@@ -76,12 +76,12 @@ class ExponentialProb1(nn.Cell):
self
.
e
=
msd
.
Exponential
(
dtype
=
dtype
.
float32
)
self
.
e
=
msd
.
Exponential
(
dtype
=
dtype
.
float32
)
def
construct
(
self
,
value
,
rate
):
def
construct
(
self
,
value
,
rate
):
prob
=
self
.
e
(
'prob'
,
value
,
rate
)
prob
=
self
.
e
.
prob
(
value
,
rate
)
log_prob
=
self
.
e
(
'log_prob'
,
value
,
rate
)
log_prob
=
self
.
e
.
log_prob
(
value
,
rate
)
cdf
=
self
.
e
(
'cdf'
,
value
,
rate
)
cdf
=
self
.
e
.
cdf
(
value
,
rate
)
log_cdf
=
self
.
e
(
'log_cdf'
,
value
,
rate
)
log_cdf
=
self
.
e
.
log_cdf
(
value
,
rate
)
sf
=
self
.
e
(
'survival_function'
,
value
,
rate
)
sf
=
self
.
e
.
survival_function
(
value
,
rate
)
log_sf
=
self
.
e
(
'log_survival'
,
value
,
rate
)
log_sf
=
self
.
e
.
log_survival
(
value
,
rate
)
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
def
test_exponential_prob1
():
def
test_exponential_prob1
():
...
@@ -104,8 +104,8 @@ class ExponentialKl(nn.Cell):
...
@@ -104,8 +104,8 @@ class ExponentialKl(nn.Cell):
self
.
e2
=
msd
.
Exponential
(
dtype
=
dtype
.
float32
)
self
.
e2
=
msd
.
Exponential
(
dtype
=
dtype
.
float32
)
def
construct
(
self
,
rate_b
,
rate_a
):
def
construct
(
self
,
rate_b
,
rate_a
):
kl1
=
self
.
e1
(
'kl_loss'
,
'Exponential'
,
rate_b
)
kl1
=
self
.
e1
.
kl_loss
(
'Exponential'
,
rate_b
)
kl2
=
self
.
e2
(
'kl_loss'
,
'Exponential'
,
rate_b
,
rate_a
)
kl2
=
self
.
e2
.
kl_loss
(
'Exponential'
,
rate_b
,
rate_a
)
return
kl1
+
kl2
return
kl1
+
kl2
def
test_kl
():
def
test_kl
():
...
@@ -128,8 +128,8 @@ class ExponentialCrossEntropy(nn.Cell):
...
@@ -128,8 +128,8 @@ class ExponentialCrossEntropy(nn.Cell):
self
.
e2
=
msd
.
Exponential
(
dtype
=
dtype
.
float32
)
self
.
e2
=
msd
.
Exponential
(
dtype
=
dtype
.
float32
)
def
construct
(
self
,
rate_b
,
rate_a
):
def
construct
(
self
,
rate_b
,
rate_a
):
h1
=
self
.
e1
(
'cross_entropy'
,
'Exponential'
,
rate_b
)
h1
=
self
.
e1
.
cross_entropy
(
'Exponential'
,
rate_b
)
h2
=
self
.
e2
(
'cross_entropy'
,
'Exponential'
,
rate_b
,
rate_a
)
h2
=
self
.
e2
.
cross_entropy
(
'Exponential'
,
rate_b
,
rate_a
)
return
h1
+
h2
return
h1
+
h2
def
test_cross_entropy
():
def
test_cross_entropy
():
...
@@ -151,11 +151,11 @@ class ExponentialBasics(nn.Cell):
...
@@ -151,11 +151,11 @@ class ExponentialBasics(nn.Cell):
self
.
e
=
msd
.
Exponential
([
0.3
,
0.5
],
dtype
=
dtype
.
float32
)
self
.
e
=
msd
.
Exponential
([
0.3
,
0.5
],
dtype
=
dtype
.
float32
)
def
construct
(
self
):
def
construct
(
self
):
mean
=
self
.
e
(
'mean'
)
mean
=
self
.
e
.
mean
(
)
sd
=
self
.
e
(
'sd'
)
sd
=
self
.
e
.
sd
(
)
var
=
self
.
e
(
'var'
)
var
=
self
.
e
.
var
(
)
mode
=
self
.
e
(
'mode'
)
mode
=
self
.
e
.
mode
(
)
entropy
=
self
.
e
(
'entropy'
)
entropy
=
self
.
e
.
entropy
(
)
return
mean
+
sd
+
var
+
mode
+
entropy
return
mean
+
sd
+
var
+
mode
+
entropy
def
test_bascis
():
def
test_bascis
():
...
@@ -165,3 +165,29 @@ def test_bascis():
...
@@ -165,3 +165,29 @@ def test_bascis():
net
=
ExponentialBasics
()
net
=
ExponentialBasics
()
ans
=
net
()
ans
=
net
()
assert
isinstance
(
ans
,
Tensor
)
assert
isinstance
(
ans
,
Tensor
)
class
ExpConstruct
(
nn
.
Cell
):
"""
Exponential distribution: going through construct.
"""
def
__init__
(
self
):
super
(
ExpConstruct
,
self
).
__init__
()
self
.
e
=
msd
.
Exponential
(
0.5
,
dtype
=
dtype
.
float32
)
self
.
e1
=
msd
.
Exponential
(
dtype
=
dtype
.
float32
)
def
construct
(
self
,
value
,
rate
):
prob
=
self
.
e
(
'prob'
,
value
)
prob1
=
self
.
e
(
'prob'
,
value
,
rate
)
prob2
=
self
.
e1
(
'prob'
,
value
,
rate
)
return
prob
+
prob1
+
prob2
def
test_exp_construct
():
"""
Test probability function going through construct.
"""
net
=
ExpConstruct
()
value
=
Tensor
([
0
,
0
,
0
,
0
,
0
],
dtype
=
dtype
.
float32
)
probs
=
Tensor
([
0.5
],
dtype
=
dtype
.
float32
)
ans
=
net
(
value
,
probs
)
assert
isinstance
(
ans
,
Tensor
)
tests/ut/python/nn/distribution/test_geometric.py
浏览文件 @
e87e1fc6
...
@@ -50,12 +50,12 @@ class GeometricProb(nn.Cell):
...
@@ -50,12 +50,12 @@ class GeometricProb(nn.Cell):
self
.
g
=
msd
.
Geometric
(
0.5
,
dtype
=
dtype
.
int32
)
self
.
g
=
msd
.
Geometric
(
0.5
,
dtype
=
dtype
.
int32
)
def
construct
(
self
,
value
):
def
construct
(
self
,
value
):
prob
=
self
.
g
(
'prob'
,
value
)
prob
=
self
.
g
.
prob
(
value
)
log_prob
=
self
.
g
(
'log_prob'
,
value
)
log_prob
=
self
.
g
.
log_prob
(
value
)
cdf
=
self
.
g
(
'cdf'
,
value
)
cdf
=
self
.
g
.
cdf
(
value
)
log_cdf
=
self
.
g
(
'log_cdf'
,
value
)
log_cdf
=
self
.
g
.
log_cdf
(
value
)
sf
=
self
.
g
(
'survival_function'
,
value
)
sf
=
self
.
g
.
survival_function
(
value
)
log_sf
=
self
.
g
(
'log_survival'
,
value
)
log_sf
=
self
.
g
.
log_survival
(
value
)
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
def
test_geometric_prob
():
def
test_geometric_prob
():
...
@@ -76,12 +76,12 @@ class GeometricProb1(nn.Cell):
...
@@ -76,12 +76,12 @@ class GeometricProb1(nn.Cell):
self
.
g
=
msd
.
Geometric
(
dtype
=
dtype
.
int32
)
self
.
g
=
msd
.
Geometric
(
dtype
=
dtype
.
int32
)
def
construct
(
self
,
value
,
probs
):
def
construct
(
self
,
value
,
probs
):
prob
=
self
.
g
(
'prob'
,
value
,
probs
)
prob
=
self
.
g
.
prob
(
value
,
probs
)
log_prob
=
self
.
g
(
'log_prob'
,
value
,
probs
)
log_prob
=
self
.
g
.
log_prob
(
value
,
probs
)
cdf
=
self
.
g
(
'cdf'
,
value
,
probs
)
cdf
=
self
.
g
.
cdf
(
value
,
probs
)
log_cdf
=
self
.
g
(
'log_cdf'
,
value
,
probs
)
log_cdf
=
self
.
g
.
log_cdf
(
value
,
probs
)
sf
=
self
.
g
(
'survival_function'
,
value
,
probs
)
sf
=
self
.
g
.
survival_function
(
value
,
probs
)
log_sf
=
self
.
g
(
'log_survival'
,
value
,
probs
)
log_sf
=
self
.
g
.
log_survival
(
value
,
probs
)
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
def
test_geometric_prob1
():
def
test_geometric_prob1
():
...
@@ -105,8 +105,8 @@ class GeometricKl(nn.Cell):
...
@@ -105,8 +105,8 @@ class GeometricKl(nn.Cell):
self
.
g2
=
msd
.
Geometric
(
dtype
=
dtype
.
int32
)
self
.
g2
=
msd
.
Geometric
(
dtype
=
dtype
.
int32
)
def
construct
(
self
,
probs_b
,
probs_a
):
def
construct
(
self
,
probs_b
,
probs_a
):
kl1
=
self
.
g1
(
'kl_loss'
,
'Geometric'
,
probs_b
)
kl1
=
self
.
g1
.
kl_loss
(
'Geometric'
,
probs_b
)
kl2
=
self
.
g2
(
'kl_loss'
,
'Geometric'
,
probs_b
,
probs_a
)
kl2
=
self
.
g2
.
kl_loss
(
'Geometric'
,
probs_b
,
probs_a
)
return
kl1
+
kl2
return
kl1
+
kl2
def
test_kl
():
def
test_kl
():
...
@@ -129,8 +129,8 @@ class GeometricCrossEntropy(nn.Cell):
...
@@ -129,8 +129,8 @@ class GeometricCrossEntropy(nn.Cell):
self
.
g2
=
msd
.
Geometric
(
dtype
=
dtype
.
int32
)
self
.
g2
=
msd
.
Geometric
(
dtype
=
dtype
.
int32
)
def
construct
(
self
,
probs_b
,
probs_a
):
def
construct
(
self
,
probs_b
,
probs_a
):
h1
=
self
.
g1
(
'cross_entropy'
,
'Geometric'
,
probs_b
)
h1
=
self
.
g1
.
cross_entropy
(
'Geometric'
,
probs_b
)
h2
=
self
.
g2
(
'cross_entropy'
,
'Geometric'
,
probs_b
,
probs_a
)
h2
=
self
.
g2
.
cross_entropy
(
'Geometric'
,
probs_b
,
probs_a
)
return
h1
+
h2
return
h1
+
h2
def
test_cross_entropy
():
def
test_cross_entropy
():
...
@@ -152,11 +152,11 @@ class GeometricBasics(nn.Cell):
...
@@ -152,11 +152,11 @@ class GeometricBasics(nn.Cell):
self
.
g
=
msd
.
Geometric
([
0.3
,
0.5
],
dtype
=
dtype
.
int32
)
self
.
g
=
msd
.
Geometric
([
0.3
,
0.5
],
dtype
=
dtype
.
int32
)
def
construct
(
self
):
def
construct
(
self
):
mean
=
self
.
g
(
'mean'
)
mean
=
self
.
g
.
mean
(
)
sd
=
self
.
g
(
'sd'
)
sd
=
self
.
g
.
sd
(
)
var
=
self
.
g
(
'var'
)
var
=
self
.
g
.
var
(
)
mode
=
self
.
g
(
'mode'
)
mode
=
self
.
g
.
mode
(
)
entropy
=
self
.
g
(
'entropy'
)
entropy
=
self
.
g
.
entropy
(
)
return
mean
+
sd
+
var
+
mode
+
entropy
return
mean
+
sd
+
var
+
mode
+
entropy
def
test_bascis
():
def
test_bascis
():
...
@@ -166,3 +166,29 @@ def test_bascis():
...
@@ -166,3 +166,29 @@ def test_bascis():
net
=
GeometricBasics
()
net
=
GeometricBasics
()
ans
=
net
()
ans
=
net
()
assert
isinstance
(
ans
,
Tensor
)
assert
isinstance
(
ans
,
Tensor
)
class
GeoConstruct
(
nn
.
Cell
):
"""
Bernoulli distribution: going through construct.
"""
def
__init__
(
self
):
super
(
GeoConstruct
,
self
).
__init__
()
self
.
g
=
msd
.
Geometric
(
0.5
,
dtype
=
dtype
.
int32
)
self
.
g1
=
msd
.
Geometric
(
dtype
=
dtype
.
int32
)
def
construct
(
self
,
value
,
probs
):
prob
=
self
.
g
(
'prob'
,
value
)
prob1
=
self
.
g
(
'prob'
,
value
,
probs
)
prob2
=
self
.
g1
(
'prob'
,
value
,
probs
)
return
prob
+
prob1
+
prob2
def
test_geo_construct
():
"""
Test probability function going through construct.
"""
net
=
GeoConstruct
()
value
=
Tensor
([
0
,
0
,
0
,
0
,
0
],
dtype
=
dtype
.
float32
)
probs
=
Tensor
([
0.5
],
dtype
=
dtype
.
float32
)
ans
=
net
(
value
,
probs
)
assert
isinstance
(
ans
,
Tensor
)
tests/ut/python/nn/distribution/test_normal.py
浏览文件 @
e87e1fc6
...
@@ -50,12 +50,12 @@ class NormalProb(nn.Cell):
...
@@ -50,12 +50,12 @@ class NormalProb(nn.Cell):
self
.
normal
=
msd
.
Normal
(
3.0
,
4.0
,
dtype
=
dtype
.
float32
)
self
.
normal
=
msd
.
Normal
(
3.0
,
4.0
,
dtype
=
dtype
.
float32
)
def
construct
(
self
,
value
):
def
construct
(
self
,
value
):
prob
=
self
.
normal
(
'prob'
,
value
)
prob
=
self
.
normal
.
prob
(
value
)
log_prob
=
self
.
normal
(
'log_prob'
,
value
)
log_prob
=
self
.
normal
.
log_prob
(
value
)
cdf
=
self
.
normal
(
'cdf'
,
value
)
cdf
=
self
.
normal
.
cdf
(
value
)
log_cdf
=
self
.
normal
(
'log_cdf'
,
value
)
log_cdf
=
self
.
normal
.
log_cdf
(
value
)
sf
=
self
.
normal
(
'survival_function'
,
value
)
sf
=
self
.
normal
.
survival_function
(
value
)
log_sf
=
self
.
normal
(
'log_survival'
,
value
)
log_sf
=
self
.
normal
.
log_survival
(
value
)
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
def
test_normal_prob
():
def
test_normal_prob
():
...
@@ -77,12 +77,12 @@ class NormalProb1(nn.Cell):
...
@@ -77,12 +77,12 @@ class NormalProb1(nn.Cell):
self
.
normal
=
msd
.
Normal
()
self
.
normal
=
msd
.
Normal
()
def
construct
(
self
,
value
,
mean
,
sd
):
def
construct
(
self
,
value
,
mean
,
sd
):
prob
=
self
.
normal
(
'prob'
,
value
,
mean
,
sd
)
prob
=
self
.
normal
.
prob
(
value
,
mean
,
sd
)
log_prob
=
self
.
normal
(
'log_prob'
,
value
,
mean
,
sd
)
log_prob
=
self
.
normal
.
log_prob
(
value
,
mean
,
sd
)
cdf
=
self
.
normal
(
'cdf'
,
value
,
mean
,
sd
)
cdf
=
self
.
normal
.
cdf
(
value
,
mean
,
sd
)
log_cdf
=
self
.
normal
(
'log_cdf'
,
value
,
mean
,
sd
)
log_cdf
=
self
.
normal
.
log_cdf
(
value
,
mean
,
sd
)
sf
=
self
.
normal
(
'survival_function'
,
value
,
mean
,
sd
)
sf
=
self
.
normal
.
survival_function
(
value
,
mean
,
sd
)
log_sf
=
self
.
normal
(
'log_survival'
,
value
,
mean
,
sd
)
log_sf
=
self
.
normal
.
log_survival
(
value
,
mean
,
sd
)
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
def
test_normal_prob1
():
def
test_normal_prob1
():
...
@@ -106,8 +106,8 @@ class NormalKl(nn.Cell):
...
@@ -106,8 +106,8 @@ class NormalKl(nn.Cell):
self
.
n2
=
msd
.
Normal
(
dtype
=
dtype
.
float32
)
self
.
n2
=
msd
.
Normal
(
dtype
=
dtype
.
float32
)
def
construct
(
self
,
mean_b
,
sd_b
,
mean_a
,
sd_a
):
def
construct
(
self
,
mean_b
,
sd_b
,
mean_a
,
sd_a
):
kl1
=
self
.
n1
(
'kl_loss'
,
'Normal'
,
mean_b
,
sd_b
)
kl1
=
self
.
n1
.
kl_loss
(
'Normal'
,
mean_b
,
sd_b
)
kl2
=
self
.
n2
(
'kl_loss'
,
'Normal'
,
mean_b
,
sd_b
,
mean_a
,
sd_a
)
kl2
=
self
.
n2
.
kl_loss
(
'Normal'
,
mean_b
,
sd_b
,
mean_a
,
sd_a
)
return
kl1
+
kl2
return
kl1
+
kl2
def
test_kl
():
def
test_kl
():
...
@@ -132,8 +132,8 @@ class NormalCrossEntropy(nn.Cell):
...
@@ -132,8 +132,8 @@ class NormalCrossEntropy(nn.Cell):
self
.
n2
=
msd
.
Normal
(
dtype
=
dtype
.
float32
)
self
.
n2
=
msd
.
Normal
(
dtype
=
dtype
.
float32
)
def
construct
(
self
,
mean_b
,
sd_b
,
mean_a
,
sd_a
):
def
construct
(
self
,
mean_b
,
sd_b
,
mean_a
,
sd_a
):
h1
=
self
.
n1
(
'cross_entropy'
,
'Normal'
,
mean_b
,
sd_b
)
h1
=
self
.
n1
.
cross_entropy
(
'Normal'
,
mean_b
,
sd_b
)
h2
=
self
.
n2
(
'cross_entropy'
,
'Normal'
,
mean_b
,
sd_b
,
mean_a
,
sd_a
)
h2
=
self
.
n2
.
cross_entropy
(
'Normal'
,
mean_b
,
sd_b
,
mean_a
,
sd_a
)
return
h1
+
h2
return
h1
+
h2
def
test_cross_entropy
():
def
test_cross_entropy
():
...
@@ -157,10 +157,10 @@ class NormalBasics(nn.Cell):
...
@@ -157,10 +157,10 @@ class NormalBasics(nn.Cell):
self
.
n
=
msd
.
Normal
(
3.0
,
4.0
,
dtype
=
dtype
.
float32
)
self
.
n
=
msd
.
Normal
(
3.0
,
4.0
,
dtype
=
dtype
.
float32
)
def
construct
(
self
):
def
construct
(
self
):
mean
=
self
.
n
(
'mean'
)
mean
=
self
.
n
.
mean
(
)
sd
=
self
.
n
(
'sd'
)
sd
=
self
.
n
.
sd
(
)
mode
=
self
.
n
(
'mode'
)
mode
=
self
.
n
.
mode
(
)
entropy
=
self
.
n
(
'entropy'
)
entropy
=
self
.
n
.
entropy
(
)
return
mean
+
sd
+
mode
+
entropy
return
mean
+
sd
+
mode
+
entropy
def
test_bascis
():
def
test_bascis
():
...
@@ -170,3 +170,30 @@ def test_bascis():
...
@@ -170,3 +170,30 @@ def test_bascis():
net
=
NormalBasics
()
net
=
NormalBasics
()
ans
=
net
()
ans
=
net
()
assert
isinstance
(
ans
,
Tensor
)
assert
isinstance
(
ans
,
Tensor
)
class
NormalConstruct
(
nn
.
Cell
):
"""
Normal distribution: going through construct.
"""
def
__init__
(
self
):
super
(
NormalConstruct
,
self
).
__init__
()
self
.
normal
=
msd
.
Normal
(
3.0
,
4.0
)
self
.
normal1
=
msd
.
Normal
()
def
construct
(
self
,
value
,
mean
,
sd
):
prob
=
self
.
normal
(
'prob'
,
value
)
prob1
=
self
.
normal
(
'prob'
,
value
,
mean
,
sd
)
prob2
=
self
.
normal1
(
'prob'
,
value
,
mean
,
sd
)
return
prob
+
prob1
+
prob2
def
test_normal_construct
():
"""
Test probability function going through construct.
"""
net
=
NormalConstruct
()
value
=
Tensor
([
0.5
,
1.0
],
dtype
=
dtype
.
float32
)
mean
=
Tensor
([
0.0
],
dtype
=
dtype
.
float32
)
sd
=
Tensor
([
1.0
],
dtype
=
dtype
.
float32
)
ans
=
net
(
value
,
mean
,
sd
)
assert
isinstance
(
ans
,
Tensor
)
tests/ut/python/nn/distribution/test_uniform.py
浏览文件 @
e87e1fc6
...
@@ -60,12 +60,12 @@ class UniformProb(nn.Cell):
...
@@ -60,12 +60,12 @@ class UniformProb(nn.Cell):
self
.
u
=
msd
.
Uniform
(
3.0
,
4.0
,
dtype
=
dtype
.
float32
)
self
.
u
=
msd
.
Uniform
(
3.0
,
4.0
,
dtype
=
dtype
.
float32
)
def
construct
(
self
,
value
):
def
construct
(
self
,
value
):
prob
=
self
.
u
(
'prob'
,
value
)
prob
=
self
.
u
.
prob
(
value
)
log_prob
=
self
.
u
(
'log_prob'
,
value
)
log_prob
=
self
.
u
.
log_prob
(
value
)
cdf
=
self
.
u
(
'cdf'
,
value
)
cdf
=
self
.
u
.
cdf
(
value
)
log_cdf
=
self
.
u
(
'log_cdf'
,
value
)
log_cdf
=
self
.
u
.
log_cdf
(
value
)
sf
=
self
.
u
(
'survival_function'
,
value
)
sf
=
self
.
u
.
survival_function
(
value
)
log_sf
=
self
.
u
(
'log_survival'
,
value
)
log_sf
=
self
.
u
.
log_survival
(
value
)
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
def
test_uniform_prob
():
def
test_uniform_prob
():
...
@@ -86,12 +86,12 @@ class UniformProb1(nn.Cell):
...
@@ -86,12 +86,12 @@ class UniformProb1(nn.Cell):
self
.
u
=
msd
.
Uniform
(
dtype
=
dtype
.
float32
)
self
.
u
=
msd
.
Uniform
(
dtype
=
dtype
.
float32
)
def
construct
(
self
,
value
,
low
,
high
):
def
construct
(
self
,
value
,
low
,
high
):
prob
=
self
.
u
(
'prob'
,
value
,
low
,
high
)
prob
=
self
.
u
.
prob
(
value
,
low
,
high
)
log_prob
=
self
.
u
(
'log_prob'
,
value
,
low
,
high
)
log_prob
=
self
.
u
.
log_prob
(
value
,
low
,
high
)
cdf
=
self
.
u
(
'cdf'
,
value
,
low
,
high
)
cdf
=
self
.
u
.
cdf
(
value
,
low
,
high
)
log_cdf
=
self
.
u
(
'log_cdf'
,
value
,
low
,
high
)
log_cdf
=
self
.
u
.
log_cdf
(
value
,
low
,
high
)
sf
=
self
.
u
(
'survival_function'
,
value
,
low
,
high
)
sf
=
self
.
u
.
survival_function
(
value
,
low
,
high
)
log_sf
=
self
.
u
(
'log_survival'
,
value
,
low
,
high
)
log_sf
=
self
.
u
.
log_survival
(
value
,
low
,
high
)
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
return
prob
+
log_prob
+
cdf
+
log_cdf
+
sf
+
log_sf
def
test_uniform_prob1
():
def
test_uniform_prob1
():
...
@@ -115,8 +115,8 @@ class UniformKl(nn.Cell):
...
@@ -115,8 +115,8 @@ class UniformKl(nn.Cell):
self
.
u2
=
msd
.
Uniform
(
dtype
=
dtype
.
float32
)
self
.
u2
=
msd
.
Uniform
(
dtype
=
dtype
.
float32
)
def
construct
(
self
,
low_b
,
high_b
,
low_a
,
high_a
):
def
construct
(
self
,
low_b
,
high_b
,
low_a
,
high_a
):
kl1
=
self
.
u1
(
'kl_loss'
,
'Uniform'
,
low_b
,
high_b
)
kl1
=
self
.
u1
.
kl_loss
(
'Uniform'
,
low_b
,
high_b
)
kl2
=
self
.
u2
(
'kl_loss'
,
'Uniform'
,
low_b
,
high_b
,
low_a
,
high_a
)
kl2
=
self
.
u2
.
kl_loss
(
'Uniform'
,
low_b
,
high_b
,
low_a
,
high_a
)
return
kl1
+
kl2
return
kl1
+
kl2
def
test_kl
():
def
test_kl
():
...
@@ -141,8 +141,8 @@ class UniformCrossEntropy(nn.Cell):
...
@@ -141,8 +141,8 @@ class UniformCrossEntropy(nn.Cell):
self
.
u2
=
msd
.
Uniform
(
dtype
=
dtype
.
float32
)
self
.
u2
=
msd
.
Uniform
(
dtype
=
dtype
.
float32
)
def
construct
(
self
,
low_b
,
high_b
,
low_a
,
high_a
):
def
construct
(
self
,
low_b
,
high_b
,
low_a
,
high_a
):
h1
=
self
.
u1
(
'cross_entropy'
,
'Uniform'
,
low_b
,
high_b
)
h1
=
self
.
u1
.
cross_entropy
(
'Uniform'
,
low_b
,
high_b
)
h2
=
self
.
u2
(
'cross_entropy'
,
'Uniform'
,
low_b
,
high_b
,
low_a
,
high_a
)
h2
=
self
.
u2
.
cross_entropy
(
'Uniform'
,
low_b
,
high_b
,
low_a
,
high_a
)
return
h1
+
h2
return
h1
+
h2
def
test_cross_entropy
():
def
test_cross_entropy
():
...
@@ -166,10 +166,10 @@ class UniformBasics(nn.Cell):
...
@@ -166,10 +166,10 @@ class UniformBasics(nn.Cell):
self
.
u
=
msd
.
Uniform
(
3.0
,
4.0
,
dtype
=
dtype
.
float32
)
self
.
u
=
msd
.
Uniform
(
3.0
,
4.0
,
dtype
=
dtype
.
float32
)
def
construct
(
self
):
def
construct
(
self
):
mean
=
self
.
u
(
'mean'
)
mean
=
self
.
u
.
mean
(
)
sd
=
self
.
u
(
'sd'
)
sd
=
self
.
u
.
sd
(
)
var
=
self
.
u
(
'var'
)
var
=
self
.
u
.
var
(
)
entropy
=
self
.
u
(
'entropy'
)
entropy
=
self
.
u
.
entropy
(
)
return
mean
+
sd
+
var
+
entropy
return
mean
+
sd
+
var
+
entropy
def
test_bascis
():
def
test_bascis
():
...
@@ -179,3 +179,30 @@ def test_bascis():
...
@@ -179,3 +179,30 @@ def test_bascis():
net
=
UniformBasics
()
net
=
UniformBasics
()
ans
=
net
()
ans
=
net
()
assert
isinstance
(
ans
,
Tensor
)
assert
isinstance
(
ans
,
Tensor
)
class
UniConstruct
(
nn
.
Cell
):
"""
Unifrom distribution: going through construct.
"""
def
__init__
(
self
):
super
(
UniConstruct
,
self
).
__init__
()
self
.
u
=
msd
.
Uniform
(
-
4.0
,
4.0
)
self
.
u1
=
msd
.
Uniform
()
def
construct
(
self
,
value
,
low
,
high
):
prob
=
self
.
u
(
'prob'
,
value
)
prob1
=
self
.
u
(
'prob'
,
value
,
low
,
high
)
prob2
=
self
.
u1
(
'prob'
,
value
,
low
,
high
)
return
prob
+
prob1
+
prob2
def
test_uniform_construct
():
"""
Test probability function going through construct.
"""
net
=
UniConstruct
()
value
=
Tensor
([
-
5.0
,
0.0
,
1.0
,
5.0
],
dtype
=
dtype
.
float32
)
low
=
Tensor
([
-
1.0
],
dtype
=
dtype
.
float32
)
high
=
Tensor
([
1.0
],
dtype
=
dtype
.
float32
)
ans
=
net
(
value
,
low
,
high
)
assert
isinstance
(
ans
,
Tensor
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录