From 5b3cf4ee61482b44c89ac8ebe9cf656e9d6fac7c Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Wed, 3 Jan 2018 17:26:22 +0800 Subject: [PATCH] Use gflags to parse arguments from command-line. --- paddle/inference/example.cc | 41 +++++++++++++++++++++++++++-------- paddle/inference/inference.cc | 1 - 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/paddle/inference/example.cc b/paddle/inference/example.cc index 30cdd96327..9711b20e6f 100644 --- a/paddle/inference/example.cc +++ b/paddle/inference/example.cc @@ -14,16 +14,37 @@ limitations under the License. */ #include #include +#include "gflags/gflags.h" #include "paddle/inference/inference.h" -int main(int argc, char* argv[]) { - std::string dirname = - "/home/work/liuyiqun/PaddlePaddle/Paddle/paddle/inference/" - "recognize_digits_mlp.inference.model"; - std::vector feed_var_names = {"x"}; - std::vector fetch_var_names = {"fc_2.tmp_2"}; - paddle::InferenceEngine* desc = new paddle::InferenceEngine(); - desc->LoadInferenceModel(dirname, feed_var_names, fetch_var_names); +DEFINE_string(dirname, "", "Directory of the inference model."); +DEFINE_string(feed_var_names, "", "Names of feeding variables"); +DEFINE_string(fetch_var_names, "", "Names of fetching variables"); + +int main(int argc, char** argv) { + google::ParseCommandLineFlags(&argc, &argv, true); + if (FLAGS_dirname.empty() || FLAGS_feed_var_names.empty() || + FLAGS_fetch_var_names.empty()) { + // Example: + // ./example --dirname=recognize_digits_mlp.inference.model + // --feed_var_names="x" + // --fetch_var_names="fc_2.tmp_2" + std::cout << "Usage: ./example --dirname=path/to/your/model " + "--feed_var_names=x --fetch_var_names=y" + << std::endl; + exit(1); + } + + std::cout << "FLAGS_dirname: " << FLAGS_dirname << std::endl; + std::cout << "FLAGS_feed_var_names: " << FLAGS_feed_var_names << std::endl; + std::cout << "FLAGS_fetch_var_names: " << FLAGS_fetch_var_names << std::endl; + + std::string dirname = FLAGS_dirname; + std::vector feed_var_names = {FLAGS_feed_var_names}; + std::vector fetch_var_names = {FLAGS_fetch_var_names}; + + paddle::InferenceEngine* engine = new paddle::InferenceEngine(); + engine->LoadInferenceModel(dirname, feed_var_names, fetch_var_names); paddle::framework::LoDTensor input; srand(time(0)); @@ -36,7 +57,7 @@ int main(int argc, char* argv[]) { std::vector feeds; feeds.push_back(input); std::vector fetchs; - desc->Execute(feeds, fetchs); + engine->Execute(feeds, fetchs); for (size_t i = 0; i < fetchs.size(); ++i) { auto dims_i = fetchs[i].dims(); @@ -52,5 +73,7 @@ int main(int argc, char* argv[]) { } std::cout << std::endl; } + + delete engine; return 0; } diff --git a/paddle/inference/inference.cc b/paddle/inference/inference.cc index ebfdcd7456..48a51efcd2 100644 --- a/paddle/inference/inference.cc +++ b/paddle/inference/inference.cc @@ -94,7 +94,6 @@ void InferenceEngine::GenerateLoadProgram(const std::string& dirname) { if (IsParameter(var)) { LOG(INFO) << "parameter's name: " << var->Name(); - // framework::VarDesc new_var = *var; framework::VarDesc* new_var = load_block->Var(var->Name()); new_var->SetShape(var->Shape()); new_var->SetDataType(var->GetDataType()); -- GitLab