From 33c277e6d2136fc41f2f5022894337c5ea39425a Mon Sep 17 00:00:00 2001 From: mapingshuo Date: Tue, 17 Mar 2020 10:56:03 +0800 Subject: [PATCH] add sgd op (#3187) * add sgd op, test=develop * test=develop --- lite/kernels/arm/CMakeLists.txt | 1 + lite/kernels/arm/sgd_compute.cc | 54 +++++++++++++++ lite/kernels/arm/sgd_compute.h | 38 +++++++++++ lite/operators/CMakeLists.txt | 1 + lite/operators/sgd_op.cc | 5 +- lite/tests/kernels/CMakeLists.txt | 1 + lite/tests/kernels/sgd_compute_test.cc | 95 ++++++++++++++++++++++++++ 7 files changed, 194 insertions(+), 1 deletion(-) create mode 100644 lite/kernels/arm/sgd_compute.cc create mode 100644 lite/kernels/arm/sgd_compute.h create mode 100644 lite/tests/kernels/sgd_compute_test.cc diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index 15b48274d3..514d6069b5 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -109,6 +109,7 @@ add_kernel(mean_compute_arm ARM extra SRCS mean_compute.cc DEPS ${lite_kernel_de if(LITE_WITH_TRAIN) add_kernel(mean_grad_compute_arm ARM extra SRCS mean_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(activation_grad_compute_arm ARM basic SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) + add_kernel(sgd_compute_arm ARM extra SRCS sgd_compute.cc DEPS ${lite_kernel_deps} math_arm) endif() lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) diff --git a/lite/kernels/arm/sgd_compute.cc b/lite/kernels/arm/sgd_compute.cc new file mode 100644 index 0000000000..8f045fca8f --- /dev/null +++ b/lite/kernels/arm/sgd_compute.cc @@ -0,0 +1,54 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/sgd_compute.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void SGDCompute::Run() { + auto& param = this->Param(); + const auto* parameter = param.Param; + const auto* grad = param.Grad; + const auto* lr_tensor = param.LearningRate; + auto* parameter_output = param.ParamOut; + + auto dims = parameter->dims(); + auto parameter_data = parameter->data(); + auto grad_data = grad->data(); + auto lr = *(lr_tensor->data()); + auto parameter_out_data = parameter_output->mutable_data(); + + int element_num = dims.production(); +#pragma omp parallel for + for (int i = 0; i < element_num; i++) { + parameter_out_data[i] = parameter_data[i] - lr * grad_data[i]; + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + sgd, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::SGDCompute, def) + .BindInput("Param", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Grad", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("LearningRate", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("ParamOut", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/sgd_compute.h b/lite/kernels/arm/sgd_compute.h new file mode 100644 index 0000000000..bb3e3931da --- /dev/null +++ b/lite/kernels/arm/sgd_compute.h @@ -0,0 +1,38 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class SGDCompute : public KernelLite { + public: + using param_t = operators::SGDParam; + + SGDCompute() = default; + + void Run() override; + + virtual ~SGDCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index be1e7aaa14..34c7b8d666 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -144,6 +144,7 @@ add_operator(mean_op extra SRCS mean_op.cc DEPS ${op_DEPS}) if (LITE_WITH_TRAIN) add_operator(mean_grad_op extra SRCS mean_grad_op.cc DEPS ${op_DEPS}) add_operator(activation_grad_ops basic SRCS activation_grad_ops.cc DEPS ${op_DEPS}) + add_operator(sgd_op extra SRCS sgd_op.cc DEPS ${op_DEPS}) endif() if (NOT LITE_WITH_X86) diff --git a/lite/operators/sgd_op.cc b/lite/operators/sgd_op.cc index cf387def24..6214542595 100644 --- a/lite/operators/sgd_op.cc +++ b/lite/operators/sgd_op.cc @@ -25,11 +25,12 @@ bool SGDOpLite::CheckShape() const { CHECK_OR_FALSE(param_.LearningRate); CHECK_OR_FALSE(param_.Grad); CHECK_OR_FALSE(param_.ParamOut); + CHECK_EQ_OR_FALSE(param_.LearningRate->dims().production(), 1); + CHECK_EQ_OR_FALSE(param_.Param->dims(), param_.Grad->dims()); return true; } bool SGDOpLite::InferShape() const { - auto lr_dims = param_.LearningRate->dims().data(); param_.ParamOut->Resize(param_.Param->dims()); return true; } @@ -38,6 +39,8 @@ bool SGDOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { auto Param_name = opdesc.Input("Param").front(); auto LearningRate_name = opdesc.Input("LearningRate").front(); auto Grad_name = opdesc.Input("Grad").front(); + // param_out and param usually have the same name, + // and share the same memory auto ParamOut_name = opdesc.Output("ParamOut").front(); param_.Param = GetVar(scope, Param_name); diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index 8f3506c454..4ecab783a1 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -65,6 +65,7 @@ if(LITE_BUILD_EXTRA) if (LITE_WITH_TRAIN) lite_cc_test(test_kernel_mean_compute SRCS mean_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_activation_grad_compute SRCS activation_grad_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_sgd_compute SRCS sgd_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) endif() endif() diff --git a/lite/tests/kernels/sgd_compute_test.cc b/lite/tests/kernels/sgd_compute_test.cc new file mode 100644 index 0000000000..687494ed34 --- /dev/null +++ b/lite/tests/kernels/sgd_compute_test.cc @@ -0,0 +1,95 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" +#include "lite/tests/utils/fill_data.h" + +namespace paddle { +namespace lite { + +class SGDComputeTester : public arena::TestCase { + protected: + std::string param_ = "param"; + std::string param_out_ = "param_out"; + std::string grad_ = "grad"; + std::string lr_ = "learning_rate"; + float learning_rate_ = 0.01; + DDim dims_{{2, 5}}; + + public: + SGDComputeTester(const Place& place, + const std::string& alias, + DDim dims, + float learning_rate) + : TestCase(place, alias), dims_(dims), learning_rate_(learning_rate) {} + + void RunBaseline(Scope* scope) override { + auto param = scope->FindTensor(param_); + auto grad = scope->FindTensor(grad_); + auto lr = scope->FindTensor(lr_); + auto param_out = scope->NewTensor(param_out_); + CHECK(param_out); + + auto param_data = param->data(); + auto grad_data = grad->data(); + auto lr_data = *lr->data(); + + param_out->Resize(dims_); + auto param_out_data = param_out->mutable_data(); + + for (int i = 0; i < dims_.production(); i++) { + param_out_data[i] = param_data[i] - lr_data * grad_data[i]; + } + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("sgd"); + op_desc->SetInput("Param", {param_}); + op_desc->SetInput("Grad", {grad_}); + op_desc->SetInput("LearningRate", {lr_}); + op_desc->SetOutput("ParamOut", {param_out_}); + } + + void PrepareData() override { + std::vector param_data(dims_.production()); + fill_data_rand(param_data.data(), -1.f, 1.f, dims_.production()); + SetCommonTensor(param_, dims_, param_data.data()); + + std::vector grad_data(dims_.production()); + fill_data_rand(grad_data.data(), -1.f, 1.f, dims_.production()); + SetCommonTensor(grad_, dims_, grad_data.data()); + + std::vector lr_data(1); + lr_data[0] = learning_rate_; + SetCommonTensor(lr_, DDim{{1}}, lr_data.data()); + } +}; + +TEST(sgd, precision) { +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + std::vector dims{3, 2, 4, 1}; + float lr = 0.01; + std::unique_ptr tester( + new SGDComputeTester(place, "def", DDim(dims), lr)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); +#endif +} + +} // namespace lite +} // namespace paddle -- GitLab