Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_44025039
mindspore
提交
517fed55
M
mindspore
项目概览
weixin_44025039
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
517fed55
编写于
8月 04, 2020
作者:
P
peixu_ren
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Added LGamma op
上级
1694c882
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
138 addition
and
2 deletion
+138
-2
mindspore/nn/layer/math.py
mindspore/nn/layer/math.py
+134
-1
tests/ut/python/ops/test_nn_ops.py
tests/ut/python/ops/test_nn_ops.py
+4
-1
未找到文件。
mindspore/nn/layer/math.py
浏览文件 @
517fed55
...
...
@@ -14,16 +14,18 @@
# ============================================================================
"""math"""
import
math
import
numpy
as
np
from
mindspore.ops
import
operations
as
P
from
mindspore.ops.operations
import
_inner_ops
as
inner
from
mindspore.common.tensor
import
Tensor
from
mindspore.ops.primitive
import
constexpr
from
..cell
import
Cell
from
...common
import
dtype
as
mstype
from
..._checkparam
import
Validator
as
validator
from
..._checkparam
import
Rel
__all__
=
[
'ReduceLogSumExp'
,
'Range'
,
'LinSpace'
]
__all__
=
[
'ReduceLogSumExp'
,
'Range'
,
'LinSpace'
,
'LGamma'
]
class
ReduceLogSumExp
(
Cell
):
...
...
@@ -169,3 +171,134 @@ class LinSpace(Cell):
lin_space_out
=
self
.
lin_space
(
self
.
assist
,
self
.
start
,
self
.
stop
,
self
.
num
)
return
lin_space_out
@
constexpr
def
check_tensors_dtype_same
(
data_dtype
,
value_dtype
,
op_name
):
"""Check tensors data type same."""
if
data_dtype
in
value_dtype
:
return
True
raise
TypeError
(
f
"For '
{
op_name
}
', the value data type '
{
value_dtype
}
' "
f
"is not consistent with assigned tensor data type
{
data_dtype
}
."
)
class
LGamma
(
Cell
):
r
"""
Calculate LGamma using Lanczos' approximation refering to "A Precision Approximationof the Gamma Function".
The algorithm is:
.. math::
lgamma(z + 1) = \frac{(\log(2) + \log(pi))}{2} + (z + 1/2) * log(t(z)) - t(z) + A(z)
t(z) = z + kLanczosGamma + 1/2
A(z) = kBaseLanczosCoeff + \sum_{k=1}^n \frac{kLanczosCoefficients[i]}{z + k}
However, if the input is less than 0.5 use Euler's reflection formula:
.. math::
lgamma(x) = \log(pi) - lgamma(1-x) - \log(abs(sin(pi * x)))
And please note that
.. math::
lgamma(+/-inf) = +inf
Thus, the behaviour of LGamma follows:
when x > 0.5, return log(Gamma(x))
when x < 0.5 and is not an interger, return the real part of Log(Gamma(x)) where Log is the complex logarithm
when x is an integer less or equal to 0, return +inf
when x = +/- inf, return +inf
Inputs:
- **input_x** (Tensor[Number]) - The input tensor. Only float16, float32 are supported.
Outputs:
Tensor, has the same shape and dtype as the 'input_x'.
Examples:
>>> input_x = Tensor(np.array(2, 3, 4).astype(np.float32))
>>> op = nn.LGamma()
>>> output = op(input_x)
"""
def
__init__
(
self
):
super
(
LGamma
,
self
).
__init__
()
# const numbers
self
.
k_lanczos_gamma
=
7
self
.
k_base_lanczos_coeff
=
0.99999999999980993227684700473478
self
.
k_lanczos_coefficients
=
[
676.520368121885098567009190444019
,
-
1259.13921672240287047156078755283
,
771.3234287776530788486528258894
,
-
176.61502916214059906584551354
,
12.507343278686904814458936853
,
-
0.13857109526572011689554707
,
9.984369578019570859563e-6
,
1.50563273514931155834e-7
]
self
.
one_half
=
0.5
self
.
one
=
1
self
.
two
=
2
self
.
inf
=
np
.
inf
self
.
pi
=
np
.
pi
self
.
log_2
=
np
.
log
(
self
.
two
)
self
.
log_pi
=
np
.
log
(
np
.
pi
)
self
.
log_sqrt_two_pi
=
(
self
.
log_2
+
self
.
log_pi
)
/
self
.
two
self
.
lanczos_gamma_plus_one_half
=
self
.
k_lanczos_gamma
+
0.5
self
.
log_lanczos_gamma_plus_one_half
=
np
.
log
(
self
.
lanczos_gamma_plus_one_half
)
# operations
self
.
log
=
P
.
Log
()
self
.
log1p
=
P
.
Log1p
()
self
.
abs
=
P
.
Abs
()
self
.
shape
=
P
.
Shape
()
self
.
dtype
=
P
.
DType
()
self
.
fill
=
P
.
Fill
()
self
.
floor
=
P
.
Floor
()
self
.
equal
=
P
.
Equal
()
self
.
greater
=
P
.
Greater
()
self
.
less
=
P
.
Less
()
self
.
lessequal
=
P
.
LessEqual
()
self
.
select
=
P
.
Select
()
self
.
sin
=
P
.
Sin
()
self
.
isfinite
=
P
.
IsFinite
()
def
construct
(
self
,
input_x
):
input_dtype
=
self
.
dtype
(
input_x
)
check_tensors_dtype_same
(
input_dtype
,
[
mstype
.
float16
,
mstype
.
float32
],
"LGamma"
)
infinity
=
self
.
fill
(
input_dtype
,
self
.
shape
(
input_x
),
self
.
inf
)
need_to_reflect
=
self
.
less
(
input_x
,
0.5
)
neg_input
=
-
input_x
z
=
self
.
select
(
need_to_reflect
,
neg_input
,
input_x
-
1
)
@
constexpr
def
_calculate_x
(
z
,
k_base_lanczos_coeff
,
k_lanczos_coefficients
):
x
=
k_base_lanczos_coeff
for
i
in
range
(
8
):
product_
=
k_lanczos_coefficients
[
i
]
/
(
z
+
i
+
1
)
x
=
product_
+
x
return
x
x
=
_calculate_x
(
z
,
self
.
k_base_lanczos_coeff
,
self
.
k_lanczos_coefficients
)
t
=
z
+
self
.
lanczos_gamma_plus_one_half
log_t
=
self
.
log1p
(
z
/
self
.
lanczos_gamma_plus_one_half
)
+
self
.
log_lanczos_gamma_plus_one_half
log_y
=
self
.
log
(
x
)
+
(
z
+
self
.
one_half
-
t
/
log_t
)
*
log_t
+
self
.
log_sqrt_two_pi
abs_input
=
self
.
abs
(
input_x
)
abs_frac_input
=
abs_input
-
self
.
floor
(
abs_input
)
input_x
=
self
.
select
(
self
.
lessequal
(
input_x
,
0.0
),
self
.
select
(
self
.
equal
(
abs_frac_input
,
0.0
),
infinity
,
input_x
),
input_x
)
reduced_frac_input
=
self
.
select
(
self
.
greater
(
abs_frac_input
,
0.5
),
1
-
abs_frac_input
,
abs_frac_input
)
reflection_denom
=
self
.
log
(
self
.
sin
(
self
.
pi
*
reduced_frac_input
))
reflection
=
self
.
select
(
self
.
isfinite
(
reflection_denom
),
-
reflection_denom
-
log_y
+
self
.
log_pi
,
-
reflection_denom
)
result
=
self
.
select
(
need_to_reflect
,
reflection
,
log_y
)
return
self
.
select
(
self
.
isfinite
(
input_x
),
result
,
infinity
)
tests/ut/python/ops/test_nn_ops.py
浏览文件 @
517fed55
...
...
@@ -584,7 +584,10 @@ test_cases = [
(
'ReduceLogSumExp'
,
{
'block'
:
nn
.
ReduceLogSumExp
((
0
,),
False
),
'desc_inputs'
:
[
Tensor
(
np
.
array
([
3
,
4
,
5
,
6
]).
astype
(
np
.
float32
))],
'desc_bprop'
:
[
Tensor
(
np
.
array
([
1
,
2
,
3
,
4
]).
astype
(
np
.
float32
))],
'skip'
:
[
'backward'
]}),
(
'LGamma'
,
{
'block'
:
nn
.
LGamma
(),
'desc_inputs'
:
[
Tensor
(
np
.
array
([
3
,
4
,
5
,
6
]).
astype
(
np
.
float32
))],
'skip'
:
[
'backward'
]}),
(
'FlattenNet'
,
{
'block'
:
FlattenNet
(),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录