Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
380a57f3
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
380a57f3
编写于
7月 15, 2020
作者:
L
liuxiao93
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Adapt ApplyCenteredRmsProp.
上级
d89cedb9
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
39 addition
and
21 deletion
+39
-21
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc
+1
-0
mindspore/ccsrc/transform/graph_ir/convert.cc
mindspore/ccsrc/transform/graph_ir/convert.cc
+1
-1
mindspore/ccsrc/transform/graph_ir/op_declare.cc
mindspore/ccsrc/transform/graph_ir/op_declare.cc
+7
-6
mindspore/ccsrc/transform/graph_ir/op_declare.h
mindspore/ccsrc/transform/graph_ir/op_declare.h
+2
-2
mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py
mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py
+23
-12
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+5
-0
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc
浏览文件 @
380a57f3
...
...
@@ -81,6 +81,7 @@ static std::map<string, string> tbe_func_adapter_map = {
{
"sparse_apply_proximal_adagrad"
,
"sparse_apply_proximal_adagrad_d"
},
{
"apply_add_sign"
,
"apply_add_sign_d"
},
{
"apply_power_sign"
,
"apply_power_sign_d"
},
{
"apply_centered_rms_prop"
,
"apply_centered_rms_prop_d"
},
{
"transpose"
,
"transpose_d"
},
{
"fill"
,
"fill_d"
},
{
"unsorted_segment_sum"
,
"unsorted_segment_sum_d"
},
...
...
mindspore/ccsrc/transform/graph_ir/convert.cc
浏览文件 @
380a57f3
...
...
@@ -409,7 +409,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{
string
(
kNameBatchToSpace
),
ADPT_DESC
(
BatchToSpaceD
)},
{
string
(
kNameAtan2
),
ADPT_DESC
(
Atan2
)},
{
string
(
kNameApplyRMSProp
),
ADPT_DESC
(
ApplyRMSPropD
)},
{
string
(
kNameApplyCenteredRMSProp
),
ADPT_DESC
(
ApplyCenteredRMSProp
)},
{
string
(
kNameApplyCenteredRMSProp
),
ADPT_DESC
(
ApplyCenteredRMSProp
D
)},
{
string
(
kNameL2Loss
),
ADPT_DESC
(
L2Loss
)},
{
string
(
kNameCTCLoss
),
ADPT_DESC
(
CTCLoss
)},
{
string
(
kNameRange
),
ADPT_DESC
(
RangeD
)},
...
...
mindspore/ccsrc/transform/graph_ir/op_declare.cc
浏览文件 @
380a57f3
...
...
@@ -1284,12 +1284,13 @@ INPUT_ATTR_MAP(ApplyRMSPropD) = {{6, ATTR_DESC(rho, AnyTraits<float>())},
ATTR_MAP
(
ApplyRMSPropD
)
=
{{
"use_locking"
,
ATTR_DESC
(
use_locking
,
AnyTraits
<
bool
>
())}};
OUTPUT_MAP
(
ApplyRMSPropD
)
=
{{
0
,
OUTPUT_DESC
(
var
)}};
// ApplyCenteredRMSProp
INPUT_MAP
(
ApplyCenteredRMSProp
)
=
{{
1
,
INPUT_DESC
(
var
)},
{
2
,
INPUT_DESC
(
mg
)},
{
3
,
INPUT_DESC
(
ms
)},
{
4
,
INPUT_DESC
(
mom
)},
{
5
,
INPUT_DESC
(
grad
)},
{
6
,
INPUT_DESC
(
lr
)},
{
7
,
INPUT_DESC
(
rho
)},
{
8
,
INPUT_DESC
(
momentum
)},
{
9
,
INPUT_DESC
(
epsilon
)}};
ATTR_MAP
(
ApplyCenteredRMSProp
)
=
{{
"use_locking"
,
ATTR_DESC
(
use_locking
,
AnyTraits
<
bool
>
())}};
OUTPUT_MAP
(
ApplyCenteredRMSProp
)
=
{{
0
,
OUTPUT_DESC
(
var
)}};
// ApplyCenteredRMSPropD
INPUT_MAP
(
ApplyCenteredRMSPropD
)
=
{{
1
,
INPUT_DESC
(
var
)},
{
2
,
INPUT_DESC
(
mg
)},
{
3
,
INPUT_DESC
(
ms
)},
{
4
,
INPUT_DESC
(
mom
)},
{
5
,
INPUT_DESC
(
grad
)},
{
6
,
INPUT_DESC
(
lr
)},
{
7
,
INPUT_DESC
(
rho
)},
{
8
,
INPUT_DESC
(
momentum
)},
{
9
,
INPUT_DESC
(
epsilon
)}};
ATTR_MAP
(
ApplyCenteredRMSPropD
)
=
{{
"use_locking"
,
ATTR_DESC
(
use_locking
,
AnyTraits
<
bool
>
())}};
OUTPUT_MAP
(
ApplyCenteredRMSPropD
)
=
{
{
0
,
OUTPUT_DESC
(
var
)},
{
1
,
OUTPUT_DESC
(
mg
)},
{
2
,
OUTPUT_DESC
(
ms
)},
{
3
,
OUTPUT_DESC
(
mom
)}};
// L2Loss
INPUT_MAP
(
L2Loss
)
=
{{
1
,
INPUT_DESC
(
x
)}};
...
...
mindspore/ccsrc/transform/graph_ir/op_declare.h
浏览文件 @
380a57f3
...
...
@@ -486,8 +486,8 @@ DECLARE_OP_USE_OUTPUT(Atan2)
DECLARE_OP_ADAPTER
(
ApplyRMSPropD
)
DECLARE_OP_USE_INPUT_ATTR
(
ApplyRMSPropD
)
DECLARE_OP_USE_OUTPUT
(
ApplyRMSPropD
)
DECLARE_OP_ADAPTER
(
ApplyCenteredRMSProp
)
DECLARE_OP_USE_OUTPUT
(
ApplyCenteredRMSProp
)
DECLARE_OP_ADAPTER
(
ApplyCenteredRMSProp
D
)
DECLARE_OP_USE_OUTPUT
(
ApplyCenteredRMSProp
D
)
DECLARE_OP_ADAPTER
(
L2Loss
)
DECLARE_OP_USE_OUTPUT
(
L2Loss
)
DECLARE_OP_ADAPTER
(
CTCLoss
)
...
...
mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py
浏览文件 @
380a57f3
...
...
@@ -13,15 +13,15 @@
# limitations under the License.
# ============================================================================
"""ApplyCenteredRMSProp op"""
"""ApplyCenteredRMSProp
D
op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
apply_centered_rms_prop_op_info
=
TBERegOp
(
"ApplyCenteredRMSProp"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"apply_centered_rms_prop.so"
)
\
.
binfile_name
(
"apply_centered_rms_prop
_d
.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"apply_centered_rms_prop"
)
\
.
kernel_name
(
"apply_centered_rms_prop
_d
"
)
\
.
partial_flag
(
True
)
\
.
input
(
0
,
"var"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"mg"
,
False
,
"required"
,
"all"
)
\
...
...
@@ -33,34 +33,45 @@ apply_centered_rms_prop_op_info = TBERegOp("ApplyCenteredRMSProp") \
.
input
(
7
,
"epsilon"
,
False
,
"required"
,
"all"
)
\
.
input
(
8
,
"grad"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"var"
,
False
,
"required"
,
"all"
)
\
.
output
(
1
,
"mg"
,
False
,
"required"
,
"all"
)
\
.
output
(
2
,
"ms"
,
False
,
"required"
,
"all"
)
\
.
output
(
3
,
"mom"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F16_FracZ
,
DataType
.
F16_FracZ
,
DataType
.
F16_FracZ
,
DataType
.
F16_FracZ
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_FracZ
,
DataType
.
F16_FracZ
)
\
DataType
.
F16_FracZ
,
DataType
.
F16_FracZ
,
DataType
.
F16_FracZ
,
DataType
.
F16_FracZ
,
DataType
.
F16_FracZ
)
\
.
dtype_format
(
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_C1HWNCoC0
)
\
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_C1HWNCoC0
,
DataType
.
F16_C1HWNCoC0
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
dtype_format
(
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
)
\
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
)
\
.
dtype_format
(
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
)
\
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
get_op_info
()
@
op_info_register
(
apply_centered_rms_prop_op_info
)
def
_apply_centered_rms_prop_tbe
():
"""ApplyCenteredRMSProp TBE register"""
"""ApplyCenteredRMSProp
D
TBE register"""
return
mindspore/ops/operations/nn_ops.py
浏览文件 @
380a57f3
...
...
@@ -1962,6 +1962,7 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
@
prim_attr_register
def
__init__
(
self
,
use_locking
=
False
):
self
.
use_locking
=
validator
.
check_value_type
(
"use_locking"
,
use_locking
,
[
bool
],
self
.
name
)
self
.
is_ascend
=
context
.
get_context
(
"device_target"
)
==
"Ascend"
def
infer_shape
(
self
,
var_shape
,
mean_gradient_shape
,
mean_square_shape
,
moment_shape
,
grad_shape
,
learning_rate_shape
,
decay_shape
,
momentum_shape
,
epsilon_shape
):
...
...
@@ -1969,6 +1970,8 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
validator
.
check
(
"var_shape"
,
var_shape
,
"mean_square_shape"
,
mean_square_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"var_shape"
,
var_shape
,
"moment_shape"
,
moment_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"var_shape"
,
var_shape
,
"grad_shape"
,
grad_shape
,
Rel
.
EQ
,
self
.
name
)
if
self
.
is_ascend
:
return
var_shape
,
mean_gradient_shape
,
mean_square_shape
,
moment_shape
return
var_shape
def
infer_dtype
(
self
,
var_dtype
,
mean_gradient_dtype
,
mean_square_dtype
,
moment_dtype
,
grad_dtype
,
...
...
@@ -1982,6 +1985,8 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
validator
.
check_type_same
(
args_rho
,
valid_types
,
self
.
name
)
args_lr
=
{
"learning_rate"
:
learning_rate_dtype
,
"rho"
:
rho_dtype
}
validator
.
check_scalar_or_tensor_type_same
(
args_lr
,
valid_types
,
self
.
name
,
allow_mix
=
True
)
if
self
.
is_ascend
:
return
var_dtype
,
mean_gradient_dtype
,
mean_square_dtype
,
moment_dtype
return
var_dtype
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录