未验证 提交 51d15cc1 编写于 作者: C cc 提交者: GitHub

Update benchmark to support setting model and param filename (#2993)

* Update benchmark to support setting model and param filename
上级 2882edb1
...@@ -29,22 +29,22 @@ ...@@ -29,22 +29,22 @@
DEFINE_string(model_dir, DEFINE_string(model_dir,
"", "",
"the path of the model, set model_dir when the model is no " "the path of the model, the model and param files is under "
"combined formate. This option will be ignored if model_file " "model_dir.");
"and param_file are exist."); DEFINE_string(model_filename,
DEFINE_string(model_file,
"", "",
"the path of model file, set model_file when the model is " "the filename of model file. When the model is combined formate, "
"combined formate."); "please set model_file.");
DEFINE_string(param_file, DEFINE_string(param_filename,
"", "",
"the path of param file, set param_file when the model is " "the filename of param file, set param_file when the model is "
"combined formate."); "combined formate.");
DEFINE_string(input_shape, DEFINE_string(input_shape,
"1,3,224,224", "1,3,224,224",
"set input shapes according to the model, " "set input shapes according to the model, "
"separated by colon and comma, " "separated by colon and comma, "
"such as 1,3,244,244:1,3,300,300."); "such as 1,3,244,244");
DEFINE_string(input_img_path, "", "the path of input image");
DEFINE_int32(warmup, 0, "warmup times"); DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times"); DEFINE_int32(repeats, 1, "repeats times");
DEFINE_int32(power_mode, DEFINE_int32(power_mode,
...@@ -77,12 +77,13 @@ inline double GetCurrentUS() { ...@@ -77,12 +77,13 @@ inline double GetCurrentUS() {
return 1e+6 * time.tv_sec + time.tv_usec; return 1e+6 * time.tv_sec + time.tv_usec;
} }
void OutputOptModel(const std::string& save_optimized_model_dir, void OutputOptModel(const std::string& save_optimized_model_dir) {
const std::vector<std::vector<int64_t>>& input_shapes) {
lite_api::CxxConfig config; lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir); config.set_model_dir(FLAGS_model_dir);
config.set_model_file(FLAGS_model_file); if (!FLAGS_model_filename.empty() && !FLAGS_param_filename.empty()) {
config.set_param_file(FLAGS_param_file); config.set_model_file(FLAGS_model_dir + "/" + FLAGS_model_filename);
config.set_param_file(FLAGS_model_dir + "/" + FLAGS_param_filename);
}
std::vector<Place> vaild_places = { std::vector<Place> vaild_places = {
Place{TARGET(kARM), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)},
}; };
...@@ -106,7 +107,7 @@ void OutputOptModel(const std::string& save_optimized_model_dir, ...@@ -106,7 +107,7 @@ void OutputOptModel(const std::string& save_optimized_model_dir,
} }
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
void Run(const std::vector<std::vector<int64_t>>& input_shapes, void Run(const std::vector<int64_t>& input_shape,
const std::string& model_dir, const std::string& model_dir,
const std::string model_name) { const std::string model_name) {
// set config and create predictor // set config and create predictor
...@@ -118,17 +119,27 @@ void Run(const std::vector<std::vector<int64_t>>& input_shapes, ...@@ -118,17 +119,27 @@ void Run(const std::vector<std::vector<int64_t>>& input_shapes,
auto predictor = lite_api::CreatePaddlePredictor(config); auto predictor = lite_api::CreatePaddlePredictor(config);
// set input // set input
for (int j = 0; j < input_shapes.size(); ++j) { auto input_tensor = predictor->GetInput(0);
auto input_tensor = predictor->GetInput(j); input_tensor->Resize(input_shape);
input_tensor->Resize(input_shapes[j]); auto input_data = input_tensor->mutable_data<float>();
auto input_data = input_tensor->mutable_data<float>(); int input_num = 1;
int input_num = 1; for (size_t i = 0; i < input_shape.size(); ++i) {
for (size_t i = 0; i < input_shapes[j].size(); ++i) { input_num *= input_shape[i];
input_num *= input_shapes[j][i]; }
} if (FLAGS_input_img_path.empty()) {
for (int i = 0; i < input_num; ++i) { for (int i = 0; i < input_num; ++i) {
input_data[i] = 1.f; input_data[i] = 1.f;
} }
} else {
std::fstream fs(FLAGS_input_img_path);
if (!fs.is_open()) {
LOG(FATAL) << "open input image " << FLAGS_input_img_path << " error.";
}
for (int i = 0; i < input_num; i++) {
fs >> input_data[i];
}
// LOG(INFO) << "input data:" << input_data[0] << " " <<
// input_data[input_num-1];
} }
// warmup // warmup
...@@ -175,25 +186,12 @@ int main(int argc, char** argv) { ...@@ -175,25 +186,12 @@ int main(int argc, char** argv) {
exit(0); exit(0);
} }
if (FLAGS_model_dir.back() == '/') {
FLAGS_model_dir.pop_back();
}
std::size_t found = FLAGS_model_dir.find_last_of("/"); std::size_t found = FLAGS_model_dir.find_last_of("/");
std::string model_name = FLAGS_model_dir.substr(found + 1); std::string model_name = FLAGS_model_dir.substr(found + 1);
std::string save_optimized_model_dir = FLAGS_model_dir + "opt2"; std::string save_optimized_model_dir = FLAGS_model_dir + "_opt2";
auto split_string =
[](const std::string& str_in) -> std::vector<std::string> {
std::vector<std::string> str_out;
std::string tmp_str = str_in;
while (!tmp_str.empty()) {
size_t next_offset = tmp_str.find(":");
str_out.push_back(tmp_str.substr(0, next_offset));
if (next_offset == std::string::npos) {
break;
} else {
tmp_str = tmp_str.substr(next_offset + 1);
}
}
return str_out;
};
auto get_shape = [](const std::string& str_shape) -> std::vector<int64_t> { auto get_shape = [](const std::string& str_shape) -> std::vector<int64_t> {
std::vector<int64_t> shape; std::vector<int64_t> shape;
...@@ -211,22 +209,18 @@ int main(int argc, char** argv) { ...@@ -211,22 +209,18 @@ int main(int argc, char** argv) {
return shape; return shape;
}; };
std::vector<std::string> str_input_shapes = split_string(FLAGS_input_shape); std::vector<int64_t> input_shape = get_shape(FLAGS_input_shape);
std::vector<std::vector<int64_t>> input_shapes;
for (size_t i = 0; i < str_input_shapes.size(); ++i) {
input_shapes.push_back(get_shape(str_input_shapes[i]));
}
// Output optimized model if needed // Output optimized model if needed
if (FLAGS_run_model_optimize) { if (FLAGS_run_model_optimize) {
paddle::lite_api::OutputOptModel(save_optimized_model_dir, input_shapes); paddle::lite_api::OutputOptModel(save_optimized_model_dir);
} }
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
// Run inference using optimized model // Run inference using optimized model
std::string run_model_dir = std::string run_model_dir =
FLAGS_run_model_optimize ? save_optimized_model_dir : FLAGS_model_dir; FLAGS_run_model_optimize ? save_optimized_model_dir : FLAGS_model_dir;
paddle::lite_api::Run(input_shapes, run_model_dir, model_name); paddle::lite_api::Run(input_shape, run_model_dir, model_name);
#endif #endif
return 0; return 0;
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册