From ca45ed536a752ddc8683a80a4795848af95d0463 Mon Sep 17 00:00:00 2001 From: liuwei1031 <46661762+liuwei1031@users.noreply.github.com> Date: Wed, 29 May 2019 14:35:10 +0800 Subject: [PATCH] migrate several ops: (#17606) * migrate several ops: mean, mean_grad fill_constant square_grad elementwise_sub_grad mul_grad * add sdg_op * fix kernel platform registration issue * code cleanup * fix platform typo --- paddle/fluid/lite/api/CMakeLists.txt | 2 +- paddle/fluid/lite/api/cxx_api_test.cc | 2 +- paddle/fluid/lite/kernels/x86/CMakeLists.txt | 18 ++- .../lite/kernels/x86/activation_compute.cc | 26 ++- .../lite/kernels/x86/elementwise_compute.cc | 55 ++++++- .../lite/kernels/x86/fill_constant_compute.cc | 56 +++++++ paddle/fluid/lite/kernels/x86/mean_compute.cc | 98 ++++++++++++ paddle/fluid/lite/kernels/x86/mul_compute.cc | 149 ++++++++++++++++++ paddle/fluid/lite/kernels/x86/sgd_compute.cc | 77 +++++++++ paddle/fluid/lite/operators/CMakeLists.txt | 6 + paddle/fluid/lite/operators/activation_ops.cc | 37 +++++ .../fluid/lite/operators/elementwise_ops.cc | 50 +++++- .../fluid/lite/operators/fill_constant_op.cc | 61 +++++++ paddle/fluid/lite/operators/mean_op.cc | 98 ++++++++++++ paddle/fluid/lite/operators/mul_op.cc | 52 +++++- paddle/fluid/lite/operators/mul_op.h | 20 +++ paddle/fluid/lite/operators/op_params.h | 51 +++++- paddle/fluid/lite/operators/sgd_op.cc | 57 +++++++ paddle/fluid/lite/operators/sgd_op.h | 50 ++++++ 19 files changed, 933 insertions(+), 32 deletions(-) create mode 100644 paddle/fluid/lite/kernels/x86/fill_constant_compute.cc create mode 100644 paddle/fluid/lite/kernels/x86/mean_compute.cc create mode 100644 paddle/fluid/lite/kernels/x86/mul_compute.cc create mode 100644 paddle/fluid/lite/kernels/x86/sgd_compute.cc create mode 100644 paddle/fluid/lite/operators/fill_constant_op.cc create mode 100644 paddle/fluid/lite/operators/mean_op.cc create mode 100644 paddle/fluid/lite/operators/sgd_op.cc create mode 100644 paddle/fluid/lite/operators/sgd_op.h diff --git a/paddle/fluid/lite/api/CMakeLists.txt b/paddle/fluid/lite/api/CMakeLists.txt index cea38089fbe..d39950f2a03 100644 --- a/paddle/fluid/lite/api/CMakeLists.txt +++ b/paddle/fluid/lite/api/CMakeLists.txt @@ -25,11 +25,11 @@ set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inferenc set(LITE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo" CACHE STRING "A path setting inference demo download directories.") - # lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc # DEPS cxx_api_lite model_parser_lite target_wrapper_host # ${ops_lite} ${host_kernels} ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model # --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) + if(WITH_TESTING) lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_naive_model.tar.gz") # add_dependencies(test_cxx_api_lite extern_lite_download_lite_naive_model_tar_gz) diff --git a/paddle/fluid/lite/api/cxx_api_test.cc b/paddle/fluid/lite/api/cxx_api_test.cc index 1d243ba41f4..7a73982e9f7 100644 --- a/paddle/fluid/lite/api/cxx_api_test.cc +++ b/paddle/fluid/lite/api/cxx_api_test.cc @@ -92,7 +92,7 @@ TEST(CXXTrainer, train) { main_program_desc.ParseFromString(main_program_pb); startup_program_desc.ParseFromString(startup_program_pb); - LOG(INFO) << main_program_desc.DebugString(); + // LOG(INFO) << main_program_desc.DebugString(); for (const auto& op : main_program_desc.blocks(0).ops()) { LOG(INFO) << "get op " << op.type(); diff --git a/paddle/fluid/lite/kernels/x86/CMakeLists.txt b/paddle/fluid/lite/kernels/x86/CMakeLists.txt index 75133f19f44..4fd44525a40 100644 --- a/paddle/fluid/lite/kernels/x86/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/x86/CMakeLists.txt @@ -2,5 +2,19 @@ if(NOT LITE_WITH_X86) return() endif() -cc_library(activation_compute SRCS activation_compute.cc DEPS ${lite_kernel_deps} activation_op) -cc_library(elementwise_compute SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} elementwise_sub_op) +cc_library(activation_compute_x86 SRCS activation_compute.cc DEPS ${lite_kernel_deps} activation_op) +cc_library(elementwise_compute_x86 SRCS elementwise_compute.cc DEPS ${lite_kernel_deps}) +cc_library(mean_compute_x86 SRCS mean_compute.cc DEPS ${lite_kernel_deps}) +cc_library(fill_constant_compute_x86 SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps}) +cc_library(mul_compute_x86 SRCS mul_compute.cc DEPS ${lite_kernel_deps}) +cc_library(sgd_compute_x86 SRCS sgd_compute.cc DEPS ${lite_kernel_deps}) + +set(x86_kernels + activation_compute_x86 + elementwise_compute_x86 + mean_compute_x86 + fill_constant_compute_x86 + mul_compute_x86 + ) + +set(x86_kernels "${x86_kernels}" CACHE INTERNAL "x86 kernels") diff --git a/paddle/fluid/lite/kernels/x86/activation_compute.cc b/paddle/fluid/lite/kernels/x86/activation_compute.cc index 79f3829b61b..4ea1c0f6504 100644 --- a/paddle/fluid/lite/kernels/x86/activation_compute.cc +++ b/paddle/fluid/lite/kernels/x86/activation_compute.cc @@ -55,7 +55,7 @@ void ActivateGrad(const platform::CPUDeviceContext& context, } template -class SquareCompute : public KernelLite { +class SquareCompute : public KernelLite { public: using param_t = operators::ActivationParam; @@ -70,14 +70,11 @@ class SquareCompute : public KernelLite { ¶m.Out->raw_tensor()); } - // TargetType target() const override; - // PrecisionType precision() const override; - virtual ~SquareCompute() = default; }; template -class SquareGradCompute : public KernelLite { +class SquareGradCompute : public KernelLite { public: using param_t = operators::ActivationGradParam; @@ -93,9 +90,6 @@ class SquareGradCompute : public KernelLite { ¶m.X_grad->raw_tensor()); } - // TargetType target() const override; - // PrecisionType precision() const override; - virtual ~SquareGradCompute() = default; }; @@ -107,16 +101,16 @@ class SquareGradCompute : public KernelLite { // float REGISTER_LITE_KERNEL(square, kX86, kFloat, kNCHW, paddle::lite::kernels::x86::SquareCompute, def) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindInput("W", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); REGISTER_LITE_KERNEL(square_grad, kX86, kFloat, kNCHW, paddle::lite::kernels::x86::SquareGradCompute, def) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindInput("W", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput(paddle::framework::GradVarName("Out"), + {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput(paddle::framework::GradVarName("X"), + {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); diff --git a/paddle/fluid/lite/kernels/x86/elementwise_compute.cc b/paddle/fluid/lite/kernels/x86/elementwise_compute.cc index e2ca9a52df6..9e9b7a86b39 100644 --- a/paddle/fluid/lite/kernels/x86/elementwise_compute.cc +++ b/paddle/fluid/lite/kernels/x86/elementwise_compute.cc @@ -32,7 +32,7 @@ struct SubFunctor { template class ElementwiseSubCompute - : public KernelLite { + : public KernelLite { public: using param_t = operators::ElementwiseParam; @@ -49,22 +49,67 @@ class ElementwiseSubCompute ¶m.Out->raw_tensor()); } - // TargetType target() const override; - // PrecisionType precision() const override; - virtual ~ElementwiseSubCompute() = default; }; +template +struct SubGradDX { + T operator()(T x, T y, T out, T dout) const { return dout; } +}; + +template +struct SubGradDY { + T operator()(T x, T y, T out, T dout) const { return -dout; } +}; + +template +class ElementwiseSubGradCompute + : public KernelLite { + public: + using param_t = operators::ElementwiseGradParam; + + void Run() override { + auto& param = *param_.get_mutable(); + auto& context = context_->As(); + CHECK(context.x86_device_context); + + param.X_grad->template mutable_data(); + param.Y_grad->template mutable_data(); + // skip out, x, y + auto dout = param.Out_grad->raw_tensor(); + auto dx = param.X_grad->raw_tensor(); + auto dy = param.Y_grad->raw_tensor(); + auto& skip = dout; + paddle::operators::ElemwiseExplicitGradCompute< + platform::CPUDeviceContext, T, SubGradDX, SubGradDY>( + *context.x86_execution_context, skip, skip, skip, dout, param.axis, &dx, + &dy, SubGradDX(), SubGradDY()); + } + + virtual ~ElementwiseSubGradCompute() = default; +}; + } // namespace x86 } // namespace kernels } // namespace lite } // namespace paddle // float -REGISTER_LITE_KERNEL(square, kHost, kFloat, kNCHW, +REGISTER_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW, paddle::lite::kernels::x86::ElementwiseSubCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); + +REGISTER_LITE_KERNEL(elementwise_sub_grad, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::ElementwiseSubCompute, + def) + .BindInput(paddle::framework::GradVarName("Out"), + {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput(paddle::framework::GradVarName("X"), + {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput(paddle::framework::GradVarName("Y"), + {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/x86/fill_constant_compute.cc b/paddle/fluid/lite/kernels/x86/fill_constant_compute.cc new file mode 100644 index 00000000000..3f8a3fe11c4 --- /dev/null +++ b/paddle/fluid/lite/kernels/x86/fill_constant_compute.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/operators/activation_op.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class FillConstantCompute : public KernelLite { + public: + using param_t = operators::FillConstantParam; + + void Run() override { + auto& param = *param_.get_mutable(); + auto& context = context_->As(); + CHECK(context.x86_device_context); + + param.Out->template mutable_data(); + + paddle::operators::math::set_constant( + *context.x86_device_context, ¶m.Out->raw_tensor(), param.value); + } + + virtual ~FillConstantCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +// float +REGISTER_LITE_KERNEL(fill_constant, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::FillConstantCompute, + def) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/x86/mean_compute.cc b/paddle/fluid/lite/kernels/x86/mean_compute.cc new file mode 100644 index 00000000000..f1dbc4d53fc --- /dev/null +++ b/paddle/fluid/lite/kernels/x86/mean_compute.cc @@ -0,0 +1,98 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/operators/activation_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +using EigenScalar = framework::EigenScalar; +template +using EigenVector = framework::EigenVector; + +template +class MeanCompute : public KernelLite { + public: + using param_t = operators::MeanParam; + + void Run() override { + auto& param = *param_.get_mutable(); + auto& context = context_->As(); + CHECK(context.x86_device_context); + + param.Out->template mutable_data(); + + auto X = EigenVector::Flatten(param.X->raw_tensor()); + auto y = EigenScalar::From(param.Out->raw_tensor()); + const auto& place = *(context.x86_device_context->eigen_device()); + + y.device(place) = X.mean(); + } + + virtual ~MeanCompute() = default; +}; + +template +class MeanGradCompute : public KernelLite { + public: + using param_t = operators::MeanGradParam; + + void Run() override { + auto& param = *param_.get_mutable(); + auto& context = context_->As(); + CHECK_EQ(param.Out_grad->raw_tensor().numel(), 1); + CHECK(context.x86_device_context); + + param.X_grad->template mutable_data(); + T x_grad_size = static_cast(param.X_grad->raw_tensor().numel()); + Eigen::DSizes bcast(static_cast(x_grad_size)); + EigenVector::Flatten(param.X_grad->raw_tensor()) + .device(*(context.x86_device_context->eigen_device())) = + (EigenVector::From(param.Out_grad->raw_tensor()) / x_grad_size) + .broadcast(bcast); + } + + virtual ~MeanGradCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +// float +REGISTER_LITE_KERNEL(mean, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::MeanCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); + +REGISTER_LITE_KERNEL(mean_grad, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::MeanGradCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput(paddle::framework::GradVarName("Out"), + {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput(paddle::framework::GradVarName("X"), + {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/x86/mul_compute.cc b/paddle/fluid/lite/kernels/x86/mul_compute.cc new file mode 100644 index 00000000000..f0c962347fb --- /dev/null +++ b/paddle/fluid/lite/kernels/x86/mul_compute.cc @@ -0,0 +1,149 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/types.h" +#include "paddle/fluid/operators/math/blas.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +using Tensor = framework::Tensor; + +template +class MulCompute : public KernelLite { + public: + using param_t = operators::MulParam; + + void Run() override { + auto& context = context_->As(); + auto& param = *param_.get_mutable(); + CHECK(context.x86_device_context); + + param.output->template mutable_data(); + + auto* x = ¶m.x->raw_tensor(); + auto* y = ¶m.y->raw_tensor(); + + const Tensor x_matrix = x->dims().size() > 2 ? framework::ReshapeToMatrix( + *x, param.x_num_col_dims) + : *x; + const Tensor y_matrix = y->dims().size() > 2 ? framework::ReshapeToMatrix( + *y, param.y_num_col_dims) + : *y; + + auto* z = ¶m.output->raw_tensor(); + auto z_dim = z->dims(); + if (z_dim.size() != 2) { + z->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); + } + + auto blas = paddle::operators::math::GetBlas( + *context.x86_device_context); + + blas.MatMul(x_matrix, y_matrix, z); + if (z_dim.size() != 2) { + z->Resize(z_dim); + } + } + + virtual ~MulCompute() = default; +}; + +template +class MulGradCompute : public KernelLite { + public: + void Run() override { + auto& context = context_->As(); + auto& param = *param_.get_mutable(); + CHECK(context.x86_device_context); + + auto* x = ¶m.x->raw_tensor(); + auto* y = ¶m.y->raw_tensor(); + auto x_matrix = x->dims().size() > 2 + ? framework::ReshapeToMatrix(*x, param.x_num_col_dims) + : static_cast(*x); + auto y_matrix = y->dims().size() > 2 + ? framework::ReshapeToMatrix(*y, param.y_num_col_dims) + : static_cast(*y); + auto* dout = ¶m.output_grad->raw_tensor(); + + Tensor dout_mat; + dout_mat.ShareDataWith(*dout); + dout_mat.Resize( + {framework::flatten_to_2d(x->dims(), param.x_num_col_dims)[0], + framework::flatten_to_2d(y->dims(), param.y_num_col_dims)[1]}); + + auto* dx = ¶m.x_grad->raw_tensor(); + auto* dy = ¶m.y_grad->raw_tensor(); + + if (dx != nullptr) { + dx->set_lod(x->lod()); + } + if (dy != nullptr) { + dy->set_lod(y->lod()); + } + + auto blas = paddle::operators::math::GetBlas( + *context.x86_device_context); + if (dx) { + // dx->mutable_data(context.x86_device_context->GetPlace()); + param.x_grad->template mutable_data(); + Tensor dx_matrix = dx->dims().size() > 2 ? framework::ReshapeToMatrix( + *dx, param.x_num_col_dims) + : *dx; + + // dx = dout * y'. dx: M x K, dout : M x N, y : K x N + blas.MatMul(dout_mat, false, y_matrix, true, &dx_matrix); + } + if (dy) { + // dy->yutable_data(context.x86_device_context->GetPlace()); + param.y_grad->template mutable_data(); + Tensor dy_matrix = dy->dims().size() > 2 ? framework::ReshapeToMatrix( + *dy, param.y_num_col_dims) + : *dy; + // dy = x' * dout. dy K x N, dout : M x N, x : M x K + blas.MatMul(x_matrix, true, dout_mat, false, &dy_matrix); + } + } + + virtual ~MulGradCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(mul, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::MulCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); + +REGISTER_LITE_KERNEL(mul_grad, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::MulGradCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput(paddle::framework::GradVarName("Out"), + {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput(paddle::framework::GradVarName("X"), + {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput(paddle::framework::GradVarName("Y"), + {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/x86/sgd_compute.cc b/paddle/fluid/lite/kernels/x86/sgd_compute.cc new file mode 100644 index 00000000000..27261fd14d6 --- /dev/null +++ b/paddle/fluid/lite/kernels/x86/sgd_compute.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/operators/jit/kernels.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class SGDCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override { + auto &context = context_->As(); + auto &sgd_param = *param_.get_mutable(); + CHECK(context.x86_device_context); + + // param.Out->template mutable_data(); + + const auto *param = &sgd_param.Param->raw_tensor(); + const auto *grad = &sgd_param.Grad->raw_tensor(); + const auto *learning_rate = &sgd_param.LearningRate->raw_tensor(); + auto *param_out = &sgd_param.ParamOut->raw_tensor(); + + auto sz = param_out->numel(); + PADDLE_ENFORCE_EQ(param->numel(), sz); + PADDLE_ENFORCE_EQ(grad->numel(), sz); + + paddle::operators::jit::sgd_attr_t attr(1, sz, 1, sz, 1); + const T *lr = learning_rate->data(); + const T *param_data = param->data(); + const T *grad_data = grad->data(); + int64_t rows_idx = 0; + T *out_data = + param_out->mutable_data(context.x86_device_context->GetPlace()); + + auto sgd = + paddle::operators::jit::KernelFuncs, + platform::CPUPlace>::Cache() + .At(attr); + sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr); + } + + virtual ~SGDCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +// float +REGISTER_LITE_KERNEL(sgd, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::SGDCompute, def) + .BindInput("Param", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("LearningRate", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Grad", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("ParamOut", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index 4ae76cad646..40782010881 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -9,6 +9,9 @@ cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS}) cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS}) cc_library(activation_ops_lite SRCS activation_ops.cc DEPS ${op_DEPS}) cc_library(elementwise_ops_lite SRCS elementwise_ops.cc DEPS ${op_DEPS}) +cc_library(mean_op_lite SRCS mean_op.cc DEPS ${op_DEPS}) +cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS}) +cc_library(sgd_op_lite SRCS sgd_op.cc DEPS ${op_DEPS}) cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite) set(ops_lite @@ -19,6 +22,9 @@ set(ops_lite feed_op_lite fetch_op_lite io_copy_op_lite + elementwise_ops_lite + mean_op_lite + fill_constant_op_lite PARENT_SCOPE) lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite memory_lite) diff --git a/paddle/fluid/lite/operators/activation_ops.cc b/paddle/fluid/lite/operators/activation_ops.cc index e92bc8e6ec2..d53bb0c9e31 100644 --- a/paddle/fluid/lite/operators/activation_ops.cc +++ b/paddle/fluid/lite/operators/activation_ops.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/op_registry.h" @@ -36,16 +37,52 @@ class ActivationOp : public OpLite { param_.X = GetVar(scope, X_name); param_.Out = GetMutableVar(scope, Out_name); + return true; } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "activation_op"; } + private: mutable ActivationParam param_; }; +class ActivationGradOp : public OpLite { + public: + explicit ActivationGradOp(const std::string& type) : OpLite(type) {} + + bool CheckShape() const override { + CHECK_OR_FALSE(param_.X_grad); + CHECK_OR_FALSE(param_.Out_grad); + return true; + } + + bool InferShape() const override { + param_.X_grad->Resize(param_.Out_grad->dims()); + return true; + } + + bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { + auto Out_grad_name = opdesc.Input(framework::GradVarName("Out")).front(); + auto X_grad_name = opdesc.Output(framework::GradVarName("X")).front(); + + param_.Out_grad = GetVar(scope, Out_grad_name); + param_.X_grad = GetMutableVar(scope, X_grad_name); + return true; + } + + void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "activation_grad_op"; } + + private: + mutable ActivationGradParam param_; +}; + } // namespace operators } // namespace lite } // namespace paddle REGISTER_LITE_OP(square, paddle::lite::operators::ActivationOp); +REGISTER_LITE_OP(square_grad, paddle::lite::operators::ActivationGradOp); diff --git a/paddle/fluid/lite/operators/elementwise_ops.cc b/paddle/fluid/lite/operators/elementwise_ops.cc index bba9209fa4d..044e621a1df 100644 --- a/paddle/fluid/lite/operators/elementwise_ops.cc +++ b/paddle/fluid/lite/operators/elementwise_ops.cc @@ -31,7 +31,7 @@ class ElementwiseOp : public OpLite { } bool InferShape() const override { - CHECK_OR_FALSE(param_.X->dims() == param_.Y->dims()); + CHECK_OR_FALSE(param_.X->dims().size() >= param_.Y->dims().size()); param_.Out->Resize(param_.X->dims()); return true; } @@ -46,16 +46,64 @@ class ElementwiseOp : public OpLite { param_.Y = GetVar(scope, Y_name); param_.Out = GetMutableVar(scope, Out_name); param_.axis = boost::get(opdesc.GetAttr("axis")); + + return true; } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "elementwise_op"; } + private: mutable operators::ElementwiseParam param_; }; +class ElementwiseGradExplicitOp : public OpLite { + public: + explicit ElementwiseGradExplicitOp(const std::string& type) : OpLite(type) {} + + bool CheckShape() const override { + CHECK_OR_FALSE(param_.Y); + CHECK_OR_FALSE(param_.X_grad); + CHECK_OR_FALSE(param_.Y_grad); + CHECK_OR_FALSE(param_.Out_grad); + return true; + } + + bool InferShape() const override { + param_.X_grad->Resize(param_.Out_grad->dims()); + param_.Y_grad->Resize(param_.Y->dims()); + return true; + } + + bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { + CHECK_EQ(opdesc.Inputs().size(), 1UL); + auto Out_name = opdesc.Input(framework::GradVarName("Out")).front(); + auto X_name = opdesc.Output(framework::GradVarName("X")).front(); + auto Y_name = opdesc.Output(framework::GradVarName("Y")).front(); + + param_.Out_grad = GetVar(scope, Out_name); + param_.X_grad = GetMutableVar(scope, X_name); + param_.Y_grad = GetMutableVar(scope, Y_name); + param_.axis = boost::get(opdesc.GetAttr("axis")); + + return true; + } + + void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { + return "elementwise_grad_explicit_op"; + } + + private: + mutable operators::ElementwiseGradParam param_; +}; + } // namespace operators } // namespace lite } // namespace paddle REGISTER_LITE_OP(elementwise_sub, paddle::lite::operators::ElementwiseOp); +REGISTER_LITE_OP(elementwise_sub_grad, + paddle::lite::operators::ElementwiseGradExplicitOp); diff --git a/paddle/fluid/lite/operators/fill_constant_op.cc b/paddle/fluid/lite/operators/fill_constant_op.cc new file mode 100644 index 00000000000..7671318fb3e --- /dev/null +++ b/paddle/fluid/lite/operators/fill_constant_op.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +class FillConstantOp : public OpLite { + public: + explicit FillConstantOp(const std::string& type) : OpLite(type) {} + + bool CheckShape() const override { + CHECK_OR_FALSE(param_.Out); + return true; + } + + bool InferShape() const override { + param_.Out->Resize(param_.shape); + return true; + } + + bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { + CHECK_EQ(opdesc.Inputs().size(), 2UL); + auto Out_name = opdesc.Output("Out").front(); + + param_.Out = GetMutableVar(scope, Out_name); + param_.dtype = boost::get(opdesc.GetAttr("dtype")); + param_.shape = boost::get>(opdesc.GetAttr("shape")); + param_.value = boost::get(opdesc.GetAttr("value")); + param_.force_cpu = boost::get(opdesc.GetAttr("force_cpu")); + return true; + } + + void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "fill_constant"; } + + private: + mutable operators::FillConstantParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(fill_constant, paddle::lite::operators::FillConstantOp); diff --git a/paddle/fluid/lite/operators/mean_op.cc b/paddle/fluid/lite/operators/mean_op.cc new file mode 100644 index 00000000000..89798ca2e5f --- /dev/null +++ b/paddle/fluid/lite/operators/mean_op.cc @@ -0,0 +1,98 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +class MeanOp : public OpLite { + public: + explicit MeanOp(const std::string& type) : OpLite(type) {} + + bool CheckShape() const override { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Out); + return true; + } + + bool InferShape() const override { + param_.Out->Resize(std::vector{1}); + return true; + } + + bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { + CHECK_EQ(opdesc.Inputs().size(), 2UL); + auto X_name = opdesc.Input("X").front(); + auto Out_name = opdesc.Output("Out").front(); + + param_.X = GetVar(scope, X_name); + param_.Out = GetMutableVar(scope, Out_name); + return true; + } + + void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "mean"; } + + private: + mutable operators::ElementwiseParam param_; +}; + +class MeanGradOp : public OpLite { + public: + explicit MeanGradOp(const std::string& type) : OpLite(type) {} + + bool CheckShape() const override { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Out_grad); + CHECK_OR_FALSE(param_.X_grad); + return true; + } + + bool InferShape() const override { + param_.X_grad->Resize(param_.X->dims()); + // param_.X_grad->set_lod(param_.X->lod()); + return true; + } + + bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { + CHECK_EQ(opdesc.Inputs().size(), 3UL); + auto X_name = opdesc.Input("X").front(); + auto Out_grad_name = opdesc.Input(framework::GradVarName("Out")).front(); + auto X_grad_name = opdesc.Output(framework::GradVarName("X")).front(); + + param_.X = GetVar(scope, X_name); + param_.Out_grad = GetVar(scope, Out_grad_name); + param_.X_grad = GetMutableVar(scope, X_grad_name); + return true; + } + + void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "mean_grad"; } + + private: + mutable operators::MeanGradParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(mean, paddle::lite::operators::MeanOp); +REGISTER_LITE_OP(mean_grad, paddle::lite::operators::MeanGradOp); diff --git a/paddle/fluid/lite/operators/mul_op.cc b/paddle/fluid/lite/operators/mul_op.cc index b78ae4578a6..75b536a093d 100644 --- a/paddle/fluid/lite/operators/mul_op.cc +++ b/paddle/fluid/lite/operators/mul_op.cc @@ -28,9 +28,18 @@ bool MulOpLite::CheckShape() const { const auto x_dims = param_.x->dims(); const auto y_dims = param_.y->dims(); - CHECK_EQ_OR_FALSE(y_dims.size(), 2UL); CHECK_GT_OR_FALSE(x_dims.size(), static_cast(param_.x_num_col_dims)); + CHECK_GT_OR_FALSE(y_dims.size(), static_cast(param_.y_num_col_dims)); + auto x_mat_dims = + framework::flatten_to_2d(x_dims.data(), param_.x_num_col_dims); + auto y_mat_dims = + framework::flatten_to_2d(y_dims.data(), param_.y_num_col_dims); + + PADDLE_ENFORCE_EQ(x_mat_dims[1], y_mat_dims[0], + "First matrix's width must be equal with second matrix's " + "height. %s, %s", + x_mat_dims[1], y_mat_dims[0]); return true; } @@ -39,11 +48,16 @@ bool MulOpLite::InferShape() const { const auto y_dims = param_.y->dims(); // Set output dims - std::vector out_dims(param_.x_num_col_dims + 1, 0); + std::vector out_dims( + param_.x_num_col_dims + y_dims.size() - param_.y_num_col_dims, 0); for (int i = 0; i < param_.x_num_col_dims; ++i) { out_dims[i] = x_dims[i]; } - out_dims.back() = y_dims[1]; + + for (auto i = static_cast(param_.y_num_col_dims); i < y_dims.size(); + ++i) { + out_dims[i] = y_dims[i]; + } param_.output->Resize(lite::DDim(out_dims)); @@ -52,6 +66,38 @@ bool MulOpLite::InferShape() const { return true; } +bool MulGradOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.y); + CHECK_OR_FALSE(param_.output_grad); + CHECK_OR_FALSE(param_.x_grad); + CHECK_OR_FALSE(param_.y_grad); + + return true; +} + +bool MulGradOpLite::InferShape() const { + param_.x_grad->Resize(param_.x->dims()); + param_.y_grad->Resize(param_.y->dims()); + return true; +} + +bool MulGradOpLite::AttachImpl(const OpDesc &op_desc, lite::Scope *scope) { + auto X_name = op_desc.Input("X").front(); + auto Y_name = op_desc.Input("Y").front(); + auto Out_grad_name = op_desc.Output(framework::GradVarName("Out")).front(); + auto X_grad_name = op_desc.Output(framework::GradVarName("X")).front(); + auto Y_grad_name = op_desc.Output(framework::GradVarName("Y")).front(); + + param_.x = GetVar(scope, X_name); + param_.y = GetVar(scope, Y_name); + param_.output_grad = GetVar(scope, Out_grad_name); + param_.x_grad = GetMutableVar(scope, X_grad_name); + param_.y_grad = GetMutableVar(scope, Y_grad_name); + + return true; +} + } // namespace operators } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/operators/mul_op.h b/paddle/fluid/lite/operators/mul_op.h index 806fdb01f9b..73827753bd2 100644 --- a/paddle/fluid/lite/operators/mul_op.h +++ b/paddle/fluid/lite/operators/mul_op.h @@ -61,6 +61,26 @@ class MulOpLite : public OpLite { mutable MulParam param_; }; +class MulGradOpLite : public OpLite { + public: + MulGradOpLite() {} + + explicit MulGradOpLite(const std::string &type) : OpLite(type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + bool AttachImpl(const OpDesc &op_desc, lite::Scope *scope) override; + + std::string DebugString() const override { return "mul_grad"; } + + private: + mutable MulGradParam param_; +}; + } // namespace operators } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index d21c0e3135d..c970ac2d873 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -72,6 +72,17 @@ struct MulParam { int y_num_col_dims{1}; }; +struct MulGradParam { + const lite::Tensor* x{}; + const lite::Tensor* y{}; + const lite::Tensor* output_grad{}; + lite::Tensor* x_grad{}; + lite::Tensor* y_grad{}; + + int x_num_col_dims{1}; + int y_num_col_dims{1}; +}; + // For Scale Op struct ScaleParam { lite::Tensor* x{}; @@ -91,9 +102,10 @@ struct ElementwiseParam { }; struct ElementwiseGradParam { - const lite::Tensor* X_grad{}; - const lite::Tensor* Y_grad{}; - lite::Tensor* Out_grad{}; + const lite::Tensor* Y{}; + const lite::Tensor* Out_grad{}; + lite::Tensor* X_grad{}; + lite::Tensor* Y_grad{}; int axis{-1}; // for broadcasting. }; @@ -111,6 +123,39 @@ struct ActivationGradParam { const lite::Tensor* Out_grad{}; }; +/// ----------------------- mean operators ---------------------- +struct MeanParam { + const lite::Tensor* X{}; + lite::Tensor* Out{}; +}; + +struct MeanGradParam { + const lite::Tensor* X{}; + const lite::Tensor* Out_grad{}; + // for backward + lite::Tensor* X_grad{}; +}; + +/// ----------------------- fill_constant operators ---------------------- +struct FillConstantParam { + int dtype{framework::proto::VarType::FP32}; + std::vector shape{}; + float value{0.0f}; + // useless for x86, keep it for compatibility + bool force_cpu{false}; + lite::Tensor* Out{}; +}; + +/// ----------------------- sgd operators ---------------------- +struct SGDParam { + int dtype{framework::proto::VarType::FP32}; + + const lite::Tensor* Param{}; + const lite::Tensor* LearningRate{}; + const lite::Tensor* Grad{}; + lite::Tensor* ParamOut{}; +}; + } // namespace operators } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/operators/sgd_op.cc b/paddle/fluid/lite/operators/sgd_op.cc new file mode 100644 index 00000000000..2571ad0b102 --- /dev/null +++ b/paddle/fluid/lite/operators/sgd_op.cc @@ -0,0 +1,57 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "/paddle/paddle/fluid/lite/operators/sgd_op.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool SGDOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.Param); + CHECK_OR_FALSE(param_.LearningRate); + CHECK_OR_FALSE(param_.Grad); + CHECK_OR_FALSE(param_.ParamOut); + return true; +} + +bool SGDOpLite::InferShape() const { + auto lr_dims = param_.LearningRate->dims().data(); + CHECK_EQ_OR_FALSE(framework::product(lr_dims), 1); + param_.ParamOut->Resize(param_.Param->dims()); + return true; +} + +bool SGDOpLite::AttachImpl(const OpDesc& opdesc, lite::Scope* scope) { + CHECK_EQ(opdesc.Inputs().size(), 3UL); + auto Param_name = opdesc.Input("Param").front(); + auto LearningRate_name = opdesc.Input("LearningRate").front(); + auto Grad_name = opdesc.Input("Grad").front(); + auto ParamOut_name = opdesc.Output("ParamOut").front(); + + param_.Param = GetVar(scope, Param_name); + param_.LearningRate = GetVar(scope, LearningRate_name); + param_.Grad = GetVar(scope, Grad_name); + param_.ParamOut = GetMutableVar(scope, ParamOut_name); + + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(sgd, paddle::lite::operators::SGDOpLite); diff --git a/paddle/fluid/lite/operators/sgd_op.h b/paddle/fluid/lite/operators/sgd_op.h new file mode 100644 index 00000000000..dea045c0b67 --- /dev/null +++ b/paddle/fluid/lite/operators/sgd_op.h @@ -0,0 +1,50 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/operators/op_params.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class SGDOpLite : public OpLite { + public: + SGDOpLite() {} + + explicit SGDOpLite(const std::string &type) : OpLite(type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + bool AttachImpl(const OpDesc &op_desc, lite::Scope *scope) override; + + std::string DebugString() const override { return "sgd"; } + + private: + mutable SGDParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle -- GitLab