提交 a41a3f77 编写于 作者: L liuqi

Add arguments checking for mace_run tool.

上级 279947d1
...@@ -99,16 +99,16 @@ DEFINE_string(model_name, ...@@ -99,16 +99,16 @@ DEFINE_string(model_name,
"", "",
"model name in yaml"); "model name in yaml");
DEFINE_string(input_node, DEFINE_string(input_node,
"input_node0,input_node1", "",
"input nodes, separated by comma"); "input nodes, separated by comma");
DEFINE_string(input_shape, DEFINE_string(input_shape,
"1,224,224,3:1,1,1,10", "",
"input shapes, separated by colon and comma"); "input shapes, separated by colon and comma");
DEFINE_string(output_node, DEFINE_string(output_node,
"output_node0,output_node1", "",
"output nodes, separated by comma"); "output nodes, separated by comma");
DEFINE_string(output_shape, DEFINE_string(output_shape,
"1,224,224,2:1,1,1,10", "",
"output shapes, separated by colon and comma"); "output shapes, separated by colon and comma");
DEFINE_string(input_data_format, DEFINE_string(input_data_format,
"NHWC", "NHWC",
...@@ -222,6 +222,10 @@ bool RunModel(const std::string &model_name, ...@@ -222,6 +222,10 @@ bool RunModel(const std::string &model_name,
// Create Engine // Create Engine
int64_t t0 = NowMicros(); int64_t t0 = NowMicros();
#ifdef MODEL_GRAPH_FORMAT_CODE #ifdef MODEL_GRAPH_FORMAT_CODE
if (model_name.empty()) {
LOG(INFO) << "Please specify model name you want to run";
return false;
}
create_engine_status = create_engine_status =
CreateMaceEngineFromCode(model_name, CreateMaceEngineFromCode(model_name,
reinterpret_cast<const unsigned char *>( reinterpret_cast<const unsigned char *>(
...@@ -233,6 +237,10 @@ bool RunModel(const std::string &model_name, ...@@ -233,6 +237,10 @@ bool RunModel(const std::string &model_name,
&engine); &engine);
#else #else
(void)(model_name); (void)(model_name);
if (model_graph_data == nullptr || model_weights_data == nullptr) {
LOG(INFO) << "Please specify model graph file and model data file";
return false;
}
create_engine_status = create_engine_status =
CreateMaceEngineFromProto(reinterpret_cast<const unsigned char *>( CreateMaceEngineFromProto(reinterpret_cast<const unsigned char *>(
model_graph_data->data()), model_graph_data->data()),
...@@ -425,11 +433,19 @@ bool RunModel(const std::string &model_name, ...@@ -425,11 +433,19 @@ bool RunModel(const std::string &model_name,
} }
int Main(int argc, char **argv) { int Main(int argc, char **argv) {
std::string usage = "mace run model\nusage: " + std::string(argv[0]) std::string usage = "MACE run model tool, please specify proper arguments.\n"
+ " [flags]"; "usage: " + std::string(argv[0])
+ " --help";
gflags::SetUsageMessage(usage); gflags::SetUsageMessage(usage);
gflags::ParseCommandLineFlags(&argc, &argv, true); gflags::ParseCommandLineFlags(&argc, &argv, true);
std::vector<std::string> input_names = Split(FLAGS_input_node, ',');
std::vector<std::string> output_names = Split(FLAGS_output_node, ',');
if (input_names.empty() || output_names.empty()) {
LOG(INFO) << gflags::ProgramUsage();
return 0;
}
LOG(INFO) << "model name: " << FLAGS_model_name; LOG(INFO) << "model name: " << FLAGS_model_name;
LOG(INFO) << "mace version: " << MaceVersion(); LOG(INFO) << "mace version: " << MaceVersion();
LOG(INFO) << "input node: " << FLAGS_input_node; LOG(INFO) << "input node: " << FLAGS_input_node;
...@@ -448,8 +464,6 @@ int Main(int argc, char **argv) { ...@@ -448,8 +464,6 @@ int Main(int argc, char **argv) {
LOG(INFO) << "omp_num_threads: " << FLAGS_omp_num_threads; LOG(INFO) << "omp_num_threads: " << FLAGS_omp_num_threads;
LOG(INFO) << "cpu_affinity_policy: " << FLAGS_cpu_affinity_policy; LOG(INFO) << "cpu_affinity_policy: " << FLAGS_cpu_affinity_policy;
std::vector<std::string> input_names = Split(FLAGS_input_node, ',');
std::vector<std::string> output_names = Split(FLAGS_output_node, ',');
std::vector<std::string> input_shapes = Split(FLAGS_input_shape, ':'); std::vector<std::string> input_shapes = Split(FLAGS_input_shape, ':');
std::vector<std::string> output_shapes = Split(FLAGS_output_shape, ':'); std::vector<std::string> output_shapes = Split(FLAGS_output_shape, ':');
...@@ -463,6 +477,12 @@ int Main(int argc, char **argv) { ...@@ -463,6 +477,12 @@ int Main(int argc, char **argv) {
for (size_t i = 0; i < output_count; ++i) { for (size_t i = 0; i < output_count; ++i) {
ParseShape(output_shapes[i], &output_shape_vec[i]); ParseShape(output_shapes[i], &output_shape_vec[i]);
} }
if (input_names.size() != input_shape_vec.size()
|| output_names.size() != output_shape_vec.size()) {
LOG(INFO) << "inputs' names do not match inputs' shapes "
"or outputs' names do not match outputs' shapes";
return 0;
}
std::vector<std::string> raw_input_data_formats = std::vector<std::string> raw_input_data_formats =
Split(FLAGS_input_data_format, ','); Split(FLAGS_input_data_format, ',');
std::vector<std::string> raw_output_data_formats = std::vector<std::string> raw_output_data_formats =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册