未验证 提交 b51bb292 编写于 作者: C cc 提交者: GitHub

Add arm int8 place automatic, test=develop (#3234)

上级 fefc28ca
...@@ -294,6 +294,32 @@ void Predictor::Build(const cpp::ProgramDesc &desc, ...@@ -294,6 +294,32 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
inner_places.emplace_back(TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)); inner_places.emplace_back(TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny));
inner_places.emplace_back( inner_places.emplace_back(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
const std::vector<std::string> quant_dequant_op = {
"fake_quantize_abs_max",
"fake_quantize_range_abs_max",
"fake_quantize_moving_average_abs_max",
"fake_quantize_dequantize_moving_average_abs_max",
"fake_dequantize_max_abs",
"fake_channel_wise_dequantize_max_abs"};
bool is_quantized_model = false;
for (size_t i = 0; i < program_desc_.BlocksSize() && !is_quantized_model;
++i) {
auto *block_desc = program_desc_.GetBlock<cpp::BlockDesc>(i);
for (size_t j = 0; j < block_desc->OpsSize() && !is_quantized_model; ++j) {
auto *op_desc = block_desc->GetOp<cpp::OpDesc>(j);
std::string op_type = op_desc->Type();
if (std::find(quant_dequant_op.begin(),
quant_dequant_op.end(),
op_type) != quant_dequant_op.end()) {
is_quantized_model = true;
}
}
}
if (is_quantized_model) {
inner_places.emplace_back(Place{TARGET(kARM), PRECISION(kInt8)});
}
Program program(desc, scope_, inner_places); Program program(desc, scope_, inner_places);
core::KernelPickFactor factor; core::KernelPickFactor factor;
......
...@@ -67,7 +67,6 @@ DEFINE_string(valid_targets, ...@@ -67,7 +67,6 @@ DEFINE_string(valid_targets,
"arm", "arm",
"The targets this model optimized for, should be one of (arm, " "The targets this model optimized for, should be one of (arm, "
"opencl, x86), splitted by space"); "opencl, x86), splitted by space");
DEFINE_bool(prefer_int8_kernel, false, "Prefer to run model with int8 kernels");
DEFINE_bool(print_supported_ops, DEFINE_bool(print_supported_ops,
false, false,
"Print supported operators on the inputed target"); "Print supported operators on the inputed target");
...@@ -121,11 +120,6 @@ std::vector<Place> ParserValidPlaces() { ...@@ -121,11 +120,6 @@ std::vector<Place> ParserValidPlaces() {
<< "At least one target should be set, should set the " << "At least one target should be set, should set the "
"command argument 'valid_targets'"; "command argument 'valid_targets'";
if (FLAGS_prefer_int8_kernel) {
LOG(WARNING) << "Int8 mode is only support by ARM target";
valid_places.insert(valid_places.begin(),
Place{TARGET(kARM), PRECISION(kInt8)});
}
return valid_places; return valid_places;
} }
...@@ -255,7 +249,6 @@ void PrintHelpInfo() { ...@@ -255,7 +249,6 @@ void PrintHelpInfo() {
" `--optimize_out_type=(protobuf|naive_buffer)`\n" " `--optimize_out_type=(protobuf|naive_buffer)`\n"
" `--optimize_out=<output_optimize_model_dir>`\n" " `--optimize_out=<output_optimize_model_dir>`\n"
" `--valid_targets=(arm|opencl|x86|npu|xpu)`\n" " `--valid_targets=(arm|opencl|x86|npu|xpu)`\n"
" `--prefer_int8_kernel=(true|false)`\n"
" `--record_tailoring_info=(true|false)`\n" " `--record_tailoring_info=(true|false)`\n"
" Arguments of model checking and ops information:\n" " Arguments of model checking and ops information:\n"
" `--print_all_ops=true` Display all the valid operators of " " `--print_all_ops=true` Display all the valid operators of "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册