Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
dc11fa9f
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dc11fa9f
编写于
8月 21, 2020
作者:
X
Xun Deng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fixed CheckTuple issues and error message
上级
04decda0
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
23 addition
and
30 deletion
+23
-30
mindspore/nn/probability/distribution/_utils/utils.py
mindspore/nn/probability/distribution/_utils/utils.py
+13
-6
mindspore/nn/probability/distribution/bernoulli.py
mindspore/nn/probability/distribution/bernoulli.py
+1
-5
mindspore/nn/probability/distribution/distribution.py
mindspore/nn/probability/distribution/distribution.py
+4
-0
mindspore/nn/probability/distribution/exponential.py
mindspore/nn/probability/distribution/exponential.py
+1
-4
mindspore/nn/probability/distribution/geometric.py
mindspore/nn/probability/distribution/geometric.py
+1
-4
mindspore/nn/probability/distribution/normal.py
mindspore/nn/probability/distribution/normal.py
+1
-5
mindspore/nn/probability/distribution/transformed_distribution.py
...e/nn/probability/distribution/transformed_distribution.py
+1
-1
mindspore/nn/probability/distribution/uniform.py
mindspore/nn/probability/distribution/uniform.py
+1
-5
未找到文件。
mindspore/nn/probability/distribution/_utils/utils.py
浏览文件 @
dc11fa9f
...
@@ -22,6 +22,7 @@ from mindspore.common.parameter import Parameter
...
@@ -22,6 +22,7 @@ from mindspore.common.parameter import Parameter
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common
import
dtype
as
mstype
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
composite
as
C
from
mindspore
import
context
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
import
mindspore.nn.probability
as
msp
import
mindspore.nn.probability
as
msp
...
@@ -273,7 +274,8 @@ def check_type(data_type, value_type, name):
...
@@ -273,7 +274,8 @@ def check_type(data_type, value_type, name):
@
constexpr
@
constexpr
def
raise_none_error
(
name
):
def
raise_none_error
(
name
):
raise
ValueError
(
f
"
{
name
}
should be specified. Value cannot be None"
)
raise
TypeError
(
f
"the type
{
name
}
should be subclass of Tensor."
f
" It should not be None since it is not specified during initialization."
)
@
constexpr
@
constexpr
def
raise_not_impl_error
(
name
):
def
raise_not_impl_error
(
name
):
...
@@ -298,15 +300,20 @@ class CheckTuple(PrimitiveWithInfer):
...
@@ -298,15 +300,20 @@ class CheckTuple(PrimitiveWithInfer):
def
__infer__
(
self
,
x
,
name
):
def
__infer__
(
self
,
x
,
name
):
if
not
isinstance
(
x
[
'dtype'
],
tuple
):
if
not
isinstance
(
x
[
'dtype'
],
tuple
):
raise
TypeError
(
"Input type should be a tuple: "
+
name
[
"value"
]
)
raise
TypeError
(
f
"For
{
name
[
'value'
]
}
, Input type should b a tuple."
)
out
=
{
'shape'
:
None
,
out
=
{
'shape'
:
None
,
'dtype'
:
None
,
'dtype'
:
None
,
'value'
:
None
}
'value'
:
x
[
"value"
]
}
return
out
return
out
def
__call__
(
self
,
*
args
):
def
__call__
(
self
,
x
,
name
):
return
if
context
.
get_context
(
"mode"
)
==
0
:
return
x
[
"value"
]
#Pynative mode
if
isinstance
(
x
,
tuple
):
return
x
raise
TypeError
(
f
"For
{
name
[
'value'
]
}
, Input type should b a tuple."
)
class
CheckTensor
(
PrimitiveWithInfer
):
class
CheckTensor
(
PrimitiveWithInfer
):
"""
"""
...
@@ -327,5 +334,5 @@ class CheckTensor(PrimitiveWithInfer):
...
@@ -327,5 +334,5 @@ class CheckTensor(PrimitiveWithInfer):
'value'
:
None
}
'value'
:
None
}
return
out
return
out
def
__call__
(
self
,
*
args
):
def
__call__
(
self
,
x
,
name
):
return
return
mindspore/nn/probability/distribution/bernoulli.py
浏览文件 @
dc11fa9f
...
@@ -18,7 +18,6 @@ from mindspore.ops import operations as P
...
@@ -18,7 +18,6 @@ from mindspore.ops import operations as P
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
composite
as
C
from
.distribution
import
Distribution
from
.distribution
import
Distribution
from
._utils.utils
import
cast_to_tensor
,
check_prob
,
check_type
,
check_distribution_name
,
raise_none_error
from
._utils.utils
import
cast_to_tensor
,
check_prob
,
check_type
,
check_distribution_name
,
raise_none_error
from
._utils.utils
import
CheckTensor
,
CheckTuple
from
._utils.custom_ops
import
log_by_step
from
._utils.custom_ops
import
log_by_step
class
Bernoulli
(
Distribution
):
class
Bernoulli
(
Distribution
):
...
@@ -125,9 +124,6 @@ class Bernoulli(Distribution):
...
@@ -125,9 +124,6 @@ class Bernoulli(Distribution):
self
.
sqrt
=
P
.
Sqrt
()
self
.
sqrt
=
P
.
Sqrt
()
self
.
uniform
=
C
.
uniform
self
.
uniform
=
C
.
uniform
self
.
checktensor
=
CheckTensor
()
self
.
checktuple
=
CheckTuple
()
def
extend_repr
(
self
):
def
extend_repr
(
self
):
if
self
.
is_scalar_batch
:
if
self
.
is_scalar_batch
:
str_info
=
f
'probs =
{
self
.
probs
}
'
str_info
=
f
'probs =
{
self
.
probs
}
'
...
@@ -279,7 +275,7 @@ class Bernoulli(Distribution):
...
@@ -279,7 +275,7 @@ class Bernoulli(Distribution):
Returns:
Returns:
Tensor, shape is shape + batch_shape.
Tensor, shape is shape + batch_shape.
"""
"""
self
.
checktuple
(
shape
,
'shape'
)
s
hape
=
s
elf
.
checktuple
(
shape
,
'shape'
)
probs1
=
self
.
_check_param
(
probs1
)
probs1
=
self
.
_check_param
(
probs1
)
origin_shape
=
shape
+
self
.
shape
(
probs1
)
origin_shape
=
shape
+
self
.
shape
(
probs1
)
if
origin_shape
==
():
if
origin_shape
==
():
...
...
mindspore/nn/probability/distribution/distribution.py
浏览文件 @
dc11fa9f
...
@@ -17,6 +17,7 @@ from mindspore.nn.cell import Cell
...
@@ -17,6 +17,7 @@ from mindspore.nn.cell import Cell
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Rel
from
mindspore._checkparam
import
Rel
from
._utils.utils
import
calc_broadcast_shape_from_param
,
check_scalar_from_param
from
._utils.utils
import
calc_broadcast_shape_from_param
,
check_scalar_from_param
from
._utils.utils
import
CheckTuple
,
CheckTensor
class
Distribution
(
Cell
):
class
Distribution
(
Cell
):
"""
"""
...
@@ -79,6 +80,9 @@ class Distribution(Cell):
...
@@ -79,6 +80,9 @@ class Distribution(Cell):
self
.
_set_log_survival
()
self
.
_set_log_survival
()
self
.
_set_cross_entropy
()
self
.
_set_cross_entropy
()
self
.
checktuple
=
CheckTuple
()
self
.
checktensor
=
CheckTensor
()
@
property
@
property
def
name
(
self
):
def
name
(
self
):
return
self
.
_name
return
self
.
_name
...
...
mindspore/nn/probability/distribution/exponential.py
浏览文件 @
dc11fa9f
...
@@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype
...
@@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype
from
.distribution
import
Distribution
from
.distribution
import
Distribution
from
._utils.utils
import
cast_to_tensor
,
check_greater_zero
,
check_type
,
check_distribution_name
,
\
from
._utils.utils
import
cast_to_tensor
,
check_greater_zero
,
check_type
,
check_distribution_name
,
\
raise_none_error
raise_none_error
from
._utils.utils
import
CheckTensor
,
CheckTuple
from
._utils.custom_ops
import
log_by_step
from
._utils.custom_ops
import
log_by_step
class
Exponential
(
Distribution
):
class
Exponential
(
Distribution
):
...
@@ -127,8 +126,6 @@ class Exponential(Distribution):
...
@@ -127,8 +126,6 @@ class Exponential(Distribution):
self
.
sq
=
P
.
Square
()
self
.
sq
=
P
.
Square
()
self
.
uniform
=
C
.
uniform
self
.
uniform
=
C
.
uniform
self
.
checktensor
=
CheckTensor
()
self
.
checktuple
=
CheckTuple
()
def
extend_repr
(
self
):
def
extend_repr
(
self
):
if
self
.
is_scalar_batch
:
if
self
.
is_scalar_batch
:
...
@@ -270,7 +267,7 @@ class Exponential(Distribution):
...
@@ -270,7 +267,7 @@ class Exponential(Distribution):
Returns:
Returns:
Tensor, shape is shape + batch_shape.
Tensor, shape is shape + batch_shape.
"""
"""
self
.
checktuple
(
shape
,
'shape'
)
s
hape
=
s
elf
.
checktuple
(
shape
,
'shape'
)
rate
=
self
.
_check_param
(
rate
)
rate
=
self
.
_check_param
(
rate
)
origin_shape
=
shape
+
self
.
shape
(
rate
)
origin_shape
=
shape
+
self
.
shape
(
rate
)
if
origin_shape
==
():
if
origin_shape
==
():
...
...
mindspore/nn/probability/distribution/geometric.py
浏览文件 @
dc11fa9f
...
@@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype
...
@@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype
from
.distribution
import
Distribution
from
.distribution
import
Distribution
from
._utils.utils
import
cast_to_tensor
,
check_prob
,
check_type
,
check_distribution_name
,
\
from
._utils.utils
import
cast_to_tensor
,
check_prob
,
check_type
,
check_distribution_name
,
\
raise_none_error
raise_none_error
from
._utils.utils
import
CheckTensor
,
CheckTuple
from
._utils.custom_ops
import
log_by_step
from
._utils.custom_ops
import
log_by_step
class
Geometric
(
Distribution
):
class
Geometric
(
Distribution
):
...
@@ -131,8 +130,6 @@ class Geometric(Distribution):
...
@@ -131,8 +130,6 @@ class Geometric(Distribution):
self
.
sqrt
=
P
.
Sqrt
()
self
.
sqrt
=
P
.
Sqrt
()
self
.
uniform
=
C
.
uniform
self
.
uniform
=
C
.
uniform
self
.
checktensor
=
CheckTensor
()
self
.
checktuple
=
CheckTuple
()
def
extend_repr
(
self
):
def
extend_repr
(
self
):
if
self
.
is_scalar_batch
:
if
self
.
is_scalar_batch
:
...
@@ -278,7 +275,7 @@ class Geometric(Distribution):
...
@@ -278,7 +275,7 @@ class Geometric(Distribution):
Returns:
Returns:
Tensor, shape is shape + batch_shape.
Tensor, shape is shape + batch_shape.
"""
"""
self
.
checktuple
(
shape
,
'shape'
)
s
hape
=
s
elf
.
checktuple
(
shape
,
'shape'
)
probs1
=
self
.
_check_param
(
probs1
)
probs1
=
self
.
_check_param
(
probs1
)
origin_shape
=
shape
+
self
.
shape
(
probs1
)
origin_shape
=
shape
+
self
.
shape
(
probs1
)
if
origin_shape
==
():
if
origin_shape
==
():
...
...
mindspore/nn/probability/distribution/normal.py
浏览文件 @
dc11fa9f
...
@@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype
...
@@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype
from
.distribution
import
Distribution
from
.distribution
import
Distribution
from
._utils.utils
import
convert_to_batch
,
check_greater_zero
,
check_type
,
check_distribution_name
,
\
from
._utils.utils
import
convert_to_batch
,
check_greater_zero
,
check_type
,
check_distribution_name
,
\
raise_none_error
raise_none_error
from
._utils.utils
import
CheckTensor
,
CheckTuple
from
._utils.custom_ops
import
log_by_step
,
expm1_by_step
from
._utils.custom_ops
import
log_by_step
,
expm1_by_step
class
Normal
(
Distribution
):
class
Normal
(
Distribution
):
...
@@ -128,9 +127,6 @@ class Normal(Distribution):
...
@@ -128,9 +127,6 @@ class Normal(Distribution):
self
.
sqrt
=
P
.
Sqrt
()
self
.
sqrt
=
P
.
Sqrt
()
self
.
zeroslike
=
P
.
ZerosLike
()
self
.
zeroslike
=
P
.
ZerosLike
()
self
.
checktensor
=
CheckTensor
()
self
.
checktuple
=
CheckTuple
()
def
extend_repr
(
self
):
def
extend_repr
(
self
):
if
self
.
is_scalar_batch
:
if
self
.
is_scalar_batch
:
str_info
=
f
'mean =
{
self
.
_mean_value
}
, standard deviation =
{
self
.
_sd_value
}
'
str_info
=
f
'mean =
{
self
.
_mean_value
}
, standard deviation =
{
self
.
_sd_value
}
'
...
@@ -277,7 +273,7 @@ class Normal(Distribution):
...
@@ -277,7 +273,7 @@ class Normal(Distribution):
Returns:
Returns:
Tensor, shape is shape + batch_shape.
Tensor, shape is shape + batch_shape.
"""
"""
self
.
checktuple
(
shape
,
'shape'
)
s
hape
=
s
elf
.
checktuple
(
shape
,
'shape'
)
mean
,
sd
=
self
.
_check_param
(
mean
,
sd
)
mean
,
sd
=
self
.
_check_param
(
mean
,
sd
)
batch_shape
=
self
.
shape
(
mean
+
sd
)
batch_shape
=
self
.
shape
(
mean
+
sd
)
origin_shape
=
shape
+
batch_shape
origin_shape
=
shape
+
batch_shape
...
...
mindspore/nn/probability/distribution/transformed_distribution.py
浏览文件 @
dc11fa9f
...
@@ -116,4 +116,4 @@ class TransformedDistribution(Distribution):
...
@@ -116,4 +116,4 @@ class TransformedDistribution(Distribution):
if
not
self
.
is_linear_transformation
:
if
not
self
.
is_linear_transformation
:
raise_not_impl_error
(
"mean"
)
raise_not_impl_error
(
"mean"
)
return
self
.
bijector
(
"forward"
,
self
.
distribution
(
"mean"
))
return
self
.
bijector
(
"forward"
,
self
.
distribution
(
"mean"
,
*
args
,
**
kwargs
))
mindspore/nn/probability/distribution/uniform.py
浏览文件 @
dc11fa9f
...
@@ -19,7 +19,6 @@ from mindspore.common import dtype as mstype
...
@@ -19,7 +19,6 @@ from mindspore.common import dtype as mstype
from
.distribution
import
Distribution
from
.distribution
import
Distribution
from
._utils.utils
import
convert_to_batch
,
check_greater
,
check_type
,
check_distribution_name
,
\
from
._utils.utils
import
convert_to_batch
,
check_greater
,
check_type
,
check_distribution_name
,
\
raise_none_error
raise_none_error
from
._utils.utils
import
CheckTensor
,
CheckTuple
from
._utils.custom_ops
import
log_by_step
from
._utils.custom_ops
import
log_by_step
class
Uniform
(
Distribution
):
class
Uniform
(
Distribution
):
...
@@ -131,9 +130,6 @@ class Uniform(Distribution):
...
@@ -131,9 +130,6 @@ class Uniform(Distribution):
self
.
zeroslike
=
P
.
ZerosLike
()
self
.
zeroslike
=
P
.
ZerosLike
()
self
.
uniform
=
C
.
uniform
self
.
uniform
=
C
.
uniform
self
.
checktensor
=
CheckTensor
()
self
.
checktuple
=
CheckTuple
()
def
extend_repr
(
self
):
def
extend_repr
(
self
):
if
self
.
is_scalar_batch
:
if
self
.
is_scalar_batch
:
str_info
=
f
'low =
{
self
.
low
}
, high =
{
self
.
high
}
'
str_info
=
f
'low =
{
self
.
low
}
, high =
{
self
.
high
}
'
...
@@ -306,7 +302,7 @@ class Uniform(Distribution):
...
@@ -306,7 +302,7 @@ class Uniform(Distribution):
Returns:
Returns:
Tensor, shape is shape + batch_shape.
Tensor, shape is shape + batch_shape.
"""
"""
self
.
checktuple
(
shape
,
'shape'
)
s
hape
=
s
elf
.
checktuple
(
shape
,
'shape'
)
low
,
high
=
self
.
_check_param
(
low
,
high
)
low
,
high
=
self
.
_check_param
(
low
,
high
)
broadcast_shape
=
self
.
shape
(
low
+
high
)
broadcast_shape
=
self
.
shape
(
low
+
high
)
origin_shape
=
shape
+
broadcast_shape
origin_shape
=
shape
+
broadcast_shape
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录