提交 6c68024d 编写于 作者: X xingzhaolong

Merge branch 'xzl/incubate/lite' into 'incubate/lite'

add gemv gemm int8 neon kernel

See merge request inference/paddlelite!41
......@@ -50,6 +50,7 @@ class ExecutorLite {
optimizer_.KernelPickPreferPlace(prefer_place);
core::KernelPickFactor factor;
factor.ConsiderTarget();
factor.ConsiderPrecision();
optimizer_.Run(std::move(program), valid_places, factor);
program_ = optimizer_.GenRuntimeProgram();
}
......
......@@ -35,6 +35,8 @@ cc_library(math_arm SRCS
split.cc
activation.cc
dropout.cc
gemm_prepacked_int8.cc
gemv_arm_int8.cc
DEPS ${lite_kernel_deps} eigen3 framework_proto_lite)
# TODO(TJ): fix me do not deps proto
......
......@@ -25,7 +25,7 @@ namespace fusion {
/* The model trained by fluid quantization is a simulation of real int8.
* The quantized Ops(conv2d, mul, depthwise conv2d etc) have fake_quantop
* in front and fake_dequantop behind.
* in front and fake_dequantop behind.
*
* When in int8 mode, the pattern like "fake_quant + quantized_op +
* fake_dequant"
......
......@@ -41,7 +41,7 @@ void FuseBase::DeleteInterNodes(SSAGraph *graph) {
}
}
LOG(INFO) << "keys: " << key2nodes_.size();
VLOG(4) << "keys: " << key2nodes_.size();
std::unordered_set<const Node *> nodes2rm;
for (auto &matched : key2nodes_) {
for (const auto &key : keys) {
......
......@@ -80,6 +80,8 @@ class KernelRegistry final {
KernelRegistryForTarget<TARGET(kARM), PRECISION(kAny),
DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kARM), PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kARM), PRECISION(kInt8),
DATALAYOUT(kNCHW)> * //
>;
......
......@@ -58,7 +58,6 @@ class Optimizer {
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"lite_elementwise_add_activation_fuse_pass", //
#endif
"lite_fc_fuse_pass", //
"static_kernel_pick_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
......
......@@ -38,6 +38,7 @@ enum class PrecisionType : int {
kUnk = 0,
kFloat,
kInt8,
kInt32,
kAny, // any precision
NUM, // number of fields.
};
......@@ -48,6 +49,19 @@ enum class DataLayoutType : int {
NUM, // number of fields.
};
static size_t PrecisionTypeLength(PrecisionType type) {
switch (type) {
case PrecisionType::kFloat:
return 4;
case PrecisionType::kInt8:
return 1;
case PrecisionType::kInt32:
return 4;
default:
return 4;
}
}
// Some helper macro to get a specific TargetType.
#define TARGET(item__) paddle::lite::TargetType::item__
// Some helper macro to get a specific PrecisionType.
......@@ -87,7 +101,7 @@ static const std::string& TargetRepr(TargetType target) {
static const std::string& PrecisionRepr(PrecisionType precision) {
static const std::string precision2string[] = {"kUnk", "kFloat", "kInt8",
"kAny"};
"kInt32", "kAny"};
auto x = static_cast<int>(precision);
CHECK_LT(x, static_cast<int>(PRECISION(NUM)));
return precision2string[x];
......
......@@ -92,6 +92,9 @@ void ConvCompute::Run() {
// }
}
void ConvComputeInt8::PrepareForRun() {}
void ConvComputeInt8::Run() {}
} // namespace arm
} // namespace kernels
} // namespace lite
......@@ -112,3 +115,23 @@ REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW,
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(conv2d, kARM, kInt8, kNCHW,
paddle::lite::kernels::arm::ConvComputeInt8, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.Finalize();
REGISTER_LITE_KERNEL(depthwise_conv2d, kARM, kInt8, kNCHW,
paddle::lite::kernels::arm::ConvComputeInt8, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Filter",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.Finalize();
......@@ -41,6 +41,25 @@ class ConvCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
nullptr};
};
class ConvComputeInt8 : public KernelLite<TARGET(kARM), PRECISION(kInt8)> {
public:
using param_t = operators::ConvParam;
void PrepareForRun() override;
void Run() override;
~ConvComputeInt8() {
if (impl_ != nullptr) {
delete impl_;
}
}
private:
lite::arm::math::ImplBase<TARGET(kARM), PRECISION(kInt8), param_t>* impl_{
nullptr};
};
} // namespace arm
} // namespace kernels
} // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册