From f149ce86f5ad86e35d58ac94ad5084a7040a0ee5 Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Fri, 20 Mar 2020 11:20:22 +0800 Subject: [PATCH] Add arm int8 place automatic, test=develop (#3234) --- lite/api/cxx_api.cc | 26 ++++++++++++++++++++++++++ lite/api/opt.cc | 7 ------- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index b739c78f7c..556a9e0af0 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -294,6 +294,32 @@ void Predictor::Build(const cpp::ProgramDesc &desc, inner_places.emplace_back(TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)); inner_places.emplace_back( TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); + + const std::vector quant_dequant_op = { + "fake_quantize_abs_max", + "fake_quantize_range_abs_max", + "fake_quantize_moving_average_abs_max", + "fake_quantize_dequantize_moving_average_abs_max", + "fake_dequantize_max_abs", + "fake_channel_wise_dequantize_max_abs"}; + bool is_quantized_model = false; + for (size_t i = 0; i < program_desc_.BlocksSize() && !is_quantized_model; + ++i) { + auto *block_desc = program_desc_.GetBlock(i); + for (size_t j = 0; j < block_desc->OpsSize() && !is_quantized_model; ++j) { + auto *op_desc = block_desc->GetOp(j); + std::string op_type = op_desc->Type(); + if (std::find(quant_dequant_op.begin(), + quant_dequant_op.end(), + op_type) != quant_dequant_op.end()) { + is_quantized_model = true; + } + } + } + if (is_quantized_model) { + inner_places.emplace_back(Place{TARGET(kARM), PRECISION(kInt8)}); + } + Program program(desc, scope_, inner_places); core::KernelPickFactor factor; diff --git a/lite/api/opt.cc b/lite/api/opt.cc index 50c1e7d580..b849719968 100644 --- a/lite/api/opt.cc +++ b/lite/api/opt.cc @@ -67,7 +67,6 @@ DEFINE_string(valid_targets, "arm", "The targets this model optimized for, should be one of (arm, " "opencl, x86), splitted by space"); -DEFINE_bool(prefer_int8_kernel, false, "Prefer to run model with int8 kernels"); DEFINE_bool(print_supported_ops, false, "Print supported operators on the inputed target"); @@ -121,11 +120,6 @@ std::vector ParserValidPlaces() { << "At least one target should be set, should set the " "command argument 'valid_targets'"; - if (FLAGS_prefer_int8_kernel) { - LOG(WARNING) << "Int8 mode is only support by ARM target"; - valid_places.insert(valid_places.begin(), - Place{TARGET(kARM), PRECISION(kInt8)}); - } return valid_places; } @@ -255,7 +249,6 @@ void PrintHelpInfo() { " `--optimize_out_type=(protobuf|naive_buffer)`\n" " `--optimize_out=`\n" " `--valid_targets=(arm|opencl|x86|npu|xpu)`\n" - " `--prefer_int8_kernel=(true|false)`\n" " `--record_tailoring_info=(true|false)`\n" " Arguments of model checking and ops information:\n" " `--print_all_ops=true` Display all the valid operators of " -- GitLab