diff --git a/CMakeLists.txt b/CMakeLists.txt index b4bfe5981e5c3f524c6d781665c100e33b0713ca..bf1d35bc51e16cc2975edd905a2c938a4eb1af83 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,6 +61,7 @@ lite_option(LITE_WITH_ARM "Enable ARM in lite mode" OFF) lite_option(LITE_WITH_NPU "Enable NPU in lite mode" OFF) lite_option(LITE_WITH_XPU "Enable XPU in lite mode" OFF) lite_option(LITE_WITH_BM "Enable BM in lite mode" OFF) +lite_option(LITE_WITH_TRAIN "Enable training operators and kernels in lite" OFF) lite_option(LITE_WITH_OPENMP "Enable OpenMP in lite framework" ON) lite_option(LITE_WITH_OPENCL "Enable OpenCL support in lite" OFF) lite_option(LITE_WITH_FPGA "Enable FPGA support in lite" OFF) diff --git a/cmake/configure.cmake b/cmake/configure.cmake index 752b22461d9d1c36b3ca6a0bfe472a5dcc3ab976..d38c78f62fa2bed4f4483355de0683f1f5b7656b 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -122,6 +122,9 @@ if (LITE_WITH_ARM) endif() endif() +if (LITE_WITH_TRAIN) + add_definitions("-DLITE_WITH_TRAIN") +endif() if (WITH_ARM_DOTPROD) add_definitions("-DWITH_ARM_DOTPROD") diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index f6f7ec75e65ff54e3f3642822e51057d3522ae3a..b739c78f7c883d62b39d88ae1a7f4bf76ae8932c 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -333,16 +333,16 @@ lite::Tensor *Predictor::GetInputByName(const std::string &name) { } } -#ifdef LITE_WITH_TRAIN -void Predictor::FeedVars(const std::vector &tensors) { - auto var = scope_->FindVar("feed"); - auto &feed_list = *(var->GetMutable>()); - feed_list.resize(tensors.size()); - - for (size_t i = 0; i < tensors.size(); ++i) - feed_list[i].ShareDataWith(tensors[i]); -} -#endif +// #ifdef LITE_WITH_TRAIN +// void Predictor::FeedVars(const std::vector &tensors) { +// auto var = scope_->FindVar("feed"); +// auto &feed_list = *(var->GetMutable>()); +// feed_list.resize(tensors.size()); + +// for (size_t i = 0; i < tensors.size(); ++i) +// feed_list[i].ShareDataWith(tensors[i]); +// } +// #endif } // namespace lite } // namespace paddle diff --git a/lite/api/cxx_api.h b/lite/api/cxx_api.h index 504710d9fa29420b8762f31e0c675b59c6c626bd..e63893cb91e112beb6be50bd661a57b9738e5fb1 100644 --- a/lite/api/cxx_api.h +++ b/lite/api/cxx_api.h @@ -101,14 +101,14 @@ class LITE_API Predictor { bool record_info = false); void SaveOpKernelInfo(const std::string& model_dir); -#ifdef LITE_WITH_TRAIN - void Run(const std::vector& tensors) { - FeedVars(tensors); - program_->Run(); - } - - void FeedVars(const std::vector& tensors); -#endif + // #ifdef LITE_WITH_TRAIN + // void Run(const std::vector& tensors) { + // FeedVars(tensors); + // program_->Run(); + // } + + // void FeedVars(const std::vector& tensors); + // #endif private: Optimizer optimizer_; diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index 49078844354c8597ea489a986826edbdbcc9eb62..26ae22ce9d27cffcb6adf53ca16b01181edddf9e 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -105,6 +105,11 @@ add_kernel(lod_reset_compute_arm ARM extra SRCS lod_reset_compute.cc DEPS ${lite add_kernel(is_empty_compute_arm ARM extra SRCS is_empty_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(lstm_arm ARM extra SRCS lstm_compute.cc DEPS ${lite_kernel_deps} math_arm) +# 4. training kernels +add_kernel(mean_compute_arm ARM extra SRCS mean_compute.cc DEPS ${lite_kernel_deps} math_arm) +if(LITE_WITH_TRAIN) + add_kernel(mean_grad_compute_arm ARM extra SRCS mean_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) +endif() lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) diff --git a/lite/kernels/arm/mean_compute.cc b/lite/kernels/arm/mean_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..5d1a5b8508c4cd7bd36bf7809c445e57bbef428e --- /dev/null +++ b/lite/kernels/arm/mean_compute.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/mean_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void MeanCompute::Run() { + auto& param = this->Param(); + const auto* input = param.X; + auto* output = param.Out; + auto x_dim = input->dims(); + auto x_data = input->data(); + auto out_data = output->mutable_data(); + + int x_size = x_dim.production(); + float sum = 0; + for (int i = 0; i < x_size; i++) { + sum += x_data[i]; + } + out_data[0] = sum / x_size; +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + mean, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::MeanCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/mean_compute.h b/lite/kernels/arm/mean_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..ba4650c6c84e270ba9b174cc09fba4bd63b486f5 --- /dev/null +++ b/lite/kernels/arm/mean_compute.h @@ -0,0 +1,36 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" +#include "lite/operators/mean_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class MeanCompute : public KernelLite { + public: + using param_t = operators::MeanParam; + + void Run() override; + + virtual ~MeanCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/arm/mean_grad_compute.cc b/lite/kernels/arm/mean_grad_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..f7a5be8be1ebd4e02a188ab40026de04b319c76e --- /dev/null +++ b/lite/kernels/arm/mean_grad_compute.cc @@ -0,0 +1,54 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/mean_grad_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void MeanGradCompute::Run() { + auto& param = this->Param(); + const auto* input = param.X; + const auto* out_grad = param.Out_grad; + auto* input_grad = param.X_grad; + + auto out_grad_data = out_grad->data(); + auto input_data = input->data(); + auto input_grad_data = input_grad->mutable_data(); + + int input_grad_size = input_grad->dims().production(); + + // TODO(mapingshuo): use parallel methods to accelerate this for loop + for (int i = 0; i < input_grad_size; i++) { + input_grad_data[i] = out_grad_data[0] / input_grad_size; + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(mean_grad, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::MeanGradCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/mean_grad_compute.h b/lite/kernels/arm/mean_grad_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..49a72e9520df95082ba0a518b664b21613baf153 --- /dev/null +++ b/lite/kernels/arm/mean_grad_compute.h @@ -0,0 +1,36 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" +#include "lite/operators/mean_grad_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class MeanGradCompute : public KernelLite { + public: + using param_t = operators::MeanGradParam; + + void Run() override; + + virtual ~MeanGradCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 512473c7d52d0f43b01611048012a8c54e2b7244..36d69e68b5c43b2da3fd3794186cb8c46a78dabc 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -18,7 +18,6 @@ add_operator(activation_ops basic SRCS activation_ops.cc DEPS ${op_DEPS}) add_operator(elementwise_ops basic SRCS elementwise_ops.cc DEPS ${op_DEPS}) add_operator(box_coder_op_lite basic SRCS box_coder_op.cc DEPS ${op_DEPS}) add_operator(multiclass_nms_op_lite basic SRCS multiclass_nms_op.cc DEPS ${op_DEPS}) -add_operator(mean_op basic SRCS mean_op.cc DEPS ${op_DEPS}) add_operator(fill_constant_op basic SRCS fill_constant_op.cc DEPS ${op_DEPS}) add_operator(fill_constant_batch_size_like_op basic SRCS fill_constant_batch_size_like_op.cc DEPS ${op_DEPS}) add_operator(shuffle_channel_op basic SRCS shuffle_channel_op.cc DEPS ${op_DEPS}) @@ -139,6 +138,12 @@ add_operator(sequence_topk_avg_pooling_op basic SRCS sequence_topk_avg_pooling_o add_operator(search_fc_op basic SRCS search_fc_op.cc DEPS ${op_DEPS}) add_operator(lstm_op extra SRCS lstm_op.cc DEPS ${op_DEPS}) +# 4. training op +add_operator(mean_op extra SRCS mean_op.cc DEPS ${op_DEPS}) +if (LITE_WITH_TRAIN) + add_operator(mean_grad_op extra SRCS mean_grad_op.cc DEPS ${op_DEPS}) +endif() + if (NOT LITE_WITH_X86) lite_cc_test(test_fc_op SRCS fc_op_test.cc DEPS fc_op memory diff --git a/lite/operators/activation_ops.cc b/lite/operators/activation_ops.cc index 4ccb1be2a96a06ec56adacde64b963c9038871d1..0e31d758b728dc9d87b07b3f0833e0512b38a19d 100644 --- a/lite/operators/activation_ops.cc +++ b/lite/operators/activation_ops.cc @@ -78,45 +78,45 @@ bool ActivationOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { return true; } -#ifdef LITE_WITH_TRAIN - -bool ActivationGradOp::CheckShape() const { - CHECK_OR_FALSE(param_.X_grad); - CHECK_OR_FALSE(param_.Out_grad); - return true; -} - -bool ActivationGradOp::InferShape() const { - param_.X_grad->Resize(param_.Out_grad->dims()); - return true; -} - -bool ActivationGradOp::AttachImpl(const cpp::OpDesc& opdesc, - lite::Scope* scope) { - auto Out_grad_name = opdesc.Input(framework::GradVarName("Out")).front(); - auto X_grad_name = opdesc.Output(framework::GradVarName("X")).front(); - - param_.Out_grad = GetVar(scope, Out_grad_name); - param_.X_grad = GetMutableVar(scope, X_grad_name); - - if (opdesc.HasInput("X")) { - auto X_name = opdesc.Input("X").front(); - param_.X = GetVar(scope, X_name); - } else { - param_.X = param_.X_grad; - } - - if (opdesc.HasInput("Out")) { - auto Out_name = opdesc.Input("Out").front(); - param_.Out = GetVar(scope, Out_name); - } else { - param_.Out = param_.Out_grad; - } - - return true; -} - -#endif +// #ifdef LITE_WITH_TRAIN + +// bool ActivationGradOp::CheckShape() const { +// CHECK_OR_FALSE(param_.X_grad); +// CHECK_OR_FALSE(param_.Out_grad); +// return true; +// } + +// bool ActivationGradOp::InferShape() const { +// param_.X_grad->Resize(param_.Out_grad->dims()); +// return true; +// } + +// bool ActivationGradOp::AttachImpl(const cpp::OpDesc& opdesc, +// lite::Scope* scope) { +// auto Out_grad_name = opdesc.Input(framework::GradVarName("Out")).front(); +// auto X_grad_name = opdesc.Output(framework::GradVarName("X")).front(); + +// param_.Out_grad = GetVar(scope, Out_grad_name); +// param_.X_grad = GetMutableVar(scope, X_grad_name); + +// if (opdesc.HasInput("X")) { +// auto X_name = opdesc.Input("X").front(); +// param_.X = GetVar(scope, X_name); +// } else { +// param_.X = param_.X_grad; +// } + +// if (opdesc.HasInput("Out")) { +// auto Out_name = opdesc.Input("Out").front(); +// param_.Out = GetVar(scope, Out_name); +// } else { +// param_.Out = param_.Out_grad; +// } + +// return true; +// } + +// #endif } // namespace operators } // namespace lite @@ -139,6 +139,6 @@ REGISTER_LITE_OP(rsqrt, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(softsign, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(gelu, paddle::lite::operators::ActivationOp); -#ifdef LITE_WITH_TRAIN -REGISTER_LITE_OP(square_grad, paddle::lite::operators::ActivationGradOp); -#endif +// #ifdef LITE_WITH_TRAIN +// REGISTER_LITE_OP(square_grad, paddle::lite::operators::ActivationGradOp); +// #endif diff --git a/lite/operators/activation_ops.h b/lite/operators/activation_ops.h index 7ff91f7bcd2dce0fbdc4b5e8e4573ecc52387d72..87e1f8e72c74a9ac63eb83ca2f2f4c72f4d80e53 100644 --- a/lite/operators/activation_ops.h +++ b/lite/operators/activation_ops.h @@ -38,25 +38,26 @@ class ActivationOp : public OpLite { mutable operators::ActivationParam param_; }; -#ifdef LITE_WITH_TRAIN -class ActivationGradOp : public OpLite { - public: - explicit ActivationGradOp(const std::string& type) : OpLite(type) {} +// #ifdef LITE_WITH_TRAIN +// class ActivationGradOp : public OpLite { +// public: +// explicit ActivationGradOp(const std::string& type) : OpLite(type) {} - bool CheckShape() const override; +// bool CheckShape() const override; - bool InferShape() const override; +// bool InferShape() const override; - bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; +// bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; - void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } +// void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); +// } - std::string DebugString() const override { return "activation_grad_op"; } +// std::string DebugString() const override { return "activation_grad_op"; } - private: - mutable operators::ActivationGradParam param_; -}; -#endif +// private: +// mutable operators::ActivationGradParam param_; +// }; +// #endif } // namespace operators } // namespace lite diff --git a/lite/operators/elementwise_ops.cc b/lite/operators/elementwise_ops.cc index c5613ae52d3aadd6d863cadecac1870c40ed4cbf..3dc6f06955d421bc1f25994139cfee5dee9bc472 100644 --- a/lite/operators/elementwise_ops.cc +++ b/lite/operators/elementwise_ops.cc @@ -96,39 +96,39 @@ bool ElementwiseOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { return true; } -#ifdef LITE_WITH_TRAIN -bool ElementwiseGradExplicitOp::CheckShape() const { - CHECK_OR_FALSE(param_.Y); - CHECK_OR_FALSE(param_.X_grad); - CHECK_OR_FALSE(param_.Out_grad); - return true; -} +// #ifdef LITE_WITH_TRAIN +// bool ElementwiseGradExplicitOp::CheckShape() const { +// CHECK_OR_FALSE(param_.Y); +// CHECK_OR_FALSE(param_.X_grad); +// CHECK_OR_FALSE(param_.Out_grad); +// return true; +//} -bool ElementwiseGradExplicitOp::InferShape() const { - param_.X_grad->Resize(param_.Out_grad->dims()); - if (param_.Y_grad) param_.Y_grad->Resize(param_.Y->dims()); - return true; -} +// bool ElementwiseGradExplicitOp::InferShape() const { +// param_.X_grad->Resize(param_.Out_grad->dims()); +// if (param_.Y_grad) param_.Y_grad->Resize(param_.Y->dims()); +// return true; +// } -bool ElementwiseGradExplicitOp::AttachImpl(const cpp::OpDesc& opdesc, - lite::Scope* scope) { - CHECK_EQ(opdesc.InputArgumentNames().size(), 2UL); - auto Y_name = opdesc.Input("Y").front(); - auto Out_name = opdesc.Input(framework::GradVarName("Out")).front(); - auto X_grad = opdesc.Output(framework::GradVarName("X")).front(); +// bool ElementwiseGradExplicitOp::AttachImpl(const cpp::OpDesc& opdesc, +// lite::Scope* scope) { +// CHECK_EQ(opdesc.InputArgumentNames().size(), 2UL); +// auto Y_name = opdesc.Input("Y").front(); +// auto Out_name = opdesc.Input(framework::GradVarName("Out")).front(); +// auto X_grad = opdesc.Output(framework::GradVarName("X")).front(); - if (opdesc.Output(framework::GradVarName("Y")).size() > 0) { - auto Y_grad = opdesc.Output(framework::GradVarName("Y")).front(); - param_.Y_grad = GetMutableVar(scope, Y_grad); - } - param_.Y = GetVar(scope, Y_name); - param_.Out_grad = GetVar(scope, Out_name); - param_.X_grad = GetMutableVar(scope, X_grad); - param_.axis = opdesc.GetAttr("axis"); +// if (opdesc.Output(framework::GradVarName("Y")).size() > 0) { +// auto Y_grad = opdesc.Output(framework::GradVarName("Y")).front(); +// param_.Y_grad = GetMutableVar(scope, Y_grad); +// } +// param_.Y = GetVar(scope, Y_name); +// param_.Out_grad = GetVar(scope, Out_name); +// param_.X_grad = GetMutableVar(scope, X_grad); +// param_.axis = opdesc.GetAttr("axis"); - return true; -} -#endif +// return true; +// } +// #endif } // namespace operators } // namespace lite @@ -141,7 +141,9 @@ REGISTER_LITE_OP(elementwise_mul, paddle::lite::operators::ElementwiseOp); REGISTER_LITE_OP(elementwise_max, paddle::lite::operators::ElementwiseOp); REGISTER_LITE_OP(elementwise_div, paddle::lite::operators::ElementwiseOp); -#ifdef LITE_WITH_TRAIN -REGISTER_LITE_OP(elementwise_sub_grad, - paddle::lite::operators::ElementwiseGradExplicitOp); -#endif +// #ifdef LITE_WITH_TRAIN +// REGISTER_LITE_OP(elementwise_sub_grad, +// paddle::lite::operators::ElementwiseGradExplicitOp); +// REGISTER_LITE_OP(elementwise_add_grad, +// paddle::lite::operators::ElementwiseGradExplicitOp); +// #endif diff --git a/lite/operators/elementwise_ops.h b/lite/operators/elementwise_ops.h index b86d35e282c893b422677395dffe871a0d7f829b..d888e3d1c14b5d3129e01d12c75e1f590c17f297 100644 --- a/lite/operators/elementwise_ops.h +++ b/lite/operators/elementwise_ops.h @@ -39,27 +39,29 @@ class ElementwiseOp : public OpLite { mutable operators::ElementwiseParam param_; }; -#ifdef LITE_WITH_TRAIN -class ElementwiseGradExplicitOp : public OpLite { - public: - explicit ElementwiseGradExplicitOp(const std::string& type) : OpLite(type) {} +// #ifdef LITE_WITH_TRAIN +// class ElementwiseGradExplicitOp : public OpLite { +// public: +// explicit ElementwiseGradExplicitOp(const std::string& type) : OpLite(type) +// {} - bool CheckShape() const override; +// bool CheckShape() const override; - bool InferShape() const override; +// bool InferShape() const override; - bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; +// bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; - void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } +// void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); +// } - std::string DebugString() const override { - return "elementwise_grad_explicit_op"; - } +// std::string DebugString() const override { +// return "elementwise_grad_explicit_op"; +// } - private: - mutable operators::ElementwiseGradParam param_; -}; -#endif +// private: +// mutable operators::ElementwiseGradParam param_; +// }; +// #endif } // namespace operators } // namespace lite diff --git a/lite/operators/fusion_elementwise_activation_ops.cc b/lite/operators/fusion_elementwise_activation_ops.cc index b82c6454b4d8e9ca1af45374ea05925cbbadf0ed..244394b95aafede6956bc548430f5c14f28ae910 100644 --- a/lite/operators/fusion_elementwise_activation_ops.cc +++ b/lite/operators/fusion_elementwise_activation_ops.cc @@ -100,8 +100,8 @@ REGISTER_LITE_OP(fusion_elementwise_max_activation, REGISTER_LITE_OP(fusion_elementwise_div_activation, paddle::lite::operators::FusionElementwiseActivationOp); -#ifdef LITE_WITH_TRAIN -REGISTER_LITE_OP( - fusion_elementwise_sub_activation_grad, - paddle::lite::operators::FusionElementwiseActivationGradExplicitOp); -#endif +// #ifdef LITE_WITH_TRAIN +// REGISTER_LITE_OP( +// fusion_elementwise_sub_activation_grad, +// paddle::lite::operators::FusionElementwiseActivationGradExplicitOp); +// #endif diff --git a/lite/operators/fusion_elementwise_activation_ops.h b/lite/operators/fusion_elementwise_activation_ops.h index 1999ebd7220c81c313492ae106812c0eb755cb6e..db521284f0fc96c542fd5e7104b045f83f837f97 100644 --- a/lite/operators/fusion_elementwise_activation_ops.h +++ b/lite/operators/fusion_elementwise_activation_ops.h @@ -43,28 +43,29 @@ class FusionElementwiseActivationOp : public OpLite { mutable operators::FusionElementwiseActivationParam param_; }; -#ifdef LITE_WITH_TRAIN -class FusionElementwiseActivationGradExplicitOp : public OpLite { - public: - explicit FusionElementwiseActivationGradExplicitOp(const std::string& type) - : OpLite(type) {} +// #ifdef LITE_WITH_TRAIN +// class FusionElementwiseActivationGradExplicitOp : public OpLite { +// public: +// explicit FusionElementwiseActivationGradExplicitOp(const std::string& type) +// : OpLite(type) {} - bool CheckShape() const override; +// bool CheckShape() const override; - bool InferShape() const override; +// bool InferShape() const override; - bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; +// bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; - void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } +// void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); +// } - std::string DebugString() const override { - return "fusion_elementwise_activation_grad_explicit_op"; - } +// std::string DebugString() const override { +// return "fusion_elementwise_activation_grad_explicit_op"; +// } - private: - mutable operators::FusionElementwiseActivationGradParam param_; -}; -#endif +// private: +// mutable operators::FusionElementwiseActivationGradParam param_; +// }; +// #endif } // namespace operators } // namespace lite diff --git a/lite/operators/mean_grad_op.cc b/lite/operators/mean_grad_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..fd17cac14fca153499a52e93f6f09ea44ea9a559 --- /dev/null +++ b/lite/operators/mean_grad_op.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/mean_grad_op.h" +#include +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool MeanGradOp::CheckShape() const { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Out_grad); + CHECK_OR_FALSE(param_.X_grad); + return true; +} + +bool MeanGradOp::InferShape() const { + param_.X_grad->Resize(param_.X->dims()); + return true; +} + +bool MeanGradOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { + CHECK_EQ(opdesc.InputArgumentNames().size(), 2UL); + auto X_name = opdesc.Input("X").front(); + auto Out_grad_name = opdesc.Input("Out@GRAD").front(); + auto X_grad_name = opdesc.Output("X@GRAD").front(); + + param_.X = GetVar(scope, X_name); + param_.Out_grad = GetVar(scope, Out_grad_name); + param_.X_grad = GetMutableVar(scope, X_grad_name); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(mean_grad, paddle::lite::operators::MeanGradOp); diff --git a/lite/operators/mean_grad_op.h b/lite/operators/mean_grad_op.h new file mode 100644 index 0000000000000000000000000000000000000000..1bd604518bfc088fc45566e393fd997ae4eed06e --- /dev/null +++ b/lite/operators/mean_grad_op.h @@ -0,0 +1,44 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +class MeanGradOp : public OpLite { + public: + explicit MeanGradOp(const std::string &type) : OpLite(type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "mean_grad"; } + + private: + mutable operators::MeanGradParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/mean_op.cc b/lite/operators/mean_op.cc index 33ad7ed7fe1c2c89339a689d4a6316d85307d871..618e9001db056b935de6aef8feff9125155d0e1a 100644 --- a/lite/operators/mean_op.cc +++ b/lite/operators/mean_op.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "lite/operators/mean_op.h" +#include #include "lite/core/op_lite.h" #include "lite/core/op_registry.h" @@ -19,82 +21,28 @@ namespace paddle { namespace lite { namespace operators { -class MeanOp : public OpLite { - public: - explicit MeanOp(const std::string& type) : OpLite(type) {} +bool MeanOp::CheckShape() const { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Out); + return true; +} - bool CheckShape() const override { - CHECK_OR_FALSE(param_.X); - CHECK_OR_FALSE(param_.Out); - return true; - } +bool MeanOp::InferShape() const { + param_.Out->Resize(std::vector{1}); + return true; +} - bool InferShape() const override { - param_.Out->Resize(std::vector{1}); - return true; - } +bool MeanOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { + auto X_name = opdesc.Input("X").front(); + auto Out_name = opdesc.Output("Out").front(); - bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { - auto X_name = opdesc.Input("X").front(); - auto Out_name = opdesc.Output("Out").front(); - - param_.X = GetVar(scope, X_name); - param_.Out = GetMutableVar(scope, Out_name); - return true; - } - - void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } - - std::string DebugString() const override { return "mean"; } - - private: - mutable operators::MeanParam param_; -}; - -#ifdef LITE_WITH_TRAIN -class MeanGradOp : public OpLite { - public: - explicit MeanGradOp(const std::string& type) : OpLite(type) {} - - bool CheckShape() const override { - CHECK_OR_FALSE(param_.X); - CHECK_OR_FALSE(param_.Out_grad); - CHECK_OR_FALSE(param_.X_grad); - return true; - } - - bool InferShape() const override { - param_.X_grad->Resize(param_.X->dims()); - // param_.X_grad->set_lod(param_.X->lod()); - return true; - } - - bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { - CHECK_EQ(opdesc.InputArgumentNames().size(), 2UL); - auto X_name = opdesc.Input("X").front(); - auto Out_grad_name = opdesc.Input(framework::GradVarName("Out")).front(); - auto X_grad_name = opdesc.Output(framework::GradVarName("X")).front(); - - param_.X = GetVar(scope, X_name); - param_.Out_grad = GetVar(scope, Out_grad_name); - param_.X_grad = GetMutableVar(scope, X_grad_name); - return true; - } - - void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } - - std::string DebugString() const override { return "mean_grad"; } - - private: - mutable operators::MeanGradParam param_; -}; -#endif + param_.X = GetVar(scope, X_name); + param_.Out = GetMutableVar(scope, Out_name); + return true; +} } // namespace operators } // namespace lite } // namespace paddle REGISTER_LITE_OP(mean, paddle::lite::operators::MeanOp); -#ifdef LITE_WITH_TRAIN -REGISTER_LITE_OP(mean_grad, paddle::lite::operators::MeanGradOp); -#endif diff --git a/lite/operators/mean_op.h b/lite/operators/mean_op.h new file mode 100644 index 0000000000000000000000000000000000000000..8526842f93cb1d01debad9c6cb28ec28b98e43e9 --- /dev/null +++ b/lite/operators/mean_op.h @@ -0,0 +1,44 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +class MeanOp : public OpLite { + public: + explicit MeanOp(const std::string &type) : OpLite(type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; + + std::string DebugString() const override { return "mean"; } + + private: + mutable operators::MeanParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index 6d63b7054176fae6d2c88cf2e330fce2c6f7eb6f..79005151caac88c44e10581655c9704015dffe8f 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -60,6 +60,9 @@ if(LITE_BUILD_EXTRA) lite_cc_test(test_kernel_lookup_table_compute SRCS lookup_table_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_lookup_table_dequant_compute SRCS lookup_table_dequant_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_gather_compute SRCS gather_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + # for training kernel + lite_cc_test(test_kernel_mean_compute SRCS mean_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + endif() lite_cc_test(test_kernel_pad2d_compute SRCS pad2d_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_prior_box_compute SRCS prior_box_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) diff --git a/lite/tests/kernels/mean_compute_test.cc b/lite/tests/kernels/mean_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..71eb86a6c15a77af92dc2b9c5da4fa2eb6aba406 --- /dev/null +++ b/lite/tests/kernels/mean_compute_test.cc @@ -0,0 +1,154 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" +#include "lite/tests/utils/fill_data.h" + +namespace paddle { +namespace lite { + +class MeanComputeTester : public arena::TestCase { + protected: + DDim input_dims_{{2, 5}}; + std::string input_ = "x"; + std::string output_ = "out"; + + public: + MeanComputeTester(const Place& place, + const std::string& alias, + const DDim& input_dims) + : TestCase(place, alias), input_dims_(input_dims) {} + + void RunBaseline(Scope* scope) override { + auto input = scope->FindTensor(input_); + auto output = scope->NewTensor(output_); + + std::vector out_dims{1}; + output->Resize(out_dims); + + auto input_data = input->data(); + auto output_data = output->mutable_data(); + + int x_size = input_dims_.production(); + float sum = 0; + for (int i = 0; i < x_size; i++) { + sum += input_data[i]; + } + output_data[0] = sum / x_size; + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("mean"); + op_desc->SetInput("X", {input_}); + op_desc->SetOutput("Out", {output_}); + } + + void PrepareData() override { + std::vector input(input_dims_.production()); + fill_data_rand(input.data(), -1.f, 1.f, input_dims_.production()); + SetCommonTensor(input_, input_dims_, input.data()); + } +}; + +void TestNormalCase(Place place, float abs_error = 2e-5) { + LOG(INFO) << "Test Mean"; + for (std::vector dims : std::vector>{ + {5}, {4, 5}, {3, 4, 5}, {2, 3, 4, 5}}) { + std::unique_ptr tester( + new MeanComputeTester(place, "def", DDim(dims))); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); + } +} + +#ifdef LITE_WITH_TRAIN +class MeanGradComputeTester : public arena::TestCase { + protected: + DDim input_dims_{{2, 5}}; + DDim output_grad_dims_{{1}}; + std::string input_ = "x"; + std::string input_grad_ = "x_grad"; + std::string output_grad_ = "out_grad"; + + public: + MeanGradComputeTester(const Place& place, + const std::string& alias, + const DDim& input_dims) + : TestCase(place, alias), input_dims_(input_dims) {} + + void RunBaseline(Scope* scope) override { + auto input = scope->FindTensor(input_); + auto output_grad = scope->FindTensor(output_grad_); + auto input_grad = scope->NewTensor(input_grad_); + + input_grad->Resize(input_dims_); + + auto input_data = input->data(); + auto output_grad_data = output_grad->data(); + auto input_grad_data = input_grad->mutable_data(); + + int x_size = input_dims_.production(); + float d_x = output_grad_data[0] / x_size; + + for (int i = 0; i < x_size; i++) { + input_grad_data[i] = d_x; + } + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("mean_grad"); + op_desc->SetInput("X", {input_}); + op_desc->SetInput("Out@GRAD", {output_grad_}); + op_desc->SetOutput("X@GRAD", {input_grad_}); + } + + void PrepareData() override { + std::vector input(input_dims_.production()); + fill_data_rand(input.data(), -1.f, 1.f, input_dims_.production()); + SetCommonTensor(input_, input_dims_, input.data()); + + std::vector output_grad(1); + fill_data_rand(output_grad.data(), -1.f, 1.f, 1); + SetCommonTensor(output_grad_, output_grad_dims_, output_grad.data()); + } +}; + +void TestGradNormalCase(Place place, float abs_error = 2e-5) { + LOG(INFO) << "Test Mean Grad"; + for (std::vector dims : std::vector>{ + {5}, {4, 5}, {3, 4, 5}, {2, 3, 4, 5}}) { + std::unique_ptr tester( + new MeanGradComputeTester(place, "def", DDim(dims))); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); + } +} +#endif + +TEST(Mean, precision) { +#ifdef LITE_WITH_ARM + float abs_error = 2e-5; + Place place(TARGET(kARM)); + TestNormalCase(place, abs_error); +#ifdef LITE_WITH_TRAIN + TestGradNormalCase(place, abs_error); +#endif +#endif +} + +} // namespace lite +} // namespace paddle diff --git a/lite/tools/ci_build.sh b/lite/tools/ci_build.sh index a4223215d5c10d6eda6dc54ec57981d4f1954e40..6105b138c6789544716d3908dd1bf4847ba81c9a 100755 --- a/lite/tools/ci_build.sh +++ b/lite/tools/ci_build.sh @@ -594,6 +594,7 @@ function cmake_arm { -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \ -DWITH_TESTING=ON \ -DLITE_BUILD_EXTRA=ON \ + -DLITE_WITH_TRAIN=ON \ -DARM_TARGET_OS=$1 -DARM_TARGET_ARCH_ABI=$2 -DARM_TARGET_LANG=$3 } diff --git a/lite/tools/cmake_tools/parse_op_registry.py b/lite/tools/cmake_tools/parse_op_registry.py index 7eb3337ed87b708102b2032de9a279fcae2d321c..44ee09c28ff70ada782b9393f4fc0d5c07943b2c 100644 --- a/lite/tools/cmake_tools/parse_op_registry.py +++ b/lite/tools/cmake_tools/parse_op_registry.py @@ -45,8 +45,6 @@ for path in paths: op_parser = RegisterLiteOpParser(str_info) ops = op_parser.parse() for op in ops: - if "_grad" in op: - continue if tailored == "ON": if op not in minlines: continue out = "USE_LITE_OP(%s);" % op