From 2cb7e4ff795deb8cbf53fc45bfeba747603faa56 Mon Sep 17 00:00:00 2001 From: mapingshuo Date: Fri, 20 Mar 2020 10:20:54 +0800 Subject: [PATCH] Mul grad (#3201) * rm grad code * add mul_grad, test=develop --- lite/kernels/arm/CMakeLists.txt | 1 + lite/kernels/arm/mul_grad_compute.cc | 163 +++++++++++ lite/kernels/arm/mul_grad_compute.h | 42 +++ lite/operators/CMakeLists.txt | 1 + lite/operators/mul_grad_op.cc | 101 +++++++ lite/operators/mul_grad_op.h | 62 ++++ lite/operators/mul_op.h | 22 -- lite/tests/kernels/CMakeLists.txt | 1 + lite/tests/kernels/mul_compute_test.cc | 2 +- lite/tests/kernels/mul_grad_compute_test.cc | 298 ++++++++++++++++++++ 10 files changed, 670 insertions(+), 23 deletions(-) create mode 100644 lite/kernels/arm/mul_grad_compute.cc create mode 100644 lite/kernels/arm/mul_grad_compute.h create mode 100644 lite/operators/mul_grad_op.cc create mode 100644 lite/operators/mul_grad_op.h create mode 100644 lite/tests/kernels/mul_grad_compute_test.cc diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index 514d6069b5..a9f15ebd70 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -109,6 +109,7 @@ add_kernel(mean_compute_arm ARM extra SRCS mean_compute.cc DEPS ${lite_kernel_de if(LITE_WITH_TRAIN) add_kernel(mean_grad_compute_arm ARM extra SRCS mean_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(activation_grad_compute_arm ARM basic SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) + add_kernel(mul_grad_compute_arm ARM extra SRCS mul_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(sgd_compute_arm ARM extra SRCS sgd_compute.cc DEPS ${lite_kernel_deps} math_arm) endif() diff --git a/lite/kernels/arm/mul_grad_compute.cc b/lite/kernels/arm/mul_grad_compute.cc new file mode 100644 index 0000000000..405d61d2ac --- /dev/null +++ b/lite/kernels/arm/mul_grad_compute.cc @@ -0,0 +1,163 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/mul_grad_compute.h" +#include +#include "lite/backends/arm/math/funcs.h" +#include "lite/backends/arm/math/sgemm.h" +#include "lite/core/op_registry.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void MulGradCompute::PrepareForRun() { + auto& ctx = this->ctx_->template As(); +} + +void MulGradCompute::Run() { + // step1 flatten_2d + auto& param = Param(); + const auto x_dims = param.x->dims(); + const auto y_dims = param.y->dims(); + const auto out_dims = param.output_grad->dims(); + + m_ = static_cast(x_dims.Slice(0, param.x_num_col_dims).production()); + + k_ = static_cast( + x_dims.Slice(param.x_num_col_dims, x_dims.size()).production()); + n_ = static_cast( + y_dims.Slice(param.y_num_col_dims, y_dims.size()).production()); + + const auto* out_grad_data = param.output_grad->data(); + const auto* x_data = param.x->data(); + const auto* y_data = param.y->data(); + float* x_grad_data; + float* y_grad_data; + if (param.x_grad) { + x_grad_data = param.x_grad->mutable_data(); + } + + if (param.y_grad) { + y_grad_data = param.y_grad->mutable_data(); + } + + paddle::lite::operators::ActivationParam act_param; + act_param.has_active = false; + // out_grad * y^T = x_grad + // (m, n), (n, k) -> (m, k) + auto& ctx = this->ctx_->template As(); + if (param.x_grad) { + if (m_ == 1) { + lite::arm::math::sgemv(y_data, + out_grad_data, + x_grad_data, + false, + k_, // M + n_, // N + false, + nullptr, + false, + lite_api::ActivationType::kIndentity, + &ctx); + } else { + paddle::lite::arm::math::sgemm(false, + true, // is_transB, + m_, // M + k_, // N + n_, // K + 1.0f, // alpha + out_grad_data, // A + n_, // lda + y_data, // B + n_, // ldb + 0.f, // beta + x_grad_data, // C + k_, // ldc + NULL, // bias + false, // is_bias + act_param, // act_param + &ctx); // ctx + } + } + + // x^T * out_grad = y_grad + // (k, m) (m, n) -> (k, n) + if (param.y_grad) { + if (n_ == 1) { + lite::arm::math::sgemv(x_data, + out_grad_data, + y_grad_data, + true, + k_, // M + m_, // N + false, + nullptr, + false, + lite_api::ActivationType::kIndentity, + &ctx); + } else { + paddle::lite::arm::math::sgemm(true, // is_transA + false, // is_transB, + k_, // M + n_, // N + m_, // K + 1.0f, // alpha + x_data, // A + k_, // lda + out_grad_data, // B + n_, // ldb + 0.f, // beta + y_grad_data, // C + n_, // ldc + NULL, // bias + false, // is_bias + act_param, // act_param + &ctx); // ctx + } + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(mul_grad, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::MulGradCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Out@GRAD", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("X@GRAD", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Y@GRAD", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/mul_grad_compute.h b/lite/kernels/arm/mul_grad_compute.h new file mode 100644 index 0000000000..2cdaff3f10 --- /dev/null +++ b/lite/kernels/arm/mul_grad_compute.h @@ -0,0 +1,42 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/core/types.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class MulGradCompute : public KernelLite { + public: + using param_t = operators::MulGradParam; + + void PrepareForRun() override; + + void Run() override; + + virtual ~MulGradCompute() = default; + + private: + int m_, n_, k_; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 34c7b8d666..4a606458d8 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -144,6 +144,7 @@ add_operator(mean_op extra SRCS mean_op.cc DEPS ${op_DEPS}) if (LITE_WITH_TRAIN) add_operator(mean_grad_op extra SRCS mean_grad_op.cc DEPS ${op_DEPS}) add_operator(activation_grad_ops basic SRCS activation_grad_ops.cc DEPS ${op_DEPS}) + add_operator(mul_grad_op basic SRCS mul_grad_op.cc DEPS ${op_DEPS}) add_operator(sgd_op extra SRCS sgd_op.cc DEPS ${op_DEPS}) endif() diff --git a/lite/operators/mul_grad_op.cc b/lite/operators/mul_grad_op.cc new file mode 100644 index 0000000000..8215521637 --- /dev/null +++ b/lite/operators/mul_grad_op.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/mul_grad_op.h" +#include "lite/core/op_registry.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool MulGradOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.y); + CHECK_OR_FALSE(param_.output_grad); + CHECK_OR_FALSE(param_.x_grad || param_.y_grad); + CHECK_OR_FALSE(param_.x_num_col_dims); + CHECK_OR_FALSE(param_.y_num_col_dims); + + const auto x_dims = param_.x->dims(); + const auto y_dims = param_.y->dims(); + const auto out_dims = param_.output_grad->dims(); + + CHECK_GT_OR_FALSE(x_dims.size(), static_cast(param_.x_num_col_dims)); + CHECK_GT_OR_FALSE(y_dims.size(), static_cast(param_.y_num_col_dims)); + + auto x_flatten_dims = flatten_2d(x_dims, param_.x_num_col_dims); + auto y_flatten_dims = flatten_2d(y_dims, param_.y_num_col_dims); + auto out_flatten_dims = flatten_2d(out_dims, param_.x_num_col_dims); + + // Out = X * Y; + CHECK_EQ_OR_FALSE(x_flatten_dims[1], y_flatten_dims[0]); + CHECK_EQ_OR_FALSE(x_flatten_dims[0], out_flatten_dims[0]); + CHECK_EQ_OR_FALSE(y_flatten_dims[1], out_flatten_dims[1]); + return true; +} + +bool MulGradOpLite::InferShape() const { + const auto x_dims = param_.x->dims(); + const auto y_dims = param_.y->dims(); + if (param_.x_grad) { + param_.x_grad->Resize(x_dims); + param_.x_grad->set_lod(param_.x->lod()); + } + if (param_.y_grad) { + param_.y_grad->Resize(y_dims); + param_.y_grad->set_lod(param_.y->lod()); + } +} + +bool MulGradOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { + CHECK(!op_desc.Input("X").empty()); + CHECK(!op_desc.Input("Y").empty()); + CHECK(!op_desc.Input("Out@GRAD").empty()); + CHECK(!op_desc.Output("X@GRAD").empty() || !op_desc.Output("Y@GRAD").empty()) + << "at least one of 'X@GRAD' and 'Y@GRAD' is not empty"; + + auto *x_var = scope->FindVar(op_desc.Input("X").front()); + CHECK(x_var); + param_.x = &x_var->Get(); + + auto *y_var = scope->FindVar(op_desc.Input("Y").front()); + CHECK(y_var); + param_.y = &y_var->Get(); + + auto *out_grad_var = scope->FindVar(op_desc.Input("Out@GRAD").front()); + CHECK(out_grad_var); + param_.output_grad = &out_grad_var->Get(); + + if (!op_desc.Output("X@GRAD").empty()) { + auto *x_grad_var = scope->FindVar(op_desc.Output("X@GRAD").front()); + CHECK(x_grad_var); + param_.x_grad = x_grad_var->GetMutable(); + } + + if (!op_desc.Output("Y@GRAD").empty()) { + auto *y_grad_var = scope->FindVar(op_desc.Output("Y@GRAD").front()); + CHECK(y_grad_var); + param_.y_grad = y_grad_var->GetMutable(); + } + param_.x_num_col_dims = op_desc.GetAttr("x_num_col_dims"); + param_.y_num_col_dims = op_desc.GetAttr("y_num_col_dims"); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(mul_grad, paddle::lite::operators::MulGradOpLite); diff --git a/lite/operators/mul_grad_op.h b/lite/operators/mul_grad_op.h new file mode 100644 index 0000000000..ef61f54f9b --- /dev/null +++ b/lite/operators/mul_grad_op.h @@ -0,0 +1,62 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/operators/op_params.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class MulGradOpLite : public OpLite { + public: + MulGradOpLite() {} + + explicit MulGradOpLite(const std::string &type) : OpLite(type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; + + std::string DebugString() const override { return "mul_grad"; } + + private: + mutable MulGradParam param_; +}; + +std::vector flatten_2d(DDim dims, int num_col_dims) { + std::vector flatten_dims{1, 1}; + for (int i = 0; i < dims.size(); i++) { + if (i < num_col_dims) { + flatten_dims[0] *= dims[i]; + } else { + flatten_dims[1] *= dims[i]; + } + } + return flatten_dims; +} + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/mul_op.h b/lite/operators/mul_op.h index e53168e00e..caf7bf6ae9 100644 --- a/lite/operators/mul_op.h +++ b/lite/operators/mul_op.h @@ -66,28 +66,6 @@ class MulOpLite : public OpLite { mutable MulParam param_; }; -#ifdef LITE_WITH_TRAIN -class MulGradOpLite : public OpLite { - public: - MulGradOpLite() {} - - explicit MulGradOpLite(const std::string &type) : OpLite(type) {} - - bool CheckShape() const override; - - bool InferShape() const override; - - void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } - - bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; - - std::string DebugString() const override { return "mul_grad"; } - - private: - mutable MulGradParam param_; -}; -#endif - } // namespace operators } // namespace lite } // namespace paddle diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index f1b51b13ab..de0d530b86 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -65,6 +65,7 @@ if(LITE_BUILD_EXTRA) if (LITE_WITH_TRAIN) lite_cc_test(test_kernel_mean_compute SRCS mean_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_activation_grad_compute SRCS activation_grad_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_mul_grad_compute SRCS mul_grad_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_sgd_compute SRCS sgd_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) endif() diff --git a/lite/tests/kernels/mul_compute_test.cc b/lite/tests/kernels/mul_compute_test.cc index d9bbfaa8d0..d070292332 100644 --- a/lite/tests/kernels/mul_compute_test.cc +++ b/lite/tests/kernels/mul_compute_test.cc @@ -109,6 +109,7 @@ void TestMul(const std::vector& x_dims, int y_num_col_dims, const Place& place, float abs_error) { + LOG(INFO) << "run test arm"; std::unique_ptr tester(new MulComputeTester(place, "def", DDim(x_dims), @@ -131,7 +132,6 @@ TEST(Mul, precision) { #else return; #endif - TestMul({4, 5}, {5, 4}, 1, 1, place, abs_error); TestMul({4, 5}, {5, 4, 3, 2}, 1, 1, place, abs_error); TestMul({4, 20}, {5, 4, 3, 2}, 1, 2, place, abs_error); diff --git a/lite/tests/kernels/mul_grad_compute_test.cc b/lite/tests/kernels/mul_grad_compute_test.cc new file mode 100644 index 0000000000..e7a64d181c --- /dev/null +++ b/lite/tests/kernels/mul_grad_compute_test.cc @@ -0,0 +1,298 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/mul_grad_compute.h" +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/arm/mul_compute.h" +#include "lite/tests/utils/fill_data.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +using param_t = operators::MulParam; +using grad_param_t = operators::MulGradParam; +using kernel_t = MulCompute; +using grad_kernel_t = MulGradCompute; + +class MulGradTester { + public: + explicit MulGradTester(const DDim& x_dims, + const DDim& y_dims, + int x_num_col_dims, + int y_num_col_dims) + : x_dims_(x_dims), + y_dims_(y_dims), + x_num_col_dims_(x_num_col_dims), + y_num_col_dims_(y_num_col_dims) {} + + void prepare_kernel() { + std::unique_ptr ctx1(new KernelContext); + ctx1->As(); + kernel_.SetContext(std::move(ctx1)); + + std::unique_ptr ctx2(new KernelContext); + ctx2->As(); + delta_kernel_.SetContext(std::move(ctx2)); + + std::unique_ptr ctx3(new KernelContext); + ctx3->As(); + grad_kernel_.SetContext(std::move(ctx3)); + } + + void run_forward(param_t* param, + kernel_t* kernel, + const std::vector& x_vec, + const std::vector& y_vec, + float* out_vec) { + Tensor x; + Tensor y; + Tensor output; + x.Resize(x_dims_); + y.Resize(y_dims_); + output.Resize(DDim(out_dims_)); + auto* x_data = x.mutable_data(); + auto* y_data = y.mutable_data(); + for (int i = 0; i < x_dims_.production(); i++) { + x_data[i] = x_vec[i]; + } + for (int i = 0; i < y_dims_.production(); i++) { + y_data[i] = y_vec[i]; + } + + param->x = &x; + param->y = &y; + param->output = &output; + param->x_num_col_dims = x_num_col_dims_; + param->y_num_col_dims = y_num_col_dims_; + kernel->SetParam(*param); + kernel->Launch(); + + auto* output_data = output.mutable_data(); + for (int i = 0; i < out_dims_.production(); i++) { + out_vec[i] = output_data[i]; + } + } + + void run_backward(grad_param_t* param, + grad_kernel_t* kernel, + const std::vector& x_vec, + const std::vector& y_vec, + const std::vector& out_grad_vec, + float* x_grad_vec, + float* y_grad_vec) { + Tensor x; + Tensor x_grad; + Tensor y; + Tensor y_grad; + Tensor out_grad; + x.Resize(x_dims_); + x_grad.Resize(x_dims_); + y.Resize(y_dims_); + y_grad.Resize(y_dims_); + out_grad.Resize(out_dims_); + auto* x_data = x.mutable_data(); + auto* y_data = y.mutable_data(); + auto* out_grad_data = out_grad.mutable_data(); + for (int i = 0; i < x_dims_.production(); i++) { + x_data[i] = x_vec[i]; + } + for (int i = 0; i < y_dims_.production(); i++) { + y_data[i] = y_vec[i]; + } + for (int i = 0; i < out_dims_.production(); i++) { + out_grad_data[i] = out_grad_vec[i]; + } + + param->x = &x; + param->x_grad = &x_grad; + param->y = &y; + param->y_grad = &y_grad; + param->output_grad = &out_grad; + param->x_num_col_dims = x_num_col_dims_; + param->y_num_col_dims = y_num_col_dims_; + kernel->SetParam(*param); + kernel->Launch(); + + auto* x_grad_data = x_grad.mutable_data(); + auto* y_grad_data = y_grad.mutable_data(); + for (int i = 0; i < x_dims_.production(); i++) { + x_grad_vec[i] = x_grad_data[i]; + } + for (int i = 0; i < y_dims_.production(); i++) { + y_grad_vec[i] = y_grad_data[i]; + } + } + + void check_grad() { + std::vector out_shape; + for (int i = 0; i < x_num_col_dims_; i++) { + out_shape.push_back(x_dims_[i]); + } + for (int i = y_num_col_dims_; i < y_dims_.size(); i++) { + out_shape.push_back(y_dims_[i]); + } + out_dims_ = DDim(out_shape); + + // forward + std::vector x(x_dims_.production()); + std::vector y(y_dims_.production()); + std::vector out(out_dims_.production()); + fill_data_rand(x.data(), -1.f, 1.f, x_dims_.production()); + fill_data_rand(y.data(), -1.f, 1.f, y_dims_.production()); + this->run_forward(¶m_, &kernel_, x, y, out.data()); + + for (int i = 0; i < x_dims_.production(); i++) { + LOG(INFO) << "x_" << i << ": " << x[i]; + } + + for (int i = 0; i < y_dims_.production(); i++) { + LOG(INFO) << "y_" << i << ": " << y[i]; + } + + for (int i = 0; i < out_dims_.production(); i++) { + LOG(INFO) << "out_" << i << ": " << out[i]; + } + + // backward + std::vector out_grad(out_dims_.production()); + std::vector x_grad(x_dims_.production()); + std::vector y_grad(y_dims_.production()); + for (int i = 0; i < out_dims_.production(); i++) { + out_grad[i] = 1.0; + } + this->run_backward(&grad_param_, + &grad_kernel_, + x, + y, + out_grad, + x_grad.data(), + y_grad.data()); + + // get numeric gradient + std::vector x_delta(x_dims_.production()); + std::vector y_delta(y_dims_.production()); + std::vector out_delta(out_dims_.production()); + + float delta = 0.001; + float max_grad_delta = 0.005; + for (int i = 0; i < x_dims_.production(); i++) { + LOG(INFO) << "--------------------"; + LOG(INFO) << "delta: " << delta; + LOG(INFO) << "max_grad_delta: " << max_grad_delta; + for (int j = 0; j < x_dims_.production(); j++) { + // x_delta[j] = i == j ? x[j] + delta : x[j]; + + if (i == j) { + x_delta[j] = x[j] + delta; + } else { + x_delta[j] = x[j]; + } + } + this->run_forward( + &delta_param_, &delta_kernel_, x_delta, y, out_delta.data()); + for (int j = 0; j < x_dims_.production(); j++) { + LOG(INFO) << "x_" << j << ": " << x[j]; + LOG(INFO) << "x_delta_" << j << ": " << x_delta[j]; + } + + for (int j = 0; j < y_dims_.production(); j++) { + LOG(INFO) << "y_" << j << ": " << y[j]; + } + + for (int j = 0; j < out_dims_.production(); j++) { + LOG(INFO) << "out_delta_" << j << ": " << out_delta[j]; + } + + float sum = 0; + for (int j = 0; j < out_dims_.production(); j++) { + sum += (out_delta[j] - out[j]); + } + + LOG(INFO) << "x_grad_" << i << ": " << x_grad[i]; + LOG(INFO) << "x_grad_num_" << i << ": " << sum / delta; + EXPECT_NEAR(x_grad[i], sum / delta, max_grad_delta); + } + + for (int i = 0; i < y_dims_.production(); i++) { + for (int j = 0; j < y_dims_.production(); j++) { + y_delta[j] = i == j ? y[j] + delta : y[j]; + } + this->run_forward( + &delta_param_, &delta_kernel_, x, y_delta, out_delta.data()); + float sum = 0; + for (int j = 0; j < out_dims_.production(); j++) { + sum += out_delta[j] - out[j]; + } + LOG(INFO) << "y_grad_" << i << ": " << y_grad[i]; + LOG(INFO) << "y_grad_num_" << i << ": " << sum / delta; + EXPECT_NEAR(y_grad[i], sum / delta, max_grad_delta); + } + } + + private: + DDim x_dims_; + DDim y_dims_; + DDim out_dims_; + int x_num_col_dims_; + int y_num_col_dims_; + kernel_t kernel_; + kernel_t delta_kernel_; + grad_kernel_t grad_kernel_; + param_t param_; + param_t delta_param_; + grad_param_t grad_param_; +}; + +void TestNormalCase(const std::vector& x_dims, + const std::vector& y_dims, + int x_num_col_dims, + int y_num_col_dims) { + std::unique_ptr tester(new MulGradTester( + DDim(x_dims), DDim(y_dims), x_num_col_dims, y_num_col_dims)); + + tester->prepare_kernel(); + float delta = 0.001; + float max_grad_delta = 0.005; + tester->check_grad(); +} + +TEST(mul_grad_arm, compute) { + LOG(INFO) << "Test Mul grad"; + DeviceInfo::Init(); + TestNormalCase({1, 3}, {3, 2}, 1, 1); + TestNormalCase({3, 2}, {2, 1}, 1, 1); + TestNormalCase({3, 1}, {1, 7}, 1, 1); + TestNormalCase({2, 3}, {3, 2}, 1, 1); + TestNormalCase({4, 5}, {5, 4}, 1, 1); + TestNormalCase({4, 5}, {5, 4, 3, 2}, 1, 1); + TestNormalCase({3, 4}, {2, 2, 3}, 1, 2); + TestNormalCase({4, 20}, {5, 4, 3, 2}, 1, 2); + TestNormalCase({4, 60}, {5, 4, 3, 2}, 1, 3); + TestNormalCase({2, 3, 4, 5}, {60, 4}, 1, 1); + TestNormalCase({2, 3, 4, 5}, {20, 4}, 2, 1); + TestNormalCase({2, 3, 4, 5}, {5, 4}, 3, 1); + TestNormalCase({2, 3, 4, 5}, {60, 3, 4, 5}, 1, 1); + TestNormalCase({2, 3, 4, 5}, {4, 5, 6, 2}, 2, 2); + TestNormalCase({2, 3, 4, 5}, {5, 1, 4, 2}, 3, 2); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle +USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(mul_grad, kARM, kFloat, kNCHW, def); -- GitLab