Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
22598e5c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
22598e5c
编写于
8月 10, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 10, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4084 Complement the arg passing conventions
Merge pull request !4084 from peixu_ren/custom_bijector
上级
42047764
60bb6beb
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
85 addition
and
91 deletion
+85
-91
mindspore/nn/probability/bijector/bijector.py
mindspore/nn/probability/bijector/bijector.py
+15
-15
mindspore/nn/probability/distribution/distribution.py
mindspore/nn/probability/distribution/distribution.py
+68
-71
mindspore/nn/probability/distribution/normal.py
mindspore/nn/probability/distribution/normal.py
+2
-5
未找到文件。
mindspore/nn/probability/bijector/bijector.py
浏览文件 @
22598e5c
...
...
@@ -69,31 +69,31 @@ class Bijector(Cell):
def
is_injective
(
self
):
return
self
.
_is_injective
def
forward
(
self
,
*
args
):
def
forward
(
self
,
*
args
,
**
kwargs
):
"""
Forward transformation: transform the input value to another distribution.
"""
return
self
.
_forward
(
*
args
)
return
self
.
_forward
(
*
args
,
**
kwargs
)
def
inverse
(
self
,
*
args
):
def
inverse
(
self
,
*
args
,
**
kwargs
):
"""
Inverse transformation: transform the input value back to the original distribution.
"""
return
self
.
_inverse
(
*
args
)
return
self
.
_inverse
(
*
args
,
**
kwargs
)
def
forward_log_jacobian
(
self
,
*
args
):
def
forward_log_jacobian
(
self
,
*
args
,
**
kwargs
):
"""
Logarithm of the derivative of forward transformation.
"""
return
self
.
_forward_log_jacobian
(
*
args
)
return
self
.
_forward_log_jacobian
(
*
args
,
**
kwargs
)
def
inverse_log_jacobian
(
self
,
*
args
):
def
inverse_log_jacobian
(
self
,
*
args
,
**
kwargs
):
"""
Logarithm of the derivative of forward transformation.
"""
return
self
.
_inverse_log_jacobian
(
*
args
)
return
self
.
_inverse_log_jacobian
(
*
args
,
**
kwargs
)
def
__call__
(
self
,
*
args
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
"""
Call Bijector directly.
This __call__ may go into two directions:
...
...
@@ -107,9 +107,9 @@ class Bijector(Cell):
"""
if
isinstance
(
args
[
0
],
Distribution
):
return
TransformedDistribution
(
self
,
args
[
0
])
return
super
(
Bijector
,
self
).
__call__
(
*
args
)
return
super
(
Bijector
,
self
).
__call__
(
*
args
,
**
kwargs
)
def
construct
(
self
,
name
,
*
args
):
def
construct
(
self
,
name
,
*
args
,
**
kwargs
):
"""
Override construct in Cell.
...
...
@@ -120,11 +120,11 @@ class Bijector(Cell):
Always raise RuntimeError as Distribution should not be called directly.
"""
if
name
==
'forward'
:
return
self
.
forward
(
*
args
)
return
self
.
forward
(
*
args
,
**
kwargs
)
if
name
==
'inverse'
:
return
self
.
inverse
(
*
args
)
return
self
.
inverse
(
*
args
,
**
kwargs
)
if
name
==
'forward_log_jacobian'
:
return
self
.
forward_log_jacobian
(
*
args
)
return
self
.
forward_log_jacobian
(
*
args
,
**
kwargs
)
if
name
==
'inverse_log_jacobian'
:
return
self
.
inverse_log_jacobian
(
*
args
)
return
self
.
inverse_log_jacobian
(
*
args
,
**
kwargs
)
return
None
mindspore/nn/probability/distribution/distribution.py
浏览文件 @
22598e5c
...
...
@@ -27,7 +27,7 @@ class Distribution(Cell):
Note:
Derived class should override operations such as ,_mean, _prob,
and _log_prob. Arguments should be passed in through *args.
and _log_prob. Arguments should be passed in through *args
or **kwargs
.
Dist_spec_args are unique for each type of distribution. For example, mean and sd
are the dist_spec_args for a Normal distribution.
...
...
@@ -171,7 +171,7 @@ class Distribution(Cell):
if
hasattr
(
self
,
'_cross_entropy'
):
self
.
_call_cross_entropy
=
self
.
_cross_entropy
def
log_prob
(
self
,
*
args
):
def
log_prob
(
self
,
*
args
,
**
kwargs
):
"""
Evaluate the log probability(pdf or pmf) at the given value.
...
...
@@ -179,18 +179,18 @@ class Distribution(Cell):
Args must include value.
Dist_spec_args are optional.
"""
return
self
.
_call_log_prob
(
*
args
)
return
self
.
_call_log_prob
(
*
args
,
**
kwargs
)
def
_calc_prob_from_log_prob
(
self
,
*
args
):
def
_calc_prob_from_log_prob
(
self
,
*
args
,
**
kwargs
):
r
"""
Evaluate prob from log probability.
.. math::
probability(x) = \exp(log_likehood(x))
"""
return
self
.
exp
(
self
.
_log_prob
(
*
args
))
return
self
.
exp
(
self
.
_log_prob
(
*
args
,
**
kwargs
))
def
prob
(
self
,
*
args
):
def
prob
(
self
,
*
args
,
**
kwargs
):
"""
Evaluate the probability (pdf or pmf) at given value.
...
...
@@ -198,18 +198,18 @@ class Distribution(Cell):
Args must include value.
Dist_spec_args are optional.
"""
return
self
.
_call_prob
(
*
args
)
return
self
.
_call_prob
(
*
args
,
**
kwargs
)
def
_calc_log_prob_from_prob
(
self
,
*
args
):
def
_calc_log_prob_from_prob
(
self
,
*
args
,
**
kwargs
):
r
"""
Evaluate log probability from probability.
.. math::
log_prob(x) = \log(prob(x))
"""
return
self
.
log
(
self
.
_prob
(
*
args
))
return
self
.
log
(
self
.
_prob
(
*
args
,
**
kwargs
))
def
cdf
(
self
,
*
args
):
def
cdf
(
self
,
*
args
,
**
kwargs
):
"""
Evaluate the cdf at given value.
...
...
@@ -217,36 +217,36 @@ class Distribution(Cell):
Args must include value.
Dist_spec_args are optional.
"""
return
self
.
_call_cdf
(
*
args
)
return
self
.
_call_cdf
(
*
args
,
**
kwargs
)
def
_calc_cdf_from_log_cdf
(
self
,
*
args
):
def
_calc_cdf_from_log_cdf
(
self
,
*
args
,
**
kwargs
):
r
"""
Evaluate cdf from log_cdf.
.. math::
cdf(x) = \exp(log_cdf(x))
"""
return
self
.
exp
(
self
.
_log_cdf
(
*
args
))
return
self
.
exp
(
self
.
_log_cdf
(
*
args
,
**
kwargs
))
def
_calc_cdf_from_survival
(
self
,
*
args
):
def
_calc_cdf_from_survival
(
self
,
*
args
,
**
kwargs
):
r
"""
Evaluate cdf from survival function.
.. math::
cdf(x) = 1 - (survival_function(x))
"""
return
1.0
-
self
.
_survival_function
(
*
args
)
return
1.0
-
self
.
_survival_function
(
*
args
,
**
kwargs
)
def
_calc_cdf_from_log_survival
(
self
,
*
args
):
def
_calc_cdf_from_log_survival
(
self
,
*
args
,
**
kwargs
):
r
"""
Evaluate cdf from log survival function.
.. math::
cdf(x) = 1 - (\exp(log_survival(x)))
"""
return
1.0
-
self
.
exp
(
self
.
_log_survival
(
*
args
))
return
1.0
-
self
.
exp
(
self
.
_log_survival
(
*
args
,
**
kwargs
))
def
log_cdf
(
self
,
*
args
):
def
log_cdf
(
self
,
*
args
,
**
kwargs
):
"""
Evaluate the log cdf at given value.
...
...
@@ -254,18 +254,18 @@ class Distribution(Cell):
Args must include value.
Dist_spec_args are optional.
"""
return
self
.
_call_log_cdf
(
*
args
)
return
self
.
_call_log_cdf
(
*
args
,
**
kwargs
)
def
_calc_log_cdf_from_call_cdf
(
self
,
*
args
):
def
_calc_log_cdf_from_call_cdf
(
self
,
*
args
,
**
kwargs
):
r
"""
Evaluate log cdf from cdf.
.. math::
log_cdf(x) = \log(cdf(x))
"""
return
self
.
log
(
self
.
_call_cdf
(
*
args
))
return
self
.
log
(
self
.
_call_cdf
(
*
args
,
**
kwargs
))
def
survival_function
(
self
,
*
args
):
def
survival_function
(
self
,
*
args
,
**
kwargs
):
"""
Evaluate the survival function at given value.
...
...
@@ -273,27 +273,27 @@ class Distribution(Cell):
Args must include value.
Dist_spec_args are optional.
"""
return
self
.
_call_survival
(
*
args
)
return
self
.
_call_survival
(
*
args
,
**
kwargs
)
def
_calc_survival_from_call_cdf
(
self
,
*
args
):
def
_calc_survival_from_call_cdf
(
self
,
*
args
,
**
kwargs
):
r
"""
Evaluate survival function from cdf.
.. math::
survival_function(x) = 1 - (cdf(x))
"""
return
1.0
-
self
.
_call_cdf
(
*
args
)
return
1.0
-
self
.
_call_cdf
(
*
args
,
**
kwargs
)
def
_calc_survival_from_log_survival
(
self
,
*
args
):
def
_calc_survival_from_log_survival
(
self
,
*
args
,
**
kwargs
):
r
"""
Evaluate survival function from log survival function.
.. math::
survival(x) = \exp(survival_function(x))
"""
return
self
.
exp
(
self
.
_log_survival
(
*
args
))
return
self
.
exp
(
self
.
_log_survival
(
*
args
,
**
kwargs
))
def
log_survival
(
self
,
*
args
):
def
log_survival
(
self
,
*
args
,
**
kwargs
):
"""
Evaluate the log survival function at given value.
...
...
@@ -301,18 +301,18 @@ class Distribution(Cell):
Args must include value.
Dist_spec_args are optional.
"""
return
self
.
_call_log_survival
(
*
args
)
return
self
.
_call_log_survival
(
*
args
,
**
kwargs
)
def
_calc_log_survival_from_call_survival
(
self
,
*
args
):
def
_calc_log_survival_from_call_survival
(
self
,
*
args
,
**
kwargs
):
r
"""
Evaluate log survival function from survival function.
.. math::
log_survival(x) = \log(survival_function(x))
"""
return
self
.
log
(
self
.
_call_survival
(
*
args
))
return
self
.
log
(
self
.
_call_survival
(
*
args
,
**
kwargs
))
def
kl_loss
(
self
,
*
args
):
def
kl_loss
(
self
,
*
args
,
**
kwargs
):
"""
Evaluate the KL divergence, i.e. KL(a||b).
...
...
@@ -320,72 +320,72 @@ class Distribution(Cell):
Args must include type of the distribution, parameters of distribution b.
Parameters for distribution a are optional.
"""
return
self
.
_kl_loss
(
*
args
)
return
self
.
_kl_loss
(
*
args
,
**
kwargs
)
def
mean
(
self
,
*
args
):
def
mean
(
self
,
*
args
,
**
kwargs
):
"""
Evaluate the mean.
Note:
Dist_spec_args are optional.
"""
return
self
.
_mean
(
*
args
)
return
self
.
_mean
(
*
args
,
**
kwargs
)
def
mode
(
self
,
*
args
):
def
mode
(
self
,
*
args
,
**
kwargs
):
"""
Evaluate the mode.
Note:
Dist_spec_args are optional.
"""
return
self
.
_mode
(
*
args
)
return
self
.
_mode
(
*
args
,
**
kwargs
)
def
sd
(
self
,
*
args
):
def
sd
(
self
,
*
args
,
**
kwargs
):
"""
Evaluate the standard deviation.
Note:
Dist_spec_args are optional.
"""
return
self
.
_call_sd
(
*
args
)
return
self
.
_call_sd
(
*
args
,
**
kwargs
)
def
var
(
self
,
*
args
):
def
var
(
self
,
*
args
,
**
kwargs
):
"""
Evaluate the variance.
Note:
Dist_spec_args are optional.
"""
return
self
.
_call_var
(
*
args
)
return
self
.
_call_var
(
*
args
,
**
kwargs
)
def
_calc_sd_from_var
(
self
,
*
args
):
def
_calc_sd_from_var
(
self
,
*
args
,
**
kwargs
):
r
"""
Evaluate log probability from probability.
.. math::
STD(x) = \sqrt(VAR(x))
"""
return
self
.
sqrt
(
self
.
_var
(
*
args
))
return
self
.
sqrt
(
self
.
_var
(
*
args
,
**
kwargs
))
def
_calc_var_from_sd
(
self
,
*
args
):
def
_calc_var_from_sd
(
self
,
*
args
,
**
kwargs
):
r
"""
Evaluate log probability from probability.
.. math::
VAR(x) = STD(x) ^ 2
"""
return
self
.
sq
(
self
.
_sd
(
*
args
))
return
self
.
sq
(
self
.
_sd
(
*
args
,
**
kwargs
))
def
entropy
(
self
,
*
args
):
def
entropy
(
self
,
*
args
,
**
kwargs
):
"""
Evaluate the entropy.
Note:
Dist_spec_args are optional.
"""
return
self
.
_entropy
(
*
args
)
return
self
.
_entropy
(
*
args
,
**
kwargs
)
def
cross_entropy
(
self
,
*
args
):
def
cross_entropy
(
self
,
*
args
,
**
kwargs
):
"""
Evaluate the cross_entropy between distribution a and b.
...
...
@@ -393,32 +393,29 @@ class Distribution(Cell):
Args must include type of the distribution, parameters of distribution b.
Parameters for distribution a are optional.
"""
return
self
.
_call_cross_entropy
(
*
args
)
return
self
.
_call_cross_entropy
(
*
args
,
**
kwargs
)
def
_calc_cross_entropy
(
self
,
*
args
):
def
_calc_cross_entropy
(
self
,
*
args
,
**
kwargs
):
r
"""
Evaluate cross_entropy from entropy and kl divergence.
.. math::
H(X, Y) = H(X) + KL(X||Y)
"""
return
self
.
_entropy
(
*
args
)
+
self
.
_kl_loss
(
*
args
)
return
self
.
_entropy
(
*
args
,
**
kwargs
)
+
self
.
_kl_loss
(
*
args
,
**
kw
args
)
def
sample
(
self
,
*
args
):
def
sample
(
self
,
*
args
,
**
kwargs
):
"""
Sampling function.
Args:
*args (list): arguments passed in through construct.
Note:
Shape of the sample is default to ().
Dist_spec_args are optional.
"""
return
self
.
_sample
(
*
args
)
return
self
.
_sample
(
*
args
,
**
kwargs
)
def
construct
(
self
,
name
,
*
args
):
def
construct
(
self
,
name
,
*
args
,
**
kwargs
):
"""
Override construct in Cell.
...
...
@@ -433,31 +430,31 @@ class Distribution(Cell):
"""
if
name
==
'log_prob'
:
return
self
.
_call_log_prob
(
*
args
)
return
self
.
_call_log_prob
(
*
args
,
**
kwargs
)
if
name
==
'prob'
:
return
self
.
_call_prob
(
*
args
)
return
self
.
_call_prob
(
*
args
,
**
kwargs
)
if
name
==
'cdf'
:
return
self
.
_call_cdf
(
*
args
)
return
self
.
_call_cdf
(
*
args
,
**
kwargs
)
if
name
==
'log_cdf'
:
return
self
.
_call_log_cdf
(
*
args
)
return
self
.
_call_log_cdf
(
*
args
,
**
kwargs
)
if
name
==
'survival_function'
:
return
self
.
_call_survival
(
*
args
)
return
self
.
_call_survival
(
*
args
,
**
kwargs
)
if
name
==
'log_survival'
:
return
self
.
_call_log_survival
(
*
args
)
return
self
.
_call_log_survival
(
*
args
,
**
kwargs
)
if
name
==
'kl_loss'
:
return
self
.
_kl_loss
(
*
args
)
return
self
.
_kl_loss
(
*
args
,
**
kwargs
)
if
name
==
'mean'
:
return
self
.
_mean
(
*
args
)
return
self
.
_mean
(
*
args
,
**
kwargs
)
if
name
==
'mode'
:
return
self
.
_mode
(
*
args
)
return
self
.
_mode
(
*
args
,
**
kwargs
)
if
name
==
'sd'
:
return
self
.
_call_sd
(
*
args
)
return
self
.
_call_sd
(
*
args
,
**
kwargs
)
if
name
==
'var'
:
return
self
.
_call_var
(
*
args
)
return
self
.
_call_var
(
*
args
,
**
kwargs
)
if
name
==
'entropy'
:
return
self
.
_entropy
(
*
args
)
return
self
.
_entropy
(
*
args
,
**
kwargs
)
if
name
==
'cross_entropy'
:
return
self
.
_call_cross_entropy
(
*
args
)
return
self
.
_call_cross_entropy
(
*
args
,
**
kwargs
)
if
name
==
'sample'
:
return
self
.
_sample
(
*
args
)
return
self
.
_sample
(
*
args
,
**
kwargs
)
return
None
mindspore/nn/probability/distribution/normal.py
浏览文件 @
22598e5c
...
...
@@ -256,8 +256,5 @@ class Normal(Distribution):
sd
=
self
.
_sd_value
if
sd
is
None
else
sd
batch_shape
=
self
.
shape
(
self
.
zeroslike
(
mean
)
+
self
.
zeroslike
(
sd
))
sample_shape
=
shape
+
batch_shape
mean_zero
=
self
.
const
(
0.0
)
sd_one
=
self
.
const
(
1.0
)
sample_norm
=
C
.
normal
(
sample_shape
,
mean_zero
,
sd_one
,
self
.
seed
)
sample
=
mean
+
sample_norm
*
sd
return
sample
sample_norm
=
C
.
normal
(
sample_shape
,
mean
,
sd
,
self
.
seed
)
return
sample_norm
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录