Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
48735c25
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
48735c25
编写于
5月 05, 2020
作者:
C
changzherui
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify applr momentum output 2
上级
e01692ad
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
16 addition
and
9 deletion
+16
-9
mindspore/ccsrc/device/kernel_runtime.cc
mindspore/ccsrc/device/kernel_runtime.cc
+1
-0
mindspore/ops/_op_impl/tbe/apply_momentum.py
mindspore/ops/_op_impl/tbe/apply_momentum.py
+10
-9
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+5
-0
未找到文件。
mindspore/ccsrc/device/kernel_runtime.cc
浏览文件 @
48735c25
...
...
@@ -201,6 +201,7 @@ void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) {
if
(
AnfAlgo
::
GetCNodeName
(
kernel
)
==
"ApplyMomentum"
)
{
auto
device_address
=
AnfAlgo
::
GetPrevNodeMutableOutputAddr
(
kernel
,
0
);
AnfAlgo
::
SetOutputAddr
(
device_address
,
0
,
kernel
.
get
());
AnfAlgo
::
SetOutputAddr
(
device_address
,
1
,
kernel
.
get
());
return
;
}
...
...
mindspore/ops/_op_impl/tbe/apply_momentum.py
浏览文件 @
48735c25
...
...
@@ -29,23 +29,24 @@ apply_momentum_op_info = TBERegOp("ApplyMomentum") \
.
input
(
2
,
"lr"
,
False
,
"required"
,
"all"
)
\
.
input
(
3
,
"grad"
,
False
,
"required"
,
"all"
)
\
.
input
(
4
,
"momentum"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"out"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"var"
,
False
,
"required"
,
"all"
)
\
.
output
(
1
,
"accum"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_Default
,
DataType
.
F16_5HD
,
DataType
.
F16_Default
,
DataType
.
F16_5HD
)
\
DataType
.
F16_Default
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_Default
,
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_Default
,
DataType
.
F16_C1HWNCoC0
)
\
DataType
.
F16_Default
,
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_C1HWNCoC0
)
\
.
dtype_format
(
DataType
.
F16_FracZ
,
DataType
.
F16_FracZ
,
DataType
.
F16_Default
,
DataType
.
F16_FracZ
,
DataType
.
F16_Default
,
DataType
.
F16_FracZ
)
\
DataType
.
F16_Default
,
DataType
.
F16_FracZ
,
DataType
.
F16_FracZ
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_Default
,
DataType
.
F32_5HD
,
DataType
.
F32_Default
,
DataType
.
F32_5HD
)
\
DataType
.
F32_Default
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
dtype_format
(
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_Default
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_Default
,
DataType
.
F32_C1HWNCoC0
)
\
DataType
.
F32_Default
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
)
\
.
dtype_format
(
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
,
DataType
.
F32_Default
,
DataType
.
F32_FracZ
,
DataType
.
F32_Default
,
DataType
.
F32_FracZ
)
\
DataType
.
F32_Default
,
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
)
\
.
get_op_info
()
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
48735c25
...
...
@@ -1427,8 +1427,11 @@ class ApplyMomentum(PrimitiveWithInfer):
def
__init__
(
self
,
use_nesterov
=
False
,
use_locking
=
False
,
gradient_scale
=
1.0
):
self
.
init_prim_io_names
(
inputs
=
[
'variable'
,
'accumulation'
,
'learning_rate'
,
'gradient'
,
'momentum'
],
outputs
=
[
'output'
])
self
.
is_tbe
=
context
.
get_context
(
"device_target"
)
==
"Ascend"
def
infer_shape
(
self
,
v_shape
,
a_shape
,
l_shape
,
g_shape
,
m_shape
):
if
self
.
is_tbe
:
return
v_shape
,
v_shape
return
v_shape
def
infer_dtype
(
self
,
v_dtype
,
a_dtype
,
l_dtype
,
g_dtype
,
m_dtype
):
...
...
@@ -1439,6 +1442,8 @@ class ApplyMomentum(PrimitiveWithInfer):
validator
.
check_scalar_or_tensor_type_same
({
"l_dtype"
:
l_dtype
},
valid_types
,
self
.
name
)
validator
.
check_scalar_or_tensor_type_same
({
"g_dtype"
:
g_dtype
},
valid_types
,
self
.
name
)
validator
.
check_scalar_or_tensor_type_same
({
"m_dtype"
:
m_dtype
},
valid_types
,
self
.
name
)
if
self
.
is_tbe
:
return
g_dtype
,
g_dtype
return
g_dtype
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录