From 1607e87cb9ee09d48c5b2eabadbc6df0c2f5fa03 Mon Sep 17 00:00:00 2001 From: Chengmo Date: Tue, 13 Oct 2020 23:00:02 +0800 Subject: [PATCH] add xpu sgd & momentum (#27728) * add xpu sgd & momentum --- .../operators/optimizers/momentum_op_xpu.cc | 62 +++++++++++++++ .../fluid/operators/optimizers/sgd_op_xpu.cc | 79 +++++++++++++++++++ .../unittests/xpu/test_momentum_op_xpu.py | 68 ++++++++++++++++ .../tests/unittests/xpu/test_sgd_op_xpu.py | 75 ++++++++++++++++++ 4 files changed, 284 insertions(+) create mode 100644 paddle/fluid/operators/optimizers/momentum_op_xpu.cc create mode 100644 paddle/fluid/operators/optimizers/sgd_op_xpu.cc create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_momentum_op_xpu.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_sgd_op_xpu.py diff --git a/paddle/fluid/operators/optimizers/momentum_op_xpu.cc b/paddle/fluid/operators/optimizers/momentum_op_xpu.cc new file mode 100644 index 00000000000..38b06c39816 --- /dev/null +++ b/paddle/fluid/operators/optimizers/momentum_op_xpu.cc @@ -0,0 +1,62 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#ifdef PADDLE_WITH_XPU +#include +#include "paddle/fluid/operators/optimizers/sgd_op.h" +namespace paddle { +namespace operators { + +template +class MomentumOpXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + T mu = static_cast(ctx.Attr("mu")); + bool use_nesterov = ctx.Attr("use_nesterov"); + + auto learning_rate = ctx.Input("LearningRate"); + auto param = ctx.Input("Param"); + auto param_out = ctx.Output("ParamOut"); + auto* velocity = ctx.Input("Velocity"); + auto velocity_out = ctx.Output("VelocityOut"); + param_out->mutable_data(ctx.GetPlace()); + velocity_out->mutable_data(ctx.GetPlace()); + auto* lr = learning_rate->data(); + + auto* grad_var = ctx.InputVar("Grad"); + PADDLE_ENFORCE_EQ(grad_var->IsType(), true, + platform::errors::PermissionDenied( + "Unsupported Variable Type of Param & Grad in " + "MomentumOp-XPU. Excepted " + "LodTensor, But received [%s] and [%s]", + paddle::framework::ToTypeName(grad_var->Type()))); + + auto grad = ctx.Input("Grad"); + + auto& dev_ctx = ctx.template device_context(); + int r = xpu::momentum( + dev_ctx.x_context(), param->data(), velocity->data(), + grad->data(), lr, use_nesterov, mu, param_out->numel(), + param_out->data(), velocity_out->data()); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::PermissionDenied("XPU kernel error!")); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL( + momentum, + ops::MomentumOpXPUKernel); +#endif diff --git a/paddle/fluid/operators/optimizers/sgd_op_xpu.cc b/paddle/fluid/operators/optimizers/sgd_op_xpu.cc new file mode 100644 index 00000000000..1d78e561101 --- /dev/null +++ b/paddle/fluid/operators/optimizers/sgd_op_xpu.cc @@ -0,0 +1,79 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/operators/optimizers/sgd_op.h" +#include +namespace paddle { +namespace operators { + +template +class SGDOpXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const auto *learning_rate = ctx.Input("LearningRate"); + + const auto *param_var = ctx.InputVar("Param"); + const auto *grad_var = ctx.InputVar("Grad"); + + if (param_var->IsType() && + grad_var->IsType()) { + const auto *param = ctx.Input("Param"); + auto *param_out = ctx.Output("ParamOut"); + // Actually, all tensors are LoDTensor except SelectedRows. + const auto *grad = ctx.Input("Grad"); + auto sz = param_out->numel(); + PADDLE_ENFORCE_EQ(param->numel(), sz, + platform::errors::InvalidArgument( + "The input tensor Param's numel of SgdOp " + "should be equal with ParamOut's numel. " + "But received Param's " + "numel = [%s], ParamOut's numel = [%s]", + param->numel(), sz)); + PADDLE_ENFORCE_EQ(grad->numel(), sz, + platform::errors::InvalidArgument( + "The input tensor Grad's numel of SgdOp " + "should be equal with ParamOut's numel. " + "But received Grad's " + "numel = [%s], ParamOut's numel = [%s]", + grad->numel(), sz)); + + const T *lr = learning_rate->data(); + const T *param_data = param->data(); + const T *grad_data = grad->data(); + T *out_data = param_out->mutable_data(ctx.GetPlace()); + + auto &dev_ctx = ctx.template device_context(); + int r = xpu::sgd(dev_ctx.x_context(), sz, grad_data, param_data, lr, + out_data); + PADDLE_ENFORCE_EQ( + r, xpu::Error_t::SUCCESS, + platform::errors::PermissionDenied("XPU kernel error!")); + } else { + PADDLE_ENFORCE_EQ(false, true, + platform::errors::PermissionDenied( + "Unsupported Variable Type of Param & Grad in " + "SgdOp-XPU. Excepted " + "LodTensor, But received [%s] and [%s]", + paddle::framework::ToTypeName(param_var->Type()))); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL( + sgd, ops::SGDOpXPUKernel); +#endif diff --git a/python/paddle/fluid/tests/unittests/xpu/test_momentum_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_momentum_op_xpu.py new file mode 100644 index 00000000000..ccee79e8cd7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_momentum_op_xpu.py @@ -0,0 +1,68 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +import os +sys.path.append("..") +from op_test import OpTest +import paddle +from paddle.fluid import core +from paddle.fluid.op import Operator + + +class TestMomentumOp1(OpTest): + def setUp(self): + self.op_type = "momentum" + self.dtype = np.float32 + self.init_dtype() + + param = np.random.random((123, 321)).astype(self.dtype) + grad = np.random.random((123, 321)).astype(self.dtype) + velocity = np.zeros((123, 321)).astype(self.dtype) + learning_rate = np.array([0.001]).astype(self.dtype) + mu = 0.0001 + use_nesterov = False + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Velocity': velocity, + 'LearningRate': learning_rate + } + + self.attrs = {'mu': mu} + + velocity_out = mu * velocity + grad + if use_nesterov: + param_out = param - grad * learning_rate - \ + velocity_out * mu * learning_rate + else: + param_out = param - learning_rate * velocity_out + + self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out} + + def init_dtype(self): + pass + + def test_check_output_with_place(self): + self.check_output_with_place(paddle.XPUPlace(0)) + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_sgd_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_sgd_op_xpu.py new file mode 100644 index 00000000000..c29150ef921 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_sgd_op_xpu.py @@ -0,0 +1,75 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +import os +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from paddle.fluid import core +from paddle.fluid.op import Operator + + +class TestSGDOp(OpTest): + def setUp(self): + self.op_type = "sgd" + self.conf() + w = np.random.random((self.h, self.w)).astype("float32") + g = np.random.random((self.h, self.w)).astype("float32") + lr = np.array([0.1]).astype("float32") + + self.inputs = {'Param': w, 'Grad': g, 'LearningRate': lr} + self.outputs = {'ParamOut': w - lr * g} + + def conf(self): + self.h = 102 + self.w = 105 + + def test_check_output_with_place(self): + self.check_output_with_place(paddle.XPUPlace(0)) + + +class TestSGDOpCase8X(TestSGDOp): + def conf(self): + self.h = 10 + self.w = 64 + + +class TestSGDOpWithLargeInput(unittest.TestCase): + def runTest(self): + data = fluid.layers.fill_constant(shape=[1], value=128, dtype='int64') + label = fluid.layers.fill_constant( + shape=[1, 150], value=0.5, dtype='float32') + emb = fluid.embedding(input=data, size=(10000, 150), dtype='float32') + out = fluid.layers.l2_normalize(x=emb, axis=-1) + + cost = fluid.layers.square_error_cost(input=out, label=label) + avg_cost = fluid.layers.mean(cost) + sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) + sgd_optimizer.minimize(avg_cost) + + place = paddle.XPUPlace(0) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + result = exe.run(fluid.default_main_program(), fetch_list=[avg_cost]) + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() -- GitLab