Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
a42ec8f6
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a42ec8f6
编写于
5月 18, 2020
作者:
Z
zhaojichen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add applyrmsprop op for vm
上级
d9c74e0a
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
159 addition
and
30 deletion
+159
-30
mindspore/ccsrc/kernel/tbe/tbe_adapter.cc
mindspore/ccsrc/kernel/tbe/tbe_adapter.cc
+4
-1
mindspore/ccsrc/operator/ops.cc
mindspore/ccsrc/operator/ops.cc
+1
-0
mindspore/ccsrc/operator/ops.h
mindspore/ccsrc/operator/ops.h
+1
-0
mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc
...e/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc
+2
-1
mindspore/nn/optim/rmsprop.py
mindspore/nn/optim/rmsprop.py
+1
-1
mindspore/ops/_op_impl/tbe/__init__.py
mindspore/ops/_op_impl/tbe/__init__.py
+3
-1
mindspore/ops/_op_impl/tbe/apply_rms_prop.py
mindspore/ops/_op_impl/tbe/apply_rms_prop.py
+51
-0
mindspore/ops/_op_impl/tbe/cumprod.py
mindspore/ops/_op_impl/tbe/cumprod.py
+17
-16
mindspore/ops/_op_impl/tbe/reduce_prod.py
mindspore/ops/_op_impl/tbe/reduce_prod.py
+44
-0
mindspore/ops/operations/math_ops.py
mindspore/ops/operations/math_ops.py
+7
-4
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+10
-2
tests/st/ops/test_rmsprop.py
tests/st/ops/test_rmsprop.py
+1
-1
tests/ut/python/ops/test_ops.py
tests/ut/python/ops/test_ops.py
+17
-3
未找到文件。
mindspore/ccsrc/kernel/tbe/tbe_adapter.cc
浏览文件 @
a42ec8f6
...
...
@@ -92,7 +92,10 @@ static std::map<string, string> tbe_func_adapter_map = {
{
"l_ars_update"
,
"lars_v2_update"
},
{
"n_ms_with_mask"
,
"nms_with_mask"
},
{
"square_sum_all"
,
"square_sum_all"
},
{
"cum_sum"
,
"cumsum_d"
}};
{
"cum_sum"
,
"cumsum_d"
},
{
"apply_rms_prop"
,
"apply_rms_prop_d"
},
{
"cum_prod"
,
"cumprod_d"
},
{
"reduce_prod"
,
"reduce_prod_d"
}};
void
TbeAdapter
::
NormalizeFuncName
(
std
::
string
*
func_name
)
{
if
(
func_name
==
nullptr
)
{
...
...
mindspore/ccsrc/operator/ops.cc
浏览文件 @
a42ec8f6
...
...
@@ -167,6 +167,7 @@ const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal");
const
PrimitivePtr
kPrimLess
=
std
::
make_shared
<
Primitive
>
(
"Less"
);
const
PrimitivePtr
kPrimLessEqual
=
std
::
make_shared
<
Primitive
>
(
"LessEqual"
);
const
PrimitivePtr
kPrimCumSum
=
std
::
make_shared
<
Primitive
>
(
"CumSum"
);
const
PrimitivePtr
kPrimCumProd
=
std
::
make_shared
<
Primitive
>
(
"CumProd"
);
// NN
const
PrimitivePtr
kPrimFlatten
=
std
::
make_shared
<
Primitive
>
(
"Flatten"
);
...
...
mindspore/ccsrc/operator/ops.h
浏览文件 @
a42ec8f6
...
...
@@ -173,6 +173,7 @@ extern const PrimitivePtr kPrimEqual;
extern
const
PrimitivePtr
kPrimLess
;
extern
const
PrimitivePtr
kPrimLessEqual
;
extern
const
PrimitivePtr
kPrimCumSum
;
extern
const
PrimitivePtr
kPrimCumProd
;
// NN
extern
const
PrimitivePtr
kPrimFlatten
;
...
...
mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc
浏览文件 @
a42ec8f6
...
...
@@ -41,6 +41,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register
(
prim
::
kPrimOneHot
->
name
(),
{
1
});
Register
(
prim
::
kPrimConcat
->
name
(),
{
0
});
Register
(
prim
::
kPrimCumSum
->
name
(),
{
1
});
Register
(
prim
::
kPrimCumProd
->
name
(),
{
1
});
Register
(
kUnsortedSegmentProdOpName
,
{
2
});
Register
(
kUnsortedSegmentMinOpName
,
{
2
});
Register
(
kSimpleMeanGradOpName
,
{
1
});
...
...
@@ -60,7 +61,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register
(
kResizeNearestNeighborGradOpName
,
{
1
});
Register
(
kResizeNearestNeighborV2OpName
,
{
1
});
Register
(
kResizeNearestNeighborV2GradOpName
,
{
1
});
Register
(
kApplyRMSPropOpname
,
{
4
,
5
,
6
});
Register
(
kApplyRMSPropOpname
,
{
5
,
6
,
7
});
Register
(
kResizeBilinearV2OpName
,
{
1
});
Register
(
kReduceProdOpName
,
{
1
});
Register
(
kCumprodOpName
,
{
1
});
...
...
mindspore/nn/optim/rmsprop.py
浏览文件 @
a42ec8f6
...
...
@@ -26,7 +26,7 @@ centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
def
_rmsprop_opt
(
opt
,
decay
,
epsilon
,
momentum
,
learning_rate
,
weight
,
ms
,
mom
,
grad
):
"""Apply rmsprop optimizer to the weight parameter using dynamic learning rate."""
success
=
True
success
=
F
.
depend
(
success
,
opt
(
weight
,
ms
,
mom
,
grad
,
learning_rate
,
decay
,
momentum
,
epsilon
))
success
=
F
.
depend
(
success
,
opt
(
weight
,
ms
,
mom
,
learning_rate
,
grad
,
decay
,
momentum
,
epsilon
))
return
success
...
...
mindspore/ops/_op_impl/tbe/__init__.py
浏览文件 @
a42ec8f6
...
...
@@ -180,7 +180,6 @@ from .check_valid import _check_valid_tbe
from
.iou
import
_iou_tbe
from
.arg_max
import
_arg_max_tbe
from
.nms_with_mask
import
_nms_with_mask_tbe
from
.random_choice_with_mask
import
_random_choice_with_mask_tbe
from
.sgd
import
_sgd_tbe
from
.lars_update
import
_lars_update_tbe
from
.bn_training_update_v2
import
_bn_training_update_v2_tbe
...
...
@@ -195,3 +194,6 @@ from .binary_cross_entropy_grad import _binary_cross_entropy_grad_tbe
from
.sin
import
_sin_tbe
from
.cos
import
_cos_tbe
from
.cum_sum
import
_cum_sum_tbe
from
.apply_rms_prop
import
_apply_rms_prop_tbe
from
.cumprod
import
_cumprop_tbe
from
.reduce_prod
import
_reduce_prod_tbe
mindspore/ops/_op_impl/tbe/apply_rms_prop.py
0 → 100644
浏览文件 @
a42ec8f6
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ApplyRMSProd op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
apply_rms_prop_op_info
=
TBERegOp
(
"ApplyRMSProp"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"apply_rms_prop.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"apply_rms_prop_d"
)
\
.
partial_flag
(
True
)
\
.
attr
(
"rho"
,
"required"
,
"float"
,
"all"
)
\
.
attr
(
"momentum"
,
"required"
,
"float"
,
"all"
)
\
.
attr
(
"epsilon"
,
"required"
,
"float"
,
"all"
)
\
.
input
(
0
,
"var"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"ms"
,
False
,
"required"
,
"all"
)
\
.
input
(
2
,
"mom"
,
False
,
"required"
,
"all"
)
\
.
input
(
3
,
"lr"
,
False
,
"required"
,
"all"
)
\
.
input
(
4
,
"grad"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"var"
,
False
,
"required"
,
"all"
)
\
.
output
(
1
,
"ms"
,
False
,
"required"
,
"all"
)
\
.
output
(
2
,
"mom"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
,
DataType
.
F32_Default
,
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
)
\
.
dtype_format
(
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_Default
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
,
DataType
.
F32_C1HWNCoC0
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_Default
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
get_op_info
()
@
op_info_register
(
apply_rms_prop_op_info
)
def
_apply_rms_prop_tbe
():
"""ApplyRMSProp TBE register"""
return
mindspore/ops/_op_impl/tbe/
random_choice_with_mask
.py
→
mindspore/ops/_op_impl/tbe/
cumprod
.py
浏览文件 @
a42ec8f6
...
...
@@ -13,29 +13,30 @@
# limitations under the License.
# ============================================================================
"""
RandomChoiceWithMask
op"""
"""
CumProd
op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
random_choice_with_mask_op_info
=
TBERegOp
(
"RandomChoiceWithMask
"
)
\
cumprop_op_info
=
TBERegOp
(
"CumProd
"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"
random_choice_with_mask
.so"
)
\
.
binfile_name
(
"
cumprod_d
.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"
random_choice_with_mask
"
)
\
.
kernel_name
(
"
cumprod_d
"
)
\
.
partial_flag
(
True
)
\
.
attr
(
"max_shape"
,
"optional"
,
"listInt"
,
"all"
)
\
.
attr
(
"means"
,
"optional"
,
"listFloat"
,
"all"
)
\
.
attr
(
"stds"
,
"optional"
,
"listFloat"
,
"all"
)
\
.
attr
(
"wh_ratio_clip"
,
"optional"
,
"float"
,
"all"
)
\
.
input
(
0
,
"rois"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"deltas"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"bboxes"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
attr
(
"axis"
,
"optional"
,
"int"
,
"all"
)
\
.
attr
(
"exclusive"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"reverse"
,
"optional"
,
"bool"
,
"all"
)
\
.
input
(
0
,
"x"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
I32_Default
,
DataType
.
I32_Default
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
I8_Default
,
DataType
.
I8_Default
)
\
.
dtype_format
(
DataType
.
U8_Default
,
DataType
.
U8_Default
)
\
.
get_op_info
()
@
op_info_register
(
random_choice_with_mask
_op_info
)
def
_
random_choice_with_mask
_tbe
():
"""
RandomChoiceWithMask
TBE register"""
@
op_info_register
(
cumprop
_op_info
)
def
_
cumprop
_tbe
():
"""
CumProd
TBE register"""
return
mindspore/ops/_op_impl/tbe/reduce_prod.py
0 → 100644
浏览文件 @
a42ec8f6
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReduceProd op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
reduce_prod_op_info
=
TBERegOp
(
"ReduceProd"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"reduce_prod_d.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"reduce_prod_d"
)
\
.
partial_flag
(
True
)
\
.
attr
(
"axis"
,
"required"
,
"listInt"
,
"all"
)
\
.
attr
(
"keep_dims"
,
"optional"
,
"bool"
,
"all"
)
\
.
input
(
0
,
"x"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
I8_Default
,
DataType
.
I8_Default
)
\
.
dtype_format
(
DataType
.
I8_FracZ
,
DataType
.
I8_FracZ
)
\
.
dtype_format
(
DataType
.
U8_Default
,
DataType
.
U8_Default
)
\
.
dtype_format
(
DataType
.
U8_FracZ
,
DataType
.
U8_FracZ
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_FracZ
,
DataType
.
F16_FracZ
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_FracZ
,
DataType
.
F32_FracZ
)
\
.
get_op_info
()
@
op_info_register
(
reduce_prod_op_info
)
def
_reduce_prod_tbe
():
"""ReduceProd TBE register"""
return
mindspore/ops/operations/math_ops.py
浏览文件 @
a42ec8f6
...
...
@@ -475,6 +475,7 @@ class CumProd(PrimitiveWithInfer):
cls_name
=
self
.
name
self
.
exclusive
=
validator
.
check_value_type
(
"exclusive"
,
exclusive
,
[
bool
],
cls_name
)
self
.
reverse
=
validator
.
check_value_type
(
"reverse"
,
reverse
,
[
bool
],
cls_name
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'axis'
],
outputs
=
[
'y'
])
def
infer_shape
(
self
,
x_shape
,
axis_shape
):
return
x_shape
...
...
@@ -2022,8 +2023,10 @@ class NMSWithMask(PrimitiveWithInfer):
validator
.
check_integer
(
"bboxes.shape()[0]"
,
bboxes_shape
[
0
],
0
,
Rel
.
GT
,
cls_name
)
if
not
self
.
is_ge
:
validator
.
check_integer
(
"bboxes.shape()[1]"
,
bboxes_shape
[
1
],
8
,
Rel
.
EQ
,
cls_name
)
else
:
validator
.
check_integer
(
"bboxes.shape()[1]"
,
bboxes_shape
[
1
],
5
,
Rel
.
EQ
,
cls_name
)
num
=
bboxes_shape
[
0
]
return
((
num
,
5
),
(
num
,),
(
num
,))
validator
.
check_integer
(
"bboxes.shape()[1]"
,
bboxes_shape
[
1
],
5
,
Rel
.
EQ
,
cls_name
)
num
=
bboxes_shape
[
0
]
return
(
bboxes_shape
,
(
num
,),
(
num
,))
...
...
@@ -2171,8 +2174,8 @@ class SquareSumAll(PrimitiveWithInfer):
- **output_y2** (Tensor) - The same type as the `input_x1`.
Examples:
>>> input_x1 = Tensor(np.random.randint([3, 2, 5,7]), mindspore.float32)
>>> input_x2 = Tensor(np.random.randint([3, 2, 5,7]), mindspore.float32)
>>> input_x1 = Tensor(np.random.randint([3, 2, 5,
7]), mindspore.float32)
>>> input_x2 = Tensor(np.random.randint([3, 2, 5,
7]), mindspore.float32)
>>> square_sum_all = P.SquareSumAll()
>>> square_sum_all(input_x1, input_x2)
"""
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
a42ec8f6
...
...
@@ -1721,15 +1721,21 @@ class ApplyRMSProp(PrimitiveWithInfer):
@
prim_attr_register
def
__init__
(
self
,
use_locking
=
False
):
self
.
use_locking
=
validator
.
check_value_type
(
"use_locking"
,
use_locking
,
[
bool
],
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'var'
,
'mean_square'
,
'moment'
,
'learning_rate'
,
'grad'
,
'rho'
,
'momentum'
,
'epsilon'
],
outputs
=
[
'output'
])
self
.
is_ge
=
context
.
get_context
(
"enable_ge"
)
self
.
is_d
=
context
.
get_context
(
"device_target"
)
==
"Ascend"
def
infer_shape
(
self
,
var_shape
,
mean_square_shape
,
moment_shape
,
grad_shape
,
learning_rate
_shape
,
decay_shape
,
def
infer_shape
(
self
,
var_shape
,
mean_square_shape
,
moment_shape
,
learning_rate_shape
,
grad
_shape
,
decay_shape
,
momentum_shape
,
epsilon_shape
):
validator
.
check
(
"var_shape"
,
var_shape
,
"mean_square_shape"
,
mean_square_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"var_shape"
,
var_shape
,
"moment_shape"
,
moment_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"var_shape"
,
var_shape
,
"grad_shape"
,
grad_shape
,
Rel
.
EQ
,
self
.
name
)
if
not
self
.
is_ge
and
self
.
is_d
:
return
var_shape
,
var_shape
,
var_shape
return
var_shape
def
infer_dtype
(
self
,
var_dtype
,
mean_square_dtype
,
moment_dtype
,
grad_dtype
,
learning_rate
_dtype
,
decay_dtype
,
def
infer_dtype
(
self
,
var_dtype
,
mean_square_dtype
,
moment_dtype
,
learning_rate_dtype
,
grad
_dtype
,
decay_dtype
,
momentum_dtype
,
epsilon_dtype
):
args
=
{
"var"
:
var_dtype
,
"mean_square"
:
mean_square_dtype
,
"moment"
:
moment_dtype
,
"grad"
:
grad_dtype
}
validator
.
check_tensor_type_same
(
args
,
mstype
.
number_type
,
self
.
name
)
...
...
@@ -1739,6 +1745,8 @@ class ApplyRMSProp(PrimitiveWithInfer):
validator
.
check_type_same
(
args_decay
,
valid_types
,
self
.
name
)
args_lr
=
{
"learning_rate"
:
learning_rate_dtype
,
"decay"
:
decay_dtype
}
validator
.
check_scalar_or_tensor_type_same
(
args_lr
,
valid_types
,
self
.
name
,
allow_mix
=
True
)
if
not
self
.
is_ge
and
self
.
is_d
:
return
var_dtype
,
var_dtype
,
var_dtype
return
var_dtype
...
...
tests/st/ops/test_rmsprop.py
浏览文件 @
a42ec8f6
...
...
@@ -37,7 +37,7 @@ class NetRMSProp(nn.Cell):
if
self
.
use_centered
:
return
self
.
rms_opt
(
var
,
mg
,
rms
,
mom
,
g
,
lr
,
decay
,
momentum
,
epsilon
)
else
:
return
self
.
rms_opt
(
var
,
rms
,
mom
,
g
,
lr
,
decay
,
momentum
,
epsilon
)
return
self
.
rms_opt
(
var
,
rms
,
mom
,
lr
,
g
,
decay
,
momentum
,
epsilon
)
def
rmsprop_numpy
(
variable
,
gradients
,
mean_square
,
moment
,
...
...
tests/ut/python/ops/test_ops.py
浏览文件 @
a42ec8f6
...
...
@@ -202,6 +202,21 @@ class ApplyFtrlNet(nn.Cell):
out
=
self
.
apply_ftrl
(
self
.
var
,
self
.
accum
,
self
.
linear
,
grad
,
self
.
lr
,
self
.
l1
,
self
.
l2
,
self
.
lr_power
)
return
out
class
ApplyRMSNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
ApplyRMSNet
,
self
).
__init__
()
self
.
apply_rms
=
P
.
ApplyRMSProp
()
self
.
lr
=
0.001
self
.
rho
=
0.0
self
.
momentum
=
0.0
self
.
epsilon
=
1e-10
self
.
var
=
Parameter
(
Tensor
(
np
.
random
.
rand
(
3
,
3
).
astype
(
np
.
float32
)),
name
=
"var"
)
self
.
ms
=
Parameter
(
Tensor
(
np
.
random
.
rand
(
3
,
3
).
astype
(
np
.
float32
)),
name
=
"ms"
)
self
.
moment
=
Parameter
(
Tensor
(
np
.
random
.
rand
(
3
,
3
).
astype
(
np
.
float32
)),
name
=
"moment"
)
def
construct
(
self
,
grad
):
out
=
self
.
apply_rms
(
self
.
var
,
self
.
ms
,
self
.
moment
,
self
.
lr
,
grad
,
self
.
rho
,
self
.
momentum
,
self
.
epsilon
)
return
out
test_case_math_ops
=
[
(
'Neg'
,
{
...
...
@@ -914,9 +929,8 @@ test_case_nn_ops = [
'desc_bprop'
:
[
3
,
3
],
'skip'
:
[
'backward'
]}),
(
'ApplyRMSProp'
,
{
'block'
:
P
.
ApplyRMSProp
(),
'desc_const'
:
[
0.9
,
0.0
,
1e-10
,
0.001
],
'desc_inputs'
:
[[
3
,
3
],
[
3
,
3
],
[
3
,
3
],
[
3
,
3
]],
'block'
:
ApplyRMSNet
(),
'desc_inputs'
:
[[
3
,
3
]],
'desc_bprop'
:
[
3
,
3
],
'skip'
:
[
'backward'
]}),
(
'ApplyCenteredRMSProp'
,
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录