提交 2b13dd2c 编写于 作者: L liuqi

Model benchmark tool support multiple inputs or outputs.

上级 a2f173c2
...@@ -274,12 +274,12 @@ bool MultipleInputOrOutput(const std::vector<std::string> &input_names, ...@@ -274,12 +274,12 @@ bool MultipleInputOrOutput(const std::vector<std::string> &input_names,
LOG(INFO) << "Net init latency: " << t1 - t0 << " us"; LOG(INFO) << "Net init latency: " << t1 - t0 << " us";
LOG(INFO) << "Total init latency: " << init_micros << " us"; LOG(INFO) << "Total init latency: " << init_micros << " us";
const int input_count = input_names.size(); const size_t input_count = input_names.size();
const int output_count = output_names.size(); const size_t output_count = output_names.size();
std::vector<mace::MaceInputInfo> input_infos(input_count); std::vector<mace::MaceInputInfo> input_infos(input_count);
std::map<std::string, float*> outputs; std::map<std::string, float*> outputs;
std::vector<std::unique_ptr<float[]>> input_datas(input_count); std::vector<std::unique_ptr<float[]>> input_datas(input_count);
for (int i = 0; i < input_count; ++i) { for (size_t i = 0; i < input_count; ++i) {
// Allocate input and output // Allocate input and output
int64_t input_size = int64_t input_size =
std::accumulate(input_shapes[i].begin(), input_shapes[i].end(), 1, std::accumulate(input_shapes[i].begin(), input_shapes[i].end(), 1,
...@@ -300,7 +300,7 @@ bool MultipleInputOrOutput(const std::vector<std::string> &input_names, ...@@ -300,7 +300,7 @@ bool MultipleInputOrOutput(const std::vector<std::string> &input_names,
input_infos[i].data = input_datas[i].get(); input_infos[i].data = input_datas[i].get();
} }
std::vector<std::unique_ptr<float[]>> output_datas(output_count); std::vector<std::unique_ptr<float[]>> output_datas(output_count);
for (int i = 0; i < output_count; ++i) { for (size_t i = 0; i < output_count; ++i) {
int64_t output_size = int64_t output_size =
std::accumulate(output_shapes[i].begin(), output_shapes[i].end(), 1, std::accumulate(output_shapes[i].begin(), output_shapes[i].end(), 1,
std::multiplies<int64_t>()); std::multiplies<int64_t>());
...@@ -329,7 +329,7 @@ bool MultipleInputOrOutput(const std::vector<std::string> &input_names, ...@@ -329,7 +329,7 @@ bool MultipleInputOrOutput(const std::vector<std::string> &input_names,
LOG(INFO) << "Averate latency: " << (t1 - t0) / FLAGS_round << " us"; LOG(INFO) << "Averate latency: " << (t1 - t0) / FLAGS_round << " us";
} }
for (int i = 0; i < output_count; ++i) { for (size_t i = 0; i < output_count; ++i) {
std::string output_name = FLAGS_output_file + "_" + FormatName(output_names[i]); std::string output_name = FLAGS_output_file + "_" + FormatName(output_names[i]);
ofstream out_file(output_name, ios::binary); ofstream out_file(output_name, ios::binary);
int64_t output_size = int64_t output_size =
...@@ -370,14 +370,14 @@ int main(int argc, char **argv) { ...@@ -370,14 +370,14 @@ int main(int argc, char **argv) {
std::vector<std::string> input_shapes = str_util::Split(FLAGS_input_shape, ':'); std::vector<std::string> input_shapes = str_util::Split(FLAGS_input_shape, ':');
std::vector<std::string> output_shapes = str_util::Split(FLAGS_output_shape, ':'); std::vector<std::string> output_shapes = str_util::Split(FLAGS_output_shape, ':');
const int input_count = input_shapes.size(); const size_t input_count = input_shapes.size();
const int output_count = output_shapes.size(); const size_t output_count = output_shapes.size();
std::vector<vector<int64_t>> input_shape_vec(input_count); std::vector<vector<int64_t>> input_shape_vec(input_count);
std::vector<vector<int64_t>> output_shape_vec(output_count); std::vector<vector<int64_t>> output_shape_vec(output_count);
for (int i = 0; i < input_count; ++i) { for (size_t i = 0; i < input_count; ++i) {
ParseShape(input_shapes[i], &input_shape_vec[i]); ParseShape(input_shapes[i], &input_shape_vec[i]);
} }
for (int i = 0; i < output_count; ++i) { for (size_t i = 0; i < output_count; ++i) {
ParseShape(output_shapes[i], &output_shape_vec[i]); ParseShape(output_shapes[i], &output_shape_vec[i]);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册