Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
4aa339cb
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4aa339cb
编写于
8月 20, 2020
作者:
P
peixu_ren
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Change the interfaces in trasformation base class
上级
5a0fe979
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
32 addition
and
15 deletion
+32
-15
mindspore/nn/probability/distribution/_utils/utils.py
mindspore/nn/probability/distribution/_utils/utils.py
+4
-0
mindspore/nn/probability/distribution/transformed_distribution.py
...e/nn/probability/distribution/transformed_distribution.py
+28
-15
未找到文件。
mindspore/nn/probability/distribution/_utils/utils.py
浏览文件 @
4aa339cb
...
...
@@ -272,6 +272,10 @@ def check_type(data_type, value_type, name):
def
raise_none_error
(
name
):
raise
ValueError
(
f
"
{
name
}
should be specified. Value cannot be None"
)
@
constexpr
def
raise_not_impl_error
(
name
):
raise
ValueError
(
f
"
{
name
}
function should be implemented for non-linear transformation"
)
@
constexpr
def
check_distribution_name
(
name
,
expected_name
):
if
name
!=
expected_name
:
...
...
mindspore/nn/probability/distribution/transformed_distribution.py
浏览文件 @
4aa339cb
...
...
@@ -18,7 +18,7 @@ from mindspore._checkparam import Validator as validator
from
mindspore.common
import
dtype
as
mstype
import
mindspore.nn
as
nn
from
.distribution
import
Distribution
from
._utils.utils
import
check_type
from
._utils.utils
import
check_type
,
raise_not_impl_error
class
TransformedDistribution
(
Distribution
):
"""
...
...
@@ -56,6 +56,7 @@ class TransformedDistribution(Distribution):
self
.
_distribution
=
distribution
self
.
_is_linear_transformation
=
bijector
.
is_constant_jacobian
self
.
exp
=
P
.
Exp
()
self
.
log
=
P
.
Log
()
@
property
def
bijector
(
self
):
...
...
@@ -69,37 +70,49 @@ class TransformedDistribution(Distribution):
def
is_linear_transformation
(
self
):
return
self
.
_is_linear_transformation
def
_cdf
(
self
,
value
):
def
_cdf
(
self
,
*
args
,
**
kwargs
):
r
"""
.. math::
Y = g(X)
P(Y <= a) = P(X <= g^{-1}(a))
"""
inverse_value
=
self
.
bijector
.
inverse
(
value
)
return
self
.
distribution
.
cdf
(
inverse_value
)
inverse_value
=
self
.
bijector
(
"inverse"
,
*
args
,
**
kwargs
)
return
self
.
distribution
(
"cdf"
,
inverse_value
)
def
_log_prob
(
self
,
value
):
def
_log_cdf
(
self
,
*
args
,
**
kwargs
):
return
self
.
log
(
self
.
_cdf
(
*
args
,
**
kwargs
))
def
_survival_function
(
self
,
*
args
,
**
kwargs
):
return
1.0
-
self
.
_cdf
(
*
args
,
**
kwargs
)
def
_log_survival
(
self
,
*
args
,
**
kwargs
):
return
self
.
log
(
self
.
_survival_function
(
*
args
,
**
kwargs
))
def
_log_prob
(
self
,
*
args
,
**
kwargs
):
r
"""
.. math::
Y = g(X)
Py(a) = Px(g^{-1}(a)) * (g^{-1})'(a)
\log(Py(a)) = \log(Px(g^{-1}(a))) + \log((g^{-1})'(a))
"""
inverse_value
=
self
.
bijector
.
inverse
(
value
)
unadjust_prob
=
self
.
distribution
.
log_prob
(
inverse_value
)
log_jacobian
=
self
.
bijector
.
inverse_log_jacobian
(
value
)
inverse_value
=
self
.
bijector
(
"inverse"
,
*
args
,
**
kwargs
)
unadjust_prob
=
self
.
distribution
(
"log_prob"
,
inverse_value
)
log_jacobian
=
self
.
bijector
(
"inverse_log_jacobian"
,
*
args
,
**
kwargs
)
return
unadjust_prob
+
log_jacobian
def
_prob
(
self
,
value
):
return
self
.
exp
(
self
.
_log_prob
(
value
))
def
_prob
(
self
,
*
args
,
**
kwargs
):
return
self
.
exp
(
self
.
_log_prob
(
*
args
,
**
kwargs
))
def
_sample
(
self
,
shape
):
org_sample
=
self
.
distribution
.
sample
(
shape
)
return
self
.
bijector
.
forward
(
org_sample
)
def
_sample
(
self
,
*
args
,
**
kwargs
):
org_sample
=
self
.
distribution
(
"sample"
,
shape
)
return
self
.
bijector
(
"forward"
,
org_sample
)
def
_mean
(
self
):
def
_mean
(
self
,
*
args
,
**
kwargs
):
"""
Note:
This function maybe overridden by derived class.
"""
return
self
.
bijector
.
forward
(
self
.
distribution
.
mean
())
if
not
self
.
is_linear_transformation
:
raise_not_impl_error
(
mean
)
return
self
.
bijector
(
"forward"
,
self
.
distribution
(
"mean"
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录