diff --git a/mace/tools/validation/mace_run.cc b/mace/tools/validation/mace_run.cc index fca4a0fd42958110130e6317274b32a600106ab3..ba25cced37aaeaff60e5e10d13b4add7580fd389 100644 --- a/mace/tools/validation/mace_run.cc +++ b/mace/tools/validation/mace_run.cc @@ -99,16 +99,16 @@ DEFINE_string(model_name, "", "model name in yaml"); DEFINE_string(input_node, - "input_node0,input_node1", + "", "input nodes, separated by comma"); DEFINE_string(input_shape, - "1,224,224,3:1,1,1,10", + "", "input shapes, separated by colon and comma"); DEFINE_string(output_node, - "output_node0,output_node1", + "", "output nodes, separated by comma"); DEFINE_string(output_shape, - "1,224,224,2:1,1,1,10", + "", "output shapes, separated by colon and comma"); DEFINE_string(input_data_format, "NHWC", @@ -222,6 +222,10 @@ bool RunModel(const std::string &model_name, // Create Engine int64_t t0 = NowMicros(); #ifdef MODEL_GRAPH_FORMAT_CODE + if (model_name.empty()) { + LOG(INFO) << "Please specify model name you want to run"; + return false; + } create_engine_status = CreateMaceEngineFromCode(model_name, reinterpret_cast( @@ -233,6 +237,10 @@ bool RunModel(const std::string &model_name, &engine); #else (void)(model_name); + if (model_graph_data == nullptr || model_weights_data == nullptr) { + LOG(INFO) << "Please specify model graph file and model data file"; + return false; + } create_engine_status = CreateMaceEngineFromProto(reinterpret_cast( model_graph_data->data()), @@ -425,11 +433,19 @@ bool RunModel(const std::string &model_name, } int Main(int argc, char **argv) { - std::string usage = "mace run model\nusage: " + std::string(argv[0]) - + " [flags]"; + std::string usage = "MACE run model tool, please specify proper arguments.\n" + "usage: " + std::string(argv[0]) + + " --help"; gflags::SetUsageMessage(usage); gflags::ParseCommandLineFlags(&argc, &argv, true); + std::vector input_names = Split(FLAGS_input_node, ','); + std::vector output_names = Split(FLAGS_output_node, ','); + if (input_names.empty() || output_names.empty()) { + LOG(INFO) << gflags::ProgramUsage(); + return 0; + } + LOG(INFO) << "model name: " << FLAGS_model_name; LOG(INFO) << "mace version: " << MaceVersion(); LOG(INFO) << "input node: " << FLAGS_input_node; @@ -448,8 +464,6 @@ int Main(int argc, char **argv) { LOG(INFO) << "omp_num_threads: " << FLAGS_omp_num_threads; LOG(INFO) << "cpu_affinity_policy: " << FLAGS_cpu_affinity_policy; - std::vector input_names = Split(FLAGS_input_node, ','); - std::vector output_names = Split(FLAGS_output_node, ','); std::vector input_shapes = Split(FLAGS_input_shape, ':'); std::vector output_shapes = Split(FLAGS_output_shape, ':'); @@ -463,6 +477,12 @@ int Main(int argc, char **argv) { for (size_t i = 0; i < output_count; ++i) { ParseShape(output_shapes[i], &output_shape_vec[i]); } + if (input_names.size() != input_shape_vec.size() + || output_names.size() != output_shape_vec.size()) { + LOG(INFO) << "inputs' names do not match inputs' shapes " + "or outputs' names do not match outputs' shapes"; + return 0; + } std::vector raw_input_data_formats = Split(FLAGS_input_data_format, ','); std::vector raw_output_data_formats =