diff --git a/lite/backends/arm/math/elementwise.cc b/lite/backends/arm/math/elementwise.cc index 407cdc9ef18fbf7f403895fd7249bde9f7e80d51..04373992e4802a0b0c2529daac851e00ebcb56cf 100644 --- a/lite/backends/arm/math/elementwise.cc +++ b/lite/backends/arm/math/elementwise.cc @@ -1254,6 +1254,19 @@ void elementwise_max_relu_broadcast(const float* dinx, } } +template <> +void elementwise_div(const int64_t* dinx, + const int64_t* diny, + int64_t* dout, + int num) { + for (int i = 0; i < num; i++) { + *dout = *dinx / *diny; + dout++; + dinx++; + diny++; + } +} + template <> void elementwise_div(const float* dinx, const float* diny, @@ -1306,6 +1319,28 @@ void elementwise_div(const float* dinx, } } +template <> +void elementwise_div_broadcast(const int64_t* dinx, + const int64_t* diny, + int64_t* dout, + int batch, + int channels, + int num) { + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const int64_t* din_ptr = dinx + offset; + const int64_t diny_data = diny[j]; + int64_t* dout_ptr = dout + offset; + for (int p = 0; p < num; p++) { + *dout_ptr = *din_ptr / diny_data; + dout_ptr++; + din_ptr++; + } + } + } +} + template <> void elementwise_div_broadcast(const float* dinx, const float* diny, diff --git a/lite/kernels/arm/calib_compute.cc b/lite/kernels/arm/calib_compute.cc index 6dac97dcbc59991d4680ab1a98a54a900573f631..383e868843b43f4081e1eac330b1422b79307d9c 100644 --- a/lite/kernels/arm/calib_compute.cc +++ b/lite/kernels/arm/calib_compute.cc @@ -33,6 +33,17 @@ void CalibComputeFp32ToInt8::Run() { din, dout, scale.data(), 1, 1, param.input->numel()); } +template +void CalibComputeInt64ToInt32::Run() { + auto& param = this->template Param(); + const auto* din = param.input->template data(); + std::vector scale = {param.scale}; + auto* dout = param.output->template mutable_data(); + for (auto i = 0; i < param.input->numel(); ++i) { + dout[i] = din[i]; + } +} + template void CalibComputeInt8ToFp32::Run() { auto& param = this->template Param(); @@ -105,6 +116,23 @@ REGISTER_LITE_KERNEL( DATALAYOUT(kNHWC))}) .Finalize(); +REGISTER_LITE_KERNEL( + calib, + kARM, + kInt64, + kNCHW, + paddle::lite::kernels::arm::CalibComputeInt64ToInt32, + int64_to_int32) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kInt64), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kInt32), + DATALAYOUT(kNCHW))}) + .Finalize(); + REGISTER_LITE_KERNEL( calib_once, kARM, @@ -161,3 +189,20 @@ REGISTER_LITE_KERNEL( PRECISION(kFloat), DATALAYOUT(kNHWC))}) .Finalize(); + +REGISTER_LITE_KERNEL( + calib_once, + kARM, + kInt64, + kNCHW, + paddle::lite::kernels::arm::CalibComputeInt64ToInt32, + int64_to_int32) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kInt64), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kInt32), + DATALAYOUT(kNCHW))}) + .Finalize(); diff --git a/lite/kernels/arm/calib_compute.h b/lite/kernels/arm/calib_compute.h index a4c8b4c1232101416e95171d70ab629f6a37177b..f10bb931df9b276bc3bb01da16906f3e5b5a7dce 100644 --- a/lite/kernels/arm/calib_compute.h +++ b/lite/kernels/arm/calib_compute.h @@ -34,6 +34,19 @@ class CalibComputeFp32ToInt8 private: }; +template +class CalibComputeInt64ToInt32 + : public KernelLite { + public: + using param_t = operators::CalibParam; + + void Run() override; + + ~CalibComputeInt64ToInt32() override{}; + + private: +}; + template class CalibComputeInt8ToFp32 : public KernelLite { diff --git a/lite/kernels/arm/cast_compute.cc b/lite/kernels/arm/cast_compute.cc index 3b3ef07e105c583b7e3eb8b64b14610ca0f9e41a..919e9c603edff4383f086ac795c3dff4ed856c4f 100644 --- a/lite/kernels/arm/cast_compute.cc +++ b/lite/kernels/arm/cast_compute.cc @@ -62,8 +62,19 @@ void CastCompute::Run() { int32_t* out_data = param.Out->mutable_data(); std::transform( x_data_begin, x_data_end, out_data, TransOp); + } else if (param.in_dtype == 0 && param.out_dtype == 5) { // bool->fp32 + const bool* x_data_begin = param.X->data(); + const bool* x_data_end = x_data_begin + param.X->numel(); + float* out_data = param.Out->mutable_data(); + std::transform(x_data_begin, x_data_end, out_data, TransOp); + } else if (param.in_dtype == 3 && param.out_dtype == 5) { // int64->fp32 + const int64_t* x_data_begin = param.X->data(); + const int64_t* x_data_end = x_data_begin + param.X->numel(); + float* out_data = param.Out->mutable_data(); + std::transform(x_data_begin, x_data_end, out_data, TransOp); } else { - LOG(FATAL) << "other has not been implemented"; + LOG(FATAL) << "other has not been implemented transform with dtype" + << param.in_dtype << " X, dtype" << param.out_dtype << " Out"; } } diff --git a/lite/kernels/arm/elementwise_compute.cc b/lite/kernels/arm/elementwise_compute.cc index 64e29a7c3aad0e947320cd59ec27f8f8429265c6..28082785e1c726097a8bfd2165f0d09b9962a5e7 100644 --- a/lite/kernels/arm/elementwise_compute.cc +++ b/lite/kernels/arm/elementwise_compute.cc @@ -300,11 +300,12 @@ void ElementwiseMaxActivationCompute::Run() { } } -void ElementwiseDivCompute::Run() { - auto& param = Param(); - const float* x_data = param.X->data(); - const float* y_data = param.Y->data(); - float* out_data = param.Out->mutable_data(); +template +void ElementwiseDivCompute::Run() { + auto& param = this->template Param(); + auto* x_data = param.X->template data(); + auto* y_data = param.Y->template data(); + auto* out_data = param.Out->template mutable_data(); int axis = param.axis; auto x_dims = param.X->dims(); auto y_dims = param.Y->dims(); @@ -313,10 +314,10 @@ void ElementwiseDivCompute::Run() { LOG(FATAL) << "elewise div don't support x_dims size < y_dims size"; } if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { - lite::arm::math::elementwise_div_broadcast( + lite::arm::math::elementwise_div_broadcast( x_data, y_data, out_data, pre, n, post); } else { - lite::arm::math::elementwise_div( + lite::arm::math::elementwise_div( x_data, y_data, out_data, x_dims.production()); } } @@ -488,17 +489,27 @@ REGISTER_LITE_KERNEL( .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); -REGISTER_LITE_KERNEL(elementwise_div, - kARM, - kFloat, - kNCHW, - paddle::lite::kernels::arm::ElementwiseDivCompute, - def) +using elementwise_div_fp32 = + paddle::lite::kernels::arm::ElementwiseDivCompute; + +REGISTER_LITE_KERNEL( + elementwise_div, kARM, kFloat, kNCHW, elementwise_div_fp32, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); +using elementwise_div_int64 = + paddle::lite::kernels::arm::ElementwiseDivCompute; + +REGISTER_LITE_KERNEL( + elementwise_div, kARM, kInt64, kNCHW, elementwise_div_int64, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .Finalize(); + REGISTER_LITE_KERNEL( fusion_elementwise_div_activation, kARM, diff --git a/lite/kernels/arm/elementwise_compute.h b/lite/kernels/arm/elementwise_compute.h index b6a10fecff0fff486f93f31510d04f3956674309..7d7a93bf6954de9bbcd1b44061e614cd041fafe8 100644 --- a/lite/kernels/arm/elementwise_compute.h +++ b/lite/kernels/arm/elementwise_compute.h @@ -86,8 +86,8 @@ class ElementwiseMaxActivationCompute virtual ~ElementwiseMaxActivationCompute() = default; }; -class ElementwiseDivCompute - : public KernelLite { +template +class ElementwiseDivCompute : public KernelLite { public: void Run() override; diff --git a/lite/kernels/arm/gather_compute.cc b/lite/kernels/arm/gather_compute.cc index 3efacc4aacefcb150d53738c950ec9e797ed78c7..2a9c70aede7475b36f70c628ff6ccaa823f030b2 100644 --- a/lite/kernels/arm/gather_compute.cc +++ b/lite/kernels/arm/gather_compute.cc @@ -73,7 +73,6 @@ void GatherCompute::Run() { REGISTER_LITE_KERNEL( gather, kARM, kAny, kNCHW, paddle::lite::kernels::arm::GatherCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) - .BindInput("Index", - {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindInput("Index", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .Finalize();