Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
bef1fc7f
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bef1fc7f
编写于
6月 26, 2020
作者:
P
peixu_ren
提交者:
Xun Deng
7月 07, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add sample functions in normal and bermoulli distributions
上级
0aa26c18
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
409 addition
and
202 deletion
+409
-202
mindspore/nn/distribution/_utils/utils.py
mindspore/nn/distribution/_utils/utils.py
+47
-38
mindspore/nn/distribution/bernoulli.py
mindspore/nn/distribution/bernoulli.py
+63
-22
mindspore/nn/distribution/distribution.py
mindspore/nn/distribution/distribution.py
+21
-53
mindspore/nn/distribution/normal.py
mindspore/nn/distribution/normal.py
+71
-26
tests/st/ops/ascend/test_distribution/test_bernoulli.py
tests/st/ops/ascend/test_distribution/test_bernoulli.py
+32
-13
tests/st/ops/ascend/test_distribution/test_normal.py
tests/st/ops/ascend/test_distribution/test_normal.py
+34
-12
tests/ut/python/nn/test_distribution.py
tests/ut/python/nn/test_distribution.py
+141
-38
未找到文件。
mindspore/nn/distribution/_utils/utils.py
浏览文件 @
bef1fc7f
...
...
@@ -15,9 +15,9 @@
# ============================================================================
"""Utitly functions to help distribution class."""
import
numpy
as
np
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
_utils
as
utils
from
....common.tensor
import
Tensor
from
....common.tensor
import
Tensor
,
MetaTensor
from
....common.parameter
import
Parameter
from
....common
import
dtype
as
mstype
...
...
@@ -33,15 +33,17 @@ def cast_to_tensor(t, dtype=mstype.float32):
Cast an user input value into a Tensor of dtype.
Args:
t (int
/float/list/numpy.ndarray/Tensor)
.
dtype (mindspore.dtype).
t (int
, float, list, numpy.ndarray, Tensor, Parameter): object to be cast to Tensor
.
dtype (mindspore.dtype)
: dtype of the Tensor. Default: mstype.float32
.
Raises:
RuntimeError: if t cannot be cast to Tensor.
Output
s:
Return
s:
Tensor.
"""
if
isinstance
(
t
,
Parameter
):
return
t
if
isinstance
(
t
,
Tensor
):
#check if the Tensor in shape of Tensor(4)
if
t
.
dim
()
==
0
:
...
...
@@ -61,9 +63,9 @@ def calc_batch_size(batch_shape):
Calculate the size of a given batch_shape.
Args:
batch_shape (tuple)
batch_shape (tuple)
: batch shape to be calculated.
Output
s:
Return
s:
int.
"""
return
int
(
np
.
prod
(
batch_shape
))
...
...
@@ -73,23 +75,26 @@ def convert_to_batch(t, batch_shape, dtype):
Convert a Tensor to a given batch shape.
Args:
t (Tensor)
batch_shape (tuple)
dtype (mindspore.dtype)
t (Tensor, Parameter): Tensor to be converted.
batch_shape (tuple): desired batch shape.
dtype (mindspore.dtype): desired dtype.
Raises:
RuntimeError: if the converison cannot be done.
Output
s:
Return
s:
Tensor, with shape of batch_shape.
"""
if
isinstance
(
t
,
Parameter
):
return
t
t
=
cast_to_tensor
(
t
,
dtype
)
reshape
=
P
.
Reshape
()
if
t
.
shape
!=
batch_shape
:
mul
=
calc_batch_size
(
batch_shape
)
//
t
.
size
()
if
(
calc_batch_size
(
batch_shape
)
%
t
.
size
())
!=
0
:
raise
RuntimeError
(
"Cannot cast the tensor to the given batch shape."
)
temp
=
list
(
t
.
asnumpy
())
*
mul
return
reshape
(
Tensor
(
temp
),
batch_shape
)
temp
=
np
.
reshape
(
temp
,
batch_shape
)
return
Tensor
(
temp
,
dtype
)
return
t
def
check_scalar_from_param
(
params
):
...
...
@@ -97,7 +102,7 @@ def check_scalar_from_param(params):
Check if params are all scalars.
Args:
params (dict): parameters used to initialize
d
distribution.
params (dict): parameters used to initialize distribution.
Notes: String parameters are excluded.
"""
...
...
@@ -116,9 +121,9 @@ def calc_broadcast_shape_from_param(params):
Calculate the broadcast shape from params.
Args:
params (dict): parameters used to initialize
d
distribution.
params (dict): parameters used to initialize distribution.
Output
s:
Return
s:
tuple.
"""
broadcast_shape
=
[]
...
...
@@ -127,7 +132,10 @@ def calc_broadcast_shape_from_param(params):
continue
if
value
is
None
:
return
None
value_t
=
cast_to_tensor
(
value
,
params
[
'dtype'
])
if
isinstance
(
value
,
Parameter
):
value_t
=
value
.
default_input
else
:
value_t
=
cast_to_tensor
(
value
,
params
[
'dtype'
])
broadcast_shape
=
utils
.
get_broadcast_shape
(
broadcast_shape
,
list
(
value_t
.
shape
),
params
[
'name'
])
return
tuple
(
broadcast_shape
)
...
...
@@ -136,36 +144,37 @@ def check_greater_equal_zero(value, name):
Check if the given Tensor is greater zero.
Args:
value (Tensor
)
value (Tensor
, Parameter): value to be checked.
name (str) : name of the value.
Raises:
ValueError: if the input value is less than zero.
"""
less
=
P
.
Less
()
zeros
=
Tensor
([
0.0
],
dtype
=
value
.
dtype
)
value
=
less
(
value
,
zeros
)
if
value
.
asnumpy
().
any
():
raise
ValueError
(
'{} should be greater than zero.'
.
format
(
name
))
if
isinstance
(
value
,
Parameter
):
if
isinstance
(
value
.
default_input
,
MetaTensor
):
return
value
=
value
.
default_input
comp
=
np
.
less
(
value
.
asnumpy
(),
np
.
zeros
(
value
.
shape
))
if
comp
.
any
():
raise
ValueError
(
f
'
{
name
}
should be greater than zero.'
)
def
check_greater
(
a
,
b
,
name_a
,
name_b
):
"""
Check if Tensor b is strictly greater than Tensor a.
Args:
a (Tensor)
b (Tensor)
a (Tensor)
: input tensor a.
b (Tensor)
: input tensor b.
name_a (str): name of Tensor_a.
name_b (str): name of Tensor_b.
Raises:
ValueError: if b is less than or equal to a
"""
less
=
P
.
Less
()
value
=
less
(
a
,
b
)
if
not
value
.
asnumpy
().
all
():
raise
ValueError
(
'{} should be less than {}'
.
format
(
name_a
,
name_b
))
comp
=
np
.
less
(
a
.
asnumpy
(),
b
.
asnumpy
())
if
not
comp
.
all
():
raise
ValueError
(
f
'
{
name_a
}
should be less than
{
name_b
}
'
)
def
check_prob
(
p
):
...
...
@@ -173,18 +182,18 @@ def check_prob(p):
Check if p is a proper probability, i.e. 0 <= p <=1.
Args:
p (Tensor
): value to check
.
p (Tensor
, Parameter): value to be checked
.
Raises:
ValueError: if p is not a proper probability.
"""
less
=
P
.
Less
()
greater
=
P
.
Greater
()
zeros
=
Tensor
([
0.0
],
dtype
=
p
.
dtype
)
ones
=
Tensor
([
1.0
],
dtype
=
p
.
dtype
)
comp
=
less
(
p
,
zeros
)
if
comp
.
a
snumpy
().
a
ny
():
if
isinstance
(
p
,
Parameter
):
if
isinstance
(
p
.
default_input
,
MetaTensor
):
return
p
=
p
.
default_input
comp
=
np
.
less
(
p
.
asnumpy
(),
np
.
zeros
(
p
.
shape
)
)
if
comp
.
any
():
raise
ValueError
(
'Probabilities should be greater than or equal to zero'
)
comp
=
greater
(
p
,
ones
)
if
comp
.
a
snumpy
().
a
ny
():
comp
=
np
.
greater
(
p
.
asnumpy
(),
np
.
ones
(
p
.
shape
)
)
if
comp
.
any
():
raise
ValueError
(
'Probabilities should be less than or equal to one'
)
mindspore/nn/distribution/bernoulli.py
浏览文件 @
bef1fc7f
...
...
@@ -23,21 +23,24 @@ class Bernoulli(Distribution):
Example class: Bernoulli Distribution.
Args:
probs (int/float/list/numpy.ndarray/Tensor): probability of 1 as outcome.
dtype (mindspore.dtype): type of the distribution, default to int32.
probs (int, float, list, numpy.ndarray, Tensor, Parameter): probability of 1 as outcome.
seed (int): seed to use in sampling. Default: 0.
dtype (mindspore.dtype): type of the distribution. Default: mstype.int32.
name (str): name of the distribution. Default: Bernoulli.
Note:
probs should be proper probabilities (0 <= p <= 1).
Examples:
>>> # To initialize a Bernoulli distribution which has equal probability of getting 1 and 0
>>> b = nn.Bernoulli(0.5, dtype =
d
type.int32)
>>> b = nn.Bernoulli(0.5, dtype =
ms
type.int32)
>>> # The following create two independent Bernoulli distributions
>>> b = nn.Bernoulli([0.7, 0.2], dtype =
d
type.int32)
>>> b = nn.Bernoulli([0.7, 0.2], dtype =
ms
type.int32)
"""
def
__init__
(
self
,
probs
=
None
,
seed
=
0
,
dtype
=
mstype
.
int32
,
name
=
"Bernoulli"
):
"""
...
...
@@ -47,7 +50,6 @@ class Bernoulli(Distribution):
super
(
Bernoulli
,
self
).
__init__
(
dtype
,
name
,
param
)
if
probs
is
not
None
:
self
.
_probs
=
cast_to_tensor
(
probs
)
# check if the input probability is valid
check_prob
(
self
.
_probs
)
else
:
self
.
_probs
=
probs
...
...
@@ -58,7 +60,17 @@ class Bernoulli(Distribution):
self
.
mul
=
P
.
Mul
()
self
.
sqrt
=
P
.
Sqrt
()
self
.
realdiv
=
P
.
RealDiv
()
self
.
shape
=
P
.
Shape
()
self
.
const
=
P
.
ScalarToArray
()
self
.
less
=
P
.
Less
()
self
.
cast
=
P
.
Cast
()
self
.
normal
=
P
.
Normal
(
seed
=
seed
)
self
.
erf
=
P
.
Erf
()
self
.
sqrt
=
P
.
Sqrt
()
def
extend_repr
(
self
):
str_info
=
f
'probs =
{
self
.
_probs
}
'
return
str_info
def
probs
(
self
):
"""
...
...
@@ -66,21 +78,25 @@ class Bernoulli(Distribution):
"""
return
self
.
_probs
def
_mean
(
self
):
def
_mean
(
self
,
name
=
'mean'
,
probs1
=
None
):
r
"""
.. math::
MEAN(B) = probs1
"""
if
name
==
'mean'
:
return
self
.
_probs
if
probs1
is
None
else
probs1
return
None
return
self
.
_probs
def
_var
(
self
):
def
_var
(
self
,
name
=
'var'
,
probs1
=
None
):
r
"""
.. math::
VAR(B) = probs1 * probs0
"""
probs0
=
self
.
add
(
1
,
-
1
*
self
.
_probs
)
return
self
.
mul
(
probs0
,
self
.
_probs
)
if
name
in
(
'sd'
,
'var'
):
probs1
=
self
.
_probs
if
probs1
is
None
else
probs1
probs0
=
self
.
add
(
1
,
-
1
*
probs1
)
return
self
.
mul
(
probs0
,
probs1
)
return
None
def
_prob
(
self
,
name
,
value
,
probs
=
None
):
r
"""
...
...
@@ -89,18 +105,20 @@ class Bernoulli(Distribution):
Args:
name (str): name of the function. Should be "prob" when passed in from construct.
value (Tensor): a Tensor composed of only zeros and ones.
probs (Tensor): probability of outcome is 1. Default
to
self._probs.
probs (Tensor): probability of outcome is 1. Default
:
self._probs.
.. math::
pmf(k) = probs1 if k = 1;
pmf(k) = probs0 if k = 0;
"""
probs1
=
self
.
_probs
if
probs
is
None
else
probs
probs0
=
self
.
add
(
1
,
-
1
*
probs1
)
return
self
.
add
(
self
.
mul
(
probs1
,
value
),
self
.
mul
(
probs0
,
self
.
add
(
1
,
-
1
*
value
)))
if
name
in
(
'prob'
,
'log_prob'
):
probs1
=
self
.
_probs
if
probs
is
None
else
probs
probs0
=
self
.
add
(
1
,
-
1
*
probs1
)
return
self
.
add
(
self
.
mul
(
probs1
,
value
),
self
.
mul
(
probs0
,
self
.
add
(
1
,
-
1
*
value
)))
return
None
def
_kl_loss
(
self
,
name
,
dist
,
probs1_b
):
def
_kl_loss
(
self
,
name
,
dist
,
probs1_b
,
probs1_a
=
None
):
r
"""
Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b).
...
...
@@ -108,19 +126,42 @@ class Bernoulli(Distribution):
name (str): name of the funtion. Should always be "kl_loss" when passed in from construct.
dist (str): type of the distributions. Should be "Bernoulli" in this case.
probs1_b (Tensor): probs1 of distribution b.
probs1_a (Tensor): probs1 of distribution a. Default: self._probs.
.. math::
KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) +
probs0_a * \log(\fract{probs0_a}{probs0_b})
"""
if
dist
==
'Bernoulli'
:
probs1_a
=
self
.
_probs
if
name
==
'kl_loss'
and
dist
==
'Bernoulli'
:
probs1_a
=
self
.
_probs
if
probs1_a
is
None
else
probs1_a
probs0_a
=
self
.
add
(
1
,
-
1
*
probs1_a
)
probs0_b
=
self
.
add
(
1
,
-
1
*
probs1_b
)
return
self
.
add
(
probs1_a
*
self
.
log
(
self
.
realdiv
(
probs1_a
,
probs1_b
)),
probs0_a
*
self
.
log
(
self
.
realdiv
(
probs0_a
,
probs0_b
)))
return
None
def
extend_repr
(
self
):
str_info
=
'probs={}'
.
format
(
self
.
_probs
)
return
str_info
def
_sample
(
self
,
name
,
shape
=
(),
probs
=
None
):
"""
Sampling.
Args:
name (str): name of the function. Should always be 'sample' when passed in from construct.
shape (tuple): shape of the sample. Default: ().
probs (Tensor): probs1 of the samples. Default: self._probs.
Returns:
Tensor, shape is shape + batch_shape.
"""
if
name
==
'sample'
:
probs1
=
self
.
_probs
if
probs
is
None
else
probs
batch_shape
=
self
.
shape
(
probs1
)
sample_shape
=
shape
+
batch_shape
mean_zero
=
self
.
const
(
0.0
)
sd_one
=
self
.
const
(
1.0
)
sqrt_two
=
self
.
sqrt
(
self
.
const
(
2.0
))
sample_norm
=
self
.
normal
(
sample_shape
,
mean_zero
,
sd_one
)
sample_uniform
=
0.5
*
(
1
+
self
.
erf
(
self
.
realdiv
(
sample_norm
,
sqrt_two
)))
sample
=
self
.
less
(
sample_uniform
,
probs1
)
sample
=
self
.
cast
(
sample
,
self
.
_dtype
)
return
sample
return
None
mindspore/nn/distribution/distribution.py
浏览文件 @
bef1fc7f
...
...
@@ -21,6 +21,11 @@ class Distribution(Cell):
"""
Base class for all mathematical distributions.
Args:
dtype (mindspore.dtype): type of the distribution.
name (str): name of the distribution.
param (dict): parameters used to initialize the distribution.
Note:
Derived class should override operations such as ,_mean, _prob,
and _log_prob. Functions should be called through construct when
...
...
@@ -97,14 +102,8 @@ class Distribution(Cell):
Note:
value is casted to Tensor for further calculation.
Args:
name (str): name of the calling function.
value (Tensor): values to be evaluated.
mean (Tensor): mean of the distirbution. Default: self.mean.
sd (Tensor): standard deviation of the distribution. Default: self.sd.
Outputs:
Tensor, shape: broadcast_shape of the distribution.
Returns:
Tensor, shape is the broadcast_shape of the distribution.
"""
return
self
.
_call_log_prob
(
*
args
)
...
...
@@ -114,36 +113,9 @@ class Distribution(Cell):
.. math::
probability(x) = \exp(log_likehood(x))
Args:
name (str): name of the calling function.
value (Tensor): values to be evaluated.
mean (Tensor): mean of the distribution. Default: self.mean.
sd (Tensor): standard deviation of the distritbuion. Default: self.sd.
"""
return
self
.
exp
(
self
.
_log_likelihood
(
*
args
))
def
_call_prob
(
self
,
*
args
):
"""
Raises:
NotImplementedError when derived class didn't override _prob or _log_likelihood.
"""
raise
NotImplementedError
(
'pdf/pmf is not implemented: {}'
.
format
(
type
(
self
).
__name__
))
def
_call_log_prob
(
self
,
*
args
):
"""
Raises:
NotImplementedError when derived class didn't override _prob or _log_likelihood.
"""
raise
NotImplementedError
(
'log_probability is not implemented: {}'
.
format
(
type
(
self
).
__name__
))
def
_call_sd
(
self
):
"""
Raises:
NotImplementedError when derived class didn't override _sd or _var.
"""
raise
NotImplementedError
(
'standard deviation is not implemented: {}'
.
format
(
type
(
self
).
__name__
))
def
prob
(
self
,
*
args
):
"""
Evaluate the prob (pdf or pmf) at given value.
...
...
@@ -151,14 +123,8 @@ class Distribution(Cell):
Note:
value is casted to Tensor for further calculation.
Args:
name (str): name of the calling function.
value (Tensor): values to be evaluated.
mean (Tensor): mean of the distribution.
sd (Tensor): standard deviation of the distritbuion.
Outputs:
Tensor, shape: broadcast_shape of the distribution.
Returns:
Tensor, shape is the broadcast_shape of the distribution.
"""
return
self
.
_call_prob
(
*
args
)
...
...
@@ -176,8 +142,8 @@ class Distribution(Cell):
Evaluate the KL divergence. Parameters of the second distribution should be
passed in through **kwargs.
Output
s:
Tensor, shape
:
broadcast_shape of the distribution and input distribution.
Return
s:
Tensor, shape
is the
broadcast_shape of the distribution and input distribution.
"""
return
self
.
_kl_loss
(
**
kwargs
)
...
...
@@ -185,8 +151,8 @@ class Distribution(Cell):
"""
Evaluate the mean.
Output
s:
Tensor, shape
:
broadcast_shape of the distribution.
Return
s:
Tensor, shape
is the
broadcast_shape of the distribution.
"""
return
self
.
_mean
(
**
kwargs
)
...
...
@@ -194,19 +160,19 @@ class Distribution(Cell):
"""
Evaluate the standard deviation.
Output
s:
Tensor,
with shape of
broadcast_shape of the distribution.
Return
s:
Tensor,
shape is the
broadcast_shape of the distribution.
"""
return
self
.
_call_sd
(
**
kwargs
)
def
_calc_sd_from_var
(
self
,
*
*
kw
args
):
def
_calc_sd_from_var
(
self
,
*
args
):
r
"""
Evaluate log probability from probability.
.. math::
STD(x) = \sqrt(VAR(x))
"""
return
self
.
sqrt
(
self
.
_var
(
*
*
kw
args
))
return
self
.
sqrt
(
self
.
_var
(
*
args
))
def
construct
(
self
,
*
inputs
):
"""
...
...
@@ -226,7 +192,9 @@ class Distribution(Cell):
if
inputs
[
0
]
==
'kl_loss'
:
return
self
.
_kl_loss
(
*
inputs
)
if
inputs
[
0
]
==
'mean'
:
return
self
.
_mean
()
return
self
.
_mean
(
*
inputs
)
if
inputs
[
0
]
==
'sd'
:
return
self
.
_call_sd
()
return
self
.
_call_sd
(
*
inputs
)
if
inputs
[
0
]
==
'sample'
:
return
self
.
_sample
(
*
inputs
)
return
None
mindspore/nn/distribution/normal.py
浏览文件 @
bef1fc7f
...
...
@@ -25,23 +25,27 @@ class Normal(Distribution):
Example class: Normal distribution.
Args:
mean (int/float/list/numpy.ndarray/Tensor): mean of the Gaussian distribution
standard deviation (int/float/list/numpy.ndarray/Tensor): vairance of the Gaussian distribution
dtype (mindspore.dtype): type of the distribution
mean (int, float, list, numpy.ndarray, Tensor, Parameter): mean of the Gaussian distribution.
sd (int, float, list, numpy.ndarray, Tensor, Parameter): stddev of the Gaussian distribution.
seed (int): seed to use in sampling. Default: 0.
dtype (mindspore.dtype): type of the distribution. Default: mstype.float32.
name (str): name of the distribution. Default: Normal.
Note:
Standard deviation should be greater than zero.
Examples:
>>> # To initialize a normal distribution of mean 3.0 and standard deviation 4.0
>>> n = nn.Normal(3.0, 4.0, dtype=
d
type.float32)
>>> n = nn.Normal(3.0, 4.0, dtype=
ms
type.float32)
>>> # The following create two independent normal distributions
>>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=
d
type.float32)
>>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=
ms
type.float32)
"""
def
__init__
(
self
,
mean
=
None
,
sd
=
None
,
seed
=
0
,
dtype
=
mstype
.
float32
,
name
=
"Normal"
):
"""
...
...
@@ -52,7 +56,6 @@ class Normal(Distribution):
if
mean
is
not
None
and
sd
is
not
None
:
self
.
_mean_value
=
convert_to_batch
(
mean
,
self
.
_broadcast_shape
,
dtype
)
self
.
_sd_value
=
convert_to_batch
(
sd
,
self
.
_broadcast_shape
,
dtype
)
#check validity of standard deviation
check_greater_equal_zero
(
self
.
_sd_value
,
"Standard deviation"
)
else
:
self
.
_mean_value
=
mean
...
...
@@ -61,11 +64,20 @@ class Normal(Distribution):
#ops needed for the class
self
.
exp
=
P
.
Exp
()
self
.
add
=
P
.
TensorAdd
()
self
.
mul
=
P
.
Mul
()
self
.
sq
=
P
.
Square
()
self
.
log
=
P
.
Log
()
self
.
sqrt
=
P
.
Sqrt
()
self
.
realdiv
=
P
.
RealDiv
()
self
.
expm1
=
P
.
Expm1
()
if
get_context
(
'device_target'
)
==
'Ascend'
else
self
.
_expm1_by_step
self
.
normal
=
P
.
Normal
(
seed
=
seed
)
self
.
shape
=
P
.
Shape
()
self
.
zeroslike
=
P
.
ZerosLike
()
self
.
const
=
P
.
ScalarToArray
()
def
extend_repr
(
self
):
str_info
=
f
'mean =
{
self
.
_mean_value
}
, standard deviation =
{
self
.
_sd_value
}
'
return
str_info
def
_expm1_by_step
(
self
,
x
):
"""
...
...
@@ -73,17 +85,23 @@ class Normal(Distribution):
"""
return
self
.
add
(
self
.
exp
(
x
),
-
1
)
def
_mean
(
self
):
def
_mean
(
self
,
name
=
'mean'
,
mean
=
None
,
sd
=
None
):
"""
Mean of the distribution.
"""
return
self
.
_mean_value
if
name
==
'mean'
:
mean
=
self
.
_mean_value
if
mean
is
None
or
sd
is
None
else
mean
return
mean
return
None
def
_sd
(
self
):
def
_sd
(
self
,
name
=
'sd'
,
mean
=
None
,
sd
=
None
):
"""
Standard deviation of the distribution.
"""
return
self
.
_sd_value
if
name
in
(
'sd'
,
'var'
):
sd
=
self
.
_sd_value
if
mean
is
None
or
sd
is
None
else
sd
return
sd
return
None
def
_log_likelihood
(
self
,
name
,
value
,
mean
=
None
,
sd
=
None
):
r
"""
...
...
@@ -92,33 +110,60 @@ class Normal(Distribution):
.. math::
L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
"""
mean
=
self
.
_mean_value
if
mean
is
None
else
mean
sd
=
self
.
_sd_value
if
sd
is
None
else
sd
unnormalized_log_prob
=
-
1.
*
self
.
realdiv
(
self
.
sq
(
self
.
add
(
value
,
-
1.
*
mean
)),
2.
*
self
.
sq
(
sd
))
neg_normalization
=
-
1.
*
self
.
log
(
self
.
sqrt
(
2.
*
np
.
pi
*
self
.
sq
(
sd
)))
return
self
.
add
(
unnormalized_log_prob
,
neg_normalization
)
def
_kl_loss
(
self
,
name
,
dist
,
mean
,
sd
):
if
name
in
(
'prob'
,
'log_prob'
):
mean
=
self
.
_mean_value
if
mean
is
None
else
mean
sd
=
self
.
_sd_value
if
sd
is
None
else
sd
unnormalized_log_prob
=
-
1.
*
self
.
realdiv
(
self
.
sq
(
self
.
add
(
value
,
-
1.
*
mean
)),
2.
*
self
.
sq
(
sd
))
neg_normalization
=
-
1.
*
self
.
log
(
self
.
sqrt
(
2.
*
np
.
pi
*
self
.
sq
(
sd
)))
return
self
.
add
(
unnormalized_log_prob
,
neg_normalization
)
return
None
def
_kl_loss
(
self
,
name
,
dist
,
mean_b
,
sd_b
,
mean_a
=
None
,
sd_a
=
None
):
r
"""
Evaluate Normal-Normal kl divergence, i.e. KL(a||b).
Args:
name (str): name of the funtion passed in from construct. Should always be "kl_loss".
dist (str): type of the distributions. Should be "Normal" in this case.
mean (Tensor): mean of distribution b.
sd (Tensor): standard deviation distribution b.
mean_b (Tensor): mean of distribution b.
sd_b (Tensor): standard deviation distribution b.
mean_a (Tensor): mean of distribution a. Default: self._mean_value.
sd_a (Tensor): standard deviation distribution a. Default: self._sd_value.
.. math::
KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 +
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
"""
if
dist
==
'Normal'
:
diff_log_scale
=
self
.
add
(
self
.
log
(
self
.
_sd_value
),
-
self
.
log
(
sd
))
squared_diff
=
self
.
sq
(
self
.
add
(
self
.
realdiv
(
self
.
_mean_value
,
sd
),
-
self
.
realdiv
(
mean
,
sd
)))
if
name
==
'kl_loss'
and
dist
==
'Normal'
:
mean_a
=
self
.
_mean_value
if
mean_a
is
None
else
mean_a
sd_a
=
self
.
_sd_value
if
sd_a
is
None
else
sd_a
diff_log_scale
=
self
.
add
(
self
.
log
(
sd_a
),
-
self
.
log
(
sd_b
))
squared_diff
=
self
.
sq
(
self
.
add
(
self
.
realdiv
(
mean_a
,
sd_b
),
-
self
.
realdiv
(
mean_b
,
sd_b
)))
return
self
.
add
(
self
.
add
(
0.5
*
squared_diff
,
0.5
*
self
.
expm1
(
2
*
diff_log_scale
)),
-
diff_log_scale
)
return
None
def
extend_repr
(
self
):
str_info
=
'mean={}, standard deviation={}'
.
format
(
self
.
_mean_value
,
self
.
_sd_value
)
return
str_info
def
_sample
(
self
,
name
,
shape
=
(),
mean
=
None
,
sd
=
None
):
"""
Sampling.
Args:
name (str): name of the function. Should always be 'sample' when passed in from construct.
shape (tuple): shape of the sample. Default: ().
mean (Tensor): mean of the samples. Default: self._mean_value.
sd (Tensor): standard deviation of the samples. Default: self._sd_value.
Returns:
Tensor, shape is shape + batch_shape.
"""
if
name
==
'sample'
:
mean
=
self
.
_mean_value
if
mean
is
None
else
mean
sd
=
self
.
_sd_value
if
sd
is
None
else
sd
batch_shape
=
self
.
shape
(
self
.
add
(
self
.
zeroslike
(
mean
),
self
.
zeroslike
(
sd
)))
sample_shape
=
shape
+
batch_shape
mean_zero
=
self
.
const
(
0.0
)
sd_one
=
self
.
const
(
1.0
)
sample_norm
=
self
.
normal
(
sample_shape
,
mean_zero
,
sd_one
)
sample
=
self
.
add
(
mean
,
self
.
mul
(
sample_norm
,
sd
))
return
sample
return
None
tests/st/ops/ascend/test_distribution/test_bernoulli.py
浏览文件 @
bef1fc7f
...
...
@@ -65,12 +65,25 @@ class Net3(nn.Cell):
"""
def
__init__
(
self
):
super
(
Net3
,
self
).
__init__
()
self
.
b
=
nn
.
Bernoulli
([
0.
7
,
0.5
],
dtype
=
dtype
.
int32
)
self
.
b
=
nn
.
Bernoulli
([
0.
5
,
0.5
],
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
):
return
self
.
b
(
'mean'
),
self
.
b
(
'sd'
)
class
Net4
(
nn
.
Cell
):
"""
Test class: log probability of bernoulli distribution.
"""
def
__init__
(
self
,
shape
,
seed
=
0
):
super
(
Net4
,
self
).
__init__
()
self
.
b
=
nn
.
Bernoulli
([
0.7
,
0.5
],
seed
=
seed
,
dtype
=
dtype
.
int32
)
self
.
shape
=
shape
@
ms_function
def
construct
(
self
,
probs
=
None
):
return
self
.
b
(
'sample'
,
self
.
shape
,
probs
)
def
test_pmf
():
"""
Test pmf.
...
...
@@ -80,10 +93,8 @@ def test_pmf():
pdf
=
Net
()
x_
=
Tensor
(
np
.
array
([
0
,
1
,
0
,
1
,
1
]).
astype
(
np
.
int32
),
dtype
=
dtype
.
float32
)
output
=
pdf
(
x_
)
print
(
"expected_pmf: "
,
expect_pmf
)
print
(
"ans: "
,
output
.
asnumpy
())
tol
=
1e-6
assert
(
output
.
asnumpy
()
-
expect_pmf
<
tol
).
all
()
assert
(
np
.
abs
(
output
.
asnumpy
()
-
expect_pmf
)
<
tol
).
all
()
def
test_log_likelihood
():
"""
...
...
@@ -94,10 +105,8 @@ def test_log_likelihood():
logprob
=
Net1
()
x_
=
Tensor
(
np
.
array
([
0
,
1
,
0
,
1
,
1
]).
astype
(
np
.
int32
),
dtype
=
dtype
.
float32
)
output
=
logprob
(
x_
)
print
(
"expected_log_probability: "
,
expect_logpmf
)
print
(
"ans: "
,
output
.
asnumpy
())
tol
=
1e-6
assert
(
output
.
asnumpy
()
-
expect_logpmf
<
tol
).
all
()
assert
(
np
.
abs
(
output
.
asnumpy
()
-
expect_logpmf
)
<
tol
).
all
()
def
test_kl_loss
():
"""
...
...
@@ -110,10 +119,8 @@ def test_kl_loss():
expect_kl_loss
=
probs1_a
*
np
.
log
(
probs1_a
/
probs1_b
)
+
probs0_a
*
np
.
log
(
probs0_a
/
probs0_b
)
kl_loss
=
Net2
()
output
=
kl_loss
(
Tensor
([
probs1_b
],
dtype
=
dtype
.
float32
))
print
(
"expected_kl_loss: "
,
expect_kl_loss
)
print
(
"ans: "
,
output
.
asnumpy
())
tol
=
1e-6
assert
(
output
.
asnumpy
()
-
expect_kl_loss
<
tol
).
all
()
assert
(
np
.
abs
(
output
.
asnumpy
()
-
expect_kl_loss
)
<
tol
).
all
()
def
test_basics
():
"""
...
...
@@ -121,8 +128,20 @@ def test_basics():
"""
basics
=
Net3
()
mean
,
sd
=
basics
()
print
(
"mean : "
,
mean
)
print
(
"sd : "
,
sd
)
expect_mean
=
[
0.5
,
0.5
]
assert
(
mean
.
asnumpy
()
==
expect_mean
).
all
()
assert
(
sd
.
asnumpy
()
==
expect_mean
).
all
()
b
=
nn
.
Bernoulli
([
0.7
,
0.5
],
dtype
=
dtype
.
int32
)
probs
=
b
.
probs
()
print
(
"probs is "
,
probs
)
expect_probs
=
[
0.7
,
0.5
]
tol
=
1e-6
assert
(
np
.
abs
(
probs
.
asnumpy
()
-
expect_probs
)
<
tol
).
all
()
def
test_sample
():
"""
Test sample.
"""
shape
=
(
2
,
3
)
sample
=
Net4
(
shape
)
output
=
sample
()
assert
output
.
shape
==
(
2
,
3
,
2
)
tests/st/ops/ascend/test_distribution/test_normal.py
浏览文件 @
bef1fc7f
...
...
@@ -65,12 +65,25 @@ class Net3(nn.Cell):
"""
def
__init__
(
self
):
super
(
Net3
,
self
).
__init__
()
self
.
n
=
nn
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([
[
2.0
],
[
4.0
]
]),
dtype
=
dtype
.
float32
)
self
.
n
=
nn
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([
2.0
,
4.0
]),
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
):
return
self
.
n
(
'mean'
),
self
.
n
(
'sd'
)
class
Net4
(
nn
.
Cell
):
"""
Test class: mean/sd of normal distribution.
"""
def
__init__
(
self
,
shape
,
seed
=
0
):
super
(
Net4
,
self
).
__init__
()
self
.
n
=
nn
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([[
2.0
],
[
4.0
]]),
seed
=
seed
,
dtype
=
dtype
.
float32
)
self
.
shape
=
shape
@
ms_function
def
construct
(
self
,
mean
=
None
,
sd
=
None
):
return
self
.
n
(
'sample'
,
self
.
shape
,
mean
,
sd
)
def
test_pdf
():
"""
Test pdf.
...
...
@@ -79,10 +92,8 @@ def test_pdf():
expect_pdf
=
norm_benchmark
.
pdf
([
1.0
,
2.0
]).
astype
(
np
.
float32
)
pdf
=
Net
()
output
=
pdf
(
Tensor
([
1.0
,
2.0
],
dtype
=
dtype
.
float32
))
print
(
"expected_pdf: "
,
expect_pdf
)
print
(
"ans: "
,
output
.
asnumpy
())
tol
=
1e-6
assert
(
output
.
asnumpy
()
-
expect_pdf
<
tol
).
all
()
assert
(
np
.
abs
(
output
.
asnumpy
()
-
expect_pdf
)
<
tol
).
all
()
def
test_log_likelihood
():
"""
...
...
@@ -92,10 +103,8 @@ def test_log_likelihood():
expect_logpdf
=
norm_benchmark
.
logpdf
([
1.0
,
2.0
]).
astype
(
np
.
float32
)
logprob
=
Net1
()
output
=
logprob
(
Tensor
([
1.0
,
2.0
],
dtype
=
dtype
.
float32
))
print
(
"expected_log_probability: "
,
expect_logpdf
)
print
(
"ans: "
,
output
.
asnumpy
())
tol
=
1e-6
assert
(
output
.
asnumpy
()
-
expect_logpdf
<
tol
).
all
()
assert
(
np
.
abs
(
output
.
asnumpy
()
-
expect_logpdf
)
<
tol
).
all
()
def
test_kl_loss
():
"""
...
...
@@ -115,10 +124,8 @@ def test_kl_loss():
mean
=
Tensor
(
mean_b
,
dtype
=
dtype
.
float32
)
sd
=
Tensor
(
sd_b
,
dtype
=
dtype
.
float32
)
output
=
kl_loss
(
mean
,
sd
)
print
(
"expected_kl_loss: "
,
expect_kl_loss
)
print
(
"ans: "
,
output
.
asnumpy
())
tol
=
1e-6
assert
(
output
.
asnumpy
()
-
expect_kl_loss
<
tol
).
all
()
assert
(
np
.
abs
(
output
.
asnumpy
()
-
expect_kl_loss
)
<
tol
).
all
()
def
test_basics
():
"""
...
...
@@ -126,5 +133,20 @@ def test_basics():
"""
basics
=
Net3
()
mean
,
sd
=
basics
()
print
(
"mean is "
,
mean
)
print
(
"sd is "
,
sd
)
expect_mean
=
[
3.0
,
3.0
]
expect_sd
=
[
2.0
,
4.0
]
tol
=
1e-6
assert
(
np
.
abs
(
mean
.
asnumpy
()
-
expect_mean
)
<
tol
).
all
()
assert
(
np
.
abs
(
sd
.
asnumpy
()
-
expect_sd
)
<
tol
).
all
()
def
test_sample
():
"""
Test sample.
"""
shape
=
(
2
,
3
)
seed
=
10
mean
=
Tensor
([
2.0
],
dtype
=
dtype
.
float32
)
sd
=
Tensor
([
2.0
,
2.0
,
2.0
],
dtype
=
dtype
.
float32
)
sample
=
Net4
(
shape
,
seed
=
seed
)
output
=
sample
(
mean
,
sd
)
assert
output
.
shape
==
(
2
,
3
,
3
)
tests/ut/python/nn/test_distribution.py
浏览文件 @
bef1fc7f
...
...
@@ -36,18 +36,18 @@ def test_no_arguments():
No args passed in during initialization.
"""
n
=
nn
.
Normal
()
assert
isinstance
(
n
,
nn
.
Distribution
)
b
=
nn
.
Bernoulli
()
print
(
n
)
print
(
b
)
assert
isinstance
(
b
,
nn
.
Distribution
)
def
test_with_arguments
():
"""
Args passed in during initialization.
"""
n
=
nn
.
Normal
([
3.0
],
[
4.0
],
dtype
=
dtype
.
float32
)
assert
isinstance
(
n
,
nn
.
Distribution
)
b
=
nn
.
Bernoulli
([
0.3
,
0.5
],
dtype
=
dtype
.
int32
)
print
(
n
)
print
(
b
)
assert
isinstance
(
b
,
nn
.
Distribution
)
class
NormalProb
(
nn
.
Cell
):
"""
...
...
@@ -69,8 +69,8 @@ def test_normal_prob():
net
=
NormalProb
()
value
=
Tensor
([
0.5
,
1.0
],
dtype
=
dtype
.
float32
)
pdf
,
log_pdf
=
net
(
value
)
print
(
"pdf: "
,
pdf
)
print
(
"log_pdf: "
,
log_pdf
)
assert
isinstance
(
pdf
,
Tensor
)
assert
isinstance
(
log_pdf
,
Tensor
)
class
NormalProb1
(
nn
.
Cell
):
"""
...
...
@@ -94,9 +94,8 @@ def test_normal_prob1():
mean
=
Tensor
([
0.0
],
dtype
=
dtype
.
float32
)
sd
=
Tensor
([
1.0
],
dtype
=
dtype
.
float32
)
pdf
,
log_pdf
=
net
(
value
,
mean
,
sd
)
print
(
"pdf: "
,
pdf
)
print
(
"log_pdf: "
,
log_pdf
)
assert
isinstance
(
pdf
,
Tensor
)
assert
isinstance
(
log_pdf
,
Tensor
)
class
NormalProb2
(
nn
.
Cell
):
"""
...
...
@@ -121,8 +120,8 @@ def test_normal_prob2():
mean
=
Tensor
([
0.0
],
dtype
=
dtype
.
float32
)
sd
=
Tensor
([
1.0
],
dtype
=
dtype
.
float32
)
pdf
,
log_pdf
=
net
(
value
,
mean
,
sd
)
print
(
"pdf: "
,
pdf
)
print
(
"log_pdf: "
,
log_pdf
)
assert
isinstance
(
pdf
,
Tensor
)
assert
isinstance
(
log_pdf
,
Tensor
)
class
BernoulliProb
(
nn
.
Cell
):
"""
...
...
@@ -133,9 +132,19 @@ class BernoulliProb(nn.Cell):
self
.
bernoulli
=
nn
.
Bernoulli
(
0.5
,
dtype
=
dtype
.
int32
)
def
construct
(
self
,
value
):
x
=
self
.
bernoulli
(
'prob'
,
value
)
y
=
self
.
bernoulli
(
'log_prob'
,
value
)
return
x
,
y
return
self
.
bernoulli
(
'prob'
,
value
)
class
BernoulliLogProb
(
nn
.
Cell
):
"""
Bernoulli distribution: initialize with probs.
"""
def
__init__
(
self
):
super
(
BernoulliLogProb
,
self
).
__init__
()
self
.
bernoulli
=
nn
.
Bernoulli
(
0.5
,
dtype
=
dtype
.
int32
)
def
construct
(
self
,
value
):
return
self
.
bernoulli
(
'log_prob'
,
value
)
def
test_bernoulli_prob
():
"""
...
...
@@ -143,10 +152,17 @@ def test_bernoulli_prob():
"""
net
=
BernoulliProb
()
value
=
Tensor
([
1
,
0
,
1
,
0
,
1
],
dtype
=
dtype
.
float32
)
ans
=
net
(
value
)
print
(
"pmf: "
,
ans
)
print
(
"log_pmf: "
,
ans
)
pmf
=
net
(
value
)
assert
isinstance
(
pmf
,
Tensor
)
def
test_bernoulli_log_prob
():
"""
Test pmf/log_pmf: passing value through construct.
"""
net
=
BernoulliLogProb
()
value
=
Tensor
([
1
,
0
,
1
,
0
,
1
],
dtype
=
dtype
.
float32
)
log_pmf
=
net
(
value
)
assert
isinstance
(
log_pmf
,
Tensor
)
class
BernoulliProb1
(
nn
.
Cell
):
"""
...
...
@@ -157,9 +173,19 @@ class BernoulliProb1(nn.Cell):
self
.
bernoulli
=
nn
.
Bernoulli
()
def
construct
(
self
,
value
,
probs
):
x
=
self
.
bernoulli
(
'prob'
,
value
,
probs
)
y
=
self
.
bernoulli
(
'log_prob'
,
value
,
probs
)
return
x
,
y
return
self
.
bernoulli
(
'prob'
,
value
,
probs
)
class
BernoulliLogProb1
(
nn
.
Cell
):
"""
Bernoulli distribution: initialize without probs.
"""
def
__init__
(
self
):
super
(
BernoulliLogProb1
,
self
).
__init__
()
self
.
bernoulli
=
nn
.
Bernoulli
()
def
construct
(
self
,
value
,
probs
):
return
self
.
bernoulli
(
'log_prob'
,
value
,
probs
)
def
test_bernoulli_prob1
():
"""
...
...
@@ -168,10 +194,18 @@ def test_bernoulli_prob1():
net
=
BernoulliProb1
()
value
=
Tensor
([
1
,
0
,
1
,
0
,
1
],
dtype
=
dtype
.
float32
)
probs
=
Tensor
([
0.3
],
dtype
=
dtype
.
float32
)
ans
=
net
(
value
,
probs
)
print
(
"pmf: "
,
ans
)
print
(
"log_pmf: "
,
ans
)
pmf
=
net
(
value
,
probs
)
assert
isinstance
(
pmf
,
Tensor
)
def
test_bernoulli_log_prob1
():
"""
Test pmf/log_pmf: passing probs through construct.
"""
net
=
BernoulliLogProb1
()
value
=
Tensor
([
1
,
0
,
1
,
0
,
1
],
dtype
=
dtype
.
float32
)
probs
=
Tensor
([
0.3
],
dtype
=
dtype
.
float32
)
log_pmf
=
net
(
value
,
probs
)
assert
isinstance
(
log_pmf
,
Tensor
)
class
BernoulliProb2
(
nn
.
Cell
):
"""
...
...
@@ -182,9 +216,19 @@ class BernoulliProb2(nn.Cell):
self
.
bernoulli
=
nn
.
Bernoulli
(
0.5
)
def
construct
(
self
,
value
,
probs
):
x
=
self
.
bernoulli
(
'prob'
,
value
,
probs
)
y
=
self
.
bernoulli
(
'log_prob'
,
value
,
probs
)
return
x
,
y
return
self
.
bernoulli
(
'prob'
,
value
,
probs
)
class
BernoulliLogProb2
(
nn
.
Cell
):
"""
Bernoulli distribution: initialize with probs.
"""
def
__init__
(
self
):
super
(
BernoulliLogProb2
,
self
).
__init__
()
self
.
bernoulli
=
nn
.
Bernoulli
(
0.5
)
def
construct
(
self
,
value
,
probs
):
return
self
.
bernoulli
(
'log_prob'
,
value
,
probs
)
def
test_bernoulli_prob2
():
"""
...
...
@@ -194,9 +238,20 @@ def test_bernoulli_prob2():
net
=
BernoulliProb2
()
value
=
Tensor
([
1
,
0
,
1
,
0
,
1
],
dtype
=
dtype
.
float32
)
probs
=
Tensor
([
0.3
],
dtype
=
dtype
.
float32
)
ans
=
net
(
value
,
probs
)
print
(
"pmf: "
,
ans
)
print
(
"log_pmf: "
,
ans
)
pmf
=
net
(
value
,
probs
)
assert
isinstance
(
pmf
,
Tensor
)
def
test_bernoulli_log_prob2
():
"""
Test pmf/log_pmf: passing probs/value through construct.
Overwrite original probs.
"""
net
=
BernoulliLogProb2
()
value
=
Tensor
([
1
,
0
,
1
,
0
,
1
],
dtype
=
dtype
.
float32
)
probs
=
Tensor
([
0.3
],
dtype
=
dtype
.
float32
)
log_pmf
=
net
(
value
,
probs
)
assert
isinstance
(
log_pmf
,
Tensor
)
class
NormalKl
(
nn
.
Cell
):
"""
...
...
@@ -229,13 +284,61 @@ def test_kl():
sd_b
=
np
.
array
([
1.0
]).
astype
(
np
.
float32
)
mean
=
Tensor
(
mean_b
,
dtype
=
dtype
.
float32
)
sd
=
Tensor
(
sd_b
,
dtype
=
dtype
.
float32
)
output
=
nor_net
(
mean
,
sd
)
print
(
"normal-normal kl loss: "
,
output
)
loss
=
nor_net
(
mean
,
sd
)
assert
isinstance
(
loss
,
Tensor
)
ber_net
=
BernoulliKl
()
probs_b
=
Tensor
([
0.3
],
dtype
=
dtype
.
float32
)
output
=
ber_net
(
probs_b
)
print
(
"bernoulli-bernoulli kl loss: "
,
output
)
loss
=
ber_net
(
probs_b
)
assert
isinstance
(
loss
,
Tensor
)
class
NormalKlNoArgs
(
nn
.
Cell
):
"""
Test class: kl_loss of Normal distribution.
No args during initialization.
"""
def
__init__
(
self
):
super
(
NormalKlNoArgs
,
self
).
__init__
()
self
.
n
=
nn
.
Normal
(
dtype
=
dtype
.
float32
)
def
construct
(
self
,
x_
,
y_
,
w_
,
v_
):
return
self
.
n
(
'kl_loss'
,
'Normal'
,
x_
,
y_
,
w_
,
v_
)
class
BernoulliKlNoArgs
(
nn
.
Cell
):
"""
Test class: kl_loss between Bernoulli distributions.
No args during initialization.
"""
def
__init__
(
self
):
super
(
BernoulliKlNoArgs
,
self
).
__init__
()
self
.
b
=
nn
.
Bernoulli
(
dtype
=
dtype
.
int32
)
def
construct
(
self
,
x_
,
y_
):
return
self
.
b
(
'kl_loss'
,
'Bernoulli'
,
x_
,
y_
)
def
test_kl_no_args
():
"""
Test kl_loss function.
"""
nor_net
=
NormalKlNoArgs
()
mean_b
=
np
.
array
([
1.0
]).
astype
(
np
.
float32
)
sd_b
=
np
.
array
([
1.0
]).
astype
(
np
.
float32
)
mean_a
=
np
.
array
([
2.0
]).
astype
(
np
.
float32
)
sd_a
=
np
.
array
([
3.0
]).
astype
(
np
.
float32
)
mean_b
=
Tensor
(
mean_b
,
dtype
=
dtype
.
float32
)
sd_b
=
Tensor
(
sd_b
,
dtype
=
dtype
.
float32
)
mean_a
=
Tensor
(
mean_a
,
dtype
=
dtype
.
float32
)
sd_a
=
Tensor
(
sd_a
,
dtype
=
dtype
.
float32
)
loss
=
nor_net
(
mean_b
,
sd_b
,
mean_a
,
sd_a
)
assert
isinstance
(
loss
,
Tensor
)
ber_net
=
BernoulliKlNoArgs
()
probs_b
=
Tensor
([
0.3
],
dtype
=
dtype
.
float32
)
probs_a
=
Tensor
([
0.7
],
dtype
=
dtype
.
float32
)
loss
=
ber_net
(
probs_b
,
probs_a
)
assert
isinstance
(
loss
,
Tensor
)
class
NormalBernoulli
(
nn
.
Cell
):
...
...
@@ -244,7 +347,7 @@ class NormalBernoulli(nn.Cell):
"""
def
__init__
(
self
):
super
(
NormalBernoulli
,
self
).
__init__
()
self
.
n
=
nn
.
Normal
(
3.0
,
4.0
,
dtype
=
dtype
.
in
t32
)
self
.
n
=
nn
.
Normal
(
3.0
,
4.0
,
dtype
=
dtype
.
floa
t32
)
self
.
b
=
nn
.
Bernoulli
(
0.5
,
dtype
=
dtype
.
int32
)
def
construct
(
self
):
...
...
@@ -260,7 +363,7 @@ def test_bascis():
"""
net
=
NormalBernoulli
()
normal_mean
,
normal_sd
,
bernoulli_mean
,
bernoulli_sd
=
net
()
print
(
"Mean of Normal distribution: "
,
normal_mean
)
print
(
"Standard deviation of Normal distribution: "
,
normal_sd
)
print
(
"Mean of Bernoulli distribution: "
,
bernoulli_mean
)
print
(
"Standard deviation of Bernoulli distribution: "
,
bernoulli_sd
)
assert
isinstance
(
normal_mean
,
Tensor
)
assert
isinstance
(
normal_sd
,
Tensor
)
assert
isinstance
(
bernoulli_mean
,
Tensor
)
assert
isinstance
(
bernoulli_sd
,
Tensor
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录