diff --git a/lite/backends/arm/math/elementwise.cc b/lite/backends/arm/math/elementwise.cc index eb3874d74eb7ee84213b0b7a932c4c1c5b0cf06f..658a235d621426a04647f1c5802147836232be89 100644 --- a/lite/backends/arm/math/elementwise.cc +++ b/lite/backends/arm/math/elementwise.cc @@ -44,6 +44,12 @@ static T naive_sub(T l, T r) { return l - r; } +// todo: remove this function when all elementwise div works +template +static T naive_div(T l, T r) { + return l / r; +} + // todo: use arm intrinsics template <> void elementwise_add(const int32_t* dinx, @@ -1511,6 +1517,15 @@ void elementwise_max_relu_broadcast(const float* dinx, } } +// todo: use arm intrinsics +template <> +void elementwise_div(const int32_t* dinx, + const int32_t* diny, + int32_t* dout, + int num) { + naive_elementwise_op(dinx, diny, dout, num, naive_div); +} + template <> void elementwise_div(const int64_t* dinx, const int64_t* diny, @@ -1576,6 +1591,18 @@ void elementwise_div(const float* dinx, } } +// todo: use arm intrinsics +template <> +void elementwise_div_broadcast(const int32_t* dinx, + const int32_t* diny, + int32_t* dout, + int batch, + int channels, + int num) { + naive_elementwise_op_broadcast( + dinx, diny, dout, batch, channels, num, naive_div); +} + template <> void elementwise_div_broadcast(const int64_t* dinx, const int64_t* diny, diff --git a/lite/kernels/arm/elementwise_compute.cc b/lite/kernels/arm/elementwise_compute.cc index cf38f7698907928bd16d8068a498239f46834be9..42cfa2471564ff15480af78944ed97416a5f93ca 100644 --- a/lite/kernels/arm/elementwise_compute.cc +++ b/lite/kernels/arm/elementwise_compute.cc @@ -548,6 +548,16 @@ REGISTER_LITE_KERNEL( .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); +using elementwise_div_int32_t = + paddle::lite::kernels::arm::ElementwiseDivCompute; +REGISTER_LITE_KERNEL( + elementwise_div, kARM, kInt32, kNCHW, elementwise_div_int32_t, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .Finalize(); + using elementwise_div_int64_t = paddle::lite::kernels::arm::ElementwiseDivCompute;