Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
415dad3a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
415dad3a
编写于
8月 07, 2020
作者:
X
Xun Deng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
added some parameter checking
上级
9ad82f79
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
154 addition
and
52 deletion
+154
-52
mindspore/nn/probability/distribution/_utils/__init__.py
mindspore/nn/probability/distribution/_utils/__init__.py
+11
-8
mindspore/nn/probability/distribution/_utils/utils.py
mindspore/nn/probability/distribution/_utils/utils.py
+18
-11
mindspore/nn/probability/distribution/bernoulli.py
mindspore/nn/probability/distribution/bernoulli.py
+7
-6
mindspore/nn/probability/distribution/distribution.py
mindspore/nn/probability/distribution/distribution.py
+14
-0
mindspore/nn/probability/distribution/exponential.py
mindspore/nn/probability/distribution/exponential.py
+7
-5
mindspore/nn/probability/distribution/geometric.py
mindspore/nn/probability/distribution/geometric.py
+8
-6
mindspore/nn/probability/distribution/normal.py
mindspore/nn/probability/distribution/normal.py
+9
-7
mindspore/nn/probability/distribution/transformed_distribution.py
...e/nn/probability/distribution/transformed_distribution.py
+12
-1
mindspore/nn/probability/distribution/uniform.py
mindspore/nn/probability/distribution/uniform.py
+10
-8
tests/ut/python/nn/distribution/test_bernoulli.py
tests/ut/python/nn/distribution/test_bernoulli.py
+12
-0
tests/ut/python/nn/distribution/test_exponential.py
tests/ut/python/nn/distribution/test_exponential.py
+12
-0
tests/ut/python/nn/distribution/test_geometric.py
tests/ut/python/nn/distribution/test_geometric.py
+12
-0
tests/ut/python/nn/distribution/test_normal.py
tests/ut/python/nn/distribution/test_normal.py
+11
-0
tests/ut/python/nn/distribution/test_uniform.py
tests/ut/python/nn/distribution/test_uniform.py
+11
-0
未找到文件。
mindspore/nn/probability/distribution/_utils/__init__.py
浏览文件 @
415dad3a
...
...
@@ -17,11 +17,14 @@ Distribution operation utility functions.
"""
from
.utils
import
*
__all__
=
[
'convert_to_batch'
,
__all__
=
[
'convert_to_batch'
,
'cast_to_tensor'
,
'check_greater'
,
'check_greater_equal_zero'
,
'check_greater_zero'
,
'calc_broadcast_shape_from_param'
,
'check_scalar_from_param'
,
'check_prob'
]
'check_prob'
,
'check_type'
,
]
mindspore/nn/probability/distribution/_utils/utils.py
浏览文件 @
415dad3a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -23,7 +22,7 @@ from mindspore.ops import operations as P
from
mindspore.ops
import
composite
as
C
import
mindspore.nn
as
nn
def
cast_to_tensor
(
t
,
dtype
=
mstype
.
float32
):
def
cast_to_tensor
(
t
,
hint_
dtype
=
mstype
.
float32
):
"""
Cast an user input value into a Tensor of dtype.
If the input t is of type Parameter, t is directly returned as a Parameter.
...
...
@@ -41,25 +40,26 @@ def cast_to_tensor(t, dtype=mstype.float32):
if
isinstance
(
t
,
Parameter
):
return
t
if
isinstance
(
t
,
Tensor
):
if
t
.
dtype
!=
hint_dtype
:
raise
TypeError
(
f
"Input tensor should be type
{
hint_dtype
}
."
)
#check if the Tensor in shape of Tensor(4)
if
t
.
dim
()
==
0
:
value
=
t
.
asnumpy
()
return
Tensor
([
t
],
dtype
=
dtype
)
return
Tensor
([
value
],
dtype
=
hint_
dtype
)
#convert the type of tensor to dtype
t
.
set_dtype
(
dtype
)
return
t
if
isinstance
(
t
,
(
list
,
np
.
ndarray
)):
return
Tensor
(
t
,
dtype
=
dtype
)
return
Tensor
(
t
,
dtype
=
hint_
dtype
)
if
np
.
isscalar
(
t
):
return
Tensor
([
t
],
dtype
=
dtype
)
return
Tensor
([
t
],
dtype
=
hint_
dtype
)
raise
RuntimeError
(
"Input type is not supported."
)
def
convert_to_batch
(
t
,
batch_shape
,
dtype
):
def
convert_to_batch
(
t
,
batch_shape
,
hint_
dtype
):
"""
Convert a Tensor to a given batch shape.
Args:
t (Tensor, Parameter): Tensor to be converted.
t (
int, float, list, numpy.ndarray,
Tensor, Parameter): Tensor to be converted.
batch_shape (tuple): desired batch shape.
dtype (mindspore.dtype): desired dtype.
...
...
@@ -71,9 +71,8 @@ def convert_to_batch(t, batch_shape, dtype):
"""
if
isinstance
(
t
,
Parameter
):
return
t
if
isinstance
(
t
,
Tensor
):
return
Tensor
(
np
.
broadcast_to
(
t
.
asnumpy
(),
batch_shape
),
dtype
=
dtype
)
return
Tensor
(
np
.
broadcast_to
(
t
,
batch_shape
),
dtype
=
dtype
)
t
=
cast_to_tensor
(
t
,
hint_dtype
)
return
Tensor
(
np
.
broadcast_to
(
t
.
asnumpy
(),
batch_shape
),
dtype
=
hint_dtype
)
def
check_scalar_from_param
(
params
):
"""
...
...
@@ -85,6 +84,8 @@ def check_scalar_from_param(params):
Notes: String parameters are excluded.
"""
for
value
in
params
.
values
():
if
isinstance
(
value
,
(
nn
.
probability
.
bijector
.
Bijector
,
nn
.
probability
.
distribution
.
Distribution
)):
return
params
[
'distribution'
].
is_scalar_batch
if
isinstance
(
value
,
Parameter
):
return
False
if
isinstance
(
value
,
(
str
,
type
(
params
[
'dtype'
]))):
...
...
@@ -108,6 +109,8 @@ def calc_broadcast_shape_from_param(params):
"""
broadcast_shape
=
[]
for
value
in
params
.
values
():
if
isinstance
(
value
,
(
nn
.
probability
.
bijector
.
Bijector
,
nn
.
probability
.
distribution
.
Distribution
)):
return
params
[
'distribution'
].
broadcast_shape
if
isinstance
(
value
,
(
str
,
type
(
params
[
'dtype'
]))):
continue
if
value
is
None
:
...
...
@@ -251,3 +254,7 @@ def check_tensor_type(name, inputs, valid_type):
inputs
=
P
.
DType
()(
inputs
)
if
inputs
not
in
valid_type
:
raise
TypeError
(
f
"
{
name
}
dtype is invalid"
)
def
check_type
(
data_type
,
value_type
,
name
):
if
not
data_type
in
value_type
:
raise
TypeError
(
f
"For
{
name
}
, valid type include
{
value_type
}
,
{
data_type
}
is invalid"
)
mindspore/nn/probability/distribution/bernoulli.py
浏览文件 @
415dad3a
...
...
@@ -16,7 +16,7 @@
from
mindspore.common
import
dtype
as
mstype
from
mindspore.ops
import
operations
as
P
from
.distribution
import
Distribution
from
._utils.utils
import
cast_to_tensor
,
check_prob
from
._utils.utils
import
cast_to_tensor
,
check_prob
,
check_type
class
Bernoulli
(
Distribution
):
"""
...
...
@@ -95,13 +95,14 @@ class Bernoulli(Distribution):
Constructor of Bernoulli distribution.
"""
param
=
dict
(
locals
())
super
(
Bernoulli
,
self
).
__init__
(
dtype
,
name
,
param
)
valid_dtype
=
mstype
.
int_type
+
mstype
.
uint_type
check_type
(
dtype
,
valid_dtype
,
"Bernoulli"
)
super
(
Bernoulli
,
self
).
__init__
(
seed
,
dtype
,
name
,
param
)
if
probs
is
not
None
:
self
.
_probs
=
cast_to_tensor
(
probs
,
dtype
=
mstype
.
float32
)
self
.
_probs
=
cast_to_tensor
(
probs
,
hint_
dtype
=
mstype
.
float32
)
check_prob
(
self
.
probs
)
else
:
self
.
_probs
=
probs
self
.
seed
=
seed
# ops needed for the class
self
.
cast
=
P
.
Cast
()
...
...
@@ -231,8 +232,8 @@ class Bernoulli(Distribution):
probs1_a (Tensor): probs1 of distribution a. Default: self.probs.
.. math::
KL(a||b) = probs1_a * \log(\frac
t
{probs1_a}{probs1_b}) +
probs0_a * \log(\frac
t
{probs0_a}{probs0_b})
KL(a||b) = probs1_a * \log(\frac{probs1_a}{probs1_b}) +
probs0_a * \log(\frac{probs0_a}{probs0_b})
"""
if
dist
==
'Bernoulli'
:
probs1_a
=
self
.
probs
if
probs1_a
is
None
else
probs1_a
...
...
mindspore/nn/probability/distribution/distribution.py
浏览文件 @
415dad3a
...
...
@@ -14,6 +14,7 @@
# ============================================================================
"""basic"""
from
mindspore.nn.cell
import
Cell
from
mindspore._checkparam
import
Validator
as
validator
from
._utils.utils
import
calc_broadcast_shape_from_param
,
check_scalar_from_param
class
Distribution
(
Cell
):
...
...
@@ -38,6 +39,7 @@ class Distribution(Cell):
original distribuion.
"""
def
__init__
(
self
,
seed
,
dtype
,
name
,
param
):
...
...
@@ -46,7 +48,11 @@ class Distribution(Cell):
Constructor of distribution class.
"""
super
(
Distribution
,
self
).
__init__
()
validator
.
check_value_type
(
'name'
,
name
,
[
str
],
'distribution_name'
)
validator
.
check_value_type
(
'seed'
,
seed
,
[
int
],
name
)
self
.
_name
=
name
self
.
_seed
=
seed
self
.
_dtype
=
dtype
self
.
_parameters
=
{}
# parsing parameters
...
...
@@ -77,6 +83,10 @@ class Distribution(Cell):
def
dtype
(
self
):
return
self
.
_dtype
@
property
def
seed
(
self
):
return
self
.
_seed
@
property
def
parameters
(
self
):
return
self
.
_parameters
...
...
@@ -85,6 +95,10 @@ class Distribution(Cell):
def
is_scalar_batch
(
self
):
return
self
.
_is_scalar_batch
@
property
def
broadcast_shape
(
self
):
return
self
.
_broadcast_shape
def
_set_prob
(
self
):
"""
Set probability funtion based on the availability of _prob and _log_likehood.
...
...
mindspore/nn/probability/distribution/exponential.py
浏览文件 @
415dad3a
...
...
@@ -17,7 +17,7 @@ import numpy as np
from
mindspore.ops
import
operations
as
P
from
mindspore.common
import
dtype
as
mstype
from
.distribution
import
Distribution
from
._utils.utils
import
cast_to_tensor
,
check_greater_zero
from
._utils.utils
import
cast_to_tensor
,
check_greater_zero
,
check_type
class
Exponential
(
Distribution
):
"""
...
...
@@ -96,9 +96,11 @@ class Exponential(Distribution):
Constructor of Exponential distribution.
"""
param
=
dict
(
locals
())
super
(
Exponential
,
self
).
__init__
(
dtype
,
name
,
param
)
valid_dtype
=
mstype
.
float_type
check_type
(
dtype
,
valid_dtype
,
"Exponential"
)
super
(
Exponential
,
self
).
__init__
(
seed
,
dtype
,
name
,
param
)
if
rate
is
not
None
:
self
.
_rate
=
cast_to_tensor
(
rate
,
mstype
.
float32
)
self
.
_rate
=
cast_to_tensor
(
rate
,
dtype
)
check_greater_zero
(
self
.
_rate
,
"rate"
)
else
:
self
.
_rate
=
rate
...
...
@@ -135,7 +137,7 @@ class Exponential(Distribution):
def
_mean
(
self
,
rate
=
None
):
r
"""
.. math::
MEAN(EXP) = \frac
t
{1.0}{\lambda}.
MEAN(EXP) = \frac{1.0}{\lambda}.
"""
rate
=
self
.
rate
if
rate
is
None
else
rate
return
1.0
/
rate
...
...
@@ -152,7 +154,7 @@ class Exponential(Distribution):
def
_sd
(
self
,
rate
=
None
):
r
"""
.. math::
sd(EXP) = \frac
t
{1.0}{\lambda}.
sd(EXP) = \frac{1.0}{\lambda}.
"""
rate
=
self
.
rate
if
rate
is
None
else
rate
return
1.0
/
rate
...
...
mindspore/nn/probability/distribution/geometric.py
浏览文件 @
415dad3a
...
...
@@ -17,7 +17,7 @@ import numpy as np
from
mindspore.ops
import
operations
as
P
from
mindspore.common
import
dtype
as
mstype
from
.distribution
import
Distribution
from
._utils.utils
import
cast_to_tensor
,
check_prob
from
._utils.utils
import
cast_to_tensor
,
check_prob
,
check_type
class
Geometric
(
Distribution
):
"""
...
...
@@ -97,9 +97,11 @@ class Geometric(Distribution):
Constructor of Geometric distribution.
"""
param
=
dict
(
locals
())
super
(
Geometric
,
self
).
__init__
(
dtype
,
name
,
param
)
valid_dtype
=
mstype
.
int_type
+
mstype
.
uint_type
check_type
(
dtype
,
valid_dtype
,
"Geometric"
)
super
(
Geometric
,
self
).
__init__
(
seed
,
dtype
,
name
,
param
)
if
probs
is
not
None
:
self
.
_probs
=
cast_to_tensor
(
probs
,
dtype
=
mstype
.
float32
)
self
.
_probs
=
cast_to_tensor
(
probs
,
hint_
dtype
=
mstype
.
float32
)
check_prob
(
self
.
_probs
)
else
:
self
.
_probs
=
probs
...
...
@@ -154,7 +156,7 @@ class Geometric(Distribution):
def
_var
(
self
,
probs1
=
None
):
r
"""
.. math::
VAR(Geo) = \frac
t
{1 - probs1}{probs1 ^ {2}}
VAR(Geo) = \frac{1 - probs1}{probs1 ^ {2}}
"""
probs1
=
self
.
probs
if
probs1
is
None
else
probs1
return
(
1.0
-
probs1
)
/
self
.
sq
(
probs1
)
...
...
@@ -162,7 +164,7 @@ class Geometric(Distribution):
def
_entropy
(
self
,
probs
=
None
):
r
"""
.. math::
H(Geo) = \frac
t
{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
H(Geo) = \frac{-1 * probs0 \log_2 (1-probs0)\ - prob1 * \log_2 (1-probs1)\ }{probs1}
"""
probs1
=
self
.
probs
if
probs
is
None
else
probs
probs0
=
1.0
-
probs1
...
...
@@ -244,7 +246,7 @@ class Geometric(Distribution):
probs1_a (Tensor): probability of success of distribution a. Default: self.probs.
.. math::
KL(a||b) = \log(\frac
t{probs1_a}{probs1_b}) + \fract{probs0_a}{probs1_a} * \log(\fract
{probs0_a}{probs0_b})
KL(a||b) = \log(\frac
{probs1_a}{probs1_b}) + \frac{probs0_a}{probs1_a} * \log(\frac
{probs0_a}{probs0_b})
"""
if
dist
==
'Geometric'
:
probs1_a
=
self
.
probs
if
probs1_a
is
None
else
probs1_a
...
...
mindspore/nn/probability/distribution/normal.py
浏览文件 @
415dad3a
...
...
@@ -18,7 +18,7 @@ from mindspore.ops import operations as P
from
mindspore.ops
import
composite
as
C
from
mindspore.common
import
dtype
as
mstype
from
.distribution
import
Distribution
from
._utils.utils
import
convert_to_batch
,
check_greater_equal_zero
from
._utils.utils
import
convert_to_batch
,
check_greater_equal_zero
,
check_type
class
Normal
(
Distribution
):
...
...
@@ -100,15 +100,17 @@ class Normal(Distribution):
Constructor of normal distribution.
"""
param
=
dict
(
locals
())
super
(
Normal
,
self
).
__init__
(
dtype
,
name
,
param
)
valid_dtype
=
mstype
.
float_type
check_type
(
dtype
,
valid_dtype
,
"Normal"
)
super
(
Normal
,
self
).
__init__
(
seed
,
dtype
,
name
,
param
)
if
mean
is
not
None
and
sd
is
not
None
:
self
.
_mean_value
=
convert_to_batch
(
mean
,
self
.
_
broadcast_shape
,
dtype
)
self
.
_sd_value
=
convert_to_batch
(
sd
,
self
.
_
broadcast_shape
,
dtype
)
self
.
_mean_value
=
convert_to_batch
(
mean
,
self
.
broadcast_shape
,
dtype
)
self
.
_sd_value
=
convert_to_batch
(
sd
,
self
.
broadcast_shape
,
dtype
)
check_greater_equal_zero
(
self
.
_sd_value
,
"Standard deviation"
)
else
:
self
.
_mean_value
=
mean
self
.
_sd_value
=
sd
self
.
seed
=
seed
#ops needed for the class
self
.
const
=
P
.
ScalarToArray
()
...
...
@@ -191,7 +193,7 @@ class Normal(Distribution):
sd (Tensor): standard deviation the distribution. Default: self._sd_value.
.. math::
L(x) = -1* \frac
t
{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
L(x) = -1* \frac{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
"""
mean
=
self
.
_mean_value
if
mean
is
None
else
mean
sd
=
self
.
_sd_value
if
sd
is
None
else
sd
...
...
@@ -229,7 +231,7 @@ class Normal(Distribution):
sd_a (Tensor): standard deviation distribution a. Default: self._sd_value.
.. math::
KL(a||b) = 0.5 * (\frac
t{MEAN(a)}{STD(b)} - \fract
{MEAN(b)}{STD(b)}) ^ 2 +
KL(a||b) = 0.5 * (\frac
{MEAN(a)}{STD(b)} - \frac
{MEAN(b)}{STD(b)}) ^ 2 +
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
"""
if
dist
==
'Normal'
:
...
...
mindspore/nn/probability/distribution/transformed_distribution.py
浏览文件 @
415dad3a
...
...
@@ -14,7 +14,11 @@
# ============================================================================
"""Transformed Distribution"""
from
mindspore.ops
import
operations
as
P
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore.common
import
dtype
as
mstype
import
mindspore.nn
as
nn
from
.distribution
import
Distribution
from
._utils.utils
import
check_type
class
TransformedDistribution
(
Distribution
):
"""
...
...
@@ -35,12 +39,19 @@ class TransformedDistribution(Distribution):
def
__init__
(
self
,
bijector
,
distribution
,
dtype
,
seed
=
0
,
name
=
"transformed_distribution"
):
"""
Constructor of transformed_distribution class.
"""
param
=
dict
(
locals
())
super
(
TransformedDistribution
,
self
).
__init__
(
distribution
.
dtype
,
name
,
param
)
validator
.
check_value_type
(
'bijector'
,
bijector
,
[
nn
.
probability
.
bijector
.
Bijector
],
name
)
validator
.
check_value_type
(
'distribution'
,
distribution
,
[
Distribution
],
name
)
valid_dtype
=
mstype
.
number_type
check_type
(
dtype
,
valid_dtype
,
"transformed_distribution"
)
super
(
TransformedDistribution
,
self
).
__init__
(
seed
,
dtype
,
name
,
param
)
self
.
_bijector
=
bijector
self
.
_distribution
=
distribution
self
.
_is_linear_transformation
=
bijector
.
is_constant_jacobian
...
...
mindspore/nn/probability/distribution/uniform.py
浏览文件 @
415dad3a
...
...
@@ -16,7 +16,7 @@
from
mindspore.ops
import
operations
as
P
from
mindspore.common
import
dtype
as
mstype
from
.distribution
import
Distribution
from
._utils.utils
import
convert_to_batch
,
check_greater
from
._utils.utils
import
convert_to_batch
,
check_greater
,
check_type
class
Uniform
(
Distribution
):
"""
...
...
@@ -97,10 +97,12 @@ class Uniform(Distribution):
Constructor of Uniform distribution.
"""
param
=
dict
(
locals
())
super
(
Uniform
,
self
).
__init__
(
dtype
,
name
,
param
)
valid_dtype
=
mstype
.
float_type
check_type
(
dtype
,
valid_dtype
,
"Uniform"
)
super
(
Uniform
,
self
).
__init__
(
seed
,
dtype
,
name
,
param
)
if
low
is
not
None
and
high
is
not
None
:
self
.
_low
=
convert_to_batch
(
low
,
self
.
_
broadcast_shape
,
dtype
)
self
.
_high
=
convert_to_batch
(
high
,
self
.
_
broadcast_shape
,
dtype
)
self
.
_low
=
convert_to_batch
(
low
,
self
.
broadcast_shape
,
dtype
)
self
.
_high
=
convert_to_batch
(
high
,
self
.
broadcast_shape
,
dtype
)
check_greater
(
self
.
low
,
self
.
high
,
"low value"
,
"high value"
)
else
:
self
.
_low
=
low
...
...
@@ -156,7 +158,7 @@ class Uniform(Distribution):
def
_mean
(
self
,
low
=
None
,
high
=
None
):
r
"""
.. math::
MEAN(U) = \frac
t
{low + high}{2}.
MEAN(U) = \frac{low + high}{2}.
"""
low
=
self
.
low
if
low
is
None
else
low
high
=
self
.
high
if
high
is
None
else
high
...
...
@@ -166,7 +168,7 @@ class Uniform(Distribution):
def
_var
(
self
,
low
=
None
,
high
=
None
):
r
"""
.. math::
VAR(U) = \frac
t
{(high -low) ^ 2}{12}.
VAR(U) = \frac{(high -low) ^ 2}{12}.
"""
low
=
self
.
low
if
low
is
None
else
low
high
=
self
.
high
if
high
is
None
else
high
...
...
@@ -207,7 +209,7 @@ class Uniform(Distribution):
.. math::
pdf(x) = 0 if x < low;
pdf(x) = \frac
t
{1.0}{high -low} if low <= x <= high;
pdf(x) = \frac{1.0}{high -low} if low <= x <= high;
pdf(x) = 0 if x > high;
"""
low
=
self
.
low
if
low
is
None
else
low
...
...
@@ -251,7 +253,7 @@ class Uniform(Distribution):
.. math::
cdf(x) = 0 if x < low;
cdf(x) = \frac
t
{x - low}{high -low} if low <= x <= high;
cdf(x) = \frac{x - low}{high -low} if low <= x <= high;
cdf(x) = 1 if x > high;
"""
low
=
self
.
low
if
low
is
None
else
low
...
...
tests/ut/python/nn/distribution/test_bernoulli.py
浏览文件 @
415dad3a
...
...
@@ -31,6 +31,18 @@ def test_arguments():
b
=
msd
.
Bernoulli
([
0.0
,
0.3
,
0.5
,
1.0
],
dtype
=
dtype
.
int32
)
assert
isinstance
(
b
,
msd
.
Distribution
)
def
test_type
():
with
pytest
.
raises
(
TypeError
):
msd
.
Bernoulli
([
0.1
],
dtype
=
dtype
.
float32
)
def
test_name
():
with
pytest
.
raises
(
TypeError
):
msd
.
Bernoulli
([
0.1
],
name
=
1.0
)
def
test_seed
():
with
pytest
.
raises
(
TypeError
):
msd
.
Bernoulli
([
0.1
],
seed
=
'seed'
)
def
test_prob
():
"""
Invalid probability.
...
...
tests/ut/python/nn/distribution/test_exponential.py
浏览文件 @
415dad3a
...
...
@@ -32,6 +32,18 @@ def test_arguments():
e
=
msd
.
Exponential
([
0.1
,
0.3
,
0.5
,
1.0
],
dtype
=
dtype
.
float32
)
assert
isinstance
(
e
,
msd
.
Distribution
)
def
test_type
():
with
pytest
.
raises
(
TypeError
):
msd
.
Exponential
([
0.1
],
dtype
=
dtype
.
int32
)
def
test_name
():
with
pytest
.
raises
(
TypeError
):
msd
.
Exponential
([
0.1
],
name
=
1.0
)
def
test_seed
():
with
pytest
.
raises
(
TypeError
):
msd
.
Exponential
([
0.1
],
seed
=
'seed'
)
def
test_rate
():
"""
Invalid rate.
...
...
tests/ut/python/nn/distribution/test_geometric.py
浏览文件 @
415dad3a
...
...
@@ -32,6 +32,18 @@ def test_arguments():
g
=
msd
.
Geometric
([
0.0
,
0.3
,
0.5
,
1.0
],
dtype
=
dtype
.
int32
)
assert
isinstance
(
g
,
msd
.
Distribution
)
def
test_type
():
with
pytest
.
raises
(
TypeError
):
msd
.
Geometric
([
0.1
],
dtype
=
dtype
.
float32
)
def
test_name
():
with
pytest
.
raises
(
TypeError
):
msd
.
Geometric
([
0.1
],
name
=
1.0
)
def
test_seed
():
with
pytest
.
raises
(
TypeError
):
msd
.
Geometric
([
0.1
],
seed
=
'seed'
)
def
test_prob
():
"""
Invalid probability.
...
...
tests/ut/python/nn/distribution/test_normal.py
浏览文件 @
415dad3a
...
...
@@ -30,6 +30,17 @@ def test_normal_shape_errpr():
with
pytest
.
raises
(
ValueError
):
msd
.
Normal
([[
2.
],
[
1.
]],
[[
2.
],
[
3.
],
[
4.
]],
dtype
=
dtype
.
float32
)
def
test_type
():
with
pytest
.
raises
(
TypeError
):
msd
.
Normal
(
0.
,
1.
,
dtype
=
dtype
.
int32
)
def
test_name
():
with
pytest
.
raises
(
TypeError
):
msd
.
Normal
(
0.
,
1.
,
name
=
1.0
)
def
test_seed
():
with
pytest
.
raises
(
TypeError
):
msd
.
Normal
(
0.
,
1.
,
seed
=
'seed'
)
def
test_arguments
():
"""
...
...
tests/ut/python/nn/distribution/test_uniform.py
浏览文件 @
415dad3a
...
...
@@ -30,6 +30,17 @@ def test_uniform_shape_errpr():
with
pytest
.
raises
(
ValueError
):
msd
.
Uniform
([[
2.
],
[
1.
]],
[[
2.
],
[
3.
],
[
4.
]],
dtype
=
dtype
.
float32
)
def
test_type
():
with
pytest
.
raises
(
TypeError
):
msd
.
Uniform
(
0.
,
1.
,
dtype
=
dtype
.
int32
)
def
test_name
():
with
pytest
.
raises
(
TypeError
):
msd
.
Uniform
(
0.
,
1.
,
name
=
1.0
)
def
test_seed
():
with
pytest
.
raises
(
TypeError
):
msd
.
Uniform
(
0.
,
1.
,
seed
=
'seed'
)
def
test_arguments
():
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录