Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
4fe8b3d3
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4fe8b3d3
编写于
4年前
作者:
X
Xun Deng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix checktensor in pynative mode
上级
b8da525f
master
无相关合并请求
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
107 addition
and
65 deletion
+107
-65
mindspore/nn/probability/bijector/bijector.py
mindspore/nn/probability/bijector/bijector.py
+15
-1
mindspore/nn/probability/bijector/power_transform.py
mindspore/nn/probability/bijector/power_transform.py
+4
-7
mindspore/nn/probability/bijector/scalar_affine.py
mindspore/nn/probability/bijector/scalar_affine.py
+7
-9
mindspore/nn/probability/bijector/softplus.py
mindspore/nn/probability/bijector/softplus.py
+6
-7
mindspore/nn/probability/distribution/_utils/utils.py
mindspore/nn/probability/distribution/_utils/utils.py
+4
-2
mindspore/nn/probability/distribution/bernoulli.py
mindspore/nn/probability/distribution/bernoulli.py
+8
-5
mindspore/nn/probability/distribution/distribution.py
mindspore/nn/probability/distribution/distribution.py
+12
-1
mindspore/nn/probability/distribution/exponential.py
mindspore/nn/probability/distribution/exponential.py
+8
-5
mindspore/nn/probability/distribution/geometric.py
mindspore/nn/probability/distribution/geometric.py
+8
-5
mindspore/nn/probability/distribution/normal.py
mindspore/nn/probability/distribution/normal.py
+16
-10
mindspore/nn/probability/distribution/transformed_distribution.py
...e/nn/probability/distribution/transformed_distribution.py
+3
-3
mindspore/nn/probability/distribution/uniform.py
mindspore/nn/probability/distribution/uniform.py
+16
-10
未找到文件。
mindspore/nn/probability/bijector/bijector.py
浏览文件 @
4fe8b3d3
...
...
@@ -13,8 +13,10 @@
# limitations under the License.
# ============================================================================
"""Bijector"""
from
mindspore
import
context
from
mindspore.nn.cell
import
Cell
from
mindspore._checkparam
import
Validator
as
validator
from
..distribution._utils.utils
import
CheckTensor
from
..distribution
import
Distribution
from
..distribution
import
TransformedDistribution
...
...
@@ -40,7 +42,7 @@ class Bijector(Cell):
Constructor of bijector class.
"""
super
(
Bijector
,
self
).
__init__
()
validator
.
check_value_type
(
'name'
,
name
,
[
str
],
'Bijector'
)
validator
.
check_value_type
(
'name'
,
name
,
[
str
],
type
(
self
).
__name__
)
validator
.
check_value_type
(
'is_constant_jacobian'
,
is_constant_jacobian
,
[
bool
],
name
)
validator
.
check_value_type
(
'is_injective'
,
is_injective
,
[
bool
],
name
)
self
.
_name
=
name
...
...
@@ -53,6 +55,9 @@ class Bijector(Cell):
self
.
_is_constant_jacobian
=
is_constant_jacobian
self
.
_is_injective
=
is_injective
self
.
context_mode
=
context
.
get_context
(
'mode'
)
self
.
checktensor
=
CheckTensor
()
@
property
def
name
(
self
):
return
self
.
_name
...
...
@@ -73,6 +78,15 @@ class Bijector(Cell):
def
is_injective
(
self
):
return
self
.
_is_injective
def
_check_value
(
self
,
value
,
name
):
"""
Check availability fo value as a Tensor.
"""
if
self
.
context_mode
==
0
:
self
.
checktensor
(
value
,
name
)
return
value
return
self
.
checktensor
(
value
,
name
)
def
forward
(
self
,
*
args
,
**
kwargs
):
"""
Forward transformation: transform the input value to another distribution.
...
...
This diff is collapsed.
Click to expand it.
mindspore/nn/probability/bijector/power_transform.py
浏览文件 @
4fe8b3d3
...
...
@@ -16,7 +16,6 @@
from
mindspore.ops
import
operations
as
P
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Rel
from
..distribution._utils.utils
import
CheckTensor
from
..distribution._utils.custom_ops
import
exp_generic
,
expm1_generic
,
log_generic
,
log1p_generic
from
.bijector
import
Bijector
...
...
@@ -66,8 +65,6 @@ class PowerTransform(Bijector):
self
.
log
=
log_generic
self
.
log1p
=
log1p_generic
self
.
checktensor
=
CheckTensor
()
@
property
def
power
(
self
):
return
self
.
_power
...
...
@@ -80,13 +77,13 @@ class PowerTransform(Bijector):
return
shape
def
_forward
(
self
,
x
):
self
.
checktensor
(
x
,
'value'
)
x
=
self
.
_check_value
(
x
,
'value'
)
if
self
.
power
==
0
:
return
self
.
exp
(
x
)
return
self
.
exp
(
self
.
log1p
(
x
*
self
.
power
)
/
self
.
power
)
def
_inverse
(
self
,
y
):
self
.
checktensor
(
y
,
'value'
)
y
=
self
.
_check_value
(
y
,
'value'
)
if
self
.
power
==
0
:
return
self
.
log
(
y
)
return
self
.
expm1
(
self
.
log
(
y
)
*
self
.
power
)
/
self
.
power
...
...
@@ -103,7 +100,7 @@ class PowerTransform(Bijector):
f'(x) = e^\frac{\log(xc + 1)}{c} * \frac{1}{xc + 1}
\log(f'(x)) = (\frac{1}{c} - 1) * \log(xc + 1)
"""
self
.
checktensor
(
x
,
'value'
)
x
=
self
.
_check_value
(
x
,
'value'
)
if
self
.
power
==
0
:
return
x
return
(
1.
/
self
.
power
-
1
)
*
self
.
log1p
(
x
*
self
.
power
)
...
...
@@ -120,5 +117,5 @@ class PowerTransform(Bijector):
f'(x) = \frac{e^c\log(y)}{y}
\log(f'(x)) = \log(\frac{e^c\log(y)}{y}) = (c-1) * \log(y)
"""
self
.
checktensor
(
y
,
'value'
)
y
=
self
.
_check_value
(
y
,
'value'
)
return
(
self
.
power
-
1
)
*
self
.
log
(
y
)
This diff is collapsed.
Click to expand it.
mindspore/nn/probability/bijector/scalar_affine.py
浏览文件 @
4fe8b3d3
...
...
@@ -15,7 +15,7 @@
"""Scalar Affine Bijector"""
from
mindspore.ops
import
operations
as
P
from
mindspore._checkparam
import
Validator
as
validator
from
..distribution._utils.utils
import
cast_to_tensor
,
CheckTensor
from
..distribution._utils.utils
import
cast_to_tensor
from
..distribution._utils.custom_ops
import
log_generic
from
.bijector
import
Bijector
...
...
@@ -57,8 +57,8 @@ class ScalarAffine(Bijector):
Constructor of scalar affine bijector.
"""
param
=
dict
(
locals
())
validator
.
check_value_type
(
'scale'
,
scale
,
[
int
,
float
],
name
)
validator
.
check_value_type
(
'shift'
,
shift
,
[
int
,
float
],
name
)
validator
.
check_value_type
(
'scale'
,
scale
,
[
int
,
float
],
type
(
self
).
__name__
)
validator
.
check_value_type
(
'shift'
,
shift
,
[
int
,
float
],
type
(
self
).
__name__
)
self
.
_scale
=
cast_to_tensor
(
scale
)
self
.
_shift
=
cast_to_tensor
(
shift
)
super
(
ScalarAffine
,
self
).
__init__
(
...
...
@@ -71,8 +71,6 @@ class ScalarAffine(Bijector):
self
.
abs
=
P
.
Abs
()
self
.
log
=
log_generic
self
.
checktensor
=
CheckTensor
()
@
property
def
scale
(
self
):
return
self
.
_scale
...
...
@@ -93,7 +91,7 @@ class ScalarAffine(Bijector):
.. math::
f(x) = a * x + b
"""
self
.
checktensor
(
x
,
'value'
)
x
=
self
.
_check_value
(
x
,
'value'
)
return
self
.
scale
*
x
+
self
.
shift
def
_inverse
(
self
,
y
):
...
...
@@ -101,7 +99,7 @@ class ScalarAffine(Bijector):
.. math::
f(y) = \frac{y - b}{a}
"""
self
.
checktensor
(
y
,
'value'
)
y
=
self
.
_check_value
(
y
,
'value'
)
return
(
y
-
self
.
shift
)
/
self
.
scale
def
_forward_log_jacobian
(
self
,
x
):
...
...
@@ -111,7 +109,7 @@ class ScalarAffine(Bijector):
f'(x) = a
\log(f'(x)) = \log(a)
"""
self
.
checktensor
(
x
,
'value'
)
x
=
self
.
_check_value
(
x
,
'value'
)
return
self
.
log
(
self
.
abs
(
self
.
scale
))
def
_inverse_log_jacobian
(
self
,
y
):
...
...
@@ -121,5 +119,5 @@ class ScalarAffine(Bijector):
f'(x) = \frac{1.0}{a}
\log(f'(x)) = - \log(a)
"""
self
.
checktensor
(
y
,
'value'
)
y
=
self
.
_check_value
(
y
,
'value'
)
return
-
1.
*
self
.
log
(
self
.
abs
(
self
.
scale
))
This diff is collapsed.
Click to expand it.
mindspore/nn/probability/bijector/softplus.py
浏览文件 @
4fe8b3d3
...
...
@@ -18,7 +18,7 @@ from mindspore.ops import operations as P
from
mindspore.common
import
dtype
as
mstype
from
mindspore.nn.layer.activation
import
LogSigmoid
from
mindspore._checkparam
import
Validator
as
validator
from
..distribution._utils.utils
import
cast_to_tensor
,
CheckTensor
from
..distribution._utils.utils
import
cast_to_tensor
from
..distribution._utils.custom_ops
import
exp_generic
,
expm1_generic
,
log_generic
from
.bijector
import
Bijector
...
...
@@ -57,7 +57,7 @@ class Softplus(Bijector):
sharpness
=
1.0
,
name
=
'Softplus'
):
param
=
dict
(
locals
())
validator
.
check_value_type
(
'sharpness'
,
sharpness
,
[
int
,
float
],
name
)
validator
.
check_value_type
(
'sharpness'
,
sharpness
,
[
int
,
float
],
type
(
self
).
__name__
)
super
(
Softplus
,
self
).
__init__
(
name
=
name
,
param
=
param
)
self
.
_sharpness
=
cast_to_tensor
(
sharpness
)
...
...
@@ -76,7 +76,6 @@ class Softplus(Bijector):
self
.
softplus
=
self
.
_softplus
self
.
inverse_softplus
=
self
.
_inverse_softplus
self
.
checktensor
=
CheckTensor
()
self
.
threshold
=
np
.
log
(
np
.
finfo
(
np
.
float32
).
eps
)
+
1
self
.
tiny
=
np
.
exp
(
self
.
threshold
)
...
...
@@ -119,7 +118,7 @@ class Softplus(Bijector):
return
shape
def
_forward
(
self
,
x
):
self
.
checktensor
(
x
,
'value'
)
x
=
self
.
_check_value
(
x
,
'value'
)
scaled_value
=
self
.
sharpness
*
x
return
self
.
softplus
(
scaled_value
)
/
self
.
sharpness
...
...
@@ -129,7 +128,7 @@ class Softplus(Bijector):
f(x) = \frac{\log(1 + e^{kx}))}{k}
f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k}
"""
self
.
checktensor
(
y
,
'value'
)
y
=
self
.
_check_value
(
y
,
'value'
)
scaled_value
=
self
.
sharpness
*
y
return
self
.
inverse_softplus
(
scaled_value
)
/
self
.
sharpness
...
...
@@ -140,7 +139,7 @@ class Softplus(Bijector):
f'(x) = \frac{e^{kx}}{ 1 + e^{kx}}
\log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx)
"""
self
.
checktensor
(
x
,
'value'
)
x
=
self
.
_check_value
(
x
,
'value'
)
scaled_value
=
self
.
sharpness
*
x
return
self
.
log_sigmoid
(
scaled_value
)
...
...
@@ -151,6 +150,6 @@ class Softplus(Bijector):
f'(y) = \frac{e^{ky}}{e^{ky} - 1}
\log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky)
"""
self
.
checktensor
(
y
,
'value'
)
y
=
self
.
_check_value
(
y
,
'value'
)
scaled_value
=
self
.
sharpness
*
y
return
scaled_value
-
self
.
inverse_softplus
(
scaled_value
)
This diff is collapsed.
Click to expand it.
mindspore/nn/probability/distribution/_utils/utils.py
浏览文件 @
4fe8b3d3
...
...
@@ -342,7 +342,7 @@ class CheckTuple(PrimitiveWithInfer):
# Pynative mode
if
isinstance
(
x
,
tuple
):
return
x
raise
TypeError
(
f
"For
{
name
[
'value'
]
}
, Input type should b
a tuple."
)
raise
TypeError
(
f
"For
{
name
}
, input type should be
a tuple."
)
class
CheckTensor
(
PrimitiveWithInfer
):
...
...
@@ -365,4 +365,6 @@ class CheckTensor(PrimitiveWithInfer):
return
out
def
__call__
(
self
,
x
,
name
):
return
if
isinstance
(
x
,
Tensor
):
return
x
raise
TypeError
(
f
"For
{
name
}
, input type should be a Tensor."
)
This diff is collapsed.
Click to expand it.
mindspore/nn/probability/distribution/bernoulli.py
浏览文件 @
4fe8b3d3
...
...
@@ -99,7 +99,7 @@ class Bernoulli(Distribution):
"""
param
=
dict
(
locals
())
valid_dtype
=
mstype
.
int_type
+
mstype
.
uint_type
+
mstype
.
float_type
check_type
(
dtype
,
valid_dtype
,
"Bernoulli"
)
check_type
(
dtype
,
valid_dtype
,
type
(
self
).
__name__
)
super
(
Bernoulli
,
self
).
__init__
(
seed
,
dtype
,
name
,
param
)
self
.
parameter_type
=
mstype
.
float32
if
probs
is
not
None
:
...
...
@@ -144,7 +144,10 @@ class Bernoulli(Distribution):
Check availablity of distribution specific args probs1.
"""
if
probs1
is
not
None
:
self
.
checktensor
(
probs1
,
'probs1'
)
if
self
.
context_mode
==
0
:
self
.
checktensor
(
probs1
,
'probs1'
)
else
:
probs1
=
self
.
checktensor
(
probs1
,
'probs1'
)
return
self
.
cast
(
probs1
,
self
.
parameter_type
)
return
self
.
probs
if
self
.
probs
is
not
None
else
raise_none_error
(
'probs1'
)
...
...
@@ -210,7 +213,7 @@ class Bernoulli(Distribution):
pmf(k) = probs1 if k = 1;
pmf(k) = probs0 if k = 0;
"""
self
.
checktensor
(
value
,
'value'
)
value
=
self
.
_check_value
(
value
,
'value'
)
value
=
self
.
cast
(
value
,
mstype
.
float32
)
probs1
=
self
.
_check_param
(
probs1
)
probs0
=
1.0
-
probs1
...
...
@@ -229,7 +232,7 @@ class Bernoulli(Distribution):
cdf(k) = probs0 if 0 <= k <1;
cdf(k) = 1 if k >=1;
"""
self
.
checktensor
(
value
,
'value'
)
value
=
self
.
_check_value
(
value
,
'value'
)
value
=
self
.
cast
(
value
,
mstype
.
float32
)
value
=
self
.
floor
(
value
)
probs1
=
self
.
_check_param
(
probs1
)
...
...
@@ -257,7 +260,7 @@ class Bernoulli(Distribution):
probs0_a * \log(\frac{probs0_a}{probs0_b})
"""
check_distribution_name
(
dist
,
'Bernoulli'
)
self
.
checktensor
(
probs1_b
,
'probs1_b'
)
probs1_b
=
self
.
_check_value
(
probs1_b
,
'probs1_b'
)
probs1_b
=
self
.
cast
(
probs1_b
,
self
.
parameter_type
)
probs1_a
=
self
.
_check_param
(
probs1
)
probs0_a
=
1.0
-
probs1_a
...
...
This diff is collapsed.
Click to expand it.
mindspore/nn/probability/distribution/distribution.py
浏览文件 @
4fe8b3d3
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""basic"""
from
mindspore
import
context
from
mindspore.nn.cell
import
Cell
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Rel
...
...
@@ -54,7 +55,7 @@ class Distribution(Cell):
Constructor of distribution class.
"""
super
(
Distribution
,
self
).
__init__
()
validator
.
check_value_type
(
'name'
,
name
,
[
str
],
'distribution_name'
)
validator
.
check_value_type
(
'name'
,
name
,
[
str
],
type
(
self
).
__name__
)
validator
.
check_integer
(
'seed'
,
seed
,
0
,
Rel
.
GE
,
name
)
self
.
_name
=
name
...
...
@@ -81,6 +82,7 @@ class Distribution(Cell):
self
.
_set_log_survival
()
self
.
_set_cross_entropy
()
self
.
context_mode
=
context
.
get_context
(
'mode'
)
self
.
checktuple
=
CheckTuple
()
self
.
checktensor
=
CheckTensor
()
...
...
@@ -108,6 +110,15 @@ class Distribution(Cell):
def
broadcast_shape
(
self
):
return
self
.
_broadcast_shape
def
_check_value
(
self
,
value
,
name
):
"""
Check availability fo value as a Tensor.
"""
if
self
.
context_mode
==
0
:
self
.
checktensor
(
value
,
name
)
return
value
return
self
.
checktensor
(
value
,
name
)
def
_set_prob
(
self
):
"""
Set probability funtion based on the availability of _prob and _log_likehood.
...
...
This diff is collapsed.
Click to expand it.
mindspore/nn/probability/distribution/exponential.py
浏览文件 @
4fe8b3d3
...
...
@@ -100,7 +100,7 @@ class Exponential(Distribution):
"""
param
=
dict
(
locals
())
valid_dtype
=
mstype
.
float_type
check_type
(
dtype
,
valid_dtype
,
"Exponential"
)
check_type
(
dtype
,
valid_dtype
,
type
(
self
).
__name__
)
super
(
Exponential
,
self
).
__init__
(
seed
,
dtype
,
name
,
param
)
self
.
parameter_type
=
dtype
if
rate
is
not
None
:
...
...
@@ -146,7 +146,10 @@ class Exponential(Distribution):
Check availablity of distribution specific args rate.
"""
if
rate
is
not
None
:
self
.
checktensor
(
rate
,
'rate'
)
if
self
.
context_mode
==
0
:
self
.
checktensor
(
rate
,
'rate'
)
else
:
rate
=
self
.
checktensor
(
rate
,
'rate'
)
return
self
.
cast
(
rate
,
self
.
parameter_type
)
return
self
.
rate
if
self
.
rate
is
not
None
else
raise_none_error
(
'rate'
)
...
...
@@ -210,7 +213,7 @@ class Exponential(Distribution):
.. math::
pdf(x) = rate * \exp(-1 * \lambda * x) if x >= 0 else 0
"""
self
.
checktensor
(
value
,
"value"
)
value
=
self
.
_check_value
(
value
,
"value"
)
value
=
self
.
cast
(
value
,
self
.
dtype
)
rate
=
self
.
_check_param
(
rate
)
prob
=
self
.
exp
(
self
.
log
(
rate
)
-
rate
*
value
)
...
...
@@ -232,7 +235,7 @@ class Exponential(Distribution):
.. math::
cdf(x) = 1.0 - \exp(-1 * \lambda * x) if x >= 0 else 0
"""
self
.
checktensor
(
value
,
'value'
)
value
=
self
.
_check_value
(
value
,
'value'
)
value
=
self
.
cast
(
value
,
self
.
dtype
)
rate
=
self
.
_check_param
(
rate
)
cdf
=
1.0
-
self
.
exp
(
-
1.
*
rate
*
value
)
...
...
@@ -251,7 +254,7 @@ class Exponential(Distribution):
rate_a (Tensor): rate of distribution a. Default: self.rate.
"""
check_distribution_name
(
dist
,
'Exponential'
)
self
.
checktensor
(
rate_b
,
'rate_b'
)
rate_b
=
self
.
_check_value
(
rate_b
,
'rate_b'
)
rate_b
=
self
.
cast
(
rate_b
,
self
.
parameter_type
)
rate_a
=
self
.
_check_param
(
rate
)
return
self
.
log
(
rate_a
)
-
self
.
log
(
rate_b
)
+
rate_b
/
rate_a
-
1.0
...
...
This diff is collapsed.
Click to expand it.
mindspore/nn/probability/distribution/geometric.py
浏览文件 @
4fe8b3d3
...
...
@@ -102,7 +102,7 @@ class Geometric(Distribution):
"""
param
=
dict
(
locals
())
valid_dtype
=
mstype
.
int_type
+
mstype
.
uint_type
+
mstype
.
float_type
check_type
(
dtype
,
valid_dtype
,
"Geometric"
)
check_type
(
dtype
,
valid_dtype
,
type
(
self
).
__name__
)
super
(
Geometric
,
self
).
__init__
(
seed
,
dtype
,
name
,
param
)
self
.
parameter_type
=
mstype
.
float32
if
probs
is
not
None
:
...
...
@@ -150,7 +150,10 @@ class Geometric(Distribution):
Check availablity of distribution specific args probs1.
"""
if
probs1
is
not
None
:
self
.
checktensor
(
probs1
,
'probs1'
)
if
self
.
context_mode
==
0
:
self
.
checktensor
(
probs1
,
'probs1'
)
else
:
probs1
=
self
.
checktensor
(
probs1
,
'probs1'
)
return
self
.
cast
(
probs1
,
self
.
parameter_type
)
return
self
.
probs
if
self
.
probs
is
not
None
else
raise_none_error
(
'probs1'
)
...
...
@@ -211,7 +214,7 @@ class Geometric(Distribution):
pmf(k) = probs0 ^k * probs1 if k >= 0;
pmf(k) = 0 if k < 0.
"""
self
.
checktensor
(
value
,
'value'
)
value
=
self
.
_check_value
(
value
,
'value'
)
value
=
self
.
cast
(
value
,
mstype
.
float32
)
value
=
self
.
floor
(
value
)
probs1
=
self
.
_check_param
(
probs1
)
...
...
@@ -233,7 +236,7 @@ class Geometric(Distribution):
cdf(k) = 0 if k < 0.
"""
self
.
checktensor
(
value
,
'value'
)
value
=
self
.
_check_value
(
value
,
'value'
)
value
=
self
.
cast
(
value
,
mstype
.
float32
)
value
=
self
.
floor
(
value
)
probs1
=
self
.
_check_param
(
probs1
)
...
...
@@ -256,7 +259,7 @@ class Geometric(Distribution):
KL(a||b) = \log(\frac{probs1_a}{probs1_b}) + \frac{probs0_a}{probs1_a} * \log(\frac{probs0_a}{probs0_b})
"""
check_distribution_name
(
dist
,
'Geometric'
)
self
.
checktensor
(
probs1_b
,
'probs1_b'
)
probs1_b
=
self
.
_check_value
(
probs1_b
,
'probs1_b'
)
probs1_b
=
self
.
cast
(
probs1_b
,
self
.
parameter_type
)
probs1_a
=
self
.
_check_param
(
probs1
)
probs0_a
=
1.0
-
probs1_a
...
...
This diff is collapsed.
Click to expand it.
mindspore/nn/probability/distribution/normal.py
浏览文件 @
4fe8b3d3
...
...
@@ -18,7 +18,7 @@ from mindspore.ops import operations as P
from
mindspore.ops
import
composite
as
C
from
mindspore.common
import
dtype
as
mstype
from
.distribution
import
Distribution
from
._utils.utils
import
c
onvert_to_batch
,
check_greater_zero
,
check_type
,
check_distribution_name
,
\
from
._utils.utils
import
c
ast_to_tensor
,
check_greater_zero
,
check_type
,
check_distribution_name
,
\
raise_none_error
from
._utils.custom_ops
import
exp_generic
,
expm1_generic
,
log_generic
,
erf_generic
...
...
@@ -102,12 +102,12 @@ class Normal(Distribution):
"""
param
=
dict
(
locals
())
valid_dtype
=
mstype
.
float_type
check_type
(
dtype
,
valid_dtype
,
"Normal"
)
check_type
(
dtype
,
valid_dtype
,
type
(
self
).
__name__
)
super
(
Normal
,
self
).
__init__
(
seed
,
dtype
,
name
,
param
)
self
.
parameter_type
=
dtype
if
mean
is
not
None
and
sd
is
not
None
:
self
.
_mean_value
=
c
onvert_to_batch
(
mean
,
self
.
broadcast_shape
,
self
.
parameter_type
)
self
.
_sd_value
=
c
onvert_to_batch
(
sd
,
self
.
broadcast_shape
,
self
.
parameter_type
)
self
.
_mean_value
=
c
ast_to_tensor
(
mean
,
self
.
parameter_type
)
self
.
_sd_value
=
c
ast_to_tensor
(
sd
,
self
.
parameter_type
)
check_greater_zero
(
self
.
_sd_value
,
"Standard deviation"
)
else
:
self
.
_mean_value
=
mean
...
...
@@ -139,12 +139,18 @@ class Normal(Distribution):
Check availablity of distribution specific args mean and sd.
"""
if
mean
is
not
None
:
self
.
checktensor
(
mean
,
'mean'
)
if
self
.
context_mode
==
0
:
self
.
checktensor
(
mean
,
'mean'
)
else
:
mean
=
self
.
checktensor
(
mean
,
'mean'
)
mean
=
self
.
cast
(
mean
,
self
.
parameter_type
)
else
:
mean
=
self
.
_mean_value
if
self
.
_mean_value
is
not
None
else
raise_none_error
(
'mean'
)
if
sd
is
not
None
:
self
.
checktensor
(
sd
,
'sd'
)
if
self
.
context_mode
==
0
:
self
.
checktensor
(
sd
,
'sd'
)
else
:
sd
=
self
.
checktensor
(
sd
,
'sd'
)
sd
=
self
.
cast
(
sd
,
self
.
parameter_type
)
else
:
sd
=
self
.
_sd_value
if
self
.
_sd_value
is
not
None
else
raise_none_error
(
'sd'
)
...
...
@@ -210,7 +216,7 @@ class Normal(Distribution):
.. math::
L(x) = -1* \frac{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2))
"""
self
.
checktensor
(
value
,
'value'
)
value
=
self
.
_check_value
(
value
,
'value'
)
value
=
self
.
cast
(
value
,
self
.
dtype
)
mean
,
sd
=
self
.
_check_param
(
mean
,
sd
)
unnormalized_log_prob
=
-
1.
*
(
self
.
sq
(
value
-
mean
))
/
(
2.
*
self
.
sq
(
sd
))
...
...
@@ -229,7 +235,7 @@ class Normal(Distribution):
.. math::
cdf(x) = 0.5 * (1+ Erf((x - \mu) / ( \sigma * \sqrt(2))))
"""
self
.
checktensor
(
value
,
'value'
)
value
=
self
.
_check_value
(
value
,
'value'
)
value
=
self
.
cast
(
value
,
self
.
dtype
)
mean
,
sd
=
self
.
_check_param
(
mean
,
sd
)
sqrt2
=
self
.
sqrt
(
self
.
const
(
2.0
))
...
...
@@ -252,8 +258,8 @@ class Normal(Distribution):
0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
"""
check_distribution_name
(
dist
,
'Normal'
)
self
.
checktensor
(
mean_b
,
'mean_b'
)
s
elf
.
checktensor
(
sd_b
,
'sd_b'
)
mean_b
=
self
.
_check_value
(
mean_b
,
'mean_b'
)
s
d_b
=
self
.
_check_value
(
sd_b
,
'sd_b'
)
mean_b
=
self
.
cast
(
mean_b
,
self
.
parameter_type
)
sd_b
=
self
.
cast
(
sd_b
,
self
.
parameter_type
)
mean_a
,
sd_a
=
self
.
_check_param
(
mean
,
sd
)
...
...
This diff is collapsed.
Click to expand it.
mindspore/nn/probability/distribution/transformed_distribution.py
浏览文件 @
4fe8b3d3
...
...
@@ -46,10 +46,10 @@ class TransformedDistribution(Distribution):
Constructor of transformed_distribution class.
"""
param
=
dict
(
locals
())
validator
.
check_value_type
(
'bijector'
,
bijector
,
[
nn
.
probability
.
bijector
.
Bijector
],
name
)
validator
.
check_value_type
(
'distribution'
,
distribution
,
[
Distribution
],
name
)
validator
.
check_value_type
(
'bijector'
,
bijector
,
[
nn
.
probability
.
bijector
.
Bijector
],
type
(
self
).
__name__
)
validator
.
check_value_type
(
'distribution'
,
distribution
,
[
Distribution
],
type
(
self
).
__name__
)
valid_dtype
=
mstype
.
number_type
check_type
(
dtype
,
valid_dtype
,
"transformed_distribution"
)
check_type
(
dtype
,
valid_dtype
,
type
(
self
).
__name__
)
super
(
TransformedDistribution
,
self
).
__init__
(
seed
,
dtype
,
name
,
param
)
self
.
_bijector
=
bijector
...
...
This diff is collapsed.
Click to expand it.
mindspore/nn/probability/distribution/uniform.py
浏览文件 @
4fe8b3d3
...
...
@@ -17,7 +17,7 @@ from mindspore.ops import operations as P
from
mindspore.ops
import
composite
as
C
from
mindspore.common
import
dtype
as
mstype
from
.distribution
import
Distribution
from
._utils.utils
import
c
onvert_to_batch
,
check_greater
,
check_type
,
check_distribution_name
,
\
from
._utils.utils
import
c
ast_to_tensor
,
check_greater
,
check_type
,
check_distribution_name
,
\
raise_none_error
from
._utils.custom_ops
import
exp_generic
,
log_generic
...
...
@@ -101,12 +101,12 @@ class Uniform(Distribution):
"""
param
=
dict
(
locals
())
valid_dtype
=
mstype
.
float_type
check_type
(
dtype
,
valid_dtype
,
"Uniform"
)
check_type
(
dtype
,
valid_dtype
,
type
(
self
).
__name__
)
super
(
Uniform
,
self
).
__init__
(
seed
,
dtype
,
name
,
param
)
self
.
parameter_type
=
dtype
if
low
is
not
None
and
high
is
not
None
:
self
.
_low
=
c
onvert_to_batch
(
low
,
self
.
broadcast_shape
,
dtype
)
self
.
_high
=
c
onvert_to_batch
(
high
,
self
.
broadcast_shape
,
dtype
)
self
.
_low
=
c
ast_to_tensor
(
low
,
dtype
)
self
.
_high
=
c
ast_to_tensor
(
high
,
dtype
)
check_greater
(
self
.
low
,
self
.
high
,
"low value"
,
"high value"
)
else
:
self
.
_low
=
low
...
...
@@ -142,12 +142,18 @@ class Uniform(Distribution):
Check availablity of distribution specific args low and high.
"""
if
low
is
not
None
:
self
.
checktensor
(
low
,
'low'
)
if
self
.
context_mode
==
0
:
self
.
checktensor
(
low
,
'low'
)
else
:
low
=
self
.
checktensor
(
low
,
'low'
)
low
=
self
.
cast
(
low
,
self
.
parameter_type
)
else
:
low
=
self
.
low
if
self
.
low
is
not
None
else
raise_none_error
(
'low'
)
if
high
is
not
None
:
self
.
checktensor
(
high
,
'high'
)
if
self
.
context_mode
==
0
:
self
.
checktensor
(
high
,
'high'
)
else
:
high
=
self
.
checktensor
(
high
,
'high'
)
high
=
self
.
cast
(
high
,
self
.
parameter_type
)
else
:
high
=
self
.
high
if
self
.
high
is
not
None
else
raise_none_error
(
'high'
)
...
...
@@ -231,7 +237,7 @@ class Uniform(Distribution):
pdf(x) = \frac{1.0}{high -low} if low <= x <= high;
pdf(x) = 0 if x > high;
"""
self
.
checktensor
(
value
,
'value'
)
value
=
self
.
_check_value
(
value
,
'value'
)
value
=
self
.
cast
(
value
,
self
.
dtype
)
low
,
high
=
self
.
_check_param
(
low
,
high
)
neg_ones
=
self
.
fill
(
self
.
dtype
,
self
.
shape
(
value
),
-
1.0
)
...
...
@@ -255,9 +261,9 @@ class Uniform(Distribution):
high_a (Tensor): upper bound of distribution a. Default: self.high.
"""
check_distribution_name
(
dist
,
'Uniform'
)
self
.
checktensor
(
low_b
,
'low_b'
)
low_b
=
self
.
_check_value
(
low_b
,
'low_b'
)
low_b
=
self
.
cast
(
low_b
,
self
.
parameter_type
)
self
.
checktensor
(
high_b
,
'high_b'
)
high_b
=
self
.
_check_value
(
high_b
,
'high_b'
)
high_b
=
self
.
cast
(
high_b
,
self
.
parameter_type
)
low_a
,
high_a
=
self
.
_check_param
(
low
,
high
)
kl
=
self
.
log
(
high_b
-
low_b
)
-
self
.
log
(
high_a
-
low_a
)
...
...
@@ -278,7 +284,7 @@ class Uniform(Distribution):
cdf(x) = \frac{x - low}{high -low} if low <= x <= high;
cdf(x) = 1 if x > high;
"""
self
.
checktensor
(
value
,
'value'
)
value
=
self
.
_check_value
(
value
,
'value'
)
value
=
self
.
cast
(
value
,
self
.
dtype
)
low
,
high
=
self
.
_check_param
(
low
,
high
)
prob
=
(
value
-
low
)
/
(
high
-
low
)
...
...
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
反馈
建议
客服
返回
顶部