提交 6c68024d 编写于 作者: X xingzhaolong

Merge branch 'xzl/incubate/lite' into 'incubate/lite'

add gemv gemm int8 neon kernel

See merge request inference/paddlelite!41
...@@ -50,6 +50,7 @@ class ExecutorLite { ...@@ -50,6 +50,7 @@ class ExecutorLite {
optimizer_.KernelPickPreferPlace(prefer_place); optimizer_.KernelPickPreferPlace(prefer_place);
core::KernelPickFactor factor; core::KernelPickFactor factor;
factor.ConsiderTarget(); factor.ConsiderTarget();
factor.ConsiderPrecision();
optimizer_.Run(std::move(program), valid_places, factor); optimizer_.Run(std::move(program), valid_places, factor);
program_ = optimizer_.GenRuntimeProgram(); program_ = optimizer_.GenRuntimeProgram();
} }
......
...@@ -35,6 +35,8 @@ cc_library(math_arm SRCS ...@@ -35,6 +35,8 @@ cc_library(math_arm SRCS
split.cc split.cc
activation.cc activation.cc
dropout.cc dropout.cc
gemm_prepacked_int8.cc
gemv_arm_int8.cc
DEPS ${lite_kernel_deps} eigen3 framework_proto_lite) DEPS ${lite_kernel_deps} eigen3 framework_proto_lite)
# TODO(TJ): fix me do not deps proto # TODO(TJ): fix me do not deps proto
......
...@@ -25,7 +25,7 @@ namespace fusion { ...@@ -25,7 +25,7 @@ namespace fusion {
/* The model trained by fluid quantization is a simulation of real int8. /* The model trained by fluid quantization is a simulation of real int8.
* The quantized Ops(conv2d, mul, depthwise conv2d etc) have fake_quantop * The quantized Ops(conv2d, mul, depthwise conv2d etc) have fake_quantop
* in front and fake_dequantop behind. * in front and fake_dequantop behind.
* *
* When in int8 mode, the pattern like "fake_quant + quantized_op + * When in int8 mode, the pattern like "fake_quant + quantized_op +
* fake_dequant" * fake_dequant"
......
...@@ -41,7 +41,7 @@ void FuseBase::DeleteInterNodes(SSAGraph *graph) { ...@@ -41,7 +41,7 @@ void FuseBase::DeleteInterNodes(SSAGraph *graph) {
} }
} }
LOG(INFO) << "keys: " << key2nodes_.size(); VLOG(4) << "keys: " << key2nodes_.size();
std::unordered_set<const Node *> nodes2rm; std::unordered_set<const Node *> nodes2rm;
for (auto &matched : key2nodes_) { for (auto &matched : key2nodes_) {
for (const auto &key : keys) { for (const auto &key : keys) {
......
...@@ -80,6 +80,8 @@ class KernelRegistry final { ...@@ -80,6 +80,8 @@ class KernelRegistry final {
KernelRegistryForTarget<TARGET(kARM), PRECISION(kAny), KernelRegistryForTarget<TARGET(kARM), PRECISION(kAny),
DATALAYOUT(kAny)> *, // DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kARM), PRECISION(kFloat), KernelRegistryForTarget<TARGET(kARM), PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kARM), PRECISION(kInt8),
DATALAYOUT(kNCHW)> * // DATALAYOUT(kNCHW)> * //
>; >;
......
...@@ -58,7 +58,6 @@ class Optimizer { ...@@ -58,7 +58,6 @@ class Optimizer {
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"lite_elementwise_add_activation_fuse_pass", // "lite_elementwise_add_activation_fuse_pass", //
#endif #endif
"lite_fc_fuse_pass", //
"static_kernel_pick_pass", // "static_kernel_pick_pass", //
"variable_place_inference_pass", // "variable_place_inference_pass", //
"argument_type_display_pass", // "argument_type_display_pass", //
......
...@@ -38,6 +38,7 @@ enum class PrecisionType : int { ...@@ -38,6 +38,7 @@ enum class PrecisionType : int {
kUnk = 0, kUnk = 0,
kFloat, kFloat,
kInt8, kInt8,
kInt32,
kAny, // any precision kAny, // any precision
NUM, // number of fields. NUM, // number of fields.
}; };
...@@ -48,6 +49,19 @@ enum class DataLayoutType : int { ...@@ -48,6 +49,19 @@ enum class DataLayoutType : int {
NUM, // number of fields. NUM, // number of fields.
}; };
static size_t PrecisionTypeLength(PrecisionType type) {
switch (type) {
case PrecisionType::kFloat:
return 4;
case PrecisionType::kInt8:
return 1;
case PrecisionType::kInt32:
return 4;
default:
return 4;
}
}
// Some helper macro to get a specific TargetType. // Some helper macro to get a specific TargetType.
#define TARGET(item__) paddle::lite::TargetType::item__ #define TARGET(item__) paddle::lite::TargetType::item__
// Some helper macro to get a specific PrecisionType. // Some helper macro to get a specific PrecisionType.
...@@ -87,7 +101,7 @@ static const std::string& TargetRepr(TargetType target) { ...@@ -87,7 +101,7 @@ static const std::string& TargetRepr(TargetType target) {
static const std::string& PrecisionRepr(PrecisionType precision) { static const std::string& PrecisionRepr(PrecisionType precision) {
static const std::string precision2string[] = {"kUnk", "kFloat", "kInt8", static const std::string precision2string[] = {"kUnk", "kFloat", "kInt8",
"kAny"}; "kInt32", "kAny"};
auto x = static_cast<int>(precision); auto x = static_cast<int>(precision);
CHECK_LT(x, static_cast<int>(PRECISION(NUM))); CHECK_LT(x, static_cast<int>(PRECISION(NUM)));
return precision2string[x]; return precision2string[x];
......
...@@ -92,6 +92,9 @@ void ConvCompute::Run() { ...@@ -92,6 +92,9 @@ void ConvCompute::Run() {
// } // }
} }
void ConvComputeInt8::PrepareForRun() {}
void ConvComputeInt8::Run() {}
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
...@@ -112,3 +115,23 @@ REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, ...@@ -112,3 +115,23 @@ REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW,
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW,
paddle::lite::kernels::arm::ConvComputeInt8, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.Finalize();
REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kInt8, kNCHW,
paddle::lite::kernels::arm::ConvComputeInt8, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.Finalize();
...@@ -41,6 +41,25 @@ class ConvCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { ...@@ -41,6 +41,25 @@ class ConvCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
nullptr}; nullptr};
}; };
class ConvComputeInt8 : public KernelLite<TARGET(kARM), PRECISION(kInt8)> {
public:
using param_t = operators::ConvParam;
void PrepareForRun() override;
void Run() override;
~ConvComputeInt8() {
if (impl_ != nullptr) {
delete impl_;
}
}
private:
lite::arm::math::ImplBase<TARGET(kARM), PRECISION(kInt8), param_t>* impl_{
nullptr};
};
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册