From 2b04dbe90af798d2f83c6031f6ba7b926daa965f Mon Sep 17 00:00:00 2001 From: Bin Li Date: Fri, 27 Sep 2019 11:22:28 +0800 Subject: [PATCH] Refactor mace_run --- docs/user_guide/quantization_usage.rst | 2 +- mace/tools/mace_run.cc | 95 ++++++++++++-------------- 2 files changed, 45 insertions(+), 52 deletions(-) diff --git a/docs/user_guide/quantization_usage.rst b/docs/user_guide/quantization_usage.rst index 40699cca..28f3afbc 100644 --- a/docs/user_guide/quantization_usage.rst +++ b/docs/user_guide/quantization_usage.rst @@ -62,7 +62,7 @@ MACE provides tools to do statistics with following steps: .. code:: sh - python tools/python/tools/quantize/quantize_stat.py --log_file range_log > overall_range + python tools/python/quantize/quantize_stat.py --log_file range_log > overall_range 4. Convert quantized model (by setting `target_abis` to the final target abis, e.g., `armeabi-v7a`, diff --git a/mace/tools/mace_run.cc b/mace/tools/mace_run.cc index 74c02f23..45ec2f33 100644 --- a/mace/tools/mace_run.cc +++ b/mace/tools/mace_run.cc @@ -318,61 +318,52 @@ bool RunModel(const std::string &model_name, DIR *dir_parent; struct dirent *entry; dir_parent = opendir(FLAGS_input_dir.c_str()); - if (dir_parent) { - while ((entry = readdir(dir_parent))) { - std::string file_name = std::string(entry->d_name); - std::string prefix = FormatName(input_names[0]); - if (file_name.find(prefix) == 0) { - std::string suffix = file_name.substr(prefix.size()); - - for (size_t i = 0; i < input_count; ++i) { - file_name = FLAGS_input_dir + "/" + FormatName(input_names[i]) - + suffix; - std::ifstream in_file(file_name, std::ios::in | std::ios::binary); - std::cout << "Read " << file_name << std::endl; - if (in_file.is_open()) { - in_file.read(reinterpret_cast( - inputs[input_names[i]].data().get()), - inputs_size[input_names[i]] * sizeof(float)); - in_file.close(); - } else { - std::cerr << "Open input file failed" << std::endl; - return -1; - } - } - engine->Run(inputs, &outputs); - - if (!FLAGS_output_dir.empty()) { - for (size_t i = 0; i < output_count; ++i) { - std::string output_name = - FLAGS_output_dir + "/" + FormatName(output_names[i]) + suffix; - std::ofstream out_file(output_name, std::ios::binary); - if (out_file.is_open()) { - int64_t output_size = - std::accumulate(output_shapes[i].begin(), - output_shapes[i].end(), - 1, - std::multiplies()); - out_file.write( - reinterpret_cast( - outputs[output_names[i]].data().get()), - output_size * sizeof(float)); - out_file.flush(); - out_file.close(); - } else { - std::cerr << "Open output file failed" << std::endl; - return -1; - } - } + MACE_CHECK(dir_parent != nullptr, "Open input_dir ", FLAGS_input_dir, + " failed: ", strerror(errno)); + while ((entry = readdir(dir_parent))) { + std::string file_name = std::string(entry->d_name); + std::string prefix = FormatName(input_names[0]); + if (file_name.find(prefix) == 0) { + std::string suffix = file_name.substr(prefix.size()); + + for (size_t i = 0; i < input_count; ++i) { + file_name = FLAGS_input_dir + "/" + FormatName(input_names[i]) + + suffix; + std::ifstream in_file(file_name, std::ios::in | std::ios::binary); + LOG(INFO) << "Read " << file_name; + MACE_CHECK(in_file.is_open(), "Open input file failed: ", + strerror(errno)); + in_file.read(reinterpret_cast( + inputs[input_names[i]].data().get()), + inputs_size[input_names[i]] * sizeof(float)); + in_file.close(); + } + engine->Run(inputs, &outputs); + + if (!FLAGS_output_dir.empty()) { + for (size_t i = 0; i < output_count; ++i) { + std::string output_name = + FLAGS_output_dir + "/" + FormatName(output_names[i]) + suffix; + std::ofstream out_file(output_name, std::ios::binary); + MACE_CHECK(out_file.is_open(), "Open output file failed: ", + strerror(errno)); + int64_t output_size = + std::accumulate(output_shapes[i].begin(), + output_shapes[i].end(), + 1, + std::multiplies()); + out_file.write( + reinterpret_cast( + outputs[output_names[i]].data().get()), + output_size * sizeof(float)); + out_file.flush(); + out_file.close(); } } } - - closedir(dir_parent); - } else { - std::cerr << "Directory " << FLAGS_input_dir << " does not exist." - << std::endl; } + + closedir(dir_parent); } else { LOG(INFO) << "Warm up run"; double warmup_millis; @@ -539,6 +530,8 @@ int Main(int argc, char **argv) { LOG(INFO) << "output shape: " << FLAGS_output_shape; LOG(INFO) << "input_file: " << FLAGS_input_file; LOG(INFO) << "output_file: " << FLAGS_output_file; + LOG(INFO) << "input dir: " << FLAGS_input_dir; + LOG(INFO) << "output dir: " << FLAGS_output_dir; LOG(INFO) << "model_data_file: " << FLAGS_model_data_file; LOG(INFO) << "model_file: " << FLAGS_model_file; LOG(INFO) << "device: " << FLAGS_device; -- GitLab