diff --git a/mace/dsp/BUILD b/mace/dsp/BUILD index 8deaa99dc6ba867cbf5511bdeec70f7cfd7a7a8e..cbe9f834a40dba000bf7da74320dae422bdf99e1 100644 --- a/mace/dsp/BUILD +++ b/mace/dsp/BUILD @@ -64,3 +64,21 @@ cc_test( "//mace/kernels:kernels", ], ) + +cc_binary( + name = "mace_dsp_run", + srcs = [ + "tool/mace_dsp_run.cc", + ], + copts = ["-std=c++11", "-D_GLIBCXX_USE_C99_MATH_TR1"], + linkopts = if_android([ + "-ldl", + "-lm", + ]), + linkstatic = 1, + deps = [ + ":dsp", + "//mace/kernels:kernels", + "//mace/utils:command_line_flags", + ], +) \ No newline at end of file diff --git a/mace/dsp/tool/mace_dsp_run.cc b/mace/dsp/tool/mace_dsp_run.cc new file mode 100644 index 0000000000000000000000000000000000000000..2c8e7afae7acfccec8418b4b63da75d0a6d47af4 --- /dev/null +++ b/mace/dsp/tool/mace_dsp_run.cc @@ -0,0 +1,109 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +/** + * Usage: + * mace_dsp_run --model=mobi_mace.pb \ + * --input_shape=1,3,224,224 \ + * --input_file=input_data \ + * --output_file=mace.out + */ +#include +#include +#include "mace/dsp/hexagon_control_wrapper.h" +#include "mace/core/net.h" +#include "mace/utils/command_line_flags.h" + +using namespace std; +using namespace mace; + +void ParseShape(const string &str, vector *shape) { + string tmp = str; + while (!tmp.empty()) { + int dim = atoi(tmp.data()); + shape->push_back(dim); + size_t next_offset = tmp.find(","); + if (next_offset == string::npos) { + break; + } else { + tmp = tmp.substr(next_offset + 1); + } + } +} + +int main(int argc, char **argv) { + string model_file; + string input_shape; + string input_file; + string output_file; + int round = 1; + + std::vector flag_list = { + Flag("model", &model_file, "model file name"), + Flag("input_shape", &input_shape, "input shape, separated by comma"), + Flag("input_file", &input_file, "input file name"), + Flag("output_file", &output_file, "output file name"), + Flag("round", &round, "round"), + }; + + string usage = Flags::Usage(argv[0], flag_list); + const bool parse_result = Flags::Parse(&argc, argv, flag_list); + + if (!parse_result) { + LOG(ERROR) << usage; + return -1; + } + + VLOG(0) << "model: " << model_file << std::endl + << "input_shape: " << input_shape << std::endl + << "input_file: " << input_file << std::endl + << "output_file: " << output_file << std::endl + << "round: " << round << std::endl; + + vector shape; + ParseShape(input_shape, &shape); + + // load input + Tensor input_tensor; + input_tensor.Resize(shape); + float *input_data = input_tensor.mutable_data(); + ifstream in_file(input_file, ios::in | ios::binary); + in_file.read(reinterpret_cast(input_data), + input_tensor.size() * sizeof(float)); + in_file.close(); + + // execute + HexagonControlWrapper wrapper; + VLOG(0) << "version: " << wrapper.GetVersion(); + wrapper.Init(); + wrapper.SetDebugLevel(0); + wrapper.Config(); + VLOG(0) << wrapper.SetupGraph(model_file); + wrapper.PrintGraph(); + + Tensor output_tensor; + timeval tv1, tv2; + gettimeofday(&tv1, NULL); + for (int i = 0; i < round; ++i) { + VLOG(0) << wrapper.ExecuteGraph(input_tensor, &output_tensor); + } + gettimeofday(&tv2, NULL); + cout << "avg duration: " + << ((tv2.tv_sec - tv1.tv_sec) * 1000 + + (tv2.tv_usec - tv1.tv_usec) / 1000) / + round + << endl; + + wrapper.GetPerfInfo(); + wrapper.PrintLog(); + VLOG(0) << wrapper.TeardownGraph(); + wrapper.Finalize(); + + // save output + ofstream out_file(output_file, ios::binary); + out_file.write((const char *) (output_tensor.data()), + output_tensor.size() * sizeof(float)); + out_file.flush(); + out_file.close(); +} \ No newline at end of file