diff --git a/lite/api/cxx_api.cc b/lite/api/cxx_api.cc index b739c78f7c883d62b39d88ae1a7f4bf76ae8932c..556a9e0af01854ff5c57a14dade72b81ed255964 100644 --- a/lite/api/cxx_api.cc +++ b/lite/api/cxx_api.cc @@ -294,6 +294,32 @@ void Predictor::Build(const cpp::ProgramDesc &desc, inner_places.emplace_back(TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)); inner_places.emplace_back( TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); + + const std::vector quant_dequant_op = { + "fake_quantize_abs_max", + "fake_quantize_range_abs_max", + "fake_quantize_moving_average_abs_max", + "fake_quantize_dequantize_moving_average_abs_max", + "fake_dequantize_max_abs", + "fake_channel_wise_dequantize_max_abs"}; + bool is_quantized_model = false; + for (size_t i = 0; i < program_desc_.BlocksSize() && !is_quantized_model; + ++i) { + auto *block_desc = program_desc_.GetBlock(i); + for (size_t j = 0; j < block_desc->OpsSize() && !is_quantized_model; ++j) { + auto *op_desc = block_desc->GetOp(j); + std::string op_type = op_desc->Type(); + if (std::find(quant_dequant_op.begin(), + quant_dequant_op.end(), + op_type) != quant_dequant_op.end()) { + is_quantized_model = true; + } + } + } + if (is_quantized_model) { + inner_places.emplace_back(Place{TARGET(kARM), PRECISION(kInt8)}); + } + Program program(desc, scope_, inner_places); core::KernelPickFactor factor; diff --git a/lite/api/opt.cc b/lite/api/opt.cc index 50c1e7d580bab49c3cf56e4ef96ad260eb573194..b8497199684cb4f6d4cc602291be5762eb93f7f9 100644 --- a/lite/api/opt.cc +++ b/lite/api/opt.cc @@ -67,7 +67,6 @@ DEFINE_string(valid_targets, "arm", "The targets this model optimized for, should be one of (arm, " "opencl, x86), splitted by space"); -DEFINE_bool(prefer_int8_kernel, false, "Prefer to run model with int8 kernels"); DEFINE_bool(print_supported_ops, false, "Print supported operators on the inputed target"); @@ -121,11 +120,6 @@ std::vector ParserValidPlaces() { << "At least one target should be set, should set the " "command argument 'valid_targets'"; - if (FLAGS_prefer_int8_kernel) { - LOG(WARNING) << "Int8 mode is only support by ARM target"; - valid_places.insert(valid_places.begin(), - Place{TARGET(kARM), PRECISION(kInt8)}); - } return valid_places; } @@ -255,7 +249,6 @@ void PrintHelpInfo() { " `--optimize_out_type=(protobuf|naive_buffer)`\n" " `--optimize_out=`\n" " `--valid_targets=(arm|opencl|x86|npu|xpu)`\n" - " `--prefer_int8_kernel=(true|false)`\n" " `--record_tailoring_info=(true|false)`\n" " Arguments of model checking and ops information:\n" " `--print_all_ops=true` Display all the valid operators of "