Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
436b5447
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
436b5447
编写于
8月 26, 2020
作者:
X
Xun Deng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add parameter type check in normal and uniform distribution
上级
2cd99c28
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
40 addition
and
16 deletion
+40
-16
mindspore/nn/probability/distribution/_utils/utils.py
mindspore/nn/probability/distribution/_utils/utils.py
+17
-0
mindspore/nn/probability/distribution/bernoulli.py
mindspore/nn/probability/distribution/bernoulli.py
+1
-1
mindspore/nn/probability/distribution/distribution.py
mindspore/nn/probability/distribution/distribution.py
+2
-1
mindspore/nn/probability/distribution/geometric.py
mindspore/nn/probability/distribution/geometric.py
+2
-2
mindspore/nn/probability/distribution/normal.py
mindspore/nn/probability/distribution/normal.py
+9
-6
mindspore/nn/probability/distribution/uniform.py
mindspore/nn/probability/distribution/uniform.py
+9
-6
未找到文件。
mindspore/nn/probability/distribution/_utils/utils.py
浏览文件 @
436b5447
...
...
@@ -368,3 +368,20 @@ class CheckTensor(PrimitiveWithInfer):
if
isinstance
(
x
,
Tensor
):
return
x
raise
TypeError
(
f
"For
{
name
}
, input type should be a Tensor."
)
def
common_dtype
(
arg_a
,
name_a
,
arg_b
,
name_b
,
hint_type
):
"""
check if arg_a and arg_b have the same dtype.
"""
if
hasattr
(
arg_a
,
'dtype'
)
and
hasattr
(
arg_b
,
'dtype'
):
if
isinstance
(
arg_a
,
np
.
ndarray
):
a_dtype
=
mstype
.
pytype_to_dtype
(
arg_a
.
dtype
)
if
isinstance
(
arg_a
,
np
.
ndarray
):
b_dtype
=
mstype
.
pytype_to_dtype
(
arg_b
.
dtype
)
if
a_dtype
!=
b_dtype
:
raise
TypeError
(
f
"
{
name_a
}
and
{
name_b
}
should have the same dtype."
)
int_type
=
mstype
.
int_type
+
mstype
.
uint_type
if
a_dtype
in
int_type
or
a_dtype
==
mstype
.
float64
:
return
mstype
.
float32
return
a_dtype
return
hint_type
mindspore/nn/probability/distribution/bernoulli.py
浏览文件 @
436b5447
...
...
@@ -32,7 +32,7 @@ class Bernoulli(Distribution):
name (str): name of the distribution. Default: Bernoulli.
Note:
probs should be proper probabilities (0 <
= p <=
1).
probs should be proper probabilities (0 <
p <
1).
Dist_spec_args is probs.
Examples:
...
...
mindspore/nn/probability/distribution/distribution.py
浏览文件 @
436b5447
...
...
@@ -26,8 +26,9 @@ class Distribution(Cell):
Base class for all mathematical distributions.
Args:
seed (int): random seed used in sampling.
dtype (mindspore.dtype): type of the distribution.
name (str):
name of the distribution
.
name (str):
Python str name prefixed to Ops created by this class. Default: subclass name
.
param (dict): parameters used to initialize the distribution.
Note:
...
...
mindspore/nn/probability/distribution/geometric.py
浏览文件 @
436b5447
...
...
@@ -35,7 +35,7 @@ class Geometric(Distribution):
name (str): name of the distribution. Default: Geometric.
Note:
probs should be proper probabilities (0 <
= p <=
1).
probs should be proper probabilities (0 <
p <
1).
Dist_spec_args is probs.
Examples:
...
...
@@ -141,7 +141,7 @@ class Geometric(Distribution):
@
property
def
probs
(
self
):
"""
Returns the probability
for the outcome is 1
.
Returns the probability
of success of the Bernoulli trail
.
"""
return
self
.
_probs
...
...
mindspore/nn/probability/distribution/normal.py
浏览文件 @
436b5447
...
...
@@ -19,7 +19,7 @@ from mindspore.ops import composite as C
from
mindspore.common
import
dtype
as
mstype
from
.distribution
import
Distribution
from
._utils.utils
import
cast_to_tensor
,
check_greater_zero
,
check_type
,
check_distribution_name
,
\
raise_none_error
raise_none_error
,
common_dtype
from
._utils.custom_ops
import
exp_generic
,
expm1_generic
,
log_generic
,
erf_generic
class
Normal
(
Distribution
):
...
...
@@ -104,7 +104,7 @@ class Normal(Distribution):
valid_dtype
=
mstype
.
float_type
check_type
(
dtype
,
valid_dtype
,
type
(
self
).
__name__
)
super
(
Normal
,
self
).
__init__
(
seed
,
dtype
,
name
,
param
)
self
.
parameter_type
=
dtype
self
.
parameter_type
=
common_dtype
(
mean
,
'mean'
,
sd
,
'sd'
,
self
.
dtype
)
if
mean
is
not
None
and
sd
is
not
None
:
self
.
_mean_value
=
cast_to_tensor
(
mean
,
self
.
parameter_type
)
self
.
_sd_value
=
cast_to_tensor
(
sd
,
self
.
parameter_type
)
...
...
@@ -126,6 +126,8 @@ class Normal(Distribution):
self
.
sq
=
P
.
Square
()
self
.
sqrt
=
P
.
Sqrt
()
self
.
zeroslike
=
P
.
ZerosLike
()
self
.
dtypeop
=
P
.
DType
()
self
.
sametypeshape
=
P
.
SameTypeShape
()
def
extend_repr
(
self
):
if
self
.
is_scalar_batch
:
...
...
@@ -143,7 +145,6 @@ class Normal(Distribution):
self
.
checktensor
(
mean
,
'mean'
)
else
:
mean
=
self
.
checktensor
(
mean
,
'mean'
)
mean
=
self
.
cast
(
mean
,
self
.
parameter_type
)
else
:
mean
=
self
.
_mean_value
if
self
.
_mean_value
is
not
None
else
raise_none_error
(
'mean'
)
if
sd
is
not
None
:
...
...
@@ -151,12 +152,14 @@ class Normal(Distribution):
self
.
checktensor
(
sd
,
'sd'
)
else
:
sd
=
self
.
checktensor
(
sd
,
'sd'
)
sd
=
self
.
cast
(
sd
,
self
.
parameter_type
)
else
:
sd
=
self
.
_sd_value
if
self
.
_sd_value
is
not
None
else
raise_none_error
(
'sd'
)
batch_shape
=
self
.
shape
(
mean
+
sd
)
mean
=
mean
*
self
.
fill
(
self
.
dtype
,
batch_shape
,
1.0
)
sd
=
sd
*
self
.
fill
(
self
.
dtype
,
batch_shape
,
1.0
)
mean
=
mean
*
self
.
fill
(
self
.
dtypeop
(
mean
),
batch_shape
,
1.0
)
sd
=
sd
*
self
.
fill
(
self
.
dtypeop
(
sd
),
batch_shape
,
1.0
)
self
.
sametypeshape
(
mean
,
sd
)
mean
=
self
.
cast
(
mean
,
self
.
parameter_type
)
sd
=
self
.
cast
(
sd
,
self
.
parameter_type
)
return
mean
,
sd
def
_mean
(
self
,
mean
=
None
,
sd
=
None
):
...
...
mindspore/nn/probability/distribution/uniform.py
浏览文件 @
436b5447
...
...
@@ -18,7 +18,7 @@ from mindspore.ops import composite as C
from
mindspore.common
import
dtype
as
mstype
from
.distribution
import
Distribution
from
._utils.utils
import
cast_to_tensor
,
check_greater
,
check_type
,
check_distribution_name
,
\
raise_none_error
raise_none_error
,
common_dtype
from
._utils.custom_ops
import
exp_generic
,
log_generic
class
Uniform
(
Distribution
):
...
...
@@ -103,7 +103,7 @@ class Uniform(Distribution):
valid_dtype
=
mstype
.
float_type
check_type
(
dtype
,
valid_dtype
,
type
(
self
).
__name__
)
super
(
Uniform
,
self
).
__init__
(
seed
,
dtype
,
name
,
param
)
self
.
parameter_type
=
dtype
self
.
parameter_type
=
common_dtype
(
low
,
'low'
,
high
,
'high'
,
self
.
dtype
)
if
low
is
not
None
and
high
is
not
None
:
self
.
_low
=
cast_to_tensor
(
low
,
dtype
)
self
.
_high
=
cast_to_tensor
(
high
,
dtype
)
...
...
@@ -130,6 +130,8 @@ class Uniform(Distribution):
self
.
zeroslike
=
P
.
ZerosLike
()
self
.
uniform
=
C
.
uniform
self
.
sametypeshape
=
P
.
SameTypeShape
()
def
extend_repr
(
self
):
if
self
.
is_scalar_batch
:
str_info
=
f
'low =
{
self
.
low
}
, high =
{
self
.
high
}
'
...
...
@@ -146,7 +148,6 @@ class Uniform(Distribution):
self
.
checktensor
(
low
,
'low'
)
else
:
low
=
self
.
checktensor
(
low
,
'low'
)
low
=
self
.
cast
(
low
,
self
.
parameter_type
)
else
:
low
=
self
.
low
if
self
.
low
is
not
None
else
raise_none_error
(
'low'
)
if
high
is
not
None
:
...
...
@@ -154,12 +155,14 @@ class Uniform(Distribution):
self
.
checktensor
(
high
,
'high'
)
else
:
high
=
self
.
checktensor
(
high
,
'high'
)
high
=
self
.
cast
(
high
,
self
.
parameter_type
)
else
:
high
=
self
.
high
if
self
.
high
is
not
None
else
raise_none_error
(
'high'
)
batch_shape
=
self
.
shape
(
high
-
low
)
high
=
high
*
self
.
fill
(
self
.
dtype
,
batch_shape
,
1.0
)
low
=
low
*
self
.
fill
(
self
.
dtype
,
batch_shape
,
1.0
)
high
=
high
*
self
.
fill
(
self
.
dtypeop
(
high
),
batch_shape
,
1.0
)
low
=
low
*
self
.
fill
(
self
.
dtypeop
(
low
),
batch_shape
,
1.0
)
self
.
sametypeshape
(
high
,
low
)
low
=
self
.
cast
(
low
,
self
.
parameter_type
)
high
=
self
.
cast
(
high
,
self
.
parameter_type
)
return
low
,
high
@
property
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录