提交 a41a3f77 编写于 作者: L liuqi

Add arguments checking for mace_run tool.

上级 279947d1
......@@ -99,16 +99,16 @@ DEFINE_string(model_name,
"",
"model name in yaml");
DEFINE_string(input_node,
"input_node0,input_node1",
"",
"input nodes, separated by comma");
DEFINE_string(input_shape,
"1,224,224,3:1,1,1,10",
"",
"input shapes, separated by colon and comma");
DEFINE_string(output_node,
"output_node0,output_node1",
"",
"output nodes, separated by comma");
DEFINE_string(output_shape,
"1,224,224,2:1,1,1,10",
"",
"output shapes, separated by colon and comma");
DEFINE_string(input_data_format,
"NHWC",
......@@ -222,6 +222,10 @@ bool RunModel(const std::string &model_name,
// Create Engine
int64_t t0 = NowMicros();
#ifdef MODEL_GRAPH_FORMAT_CODE
if (model_name.empty()) {
LOG(INFO) << "Please specify model name you want to run";
return false;
}
create_engine_status =
CreateMaceEngineFromCode(model_name,
reinterpret_cast<const unsigned char *>(
......@@ -233,6 +237,10 @@ bool RunModel(const std::string &model_name,
&engine);
#else
(void)(model_name);
if (model_graph_data == nullptr || model_weights_data == nullptr) {
LOG(INFO) << "Please specify model graph file and model data file";
return false;
}
create_engine_status =
CreateMaceEngineFromProto(reinterpret_cast<const unsigned char *>(
model_graph_data->data()),
......@@ -425,11 +433,19 @@ bool RunModel(const std::string &model_name,
}
int Main(int argc, char **argv) {
std::string usage = "mace run model\nusage: " + std::string(argv[0])
+ " [flags]";
std::string usage = "MACE run model tool, please specify proper arguments.\n"
"usage: " + std::string(argv[0])
+ " --help";
gflags::SetUsageMessage(usage);
gflags::ParseCommandLineFlags(&argc, &argv, true);
std::vector<std::string> input_names = Split(FLAGS_input_node, ',');
std::vector<std::string> output_names = Split(FLAGS_output_node, ',');
if (input_names.empty() || output_names.empty()) {
LOG(INFO) << gflags::ProgramUsage();
return 0;
}
LOG(INFO) << "model name: " << FLAGS_model_name;
LOG(INFO) << "mace version: " << MaceVersion();
LOG(INFO) << "input node: " << FLAGS_input_node;
......@@ -448,8 +464,6 @@ int Main(int argc, char **argv) {
LOG(INFO) << "omp_num_threads: " << FLAGS_omp_num_threads;
LOG(INFO) << "cpu_affinity_policy: " << FLAGS_cpu_affinity_policy;
std::vector<std::string> input_names = Split(FLAGS_input_node, ',');
std::vector<std::string> output_names = Split(FLAGS_output_node, ',');
std::vector<std::string> input_shapes = Split(FLAGS_input_shape, ':');
std::vector<std::string> output_shapes = Split(FLAGS_output_shape, ':');
......@@ -463,6 +477,12 @@ int Main(int argc, char **argv) {
for (size_t i = 0; i < output_count; ++i) {
ParseShape(output_shapes[i], &output_shape_vec[i]);
}
if (input_names.size() != input_shape_vec.size()
|| output_names.size() != output_shape_vec.size()) {
LOG(INFO) << "inputs' names do not match inputs' shapes "
"or outputs' names do not match outputs' shapes";
return 0;
}
std::vector<std::string> raw_input_data_formats =
Split(FLAGS_input_data_format, ',');
std::vector<std::string> raw_output_data_formats =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册