diff --git a/lite/api/benchmark.cc b/lite/api/benchmark.cc index efb706a5ef5fc31ff2cfc22e04ee5a808e4991cd..f0cb6841d5b73ea600b9e2b7e2f055192811b6c3 100644 --- a/lite/api/benchmark.cc +++ b/lite/api/benchmark.cc @@ -30,7 +30,19 @@ #include "lite/utils/cp_logging.h" #include "lite/utils/string.h" -DEFINE_string(model_dir, "", "model dir"); +DEFINE_string(model_dir, + "", + "the path of the model, set model_dir when the model is no " + "combined formate. This option will be ignored if model_file " + "and param_file are exist."); +DEFINE_string(model_file, + "", + "the path of model file, set model_file when the model is " + "combined formate."); +DEFINE_string(param_file, + "", + "the path of param file, set param_file when the model is " + "combined formate."); DEFINE_string(input_shape, "1,3,224,224", "set input shapes according to the model, " @@ -68,11 +80,12 @@ inline double GetCurrentUS() { return 1e+6 * time.tv_sec + time.tv_usec; } -void OutputOptModel(const std::string& load_model_dir, - const std::string& save_optimized_model_dir, +void OutputOptModel(const std::string& save_optimized_model_dir, const std::vector>& input_shapes) { lite_api::CxxConfig config; - config.set_model_dir(load_model_dir); + config.set_model_dir(FLAGS_model_dir); + config.set_model_file(FLAGS_model_file); + config.set_param_file(FLAGS_param_file); std::vector vaild_places = { Place{TARGET(kARM), PRECISION(kFloat)}, }; @@ -91,7 +104,7 @@ void OutputOptModel(const std::string& load_model_dir, } predictor->SaveOptimizedModel(save_optimized_model_dir, LiteModelType::kNaiveBuffer); - LOG(INFO) << "Load model from " << load_model_dir; + LOG(INFO) << "Load model from " << FLAGS_model_dir; LOG(INFO) << "Save optimized model to " << save_optimized_model_dir; } @@ -146,7 +159,7 @@ void Run(const std::vector>& input_shapes, LOG(FATAL) << "open result file failed"; } ofs.precision(5); - ofs << std::setw(20) << std::fixed << std::left << model_name; + ofs << std::setw(30) << std::fixed << std::left << model_name; ofs << "min = " << std::setw(12) << min_res; ofs << "max = " << std::setw(12) << max_res; ofs << "average = " << std::setw(12) << avg_res; @@ -209,8 +222,7 @@ int main(int argc, char** argv) { // Output optimized model if needed if (FLAGS_run_model_optimize) { - paddle::lite_api::OutputOptModel( - FLAGS_model_dir, save_optimized_model_dir, input_shapes); + paddle::lite_api::OutputOptModel(save_optimized_model_dir, input_shapes); } #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK