diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index 2af70bd44b5a492df367c5c08faa9d9bc6177830..d3fd00d401fbf290a20cbc0bb0f3503883fdd281 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -177,6 +177,18 @@ void TbeAdapter::InputOrderPass(const std::string &op_name, std::vectorpush_back(inputs_list[i]); } + } else if (op_name == "ApplyCenteredRMSProp") { + // Parameter order of ApplyCenteredRMSProp's TBE implementation is different from python API, so map + // TBE parameter to correspond python API parameter by latter's index using hardcode + inputs_json->push_back(inputs_list[0]); + inputs_json->push_back(inputs_list[1]); + inputs_json->push_back(inputs_list[2]); + inputs_json->push_back(inputs_list[3]); + inputs_json->push_back(inputs_list[5]); + inputs_json->push_back(inputs_list[6]); + inputs_json->push_back(inputs_list[7]); + inputs_json->push_back(inputs_list[8]); + inputs_json->push_back(inputs_list[4]); } else { inputs_json->push_back(inputs_list[1]); inputs_json->push_back(inputs_list[0]); diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index dd3339fe8a74d4416b4a202e40b045d136545170..39876cabcac676b1ac6fa7482b07c5fc668d0d6a 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1807,18 +1807,23 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer): Examples: >>> centered_rms_prop = P.ApplyCenteredRMSProp() - >>> input_x = Tensor(1., mindspore.float32) - >>> mean_grad = Tensor(2., mindspore.float32) - >>> mean_square = Tensor(1., mindspore.float32) - >>> moment = Tensor(2., mindspore.float32) - >>> grad = Tensor(1., mindspore.float32) + >>> input_x = Tensor(np.arange(-6, 6).astype(np.float32).reshape(2, 3, 2), mindspore.float32) + >>> mean_grad = Tensor(np.arange(12).astype(np.float32).reshape(2, 3, 2), mindspore.float32) + >>> mean_square = Tensor(np.arange(-8, 4).astype(np.float32).reshape(2, 3, 2), mindspore.float32) + >>> moment = Tensor(np.arange(12).astype(np.float32).reshape(2, 3, 2), mindspore.float32) + >>> grad = Tensor(np.arange(12).astype(np.float32).rehspae(2, 3, 2), mindspore.float32) >>> learning_rate = Tensor(0.9, mindspore.float32) >>> decay = 0.0 >>> momentum = 1e-10 - >>> epsilon = 0.001 + >>> epsilon = 0.05 >>> result = centered_rms_prop(input_x, mean_grad, mean_square, moment, grad, >>> learning_rate, decay, momentum, epsilon) - -27.460497 + [[[ -6. -9.024922] + [-12.049845 -15.074766] + [-18.09969 -21.124613]] + [[-24.149532 -27.174456] + [-30.199379 -33.2243 ] + [-36.249226 -39.274143]]] """ @prim_attr_register