提交 3828ac54 编写于 作者: L liuqi

Fix benchmark model with new APIs.

上级 e9eaa4d4
......@@ -12,6 +12,7 @@
#include "gflags/gflags.h"
#include "mace/public/mace.h"
#include "mace/public/mace_runtime.h"
#include "mace/utils/logging.h"
#include "mace/benchmark/stat_summarizer.h"
......@@ -95,9 +96,23 @@ inline int64_t NowMicros() {
return static_cast<int64_t>(tv.tv_sec) * 1000000 + tv.tv_usec;
}
DeviceType ParseDeviceType(const std::string &device_str) {
if (device_str.compare("CPU") == 0) {
return DeviceType::CPU;
} else if (device_str.compare("NEON") == 0) {
return DeviceType::NEON;
} else if (device_str.compare("OPENCL") == 0) {
return DeviceType::OPENCL;
} else if (device_str.compare("HEXAGON") == 0) {
return DeviceType::HEXAGON;
} else {
return DeviceType::CPU;
}
}
bool RunInference(MaceEngine *engine,
const std::vector<mace::MaceInputInfo> &input_infos,
std::map<std::string, float*> *output_infos,
const std::map<std::string, mace::MaceTensor> &input_infos,
std::map<std::string, mace::MaceTensor> *output_infos,
StatSummarizer *summarizer,
int64_t *inference_time_us) {
MACE_CHECK_NOTNULL(output_infos);
......@@ -106,28 +121,16 @@ bool RunInference(MaceEngine *engine,
if (summarizer) {
run_metadata_ptr = &run_metadata;
}
if (input_infos.size() == 1 && output_infos->size() == 1) {
const int64_t start_time = NowMicros();
bool s = engine->Run(input_infos[0].data, input_infos[0].shape,
output_infos->begin()->second, run_metadata_ptr);
const int64_t end_time = NowMicros();
if (!s) {
LOG(ERROR) << "Error during inference.";
return s;
}
*inference_time_us = end_time - start_time;
} else {
const int64_t start_time = NowMicros();
bool s = engine->Run(input_infos, *output_infos, run_metadata_ptr);
const int64_t end_time = NowMicros();
const int64_t start_time = NowMicros();
mace::MaceStatus s = engine->Run(input_infos, output_infos, run_metadata_ptr);
const int64_t end_time = NowMicros();
if (!s) {
LOG(ERROR) << "Error during inference.";
return s;
}
*inference_time_us = end_time - start_time;
if (s != mace::MaceStatus::MACE_SUCCESS) {
LOG(ERROR) << "Error during inference.";
return false;
}
*inference_time_us = end_time - start_time;
if (summarizer != nullptr) {
summarizer->ProcessMetadata(run_metadata);
......@@ -137,8 +140,8 @@ bool RunInference(MaceEngine *engine,
}
bool Run(MaceEngine *engine,
const std::vector<mace::MaceInputInfo> &input_infos,
std::map<std::string, float*> *output_infos,
const std::map<std::string, mace::MaceTensor> &input_infos,
std::map<std::string, mace::MaceTensor> *output_infos,
StatSummarizer *summarizer,
int num_runs,
double max_time_sec,
......@@ -261,12 +264,7 @@ int Main(int argc, char **argv) {
stats_options.show_summary = FLAGS_show_summary;
stats.reset(new StatSummarizer(stats_options));
DeviceType device_type = CPU;
if (FLAGS_device == "OPENCL") {
device_type = OPENCL;
} else if (FLAGS_device == "NEON") {
device_type = NEON;
}
mace::DeviceType device_type = ParseDeviceType(FLAGS_device);
// config runtime
mace::ConfigOmpThreads(FLAGS_omp_num_threads);
......@@ -302,50 +300,44 @@ int Main(int argc, char **argv) {
mace::MACE_MODEL_TAG::LoadModelData(FLAGS_model_data_file.c_str());
NetDef net_def = mace::MACE_MODEL_TAG::CreateNet(model_data);
std::vector<mace::MaceInputInfo> input_infos(input_count);
std::map<std::string, float*> output_infos;
std::vector<std::unique_ptr<float[]>> input_datas(input_count);
std::vector<std::unique_ptr<float[]>> output_datas(output_count);
std::map<std::string, mace::MaceTensor> inputs;
std::map<std::string, mace::MaceTensor> outputs;
for (size_t i = 0; i < input_count; ++i) {
int64_t input_size = std::accumulate(input_shape_vec[i].begin(),
input_shape_vec[i].end(), 1,
std::multiplies<int64_t>());
input_datas[i].reset(new float[input_size]);
// Allocate input and output
int64_t input_size =
std::accumulate(input_shape_vec[i].begin(), input_shape_vec[i].end(), 1,
std::multiplies<int64_t>());
auto buffer_in = std::shared_ptr<float>(new float[input_size],
std::default_delete<float[]>());
// load input
std::ifstream in_file(FLAGS_input_file + "_" + FormatName(input_names[i]),
std::ios::in | std::ios::binary);
if (in_file.is_open()) {
in_file.read(reinterpret_cast<char *>(input_datas[i].get()),
in_file.read(reinterpret_cast<char *>(buffer_in.get()),
input_size * sizeof(float));
in_file.close();
} else {
LOG(INFO) << "Open input file failed";
return -1;
}
input_infos[i].name = input_names[i];
input_infos[i].shape = input_shape_vec[i];
input_infos[i].data = input_datas[i].get();
inputs[input_names[i]] = mace::MaceTensor(input_shape_vec[i], buffer_in);
}
for (size_t i = 0; i < output_count; ++i) {
int64_t output_size = std::accumulate(output_shape_vec[i].begin(),
output_shape_vec[i].end(), 1,
std::multiplies<int64_t>());
output_datas[i].reset(new float[output_size]);
output_infos[output_names[i]] = output_datas[i].get();
int64_t output_size =
std::accumulate(output_shape_vec[i].begin(),
output_shape_vec[i].end(), 1,
std::multiplies<int64_t>());
auto buffer_out = std::shared_ptr<float>(new float[output_size],
std::default_delete<float[]>());
outputs[output_names[i]] = mace::MaceTensor(output_shape_vec[i], buffer_out);
}
// Init model
LOG(INFO) << "Run init";
std::unique_ptr<mace::MaceEngine> engine_ptr;
if (input_count == 1 && output_count == 1) {
engine_ptr.reset(new mace::MaceEngine(&net_def, device_type));
} else {
engine_ptr.reset(new mace::MaceEngine(&net_def, device_type,
input_names, output_names));
}
if (device_type == DeviceType::OPENCL) {
std::unique_ptr<mace::MaceEngine> engine_ptr(
new mace::MaceEngine(&net_def, device_type, input_names, output_names));
if (device_type == DeviceType::OPENCL || device_type == DeviceType::HEXAGON) {
mace::MACE_MODEL_TAG::UnloadModelData(model_data);
}
......@@ -355,7 +347,7 @@ int Main(int argc, char **argv) {
int64_t num_warmup_runs = 0;
if (FLAGS_warmup_runs > 0) {
bool status =
Run(engine_ptr.get(), input_infos, &output_infos, nullptr,
Run(engine_ptr.get(), inputs, &outputs, nullptr,
FLAGS_warmup_runs, -1.0,
inter_inference_sleep_seconds, &warmup_time_us, &num_warmup_runs);
if (!status) {
......@@ -370,7 +362,7 @@ int Main(int argc, char **argv) {
int64_t no_stat_time_us = 0;
int64_t no_stat_runs = 0;
bool status =
Run(engine_ptr.get(), input_infos, &output_infos,
Run(engine_ptr.get(), inputs, &outputs,
nullptr, FLAGS_max_num_runs, max_benchmark_time_seconds,
inter_inference_sleep_seconds, &no_stat_time_us, &no_stat_runs);
if (!status) {
......@@ -379,7 +371,7 @@ int Main(int argc, char **argv) {
int64_t stat_time_us = 0;
int64_t stat_runs = 0;
status = Run(engine_ptr.get(), input_infos, &output_infos,
status = Run(engine_ptr.get(), inputs, &outputs,
stats.get(), FLAGS_max_num_runs, max_benchmark_time_seconds,
inter_inference_sleep_seconds, &stat_time_us, &stat_runs);
if (!status) {
......
#!/bin/bash
set -x
Usage() {
echo "Usage: bash tools/benchmark.sh target_soc model_output_dir option_args"
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册