Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
93d9cc3c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
93d9cc3c
编写于
8月 24, 2020
作者:
P
peixu_ren
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add erf and erfc as generic functions for all the backend and fix notation in power_transform.
上级
c165a6d0
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
203 addition
and
44 deletion
+203
-44
mindspore/nn/probability/bijector/power_transform.py
mindspore/nn/probability/bijector/power_transform.py
+6
-6
mindspore/nn/probability/bijector/scalar_affine.py
mindspore/nn/probability/bijector/scalar_affine.py
+2
-2
mindspore/nn/probability/bijector/softplus.py
mindspore/nn/probability/bijector/softplus.py
+4
-4
mindspore/nn/probability/distribution/_utils/__init__.py
mindspore/nn/probability/distribution/_utils/__init__.py
+6
-4
mindspore/nn/probability/distribution/_utils/custom_ops.py
mindspore/nn/probability/distribution/_utils/custom_ops.py
+164
-7
mindspore/nn/probability/distribution/bernoulli.py
mindspore/nn/probability/distribution/bernoulli.py
+4
-4
mindspore/nn/probability/distribution/exponential.py
mindspore/nn/probability/distribution/exponential.py
+3
-3
mindspore/nn/probability/distribution/geometric.py
mindspore/nn/probability/distribution/geometric.py
+3
-3
mindspore/nn/probability/distribution/normal.py
mindspore/nn/probability/distribution/normal.py
+5
-5
mindspore/nn/probability/distribution/transformed_distribution.py
...e/nn/probability/distribution/transformed_distribution.py
+3
-3
mindspore/nn/probability/distribution/uniform.py
mindspore/nn/probability/distribution/uniform.py
+3
-3
未找到文件。
mindspore/nn/probability/bijector/power_transform.py
浏览文件 @
93d9cc3c
...
@@ -17,14 +17,14 @@ from mindspore.ops import operations as P
...
@@ -17,14 +17,14 @@ from mindspore.ops import operations as P
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Rel
from
mindspore._checkparam
import
Rel
from
..distribution._utils.utils
import
CheckTensor
from
..distribution._utils.utils
import
CheckTensor
from
..distribution._utils.custom_ops
import
exp_
by_step
,
expm1_by_step
,
log_by_step
,
log1p_by_step
from
..distribution._utils.custom_ops
import
exp_
generic
,
expm1_generic
,
log_generic
,
log1p_generic
from
.bijector
import
Bijector
from
.bijector
import
Bijector
class
PowerTransform
(
Bijector
):
class
PowerTransform
(
Bijector
):
r
"""
r
"""
Power Bijector.
Power Bijector.
This Bijector performs the operation: Y = g(X) = (1 + X * c)^(1 / c), X >= -1 / c, where c
is
power.
This Bijector performs the operation: Y = g(X) = (1 + X * c)^(1 / c), X >= -1 / c, where c
>= 0 is the
power.
The power transform maps inputs from `[-1/c, inf]` to `[0, inf]`.
The power transform maps inputs from `[-1/c, inf]` to `[0, inf]`.
...
@@ -61,10 +61,10 @@ class PowerTransform(Bijector):
...
@@ -61,10 +61,10 @@ class PowerTransform(Bijector):
validator
.
check_number
(
"power"
,
power
,
0
,
Rel
.
GE
,
self
.
name
)
validator
.
check_number
(
"power"
,
power
,
0
,
Rel
.
GE
,
self
.
name
)
self
.
_power
=
power
self
.
_power
=
power
self
.
pow
=
P
.
Pow
()
self
.
pow
=
P
.
Pow
()
self
.
exp
=
exp_
by_step
self
.
exp
=
exp_
generic
self
.
expm1
=
expm1_
by_step
self
.
expm1
=
expm1_
generic
self
.
log
=
log_
by_step
self
.
log
=
log_
generic
self
.
log1p
=
log1p_
by_step
self
.
log1p
=
log1p_
generic
self
.
checktensor
=
CheckTensor
()
self
.
checktensor
=
CheckTensor
()
...
...
mindspore/nn/probability/bijector/scalar_affine.py
浏览文件 @
93d9cc3c
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
from
..distribution._utils.utils
import
cast_to_tensor
,
CheckTensor
from
..distribution._utils.utils
import
cast_to_tensor
,
CheckTensor
from
..distribution._utils.custom_ops
import
log_
by_step
from
..distribution._utils.custom_ops
import
log_
generic
from
.bijector
import
Bijector
from
.bijector
import
Bijector
...
@@ -69,7 +69,7 @@ class ScalarAffine(Bijector):
...
@@ -69,7 +69,7 @@ class ScalarAffine(Bijector):
param
=
param
)
param
=
param
)
self
.
abs
=
P
.
Abs
()
self
.
abs
=
P
.
Abs
()
self
.
log
=
log_
by_step
self
.
log
=
log_
generic
self
.
checktensor
=
CheckTensor
()
self
.
checktensor
=
CheckTensor
()
...
...
mindspore/nn/probability/bijector/softplus.py
浏览文件 @
93d9cc3c
...
@@ -19,7 +19,7 @@ from mindspore.common import dtype as mstype
...
@@ -19,7 +19,7 @@ from mindspore.common import dtype as mstype
from
mindspore.nn.layer.activation
import
LogSigmoid
from
mindspore.nn.layer.activation
import
LogSigmoid
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
from
..distribution._utils.utils
import
cast_to_tensor
,
CheckTensor
from
..distribution._utils.utils
import
cast_to_tensor
,
CheckTensor
from
..distribution._utils.custom_ops
import
exp_
by_step
,
expm1_by_step
,
log_by_step
from
..distribution._utils.custom_ops
import
exp_
generic
,
expm1_generic
,
log_generic
from
.bijector
import
Bijector
from
.bijector
import
Bijector
...
@@ -61,9 +61,9 @@ class Softplus(Bijector):
...
@@ -61,9 +61,9 @@ class Softplus(Bijector):
super
(
Softplus
,
self
).
__init__
(
name
=
name
,
param
=
param
)
super
(
Softplus
,
self
).
__init__
(
name
=
name
,
param
=
param
)
self
.
_sharpness
=
cast_to_tensor
(
sharpness
)
self
.
_sharpness
=
cast_to_tensor
(
sharpness
)
self
.
exp
=
exp_
by_step
self
.
exp
=
exp_
generic
self
.
log
=
log_
by_step
self
.
log
=
log_
generic
self
.
expm1
=
expm1_
by_step
self
.
expm1
=
expm1_
generic
self
.
abs
=
P
.
Abs
()
self
.
abs
=
P
.
Abs
()
self
.
fill
=
P
.
Fill
()
self
.
fill
=
P
.
Fill
()
self
.
greater
=
P
.
Greater
()
self
.
greater
=
P
.
Greater
()
...
...
mindspore/nn/probability/distribution/_utils/__init__.py
浏览文件 @
93d9cc3c
...
@@ -28,8 +28,10 @@ __all__ = [
...
@@ -28,8 +28,10 @@ __all__ = [
'check_scalar_from_param'
,
'check_scalar_from_param'
,
'check_prob'
,
'check_prob'
,
'check_type'
,
'check_type'
,
'exp_by_step'
,
'exp_generic'
,
'expm1_by_step'
,
'expm1_generic'
,
'log_by_step'
,
'log_generic'
,
'log1p_by_step'
,
'log1p_generic'
,
'erf_generic'
,
'erfc_generic'
,
]
]
mindspore/nn/probability/distribution/_utils/custom_ops.py
浏览文件 @
93d9cc3c
...
@@ -17,8 +17,7 @@ import numpy as np
...
@@ -17,8 +17,7 @@ import numpy as np
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common
import
dtype
as
mstype
def
exp_generic
(
input_x
):
def
exp_by_step
(
input_x
):
"""
"""
Log op on Ascend doesn't supprot int types.
Log op on Ascend doesn't supprot int types.
Fix this with casting the type.
Fix this with casting the type.
...
@@ -30,14 +29,14 @@ def exp_by_step(input_x):
...
@@ -30,14 +29,14 @@ def exp_by_step(input_x):
return
exp
(
input_x
)
return
exp
(
input_x
)
def
expm1_
by_step
(
input_x
):
def
expm1_
generic
(
input_x
):
"""
"""
Expm1 ops under GPU context.
Expm1 ops under GPU context.
"""
"""
return
exp_
by_step
(
input_x
)
-
1.0
return
exp_
generic
(
input_x
)
-
1.0
def
log_
by_step
(
input_x
):
def
log_
generic
(
input_x
):
"""
"""
Log op on Ascend is calculated as log(abs(x)).
Log op on Ascend is calculated as log(abs(x)).
Fix this with putting negative values as nan.
Fix this with putting negative values as nan.
...
@@ -63,8 +62,166 @@ def log_by_step(input_x):
...
@@ -63,8 +62,166 @@ def log_by_step(input_x):
return
select
(
neg_x
,
nan
,
result
)
return
select
(
neg_x
,
nan
,
result
)
def
log1p_
by_step
(
x
):
def
log1p_
generic
(
x
):
"""
"""
Log1p ops on GPU device or when device_target == GPU.
Log1p ops on GPU device or when device_target == GPU.
"""
"""
return
log_by_step
(
x
+
1.0
)
return
log_generic
(
x
+
1.0
)
def
_evaluate_polynomial
(
x
,
coefficients
):
poly
=
0
for
co
in
coefficients
:
poly
=
poly
*
x
+
co
return
poly
def
erf_f32_generic
(
x
):
"""
Calculate erf for dtype of f32
"""
k_erf_tcoefficient
=
[
+
7.853861353153693e-5
,
-
8.010193625184903e-4
,
+
5.188327685732524e-3
,
-
2.685381193529856e-2
,
+
1.128358514861418e-1
,
-
3.761262582423300e-1
,
+
1.128379165726710e+0
]
poly
=
_evaluate_polynomial
(
x
*
x
,
k_erf_tcoefficient
)
return
x
*
poly
def
erf_f64_generic
(
x
):
"""
Calculate erf for dtype of f64
"""
k_erf_tcoefficient
=
[
9.60497373987051638749e0
,
9.00260197203842689217e1
,
2.23200534594684319226e3
,
7.00332514112805075473e3
,
5.55923013010394962768e4
]
k_erf_ucoefficient
=
[
1.00000000000000000000e0
,
3.35617141647503099647e1
,
5.21357949780152679795e2
,
4.59432382970980127987e3
,
2.26290000613890934246e4
,
4.92673942608635921086e4
]
z
=
x
*
x
poly1
=
_evaluate_polynomial
(
z
,
k_erf_tcoefficient
)
poly2
=
_evaluate_polynomial
(
z
,
k_erf_ucoefficient
)
return
x
*
poly1
/
poly2
def
erfc_f32_generic
(
x
):
"""
Calculate erfc for dtype of f32
"""
k_maxlog
=
88.72283905206835
k_erfc_pcoefficient
=
[
+
2.326819970068386e-2
,
-
1.387039388740657e-1
,
+
3.687424674597105e-1
,
-
5.824733027278666e-1
,
+
6.210004621745983e-1
,
-
4.944515323274145e-1
,
+
3.404879937665872e-1
,
-
2.741127028184656e-1
,
+
5.638259427386472e-1
]
k_erfc_rcoefficient
=
[
-
1.047766399936249e+1
,
+
1.297719955372516e+1
,
-
7.495518717768503e+0
,
+
2.921019019210786e+0
,
-
1.015265279202700e+0
,
+
4.218463358204948e-1
,
-
2.820767439740514e-1
,
+
5.641895067754075e-1
]
abs_cal
=
P
.
Abs
()
select
=
P
.
Select
()
less
=
P
.
Less
()
fill
=
P
.
Fill
()
dtype
=
P
.
DType
()
shape
=
P
.
Shape
()
abs_x
=
abs_cal
(
x
)
z
=
exp_generic
(
-
x
*
x
)
q
=
1
/
abs_x
y
=
q
*
q
poly1
=
_evaluate_polynomial
(
y
,
k_erfc_pcoefficient
)
poly2
=
_evaluate_polynomial
(
y
,
k_erfc_rcoefficient
)
p
=
select
(
less
(
abs_x
,
2.0
),
poly1
,
poly2
)
y
=
z
*
q
*
p
zeros
=
fill
(
dtype
(
x
),
shape
(
x
),
0
)
y_clamp
=
select
(
less
(
z
,
-
k_maxlog
),
zeros
,
y
)
return
select
(
less
(
x
,
0
),
2.0
-
y_clamp
,
y_clamp
)
def
erfc_f64_generic
(
x
):
"""
Calculate erfc for dtype of f64
"""
k_maxlog
=
7.09782712893383996843e2
k_erfc_pcoefficient
=
[
2.46196981473530512524e-10
,
5.64189564831068821977e-1
,
7.46321056442269912687e0
,
4.86371970985681366614e1
,
1.96520832956077098242e2
,
5.26445194995477358631e2
,
9.34528527171957607540e2
,
1.02755188689515710272e3
,
5.57535335369399327526e2
]
k_erfc_qcoefficient
=
[
1.00000000000000000000e0
,
1.32281951154744992508e1
,
8.67072140885989742329e1
,
3.54937778887819891062e2
,
9.75708501743205489753e2
,
1.82390916687909736289e3
,
2.24633760818710981792e3
,
1.65666309194161350182e3
,
5.57535340817727675546e2
]
k_erfc_rcoefficient
=
[
5.64189583547755073984e-1
,
1.27536670759978104416e0
,
5.01905042251180477414e0
,
6.16021097993053585195e0
,
7.40974269950448939160e0
,
2.97886665372100240670e0
]
k_erfc_scoefficient
=
[
1.00000000000000000000e0
,
2.26052863220117276590e0
,
9.39603524938001434673e0
,
1.20489539808096656605e1
,
1.70814450747565897222e1
,
9.60896809063285878198e0
,
3.36907645100081516050e02
]
abs_cal
=
P
.
Abs
()
select
=
P
.
Select
()
less
=
P
.
Less
()
fill
=
P
.
Fill
()
dtype
=
P
.
DType
()
shape
=
P
.
Shape
()
abs_x
=
abs_cal
(
x
)
z
=
-
x
*
x
exp_z
=
exp_generic
(
z
)
temp1
=
exp_z
*
_evaluate_polynomial
(
abs_x
,
k_erfc_pcoefficient
)
/
_evaluate_polynomial
(
abs_x
,
k_erfc_qcoefficient
)
temp2
=
exp_z
*
_evaluate_polynomial
(
abs_x
,
k_erfc_rcoefficient
)
/
_evaluate_polynomial
(
abs_x
,
k_erfc_scoefficient
)
y
=
select
(
less
(
abs_x
,
8.0
),
temp1
,
temp2
)
zeros
=
fill
(
dtype
(
x
),
shape
(
x
),
0
)
y_clamp
=
select
(
less
(
z
,
k_maxlog
),
zeros
,
y
)
poly2
=
_evaluate_polynomial
(
y
,
k_erfc_rcoefficient
)
p
=
select
(
less
(
abs_x
,
2.0
),
poly1
,
poly2
)
y
=
z
*
q
*
p
zeros
=
fill
(
dtype
(
x
),
shape
(
x
),
0
)
y_clamp
=
select
(
less
(
z
,
-
k_maxlog
),
zeros
,
y
)
return
select
(
less
(
x
,
0
),
2.0
-
y_clamp
,
y_clamp
)
def
erfc_generic
(
x
):
select
=
P
.
Select
()
greater
=
P
.
Greater
()
abs_cal
=
P
.
Abs
()
return
select
(
greater
(
abs_cal
(
x
),
1
),
erfc_f32_generic
(
x
),
1
-
erf_f32_generic
(
x
))
def
erf_generic
(
x
):
select
=
P
.
Select
()
less
=
P
.
Less
()
abs_cal
=
P
.
Abs
()
return
select
(
less
(
abs_cal
(
x
),
1
),
erf_f32_generic
(
x
),
1
-
erfc_f32_generic
(
x
))
mindspore/nn/probability/distribution/bernoulli.py
浏览文件 @
93d9cc3c
...
@@ -18,7 +18,7 @@ from mindspore.ops import operations as P
...
@@ -18,7 +18,7 @@ from mindspore.ops import operations as P
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
composite
as
C
from
.distribution
import
Distribution
from
.distribution
import
Distribution
from
._utils.utils
import
cast_to_tensor
,
check_prob
,
check_type
,
check_distribution_name
,
raise_none_error
from
._utils.utils
import
cast_to_tensor
,
check_prob
,
check_type
,
check_distribution_name
,
raise_none_error
from
._utils.custom_ops
import
exp_
by_step
,
log_by_step
from
._utils.custom_ops
import
exp_
generic
,
log_generic
,
erf_generic
class
Bernoulli
(
Distribution
):
class
Bernoulli
(
Distribution
):
...
@@ -109,13 +109,13 @@ class Bernoulli(Distribution):
...
@@ -109,13 +109,13 @@ class Bernoulli(Distribution):
self
.
_probs
=
probs
self
.
_probs
=
probs
# ops needed for the class
# ops needed for the class
self
.
exp
=
exp_by_step
self
.
exp
=
exp_generic
self
.
log
=
log_by_step
self
.
log
=
log_generic
self
.
erf
=
erf_generic
self
.
squeeze
=
P
.
Squeeze
(
0
)
self
.
squeeze
=
P
.
Squeeze
(
0
)
self
.
cast
=
P
.
Cast
()
self
.
cast
=
P
.
Cast
()
self
.
const
=
P
.
ScalarToArray
()
self
.
const
=
P
.
ScalarToArray
()
self
.
dtypeop
=
P
.
DType
()
self
.
dtypeop
=
P
.
DType
()
self
.
erf
=
P
.
Erf
()
self
.
floor
=
P
.
Floor
()
self
.
floor
=
P
.
Floor
()
self
.
fill
=
P
.
Fill
()
self
.
fill
=
P
.
Fill
()
self
.
less
=
P
.
Less
()
self
.
less
=
P
.
Less
()
...
...
mindspore/nn/probability/distribution/exponential.py
浏览文件 @
93d9cc3c
...
@@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype
...
@@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype
from
.distribution
import
Distribution
from
.distribution
import
Distribution
from
._utils.utils
import
cast_to_tensor
,
check_greater_zero
,
check_type
,
check_distribution_name
,
\
from
._utils.utils
import
cast_to_tensor
,
check_greater_zero
,
check_type
,
check_distribution_name
,
\
raise_none_error
raise_none_error
from
._utils.custom_ops
import
exp_
by_step
,
log_by_step
from
._utils.custom_ops
import
exp_
generic
,
log_generic
class
Exponential
(
Distribution
):
class
Exponential
(
Distribution
):
"""
"""
...
@@ -112,8 +112,8 @@ class Exponential(Distribution):
...
@@ -112,8 +112,8 @@ class Exponential(Distribution):
self
.
minval
=
np
.
finfo
(
np
.
float
).
tiny
self
.
minval
=
np
.
finfo
(
np
.
float
).
tiny
# ops needed for the class
# ops needed for the class
self
.
exp
=
exp_
by_step
self
.
exp
=
exp_
generic
self
.
log
=
log_
by_step
self
.
log
=
log_
generic
self
.
squeeze
=
P
.
Squeeze
(
0
)
self
.
squeeze
=
P
.
Squeeze
(
0
)
self
.
cast
=
P
.
Cast
()
self
.
cast
=
P
.
Cast
()
self
.
const
=
P
.
ScalarToArray
()
self
.
const
=
P
.
ScalarToArray
()
...
...
mindspore/nn/probability/distribution/geometric.py
浏览文件 @
93d9cc3c
...
@@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype
...
@@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype
from
.distribution
import
Distribution
from
.distribution
import
Distribution
from
._utils.utils
import
cast_to_tensor
,
check_prob
,
check_type
,
check_distribution_name
,
\
from
._utils.utils
import
cast_to_tensor
,
check_prob
,
check_type
,
check_distribution_name
,
\
raise_none_error
raise_none_error
from
._utils.custom_ops
import
exp_
by_step
,
log_by_step
from
._utils.custom_ops
import
exp_
generic
,
log_generic
class
Geometric
(
Distribution
):
class
Geometric
(
Distribution
):
...
@@ -114,8 +114,8 @@ class Geometric(Distribution):
...
@@ -114,8 +114,8 @@ class Geometric(Distribution):
self
.
minval
=
np
.
finfo
(
np
.
float
).
tiny
self
.
minval
=
np
.
finfo
(
np
.
float
).
tiny
# ops needed for the class
# ops needed for the class
self
.
exp
=
exp_
by_step
self
.
exp
=
exp_
generic
self
.
log
=
log_
by_step
self
.
log
=
log_
generic
self
.
squeeze
=
P
.
Squeeze
(
0
)
self
.
squeeze
=
P
.
Squeeze
(
0
)
self
.
cast
=
P
.
Cast
()
self
.
cast
=
P
.
Cast
()
self
.
const
=
P
.
ScalarToArray
()
self
.
const
=
P
.
ScalarToArray
()
...
...
mindspore/nn/probability/distribution/normal.py
浏览文件 @
93d9cc3c
...
@@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype
...
@@ -20,7 +20,7 @@ from mindspore.common import dtype as mstype
from
.distribution
import
Distribution
from
.distribution
import
Distribution
from
._utils.utils
import
convert_to_batch
,
check_greater_zero
,
check_type
,
check_distribution_name
,
\
from
._utils.utils
import
convert_to_batch
,
check_greater_zero
,
check_type
,
check_distribution_name
,
\
raise_none_error
raise_none_error
from
._utils.custom_ops
import
exp_
by_step
,
expm1_by_step
,
log_by_step
from
._utils.custom_ops
import
exp_
generic
,
expm1_generic
,
log_generic
,
erf_generic
class
Normal
(
Distribution
):
class
Normal
(
Distribution
):
"""
"""
...
@@ -114,13 +114,13 @@ class Normal(Distribution):
...
@@ -114,13 +114,13 @@ class Normal(Distribution):
self
.
_sd_value
=
sd
self
.
_sd_value
=
sd
#ops needed for the class
#ops needed for the class
self
.
exp
=
exp_by_step
self
.
exp
=
exp_generic
self
.
expm1
=
expm1_by_step
self
.
expm1
=
expm1_generic
self
.
log
=
log_by_step
self
.
log
=
log_generic
self
.
erf
=
erf_generic
self
.
squeeze
=
P
.
Squeeze
(
0
)
self
.
squeeze
=
P
.
Squeeze
(
0
)
self
.
cast
=
P
.
Cast
()
self
.
cast
=
P
.
Cast
()
self
.
const
=
P
.
ScalarToArray
()
self
.
const
=
P
.
ScalarToArray
()
self
.
erf
=
P
.
Erf
()
self
.
fill
=
P
.
Fill
()
self
.
fill
=
P
.
Fill
()
self
.
shape
=
P
.
Shape
()
self
.
shape
=
P
.
Shape
()
self
.
sq
=
P
.
Square
()
self
.
sq
=
P
.
Square
()
...
...
mindspore/nn/probability/distribution/transformed_distribution.py
浏览文件 @
93d9cc3c
...
@@ -18,7 +18,7 @@ from mindspore.common import dtype as mstype
...
@@ -18,7 +18,7 @@ from mindspore.common import dtype as mstype
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
.distribution
import
Distribution
from
.distribution
import
Distribution
from
._utils.utils
import
check_type
,
raise_not_impl_error
from
._utils.utils
import
check_type
,
raise_not_impl_error
from
._utils.custom_ops
import
exp_
by_step
,
log_by_step
from
._utils.custom_ops
import
exp_
generic
,
log_generic
class
TransformedDistribution
(
Distribution
):
class
TransformedDistribution
(
Distribution
):
"""
"""
...
@@ -55,8 +55,8 @@ class TransformedDistribution(Distribution):
...
@@ -55,8 +55,8 @@ class TransformedDistribution(Distribution):
self
.
_bijector
=
bijector
self
.
_bijector
=
bijector
self
.
_distribution
=
distribution
self
.
_distribution
=
distribution
self
.
_is_linear_transformation
=
bijector
.
is_constant_jacobian
self
.
_is_linear_transformation
=
bijector
.
is_constant_jacobian
self
.
exp
=
exp_
by_step
self
.
exp
=
exp_
generic
self
.
log
=
log_
by_step
self
.
log
=
log_
generic
@
property
@
property
def
bijector
(
self
):
def
bijector
(
self
):
...
...
mindspore/nn/probability/distribution/uniform.py
浏览文件 @
93d9cc3c
...
@@ -19,7 +19,7 @@ from mindspore.common import dtype as mstype
...
@@ -19,7 +19,7 @@ from mindspore.common import dtype as mstype
from
.distribution
import
Distribution
from
.distribution
import
Distribution
from
._utils.utils
import
convert_to_batch
,
check_greater
,
check_type
,
check_distribution_name
,
\
from
._utils.utils
import
convert_to_batch
,
check_greater
,
check_type
,
check_distribution_name
,
\
raise_none_error
raise_none_error
from
._utils.custom_ops
import
exp_
by_step
,
log_by_step
from
._utils.custom_ops
import
exp_
generic
,
log_generic
class
Uniform
(
Distribution
):
class
Uniform
(
Distribution
):
"""
"""
...
@@ -113,8 +113,8 @@ class Uniform(Distribution):
...
@@ -113,8 +113,8 @@ class Uniform(Distribution):
self
.
_high
=
high
self
.
_high
=
high
# ops needed for the class
# ops needed for the class
self
.
exp
=
exp_
by_step
self
.
exp
=
exp_
generic
self
.
log
=
log_
by_step
self
.
log
=
log_
generic
self
.
squeeze
=
P
.
Squeeze
(
0
)
self
.
squeeze
=
P
.
Squeeze
(
0
)
self
.
cast
=
P
.
Cast
()
self
.
cast
=
P
.
Cast
()
self
.
const
=
P
.
ScalarToArray
()
self
.
const
=
P
.
ScalarToArray
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录