From 9cbba97b3d3fcd4c2f4ca1bf8b6088df93af2cf9 Mon Sep 17 00:00:00 2001 From: lzzyzlbb <287246233@qq.com> Date: Wed, 18 Aug 2021 19:49:14 +0800 Subject: [PATCH] [NPU]add rmsprop op (#34864) * [npu]add rmsprop op --- .../operators/optimizers/rmsprop_op_npu.cc | 101 ++++++++++++ .../unittests/npu/test_rmsprop_op_npu.py | 152 ++++++++++++++++++ 2 files changed, 253 insertions(+) create mode 100644 paddle/fluid/operators/optimizers/rmsprop_op_npu.cc create mode 100644 python/paddle/fluid/tests/unittests/npu/test_rmsprop_op_npu.py diff --git a/paddle/fluid/operators/optimizers/rmsprop_op_npu.cc b/paddle/fluid/operators/optimizers/rmsprop_op_npu.cc new file mode 100644 index 0000000000..2edde1dd9c --- /dev/null +++ b/paddle/fluid/operators/optimizers/rmsprop_op_npu.cc @@ -0,0 +1,101 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/optimizers/rmsprop_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +class RMSPROPNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *grad_var = ctx.InputVar("Grad"); + auto *param_out = ctx.Output("ParamOut"); + auto *moment_out = ctx.Output("MomentOut"); + auto *mean_square_out = ctx.Output("MeanSquareOut"); + + param_out->mutable_data(ctx.GetPlace()); + moment_out->mutable_data(ctx.GetPlace()); + mean_square_out->mutable_data(ctx.GetPlace()); + + auto epsilon = static_cast(ctx.Attr("epsilon")); + auto rho = static_cast(ctx.Attr("decay")); + auto momentum = static_cast(ctx.Attr("momentum")); + auto *p_tensor = ctx.Input("Param"); + auto *ms_tensor = ctx.Input("MeanSquare"); + auto *lr_tensor = ctx.Input("LearningRate"); + auto *mom_tensor = ctx.Input("Moment"); + bool centered = ctx.Attr("centered"); + + auto stream = + ctx.template device_context() + .stream(); + if (grad_var->IsType()) { + auto *grad_tensor = ctx.Input("Grad"); + if (centered) { + framework::NPUAttributeMap attr_input = {{"use_locking", false}}; + const Tensor *rho_tensor = nullptr; + const Tensor *momentum_tensor = nullptr; + const Tensor *epsilon_tensor = nullptr; + Tensor rho_tmp(framework::proto::VarType::FP32); + rho_tmp.mutable_data({1}, ctx.GetPlace()); + FillNpuTensorWithConstant(&rho_tmp, rho); + rho_tensor = &rho_tmp; + Tensor momentum_tmp(framework::proto::VarType::FP32); + momentum_tmp.mutable_data({1}, ctx.GetPlace()); + FillNpuTensorWithConstant(&momentum_tmp, momentum); + momentum_tensor = &momentum_tmp; + Tensor epsilon_tmp(framework::proto::VarType::FP32); + epsilon_tmp.mutable_data({1}, ctx.GetPlace()); + FillNpuTensorWithConstant(&epsilon_tmp, epsilon); + epsilon_tensor = &epsilon_tmp; + auto *mg_tensor = ctx.Input("MeanGrad"); + auto *mean_grad_out = ctx.Output("MeanGradOut"); + mean_grad_out->mutable_data(ctx.GetPlace()); + const auto &runner_applycenterrmsprop = NpuOpRunner( + std::string("ApplyCenteredRMSPropD"), + {*p_tensor, *mg_tensor, *ms_tensor, *mom_tensor, *lr_tensor, + *rho_tensor, *momentum_tensor, *epsilon_tensor, *grad_tensor}, + {*param_out, *mean_grad_out, *mean_square_out, *moment_out}, + {attr_input}); + runner_applycenterrmsprop.Run(stream); + } else { + framework::NPUAttributeMap attr_input = { + {"rho", rho}, {"momentum", momentum}, {"epsilon", epsilon}}; + const auto &runner_applyrmsprop = NpuOpRunner( + std::string("ApplyRMSPropD"), + {*p_tensor, *ms_tensor, *mom_tensor, *lr_tensor, *grad_tensor}, + {*param_out, *mean_square_out, *moment_out}, {attr_input}); + runner_applyrmsprop.Run(stream); + } + } else { + PADDLE_ENFORCE_EQ(false, true, + platform::errors::PermissionDenied( + "Unsupported Variable Type of Grad " + "in RmspropOp. Excepted LodTensor, " + "But received [%s]", + paddle::framework::ToTypeName(grad_var->Type()))); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_NPU_KERNEL( + rmsprop, ops::RMSPROPNPUKernel) diff --git a/python/paddle/fluid/tests/unittests/npu/test_rmsprop_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_rmsprop_op_npu.py new file mode 100644 index 0000000000..8bdf841c5c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_rmsprop_op_npu.py @@ -0,0 +1,152 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import numpy as np +import paddle.fluid.core as core +from paddle.fluid.op import Operator +import paddle.fluid as fluid +import paddle + +paddle.enable_static() +SEED = 2021 + + +class TestNet(unittest.TestCase): + def _test(self, run_npu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + a_np = np.random.random(size=(32, 32)).astype('float32') + b_np = np.random.random(size=(32, 32)).astype('float32') + label_np = np.random.randint(2, size=(32, 1)).astype('int64') + + with paddle.static.program_guard(main_prog, startup_prog): + a = paddle.static.data(name="a", shape=[32, 32], dtype='float32') + b = paddle.static.data(name="b", shape=[32, 32], dtype='float32') + label = paddle.static.data( + name="label", shape=[32, 1], dtype='int64') + + sum = paddle.add(a, b) + z = paddle.pow(sum, 2.0) + + fc_1 = fluid.layers.fc(input=z, size=128) + prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax') + + cost = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.reduce_mean(cost) + rmsprop = fluid.optimizer.RMSProp(learning_rate=0.01) + rmsprop.minimize(loss) + + if run_npu: + place = paddle.NPUPlace(0) + else: + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + print("Start run on {}".format(place)) + for epoch in range(100): + + pred_res, loss_res = exe.run( + main_prog, + feed={"a": a_np, + "b": b_np, + "label": label_np}, + fetch_list=[prediction, loss]) + if epoch % 10 == 0: + print("Epoch {} | Prediction[0]: {}, Loss: {}".format( + epoch, pred_res[0], loss_res)) + + return pred_res, loss_res + + def test_npu(self): + cpu_pred, cpu_loss = self._test(False) + npu_pred, npu_loss = self._test(True) + + self.assertTrue(np.allclose(npu_pred, cpu_pred, rtol=1e-3)) + self.assertTrue(np.allclose(npu_loss, cpu_loss, rtol=1e-3)) + + +class TestCenteredNet(unittest.TestCase): + def _test(self, run_npu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + a_np = np.random.random(size=(32, 32)).astype('float32') + b_np = np.random.random(size=(32, 32)).astype('float32') + label_np = np.random.randint(2, size=(32, 1)).astype('int64') + + with paddle.static.program_guard(main_prog, startup_prog): + a = paddle.static.data(name="a", shape=[32, 32], dtype='float32') + b = paddle.static.data(name="b", shape=[32, 32], dtype='float32') + label = paddle.static.data( + name="label", shape=[32, 1], dtype='int64') + + sum = paddle.add(a, b) + z = paddle.pow(sum, 2.0) + + fc_1 = fluid.layers.fc(input=z, size=128) + prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax') + + cost = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.reduce_mean(cost) + rmsprop = fluid.optimizer.RMSProp(learning_rate=0.01, centered=True) + rmsprop.minimize(loss) + + if run_npu: + place = paddle.NPUPlace(0) + else: + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + print("Start run on {}".format(place)) + for epoch in range(100): + + pred_res, loss_res = exe.run( + main_prog, + feed={"a": a_np, + "b": b_np, + "label": label_np}, + fetch_list=[prediction, loss]) + if epoch % 10 == 0: + print("Epoch {} | Prediction[0]: {}, Loss: {}".format( + epoch, pred_res[0], loss_res)) + + return pred_res, loss_res + + def test_npu(self): + cpu_pred, cpu_loss = self._test(False) + npu_pred, npu_loss = self._test(True) + + self.assertTrue(np.allclose(npu_pred, cpu_pred, rtol=1e-3)) + self.assertTrue(np.allclose(npu_loss, cpu_loss, rtol=1e-3)) + + +if __name__ == "__main__": + unittest.main() -- GitLab