diff --git a/lite/backends/arm/math/elementwise.cc b/lite/backends/arm/math/elementwise.cc index 186ad19735799dcb91641354af4b4f09692bfce9..47a4d427f5400212a80fc31336e462a1c48bd640 100644 --- a/lite/backends/arm/math/elementwise.cc +++ b/lite/backends/arm/math/elementwise.cc @@ -266,6 +266,72 @@ void elementwise_add_relu_broadcast(const float* dinx, } } +template <> +void elementwise_add_grad(const float* dout_grad, + float* x_grad, + int num) { + int cnt = num >> 4; + int remain = num & 0x0f; +#pragma omp parallel for + for (int i = 0; i < cnt; ++i) { + const float* out_data = dout_grad + 16 * i; + float* x_data = x_grad + 16 * i; + float32x4_t din0 = vld1q_f32(out_data); + float32x4_t din1 = vld1q_f32(out_data + 4); + float32x4_t din2 = vld1q_f32(out_data + 8); + float32x4_t din3 = vld1q_f32(out_data + 12); + vst1q_f32(x_data, din0); + vst1q_f32(x_data + 4, din1); + vst1q_f32(x_data + 8, din2); + vst1q_f32(x_data + 12, din3); + } + if (remain > 0) { + const float* out_data = dout_grad + 16 * cnt; + float* x_data = x_grad + 16 * cnt; + for (int i = 0; i < remain; ++i) { + x_data[i] = out_data[i]; + } + } +} +// we assume that y_data numel less than x_data, otherwise, call this function +// by change x_grad and y_grad position +template <> +void elementwise_add_grad_broadcast(const float* dout_grad, + float* x_grad, + float* y_grad, + int pre, + int n, + int post) { + if (x_grad) { + elementwise_add_grad(dout_grad, x_grad, pre * n * post); + } + if (y_grad) { + memset(y_grad, 0, n * sizeof(float)); +#pragma omp parallel for + for (int i = 0; i < pre; ++i) { + for (int j = 0; j < n; ++j) { + float sum = 0; + int cnt = post >> 2; + int remain = post & 0x03; + const float* out_data = dout_grad + (i * n + j) * post; + float32x4_t sum_v = vdupq_n_f32(0); + for (int ci = 0; ci < cnt; ++ci) { + float32x4_t din = vld1q_f32(out_data + 4 * ci); + sum_v = vaddq_f32(sum_v, din); + } + out_data += 4 * cnt; + for (int ci = 0; ci < remain; ++ci) { + sum += out_data[ci]; + } + float32x2_t high = vget_high_f32(sum_v); + float32x2_t low = vget_low_f32(sum_v); + sum += vget_lane_f32(high, 0) + vget_lane_f32(high, 1) + + vget_lane_f32(low, 0) + vget_lane_f32(low, 1); + y_grad[j] += sum; + } + } + } +} template <> void elementwise_sub(const float* dinx, const float* diny, @@ -510,6 +576,84 @@ void elementwise_sub_relu_broadcast(const float* dinx, } } } +// we assume the formula is x-y +template <> +void elementwise_sub_grad(const float* dout_grad, + float* x_grad, + float* y_grad, + int num) { + if (x_grad) { + elementwise_add_grad(dout_grad, x_grad, num); + } + if (y_grad) { + int cnt = num >> 4; + int remain = num & 0x0f; + float32x4_t minus = vdupq_n_f32(-1); +#pragma omp parallel for + for (int i = 0; i < cnt; ++i) { + const float* out_data = dout_grad + 16 * i; + float* y_data = y_grad + 16 * i; + float32x4_t din0 = vld1q_f32(out_data); + float32x4_t din1 = vld1q_f32(out_data + 4); + float32x4_t din2 = vld1q_f32(out_data + 8); + float32x4_t din3 = vld1q_f32(out_data + 12); + din0 = vmulq_f32(din0, minus); + din1 = vmulq_f32(din1, minus); + din2 = vmulq_f32(din2, minus); + din3 = vmulq_f32(din3, minus); + vst1q_f32(y_data, din0); + vst1q_f32(y_data + 4, din1); + vst1q_f32(y_data + 8, din2); + vst1q_f32(y_data + 12, din3); + } + if (remain > 0) { + const float* out_data = dout_grad + 16 * cnt; + float* y_data = y_grad + 16 * cnt; + for (int i = 0; i < remain; ++i) { + y_data[i] = -out_data[i]; + } + } + } +} +// we assume that y_data numel less than x_data, otherwise, call this function +// by change x_grad and y_grad position +template <> +void elementwise_sub_grad_broadcast(const float* dout_grad, + float* x_grad, + float* y_grad, + int pre, + int n, + int post) { + if (x_grad) { + elementwise_add_grad(dout_grad, x_grad, pre * n * post); + } + if (y_grad) { + memset(y_grad, 0, n * sizeof(float)); +#pragma omp parallel for + for (int i = 0; i < pre; ++i) { + for (int j = 0; j < n; ++j) { + float sum = 0; + int cnt = post << 2; + int remain = post & 0x03; + const float* out_data = dout_grad + (i * n + j) * post; + float32x4_t sum_v = vdupq_n_f32(0); + for (int ci = 0; ci < cnt; ++ci) { + float32x4_t din = vld1q_f32(out_data + 4 * ci); + sum_v = vaddq_f32(sum_v, din); + } + out_data += 4 * cnt; + for (int ci = 0; ci < remain; ++ci) { + sum -= out_data[ci]; + } + float32x2_t high = vget_high_f32(sum_v); + float32x2_t low = vget_low_f32(sum_v); + sum -= vget_lane_f32(high, 0) + vget_lane_f32(high, 1) + + vget_lane_f32(low, 0) + vget_lane_f32(low, 1); + y_grad[j] += sum; + } + } + } +} template <> void elementwise_mul(const float* dinx, diff --git a/lite/backends/arm/math/elementwise.h b/lite/backends/arm/math/elementwise.h index 60d702742dec58f1502837617f5d4059dbb43e22..06ecab08edcaf06614de94b99084be2ee80647aa 100644 --- a/lite/backends/arm/math/elementwise.h +++ b/lite/backends/arm/math/elementwise.h @@ -183,6 +183,13 @@ template void elementwise_add_relu_broadcast( const T* dinx, const T* diny, T* dout, int batch, int channels, int num); +template +void elementwise_add_grad(const T* dout, T* dinx, int num); + +template +void elementwise_add_grad_broadcast( + const T* dout_grad, T* x_grad, T* y_grad, int pre, int n, int post); + template void elementwise_sub(const T* dinx, const T* diny, T* dout, int num); @@ -197,6 +204,13 @@ template void elementwise_sub_relu_broadcast( const T* dinx, const T* diny, T* dout, int batch, int channels, int num); +template +void elementwise_sub_grad(const T* dout, T* dinx, T* diny, int num); + +template +void elementwise_sub_grad_broadcast( + const T* dout_grad, T* x_grad, T* y_grad, int pre, int n, int post); + template void elementwise_mul(const T* dinx, const T* diny, T* dout, int num); diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index a9f15ebd70df864ce545bb01118473b66ae729fb..75dee596dd8b9b63bb45610659d558f6c82e574d 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -109,6 +109,7 @@ add_kernel(mean_compute_arm ARM extra SRCS mean_compute.cc DEPS ${lite_kernel_de if(LITE_WITH_TRAIN) add_kernel(mean_grad_compute_arm ARM extra SRCS mean_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(activation_grad_compute_arm ARM basic SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) + add_kernel(elementwise_grad_compute_arm ARM basic SRCS elementwise_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(mul_grad_compute_arm ARM extra SRCS mul_grad_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(sgd_compute_arm ARM extra SRCS sgd_compute.cc DEPS ${lite_kernel_deps} math_arm) endif() diff --git a/lite/kernels/arm/elementwise_grad_compute.cc b/lite/kernels/arm/elementwise_grad_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..53971e26706f15032dae244e7bd0493a49376cd6 --- /dev/null +++ b/lite/kernels/arm/elementwise_grad_compute.cc @@ -0,0 +1,199 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/elementwise_grad_compute.h" +#include +#include +#include "lite/backends/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +inline DDim trim_trailing_singular_dims(const DDim& dims) { + // Remove trailing dimensions of size 1 for y + auto actual_dims_size = dims.size(); + for (; actual_dims_size != 0; --actual_dims_size) { + if (dims[actual_dims_size - 1] != 1) break; + } + + std::vector trim_dims; + trim_dims.resize(actual_dims_size); + for (int i = 0; i < actual_dims_size; ++i) { + trim_dims[i] = dims[i]; + } + if (trim_dims.size() == 0) { + return DDim(); + } + return DDim(trim_dims); +} + +inline bool is_broadcast(const DDim& x_dims, + const DDim& y_dims, + int axis, + int* pre, + int* n, + int* post) { + if (axis < 0) { + axis = x_dims.size() - y_dims.size(); + } + DDim y_dim_trim = trim_trailing_singular_dims(y_dims); + axis = (y_dim_trim.size() == 0) ? x_dims.size() : axis; + if (x_dims.size() == y_dim_trim.size()) { + return false; + } + *pre = 1; + *n = 1; + *post = 1; + for (int i = 0; i < axis; ++i) { + (*pre) *= x_dims[i]; + } + for (int i = 0; i < y_dim_trim.size(); ++i) { + CHECK_EQ(x_dims[i + axis], y_dim_trim[i]) + << "Broadcast dimension mismatch."; + (*n) *= y_dim_trim[i]; + } + for (int i = axis + y_dim_trim.size(); i < x_dims.size(); ++i) { + (*post) *= x_dims[i]; + } + return true; +} + +void ElementwiseAddGradCompute::Run() { + auto& param = Param(); + const float* x_data = param.X->data(); + const float* y_data = param.Y->data(); + const float* out_grad_data = param.OutGrad->data(); + float* x_grad_data = param.XGrad->mutable_data(); + float* y_grad_data = param.YGrad->mutable_data(); + int axis = param.axis; + auto x_dims = param.X->dims(); + auto y_dims = param.Y->dims(); + int pre, n, post; + if (x_dims.size() < y_dims.size() && + is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) { + lite::arm::math::elementwise_add_grad_broadcast( + out_grad_data, y_grad_data, x_grad_data, pre, n, post); + } else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { + lite::arm::math::elementwise_add_grad_broadcast( + out_grad_data, x_grad_data, y_grad_data, pre, n, post); + } else { + lite::arm::math::elementwise_add_grad( + out_grad_data, x_grad_data, x_dims.production()); + lite::arm::math::elementwise_add_grad( + out_grad_data, y_grad_data, y_dims.production()); + } +} + +void ElementwiseSubGradCompute::Run() { + auto& param = Param(); + const float* x_data = param.X->data(); + const float* y_data = param.Y->data(); + const float* out_data = param.OutGrad->data(); + float* x_grad_data = param.XGrad->mutable_data(); + float* y_grad_data = param.YGrad->mutable_data(); + int axis = param.axis; + auto x_dims = param.X->dims(); + auto y_dims = param.Y->dims(); + int pre, n, post; + if (x_dims.size() < y_dims.size()) { + LOG(FATAL) << "elewise div grad don't support x_dims size < y_dims size"; + } + if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { + lite::arm::math::elementwise_sub_grad_broadcast( + out_data, x_grad_data, y_grad_data, pre, n, post); + } else { + lite::arm::math::elementwise_sub_grad( + out_data, x_grad_data, y_grad_data, x_dims.production()); + } +} + +template +void ElementwiseMulGradCompute::Run() { + LOG(FATAL) << "elementwise mul_grad not implement yet"; +} + +void ElementwiseMaxGradCompute::Run() { + LOG(FATAL) << "elementwise max_grad not implement yet"; +} + +void ElementwiseDivGradCompute::Run() { + LOG(FATAL) << "elementwise div_grad not implement yet"; +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +using elementwise_mul_grad_float = + paddle::lite::kernels::arm::ElementwiseMulGradCompute; + +REGISTER_LITE_KERNEL(elementwise_add_grad, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::ElementwiseAddGradCompute, + def) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); + +REGISTER_LITE_KERNEL(elementwise_sub_grad, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::ElementwiseSubGradCompute, + def) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); + +REGISTER_LITE_KERNEL(elementwise_div_grad, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::ElementwiseDivGradCompute, + def) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); + +REGISTER_LITE_KERNEL( + elementwise_mul_grad, kARM, kFloat, kNCHW, elementwise_mul_grad_float, def) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); + +REGISTER_LITE_KERNEL(elementwise_max_grad, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::ElementwiseMaxGradCompute, + def) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Out@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("X@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Y@Grad", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/elementwise_grad_compute.h b/lite/kernels/arm/elementwise_grad_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..1273d8317410ce6689637e28597f9867702e1c2c --- /dev/null +++ b/lite/kernels/arm/elementwise_grad_compute.h @@ -0,0 +1,68 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class ElementwiseAddGradCompute + : public KernelLite { + public: + void Run() override; + + virtual ~ElementwiseAddGradCompute() = default; +}; + +class ElementwiseSubGradCompute + : public KernelLite { + public: + void Run() override; + + virtual ~ElementwiseSubGradCompute() = default; +}; + +template +class ElementwiseMulGradCompute : public KernelLite { + public: + void Run() override; + + virtual ~ElementwiseMulGradCompute() = default; +}; + +class ElementwiseMaxGradCompute + : public KernelLite { + public: + void Run() override; + + virtual ~ElementwiseMaxGradCompute() = default; +}; + +class ElementwiseDivGradCompute + : public KernelLite { + public: + void Run() override; + + virtual ~ElementwiseDivGradCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 4a606458d88f84bc8720d41dd22462132d14293e..48e27560317c089446e8dbc5040786f34ca962c4 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -144,6 +144,7 @@ add_operator(mean_op extra SRCS mean_op.cc DEPS ${op_DEPS}) if (LITE_WITH_TRAIN) add_operator(mean_grad_op extra SRCS mean_grad_op.cc DEPS ${op_DEPS}) add_operator(activation_grad_ops basic SRCS activation_grad_ops.cc DEPS ${op_DEPS}) + add_operator(elementwise_grad_op extra SRCS elementwise_grad_ops.cc DEPS ${op_DEPS}) add_operator(mul_grad_op basic SRCS mul_grad_op.cc DEPS ${op_DEPS}) add_operator(sgd_op extra SRCS sgd_op.cc DEPS ${op_DEPS}) endif() diff --git a/lite/operators/elementwise_grad_ops.cc b/lite/operators/elementwise_grad_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..8d9e1040976a98d890dc8c841cb4f70d81453d61 --- /dev/null +++ b/lite/operators/elementwise_grad_ops.cc @@ -0,0 +1,67 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/elementwise_grad_ops.h" +#include +#include +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace operators { + +bool ElementwiseGradOp::CheckShape() const { + CHECK_OR_FALSE(param_.XGrad); + CHECK_OR_FALSE(param_.YGrad); + CHECK_OR_FALSE(param_.OutGrad); + return true; +} + +bool ElementwiseGradOp::InferShape() const { + auto x_dim = param_.X->dims(); + auto y_dim = param_.Y->dims(); + param_.XGrad->Resize(x_dim); + param_.YGrad->Resize(y_dim); + return true; +} + +bool ElementwiseGradOp::AttachImpl(const cpp::OpDesc& opdesc, + lite::Scope* scope) { + auto Y_name = opdesc.Input("Y").front(); + auto X_name = opdesc.Input("X").front(); + auto Out_name = opdesc.Input("Out@Grad").front(); + auto x_grad_name = opdesc.Output("X@Grad").front(); + auto y_grad_name = opdesc.Output("Y@Grad").front(); + + param_.X = GetVar(scope, X_name); + param_.Y = GetVar(scope, Y_name); + param_.XGrad = GetMutableVar(scope, x_grad_name); + param_.YGrad = GetMutableVar(scope, y_grad_name); + param_.OutGrad = GetVar(scope, Out_name); + param_.axis = opdesc.GetAttr("axis"); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(elementwise_grad_sub, + paddle::lite::operators::ElementwiseGradOp); +REGISTER_LITE_OP(elementwise_grad_add, + paddle::lite::operators::ElementwiseGradOp); + +REGISTER_LITE_OP(elementwise_grad_mul, + paddle::lite::operators::ElementwiseGradOp); +REGISTER_LITE_OP(elementwise_grad_max, + paddle::lite::operators::ElementwiseGradOp); diff --git a/lite/operators/elementwise_grad_ops.h b/lite/operators/elementwise_grad_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..c45d581936207f0b37ee70a0505b912d0b509e35 --- /dev/null +++ b/lite/operators/elementwise_grad_ops.h @@ -0,0 +1,44 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace operators { + +class ElementwiseGradOp : public OpLite { + public: + explicit ElementwiseGradOp(const std::string& op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; + + void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "elementwise_grad_op"; } + + private: + mutable operators::ElementwiseGradParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 6d18f1bf348530fc111499ca7cbb89e9bec88d9d..36d3b42c6b315a3858f475bd5756579137528051 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -387,10 +387,11 @@ struct ElementwiseParam { }; struct ElementwiseGradParam { + const lite::Tensor* X{}; const lite::Tensor* Y{}; - const lite::Tensor* Out_grad{}; - lite::Tensor* X_grad{}; - lite::Tensor* Y_grad{}; + const lite::Tensor* OutGrad{}; + lite::Tensor* XGrad{}; + lite::Tensor* YGrad{}; int axis{-1}; // for broadcasting. }; diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index de0d530b867ea69b10f2eac7d8861fd8e8856b8f..ffed48cdc612bd7d5c7e701b0e198390976b7bef 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -65,6 +65,7 @@ if(LITE_BUILD_EXTRA) if (LITE_WITH_TRAIN) lite_cc_test(test_kernel_mean_compute SRCS mean_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_activation_grad_compute SRCS activation_grad_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_elementwise_grad_compute SRCS elementwise_grad_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_mul_grad_compute SRCS mul_grad_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_sgd_compute SRCS sgd_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) endif() diff --git a/lite/tests/kernels/elementwise_grad_compute_test.cc b/lite/tests/kernels/elementwise_grad_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2b5fbbb65d3d7e17bf90afb71f5c8154f0d88488 --- /dev/null +++ b/lite/tests/kernels/elementwise_grad_compute_test.cc @@ -0,0 +1,541 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/elementwise_grad_compute.h" +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/arm/elementwise_compute.h" +#include "lite/tests/utils/fill_data.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +using param_t = operators::ElementwiseParam; +using grad_param_t = operators::ElementwiseGradParam; +using kernel_add_t = ElementwiseAddCompute; +using grad_kernel_add_t = ElementwiseAddGradCompute; +using kernel_sub_t = ElementwiseSubCompute; +using grad_kernel_sub_t = ElementwiseSubGradCompute; + +void elementwise_common(grad_param_t& param, // NOLINT + std::vector& out_grad, // NOLINT + std::vector& x_grad, // NOLINT + std::vector& y_grad, // NOLINT + std::string flag) { + auto x_dims = param.X->dims(); + auto y_dims = param.Y->dims(); + if (x_dims == y_dims) { + for (int i = 0; i < x_dims.production(); ++i) { + if (flag == "add") { + x_grad[i] = out_grad[i]; + y_grad[i] = out_grad[i]; + } + if (flag == "sub") { + x_grad[i] = out_grad[i]; + y_grad[i] = -out_grad[i]; + } + } + } else { + LOG(FATAL) << "unsupport dims"; + } +} + +class ElementwiseAddGradTester { + public: + explicit ElementwiseAddGradTester(const DDim& x_dims, + const DDim& y_dims, + int axis) + : x_dims_(x_dims), y_dims_(y_dims), axis_(axis) {} + + void prepare_kernel() { + std::unique_ptr ctx1(new KernelContext); + ctx1->As(); + kernel_.SetContext(std::move(ctx1)); + + std::unique_ptr ctx3(new KernelContext); + ctx3->As(); + grad_kernel_.SetContext(std::move(ctx3)); + } + + void run_forward(param_t* param, + kernel_add_t* kernel, + const std::vector& x_vec, + const std::vector& y_vec, + float* out_vec) { + Tensor x; + Tensor y; + Tensor output; + x.Resize(x_dims_); + y.Resize(y_dims_); + output.Resize(DDim(out_dims_)); + auto* x_data = x.mutable_data(); + auto* y_data = y.mutable_data(); + for (int i = 0; i < x_dims_.production(); i++) { + x_data[i] = x_vec[i]; + } + for (int i = 0; i < y_dims_.production(); i++) { + y_data[i] = y_vec[i]; + } + + param->X = &x; + param->Y = &y; + param->Out = &output; + param->axis = axis_; + kernel->SetParam(*param); + kernel->Launch(); + + auto* output_data = output.mutable_data(); + for (int i = 0; i < out_dims_.production(); i++) { + out_vec[i] = output_data[i]; + } + } + + void run_backward(grad_param_t* param, + grad_kernel_add_t* kernel, + const std::vector& x_vec, + const std::vector& y_vec, + const std::vector& out_grad_vec, + float* x_grad_vec, + float* y_grad_vec) { + Tensor x; + Tensor x_grad; + Tensor y; + Tensor y_grad; + Tensor out_grad; + x.Resize(x_dims_); + x_grad.Resize(x_dims_); + y.Resize(y_dims_); + y_grad.Resize(y_dims_); + out_grad.Resize(out_dims_); + auto* x_data = x.mutable_data(); + auto* y_data = y.mutable_data(); + auto* out_grad_data = out_grad.mutable_data(); + for (int i = 0; i < x_dims_.production(); i++) { + x_data[i] = x_vec[i]; + } + for (int i = 0; i < y_dims_.production(); i++) { + y_data[i] = y_vec[i]; + } + for (int i = 0; i < out_dims_.production(); i++) { + out_grad_data[i] = out_grad_vec[i]; + } + + param->X = &x; + param->XGrad = &x_grad; + param->Y = &y; + param->YGrad = &y_grad; + param->OutGrad = &out_grad; + param->axis = axis_; + + kernel->SetParam(*param); + kernel->Launch(); + + auto* x_grad_data = x_grad.mutable_data(); + auto* y_grad_data = y_grad.mutable_data(); + for (int i = 0; i < x_dims_.production(); i++) { + x_grad_vec[i] = x_grad_data[i]; + } + for (int i = 0; i < y_dims_.production(); i++) { + y_grad_vec[i] = y_grad_data[i]; + } + } + + void check_grad(float delta2, float max_grad_delta2) { + std::vector out_shape; + // infer shape + auto x_dim = x_dims_; + auto y_dim = y_dims_; + if (x_dim == y_dim) { + out_dims_ = x_dim; + } else { + int max_dim = (x_dim.size() > y_dim.size() ? x_dim.size() : y_dim.size()); + int axis = param_.axis; + axis = + (axis == -1 ? std::abs(static_cast(x_dim.size() - y_dim.size())) + : axis); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector out_dims_array(max_dim); + + if (x_dim.size() > y_dim.size()) { + for (int i = 0; i < axis; ++i) { + y_dims_array[i] = 1; + } + if (axis + y_dim.size() < max_dim) { + for (int i = axis + y_dim.size(); i < max_dim; ++i) { + y_dims_array[i] = 1; + } + } + x_dims_array = x_dim.Vectorize(); + for (int i = 0; i < y_dim.size(); ++i) { + y_dims_array[i + axis] = y_dim[i]; + } + } else { + for (int i = 0; i < axis; ++i) { + x_dims_array[i] = 1; + } + if (axis + x_dim.size() < max_dim) { + for (int i = axis + x_dim.size(); i < max_dim; ++i) { + x_dims_array[i] = 1; + } + } + y_dims_array = y_dim.Vectorize(); + for (int i = 0; i < x_dim.size(); ++i) { + x_dims_array[i + axis] = x_dim[i]; + } + } + for (int i = 0; i < max_dim; i++) { + if (x_dims_array[i] == -1 || y_dims_array[i] == -1) { + out_dims_array[i] = -1; + } else { + out_dims_array[i] = std::max(x_dims_array[i], y_dims_array[i]); + } + } + out_dims_ = DDim(out_dims_array); + } + // infer end + // forward + std::vector x(x_dims_.production()); + std::vector y(y_dims_.production()); + std::vector out(out_dims_.production()); + fill_data_rand(x.data(), -1.f, 1.f, x_dims_.production()); + fill_data_rand(y.data(), -1.f, 1.f, y_dims_.production()); + this->run_forward(¶m_, &kernel_, x, y, out.data()); + + for (int i = 0; i < x_dims_.production(); i++) { + LOG(INFO) << "x_" << i << ": " << x[i]; + } + + for (int i = 0; i < y_dims_.production(); i++) { + LOG(INFO) << "y_" << i << ": " << y[i]; + } + + for (int i = 0; i < out_dims_.production(); i++) { + LOG(INFO) << "out_" << i << ": " << out[i]; + } + + // backward + std::vector out_grad(out_dims_.production()); + std::vector x_grad(x_dims_.production()); + std::vector y_grad(y_dims_.production()); + for (int i = 0; i < out_dims_.production(); i++) { + out_grad[i] = 1.0; + } + this->run_backward(&grad_param_, + &grad_kernel_, + x, + y, + out_grad, + x_grad.data(), + y_grad.data()); + + for (int i = 0; i < x_grad.size(); i++) { + LOG(INFO) << "x_grad_" << i << ": " << x_grad[i]; + } + + for (int i = 0; i < y_grad.size(); i++) { + LOG(INFO) << "y_grad_" << i << ": " << y_grad[i]; + } + + // get numeric gradient + std::vector x_delta(x_dims_.production()); + std::vector y_delta(y_dims_.production()); + std::vector out_delta(out_dims_.production()); + Tensor tensor_x; + Tensor tensor_y; + tensor_x.Resize(x_dims_); + tensor_y.Resize(y_dims_); + grad_param_.X = &tensor_x; + grad_param_.Y = &tensor_y; + + elementwise_common(grad_param_, out_grad, x_delta, y_delta, "add"); + + float max_grad_delta = 0.0005; + for (int i = 0; i < x_dims_.production(); i++) { + EXPECT_NEAR(x_grad[i], x_delta[i], max_grad_delta); + EXPECT_NEAR(y_grad[i], y_delta[i], max_grad_delta); + } + } + + private: + DDim x_dims_; + DDim y_dims_; + DDim out_dims_; + int axis_; + kernel_add_t kernel_; + grad_kernel_add_t grad_kernel_; + param_t param_; + grad_param_t grad_param_; +}; + +class ElementwiseSubGradTester { + public: + explicit ElementwiseSubGradTester(const DDim& x_dims, + const DDim& y_dims, + int axis) + : x_dims_(x_dims), y_dims_(y_dims), axis_(axis) {} + + void prepare_kernel() { + std::unique_ptr ctx1(new KernelContext); + ctx1->As(); + kernel_.SetContext(std::move(ctx1)); + + std::unique_ptr ctx3(new KernelContext); + ctx3->As(); + grad_kernel_.SetContext(std::move(ctx3)); + } + + void run_forward(param_t* param, + kernel_sub_t* kernel, + const std::vector& x_vec, + const std::vector& y_vec, + float* out_vec) { + Tensor x; + Tensor y; + Tensor output; + x.Resize(x_dims_); + y.Resize(y_dims_); + output.Resize(DDim(out_dims_)); + auto* x_data = x.mutable_data(); + auto* y_data = y.mutable_data(); + for (int i = 0; i < x_dims_.production(); i++) { + x_data[i] = x_vec[i]; + } + for (int i = 0; i < y_dims_.production(); i++) { + y_data[i] = y_vec[i]; + } + + param->X = &x; + param->Y = &y; + param->Out = &output; + param->axis = axis_; + kernel->SetParam(*param); + kernel->Launch(); + + auto* output_data = output.mutable_data(); + for (int i = 0; i < out_dims_.production(); i++) { + out_vec[i] = output_data[i]; + } + } + + void run_backward(grad_param_t* param, + grad_kernel_sub_t* kernel, + const std::vector& x_vec, + const std::vector& y_vec, + const std::vector& out_grad_vec, + float* x_grad_vec, + float* y_grad_vec) { + Tensor x; + Tensor x_grad; + Tensor y; + Tensor y_grad; + Tensor out_grad; + x.Resize(x_dims_); + x_grad.Resize(x_dims_); + y.Resize(y_dims_); + y_grad.Resize(y_dims_); + out_grad.Resize(out_dims_); + auto* x_data = x.mutable_data(); + auto* y_data = y.mutable_data(); + auto* out_grad_data = out_grad.mutable_data(); + for (int i = 0; i < x_dims_.production(); i++) { + x_data[i] = x_vec[i]; + } + for (int i = 0; i < y_dims_.production(); i++) { + y_data[i] = y_vec[i]; + } + for (int i = 0; i < out_dims_.production(); i++) { + out_grad_data[i] = out_grad_vec[i]; + } + + param->X = &x; + param->XGrad = &x_grad; + param->Y = &y; + param->YGrad = &y_grad; + param->OutGrad = &out_grad; + param->axis = axis_; + + kernel->SetParam(*param); + kernel->Launch(); + + auto* x_grad_data = x_grad.mutable_data(); + auto* y_grad_data = y_grad.mutable_data(); + for (int i = 0; i < x_dims_.production(); i++) { + x_grad_vec[i] = x_grad_data[i]; + } + for (int i = 0; i < y_dims_.production(); i++) { + y_grad_vec[i] = y_grad_data[i]; + } + } + + void check_grad(float delta2, float max_grad_delta2) { + std::vector out_shape; + // infer shape + auto x_dim = x_dims_; + auto y_dim = y_dims_; + if (x_dim == y_dim) { + out_dims_ = x_dim; + } else { + int max_dim = (x_dim.size() > y_dim.size() ? x_dim.size() : y_dim.size()); + int axis = param_.axis; + axis = + (axis == -1 ? std::abs(static_cast(x_dim.size() - y_dim.size())) + : axis); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector out_dims_array(max_dim); + + if (x_dim.size() > y_dim.size()) { + for (int i = 0; i < axis; ++i) { + y_dims_array[i] = 1; + } + if (axis + y_dim.size() < max_dim) { + for (int i = axis + y_dim.size(); i < max_dim; ++i) { + y_dims_array[i] = 1; + } + } + x_dims_array = x_dim.Vectorize(); + for (int i = 0; i < y_dim.size(); ++i) { + y_dims_array[i + axis] = y_dim[i]; + } + } else { + for (int i = 0; i < axis; ++i) { + x_dims_array[i] = 1; + } + if (axis + x_dim.size() < max_dim) { + for (int i = axis + x_dim.size(); i < max_dim; ++i) { + x_dims_array[i] = 1; + } + } + y_dims_array = y_dim.Vectorize(); + for (int i = 0; i < x_dim.size(); ++i) { + x_dims_array[i + axis] = x_dim[i]; + } + } + for (int i = 0; i < max_dim; i++) { + if (x_dims_array[i] == -1 || y_dims_array[i] == -1) { + out_dims_array[i] = -1; + } else { + out_dims_array[i] = std::max(x_dims_array[i], y_dims_array[i]); + } + } + out_dims_ = DDim(out_dims_array); + } + // infer end + // forward + std::vector x(x_dims_.production()); + std::vector y(y_dims_.production()); + std::vector out(out_dims_.production()); + fill_data_rand(x.data(), -1.f, 1.f, x_dims_.production()); + fill_data_rand(y.data(), -1.f, 1.f, y_dims_.production()); + this->run_forward(¶m_, &kernel_, x, y, out.data()); + + for (int i = 0; i < x_dims_.production(); i++) { + LOG(INFO) << "x_" << i << ": " << x[i]; + } + + for (int i = 0; i < y_dims_.production(); i++) { + LOG(INFO) << "y_" << i << ": " << y[i]; + } + + for (int i = 0; i < out_dims_.production(); i++) { + LOG(INFO) << "out_" << i << ": " << out[i]; + } + + // backward + std::vector out_grad(out_dims_.production()); + std::vector x_grad(x_dims_.production()); + std::vector y_grad(y_dims_.production()); + for (int i = 0; i < out_dims_.production(); i++) { + out_grad[i] = 1.0; + } + this->run_backward(&grad_param_, + &grad_kernel_, + x, + y, + out_grad, + x_grad.data(), + y_grad.data()); + + for (int i = 0; i < x_grad.size(); i++) { + LOG(INFO) << "x_grad_" << i << ": " << x_grad[i]; + } + + for (int i = 0; i < y_grad.size(); i++) { + LOG(INFO) << "y_grad_" << i << ": " << y_grad[i]; + } + + // get numeric gradient + std::vector x_delta(x_dims_.production()); + std::vector y_delta(y_dims_.production()); + std::vector out_delta(out_dims_.production()); + Tensor tensor_x; + Tensor tensor_y; + tensor_x.Resize(x_dims_); + tensor_y.Resize(y_dims_); + grad_param_.X = &tensor_x; + grad_param_.Y = &tensor_y; + + elementwise_common(grad_param_, out_grad, x_delta, y_delta, "sub"); + + float max_grad_delta = 0.0005; + for (int i = 0; i < x_dims_.production(); i++) { + EXPECT_NEAR(x_grad[i], x_delta[i], max_grad_delta); + EXPECT_NEAR(y_grad[i], y_delta[i], max_grad_delta); + } + } + + private: + DDim x_dims_; + DDim y_dims_; + DDim out_dims_; + int axis_; + kernel_sub_t kernel_; + grad_kernel_sub_t grad_kernel_; + param_t param_; + grad_param_t grad_param_; +}; +void TestNormalCase(const std::vector& x_dims, + const std::vector& y_dims, + int axis) { + std::unique_ptr tester_add( + new ElementwiseAddGradTester(DDim(x_dims), DDim(y_dims), axis)); + std::unique_ptr tester_sub( + new ElementwiseSubGradTester(DDim(x_dims), DDim(y_dims), axis)); + + tester_add->prepare_kernel(); + tester_sub->prepare_kernel(); + float delta = 0.001; + float max_grad_delta = 0.005; + tester_add->check_grad(delta, max_grad_delta); + tester_sub->check_grad(delta, max_grad_delta); +} + +TEST(mul_grad_arm, compute) { + LOG(INFO) << "Test Elementwise grad"; + DeviceInfo::Init(); + TestNormalCase({3, 2}, {3, 2}, 0); + TestNormalCase({3, 5}, {3, 5}, 1); + TestNormalCase({3, 4, 3}, {3, 4, 3}, 0); + TestNormalCase({9, 2, 5}, {9, 2, 5}, 1); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle +USE_LITE_KERNEL(elementwise_add_grad, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, def);