From 32b5caa0c017891ec62f8617e7246070248696fd Mon Sep 17 00:00:00 2001 From: zhangwen31 Date: Mon, 14 Sep 2020 05:21:03 +0000 Subject: [PATCH] [arm][math]feat: elementwise_add support i32 and i64 now --- lite/backends/arm/math/elementwise.cc | 83 +++++++++++++++++++++++++ lite/kernels/arm/elementwise_compute.cc | 20 ++++++ 2 files changed, 103 insertions(+) diff --git a/lite/backends/arm/math/elementwise.cc b/lite/backends/arm/math/elementwise.cc index 3f92300ae8..7b9c324451 100644 --- a/lite/backends/arm/math/elementwise.cc +++ b/lite/backends/arm/math/elementwise.cc @@ -21,6 +21,41 @@ namespace lite { namespace arm { namespace math { +// todo: remove this function when all elementwise_add works +template +static void naive_elementwise_op( + const T* dinx, const T* diny, T* dout, int num, std::function op) { + for (int i = 0; i < num; ++i) { + *dout = op(*dinx, *diny); + ++dinx; + ++diny; + ++dout; + } +} +// todo: remove this function when all elementwise_add works +template +static T naive_add(T l, T r) { + return l + r; +} + +// todo: use arm intrinsics +template <> +void elementwise_add(const int32_t* dinx, + const int32_t* diny, + int32_t* dout, + int num) { + naive_elementwise_op(dinx, diny, dout, num, naive_add); +} + +// todo: use arm intrinsics +template <> +void elementwise_add(const int64_t* dinx, + const int64_t* diny, + int64_t* dout, + int num) { + naive_elementwise_op(dinx, diny, dout, num, naive_add); +} + template <> void elementwise_add(const float* dinx, const float* diny, @@ -178,6 +213,54 @@ void elementwise_add_tanh(const float* dinx, } } +// todo: remove this function when all elementwise_add works +template +static void naive_elementwise_op_broadcast(const T* x_data, + const T* y_data, + T* out_data, + int batch, + int channels, + int num, + std::function op) { + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const T* din_ptr = x_data + offset; + const T diny_data = y_data[j]; + T* dout_ptr = out_data + offset; + for (int k = 0; k < num; ++k) { + *dout_ptr = op(*din_ptr, diny_data); + dout_ptr++; + din_ptr++; + } + } + } +} + +// todo: use arm intrinsics +template <> +void elementwise_add_broadcast(const int32_t* dinx, + const int32_t* diny, + int32_t* dout, + int batch, + int channels, + int num) { + naive_elementwise_op_broadcast( + dinx, diny, dout, batch, channels, num, naive_add); +} + +// todo: use arm intrinsics +template <> +void elementwise_add_broadcast(const int64_t* dinx, + const int64_t* diny, + int64_t* dout, + int batch, + int channels, + int num) { + naive_elementwise_op_broadcast( + dinx, diny, dout, batch, channels, num, naive_add); +} + template <> void elementwise_add_broadcast(const float* dinx, const float* diny, diff --git a/lite/kernels/arm/elementwise_compute.cc b/lite/kernels/arm/elementwise_compute.cc index 44ce1c8652..9251201734 100644 --- a/lite/kernels/arm/elementwise_compute.cc +++ b/lite/kernels/arm/elementwise_compute.cc @@ -387,6 +387,26 @@ REGISTER_LITE_KERNEL( .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); +using elementwise_add_int32_t = + paddle::lite::kernels::arm::ElementwiseAddCompute; +REGISTER_LITE_KERNEL( + elementwise_add, kARM, kInt32, kNCHW, elementwise_add_int32_t, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .Finalize(); + +using elementwise_add_int64_t = + paddle::lite::kernels::arm::ElementwiseAddCompute; +REGISTER_LITE_KERNEL( + elementwise_add, kARM, kInt64, kNCHW, elementwise_add_int64_t, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .Finalize(); + REGISTER_LITE_KERNEL( fusion_elementwise_add_activation, kARM, -- GitLab