Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
af85b2ce
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
af85b2ce
编写于
6月 11, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 11, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1986 fixed validator for CumSum
Merge pull request !1986 from jiangjinsheng/issue_fix2
上级
961af9fe
91183182
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
6 addition
and
10 deletion
+6
-10
mindspore/ops/_grad/grad_math_ops.py
mindspore/ops/_grad/grad_math_ops.py
+4
-3
mindspore/ops/operations/math_ops.py
mindspore/ops/operations/math_ops.py
+2
-4
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+0
-3
未找到文件。
mindspore/ops/_grad/grad_math_ops.py
浏览文件 @
af85b2ce
...
...
@@ -1001,15 +1001,16 @@ def get_bprop_bessel_i1e(self):
reciprocal
=
P
.
Reciprocal
()
cast
=
P
.
Cast
()
dtype
=
P
.
DType
()
abs_ops
=
P
.
Abs
()
def
bprop
(
x
,
out
,
dout
):
zeros
=
zeros_like
(
x
)
np_eps
=
const_utils
.
get_np_eps
(
dtype
(
x
))
eps
=
cast
(
np_eps
,
dtype
(
x
))
x_is_valid
=
less
(
eps
,
x
)
x_is_valid
=
less
(
eps
,
abs_ops
(
x
)
)
x_safe
=
select
(
x_is_valid
,
x
,
eps
+
zeros
)
tmp
=
bessel_i0e
(
x_safe
)
-
out
*
(
sign
(
x
)
+
reciprocal
(
x_safe
))
dx
=
select
(
x_is_valid
,
tmp
,
0.5
+
zeros
)
tmp
=
bessel_i0e
(
x_safe
)
-
out
*
(
sign
(
x
_safe
)
+
reciprocal
(
x_safe
))
dx
=
select
(
x_is_valid
,
tmp
,
cast
(
0.5
,
dtype
(
x
))
+
zeros
)
*
dout
return
(
dx
,)
return
bprop
...
...
mindspore/ops/operations/math_ops.py
浏览文件 @
af85b2ce
...
...
@@ -672,6 +672,8 @@ class CumSum(PrimitiveWithInfer):
def
__infer__
(
self
,
x
,
axis
):
cls_name
=
self
.
name
x_shp
=
x
[
'shape'
]
if
axis
[
'value'
]
is
None
:
raise
ValueError
(
f
"For
{
self
.
name
}
, axis must be const."
)
validator
.
check_value_type
(
'axis'
,
axis
[
'value'
],
[
int
],
cls_name
)
valid_types
=
[
mstype
.
uint8
,
mstype
.
int8
,
mstype
.
int32
,
mstype
.
float16
,
mstype
.
float32
]
validator
.
check_tensor_type_same
({
'x'
:
x
[
'dtype'
]},
valid_types
,
cls_name
)
...
...
@@ -679,10 +681,6 @@ class CumSum(PrimitiveWithInfer):
'dtype'
:
x
[
'dtype'
],
'value'
:
None
}
def
infer_value
(
self
,
x
,
axis
):
if
axis
is
None
:
raise
ValueError
(
f
"For
{
self
.
name
}
, axis must be const."
)
class
AddN
(
PrimitiveWithInfer
):
"""
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
af85b2ce
...
...
@@ -1767,9 +1767,6 @@ class ApplyRMSProp(PrimitiveWithInfer):
def
infer_value
(
self
,
var
,
mean_square
,
moment
,
learning_rate
,
grad
,
decay
,
momentum
,
epsilon
):
if
decay
is
None
or
momentum
is
None
or
epsilon
is
None
:
raise
ValueError
(
f
"For
{
self
.
name
}
, decay, momentum, epsilon must be const."
)
if
not
self
.
is_ge
and
self
.
is_d
:
return
None
,
None
,
None
return
None
class
ApplyCenteredRMSProp
(
PrimitiveWithInfer
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录