提交 2df6eaf3 编写于 作者: Z zhouneng

fix applycenteredrmspop params map bug

上级 c51d90d8
......@@ -177,6 +177,18 @@ void TbeAdapter::InputOrderPass(const std::string &op_name, std::vector<std::vec
for (size_t i = 3; i < inputs_list.size(); ++i) {
inputs_json->push_back(inputs_list[i]);
}
} else if (op_name == "ApplyCenteredRMSProp") {
// Parameter order of ApplyCenteredRMSProp's TBE implementation is different from python API, so map
// TBE parameter to correspond python API parameter by latter's index using hardcode
inputs_json->push_back(inputs_list[0]);
inputs_json->push_back(inputs_list[1]);
inputs_json->push_back(inputs_list[2]);
inputs_json->push_back(inputs_list[3]);
inputs_json->push_back(inputs_list[5]);
inputs_json->push_back(inputs_list[6]);
inputs_json->push_back(inputs_list[7]);
inputs_json->push_back(inputs_list[8]);
inputs_json->push_back(inputs_list[4]);
} else {
inputs_json->push_back(inputs_list[1]);
inputs_json->push_back(inputs_list[0]);
......
......@@ -1807,18 +1807,23 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
Examples:
>>> centered_rms_prop = P.ApplyCenteredRMSProp()
>>> input_x = Tensor(1., mindspore.float32)
>>> mean_grad = Tensor(2., mindspore.float32)
>>> mean_square = Tensor(1., mindspore.float32)
>>> moment = Tensor(2., mindspore.float32)
>>> grad = Tensor(1., mindspore.float32)
>>> input_x = Tensor(np.arange(-6, 6).astype(np.float32).reshape(2, 3, 2), mindspore.float32)
>>> mean_grad = Tensor(np.arange(12).astype(np.float32).reshape(2, 3, 2), mindspore.float32)
>>> mean_square = Tensor(np.arange(-8, 4).astype(np.float32).reshape(2, 3, 2), mindspore.float32)
>>> moment = Tensor(np.arange(12).astype(np.float32).reshape(2, 3, 2), mindspore.float32)
>>> grad = Tensor(np.arange(12).astype(np.float32).rehspae(2, 3, 2), mindspore.float32)
>>> learning_rate = Tensor(0.9, mindspore.float32)
>>> decay = 0.0
>>> momentum = 1e-10
>>> epsilon = 0.001
>>> epsilon = 0.05
>>> result = centered_rms_prop(input_x, mean_grad, mean_square, moment, grad,
>>> learning_rate, decay, momentum, epsilon)
-27.460497
[[[ -6. -9.024922]
[-12.049845 -15.074766]
[-18.09969 -21.124613]]
[[-24.149532 -27.174456]
[-30.199379 -33.2243 ]
[-36.249226 -39.274143]]]
"""
@prim_attr_register
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册