Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
0aa26c18
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0aa26c18
编写于
6月 26, 2020
作者:
X
Xun Deng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add high level abstract class Distribution and two example class:
Bernoulli and Normal
上级
9ba937b1
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
1252 addition
and
1 deletion
+1252
-1
mindspore/nn/__init__.py
mindspore/nn/__init__.py
+5
-1
mindspore/nn/distribution/__init__.py
mindspore/nn/distribution/__init__.py
+27
-0
mindspore/nn/distribution/_utils/__init__.py
mindspore/nn/distribution/_utils/__init__.py
+24
-0
mindspore/nn/distribution/_utils/utils.py
mindspore/nn/distribution/_utils/utils.py
+190
-0
mindspore/nn/distribution/bernoulli.py
mindspore/nn/distribution/bernoulli.py
+126
-0
mindspore/nn/distribution/distribution.py
mindspore/nn/distribution/distribution.py
+232
-0
mindspore/nn/distribution/normal.py
mindspore/nn/distribution/normal.py
+124
-0
tests/st/ops/ascend/test_distribution/test_bernoulli.py
tests/st/ops/ascend/test_distribution/test_bernoulli.py
+128
-0
tests/st/ops/ascend/test_distribution/test_normal.py
tests/st/ops/ascend/test_distribution/test_normal.py
+130
-0
tests/ut/python/nn/test_distribution.py
tests/ut/python/nn/test_distribution.py
+266
-0
未找到文件。
mindspore/nn/__init__.py
浏览文件 @
0aa26c18
...
...
@@ -17,13 +17,15 @@ Neural Networks Cells.
Pre-defined building blocks or computing units to construct Neural Networks.
"""
from
.
import
layer
,
loss
,
optim
,
metrics
,
wrap
from
.
import
layer
,
loss
,
optim
,
metrics
,
wrap
,
distribution
from
.cell
import
Cell
,
GraphKernel
from
.layer
import
*
from
.loss
import
*
from
.optim
import
*
from
.metrics
import
*
from
.wrap
import
*
from
.distribution
import
*
__all__
=
[
"Cell"
,
"GraphKernel"
]
__all__
.
extend
(
layer
.
__all__
)
...
...
@@ -31,5 +33,7 @@ __all__.extend(loss.__all__)
__all__
.
extend
(
optim
.
__all__
)
__all__
.
extend
(
metrics
.
__all__
)
__all__
.
extend
(
wrap
.
__all__
)
__all__
.
extend
(
distribution
.
__all__
)
__all__
.
sort
()
mindspore/nn/distribution/__init__.py
0 → 100644
浏览文件 @
0aa26c18
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Distribution.
The high-level components(Distributions) used to construct the probabilistic network.
"""
from
.distribution
import
Distribution
from
.normal
import
Normal
from
.bernoulli
import
Bernoulli
__all__
=
[
'Distribution'
,
'Normal'
,
'Bernoulli'
,]
mindspore/nn/distribution/_utils/__init__.py
0 → 100644
浏览文件 @
0aa26c18
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Distribution operation utility functions.
"""
from
.utils
import
*
__all__
=
[
'check_scalar'
,
'convert_to_batch'
,
'cast_to_tensor'
,
'calc_batch_size'
,
'check_greater'
,
'check_greater_equal_zero'
,
'calc_broadcast_shape_from_param'
,
'check_scalar_from_param'
,
'check_prob'
]
mindspore/nn/distribution/_utils/utils.py
0 → 100644
浏览文件 @
0aa26c18
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utitly functions to help distribution class."""
import
numpy
as
np
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
_utils
as
utils
from
....common.tensor
import
Tensor
from
....common
import
dtype
as
mstype
def
check_scalar
(
value
):
"""
Check if input value is a scalar.
"""
return
np
.
isscalar
(
value
)
def
cast_to_tensor
(
t
,
dtype
=
mstype
.
float32
):
"""
Cast an user input value into a Tensor of dtype.
Args:
t (int/float/list/numpy.ndarray/Tensor).
dtype (mindspore.dtype).
Raises:
RuntimeError: if t cannot be cast to Tensor.
Outputs:
Tensor.
"""
if
isinstance
(
t
,
Tensor
):
#check if the Tensor in shape of Tensor(4)
if
t
.
dim
()
==
0
:
value
=
t
.
asnumpy
()
return
Tensor
([
t
],
dtype
=
dtype
)
#convert the type of tensor to dtype
t
.
set_dtype
(
dtype
)
return
t
if
isinstance
(
t
,
(
list
,
np
.
ndarray
)):
return
Tensor
(
t
,
dtype
=
dtype
)
if
check_scalar
(
t
):
return
Tensor
([
t
],
dtype
=
dtype
)
raise
RuntimeError
(
"Input type is not supported."
)
def
calc_batch_size
(
batch_shape
):
"""
Calculate the size of a given batch_shape.
Args:
batch_shape (tuple)
Outputs:
int.
"""
return
int
(
np
.
prod
(
batch_shape
))
def
convert_to_batch
(
t
,
batch_shape
,
dtype
):
"""
Convert a Tensor to a given batch shape.
Args:
t (Tensor)
batch_shape (tuple)
dtype (mindspore.dtype)
Raises:
RuntimeError: if the converison cannot be done.
Outputs:
Tensor, with shape of batch_shape.
"""
t
=
cast_to_tensor
(
t
,
dtype
)
reshape
=
P
.
Reshape
()
if
t
.
shape
!=
batch_shape
:
mul
=
calc_batch_size
(
batch_shape
)
//
t
.
size
()
if
(
calc_batch_size
(
batch_shape
)
%
t
.
size
())
!=
0
:
raise
RuntimeError
(
"Cannot cast the tensor to the given batch shape."
)
temp
=
list
(
t
.
asnumpy
())
*
mul
return
reshape
(
Tensor
(
temp
),
batch_shape
)
return
t
def
check_scalar_from_param
(
params
):
"""
Check if params are all scalars.
Args:
params (dict): parameters used to initialized distribution.
Notes: String parameters are excluded.
"""
for
value
in
params
.
values
():
if
isinstance
(
value
,
(
str
,
type
(
params
[
'dtype'
]))):
continue
elif
check_scalar
(
value
):
continue
else
:
return
False
return
True
def
calc_broadcast_shape_from_param
(
params
):
"""
Calculate the broadcast shape from params.
Args:
params (dict): parameters used to initialized distribution.
Outputs:
tuple.
"""
broadcast_shape
=
[]
for
value
in
params
.
values
():
if
isinstance
(
value
,
(
str
,
type
(
params
[
'dtype'
]))):
continue
if
value
is
None
:
return
None
value_t
=
cast_to_tensor
(
value
,
params
[
'dtype'
])
broadcast_shape
=
utils
.
get_broadcast_shape
(
broadcast_shape
,
list
(
value_t
.
shape
),
params
[
'name'
])
return
tuple
(
broadcast_shape
)
def
check_greater_equal_zero
(
value
,
name
):
"""
Check if the given Tensor is greater zero.
Args:
value (Tensor)
name (str) : name of the value.
Raises:
ValueError: if the input value is less than zero.
"""
less
=
P
.
Less
()
zeros
=
Tensor
([
0.0
],
dtype
=
value
.
dtype
)
value
=
less
(
value
,
zeros
)
if
value
.
asnumpy
().
any
():
raise
ValueError
(
'{} should be greater than zero.'
.
format
(
name
))
def
check_greater
(
a
,
b
,
name_a
,
name_b
):
"""
Check if Tensor b is strictly greater than Tensor a.
Args:
a (Tensor)
b (Tensor)
name_a (str): name of Tensor_a.
name_b (str): name of Tensor_b.
Raises:
ValueError: if b is less than or equal to a
"""
less
=
P
.
Less
()
value
=
less
(
a
,
b
)
if
not
value
.
asnumpy
().
all
():
raise
ValueError
(
'{} should be less than {}'
.
format
(
name_a
,
name_b
))
def
check_prob
(
p
):
"""
Check if p is a proper probability, i.e. 0 <= p <=1.
Args:
p (Tensor): value to check.
Raises:
ValueError: if p is not a proper probability.
"""
less
=
P
.
Less
()
greater
=
P
.
Greater
()
zeros
=
Tensor
([
0.0
],
dtype
=
p
.
dtype
)
ones
=
Tensor
([
1.0
],
dtype
=
p
.
dtype
)
comp
=
less
(
p
,
zeros
)
if
comp
.
asnumpy
().
any
():
raise
ValueError
(
'Probabilities should be greater than or equal to zero'
)
comp
=
greater
(
p
,
ones
)
if
comp
.
asnumpy
().
any
():
raise
ValueError
(
'Probabilities should be less than or equal to one'
)
mindspore/nn/distribution/bernoulli.py
0 → 100644
浏览文件 @
0aa26c18
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bernoulli Distribution"""
from
mindspore.ops
import
operations
as
P
from
.distribution
import
Distribution
from
._utils.utils
import
cast_to_tensor
,
check_prob
from
...common
import
dtype
as
mstype
class
Bernoulli
(
Distribution
):
"""
Example class: Bernoulli Distribution.
Args:
probs (int/float/list/numpy.ndarray/Tensor): probability of 1 as outcome.
dtype (mindspore.dtype): type of the distribution, default to int32.
Note:
probs should be proper probabilities (0 <= p <= 1).
Examples:
>>> # To initialize a Bernoulli distribution which has equal probability of getting 1 and 0
>>> b = nn.Bernoulli(0.5, dtype = dtype.int32)
>>> # The following create two independent Bernoulli distributions
>>> b = nn.Bernoulli([0.7, 0.2], dtype = dtype.int32)
"""
def
__init__
(
self
,
probs
=
None
,
dtype
=
mstype
.
int32
,
name
=
"Bernoulli"
):
"""
Constructor of Bernoulli distribution.
"""
param
=
dict
(
locals
())
super
(
Bernoulli
,
self
).
__init__
(
dtype
,
name
,
param
)
if
probs
is
not
None
:
self
.
_probs
=
cast_to_tensor
(
probs
)
# check if the input probability is valid
check_prob
(
self
.
_probs
)
else
:
self
.
_probs
=
probs
# ops needed for the class
self
.
log
=
P
.
Log
()
self
.
add
=
P
.
TensorAdd
()
self
.
mul
=
P
.
Mul
()
self
.
sqrt
=
P
.
Sqrt
()
self
.
realdiv
=
P
.
RealDiv
()
def
probs
(
self
):
"""
Returns the probability for the outcome is 1.
"""
return
self
.
_probs
def
_mean
(
self
):
r
"""
.. math::
MEAN(B) = probs1
"""
return
self
.
_probs
def
_var
(
self
):
r
"""
.. math::
VAR(B) = probs1 * probs0
"""
probs0
=
self
.
add
(
1
,
-
1
*
self
.
_probs
)
return
self
.
mul
(
probs0
,
self
.
_probs
)
def
_prob
(
self
,
name
,
value
,
probs
=
None
):
r
"""
pmf of Bernoulli distribution.
Args:
name (str): name of the function. Should be "prob" when passed in from construct.
value (Tensor): a Tensor composed of only zeros and ones.
probs (Tensor): probability of outcome is 1. Default to self._probs.
.. math::
pmf(k) = probs1 if k = 1;
pmf(k) = probs0 if k = 0;
"""
probs1
=
self
.
_probs
if
probs
is
None
else
probs
probs0
=
self
.
add
(
1
,
-
1
*
probs1
)
return
self
.
add
(
self
.
mul
(
probs1
,
value
),
self
.
mul
(
probs0
,
self
.
add
(
1
,
-
1
*
value
)))
def
_kl_loss
(
self
,
name
,
dist
,
probs1_b
):
r
"""
Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b).
Args:
name (str): name of the funtion. Should always be "kl_loss" when passed in from construct.
dist (str): type of the distributions. Should be "Bernoulli" in this case.
probs1_b (Tensor): probs1 of distribution b.
.. math::
KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) +
probs0_a * \log(\fract{probs0_a}{probs0_b})
"""
if
dist
==
'Bernoulli'
:
probs1_a
=
self
.
_probs
probs0_a
=
self
.
add
(
1
,
-
1
*
probs1_a
)
probs0_b
=
self
.
add
(
1
,
-
1
*
probs1_b
)
return
self
.
add
(
probs1_a
*
self
.
log
(
self
.
realdiv
(
probs1_a
,
probs1_b
)),
probs0_a
*
self
.
log
(
self
.
realdiv
(
probs0_a
,
probs0_b
)))
return
None
def
extend_repr
(
self
):
str_info
=
'probs={}'
.
format
(
self
.
_probs
)
return
str_info
mindspore/nn/distribution/distribution.py
0 → 100644
浏览文件 @
0aa26c18
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""basic"""
from
..cell
import
Cell
from
._utils.utils
import
calc_broadcast_shape_from_param
class
Distribution
(
Cell
):
"""
Base class for all mathematical distributions.
Note:
Derived class should override operations such as ,_mean, _prob,
and _log_prob. Functions should be called through construct when
used inside a network in the form of function name followed by
arguments.
Examples:
>>> class MyNormalDistribution(Distribution):
>>> def __init__(self):
>>> super(MyDistribution, self).__init__()
>>> self._mean_value = Tensor([2.0,3.0])
>>> self._sd_value = Tensor([2.0,3.0])
>>>
>>> def _mean(self):
>>> return self._mean_value
"""
def
__init__
(
self
,
dtype
,
name
,
param
):
"""
Constructor of distribution class.
"""
super
(
Distribution
,
self
).
__init__
()
self
.
_name
=
name
self
.
_dtype
=
dtype
self
.
_parameters
=
{}
# parsing parameters
for
k
in
param
.
keys
():
if
not
(
k
==
'self'
or
k
.
startswith
(
'_'
)):
self
.
_parameters
[
k
]
=
param
[
k
]
# some attributes
self
.
_broadcast_shape
=
calc_broadcast_shape_from_param
(
self
.
_parameters
)
# set the function to call according to the derived class's attributes
self
.
_set_prob
()
self
.
_set_log_prob
()
self
.
_set_sd
()
def
_set_prob
(
self
):
"""
Set probability funtion based on the availability of _prob and _log_likehood.
"""
if
hasattr
(
self
,
'_prob'
):
self
.
_call_prob
=
self
.
_prob
elif
hasattr
(
self
,
'_log_likelihood'
):
self
.
_call_prob
=
self
.
_calc_prob_from_log_likelihood
def
_set_sd
(
self
):
"""
Set standard deviation based on the availability of _sd and _var.
"""
if
hasattr
(
self
,
'_sd'
):
self
.
_call_sd
=
self
.
_sd
elif
hasattr
(
self
,
'_var'
):
self
.
_call_sd
=
self
.
_calc_sd_from_var
def
_set_log_prob
(
self
):
"""
Set log probability based on the availability of _prob and _log_likelihood.
"""
if
hasattr
(
self
,
'_log_likelihood'
):
self
.
_call_log_prob
=
self
.
_log_likelihood
if
hasattr
(
self
,
'_prob'
):
self
.
_call_log_prob
=
self
.
_calc_log_prob_from_prob
def
log_likelihood
(
self
,
*
args
):
"""
Evaluate the log probability at the given value.
Note:
value is casted to Tensor for further calculation.
Args:
name (str): name of the calling function.
value (Tensor): values to be evaluated.
mean (Tensor): mean of the distirbution. Default: self.mean.
sd (Tensor): standard deviation of the distribution. Default: self.sd.
Outputs:
Tensor, shape: broadcast_shape of the distribution.
"""
return
self
.
_call_log_prob
(
*
args
)
def
_calc_prob_from_log_likelihood
(
self
,
*
args
):
r
"""
Evaluate prob from log probability.
.. math::
probability(x) = \exp(log_likehood(x))
Args:
name (str): name of the calling function.
value (Tensor): values to be evaluated.
mean (Tensor): mean of the distribution. Default: self.mean.
sd (Tensor): standard deviation of the distritbuion. Default: self.sd.
"""
return
self
.
exp
(
self
.
_log_likelihood
(
*
args
))
def
_call_prob
(
self
,
*
args
):
"""
Raises:
NotImplementedError when derived class didn't override _prob or _log_likelihood.
"""
raise
NotImplementedError
(
'pdf/pmf is not implemented: {}'
.
format
(
type
(
self
).
__name__
))
def
_call_log_prob
(
self
,
*
args
):
"""
Raises:
NotImplementedError when derived class didn't override _prob or _log_likelihood.
"""
raise
NotImplementedError
(
'log_probability is not implemented: {}'
.
format
(
type
(
self
).
__name__
))
def
_call_sd
(
self
):
"""
Raises:
NotImplementedError when derived class didn't override _sd or _var.
"""
raise
NotImplementedError
(
'standard deviation is not implemented: {}'
.
format
(
type
(
self
).
__name__
))
def
prob
(
self
,
*
args
):
"""
Evaluate the prob (pdf or pmf) at given value.
Note:
value is casted to Tensor for further calculation.
Args:
name (str): name of the calling function.
value (Tensor): values to be evaluated.
mean (Tensor): mean of the distribution.
sd (Tensor): standard deviation of the distritbuion.
Outputs:
Tensor, shape: broadcast_shape of the distribution.
"""
return
self
.
_call_prob
(
*
args
)
def
_calc_log_prob_from_prob
(
self
,
*
args
):
r
"""
Evaluate log probability from probability.
.. math::
log_prob(x) = \log(prob(x))
"""
return
self
.
log
(
self
.
_prob
(
*
args
))
def
kl_loss
(
self
,
**
kwargs
):
"""
Evaluate the KL divergence. Parameters of the second distribution should be
passed in through **kwargs.
Outputs:
Tensor, shape: broadcast_shape of the distribution and input distribution.
"""
return
self
.
_kl_loss
(
**
kwargs
)
def
mean
(
self
,
**
kwargs
):
"""
Evaluate the mean.
Outputs:
Tensor, shape: broadcast_shape of the distribution.
"""
return
self
.
_mean
(
**
kwargs
)
def
sd
(
self
,
**
kwargs
):
"""
Evaluate the standard deviation.
Outputs:
Tensor, with shape of broadcast_shape of the distribution.
"""
return
self
.
_call_sd
(
**
kwargs
)
def
_calc_sd_from_var
(
self
,
**
kwargs
):
r
"""
Evaluate log probability from probability.
.. math::
STD(x) = \sqrt(VAR(x))
"""
return
self
.
sqrt
(
self
.
_var
(
**
kwargs
))
def
construct
(
self
,
*
inputs
):
"""
Override construct in Cell.
Args:
*inputs: inputs[0] is always the name of the function.
Notes:
Always raise RuntimeError as Distribution should not be called directly.
"""
if
inputs
[
0
]
==
'log_prob'
:
return
self
.
_call_log_prob
(
*
inputs
)
if
inputs
[
0
]
==
'prob'
:
return
self
.
_call_prob
(
*
inputs
)
if
inputs
[
0
]
==
'kl_loss'
:
return
self
.
_kl_loss
(
*
inputs
)
if
inputs
[
0
]
==
'mean'
:
return
self
.
_mean
()
if
inputs
[
0
]
==
'sd'
:
return
self
.
_call_sd
()
return
None
mindspore/nn/distribution/normal.py
0 → 100644
浏览文件 @
0aa26c18
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Normal Distribution"""
import
numpy
as
np
from
mindspore.ops
import
operations
as
P
from
.distribution
import
Distribution
from
._utils.utils
import
convert_to_batch
,
check_greater_equal_zero
from
...common
import
dtype
as
mstype
from
...context
import
get_context
class
Normal
(
Distribution
):
"""
Example class: Normal distribution.
Args:
mean (int/float/list/numpy.ndarray/Tensor): mean of the Gaussian distribution
standard deviation (int/float/list/numpy.ndarray/Tensor): vairance of the Gaussian distribution
dtype (mindspore.dtype): type of the distribution
Note:
Standard deviation should be greater than zero.
Examples:
>>> # To initialize a normal distribution of mean 3.0 and standard deviation 4.0
>>> n = nn.Normal(3.0, 4.0, dtype=dtype.float32)
>>> # The following create two independent normal distributions
>>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=dtype.float32)
"""
def
__init__
(
self
,
mean
=
None
,
sd
=
None
,
dtype
=
mstype
.
float32
,
name
=
"Normal"
):
"""
Constructor of normal distribution.
"""
param
=
dict
(
locals
())
super
(
Normal
,
self
).
__init__
(
dtype
,
name
,
param
)
if
mean
is
not
None
and
sd
is
not
None
:
self
.
_mean_value
=
convert_to_batch
(
mean
,
self
.
_broadcast_shape
,
dtype
)
self
.
_sd_value
=
convert_to_batch
(
sd
,
self
.
_broadcast_shape
,
dtype
)
#check validity of standard deviation
check_greater_equal_zero
(
self
.
_sd_value
,
"Standard deviation"
)
else
:
self
.
_mean_value
=
mean
self
.
_sd_value
=
sd
#ops needed for the class
self
.
exp
=
P
.
Exp
()
self
.
add
=
P
.
TensorAdd
()
self
.
sq
=
P
.
Square
()
self
.
log
=
P
.
Log
()
self
.
sqrt
=
P
.
Sqrt
()
self
.
realdiv
=
P
.
RealDiv
()
self
.
expm1
=
P
.
Expm1
()
if
get_context
(
'device_target'
)
==
'Ascend'
else
self
.
_expm1_by_step
def
_expm1_by_step
(
self
,
x
):
"""
Expm1 ops under GPU context.
"""
return
self
.
add
(
self
.
exp
(
x
),
-
1
)
def
_mean
(
self
):
"""
Mean of the distribution.
"""
return
self
.
_mean_value
def
_sd
(
self
):
"""
Standard deviation of the distribution.
"""
return
self
.
_sd_value
def
_log_likelihood
(
self
,
name
,
value
,
mean
=
None
,
sd
=
None
):
r
"""
Evaluate log probability.
.. math::
L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
"""
mean
=
self
.
_mean_value
if
mean
is
None
else
mean
sd
=
self
.
_sd_value
if
sd
is
None
else
sd
unnormalized_log_prob
=
-
1.
*
self
.
realdiv
(
self
.
sq
(
self
.
add
(
value
,
-
1.
*
mean
)),
2.
*
self
.
sq
(
sd
))
neg_normalization
=
-
1.
*
self
.
log
(
self
.
sqrt
(
2.
*
np
.
pi
*
self
.
sq
(
sd
)))
return
self
.
add
(
unnormalized_log_prob
,
neg_normalization
)
def
_kl_loss
(
self
,
name
,
dist
,
mean
,
sd
):
r
"""
Evaluate Normal-Normal kl divergence, i.e. KL(a||b).
Args:
name (str): name of the funtion passed in from construct. Should always be "kl_loss".
dist (str): type of the distributions. Should be "Normal" in this case.
mean (Tensor): mean of distribution b.
sd (Tensor): standard deviation distribution b.
.. math::
KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 +
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
"""
if
dist
==
'Normal'
:
diff_log_scale
=
self
.
add
(
self
.
log
(
self
.
_sd_value
),
-
self
.
log
(
sd
))
squared_diff
=
self
.
sq
(
self
.
add
(
self
.
realdiv
(
self
.
_mean_value
,
sd
),
-
self
.
realdiv
(
mean
,
sd
)))
return
self
.
add
(
self
.
add
(
0.5
*
squared_diff
,
0.5
*
self
.
expm1
(
2
*
diff_log_scale
)),
-
diff_log_scale
)
return
None
def
extend_repr
(
self
):
str_info
=
'mean={}, standard deviation={}'
.
format
(
self
.
_mean_value
,
self
.
_sd_value
)
return
str_info
tests/st/ops/ascend/test_distribution/test_bernoulli.py
0 → 100644
浏览文件 @
0aa26c18
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for bernoulli distribution"""
import
numpy
as
np
from
scipy
import
stats
import
mindspore.context
as
context
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore.common.api
import
ms_function
from
mindspore
import
dtype
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
class
Net
(
nn
.
Cell
):
"""
Test class: probability of bernoulli distribution.
"""
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
b
=
nn
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
,
x_
):
return
self
.
b
(
'prob'
,
x_
)
class
Net1
(
nn
.
Cell
):
"""
Test class: log probability of bernoulli distribution.
"""
def
__init__
(
self
):
super
(
Net1
,
self
).
__init__
()
self
.
b
=
nn
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
,
x_
):
return
self
.
b
(
'log_prob'
,
x_
)
class
Net2
(
nn
.
Cell
):
"""
Test class: kl_loss between bernoulli distributions.
"""
def
__init__
(
self
):
super
(
Net2
,
self
).
__init__
()
self
.
b
=
nn
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
,
x_
):
return
self
.
b
(
'kl_loss'
,
'Bernoulli'
,
x_
)
class
Net3
(
nn
.
Cell
):
"""
Test class: mean/sd of bernoulli distribution.
"""
def
__init__
(
self
):
super
(
Net3
,
self
).
__init__
()
self
.
b
=
nn
.
Bernoulli
([
0.7
,
0.5
],
dtype
=
dtype
.
int32
)
@
ms_function
def
construct
(
self
):
return
self
.
b
(
'mean'
),
self
.
b
(
'sd'
)
def
test_pmf
():
"""
Test pmf.
"""
bernoulli_benchmark
=
stats
.
bernoulli
(
0.7
)
expect_pmf
=
bernoulli_benchmark
.
pmf
([
0
,
1
,
0
,
1
,
1
]).
astype
(
np
.
float32
)
pdf
=
Net
()
x_
=
Tensor
(
np
.
array
([
0
,
1
,
0
,
1
,
1
]).
astype
(
np
.
int32
),
dtype
=
dtype
.
float32
)
output
=
pdf
(
x_
)
print
(
"expected_pmf: "
,
expect_pmf
)
print
(
"ans: "
,
output
.
asnumpy
())
tol
=
1e-6
assert
(
output
.
asnumpy
()
-
expect_pmf
<
tol
).
all
()
def
test_log_likelihood
():
"""
Test log_pmf.
"""
bernoulli_benchmark
=
stats
.
bernoulli
(
0.7
)
expect_logpmf
=
bernoulli_benchmark
.
logpmf
([
0
,
1
,
0
,
1
,
1
]).
astype
(
np
.
float32
)
logprob
=
Net1
()
x_
=
Tensor
(
np
.
array
([
0
,
1
,
0
,
1
,
1
]).
astype
(
np
.
int32
),
dtype
=
dtype
.
float32
)
output
=
logprob
(
x_
)
print
(
"expected_log_probability: "
,
expect_logpmf
)
print
(
"ans: "
,
output
.
asnumpy
())
tol
=
1e-6
assert
(
output
.
asnumpy
()
-
expect_logpmf
<
tol
).
all
()
def
test_kl_loss
():
"""
Test kl_loss.
"""
probs1_a
=
0.7
probs1_b
=
0.5
probs0_a
=
1
-
probs1_a
probs0_b
=
1
-
probs1_b
expect_kl_loss
=
probs1_a
*
np
.
log
(
probs1_a
/
probs1_b
)
+
probs0_a
*
np
.
log
(
probs0_a
/
probs0_b
)
kl_loss
=
Net2
()
output
=
kl_loss
(
Tensor
([
probs1_b
],
dtype
=
dtype
.
float32
))
print
(
"expected_kl_loss: "
,
expect_kl_loss
)
print
(
"ans: "
,
output
.
asnumpy
())
tol
=
1e-6
assert
(
output
.
asnumpy
()
-
expect_kl_loss
<
tol
).
all
()
def
test_basics
():
"""
Test mean/standard deviation and probs.
"""
basics
=
Net3
()
mean
,
sd
=
basics
()
print
(
"mean : "
,
mean
)
print
(
"sd : "
,
sd
)
b
=
nn
.
Bernoulli
([
0.7
,
0.5
],
dtype
=
dtype
.
int32
)
probs
=
b
.
probs
()
print
(
"probs is "
,
probs
)
tests/st/ops/ascend/test_distribution/test_normal.py
0 → 100644
浏览文件 @
0aa26c18
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for normal distribution"""
import
numpy
as
np
from
scipy
import
stats
import
mindspore.context
as
context
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore.common.api
import
ms_function
from
mindspore
import
dtype
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
class
Net
(
nn
.
Cell
):
"""
Test class: probability of normal distribution.
"""
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
n
=
nn
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([[
2.0
],
[
4.0
]]),
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
):
return
self
.
n
(
'prob'
,
x_
)
class
Net1
(
nn
.
Cell
):
"""
Test class: log probability of normal distribution.
"""
def
__init__
(
self
):
super
(
Net1
,
self
).
__init__
()
self
.
n
=
nn
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([[
2.0
],
[
4.0
]]),
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
):
return
self
.
n
(
'log_prob'
,
x_
)
class
Net2
(
nn
.
Cell
):
"""
Test class: kl_loss of normal distribution.
"""
def
__init__
(
self
):
super
(
Net2
,
self
).
__init__
()
self
.
n
=
nn
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([
4.0
]),
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
,
x_
,
y_
):
return
self
.
n
(
'kl_loss'
,
'Normal'
,
x_
,
y_
)
class
Net3
(
nn
.
Cell
):
"""
Test class: mean/sd of normal distribution.
"""
def
__init__
(
self
):
super
(
Net3
,
self
).
__init__
()
self
.
n
=
nn
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([[
2.0
],
[
4.0
]]),
dtype
=
dtype
.
float32
)
@
ms_function
def
construct
(
self
):
return
self
.
n
(
'mean'
),
self
.
n
(
'sd'
)
def
test_pdf
():
"""
Test pdf.
"""
norm_benchmark
=
stats
.
norm
(
np
.
array
([
3.0
]),
np
.
array
([[
2.0
],
[
4.0
]]))
expect_pdf
=
norm_benchmark
.
pdf
([
1.0
,
2.0
]).
astype
(
np
.
float32
)
pdf
=
Net
()
output
=
pdf
(
Tensor
([
1.0
,
2.0
],
dtype
=
dtype
.
float32
))
print
(
"expected_pdf: "
,
expect_pdf
)
print
(
"ans: "
,
output
.
asnumpy
())
tol
=
1e-6
assert
(
output
.
asnumpy
()
-
expect_pdf
<
tol
).
all
()
def
test_log_likelihood
():
"""
Test log_pdf.
"""
norm_benchmark
=
stats
.
norm
(
np
.
array
([
3.0
]),
np
.
array
([[
2.0
],
[
4.0
]]))
expect_logpdf
=
norm_benchmark
.
logpdf
([
1.0
,
2.0
]).
astype
(
np
.
float32
)
logprob
=
Net1
()
output
=
logprob
(
Tensor
([
1.0
,
2.0
],
dtype
=
dtype
.
float32
))
print
(
"expected_log_probability: "
,
expect_logpdf
)
print
(
"ans: "
,
output
.
asnumpy
())
tol
=
1e-6
assert
(
output
.
asnumpy
()
-
expect_logpdf
<
tol
).
all
()
def
test_kl_loss
():
"""
Test kl_loss.
"""
mean_a
=
np
.
array
([
3.0
]).
astype
(
np
.
float32
)
sd_a
=
np
.
array
([
4.0
]).
astype
(
np
.
float32
)
mean_b
=
np
.
array
([
1.0
]).
astype
(
np
.
float32
)
sd_b
=
np
.
array
([
1.0
]).
astype
(
np
.
float32
)
diff_log_scale
=
np
.
log
(
sd_a
)
-
np
.
log
(
sd_b
)
squared_diff
=
np
.
square
(
mean_a
/
sd_b
-
mean_b
/
sd_b
)
expect_kl_loss
=
0.5
*
squared_diff
+
0.5
*
np
.
expm1
(
2
*
diff_log_scale
)
-
diff_log_scale
kl_loss
=
Net2
()
mean
=
Tensor
(
mean_b
,
dtype
=
dtype
.
float32
)
sd
=
Tensor
(
sd_b
,
dtype
=
dtype
.
float32
)
output
=
kl_loss
(
mean
,
sd
)
print
(
"expected_kl_loss: "
,
expect_kl_loss
)
print
(
"ans: "
,
output
.
asnumpy
())
tol
=
1e-6
assert
(
output
.
asnumpy
()
-
expect_kl_loss
<
tol
).
all
()
def
test_basics
():
"""
Test mean/standard deviation.
"""
basics
=
Net3
()
mean
,
sd
=
basics
()
print
(
"mean is "
,
mean
)
print
(
"sd is "
,
sd
)
tests/ut/python/nn/test_distribution.py
0 → 100644
浏览文件 @
0aa26c18
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test nn.Distribution.
Including Normal Distribution and Bernoulli Distribution.
"""
import
pytest
import
numpy
as
np
import
mindspore.nn
as
nn
from
mindspore
import
dtype
from
mindspore
import
Tensor
def
test_normal_shape_errpr
():
"""
Invalid shapes.
"""
with
pytest
.
raises
(
ValueError
):
nn
.
Normal
([[
2.
],
[
1.
]],
[[
2.
],
[
3.
],
[
4.
]],
dtype
=
dtype
.
float32
)
def
test_no_arguments
():
"""
No args passed in during initialization.
"""
n
=
nn
.
Normal
()
b
=
nn
.
Bernoulli
()
print
(
n
)
print
(
b
)
def
test_with_arguments
():
"""
Args passed in during initialization.
"""
n
=
nn
.
Normal
([
3.0
],
[
4.0
],
dtype
=
dtype
.
float32
)
b
=
nn
.
Bernoulli
([
0.3
,
0.5
],
dtype
=
dtype
.
int32
)
print
(
n
)
print
(
b
)
class
NormalProb
(
nn
.
Cell
):
"""
Normal distribution: initialize with mean/sd.
"""
def
__init__
(
self
):
super
(
NormalProb
,
self
).
__init__
()
self
.
normal
=
nn
.
Normal
(
3.0
,
4.0
,
dtype
=
dtype
.
float32
)
def
construct
(
self
,
value
):
x
=
self
.
normal
(
'prob'
,
value
)
y
=
self
.
normal
(
'log_prob'
,
value
)
return
x
,
y
def
test_normal_prob
():
"""
Test pdf/log_pdf: passing value through construct.
"""
net
=
NormalProb
()
value
=
Tensor
([
0.5
,
1.0
],
dtype
=
dtype
.
float32
)
pdf
,
log_pdf
=
net
(
value
)
print
(
"pdf: "
,
pdf
)
print
(
"log_pdf: "
,
log_pdf
)
class
NormalProb1
(
nn
.
Cell
):
"""
Normal distribution: initialize without mean/sd.
"""
def
__init__
(
self
):
super
(
NormalProb1
,
self
).
__init__
()
self
.
normal
=
nn
.
Normal
()
def
construct
(
self
,
value
,
mean
,
sd
):
x
=
self
.
normal
(
'prob'
,
value
,
mean
,
sd
)
y
=
self
.
normal
(
'log_prob'
,
value
,
mean
,
sd
)
return
x
,
y
def
test_normal_prob1
():
"""
Test pdf/logpdf: passing mean/sd, value through construct.
"""
net
=
NormalProb1
()
value
=
Tensor
([
0.5
,
1.0
],
dtype
=
dtype
.
float32
)
mean
=
Tensor
([
0.0
],
dtype
=
dtype
.
float32
)
sd
=
Tensor
([
1.0
],
dtype
=
dtype
.
float32
)
pdf
,
log_pdf
=
net
(
value
,
mean
,
sd
)
print
(
"pdf: "
,
pdf
)
print
(
"log_pdf: "
,
log_pdf
)
class
NormalProb2
(
nn
.
Cell
):
"""
Normal distribution: initialize with mean/sd.
"""
def
__init__
(
self
):
super
(
NormalProb2
,
self
).
__init__
()
self
.
normal
=
nn
.
Normal
(
3.0
,
4.0
,
dtype
=
dtype
.
float32
)
def
construct
(
self
,
value
,
mean
,
sd
):
x
=
self
.
normal
(
'prob'
,
value
,
mean
,
sd
)
y
=
self
.
normal
(
'log_prob'
,
value
,
mean
,
sd
)
return
x
,
y
def
test_normal_prob2
():
"""
Test pdf/log_pdf: passing mean/sd through construct.
Overwrite original mean/sd.
"""
net
=
NormalProb2
()
value
=
Tensor
([
0.5
,
1.0
],
dtype
=
dtype
.
float32
)
mean
=
Tensor
([
0.0
],
dtype
=
dtype
.
float32
)
sd
=
Tensor
([
1.0
],
dtype
=
dtype
.
float32
)
pdf
,
log_pdf
=
net
(
value
,
mean
,
sd
)
print
(
"pdf: "
,
pdf
)
print
(
"log_pdf: "
,
log_pdf
)
class
BernoulliProb
(
nn
.
Cell
):
"""
Bernoulli distribution: initialize with probs.
"""
def
__init__
(
self
):
super
(
BernoulliProb
,
self
).
__init__
()
self
.
bernoulli
=
nn
.
Bernoulli
(
0.5
,
dtype
=
dtype
.
int32
)
def
construct
(
self
,
value
):
x
=
self
.
bernoulli
(
'prob'
,
value
)
y
=
self
.
bernoulli
(
'log_prob'
,
value
)
return
x
,
y
def
test_bernoulli_prob
():
"""
Test pmf/log_pmf: passing value through construct.
"""
net
=
BernoulliProb
()
value
=
Tensor
([
1
,
0
,
1
,
0
,
1
],
dtype
=
dtype
.
float32
)
ans
=
net
(
value
)
print
(
"pmf: "
,
ans
)
print
(
"log_pmf: "
,
ans
)
class
BernoulliProb1
(
nn
.
Cell
):
"""
Bernoulli distribution: initialize without probs.
"""
def
__init__
(
self
):
super
(
BernoulliProb1
,
self
).
__init__
()
self
.
bernoulli
=
nn
.
Bernoulli
()
def
construct
(
self
,
value
,
probs
):
x
=
self
.
bernoulli
(
'prob'
,
value
,
probs
)
y
=
self
.
bernoulli
(
'log_prob'
,
value
,
probs
)
return
x
,
y
def
test_bernoulli_prob1
():
"""
Test pmf/log_pmf: passing probs through construct.
"""
net
=
BernoulliProb1
()
value
=
Tensor
([
1
,
0
,
1
,
0
,
1
],
dtype
=
dtype
.
float32
)
probs
=
Tensor
([
0.3
],
dtype
=
dtype
.
float32
)
ans
=
net
(
value
,
probs
)
print
(
"pmf: "
,
ans
)
print
(
"log_pmf: "
,
ans
)
class
BernoulliProb2
(
nn
.
Cell
):
"""
Bernoulli distribution: initialize with probs.
"""
def
__init__
(
self
):
super
(
BernoulliProb2
,
self
).
__init__
()
self
.
bernoulli
=
nn
.
Bernoulli
(
0.5
)
def
construct
(
self
,
value
,
probs
):
x
=
self
.
bernoulli
(
'prob'
,
value
,
probs
)
y
=
self
.
bernoulli
(
'log_prob'
,
value
,
probs
)
return
x
,
y
def
test_bernoulli_prob2
():
"""
Test pmf/log_pmf: passing probs/value through construct.
Overwrite original probs.
"""
net
=
BernoulliProb2
()
value
=
Tensor
([
1
,
0
,
1
,
0
,
1
],
dtype
=
dtype
.
float32
)
probs
=
Tensor
([
0.3
],
dtype
=
dtype
.
float32
)
ans
=
net
(
value
,
probs
)
print
(
"pmf: "
,
ans
)
print
(
"log_pmf: "
,
ans
)
class
NormalKl
(
nn
.
Cell
):
"""
Test class: kl_loss of Normal distribution.
"""
def
__init__
(
self
):
super
(
NormalKl
,
self
).
__init__
()
self
.
n
=
nn
.
Normal
(
np
.
array
([
3.0
]),
np
.
array
([
4.0
]),
dtype
=
dtype
.
float32
)
def
construct
(
self
,
x_
,
y_
):
return
self
.
n
(
'kl_loss'
,
'Normal'
,
x_
,
y_
)
class
BernoulliKl
(
nn
.
Cell
):
"""
Test class: kl_loss between Bernoulli distributions.
"""
def
__init__
(
self
):
super
(
BernoulliKl
,
self
).
__init__
()
self
.
b
=
nn
.
Bernoulli
(
0.7
,
dtype
=
dtype
.
int32
)
def
construct
(
self
,
x_
):
return
self
.
b
(
'kl_loss'
,
'Bernoulli'
,
x_
)
def
test_kl
():
"""
Test kl_loss function.
"""
nor_net
=
NormalKl
()
mean_b
=
np
.
array
([
1.0
]).
astype
(
np
.
float32
)
sd_b
=
np
.
array
([
1.0
]).
astype
(
np
.
float32
)
mean
=
Tensor
(
mean_b
,
dtype
=
dtype
.
float32
)
sd
=
Tensor
(
sd_b
,
dtype
=
dtype
.
float32
)
output
=
nor_net
(
mean
,
sd
)
print
(
"normal-normal kl loss: "
,
output
)
ber_net
=
BernoulliKl
()
probs_b
=
Tensor
([
0.3
],
dtype
=
dtype
.
float32
)
output
=
ber_net
(
probs_b
)
print
(
"bernoulli-bernoulli kl loss: "
,
output
)
class
NormalBernoulli
(
nn
.
Cell
):
"""
Test class: basic mean/sd function.
"""
def
__init__
(
self
):
super
(
NormalBernoulli
,
self
).
__init__
()
self
.
n
=
nn
.
Normal
(
3.0
,
4.0
,
dtype
=
dtype
.
int32
)
self
.
b
=
nn
.
Bernoulli
(
0.5
,
dtype
=
dtype
.
int32
)
def
construct
(
self
):
normal_mean
=
self
.
n
(
'mean'
)
normal_sd
=
self
.
n
(
'sd'
)
bernoulli_mean
=
self
.
b
(
'mean'
)
bernoulli_sd
=
self
.
b
(
'sd'
)
return
normal_mean
,
normal_sd
,
bernoulli_mean
,
bernoulli_sd
def
test_bascis
():
"""
Test mean/sd functionality of Normal and Bernoulli.
"""
net
=
NormalBernoulli
()
normal_mean
,
normal_sd
,
bernoulli_mean
,
bernoulli_sd
=
net
()
print
(
"Mean of Normal distribution: "
,
normal_mean
)
print
(
"Standard deviation of Normal distribution: "
,
normal_sd
)
print
(
"Mean of Bernoulli distribution: "
,
bernoulli_mean
)
print
(
"Standard deviation of Bernoulli distribution: "
,
bernoulli_sd
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录