未验证 提交 33c277e6 编写于 作者: M mapingshuo 提交者: GitHub

add sgd op (#3187)

* add sgd op, test=develop

* test=develop
上级 e579287a
...@@ -109,6 +109,7 @@ add_kernel(mean_compute_arm ARM extra SRCS mean_compute.cc DEPS ${lite_kernel_de ...@@ -109,6 +109,7 @@ add_kernel(mean_compute_arm ARM extra SRCS mean_compute.cc DEPS ${lite_kernel_de
if(LITE_WITH_TRAIN) if(LITE_WITH_TRAIN)
add_kernel(mean_grad_compute_arm ARM extra SRCS mean_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(mean_grad_compute_arm ARM extra SRCS mean_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(activation_grad_compute_arm ARM basic SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(activation_grad_compute_arm ARM basic SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sgd_compute_arm ARM extra SRCS sgd_compute.cc DEPS ${lite_kernel_deps} math_arm)
endif() endif()
lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/sgd_compute.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void SGDCompute::Run() {
auto& param = this->Param<param_t>();
const auto* parameter = param.Param;
const auto* grad = param.Grad;
const auto* lr_tensor = param.LearningRate;
auto* parameter_output = param.ParamOut;
auto dims = parameter->dims();
auto parameter_data = parameter->data<float>();
auto grad_data = grad->data<float>();
auto lr = *(lr_tensor->data<float>());
auto parameter_out_data = parameter_output->mutable_data<float>();
int element_num = dims.production();
#pragma omp parallel for
for (int i = 0; i < element_num; i++) {
parameter_out_data[i] = parameter_data[i] - lr * grad_data[i];
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
sgd, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::SGDCompute, def)
.BindInput("Param", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Grad", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("LearningRate", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("ParamOut", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class SGDCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::SGDParam;
SGDCompute() = default;
void Run() override;
virtual ~SGDCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -144,6 +144,7 @@ add_operator(mean_op extra SRCS mean_op.cc DEPS ${op_DEPS}) ...@@ -144,6 +144,7 @@ add_operator(mean_op extra SRCS mean_op.cc DEPS ${op_DEPS})
if (LITE_WITH_TRAIN) if (LITE_WITH_TRAIN)
add_operator(mean_grad_op extra SRCS mean_grad_op.cc DEPS ${op_DEPS}) add_operator(mean_grad_op extra SRCS mean_grad_op.cc DEPS ${op_DEPS})
add_operator(activation_grad_ops basic SRCS activation_grad_ops.cc DEPS ${op_DEPS}) add_operator(activation_grad_ops basic SRCS activation_grad_ops.cc DEPS ${op_DEPS})
add_operator(sgd_op extra SRCS sgd_op.cc DEPS ${op_DEPS})
endif() endif()
if (NOT LITE_WITH_X86) if (NOT LITE_WITH_X86)
......
...@@ -25,11 +25,12 @@ bool SGDOpLite::CheckShape() const { ...@@ -25,11 +25,12 @@ bool SGDOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.LearningRate); CHECK_OR_FALSE(param_.LearningRate);
CHECK_OR_FALSE(param_.Grad); CHECK_OR_FALSE(param_.Grad);
CHECK_OR_FALSE(param_.ParamOut); CHECK_OR_FALSE(param_.ParamOut);
CHECK_EQ_OR_FALSE(param_.LearningRate->dims().production(), 1);
CHECK_EQ_OR_FALSE(param_.Param->dims(), param_.Grad->dims());
return true; return true;
} }
bool SGDOpLite::InferShape() const { bool SGDOpLite::InferShape() const {
auto lr_dims = param_.LearningRate->dims().data();
param_.ParamOut->Resize(param_.Param->dims()); param_.ParamOut->Resize(param_.Param->dims());
return true; return true;
} }
...@@ -38,6 +39,8 @@ bool SGDOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { ...@@ -38,6 +39,8 @@ bool SGDOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
auto Param_name = opdesc.Input("Param").front(); auto Param_name = opdesc.Input("Param").front();
auto LearningRate_name = opdesc.Input("LearningRate").front(); auto LearningRate_name = opdesc.Input("LearningRate").front();
auto Grad_name = opdesc.Input("Grad").front(); auto Grad_name = opdesc.Input("Grad").front();
// param_out and param usually have the same name,
// and share the same memory
auto ParamOut_name = opdesc.Output("ParamOut").front(); auto ParamOut_name = opdesc.Output("ParamOut").front();
param_.Param = GetVar<lite::Tensor>(scope, Param_name); param_.Param = GetVar<lite::Tensor>(scope, Param_name);
......
...@@ -65,6 +65,7 @@ if(LITE_BUILD_EXTRA) ...@@ -65,6 +65,7 @@ if(LITE_BUILD_EXTRA)
if (LITE_WITH_TRAIN) if (LITE_WITH_TRAIN)
lite_cc_test(test_kernel_mean_compute SRCS mean_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_mean_compute SRCS mean_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_activation_grad_compute SRCS activation_grad_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_activation_grad_compute SRCS activation_grad_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sgd_compute SRCS sgd_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif() endif()
endif() endif()
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
class SGDComputeTester : public arena::TestCase {
protected:
std::string param_ = "param";
std::string param_out_ = "param_out";
std::string grad_ = "grad";
std::string lr_ = "learning_rate";
float learning_rate_ = 0.01;
DDim dims_{{2, 5}};
public:
SGDComputeTester(const Place& place,
const std::string& alias,
DDim dims,
float learning_rate)
: TestCase(place, alias), dims_(dims), learning_rate_(learning_rate) {}
void RunBaseline(Scope* scope) override {
auto param = scope->FindTensor(param_);
auto grad = scope->FindTensor(grad_);
auto lr = scope->FindTensor(lr_);
auto param_out = scope->NewTensor(param_out_);
CHECK(param_out);
auto param_data = param->data<float>();
auto grad_data = grad->data<float>();
auto lr_data = *lr->data<float>();
param_out->Resize(dims_);
auto param_out_data = param_out->mutable_data<float>();
for (int i = 0; i < dims_.production(); i++) {
param_out_data[i] = param_data[i] - lr_data * grad_data[i];
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("sgd");
op_desc->SetInput("Param", {param_});
op_desc->SetInput("Grad", {grad_});
op_desc->SetInput("LearningRate", {lr_});
op_desc->SetOutput("ParamOut", {param_out_});
}
void PrepareData() override {
std::vector<float> param_data(dims_.production());
fill_data_rand(param_data.data(), -1.f, 1.f, dims_.production());
SetCommonTensor(param_, dims_, param_data.data());
std::vector<float> grad_data(dims_.production());
fill_data_rand(grad_data.data(), -1.f, 1.f, dims_.production());
SetCommonTensor(grad_, dims_, grad_data.data());
std::vector<float> lr_data(1);
lr_data[0] = learning_rate_;
SetCommonTensor(lr_, DDim{{1}}, lr_data.data());
}
};
TEST(sgd, precision) {
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
std::vector<int64_t> dims{3, 2, 4, 1};
float lr = 0.01;
std::unique_ptr<arena::TestCase> tester(
new SGDComputeTester(place, "def", DDim(dims), lr));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
#endif
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册