提交 5b3cf4ee 编写于 作者: L Liu Yiqun

Use gflags to parse arguments from command-line.

上级 42a0603e
......@@ -14,16 +14,37 @@ limitations under the License. */
#include <time.h>
#include <iostream>
#include "gflags/gflags.h"
#include "paddle/inference/inference.h"
int main(int argc, char* argv[]) {
std::string dirname =
"/home/work/liuyiqun/PaddlePaddle/Paddle/paddle/inference/"
"recognize_digits_mlp.inference.model";
std::vector<std::string> feed_var_names = {"x"};
std::vector<std::string> fetch_var_names = {"fc_2.tmp_2"};
paddle::InferenceEngine* desc = new paddle::InferenceEngine();
desc->LoadInferenceModel(dirname, feed_var_names, fetch_var_names);
DEFINE_string(dirname, "", "Directory of the inference model.");
DEFINE_string(feed_var_names, "", "Names of feeding variables");
DEFINE_string(fetch_var_names, "", "Names of fetching variables");
int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_dirname.empty() || FLAGS_feed_var_names.empty() ||
FLAGS_fetch_var_names.empty()) {
// Example:
// ./example --dirname=recognize_digits_mlp.inference.model
// --feed_var_names="x"
// --fetch_var_names="fc_2.tmp_2"
std::cout << "Usage: ./example --dirname=path/to/your/model "
"--feed_var_names=x --fetch_var_names=y"
<< std::endl;
exit(1);
}
std::cout << "FLAGS_dirname: " << FLAGS_dirname << std::endl;
std::cout << "FLAGS_feed_var_names: " << FLAGS_feed_var_names << std::endl;
std::cout << "FLAGS_fetch_var_names: " << FLAGS_fetch_var_names << std::endl;
std::string dirname = FLAGS_dirname;
std::vector<std::string> feed_var_names = {FLAGS_feed_var_names};
std::vector<std::string> fetch_var_names = {FLAGS_fetch_var_names};
paddle::InferenceEngine* engine = new paddle::InferenceEngine();
engine->LoadInferenceModel(dirname, feed_var_names, fetch_var_names);
paddle::framework::LoDTensor input;
srand(time(0));
......@@ -36,7 +57,7 @@ int main(int argc, char* argv[]) {
std::vector<paddle::framework::LoDTensor> feeds;
feeds.push_back(input);
std::vector<paddle::framework::LoDTensor> fetchs;
desc->Execute(feeds, fetchs);
engine->Execute(feeds, fetchs);
for (size_t i = 0; i < fetchs.size(); ++i) {
auto dims_i = fetchs[i].dims();
......@@ -52,5 +73,7 @@ int main(int argc, char* argv[]) {
}
std::cout << std::endl;
}
delete engine;
return 0;
}
......@@ -94,7 +94,6 @@ void InferenceEngine::GenerateLoadProgram(const std::string& dirname) {
if (IsParameter(var)) {
LOG(INFO) << "parameter's name: " << var->Name();
// framework::VarDesc new_var = *var;
framework::VarDesc* new_var = load_block->Var(var->Name());
new_var->SetShape(var->Shape());
new_var->SetDataType(var->GetDataType());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册