提交 a42ec8f6 编写于 作者: Z zhaojichen

add applyrmsprop op for vm

上级 d9c74e0a
......@@ -92,7 +92,10 @@ static std::map<string, string> tbe_func_adapter_map = {
{"l_ars_update", "lars_v2_update"},
{"n_ms_with_mask", "nms_with_mask"},
{"square_sum_all", "square_sum_all"},
{"cum_sum", "cumsum_d"}};
{"cum_sum", "cumsum_d"},
{"apply_rms_prop", "apply_rms_prop_d"},
{"cum_prod", "cumprod_d"},
{"reduce_prod", "reduce_prod_d"}};
void TbeAdapter::NormalizeFuncName(std::string *func_name) {
if (func_name == nullptr) {
......
......@@ -167,6 +167,7 @@ const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal");
const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum");
const PrimitivePtr kPrimCumProd = std::make_shared<Primitive>("CumProd");
// NN
const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
......
......@@ -173,6 +173,7 @@ extern const PrimitivePtr kPrimEqual;
extern const PrimitivePtr kPrimLess;
extern const PrimitivePtr kPrimLessEqual;
extern const PrimitivePtr kPrimCumSum;
extern const PrimitivePtr kPrimCumProd;
// NN
extern const PrimitivePtr kPrimFlatten;
......
......@@ -41,6 +41,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register(prim::kPrimOneHot->name(), {1});
Register(prim::kPrimConcat->name(), {0});
Register(prim::kPrimCumSum->name(), {1});
Register(prim::kPrimCumProd->name(), {1});
Register(kUnsortedSegmentProdOpName, {2});
Register(kUnsortedSegmentMinOpName, {2});
Register(kSimpleMeanGradOpName, {1});
......@@ -60,7 +61,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register(kResizeNearestNeighborGradOpName, {1});
Register(kResizeNearestNeighborV2OpName, {1});
Register(kResizeNearestNeighborV2GradOpName, {1});
Register(kApplyRMSPropOpname, {4, 5, 6});
Register(kApplyRMSPropOpname, {5, 6, 7});
Register(kResizeBilinearV2OpName, {1});
Register(kReduceProdOpName, {1});
Register(kCumprodOpName, {1});
......
......@@ -26,7 +26,7 @@ centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
def _rmsprop_opt(opt, decay, epsilon, momentum, learning_rate, weight, ms, mom, grad):
"""Apply rmsprop optimizer to the weight parameter using dynamic learning rate."""
success = True
success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon))
success = F.depend(success, opt(weight, ms, mom, learning_rate, grad, decay, momentum, epsilon))
return success
......
......@@ -180,7 +180,6 @@ from .check_valid import _check_valid_tbe
from .iou import _iou_tbe
from .arg_max import _arg_max_tbe
from .nms_with_mask import _nms_with_mask_tbe
from .random_choice_with_mask import _random_choice_with_mask_tbe
from .sgd import _sgd_tbe
from .lars_update import _lars_update_tbe
from .bn_training_update_v2 import _bn_training_update_v2_tbe
......@@ -195,3 +194,6 @@ from .binary_cross_entropy_grad import _binary_cross_entropy_grad_tbe
from .sin import _sin_tbe
from .cos import _cos_tbe
from .cum_sum import _cum_sum_tbe
from .apply_rms_prop import _apply_rms_prop_tbe
from .cumprod import _cumprop_tbe
from .reduce_prod import _reduce_prod_tbe
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ApplyRMSProd op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
apply_rms_prop_op_info = TBERegOp("ApplyRMSProp") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("apply_rms_prop.so") \
.compute_cost(10) \
.kernel_name("apply_rms_prop_d") \
.partial_flag(True) \
.attr("rho", "required", "float", "all") \
.attr("momentum", "required", "float", "all") \
.attr("epsilon", "required", "float", "all") \
.input(0, "var", False, "required", "all") \
.input(1, "ms", False, "required", "all") \
.input(2, "mom", False, "required", "all") \
.input(3, "lr", False, "required", "all") \
.input(4, "grad", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "ms", False, "required", "all") \
.output(2, "mom", False, "required", "all") \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default,
DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default,
DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(apply_rms_prop_op_info)
def _apply_rms_prop_tbe():
"""ApplyRMSProp TBE register"""
return
......@@ -13,29 +13,30 @@
# limitations under the License.
# ============================================================================
"""RandomChoiceWithMask op"""
"""CumProd op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
random_choice_with_mask_op_info = TBERegOp("RandomChoiceWithMask") \
cumprop_op_info = TBERegOp("CumProd") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("random_choice_with_mask.so") \
.binfile_name("cumprod_d.so") \
.compute_cost(10) \
.kernel_name("random_choice_with_mask") \
.kernel_name("cumprod_d") \
.partial_flag(True) \
.attr("max_shape", "optional", "listInt", "all") \
.attr("means", "optional", "listFloat", "all") \
.attr("stds", "optional", "listFloat", "all") \
.attr("wh_ratio_clip", "optional", "float", "all") \
.input(0, "rois", False, "required", "all") \
.input(1, "deltas", False, "required", "all") \
.output(0, "bboxes", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.attr("axis", "optional", "int", "all") \
.attr("exclusive", "optional", "bool", "all") \
.attr("reverse", "optional", "bool", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.get_op_info()
@op_info_register(random_choice_with_mask_op_info)
def _random_choice_with_mask_tbe():
"""RandomChoiceWithMask TBE register"""
@op_info_register(cumprop_op_info)
def _cumprop_tbe():
"""CumProd TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReduceProd op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
reduce_prod_op_info = TBERegOp("ReduceProd") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("reduce_prod_d.so") \
.compute_cost(10) \
.kernel_name("reduce_prod_d") \
.partial_flag(True) \
.attr("axis", "required", "listInt", "all") \
.attr("keep_dims", "optional", "bool", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_FracZ, DataType.I8_FracZ) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_FracZ, DataType.U8_FracZ) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \
.get_op_info()
@op_info_register(reduce_prod_op_info)
def _reduce_prod_tbe():
"""ReduceProd TBE register"""
return
......@@ -475,6 +475,7 @@ class CumProd(PrimitiveWithInfer):
cls_name = self.name
self.exclusive = validator.check_value_type("exclusive", exclusive, [bool], cls_name)
self.reverse = validator.check_value_type("reverse", reverse, [bool], cls_name)
self.init_prim_io_names(inputs=['x', 'axis'], outputs=['y'])
def infer_shape(self, x_shape, axis_shape):
return x_shape
......@@ -2022,8 +2023,10 @@ class NMSWithMask(PrimitiveWithInfer):
validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name)
if not self.is_ge:
validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 8, Rel.EQ, cls_name)
else:
validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name)
num = bboxes_shape[0]
return ((num, 5), (num,), (num,))
validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name)
num = bboxes_shape[0]
return (bboxes_shape, (num,), (num,))
......@@ -2171,8 +2174,8 @@ class SquareSumAll(PrimitiveWithInfer):
- **output_y2** (Tensor) - The same type as the `input_x1`.
Examples:
>>> input_x1 = Tensor(np.random.randint([3, 2, 5,7]), mindspore.float32)
>>> input_x2 = Tensor(np.random.randint([3, 2, 5,7]), mindspore.float32)
>>> input_x1 = Tensor(np.random.randint([3, 2, 5, 7]), mindspore.float32)
>>> input_x2 = Tensor(np.random.randint([3, 2, 5, 7]), mindspore.float32)
>>> square_sum_all = P.SquareSumAll()
>>> square_sum_all(input_x1, input_x2)
"""
......
......@@ -1721,15 +1721,21 @@ class ApplyRMSProp(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, use_locking=False):
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
self.init_prim_io_names(inputs=['var', 'mean_square', 'moment', 'learning_rate', 'grad',
'rho', 'momentum', 'epsilon'], outputs=['output'])
self.is_ge = context.get_context("enable_ge")
self.is_d = context.get_context("device_target") == "Ascend"
def infer_shape(self, var_shape, mean_square_shape, moment_shape, grad_shape, learning_rate_shape, decay_shape,
def infer_shape(self, var_shape, mean_square_shape, moment_shape, learning_rate_shape, grad_shape, decay_shape,
momentum_shape, epsilon_shape):
validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
if not self.is_ge and self.is_d:
return var_shape, var_shape, var_shape
return var_shape
def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, grad_dtype, learning_rate_dtype, decay_dtype,
def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, learning_rate_dtype, grad_dtype, decay_dtype,
momentum_dtype, epsilon_dtype):
args = {"var": var_dtype, "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
......@@ -1739,6 +1745,8 @@ class ApplyRMSProp(PrimitiveWithInfer):
validator.check_type_same(args_decay, valid_types, self.name)
args_lr = {"learning_rate": learning_rate_dtype, "decay": decay_dtype}
validator.check_scalar_or_tensor_type_same(args_lr, valid_types, self.name, allow_mix=True)
if not self.is_ge and self.is_d:
return var_dtype, var_dtype, var_dtype
return var_dtype
......
......@@ -37,7 +37,7 @@ class NetRMSProp(nn.Cell):
if self.use_centered:
return self.rms_opt(var, mg, rms, mom, g, lr, decay, momentum, epsilon)
else:
return self.rms_opt(var, rms, mom, g, lr, decay, momentum, epsilon)
return self.rms_opt(var, rms, mom, lr, g, decay, momentum, epsilon)
def rmsprop_numpy(variable, gradients, mean_square, moment,
......
......@@ -202,6 +202,21 @@ class ApplyFtrlNet(nn.Cell):
out = self.apply_ftrl(self.var, self.accum, self.linear, grad, self.lr, self.l1, self.l2, self.lr_power)
return out
class ApplyRMSNet(nn.Cell):
def __init__(self):
super(ApplyRMSNet, self).__init__()
self.apply_rms = P.ApplyRMSProp()
self.lr = 0.001
self.rho = 0.0
self.momentum= 0.0
self.epsilon = 1e-10
self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
self.ms = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="ms")
self.moment = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="moment")
def construct(self, grad):
out = self.apply_rms(self.var, self.ms, self.moment, self.lr, grad, self.rho, self.momentum, self.epsilon)
return out
test_case_math_ops = [
('Neg', {
......@@ -914,9 +929,8 @@ test_case_nn_ops = [
'desc_bprop': [3, 3],
'skip': ['backward']}),
('ApplyRMSProp', {
'block': P.ApplyRMSProp(),
'desc_const': [0.9, 0.0, 1e-10, 0.001],
'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]],
'block': ApplyRMSNet(),
'desc_inputs': [[3, 3]],
'desc_bprop': [3, 3],
'skip': ['backward']}),
('ApplyCenteredRMSProp', {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册