提交 2b04dbe9 编写于 作者: B Bin Li 提交者: 叶剑武

Refactor mace_run

上级 d83cab22
......@@ -62,7 +62,7 @@ MACE provides tools to do statistics with following steps:
.. code:: sh
python tools/python/tools/quantize/quantize_stat.py --log_file range_log > overall_range
python tools/python/quantize/quantize_stat.py --log_file range_log > overall_range
4. Convert quantized model (by setting `target_abis` to the final target abis, e.g., `armeabi-v7a`,
......
......@@ -318,61 +318,52 @@ bool RunModel(const std::string &model_name,
DIR *dir_parent;
struct dirent *entry;
dir_parent = opendir(FLAGS_input_dir.c_str());
if (dir_parent) {
while ((entry = readdir(dir_parent))) {
std::string file_name = std::string(entry->d_name);
std::string prefix = FormatName(input_names[0]);
if (file_name.find(prefix) == 0) {
std::string suffix = file_name.substr(prefix.size());
for (size_t i = 0; i < input_count; ++i) {
file_name = FLAGS_input_dir + "/" + FormatName(input_names[i])
+ suffix;
std::ifstream in_file(file_name, std::ios::in | std::ios::binary);
std::cout << "Read " << file_name << std::endl;
if (in_file.is_open()) {
in_file.read(reinterpret_cast<char *>(
inputs[input_names[i]].data().get()),
inputs_size[input_names[i]] * sizeof(float));
in_file.close();
} else {
std::cerr << "Open input file failed" << std::endl;
return -1;
}
}
engine->Run(inputs, &outputs);
if (!FLAGS_output_dir.empty()) {
for (size_t i = 0; i < output_count; ++i) {
std::string output_name =
FLAGS_output_dir + "/" + FormatName(output_names[i]) + suffix;
std::ofstream out_file(output_name, std::ios::binary);
if (out_file.is_open()) {
int64_t output_size =
std::accumulate(output_shapes[i].begin(),
output_shapes[i].end(),
1,
std::multiplies<int64_t>());
out_file.write(
reinterpret_cast<char *>(
outputs[output_names[i]].data().get()),
output_size * sizeof(float));
out_file.flush();
out_file.close();
} else {
std::cerr << "Open output file failed" << std::endl;
return -1;
}
}
MACE_CHECK(dir_parent != nullptr, "Open input_dir ", FLAGS_input_dir,
" failed: ", strerror(errno));
while ((entry = readdir(dir_parent))) {
std::string file_name = std::string(entry->d_name);
std::string prefix = FormatName(input_names[0]);
if (file_name.find(prefix) == 0) {
std::string suffix = file_name.substr(prefix.size());
for (size_t i = 0; i < input_count; ++i) {
file_name = FLAGS_input_dir + "/" + FormatName(input_names[i])
+ suffix;
std::ifstream in_file(file_name, std::ios::in | std::ios::binary);
LOG(INFO) << "Read " << file_name;
MACE_CHECK(in_file.is_open(), "Open input file failed: ",
strerror(errno));
in_file.read(reinterpret_cast<char *>(
inputs[input_names[i]].data().get()),
inputs_size[input_names[i]] * sizeof(float));
in_file.close();
}
engine->Run(inputs, &outputs);
if (!FLAGS_output_dir.empty()) {
for (size_t i = 0; i < output_count; ++i) {
std::string output_name =
FLAGS_output_dir + "/" + FormatName(output_names[i]) + suffix;
std::ofstream out_file(output_name, std::ios::binary);
MACE_CHECK(out_file.is_open(), "Open output file failed: ",
strerror(errno));
int64_t output_size =
std::accumulate(output_shapes[i].begin(),
output_shapes[i].end(),
1,
std::multiplies<int64_t>());
out_file.write(
reinterpret_cast<char *>(
outputs[output_names[i]].data().get()),
output_size * sizeof(float));
out_file.flush();
out_file.close();
}
}
}
closedir(dir_parent);
} else {
std::cerr << "Directory " << FLAGS_input_dir << " does not exist."
<< std::endl;
}
closedir(dir_parent);
} else {
LOG(INFO) << "Warm up run";
double warmup_millis;
......@@ -539,6 +530,8 @@ int Main(int argc, char **argv) {
LOG(INFO) << "output shape: " << FLAGS_output_shape;
LOG(INFO) << "input_file: " << FLAGS_input_file;
LOG(INFO) << "output_file: " << FLAGS_output_file;
LOG(INFO) << "input dir: " << FLAGS_input_dir;
LOG(INFO) << "output dir: " << FLAGS_output_dir;
LOG(INFO) << "model_data_file: " << FLAGS_model_data_file;
LOG(INFO) << "model_file: " << FLAGS_model_file;
LOG(INFO) << "device: " << FLAGS_device;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册