diff --git a/paddle/fluid/lite/arm/math/elementwise.cc b/paddle/fluid/lite/arm/math/elementwise.cc index 2a74e7ee4ec4be51b420b1fa2d2a1be7c3f148fb..7c1ea8d3a70451dd790a9eea516b74f58ec91d5e 100644 --- a/paddle/fluid/lite/arm/math/elementwise.cc +++ b/paddle/fluid/lite/arm/math/elementwise.cc @@ -65,9 +65,61 @@ void elementwise_add(const float* dinx, const float* diny, float* dout, } template <> -void elementwise_add_axis(const float* dinx, const float* diny, - float* dout, int batch, int channels, - int num) { +void elementwise_add_relu(const float* dinx, const float* diny, + float* dout, int num) { + int cnt = num >> 4; + int remain = num % 16; + float32x4_t vzero = vdupq_n_f32(0.f); +#pragma omp parallel for + for (int i = 0; i < cnt; i++) { + const float* dinx_ptr = dinx + (i << 4); + const float* diny_ptr = diny + (i << 4); + float* dout_ptr = dout + (i << 4); + + float32x4_t dinx0 = vld1q_f32(dinx_ptr); + float32x4_t dinx1 = vld1q_f32(dinx_ptr + 4); + float32x4_t dinx2 = vld1q_f32(dinx_ptr + 8); + float32x4_t dinx3 = vld1q_f32(dinx_ptr + 12); + + float32x4_t diny0 = vld1q_f32(diny_ptr); + float32x4_t diny1 = vld1q_f32(diny_ptr + 4); + float32x4_t diny2 = vld1q_f32(diny_ptr + 8); + float32x4_t diny3 = vld1q_f32(diny_ptr + 12); + + dinx0 = vaddq_f32(dinx0, diny0); + dinx1 = vaddq_f32(dinx1, diny1); + dinx2 = vaddq_f32(dinx2, diny2); + dinx3 = vaddq_f32(dinx3, diny3); + + // relu + dinx0 = vmaxq_f32(dinx0, vzero); + dinx1 = vmaxq_f32(dinx1, vzero); + dinx2 = vmaxq_f32(dinx2, vzero); + dinx3 = vmaxq_f32(dinx3, vzero); + + vst1q_f32(dout_ptr, dinx0); + vst1q_f32(dout_ptr + 4, dinx1); + vst1q_f32(dout_ptr + 8, dinx2); + vst1q_f32(dout_ptr + 12, dinx3); + } + if (remain > 0) { + const float* dinx_ptr = dinx + (cnt << 4); + const float* diny_ptr = diny + (cnt << 4); + float* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; i++) { + float tmp = *dinx_ptr + *diny_ptr; + *dout_ptr = tmp > 0.f ? tmp : 0.f; + dout_ptr++; + dinx_ptr++; + diny_ptr++; + } + } +} + +template <> +void elementwise_add_broadcast(const float* dinx, const float* diny, + float* dout, int batch, int channels, + int num) { #pragma omp parallel for collapse(2) for (int i = 0; i < batch; ++i) { for (int j = 0; j < channels; ++j) { @@ -127,6 +179,82 @@ void elementwise_add_axis(const float* dinx, const float* diny, } } +template <> +void elementwise_add_relu_broadcast(const float* dinx, const float* diny, + float* dout, int batch, int channels, + int num) { + float32x4_t vzero = vdupq_n_f32(0.f); +#pragma omp parallel for collapse(2) + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const float* din_ptr = dinx + offset; + const float diny_data = diny[j]; + float* dout_ptr = dout + offset; + + int cnt = num >> 4; + int remain = num % 16; + float32x4_t rb = vdupq_n_f32(diny_data); + for (int k = 0; k < cnt; ++k) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + din0 = vaddq_f32(din0, rb); + din1 = vaddq_f32(din1, rb); + din2 = vaddq_f32(din2, rb); + din3 = vaddq_f32(din3, rb); + + // relu + din0 = vmaxq_f32(din0, vzero); + din1 = vmaxq_f32(din1, vzero); + din2 = vmaxq_f32(din2, vzero); + din3 = vmaxq_f32(din3, vzero); + + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + vst1q_f32(dout_ptr + 8, din2); + vst1q_f32(dout_ptr + 12, din3); + din_ptr += 16; + dout_ptr += 16; + } + if (remain >= 8) { + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + din0 = vaddq_f32(din0, rb); + din1 = vaddq_f32(din1, rb); + // relu + din0 = vmaxq_f32(din0, vzero); + din1 = vmaxq_f32(din1, vzero); + vst1q_f32(dout_ptr, din0); + vst1q_f32(dout_ptr + 4, din1); + din_ptr += 8; + dout_ptr += 8; + remain -= 8; + } + if (remain >= 4) { + float32x4_t din0 = vld1q_f32(din_ptr); + din0 = vaddq_f32(din0, rb); + // relu + din0 = vmaxq_f32(din0, vzero); + vst1q_f32(dout_ptr, din0); + din_ptr += 4; + dout_ptr += 4; + remain -= 4; + } + if (remain > 0) { + for (int p = 0; p < remain; p++) { + float tmp = *din_ptr + diny_data; + *dout_ptr = tmp > 0.f ? tmp : 0.f; + dout_ptr++; + din_ptr++; + } + } + } + } +} + } // namespace math } // namespace arm } // namespace lite diff --git a/paddle/fluid/lite/arm/math/elementwise.h b/paddle/fluid/lite/arm/math/elementwise.h index ca8f87895fcea80f9a1a178a0bf43b34c44182bb..9300d73753d695819af6ec7066fd95020457bd29 100644 --- a/paddle/fluid/lite/arm/math/elementwise.h +++ b/paddle/fluid/lite/arm/math/elementwise.h @@ -23,8 +23,15 @@ template void elementwise_add(const T* dinx, const T* diny, T* dout, int num); template -void elementwise_add_axis(const T* dinx, const T* diny, T* dout, int batch, - int channels, int num); +void elementwise_add_relu(const T* dinx, const T* diny, T* dout, int num); + +template +void elementwise_add_broadcast(const T* dinx, const T* diny, T* dout, int batch, + int channels, int num); + +template +void elementwise_add_relu_broadcast(const T* dinx, const T* diny, T* dout, + int batch, int channels, int num); } // namespace math } // namespace arm diff --git a/paddle/fluid/lite/kernels/arm/CMakeLists.txt b/paddle/fluid/lite/kernels/arm/CMakeLists.txt index 040b80c113162e0a87325c0bf353522677dbc9c8..95c8b95ec16aef37c6642df98c2b011b1d3a15a8 100644 --- a/paddle/fluid/lite/kernels/arm/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/arm/CMakeLists.txt @@ -11,7 +11,7 @@ cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(conv_compute_arm SRCS conv_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(batch_norm_compute_arm SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps} math_arm) -cc_library(elementwise_add_compute_arm SRCS elementwise_add_compute.cc DEPS ${lite_kernel_deps} math_arm) +cc_library(elementwise_compute_arm SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(pool_compute_arm SRCS pool_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(split_compute_arm SRCS split_compute.cc DEPS ${lite_kernel_deps} math_arm) cc_library(concat_compute_arm SRCS concat_compute.cc DEPS ${lite_kernel_deps} math_arm) @@ -24,7 +24,7 @@ lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_comput lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) lite_cc_test(test_conv_compute_arm SRCS conv_compute_test.cc DEPS conv_compute_arm) lite_cc_test(test_batch_norm_compute_arm SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_arm) -lite_cc_test(test_elementwise_add_compute_arm SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_arm) +lite_cc_test(test_elementwise_compute_arm SRCS elementwise_compute_test.cc DEPS elementwise_compute_arm) lite_cc_test(test_pool_compute_arm SRCS pool_compute_test.cc DEPS pool_compute_arm) lite_cc_test(test_mul_compute_arm SRCS mul_compute_test.cc DEPS mul_compute_arm) lite_cc_test(test_split_compute_arm SRCS split_compute_test.cc DEPS split_compute_arm) @@ -40,7 +40,7 @@ set(arm_kernels softmax_compute_arm conv_compute_arm batch_norm_compute_arm - elementwise_add_compute_arm + elementwise_compute_arm pool_compute_arm split_compute_arm concat_compute_arm diff --git a/paddle/fluid/lite/kernels/arm/elementwise_add_compute.cc b/paddle/fluid/lite/kernels/arm/elementwise_add_compute.cc deleted file mode 100644 index e9d9f4927b7ee18b3e18efa69a00dcb1c813bf3b..0000000000000000000000000000000000000000 --- a/paddle/fluid/lite/kernels/arm/elementwise_add_compute.cc +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/lite/kernels/arm/elementwise_add_compute.h" -#include "paddle/fluid/lite/arm/math/funcs.h" - -namespace paddle { -namespace lite { -namespace kernels { -namespace arm { - -void ElementwiseAddCompute::Run() { - auto& param = Param(); - const float* x_data = param.X->data(); - const float* y_data = param.Y->data(); - float* out_data = param.Out->mutable_data(); - int axis = param.axis; - auto x_dims = param.X->dims(); - auto y_dims = param.Y->dims(); - if (axis < 0) { - axis = x_dims.size() - y_dims.size(); - } - if (x_dims.size() == y_dims.size()) { - lite::arm::math::elementwise_add(x_data, y_data, out_data, - x_dims.production()); - } else { - int batch = 1; - int channels = 1; - int num = 1; - for (int i = 0; i < axis; ++i) { - batch *= x_dims[i]; - } - for (int i = 0; i < y_dims.size(); ++i) { - channels *= y_dims[i]; - } - for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) { - num *= x_dims[i]; - } - lite::arm::math::elementwise_add_axis(x_data, y_data, out_data, batch, - channels, num); - } -} - -} // namespace arm -} // namespace kernels -} // namespace lite -} // namespace paddle - -REGISTER_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, - paddle::lite::kernels::arm::ElementwiseAddCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) - .Finalize(); diff --git a/paddle/fluid/lite/kernels/arm/elementwise_add_compute_test.cc b/paddle/fluid/lite/kernels/arm/elementwise_add_compute_test.cc deleted file mode 100644 index 20b998dc6cfa8a9606fcf0f716470366fdd60338..0000000000000000000000000000000000000000 --- a/paddle/fluid/lite/kernels/arm/elementwise_add_compute_test.cc +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/lite/kernels/arm/elementwise_add_compute.h" -#include -#include -#include "paddle/fluid/lite/core/op_registry.h" - -namespace paddle { -namespace lite { -namespace kernels { -namespace arm { - -TEST(elementwise_add_arm, retrive_op) { - auto elementwise_add = - KernelRegistry::Global().Create( - "elementwise_add"); - ASSERT_FALSE(elementwise_add.empty()); - ASSERT_TRUE(elementwise_add.front()); -} - -TEST(elementwise_add_arm, init) { - ElementwiseAddCompute elementwise_add; - ASSERT_EQ(elementwise_add.precision(), PRECISION(kFloat)); - ASSERT_EQ(elementwise_add.target(), TARGET(kARM)); -} - -template -void elementwise_add_compute_ref(const operators::ElementwiseParam& param) { - const dtype* x_data = param.X->data(); - const dtype* y_data = param.Y->data(); - dtype* out_data = param.Out->mutable_data(); - auto x_dims = param.X->dims(); - auto y_dims = param.Y->dims(); - int axis = param.axis; - if (axis < 0) { - axis = x_dims.size() - y_dims.size(); - } - int batch = 1; - int channels = 1; - int num = 1; - for (int i = 0; i < axis; ++i) { - batch *= x_dims[i]; - } - for (int i = 0; i < y_dims.size(); ++i) { - channels *= y_dims[i]; - } - for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) { - num *= x_dims[i]; - } - for (int i = 0; i < batch; ++i) { - for (int j = 0; j < channels; ++j) { - int offset = (i * channels + j) * num; - const dtype* din_ptr = x_data + offset; - const dtype diny_data = y_data[j]; - dtype* dout_ptr = out_data + offset; - for (int k = 0; k < num; ++k) { - *dout_ptr = *din_ptr + diny_data; - dout_ptr++; - din_ptr++; - } - } - } -} - -TEST(elementwise_add, compute) { - ElementwiseAddCompute elementwise_add; - operators::ElementwiseParam param; - lite::Tensor x, y, output, output_ref; - - for (auto n : {1, 3, 4, 11}) { - for (auto c : {1, 3, 4, 11}) { - for (auto h : {1, 3, 4, 11}) { - for (auto w : {1, 3, 4, 11}) { - for (auto axis : {-1, 0, 1, 2, 3}) { - for (auto yd : - {std::vector({n}), std::vector({c}), - std::vector({h}), std::vector({w}), - std::vector({n, c}), std::vector({c, h}), - std::vector({h, w}), std::vector({n, c, h}), - std::vector({c, h, w}), - std::vector({n, c, h, w})}) { - auto x_dim = DDim(std::vector({n, c, h, w})); - auto y_dim = DDim(yd); - int axis_t = axis < 0 ? x_dim.size() - y_dim.size() : axis; - - if (axis_t + y_dim.size() > 4) continue; - bool flag = false; - for (int i = 0; i < y_dim.size(); i++) { - if (x_dim[i + axis_t] != y_dim[i]) flag = true; - } - if (flag) continue; - - x.Resize(x_dim); - y.Resize(y_dim); - output.Resize(x_dim); - output_ref.Resize(x_dim); - auto* x_data = x.mutable_data(); - auto* y_data = y.mutable_data(); - auto* output_data = output.mutable_data(); - auto* output_ref_data = output_ref.mutable_data(); - for (int i = 0; i < x_dim.production(); i++) { - x_data[i] = i; - } - for (int i = 0; i < y_dim.production(); i++) { - y_data[i] = i; - } - param.X = &x; - param.Y = &y; - param.axis = axis; - param.Out = &output; - elementwise_add.SetParam(param); - elementwise_add.Run(); - param.Out = &output_ref; - elementwise_add_compute_ref(param); - for (int i = 0; i < output.dims().production(); i++) { - EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); - } - } - } - } - } - } - } -} - -} // namespace arm -} // namespace kernels -} // namespace lite -} // namespace paddle - -USE_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, def); diff --git a/paddle/fluid/lite/kernels/arm/elementwise_compute.cc b/paddle/fluid/lite/kernels/arm/elementwise_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..c3b9b41cde1e70ecef580f72cfbb6c558258631d --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/elementwise_compute.cc @@ -0,0 +1,111 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/elementwise_compute.h" +#include +#include "paddle/fluid/lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +inline bool is_broadcast(const DDim& x_dims, const DDim& y_dims, int axis, + int* pre, int* n, int* post) { + if (axis < 0) { + axis = x_dims.size() - y_dims.size(); + } + if (x_dims.size() == y_dims.size()) { + return false; + } + *pre = 1; + *n = 1; + *post = 1; + for (int i = 0; i < axis; ++i) { + (*pre) *= x_dims[i]; + } + for (int i = 0; i < y_dims.size(); ++i) { + CHECK_EQ(x_dims[i + axis], y_dims[i]) << "Broadcast dimension mismatch."; + (*n) *= y_dims[i]; + } + for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) { + (*post) *= x_dims[i]; + } + return true; +} + +void ElementwiseAddCompute::Run() { + auto& param = Param(); + const float* x_data = param.X->data(); + const float* y_data = param.Y->data(); + float* out_data = param.Out->mutable_data(); + int axis = param.axis; + auto x_dims = param.X->dims(); + auto y_dims = param.Y->dims(); + int pre, n, post; + if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { + lite::arm::math::elementwise_add_broadcast(x_data, y_data, out_data, pre, n, + post); + } else { + lite::arm::math::elementwise_add(x_data, y_data, out_data, + x_dims.production()); + } +} + +void ElementwiseAddActivationCompute::Run() { + auto& param = Param(); + const float* x_data = param.X->data(); + const float* y_data = param.Y->data(); + float* out_data = param.Out->mutable_data(); + int axis = param.axis; + std::string act_type = param.act_type; + auto x_dims = param.X->dims(); + auto y_dims = param.Y->dims(); + int pre, n, post; + if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { + if (act_type == "relu") { + lite::arm::math::elementwise_add_relu_broadcast(x_data, y_data, out_data, + pre, n, post); + } else { + LOG(FATAL) << "unsupported Activation type: " << act_type; + } + } else { + if (act_type == "relu") { + lite::arm::math::elementwise_add_relu(x_data, y_data, out_data, + x_dims.production()); + } else { + LOG(FATAL) << "unsupported Activation type: " << act_type; + } + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, + paddle::lite::kernels::arm::ElementwiseAddCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); + +REGISTER_LITE_KERNEL( + fusion_elementwise_add_activation, kARM, kFloat, kNCHW, + paddle::lite::kernels::arm::ElementwiseAddActivationCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/arm/elementwise_add_compute.h b/paddle/fluid/lite/kernels/arm/elementwise_compute.h similarity index 85% rename from paddle/fluid/lite/kernels/arm/elementwise_add_compute.h rename to paddle/fluid/lite/kernels/arm/elementwise_compute.h index 9939509d0be25eadccdb563e802c98291dea751b..bb80c61221eea2acaad397895d3fbad880e9dce3 100644 --- a/paddle/fluid/lite/kernels/arm/elementwise_add_compute.h +++ b/paddle/fluid/lite/kernels/arm/elementwise_compute.h @@ -30,6 +30,14 @@ class ElementwiseAddCompute virtual ~ElementwiseAddCompute() = default; }; +class ElementwiseAddActivationCompute + : public KernelLite { + public: + void Run() override; + + virtual ~ElementwiseAddActivationCompute() = default; +}; + } // namespace arm } // namespace kernels } // namespace lite diff --git a/paddle/fluid/lite/kernels/arm/elementwise_compute_test.cc b/paddle/fluid/lite/kernels/arm/elementwise_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e242c8cc583ecb418ad0c1ebd9dcbde0003b9e7 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/elementwise_compute_test.cc @@ -0,0 +1,263 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/elementwise_compute.h" +#include +#include +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +TEST(elementwise_add_arm, retrive_op) { + auto elementwise_add = + KernelRegistry::Global().Create( + "elementwise_add"); + ASSERT_FALSE(elementwise_add.empty()); + ASSERT_TRUE(elementwise_add.front()); +} + +TEST(elementwise_add_arm, init) { + ElementwiseAddCompute elementwise_add; + ASSERT_EQ(elementwise_add.precision(), PRECISION(kFloat)); + ASSERT_EQ(elementwise_add.target(), TARGET(kARM)); +} + +template +void elementwise_compute_ref(const operators::ElementwiseParam& param, + const std::string elt_type, + const std::string act_type) { + const dtype* x_data = param.X->data(); + const dtype* y_data = param.Y->data(); + dtype* out_data = param.Out->mutable_data(); + auto x_dims = param.X->dims(); + auto y_dims = param.Y->dims(); + int axis = param.axis; + if (axis < 0) { + axis = x_dims.size() - y_dims.size(); + } + int batch = 1; + int channels = 1; + int num = 1; + for (int i = 0; i < axis; ++i) { + batch *= x_dims[i]; + } + for (int i = 0; i < y_dims.size(); ++i) { + channels *= y_dims[i]; + } + for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) { + num *= x_dims[i]; + } + // do elementwise add/sub/max... + if (elt_type == "add") { + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const dtype* din_ptr = x_data + offset; + const dtype diny_data = y_data[j]; + dtype* dout_ptr = out_data + offset; + for (int k = 0; k < num; ++k) { + *dout_ptr = *din_ptr + diny_data; + dout_ptr++; + din_ptr++; + } + } + } + } else if (elt_type == "sub") { + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const dtype* din_ptr = x_data + offset; + const dtype diny_data = y_data[j]; + dtype* dout_ptr = out_data + offset; + for (int k = 0; k < num; ++k) { + *dout_ptr = *din_ptr - diny_data; + dout_ptr++; + din_ptr++; + } + } + } + } else { + LOG(FATAL) << "unsupported Elementwise type: " << elt_type; + } + // do activation relu/sigmod... + if (act_type.size() > 0) { + if (act_type == "relu") { + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + dtype* dout_ptr = out_data + (i * channels + j) * num; + for (int k = 0; k < num; ++k) { + *dout_ptr = *dout_ptr > 0.0f ? *dout_ptr : 0.0f; + dout_ptr++; + } + } + } + } else { + LOG(FATAL) << "unsupported Activation type: " << elt_type; + } + } +} + +TEST(elementwise_add, compute) { + ElementwiseAddCompute elementwise_add; + operators::ElementwiseParam param; + lite::Tensor x, y, output, output_ref; + + for (auto n : {1, 3, 4, 11}) { + for (auto c : {1, 3, 4, 11}) { + for (auto h : {1, 3, 4, 11}) { + for (auto w : {1, 3, 4, 11}) { + for (auto axis : {-1, 0, 1, 2, 3}) { + for (auto yd : + {std::vector({n}), std::vector({c}), + std::vector({h}), std::vector({w}), + std::vector({n, c}), std::vector({c, h}), + std::vector({h, w}), std::vector({n, c, h}), + std::vector({c, h, w}), + std::vector({n, c, h, w})}) { + auto x_dim = DDim(std::vector({n, c, h, w})); + auto y_dim = DDim(yd); + int axis_t = axis < 0 ? x_dim.size() - y_dim.size() : axis; + + if (axis_t + y_dim.size() > 4) continue; + bool flag = false; + for (int i = 0; i < y_dim.size(); i++) { + if (x_dim[i + axis_t] != y_dim[i]) flag = true; + } + if (flag) continue; + + x.Resize(x_dim); + y.Resize(y_dim); + output.Resize(x_dim); + output_ref.Resize(x_dim); + auto* x_data = x.mutable_data(); + auto* y_data = y.mutable_data(); + auto* output_data = output.mutable_data(); + auto* output_ref_data = output_ref.mutable_data(); + for (int i = 0; i < x_dim.production(); i++) { + x_data[i] = i; + } + for (int i = 0; i < y_dim.production(); i++) { + y_data[i] = i; + } + param.X = &x; + param.Y = &y; + param.axis = axis; + param.Out = &output; + elementwise_add.SetParam(param); + elementwise_add.Run(); + param.Out = &output_ref; + elementwise_compute_ref(param, "add", ""); + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); + } + } + } + } + } + } + } +} + +TEST(fusion_elementwise_add_activation_arm, retrive_op) { + auto fusion_elementwise_add_activation = + KernelRegistry::Global().Create( + "fusion_elementwise_add_activation"); + ASSERT_FALSE(fusion_elementwise_add_activation.empty()); + ASSERT_TRUE(fusion_elementwise_add_activation.front()); +} + +TEST(fusion_elementwise_add_activation_arm, init) { + ElementwiseAddActivationCompute fusion_elementwise_add_activation; + ASSERT_EQ(fusion_elementwise_add_activation.precision(), PRECISION(kFloat)); + ASSERT_EQ(fusion_elementwise_add_activation.target(), TARGET(kARM)); +} + +TEST(fusion_elementwise_add_activation_arm, compute) { + ElementwiseAddActivationCompute fusion_elementwise_add_activation; + operators::FusionElementwiseActivationParam param; + lite::Tensor x, y, output, output_ref; + + for (auto act_type : {"relu"}) { + for (auto n : {1, 3, 4, 11}) { + for (auto c : {1, 3, 4, 11}) { + for (auto h : {1, 3, 4, 11}) { + for (auto w : {1, 3, 4, 11}) { + for (auto axis : {-1, 0, 1, 2, 3}) { + for (auto yd : + {std::vector({n}), std::vector({c}), + std::vector({h}), std::vector({w}), + std::vector({n, c}), std::vector({c, h}), + std::vector({h, w}), + std::vector({n, c, h}), + std::vector({c, h, w}), + std::vector({n, c, h, w})}) { + auto x_dim = DDim(std::vector({n, c, h, w})); + auto y_dim = DDim(yd); + int axis_t = axis < 0 ? x_dim.size() - y_dim.size() : axis; + + if (axis_t + y_dim.size() > 4) continue; + bool flag = false; + for (int i = 0; i < y_dim.size(); i++) { + if (x_dim[i + axis_t] != y_dim[i]) flag = true; + } + if (flag) continue; + + x.Resize(x_dim); + y.Resize(y_dim); + output.Resize(x_dim); + output_ref.Resize(x_dim); + auto* x_data = x.mutable_data(); + auto* y_data = y.mutable_data(); + auto* output_data = output.mutable_data(); + auto* output_ref_data = output_ref.mutable_data(); + for (int i = 0; i < x_dim.production(); i++) { + float sign = i % 3 == 0 ? -1.0f : 1.0f; + x_data[i] = i * sign; + } + for (int i = 0; i < y_dim.production(); i++) { + float sign = i % 2 == 0 ? 0.5f : -0.5f; + y_data[i] = i * sign; + } + param.X = &x; + param.Y = &y; + param.axis = axis; + param.Out = &output; + param.act_type = act_type; + fusion_elementwise_add_activation.SetParam(param); + fusion_elementwise_add_activation.Run(); + param.Out = &output_ref; + elementwise_compute_ref(param, "add", act_type); + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); + } + } + } + } + } + } + } + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(fusion_elementwise_add_activation, kARM, kFloat, kNCHW, def);