diff --git a/paddle/fluid/lite/arm/math/CMakeLists.txt b/paddle/fluid/lite/arm/math/CMakeLists.txt index 883e7bc4609b09dcea485eb85607fe7e8f2136cf..0f832029c86246efa00a19f6d80813071936f88d 100644 --- a/paddle/fluid/lite/arm/math/CMakeLists.txt +++ b/paddle/fluid/lite/arm/math/CMakeLists.txt @@ -35,6 +35,8 @@ cc_library(math_arm SRCS split.cc activation.cc dropout.cc + gemm_prepacked_int8.cc + gemv_arm_int8.cc DEPS ${lite_kernel_deps} eigen3 framework_proto_lite) # TODO(TJ): fix me do not deps proto diff --git a/paddle/fluid/lite/core/target_wrapper.h b/paddle/fluid/lite/core/target_wrapper.h index c4a870ab83f0c61fc4a5116f8c3dd379e8ead9db..858ce65853ec792e3ec3b5a92db5fa0de223f505 100644 --- a/paddle/fluid/lite/core/target_wrapper.h +++ b/paddle/fluid/lite/core/target_wrapper.h @@ -38,6 +38,7 @@ enum class PrecisionType : int { kUnk = 0, kFloat, kInt8, + kInt32, kAny, // any precision NUM, // number of fields. }; @@ -48,6 +49,17 @@ enum class DataLayoutType : int { NUM, // number of fields. }; +static size_t PrecisionTypeLength(PrecisionType type) { + switch (type) { + case PrecisionType::kFloat: + return 4; + case PrecisionType::kInt8: + return 1; + default: + return 4; + } +} + // Some helper macro to get a specific TargetType. #define TARGET(item__) paddle::lite::TargetType::item__ // Some helper macro to get a specific PrecisionType. diff --git a/paddle/fluid/lite/kernels/arm/conv_compute.cc b/paddle/fluid/lite/kernels/arm/conv_compute.cc index a7cd385be9873837307fc89d8ac5a1a2ed7171a9..4ac6cd4b76121dca1ba9dc2fde541d32f1b377c0 100644 --- a/paddle/fluid/lite/kernels/arm/conv_compute.cc +++ b/paddle/fluid/lite/kernels/arm/conv_compute.cc @@ -94,6 +94,9 @@ void ConvCompute::Run() { // } } +void ConvComputeInt8::PrepareForRun() {} +void ConvComputeInt8::Run() {} + } // namespace arm } // namespace kernels } // namespace lite @@ -114,3 +117,23 @@ REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); + +REGISTER_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW, + paddle::lite::kernels::arm::ConvComputeInt8, def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindInput("Filter", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) + .BindOutput("Output", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) + .Finalize(); + +REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kInt8, kNCHW, + paddle::lite::kernels::arm::ConvComputeInt8, def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindInput("Filter", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) + .BindOutput("Output", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/arm/conv_compute.h b/paddle/fluid/lite/kernels/arm/conv_compute.h index 21fabf8c3e8f7983a891265135c39b96aaf42e8d..e5d5721a3b30256bd14a165400723cc4563cd942 100644 --- a/paddle/fluid/lite/kernels/arm/conv_compute.h +++ b/paddle/fluid/lite/kernels/arm/conv_compute.h @@ -41,6 +41,25 @@ class ConvCompute : public KernelLite { nullptr}; }; +class ConvComputeInt8 : public KernelLite { + public: + using param_t = operators::ConvParam; + + void PrepareForRun() override; + + void Run() override; + + ~ConvComputeInt8() { + if (impl_ != nullptr) { + delete impl_; + } + } + + private: + lite::arm::math::ImplBase* impl_{ + nullptr}; +}; + } // namespace arm } // namespace kernels } // namespace lite