提交 a42ec8f6 编写于 作者: Z zhaojichen

add applyrmsprop op for vm

上级 d9c74e0a
...@@ -92,7 +92,10 @@ static std::map<string, string> tbe_func_adapter_map = { ...@@ -92,7 +92,10 @@ static std::map<string, string> tbe_func_adapter_map = {
{"l_ars_update", "lars_v2_update"}, {"l_ars_update", "lars_v2_update"},
{"n_ms_with_mask", "nms_with_mask"}, {"n_ms_with_mask", "nms_with_mask"},
{"square_sum_all", "square_sum_all"}, {"square_sum_all", "square_sum_all"},
{"cum_sum", "cumsum_d"}}; {"cum_sum", "cumsum_d"},
{"apply_rms_prop", "apply_rms_prop_d"},
{"cum_prod", "cumprod_d"},
{"reduce_prod", "reduce_prod_d"}};
void TbeAdapter::NormalizeFuncName(std::string *func_name) { void TbeAdapter::NormalizeFuncName(std::string *func_name) {
if (func_name == nullptr) { if (func_name == nullptr) {
......
...@@ -167,6 +167,7 @@ const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal"); ...@@ -167,6 +167,7 @@ const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal");
const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less"); const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual"); const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum"); const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum");
const PrimitivePtr kPrimCumProd = std::make_shared<Primitive>("CumProd");
// NN // NN
const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
......
...@@ -173,6 +173,7 @@ extern const PrimitivePtr kPrimEqual; ...@@ -173,6 +173,7 @@ extern const PrimitivePtr kPrimEqual;
extern const PrimitivePtr kPrimLess; extern const PrimitivePtr kPrimLess;
extern const PrimitivePtr kPrimLessEqual; extern const PrimitivePtr kPrimLessEqual;
extern const PrimitivePtr kPrimCumSum; extern const PrimitivePtr kPrimCumSum;
extern const PrimitivePtr kPrimCumProd;
// NN // NN
extern const PrimitivePtr kPrimFlatten; extern const PrimitivePtr kPrimFlatten;
......
...@@ -41,6 +41,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { ...@@ -41,6 +41,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register(prim::kPrimOneHot->name(), {1}); Register(prim::kPrimOneHot->name(), {1});
Register(prim::kPrimConcat->name(), {0}); Register(prim::kPrimConcat->name(), {0});
Register(prim::kPrimCumSum->name(), {1}); Register(prim::kPrimCumSum->name(), {1});
Register(prim::kPrimCumProd->name(), {1});
Register(kUnsortedSegmentProdOpName, {2}); Register(kUnsortedSegmentProdOpName, {2});
Register(kUnsortedSegmentMinOpName, {2}); Register(kUnsortedSegmentMinOpName, {2});
Register(kSimpleMeanGradOpName, {1}); Register(kSimpleMeanGradOpName, {1});
...@@ -60,7 +61,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { ...@@ -60,7 +61,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register(kResizeNearestNeighborGradOpName, {1}); Register(kResizeNearestNeighborGradOpName, {1});
Register(kResizeNearestNeighborV2OpName, {1}); Register(kResizeNearestNeighborV2OpName, {1});
Register(kResizeNearestNeighborV2GradOpName, {1}); Register(kResizeNearestNeighborV2GradOpName, {1});
Register(kApplyRMSPropOpname, {4, 5, 6}); Register(kApplyRMSPropOpname, {5, 6, 7});
Register(kResizeBilinearV2OpName, {1}); Register(kResizeBilinearV2OpName, {1});
Register(kReduceProdOpName, {1}); Register(kReduceProdOpName, {1});
Register(kCumprodOpName, {1}); Register(kCumprodOpName, {1});
......
...@@ -26,7 +26,7 @@ centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") ...@@ -26,7 +26,7 @@ centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
def _rmsprop_opt(opt, decay, epsilon, momentum, learning_rate, weight, ms, mom, grad): def _rmsprop_opt(opt, decay, epsilon, momentum, learning_rate, weight, ms, mom, grad):
"""Apply rmsprop optimizer to the weight parameter using dynamic learning rate.""" """Apply rmsprop optimizer to the weight parameter using dynamic learning rate."""
success = True success = True
success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon)) success = F.depend(success, opt(weight, ms, mom, learning_rate, grad, decay, momentum, epsilon))
return success return success
......
...@@ -180,7 +180,6 @@ from .check_valid import _check_valid_tbe ...@@ -180,7 +180,6 @@ from .check_valid import _check_valid_tbe
from .iou import _iou_tbe from .iou import _iou_tbe
from .arg_max import _arg_max_tbe from .arg_max import _arg_max_tbe
from .nms_with_mask import _nms_with_mask_tbe from .nms_with_mask import _nms_with_mask_tbe
from .random_choice_with_mask import _random_choice_with_mask_tbe
from .sgd import _sgd_tbe from .sgd import _sgd_tbe
from .lars_update import _lars_update_tbe from .lars_update import _lars_update_tbe
from .bn_training_update_v2 import _bn_training_update_v2_tbe from .bn_training_update_v2 import _bn_training_update_v2_tbe
...@@ -195,3 +194,6 @@ from .binary_cross_entropy_grad import _binary_cross_entropy_grad_tbe ...@@ -195,3 +194,6 @@ from .binary_cross_entropy_grad import _binary_cross_entropy_grad_tbe
from .sin import _sin_tbe from .sin import _sin_tbe
from .cos import _cos_tbe from .cos import _cos_tbe
from .cum_sum import _cum_sum_tbe from .cum_sum import _cum_sum_tbe
from .apply_rms_prop import _apply_rms_prop_tbe
from .cumprod import _cumprop_tbe
from .reduce_prod import _reduce_prod_tbe
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ApplyRMSProd op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
apply_rms_prop_op_info = TBERegOp("ApplyRMSProp") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("apply_rms_prop.so") \
.compute_cost(10) \
.kernel_name("apply_rms_prop_d") \
.partial_flag(True) \
.attr("rho", "required", "float", "all") \
.attr("momentum", "required", "float", "all") \
.attr("epsilon", "required", "float", "all") \
.input(0, "var", False, "required", "all") \
.input(1, "ms", False, "required", "all") \
.input(2, "mom", False, "required", "all") \
.input(3, "lr", False, "required", "all") \
.input(4, "grad", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "ms", False, "required", "all") \
.output(2, "mom", False, "required", "all") \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default,
DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default,
DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(apply_rms_prop_op_info)
def _apply_rms_prop_tbe():
"""ApplyRMSProp TBE register"""
return
...@@ -13,29 +13,30 @@ ...@@ -13,29 +13,30 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""RandomChoiceWithMask op""" """CumProd op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
random_choice_with_mask_op_info = TBERegOp("RandomChoiceWithMask") \ cumprop_op_info = TBERegOp("CumProd") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.async_flag(False) \ .async_flag(False) \
.binfile_name("random_choice_with_mask.so") \ .binfile_name("cumprod_d.so") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("random_choice_with_mask") \ .kernel_name("cumprod_d") \
.partial_flag(True) \ .partial_flag(True) \
.attr("max_shape", "optional", "listInt", "all") \ .attr("axis", "optional", "int", "all") \
.attr("means", "optional", "listFloat", "all") \ .attr("exclusive", "optional", "bool", "all") \
.attr("stds", "optional", "listFloat", "all") \ .attr("reverse", "optional", "bool", "all") \
.attr("wh_ratio_clip", "optional", "float", "all") \ .input(0, "x", False, "required", "all") \
.input(0, "rois", False, "required", "all") \ .output(0, "y", False, "required", "all") \
.input(1, "deltas", False, "required", "all") \ .dtype_format(DataType.I32_Default, DataType.I32_Default) \
.output(0, "bboxes", False, "required", "all") \ .dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.get_op_info() .get_op_info()
@op_info_register(random_choice_with_mask_op_info) @op_info_register(cumprop_op_info)
def _random_choice_with_mask_tbe(): def _cumprop_tbe():
"""RandomChoiceWithMask TBE register""" """CumProd TBE register"""
return return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReduceProd op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
reduce_prod_op_info = TBERegOp("ReduceProd") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("reduce_prod_d.so") \
.compute_cost(10) \
.kernel_name("reduce_prod_d") \
.partial_flag(True) \
.attr("axis", "required", "listInt", "all") \
.attr("keep_dims", "optional", "bool", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_FracZ, DataType.I8_FracZ) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_FracZ, DataType.U8_FracZ) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \
.get_op_info()
@op_info_register(reduce_prod_op_info)
def _reduce_prod_tbe():
"""ReduceProd TBE register"""
return
...@@ -475,6 +475,7 @@ class CumProd(PrimitiveWithInfer): ...@@ -475,6 +475,7 @@ class CumProd(PrimitiveWithInfer):
cls_name = self.name cls_name = self.name
self.exclusive = validator.check_value_type("exclusive", exclusive, [bool], cls_name) self.exclusive = validator.check_value_type("exclusive", exclusive, [bool], cls_name)
self.reverse = validator.check_value_type("reverse", reverse, [bool], cls_name) self.reverse = validator.check_value_type("reverse", reverse, [bool], cls_name)
self.init_prim_io_names(inputs=['x', 'axis'], outputs=['y'])
def infer_shape(self, x_shape, axis_shape): def infer_shape(self, x_shape, axis_shape):
return x_shape return x_shape
...@@ -2022,8 +2023,10 @@ class NMSWithMask(PrimitiveWithInfer): ...@@ -2022,8 +2023,10 @@ class NMSWithMask(PrimitiveWithInfer):
validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name) validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name)
if not self.is_ge: if not self.is_ge:
validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 8, Rel.EQ, cls_name) validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 8, Rel.EQ, cls_name)
else: num = bboxes_shape[0]
validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name) return ((num, 5), (num,), (num,))
validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name)
num = bboxes_shape[0] num = bboxes_shape[0]
return (bboxes_shape, (num,), (num,)) return (bboxes_shape, (num,), (num,))
...@@ -2171,8 +2174,8 @@ class SquareSumAll(PrimitiveWithInfer): ...@@ -2171,8 +2174,8 @@ class SquareSumAll(PrimitiveWithInfer):
- **output_y2** (Tensor) - The same type as the `input_x1`. - **output_y2** (Tensor) - The same type as the `input_x1`.
Examples: Examples:
>>> input_x1 = Tensor(np.random.randint([3, 2, 5,7]), mindspore.float32) >>> input_x1 = Tensor(np.random.randint([3, 2, 5, 7]), mindspore.float32)
>>> input_x2 = Tensor(np.random.randint([3, 2, 5,7]), mindspore.float32) >>> input_x2 = Tensor(np.random.randint([3, 2, 5, 7]), mindspore.float32)
>>> square_sum_all = P.SquareSumAll() >>> square_sum_all = P.SquareSumAll()
>>> square_sum_all(input_x1, input_x2) >>> square_sum_all(input_x1, input_x2)
""" """
......
...@@ -1721,15 +1721,21 @@ class ApplyRMSProp(PrimitiveWithInfer): ...@@ -1721,15 +1721,21 @@ class ApplyRMSProp(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, use_locking=False): def __init__(self, use_locking=False):
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
self.init_prim_io_names(inputs=['var', 'mean_square', 'moment', 'learning_rate', 'grad',
'rho', 'momentum', 'epsilon'], outputs=['output'])
self.is_ge = context.get_context("enable_ge")
self.is_d = context.get_context("device_target") == "Ascend"
def infer_shape(self, var_shape, mean_square_shape, moment_shape, grad_shape, learning_rate_shape, decay_shape, def infer_shape(self, var_shape, mean_square_shape, moment_shape, learning_rate_shape, grad_shape, decay_shape,
momentum_shape, epsilon_shape): momentum_shape, epsilon_shape):
validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
if not self.is_ge and self.is_d:
return var_shape, var_shape, var_shape
return var_shape return var_shape
def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, grad_dtype, learning_rate_dtype, decay_dtype, def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, learning_rate_dtype, grad_dtype, decay_dtype,
momentum_dtype, epsilon_dtype): momentum_dtype, epsilon_dtype):
args = {"var": var_dtype, "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype} args = {"var": var_dtype, "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name) validator.check_tensor_type_same(args, mstype.number_type, self.name)
...@@ -1739,6 +1745,8 @@ class ApplyRMSProp(PrimitiveWithInfer): ...@@ -1739,6 +1745,8 @@ class ApplyRMSProp(PrimitiveWithInfer):
validator.check_type_same(args_decay, valid_types, self.name) validator.check_type_same(args_decay, valid_types, self.name)
args_lr = {"learning_rate": learning_rate_dtype, "decay": decay_dtype} args_lr = {"learning_rate": learning_rate_dtype, "decay": decay_dtype}
validator.check_scalar_or_tensor_type_same(args_lr, valid_types, self.name, allow_mix=True) validator.check_scalar_or_tensor_type_same(args_lr, valid_types, self.name, allow_mix=True)
if not self.is_ge and self.is_d:
return var_dtype, var_dtype, var_dtype
return var_dtype return var_dtype
......
...@@ -37,7 +37,7 @@ class NetRMSProp(nn.Cell): ...@@ -37,7 +37,7 @@ class NetRMSProp(nn.Cell):
if self.use_centered: if self.use_centered:
return self.rms_opt(var, mg, rms, mom, g, lr, decay, momentum, epsilon) return self.rms_opt(var, mg, rms, mom, g, lr, decay, momentum, epsilon)
else: else:
return self.rms_opt(var, rms, mom, g, lr, decay, momentum, epsilon) return self.rms_opt(var, rms, mom, lr, g, decay, momentum, epsilon)
def rmsprop_numpy(variable, gradients, mean_square, moment, def rmsprop_numpy(variable, gradients, mean_square, moment,
......
...@@ -202,6 +202,21 @@ class ApplyFtrlNet(nn.Cell): ...@@ -202,6 +202,21 @@ class ApplyFtrlNet(nn.Cell):
out = self.apply_ftrl(self.var, self.accum, self.linear, grad, self.lr, self.l1, self.l2, self.lr_power) out = self.apply_ftrl(self.var, self.accum, self.linear, grad, self.lr, self.l1, self.l2, self.lr_power)
return out return out
class ApplyRMSNet(nn.Cell):
def __init__(self):
super(ApplyRMSNet, self).__init__()
self.apply_rms = P.ApplyRMSProp()
self.lr = 0.001
self.rho = 0.0
self.momentum= 0.0
self.epsilon = 1e-10
self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
self.ms = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="ms")
self.moment = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="moment")
def construct(self, grad):
out = self.apply_rms(self.var, self.ms, self.moment, self.lr, grad, self.rho, self.momentum, self.epsilon)
return out
test_case_math_ops = [ test_case_math_ops = [
('Neg', { ('Neg', {
...@@ -914,9 +929,8 @@ test_case_nn_ops = [ ...@@ -914,9 +929,8 @@ test_case_nn_ops = [
'desc_bprop': [3, 3], 'desc_bprop': [3, 3],
'skip': ['backward']}), 'skip': ['backward']}),
('ApplyRMSProp', { ('ApplyRMSProp', {
'block': P.ApplyRMSProp(), 'block': ApplyRMSNet(),
'desc_const': [0.9, 0.0, 1e-10, 0.001], 'desc_inputs': [[3, 3]],
'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]],
'desc_bprop': [3, 3], 'desc_bprop': [3, 3],
'skip': ['backward']}), 'skip': ['backward']}),
('ApplyCenteredRMSProp', { ('ApplyCenteredRMSProp', {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册