Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
4fe8b3d3
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
4fe8b3d3
编写于
8月 25, 2020
作者:
X
Xun Deng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix checktensor in pynative mode
上级
b8da525f
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
107 addition
and
65 deletion
+107
-65
mindspore/nn/probability/bijector/bijector.py
mindspore/nn/probability/bijector/bijector.py
+15
-1
mindspore/nn/probability/bijector/power_transform.py
mindspore/nn/probability/bijector/power_transform.py
+4
-7
mindspore/nn/probability/bijector/scalar_affine.py
mindspore/nn/probability/bijector/scalar_affine.py
+7
-9
mindspore/nn/probability/bijector/softplus.py
mindspore/nn/probability/bijector/softplus.py
+6
-7
mindspore/nn/probability/distribution/_utils/utils.py
mindspore/nn/probability/distribution/_utils/utils.py
+4
-2
mindspore/nn/probability/distribution/bernoulli.py
mindspore/nn/probability/distribution/bernoulli.py
+8
-5
mindspore/nn/probability/distribution/distribution.py
mindspore/nn/probability/distribution/distribution.py
+12
-1
mindspore/nn/probability/distribution/exponential.py
mindspore/nn/probability/distribution/exponential.py
+8
-5
mindspore/nn/probability/distribution/geometric.py
mindspore/nn/probability/distribution/geometric.py
+8
-5
mindspore/nn/probability/distribution/normal.py
mindspore/nn/probability/distribution/normal.py
+16
-10
mindspore/nn/probability/distribution/transformed_distribution.py
...e/nn/probability/distribution/transformed_distribution.py
+3
-3
mindspore/nn/probability/distribution/uniform.py
mindspore/nn/probability/distribution/uniform.py
+16
-10
未找到文件。
mindspore/nn/probability/bijector/bijector.py
浏览文件 @
4fe8b3d3
...
...
@@ -13,8 +13,10 @@
# limitations under the License.
# ============================================================================
"""Bijector"""
from
mindspore
import
context
from
mindspore.nn.cell
import
Cell
from
mindspore._checkparam
import
Validator
as
validator
from
..distribution._utils.utils
import
CheckTensor
from
..distribution
import
Distribution
from
..distribution
import
TransformedDistribution
...
...
@@ -40,7 +42,7 @@ class Bijector(Cell):
Constructor of bijector class.
"""
super
(
Bijector
,
self
).
__init__
()
validator
.
check_value_type
(
'name'
,
name
,
[
str
],
'Bijector'
)
validator
.
check_value_type
(
'name'
,
name
,
[
str
],
type
(
self
).
__name__
)
validator
.
check_value_type
(
'is_constant_jacobian'
,
is_constant_jacobian
,
[
bool
],
name
)
validator
.
check_value_type
(
'is_injective'
,
is_injective
,
[
bool
],
name
)
self
.
_name
=
name
...
...
@@ -53,6 +55,9 @@ class Bijector(Cell):
self
.
_is_constant_jacobian
=
is_constant_jacobian
self
.
_is_injective
=
is_injective
self
.
context_mode
=
context
.
get_context
(
'mode'
)
self
.
checktensor
=
CheckTensor
()
@
property
def
name
(
self
):
return
self
.
_name
...
...
@@ -73,6 +78,15 @@ class Bijector(Cell):
def
is_injective
(
self
):
return
self
.
_is_injective
def
_check_value
(
self
,
value
,
name
):
"""
Check availability fo value as a Tensor.
"""
if
self
.
context_mode
==
0
:
self
.
checktensor
(
value
,
name
)
return
value
return
self
.
checktensor
(
value
,
name
)
def
forward
(
self
,
*
args
,
**
kwargs
):
"""
Forward transformation: transform the input value to another distribution.
...
...
mindspore/nn/probability/bijector/power_transform.py
浏览文件 @
4fe8b3d3
...
...
@@ -16,7 +16,6 @@
from
mindspore.ops
import
operations
as
P
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Rel
from
..distribution._utils.utils
import
CheckTensor
from
..distribution._utils.custom_ops
import
exp_generic
,
expm1_generic
,
log_generic
,
log1p_generic
from
.bijector
import
Bijector
...
...
@@ -66,8 +65,6 @@ class PowerTransform(Bijector):
self
.
log
=
log_generic
self
.
log1p
=
log1p_generic
self
.
checktensor
=
CheckTensor
()
@
property
def
power
(
self
):
return
self
.
_power
...
...
@@ -80,13 +77,13 @@ class PowerTransform(Bijector):
return
shape
def
_forward
(
self
,
x
):
self
.
checktensor
(
x
,
'value'
)
x
=
self
.
_check_value
(
x
,
'value'
)
if
self
.
power
==
0
:
return
self
.
exp
(
x
)
return
self
.
exp
(
self
.
log1p
(
x
*
self
.
power
)
/
self
.
power
)
def
_inverse
(
self
,
y
):
self
.
checktensor
(
y
,
'value'
)
y
=
self
.
_check_value
(
y
,
'value'
)
if
self
.
power
==
0
:
return
self
.
log
(
y
)
return
self
.
expm1
(
self
.
log
(
y
)
*
self
.
power
)
/
self
.
power
...
...
@@ -103,7 +100,7 @@ class PowerTransform(Bijector):
f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1}
\log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1)
"""
self
.
checktensor
(
x
,
'value'
)
x
=
self
.
_check_value
(
x
,
'value'
)
if
self
.
power
==
0
:
return
x
return
(
1.
/
self
.
power
-
1
)
*
self
.
log1p
(
x
*
self
.
power
)
...
...
@@ -120,5 +117,5 @@ class PowerTransform(Bijector):
f'(x) = \frac{e^c\log(y)}{y}
\log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y)
"""
self
.
checktensor
(
y
,
'value'
)
y
=
self
.
_check_value
(
y
,
'value'
)
return
(
self
.
power
-
1
)
*
self
.
log
(
y
)
mindspore/nn/probability/bijector/scalar_affine.py
浏览文件 @
4fe8b3d3
...
...
@@ -15,7 +15,7 @@
"""Scalar Affine Bijector"""
from
mindspore.ops
import
operations
as
P
from
mindspore._checkparam
import
Validator
as
validator
from
..distribution._utils.utils
import
cast_to_tensor
,
CheckTensor
from
..distribution._utils.utils
import
cast_to_tensor
from
..distribution._utils.custom_ops
import
log_generic
from
.bijector
import
Bijector
...
...
@@ -57,8 +57,8 @@ class ScalarAffine(Bijector):
Constructor of scalar affine bijector.
"""
param
=
dict
(
locals
())
validator
.
check_value_type
(
'scale'
,
scale
,
[
int
,
float
],
name
)
validator
.
check_value_type
(
'shift'
,
shift
,
[
int
,
float
],
name
)
validator
.
check_value_type
(
'scale'
,
scale
,
[
int
,
float
],
type
(
self
).
__name__
)
validator
.
check_value_type
(
'shift'
,
shift
,
[
int
,
float
],
type
(
self
).
__name__
)
self
.
_scale
=
cast_to_tensor
(
scale
)
self
.
_shift
=
cast_to_tensor
(
shift
)
super
(
ScalarAffine
,
self
).
__init__
(
...
...
@@ -71,8 +71,6 @@ class ScalarAffine(Bijector):
self
.
abs
=
P
.
Abs
()
self
.
log
=
log_generic
self
.
checktensor
=
CheckTensor
()
@
property
def
scale
(
self
):
return
self
.
_scale
...
...
@@ -93,7 +91,7 @@ class ScalarAffine(Bijector):
.. math::
f(x) = a * x + b
"""
self
.
checktensor
(
x
,
'value'
)
x
=
self
.
_check_value
(
x
,
'value'
)
return
self
.
scale
*
x
+
self
.
shift
def
_inverse
(
self
,
y
):
...
...
@@ -101,7 +99,7 @@ class ScalarAffine(Bijector):
.. math::
f(y) = \frac{y - b}{a}
"""
self
.
checktensor
(
y
,
'value'
)
y
=
self
.
_check_value
(
y
,
'value'
)
return
(
y
-
self
.
shift
)
/
self
.
scale
def
_forward_log_jacobian
(
self
,
x
):
...
...
@@ -111,7 +109,7 @@ class ScalarAffine(Bijector):
f'(x) = a
\log(f'(x)) = \log(a)
"""
self
.
checktensor
(
x
,
'value'
)
x
=
self
.
_check_value
(
x
,
'value'
)
return
self
.
log
(
self
.
abs
(
self
.
scale
))
def
_inverse_log_jacobian
(
self
,
y
):
...
...
@@ -121,5 +119,5 @@ class ScalarAffine(Bijector):
f'(x) = \frac{1.0}{a}
\log(f'(x)) = - \log(a)
"""
self
.
checktensor
(
y
,
'value'
)
y
=
self
.
_check_value
(
y
,
'value'
)
return
-
1.
*
self
.
log
(
self
.
abs
(
self
.
scale
))
mindspore/nn/probability/bijector/softplus.py
浏览文件 @
4fe8b3d3
...
...
@@ -18,7 +18,7 @@ from mindspore.ops import operations as P
from
mindspore.common
import
dtype
as
mstype
from
mindspore.nn.layer.activation
import
LogSigmoid
from
mindspore._checkparam
import
Validator
as
validator
from
..distribution._utils.utils
import
cast_to_tensor
,
CheckTensor
from
..distribution._utils.utils
import
cast_to_tensor
from
..distribution._utils.custom_ops
import
exp_generic
,
expm1_generic
,
log_generic
from
.bijector
import
Bijector
...
...
@@ -57,7 +57,7 @@ class Softplus(Bijector):
sharpness
=
1.0
,
name
=
'Softplus'
):
param
=
dict
(
locals
())
validator
.
check_value_type
(
'sharpness'
,
sharpness
,
[
int
,
float
],
name
)
validator
.
check_value_type
(
'sharpness'
,
sharpness
,
[
int
,
float
],
type
(
self
).
__name__
)
super
(
Softplus
,
self
).
__init__
(
name
=
name
,
param
=
param
)
self
.
_sharpness
=
cast_to_tensor
(
sharpness
)
...
...
@@ -76,7 +76,6 @@ class Softplus(Bijector):
self
.
softplus
=
self
.
_softplus
self
.
inverse_softplus
=
self
.
_inverse_softplus
self
.
checktensor
=
CheckTensor
()
self
.
threshold
=
np
.
log
(
np
.
finfo
(
np
.
float32
).
eps
)
+
1
self
.
tiny
=
np
.
exp
(
self
.
threshold
)
...
...
@@ -119,7 +118,7 @@ class Softplus(Bijector):
return
shape
def
_forward
(
self
,
x
):
self
.
checktensor
(
x
,
'value'
)
x
=
self
.
_check_value
(
x
,
'value'
)
scaled_value
=
self
.
sharpness
*
x
return
self
.
softplus
(
scaled_value
)
/
self
.
sharpness
...
...
@@ -129,7 +128,7 @@ class Softplus(Bijector):
f(x) = \frac{\log(1 + e^{kx}))}{k}
f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k}
"""
self
.
checktensor
(
y
,
'value'
)
y
=
self
.
_check_value
(
y
,
'value'
)
scaled_value
=
self
.
sharpness
*
y
return
self
.
inverse_softplus
(
scaled_value
)
/
self
.
sharpness
...
...
@@ -140,7 +139,7 @@ class Softplus(Bijector):
f'(x) = \frac{e^{kx}}{ 1 + e^{kx}}
\log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx)
"""
self
.
checktensor
(
x
,
'value'
)
x
=
self
.
_check_value
(
x
,
'value'
)
scaled_value
=
self
.
sharpness
*
x
return
self
.
log_sigmoid
(
scaled_value
)
...
...
@@ -151,6 +150,6 @@ class Softplus(Bijector):
f'(y) = \frac{e^{ky}}{e^{ky} - 1}
\log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky)
"""
self
.
checktensor
(
y
,
'value'
)
y
=
self
.
_check_value
(
y
,
'value'
)
scaled_value
=
self
.
sharpness
*
y
return
scaled_value
-
self
.
inverse_softplus
(
scaled_value
)
mindspore/nn/probability/distribution/_utils/utils.py
浏览文件 @
4fe8b3d3
...
...
@@ -342,7 +342,7 @@ class CheckTuple(PrimitiveWithInfer):
# Pynative mode
if
isinstance
(
x
,
tuple
):
return
x
raise
TypeError
(
f
"For
{
name
[
'value'
]
}
, Input type should b
a tuple."
)
raise
TypeError
(
f
"For
{
name
}
, input type should be
a tuple."
)
class
CheckTensor
(
PrimitiveWithInfer
):
...
...
@@ -365,4 +365,6 @@ class CheckTensor(PrimitiveWithInfer):
return
out
def
__call__
(
self
,
x
,
name
):
return
if
isinstance
(
x
,
Tensor
):
return
x
raise
TypeError
(
f
"For
{
name
}
, input type should be a Tensor."
)
mindspore/nn/probability/distribution/bernoulli.py
浏览文件 @
4fe8b3d3
...
...
@@ -99,7 +99,7 @@ class Bernoulli(Distribution):
"""
param
=
dict
(
locals
())
valid_dtype
=
mstype
.
int_type
+
mstype
.
uint_type
+
mstype
.
float_type
check_type
(
dtype
,
valid_dtype
,
"Bernoulli"
)
check_type
(
dtype
,
valid_dtype
,
type
(
self
).
__name__
)
super
(
Bernoulli
,
self
).
__init__
(
seed
,
dtype
,
name
,
param
)
self
.
parameter_type
=
mstype
.
float32
if
probs
is
not
None
:
...
...
@@ -144,7 +144,10 @@ class Bernoulli(Distribution):
Check availablity of distribution specific args probs1.
"""
if
probs1
is
not
None
:
self
.
checktensor
(
probs1
,
'probs1'
)
if
self
.
context_mode
==
0
:
self
.
checktensor
(
probs1
,
'probs1'
)
else
:
probs1
=
self
.
checktensor
(
probs1
,
'probs1'
)
return
self
.
cast
(
probs1
,
self
.
parameter_type
)
return
self
.
probs
if
self
.
probs
is
not
None
else
raise_none_error
(
'probs1'
)
...
...
@@ -210,7 +213,7 @@ class Bernoulli(Distribution):
pmf(k) = probs1 if k = 1;
pmf(k) = probs0 if k = 0;
"""
self
.
checktensor
(
value
,
'value'
)
value
=
self
.
_check_value
(
value
,
'value'
)
value
=
self
.
cast
(
value
,
mstype
.
float32
)
probs1
=
self
.
_check_param
(
probs1
)
probs0
=
1.0
-
probs1
...
...
@@ -229,7 +232,7 @@ class Bernoulli(Distribution):
cdf(k) = probs0 if 0 <= k <1;
cdf(k) = 1 if k >=1;
"""
self
.
checktensor
(
value
,
'value'
)
value
=
self
.
_check_value
(
value
,
'value'
)
value
=
self
.
cast
(
value
,
mstype
.
float32
)
value
=
self
.
floor
(
value
)
probs1
=
self
.
_check_param
(
probs1
)
...
...
@@ -257,7 +260,7 @@ class Bernoulli(Distribution):
probs0_a * \log(\frac{probs0_a}{probs0_b})
"""
check_distribution_name
(
dist
,
'Bernoulli'
)
self
.
checktensor
(
probs1_b
,
'probs1_b'
)
probs1_b
=
self
.
_check_value
(
probs1_b
,
'probs1_b'
)
probs1_b
=
self
.
cast
(
probs1_b
,
self
.
parameter_type
)
probs1_a
=
self
.
_check_param
(
probs1
)
probs0_a
=
1.0
-
probs1_a
...
...
mindspore/nn/probability/distribution/distribution.py
浏览文件 @
4fe8b3d3
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""basic"""
from
mindspore
import
context
from
mindspore.nn.cell
import
Cell
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Rel
...
...
@@ -54,7 +55,7 @@ class Distribution(Cell):
Constructor of distribution class.
"""
super
(
Distribution
,
self
).
__init__
()
validator
.
check_value_type
(
'name'
,
name
,
[
str
],
'distribution_name'
)
validator
.
check_value_type
(
'name'
,
name
,
[
str
],
type
(
self
).
__name__
)
validator
.
check_integer
(
'seed'
,
seed
,
0
,
Rel
.
GE
,
name
)
self
.
_name
=
name
...
...
@@ -81,6 +82,7 @@ class Distribution(Cell):
self
.
_set_log_survival
()
self
.
_set_cross_entropy
()
self
.
context_mode
=
context
.
get_context
(
'mode'
)
self
.
checktuple
=
CheckTuple
()
self
.
checktensor
=
CheckTensor
()
...
...
@@ -108,6 +110,15 @@ class Distribution(Cell):
def
broadcast_shape
(
self
):
return
self
.
_broadcast_shape
def
_check_value
(
self
,
value
,
name
):
"""
Check availability fo value as a Tensor.
"""
if
self
.
context_mode
==
0
:
self
.
checktensor
(
value
,
name
)
return
value
return
self
.
checktensor
(
value
,
name
)
def
_set_prob
(
self
):
"""
Set probability funtion based on the availability of _prob and _log_likehood.
...
...
mindspore/nn/probability/distribution/exponential.py
浏览文件 @
4fe8b3d3
...
...
@@ -100,7 +100,7 @@ class Exponential(Distribution):
"""
param
=
dict
(
locals
())
valid_dtype
=
mstype
.
float_type
check_type
(
dtype
,
valid_dtype
,
"Exponential"
)
check_type
(
dtype
,
valid_dtype
,
type
(
self
).
__name__
)
super
(
Exponential
,
self
).
__init__
(
seed
,
dtype
,
name
,
param
)
self
.
parameter_type
=
dtype
if
rate
is
not
None
:
...
...
@@ -146,7 +146,10 @@ class Exponential(Distribution):
Check availablity of distribution specific args rate.
"""
if
rate
is
not
None
:
self
.
checktensor
(
rate
,
'rate'
)
if
self
.
context_mode
==
0
:
self
.
checktensor
(
rate
,
'rate'
)
else
:
rate
=
self
.
checktensor
(
rate
,
'rate'
)
return
self
.
cast
(
rate
,
self
.
parameter_type
)
return
self
.
rate
if
self
.
rate
is
not
None
else
raise_none_error
(
'rate'
)
...
...
@@ -210,7 +213,7 @@ class Exponential(Distribution):
.. math::
pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0
"""
self
.
checktensor
(
value
,
"value"
)
value
=
self
.
_check_value
(
value
,
"value"
)
value
=
self
.
cast
(
value
,
self
.
dtype
)
rate
=
self
.
_check_param
(
rate
)
prob
=
self
.
exp
(
self
.
log
(
rate
)
-
rate
*
value
)
...
...
@@ -232,7 +235,7 @@ class Exponential(Distribution):
.. math::
cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0
"""
self
.
checktensor
(
value
,
'value'
)
value
=
self
.
_check_value
(
value
,
'value'
)
value
=
self
.
cast
(
value
,
self
.
dtype
)
rate
=
self
.
_check_param
(
rate
)
cdf
=
1.0
-
self
.
exp
(
-
1.
*
rate
*
value
)
...
...
@@ -251,7 +254,7 @@ class Exponential(Distribution):
rate_a (Tensor): rate of distribution a. Default: self.rate.
"""
check_distribution_name
(
dist
,
'Exponential'
)
self
.
checktensor
(
rate_b
,
'rate_b'
)
rate_b
=
self
.
_check_value
(
rate_b
,
'rate_b'
)
rate_b
=
self
.
cast
(
rate_b
,
self
.
parameter_type
)
rate_a
=
self
.
_check_param
(
rate
)
return
self
.
log
(
rate_a
)
-
self
.
log
(
rate_b
)
+
rate_b
/
rate_a
-
1.0
...
...
mindspore/nn/probability/distribution/geometric.py
浏览文件 @
4fe8b3d3
...
...
@@ -102,7 +102,7 @@ class Geometric(Distribution):
"""
param
=
dict
(
locals
())
valid_dtype
=
mstype
.
int_type
+
mstype
.
uint_type
+
mstype
.
float_type
check_type
(
dtype
,
valid_dtype
,
"Geometric"
)
check_type
(
dtype
,
valid_dtype
,
type
(
self
).
__name__
)
super
(
Geometric
,
self
).
__init__
(
seed
,
dtype
,
name
,
param
)
self
.
parameter_type
=
mstype
.
float32
if
probs
is
not
None
:
...
...
@@ -150,7 +150,10 @@ class Geometric(Distribution):
Check availablity of distribution specific args probs1.
"""
if
probs1
is
not
None
:
self
.
checktensor
(
probs1
,
'probs1'
)
if
self
.
context_mode
==
0
:
self
.
checktensor
(
probs1
,
'probs1'
)
else
:
probs1
=
self
.
checktensor
(
probs1
,
'probs1'
)
return
self
.
cast
(
probs1
,
self
.
parameter_type
)
return
self
.
probs
if
self
.
probs
is
not
None
else
raise_none_error
(
'probs1'
)
...
...
@@ -211,7 +214,7 @@ class Geometric(Distribution):
pmf(k) = probs0 ^k * probs1 if k >= 0;
pmf(k) = 0 if k < 0.
"""
self
.
checktensor
(
value
,
'value'
)
value
=
self
.
_check_value
(
value
,
'value'
)
value
=
self
.
cast
(
value
,
mstype
.
float32
)
value
=
self
.
floor
(
value
)
probs1
=
self
.
_check_param
(
probs1
)
...
...
@@ -233,7 +236,7 @@ class Geometric(Distribution):
cdf(k) = 0 if k < 0.
"""
self
.
checktensor
(
value
,
'value'
)
value
=
self
.
_check_value
(
value
,
'value'
)
value
=
self
.
cast
(
value
,
mstype
.
float32
)
value
=
self
.
floor
(
value
)
probs1
=
self
.
_check_param
(
probs1
)
...
...
@@ -256,7 +259,7 @@ class Geometric(Distribution):
KL(a||b) = \log(\frac{probs1_a}{probs1_b}) + \frac{probs0_a}{probs1_a} * \log(\frac{probs0_a}{probs0_b})
"""
check_distribution_name
(
dist
,
'Geometric'
)
self
.
checktensor
(
probs1_b
,
'probs1_b'
)
probs1_b
=
self
.
_check_value
(
probs1_b
,
'probs1_b'
)
probs1_b
=
self
.
cast
(
probs1_b
,
self
.
parameter_type
)
probs1_a
=
self
.
_check_param
(
probs1
)
probs0_a
=
1.0
-
probs1_a
...
...
mindspore/nn/probability/distribution/normal.py
浏览文件 @
4fe8b3d3
...
...
@@ -18,7 +18,7 @@ from mindspore.ops import operations as P
from
mindspore.ops
import
composite
as
C
from
mindspore.common
import
dtype
as
mstype
from
.distribution
import
Distribution
from
._utils.utils
import
c
onvert_to_batch
,
check_greater_zero
,
check_type
,
check_distribution_name
,
\
from
._utils.utils
import
c
ast_to_tensor
,
check_greater_zero
,
check_type
,
check_distribution_name
,
\
raise_none_error
from
._utils.custom_ops
import
exp_generic
,
expm1_generic
,
log_generic
,
erf_generic
...
...
@@ -102,12 +102,12 @@ class Normal(Distribution):
"""
param
=
dict
(
locals
())
valid_dtype
=
mstype
.
float_type
check_type
(
dtype
,
valid_dtype
,
"Normal"
)
check_type
(
dtype
,
valid_dtype
,
type
(
self
).
__name__
)
super
(
Normal
,
self
).
__init__
(
seed
,
dtype
,
name
,
param
)
self
.
parameter_type
=
dtype
if
mean
is
not
None
and
sd
is
not
None
:
self
.
_mean_value
=
c
onvert_to_batch
(
mean
,
self
.
broadcast_shape
,
self
.
parameter_type
)
self
.
_sd_value
=
c
onvert_to_batch
(
sd
,
self
.
broadcast_shape
,
self
.
parameter_type
)
self
.
_mean_value
=
c
ast_to_tensor
(
mean
,
self
.
parameter_type
)
self
.
_sd_value
=
c
ast_to_tensor
(
sd
,
self
.
parameter_type
)
check_greater_zero
(
self
.
_sd_value
,
"Standard deviation"
)
else
:
self
.
_mean_value
=
mean
...
...
@@ -139,12 +139,18 @@ class Normal(Distribution):
Check availablity of distribution specific args mean and sd.
"""
if
mean
is
not
None
:
self
.
checktensor
(
mean
,
'mean'
)
if
self
.
context_mode
==
0
:
self
.
checktensor
(
mean
,
'mean'
)
else
:
mean
=
self
.
checktensor
(
mean
,
'mean'
)
mean
=
self
.
cast
(
mean
,
self
.
parameter_type
)
else
:
mean
=
self
.
_mean_value
if
self
.
_mean_value
is
not
None
else
raise_none_error
(
'mean'
)
if
sd
is
not
None
:
self
.
checktensor
(
sd
,
'sd'
)
if
self
.
context_mode
==
0
:
self
.
checktensor
(
sd
,
'sd'
)
else
:
sd
=
self
.
checktensor
(
sd
,
'sd'
)
sd
=
self
.
cast
(
sd
,
self
.
parameter_type
)
else
:
sd
=
self
.
_sd_value
if
self
.
_sd_value
is
not
None
else
raise_none_error
(
'sd'
)
...
...
@@ -210,7 +216,7 @@ class Normal(Distribution):
.. math::
L(x) = -1* \frac{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
"""
self
.
checktensor
(
value
,
'value'
)
value
=
self
.
_check_value
(
value
,
'value'
)
value
=
self
.
cast
(
value
,
self
.
dtype
)
mean
,
sd
=
self
.
_check_param
(
mean
,
sd
)
unnormalized_log_prob
=
-
1.
*
(
self
.
sq
(
value
-
mean
))
/
(
2.
*
self
.
sq
(
sd
))
...
...
@@ -229,7 +235,7 @@ class Normal(Distribution):
.. math::
cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2))))
"""
self
.
checktensor
(
value
,
'value'
)
value
=
self
.
_check_value
(
value
,
'value'
)
value
=
self
.
cast
(
value
,
self
.
dtype
)
mean
,
sd
=
self
.
_check_param
(
mean
,
sd
)
sqrt2
=
self
.
sqrt
(
self
.
const
(
2.0
))
...
...
@@ -252,8 +258,8 @@ class Normal(Distribution):
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
"""
check_distribution_name
(
dist
,
'Normal'
)
self
.
checktensor
(
mean_b
,
'mean_b'
)
s
elf
.
checktensor
(
sd_b
,
'sd_b'
)
mean_b
=
self
.
_check_value
(
mean_b
,
'mean_b'
)
s
d_b
=
self
.
_check_value
(
sd_b
,
'sd_b'
)
mean_b
=
self
.
cast
(
mean_b
,
self
.
parameter_type
)
sd_b
=
self
.
cast
(
sd_b
,
self
.
parameter_type
)
mean_a
,
sd_a
=
self
.
_check_param
(
mean
,
sd
)
...
...
mindspore/nn/probability/distribution/transformed_distribution.py
浏览文件 @
4fe8b3d3
...
...
@@ -46,10 +46,10 @@ class TransformedDistribution(Distribution):
Constructor of transformed_distribution class.
"""
param
=
dict
(
locals
())
validator
.
check_value_type
(
'bijector'
,
bijector
,
[
nn
.
probability
.
bijector
.
Bijector
],
name
)
validator
.
check_value_type
(
'distribution'
,
distribution
,
[
Distribution
],
name
)
validator
.
check_value_type
(
'bijector'
,
bijector
,
[
nn
.
probability
.
bijector
.
Bijector
],
type
(
self
).
__name__
)
validator
.
check_value_type
(
'distribution'
,
distribution
,
[
Distribution
],
type
(
self
).
__name__
)
valid_dtype
=
mstype
.
number_type
check_type
(
dtype
,
valid_dtype
,
"transformed_distribution"
)
check_type
(
dtype
,
valid_dtype
,
type
(
self
).
__name__
)
super
(
TransformedDistribution
,
self
).
__init__
(
seed
,
dtype
,
name
,
param
)
self
.
_bijector
=
bijector
...
...
mindspore/nn/probability/distribution/uniform.py
浏览文件 @
4fe8b3d3
...
...
@@ -17,7 +17,7 @@ from mindspore.ops import operations as P
from
mindspore.ops
import
composite
as
C
from
mindspore.common
import
dtype
as
mstype
from
.distribution
import
Distribution
from
._utils.utils
import
c
onvert_to_batch
,
check_greater
,
check_type
,
check_distribution_name
,
\
from
._utils.utils
import
c
ast_to_tensor
,
check_greater
,
check_type
,
check_distribution_name
,
\
raise_none_error
from
._utils.custom_ops
import
exp_generic
,
log_generic
...
...
@@ -101,12 +101,12 @@ class Uniform(Distribution):
"""
param
=
dict
(
locals
())
valid_dtype
=
mstype
.
float_type
check_type
(
dtype
,
valid_dtype
,
"Uniform"
)
check_type
(
dtype
,
valid_dtype
,
type
(
self
).
__name__
)
super
(
Uniform
,
self
).
__init__
(
seed
,
dtype
,
name
,
param
)
self
.
parameter_type
=
dtype
if
low
is
not
None
and
high
is
not
None
:
self
.
_low
=
c
onvert_to_batch
(
low
,
self
.
broadcast_shape
,
dtype
)
self
.
_high
=
c
onvert_to_batch
(
high
,
self
.
broadcast_shape
,
dtype
)
self
.
_low
=
c
ast_to_tensor
(
low
,
dtype
)
self
.
_high
=
c
ast_to_tensor
(
high
,
dtype
)
check_greater
(
self
.
low
,
self
.
high
,
"low value"
,
"high value"
)
else
:
self
.
_low
=
low
...
...
@@ -142,12 +142,18 @@ class Uniform(Distribution):
Check availablity of distribution specific args low and high.
"""
if
low
is
not
None
:
self
.
checktensor
(
low
,
'low'
)
if
self
.
context_mode
==
0
:
self
.
checktensor
(
low
,
'low'
)
else
:
low
=
self
.
checktensor
(
low
,
'low'
)
low
=
self
.
cast
(
low
,
self
.
parameter_type
)
else
:
low
=
self
.
low
if
self
.
low
is
not
None
else
raise_none_error
(
'low'
)
if
high
is
not
None
:
self
.
checktensor
(
high
,
'high'
)
if
self
.
context_mode
==
0
:
self
.
checktensor
(
high
,
'high'
)
else
:
high
=
self
.
checktensor
(
high
,
'high'
)
high
=
self
.
cast
(
high
,
self
.
parameter_type
)
else
:
high
=
self
.
high
if
self
.
high
is
not
None
else
raise_none_error
(
'high'
)
...
...
@@ -231,7 +237,7 @@ class Uniform(Distribution):
pdf(x) = \frac{1.0}{high -low} if low <= x <= high;
pdf(x) = 0 if x > high;
"""
self
.
checktensor
(
value
,
'value'
)
value
=
self
.
_check_value
(
value
,
'value'
)
value
=
self
.
cast
(
value
,
self
.
dtype
)
low
,
high
=
self
.
_check_param
(
low
,
high
)
neg_ones
=
self
.
fill
(
self
.
dtype
,
self
.
shape
(
value
),
-
1.0
)
...
...
@@ -255,9 +261,9 @@ class Uniform(Distribution):
high_a (Tensor): upper bound of distribution a. Default: self.high.
"""
check_distribution_name
(
dist
,
'Uniform'
)
self
.
checktensor
(
low_b
,
'low_b'
)
low_b
=
self
.
_check_value
(
low_b
,
'low_b'
)
low_b
=
self
.
cast
(
low_b
,
self
.
parameter_type
)
self
.
checktensor
(
high_b
,
'high_b'
)
high_b
=
self
.
_check_value
(
high_b
,
'high_b'
)
high_b
=
self
.
cast
(
high_b
,
self
.
parameter_type
)
low_a
,
high_a
=
self
.
_check_param
(
low
,
high
)
kl
=
self
.
log
(
high_b
-
low_b
)
-
self
.
log
(
high_a
-
low_a
)
...
...
@@ -278,7 +284,7 @@ class Uniform(Distribution):
cdf(x) = \frac{x - low}{high -low} if low <= x <= high;
cdf(x) = 1 if x > high;
"""
self
.
checktensor
(
value
,
'value'
)
value
=
self
.
_check_value
(
value
,
'value'
)
value
=
self
.
cast
(
value
,
self
.
dtype
)
low
,
high
=
self
.
_check_param
(
low
,
high
)
prob
=
(
value
-
low
)
/
(
high
-
low
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录