提交 2b04dbe9 编写于 作者: B Bin Li 提交者: 叶剑武

Refactor mace_run

上级 d83cab22
...@@ -62,7 +62,7 @@ MACE provides tools to do statistics with following steps: ...@@ -62,7 +62,7 @@ MACE provides tools to do statistics with following steps:
.. code:: sh .. code:: sh
python tools/python/tools/quantize/quantize_stat.py --log_file range_log > overall_range python tools/python/quantize/quantize_stat.py --log_file range_log > overall_range
4. Convert quantized model (by setting `target_abis` to the final target abis, e.g., `armeabi-v7a`, 4. Convert quantized model (by setting `target_abis` to the final target abis, e.g., `armeabi-v7a`,
......
...@@ -318,61 +318,52 @@ bool RunModel(const std::string &model_name, ...@@ -318,61 +318,52 @@ bool RunModel(const std::string &model_name,
DIR *dir_parent; DIR *dir_parent;
struct dirent *entry; struct dirent *entry;
dir_parent = opendir(FLAGS_input_dir.c_str()); dir_parent = opendir(FLAGS_input_dir.c_str());
if (dir_parent) { MACE_CHECK(dir_parent != nullptr, "Open input_dir ", FLAGS_input_dir,
while ((entry = readdir(dir_parent))) { " failed: ", strerror(errno));
std::string file_name = std::string(entry->d_name); while ((entry = readdir(dir_parent))) {
std::string prefix = FormatName(input_names[0]); std::string file_name = std::string(entry->d_name);
if (file_name.find(prefix) == 0) { std::string prefix = FormatName(input_names[0]);
std::string suffix = file_name.substr(prefix.size()); if (file_name.find(prefix) == 0) {
std::string suffix = file_name.substr(prefix.size());
for (size_t i = 0; i < input_count; ++i) {
file_name = FLAGS_input_dir + "/" + FormatName(input_names[i]) for (size_t i = 0; i < input_count; ++i) {
+ suffix; file_name = FLAGS_input_dir + "/" + FormatName(input_names[i])
std::ifstream in_file(file_name, std::ios::in | std::ios::binary); + suffix;
std::cout << "Read " << file_name << std::endl; std::ifstream in_file(file_name, std::ios::in | std::ios::binary);
if (in_file.is_open()) { LOG(INFO) << "Read " << file_name;
in_file.read(reinterpret_cast<char *>( MACE_CHECK(in_file.is_open(), "Open input file failed: ",
inputs[input_names[i]].data().get()), strerror(errno));
inputs_size[input_names[i]] * sizeof(float)); in_file.read(reinterpret_cast<char *>(
in_file.close(); inputs[input_names[i]].data().get()),
} else { inputs_size[input_names[i]] * sizeof(float));
std::cerr << "Open input file failed" << std::endl; in_file.close();
return -1; }
} engine->Run(inputs, &outputs);
}
engine->Run(inputs, &outputs); if (!FLAGS_output_dir.empty()) {
for (size_t i = 0; i < output_count; ++i) {
if (!FLAGS_output_dir.empty()) { std::string output_name =
for (size_t i = 0; i < output_count; ++i) { FLAGS_output_dir + "/" + FormatName(output_names[i]) + suffix;
std::string output_name = std::ofstream out_file(output_name, std::ios::binary);
FLAGS_output_dir + "/" + FormatName(output_names[i]) + suffix; MACE_CHECK(out_file.is_open(), "Open output file failed: ",
std::ofstream out_file(output_name, std::ios::binary); strerror(errno));
if (out_file.is_open()) { int64_t output_size =
int64_t output_size = std::accumulate(output_shapes[i].begin(),
std::accumulate(output_shapes[i].begin(), output_shapes[i].end(),
output_shapes[i].end(), 1,
1, std::multiplies<int64_t>());
std::multiplies<int64_t>()); out_file.write(
out_file.write( reinterpret_cast<char *>(
reinterpret_cast<char *>( outputs[output_names[i]].data().get()),
outputs[output_names[i]].data().get()), output_size * sizeof(float));
output_size * sizeof(float)); out_file.flush();
out_file.flush(); out_file.close();
out_file.close();
} else {
std::cerr << "Open output file failed" << std::endl;
return -1;
}
}
} }
} }
} }
closedir(dir_parent);
} else {
std::cerr << "Directory " << FLAGS_input_dir << " does not exist."
<< std::endl;
} }
closedir(dir_parent);
} else { } else {
LOG(INFO) << "Warm up run"; LOG(INFO) << "Warm up run";
double warmup_millis; double warmup_millis;
...@@ -539,6 +530,8 @@ int Main(int argc, char **argv) { ...@@ -539,6 +530,8 @@ int Main(int argc, char **argv) {
LOG(INFO) << "output shape: " << FLAGS_output_shape; LOG(INFO) << "output shape: " << FLAGS_output_shape;
LOG(INFO) << "input_file: " << FLAGS_input_file; LOG(INFO) << "input_file: " << FLAGS_input_file;
LOG(INFO) << "output_file: " << FLAGS_output_file; LOG(INFO) << "output_file: " << FLAGS_output_file;
LOG(INFO) << "input dir: " << FLAGS_input_dir;
LOG(INFO) << "output dir: " << FLAGS_output_dir;
LOG(INFO) << "model_data_file: " << FLAGS_model_data_file; LOG(INFO) << "model_data_file: " << FLAGS_model_data_file;
LOG(INFO) << "model_file: " << FLAGS_model_file; LOG(INFO) << "model_file: " << FLAGS_model_file;
LOG(INFO) << "device: " << FLAGS_device; LOG(INFO) << "device: " << FLAGS_device;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册