diff --git a/paddle/fluid/inference/tests/book/test_inference_nlp.cc b/paddle/fluid/inference/tests/book/test_inference_nlp.cc index 70aa42ac4111c0524a55e26aaefa864338c1d6c1..a0e83a17058a4edcb8f23f23ce155e644ae0cf3b 100644 --- a/paddle/fluid/inference/tests/book/test_inference_nlp.cc +++ b/paddle/fluid/inference/tests/book/test_inference_nlp.cc @@ -101,23 +101,22 @@ void SplitData( } void ThreadRunInfer( - const int tid, paddle::framework::Executor* executor, - paddle::framework::Scope* scope, - const std::unique_ptr& inference_program, + const int tid, paddle::framework::Scope* scope, const std::vector>& jobs) { - auto copy_program = std::unique_ptr( - new paddle::framework::ProgramDesc(*inference_program)); + // maybe framework:ProgramDesc is not thread-safe auto& sub_scope = scope->NewScope(); + auto place = paddle::platform::CPUPlace(); + auto executor = paddle::framework::Executor(place); + auto inference_program = + paddle::inference::Load(&executor, scope, FLAGS_model_path); - std::string feed_holder_name = "feed_" + paddle::string::to_string(tid); - std::string fetch_holder_name = "fetch_" + paddle::string::to_string(tid); - copy_program->SetFeedHolderName(feed_holder_name); - copy_program->SetFetchHolderName(fetch_holder_name); + auto ctx = executor.Prepare(*inference_program, /*block_id*/ 0); + executor.CreateVariables(*inference_program, &sub_scope, /*block_id*/ 0); const std::vector& feed_target_names = - copy_program->GetFeedTargetNames(); + inference_program->GetFeedTargetNames(); const std::vector& fetch_target_names = - copy_program->GetFetchTargetNames(); + inference_program->GetFetchTargetNames(); PADDLE_ENFORCE_EQ(fetch_target_names.size(), 1UL); std::map fetch_targets; @@ -131,9 +130,8 @@ void ThreadRunInfer( auto start_ms = GetCurrentMs(); for (size_t i = 0; i < inputs.size(); ++i) { feed_targets[feed_target_names[0]] = inputs[i]; - executor->Run(*copy_program, &sub_scope, &feed_targets, &fetch_targets, - true /*create_local_scope*/, true /*create_vars*/, - feed_holder_name, fetch_holder_name); + executor.RunPreparedContext(ctx.get(), &sub_scope, &feed_targets, + &fetch_targets, false /*create_local_scope*/); } auto stop_ms = GetCurrentMs(); scope->DeleteScope(&sub_scope); @@ -158,22 +156,10 @@ TEST(inference, nlp) { LOG(INFO) << "Number of samples (seq_len<1024): " << datasets.size(); LOG(INFO) << "Total number of words: " << num_total_words; - const bool model_combined = false; // 0. Call `paddle::framework::InitDevices()` initialize all the devices - // 1. Define place, executor, scope - auto place = paddle::platform::CPUPlace(); - auto executor = paddle::framework::Executor(place); std::unique_ptr scope( new paddle::framework::Scope()); - // 2. Initialize the inference_program and load parameters - std::unique_ptr inference_program; - inference_program = - InitProgram(&executor, scope.get(), FLAGS_model_path, model_combined); - if (FLAGS_use_mkldnn) { - EnableMKLDNN(inference_program); - } - #ifdef PADDLE_WITH_MKLML // only use 1 thread number per std::thread omp_set_dynamic(0); @@ -189,21 +175,30 @@ TEST(inference, nlp) { start_ms = GetCurrentMs(); for (int i = 0; i < FLAGS_num_threads; ++i) { threads.emplace_back( - new std::thread(ThreadRunInfer, i, &executor, scope.get(), - std::ref(inference_program), std::ref(jobs))); + new std::thread(ThreadRunInfer, i, scope.get(), std::ref(jobs))); } for (int i = 0; i < FLAGS_num_threads; ++i) { threads[i]->join(); } stop_ms = GetCurrentMs(); } else { - if (FLAGS_prepare_vars) { - executor.CreateVariables(*inference_program, scope.get(), 0); + // 1. Define place, executor, scope + auto place = paddle::platform::CPUPlace(); + auto executor = paddle::framework::Executor(place); + + // 2. Initialize the inference_program and load parameters + std::unique_ptr inference_program; + inference_program = InitProgram(&executor, scope.get(), FLAGS_model_path, + /*model combined*/ false); + if (FLAGS_use_mkldnn) { + EnableMKLDNN(inference_program); } // always prepare context std::unique_ptr ctx; ctx = executor.Prepare(*inference_program, 0); - + if (FLAGS_prepare_vars) { + executor.CreateVariables(*inference_program, scope.get(), 0); + } // preapre fetch const std::vector& fetch_target_names = inference_program->GetFetchTargetNames();