From dfcd059b61c2111d81fb9002c50e9c3af2a6c11d Mon Sep 17 00:00:00 2001 From: zhangwen31 Date: Mon, 14 Sep 2020 14:36:54 +0000 Subject: [PATCH] [arm][kernel][math] feat: add i32 and i64 support for elementwise sub --- lite/backends/arm/math/elementwise.cc | 49 +++++++++++++++++++++++++ lite/kernels/arm/elementwise_compute.cc | 20 ++++++++++ 2 files changed, 69 insertions(+) diff --git a/lite/backends/arm/math/elementwise.cc b/lite/backends/arm/math/elementwise.cc index 7b9c324451..eb3874d74e 100644 --- a/lite/backends/arm/math/elementwise.cc +++ b/lite/backends/arm/math/elementwise.cc @@ -38,6 +38,12 @@ static T naive_add(T l, T r) { return l + r; } +// todo: remove this function when all elementwise sub works +template +static T naive_sub(T l, T r) { + return l - r; +} + // todo: use arm intrinsics template <> void elementwise_add(const int32_t* dinx, @@ -472,6 +478,25 @@ void elementwise_add_grad_broadcast(const float* dout_grad, } } } + +// todo: use arm intrinsics +template <> +void elementwise_sub(const int32_t* dinx, + const int32_t* diny, + int32_t* dout, + int num) { + naive_elementwise_op(dinx, diny, dout, num, naive_sub); +} + +// todo: use arm intrinsics +template <> +void elementwise_sub(const int64_t* dinx, + const int64_t* diny, + int64_t* dout, + int num) { + naive_elementwise_op(dinx, diny, dout, num, naive_sub); +} + template <> void elementwise_sub(const float* dinx, const float* diny, @@ -572,6 +597,30 @@ void elementwise_sub_relu(const float* dinx, } } +// todo: use arm intrinsics +template <> +void elementwise_sub_broadcast(const int32_t* dinx, + const int32_t* diny, + int32_t* dout, + int batch, + int channels, + int num) { + naive_elementwise_op_broadcast( + dinx, diny, dout, batch, channels, num, naive_sub); +} + +// todo: use arm intrinsics +template <> +void elementwise_sub_broadcast(const int64_t* dinx, + const int64_t* diny, + int64_t* dout, + int batch, + int channels, + int num) { + naive_elementwise_op_broadcast( + dinx, diny, dout, batch, channels, num, naive_sub); +} + template <> void elementwise_sub_broadcast(const float* dinx, const float* diny, diff --git a/lite/kernels/arm/elementwise_compute.cc b/lite/kernels/arm/elementwise_compute.cc index 71fedc9665..cf38f76989 100644 --- a/lite/kernels/arm/elementwise_compute.cc +++ b/lite/kernels/arm/elementwise_compute.cc @@ -429,6 +429,26 @@ REGISTER_LITE_KERNEL( .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); +using elementwise_sub_int32_t = + paddle::lite::kernels::arm::ElementwiseSubCompute; +REGISTER_LITE_KERNEL( + elementwise_sub, kARM, kInt32, kNCHW, elementwise_sub_int32_t, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .Finalize(); + +using elementwise_sub_int64_t = + paddle::lite::kernels::arm::ElementwiseSubCompute; +REGISTER_LITE_KERNEL( + elementwise_sub, kARM, kInt64, kNCHW, elementwise_sub_int64_t, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .Finalize(); + REGISTER_LITE_KERNEL( fusion_elementwise_sub_activation, kARM, -- GitLab