未验证 提交 51d15cc1 编写于 作者: C cc 提交者: GitHub

Update benchmark to support setting model and param filename (#2993)

* Update benchmark to support setting model and param filename
上级 2882edb1
......@@ -29,22 +29,22 @@
DEFINE_string(model_dir,
"",
"the path of the model, set model_dir when the model is no "
"combined formate. This option will be ignored if model_file "
"and param_file are exist.");
DEFINE_string(model_file,
"the path of the model, the model and param files is under "
"model_dir.");
DEFINE_string(model_filename,
"",
"the path of model file, set model_file when the model is "
"combined formate.");
DEFINE_string(param_file,
"the filename of model file. When the model is combined formate, "
"please set model_file.");
DEFINE_string(param_filename,
"",
"the path of param file, set param_file when the model is "
"the filename of param file, set param_file when the model is "
"combined formate.");
DEFINE_string(input_shape,
"1,3,224,224",
"set input shapes according to the model, "
"separated by colon and comma, "
"such as 1,3,244,244:1,3,300,300.");
"such as 1,3,244,244");
DEFINE_string(input_img_path, "", "the path of input image");
DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times");
DEFINE_int32(power_mode,
......@@ -77,12 +77,13 @@ inline double GetCurrentUS() {
return 1e+6 * time.tv_sec + time.tv_usec;
}
void OutputOptModel(const std::string& save_optimized_model_dir,
const std::vector<std::vector<int64_t>>& input_shapes) {
void OutputOptModel(const std::string& save_optimized_model_dir) {
lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir);
config.set_model_file(FLAGS_model_file);
config.set_param_file(FLAGS_param_file);
if (!FLAGS_model_filename.empty() && !FLAGS_param_filename.empty()) {
config.set_model_file(FLAGS_model_dir + "/" + FLAGS_model_filename);
config.set_param_file(FLAGS_model_dir + "/" + FLAGS_param_filename);
}
std::vector<Place> vaild_places = {
Place{TARGET(kARM), PRECISION(kFloat)},
};
......@@ -106,7 +107,7 @@ void OutputOptModel(const std::string& save_optimized_model_dir,
}
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
void Run(const std::vector<std::vector<int64_t>>& input_shapes,
void Run(const std::vector<int64_t>& input_shape,
const std::string& model_dir,
const std::string model_name) {
// set config and create predictor
......@@ -118,17 +119,27 @@ void Run(const std::vector<std::vector<int64_t>>& input_shapes,
auto predictor = lite_api::CreatePaddlePredictor(config);
// set input
for (int j = 0; j < input_shapes.size(); ++j) {
auto input_tensor = predictor->GetInput(j);
input_tensor->Resize(input_shapes[j]);
auto input_data = input_tensor->mutable_data<float>();
int input_num = 1;
for (size_t i = 0; i < input_shapes[j].size(); ++i) {
input_num *= input_shapes[j][i];
}
auto input_tensor = predictor->GetInput(0);
input_tensor->Resize(input_shape);
auto input_data = input_tensor->mutable_data<float>();
int input_num = 1;
for (size_t i = 0; i < input_shape.size(); ++i) {
input_num *= input_shape[i];
}
if (FLAGS_input_img_path.empty()) {
for (int i = 0; i < input_num; ++i) {
input_data[i] = 1.f;
}
} else {
std::fstream fs(FLAGS_input_img_path);
if (!fs.is_open()) {
LOG(FATAL) << "open input image " << FLAGS_input_img_path << " error.";
}
for (int i = 0; i < input_num; i++) {
fs >> input_data[i];
}
// LOG(INFO) << "input data:" << input_data[0] << " " <<
// input_data[input_num-1];
}
// warmup
......@@ -175,25 +186,12 @@ int main(int argc, char** argv) {
exit(0);
}
if (FLAGS_model_dir.back() == '/') {
FLAGS_model_dir.pop_back();
}
std::size_t found = FLAGS_model_dir.find_last_of("/");
std::string model_name = FLAGS_model_dir.substr(found + 1);
std::string save_optimized_model_dir = FLAGS_model_dir + "opt2";
auto split_string =
[](const std::string& str_in) -> std::vector<std::string> {
std::vector<std::string> str_out;
std::string tmp_str = str_in;
while (!tmp_str.empty()) {
size_t next_offset = tmp_str.find(":");
str_out.push_back(tmp_str.substr(0, next_offset));
if (next_offset == std::string::npos) {
break;
} else {
tmp_str = tmp_str.substr(next_offset + 1);
}
}
return str_out;
};
std::string save_optimized_model_dir = FLAGS_model_dir + "_opt2";
auto get_shape = [](const std::string& str_shape) -> std::vector<int64_t> {
std::vector<int64_t> shape;
......@@ -211,22 +209,18 @@ int main(int argc, char** argv) {
return shape;
};
std::vector<std::string> str_input_shapes = split_string(FLAGS_input_shape);
std::vector<std::vector<int64_t>> input_shapes;
for (size_t i = 0; i < str_input_shapes.size(); ++i) {
input_shapes.push_back(get_shape(str_input_shapes[i]));
}
std::vector<int64_t> input_shape = get_shape(FLAGS_input_shape);
// Output optimized model if needed
if (FLAGS_run_model_optimize) {
paddle::lite_api::OutputOptModel(save_optimized_model_dir, input_shapes);
paddle::lite_api::OutputOptModel(save_optimized_model_dir);
}
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
// Run inference using optimized model
std::string run_model_dir =
FLAGS_run_model_optimize ? save_optimized_model_dir : FLAGS_model_dir;
paddle::lite_api::Run(input_shapes, run_model_dir, model_name);
paddle::lite_api::Run(input_shape, run_model_dir, model_name);
#endif
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册