diff --git a/paddle/fluid/inference/tests/book/test_inference_nlp.cc b/paddle/fluid/inference/tests/book/test_inference_nlp.cc index 27bdd5528efb1413853d8461c2e748f586262d39..c942b43f174895d1bfa9688bb6d651f440b9bf41 100644 --- a/paddle/fluid/inference/tests/book/test_inference_nlp.cc +++ b/paddle/fluid/inference/tests/book/test_inference_nlp.cc @@ -19,6 +19,10 @@ limitations under the License. */ #include "paddle/fluid/inference/tests/test_helper.h" DEFINE_string(dirname, "", "Directory of the inference model."); +DEFINE_int32(repeat, 100, "Running the inference program repeat times"); +DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run inference"); +DEFINE_bool(prepare_vars, true, "Prepare variables before executor"); +DEFINE_bool(prepare_context, true, "Prepare Context before executor"); TEST(inference, understand_sentiment) { if (FLAGS_dirname.empty()) { @@ -61,10 +65,29 @@ TEST(inference, understand_sentiment) { std::vector cpu_fetchs1; cpu_fetchs1.push_back(&output1); - int repeat = 100; // Run inference on CPU - TestInference(dirname, cpu_feeds, - cpu_fetchs1, repeat); + const bool model_combined = false; + if (FLAGS_prepare_vars) { + if (FLAGS_prepare_context) { + TestInference( + dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, model_combined, + FLAGS_use_mkldnn); + } else { + TestInference( + dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, model_combined, + FLAGS_use_mkldnn); + } + } else { + if (FLAGS_prepare_context) { + TestInference( + dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, model_combined, + FLAGS_use_mkldnn); + } else { + TestInference( + dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, model_combined, + FLAGS_use_mkldnn); + } + } LOG(INFO) << output1.lod(); LOG(INFO) << output1.dims();