From fbd3604cad8fdb3ad7fa2f6717395b1c40e6ecaf Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 3 Apr 2018 05:31:52 +0000 Subject: [PATCH] Split Executor.Run to Executor.Prepare and Executor.RunPreparedContext for inference. --- paddle/fluid/framework/executor.cc | 94 ++++++++++++------- paddle/fluid/framework/executor.h | 7 ++ .../test_inference_image_classification.cc | 4 +- paddle/fluid/inference/tests/test_helper.h | 20 +++- 4 files changed, 85 insertions(+), 40 deletions(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 64c06687b6b..009d0fbeb88 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -129,13 +129,15 @@ static bool has_feed_operators( feed_count, feed_targets.size(), "The number of feed operators should match 'feed_targets'"); - // When feed operator are present, so should be feed_holder - auto var = block.FindVar(feed_holder_name); - PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", - feed_holder_name); - PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH, - "'%s' variable should be 'FEED_MINIBATCH' type", - feed_holder_name); + if (!feed_holder_name.empty()) { + // When feed operator are present, so should be feed_holder + auto var = block.FindVar(feed_holder_name); + PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", + feed_holder_name); + PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH, + "'%s' variable should be 'FEED_MINIBATCH' type", + feed_holder_name); + } } return feed_count > 0; @@ -169,13 +171,15 @@ static bool has_fetch_operators( fetch_count, fetch_targets.size(), "The number of fetch operators should match 'fetch_targets'"); - // When fetch operator are present, so should be fetch_holder - auto var = block.FindVar(fetch_holder_name); - PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", - fetch_holder_name); - PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST, - "'%s' variable should be 'FETCH_LIST' type", - fetch_holder_name); + if (!fetch_holder_name.empty()) { + // When fetch operator are present, so should be fetch_holder + auto var = block.FindVar(fetch_holder_name); + PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", + fetch_holder_name); + PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST, + "'%s' variable should be 'FETCH_LIST' type", + fetch_holder_name); + } } return fetch_count > 0; @@ -222,16 +226,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, } } - // map the data of feed_targets to feed_holder - for (auto* op : global_block->AllOps()) { - if (op->Type() == kFeedOpType) { - std::string feed_target_name = op->Output("Out")[0]; - int idx = boost::get(op->GetAttr("col")); - SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name, - idx); - } - } - if (!has_fetch_ops) { // create fetch_holder variable auto* fetch_holder = global_block->Var(fetch_holder_name); @@ -255,17 +249,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, } } - Run(*copy_program, scope, 0, create_vars, create_vars); - - // obtain the data of fetch_targets from fetch_holder - for (auto* op : global_block->AllOps()) { - if (op->Type() == kFetchOpType) { - std::string fetch_target_name = op->Input("X")[0]; - int idx = boost::get(op->GetAttr("col")); - *fetch_targets[fetch_target_name] = - GetFetchVariable(*scope, fetch_holder_name, idx); - } - } + auto ctx = Prepare(*copy_program, 0); + RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets, + feed_holder_name, fetch_holder_name, create_vars); } std::unique_ptr Executor::Prepare( @@ -343,5 +329,43 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, } } +void Executor::RunPreparedContext( + ExecutorPrepareContext* ctx, Scope* scope, + std::map& feed_targets, + std::map& fetch_targets, + const std::string& feed_holder_name, const std::string& fetch_holder_name, + bool create_vars) { + auto& global_block = ctx->prog_.Block(ctx->block_id_); + + // map the data of feed_targets to feed_holder + for (auto* op : global_block.AllOps()) { + if (op->Type() == kFeedOpType) { + std::string feed_target_name = op->Output("Out")[0]; + PADDLE_ENFORCE(feed_targets.find(feed_target_name) != feed_targets.end(), + "Variable %s is not feeded."); + + int idx = boost::get(op->GetAttr("col")); + SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name, + idx); + } + } + + RunPreparedContext(ctx, scope, create_vars, create_vars); + + // obtain the data of fetch_targets from fetch_holder + for (auto* op : global_block.AllOps()) { + if (op->Type() == kFetchOpType) { + std::string fetch_target_name = op->Input("X")[0]; + PADDLE_ENFORCE( + fetch_targets.find(fetch_target_name) != fetch_targets.end(), + "Variable %s is not fetched."); + + int idx = boost::get(op->GetAttr("col")); + *fetch_targets[fetch_target_name] = + GetFetchVariable(*scope, fetch_holder_name, idx); + } + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 7173c51c95e..b0e64d5de0c 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -65,6 +65,13 @@ class Executor { bool create_local_scope = true, bool create_vars = true); + void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, + std::map& feed_targets, + std::map& fetch_targets, + const std::string& feed_holder_name = "feed", + const std::string& fetch_holder_name = "fetch", + bool create_vars = true); + private: const platform::Place place_; }; diff --git a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc index e9a27171f1c..9126efb8c2e 100644 --- a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc +++ b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc @@ -48,7 +48,7 @@ TEST(inference, image_classification) { // Run inference on CPU LOG(INFO) << "--- CPU Runs: ---"; - TestInference( + TestInference( dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat); LOG(INFO) << output1.dims(); @@ -59,7 +59,7 @@ TEST(inference, image_classification) { // Run inference on CUDA GPU LOG(INFO) << "--- GPU Runs: ---"; - TestInference( + TestInference( dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat); LOG(INFO) << output2.dims(); diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index dce541c0971..d559cc7d038 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -88,7 +88,7 @@ void CheckError(paddle::framework::LoDTensor& output1, EXPECT_EQ(count, 0U) << "There are " << count << " different elements."; } -template +template void TestInference(const std::string& dirname, const std::vector& cpu_feeds, std::vector& cpu_fetchs, @@ -170,7 +170,14 @@ void TestInference(const std::string& dirname, // 6. Run the inference program { // Ignore the profiling results of the first run - executor.Run(*inference_program, scope, feed_targets, fetch_targets); + std::unique_ptr ctx; + if (PrepareContext) { + ctx = executor.Prepare(*inference_program, 0); + executor.RunPreparedContext( + ctx.get(), scope, feed_targets, fetch_targets); + } else { + executor.Run(*inference_program, scope, feed_targets, fetch_targets); + } // Enable the profiler paddle::platform::EnableProfiler(state); @@ -181,7 +188,14 @@ void TestInference(const std::string& dirname, "run_inference", paddle::platform::DeviceContextPool::Instance().Get(place)); - executor.Run(*inference_program, scope, feed_targets, fetch_targets); + if (PrepareContext) { + // Note: if you changed the inference_program, you need to call + // executor.Prepare() again to get a new ExecutorPrepareContext. + executor.RunPreparedContext( + ctx.get(), scope, feed_targets, fetch_targets); + } else { + executor.Run(*inference_program, scope, feed_targets, fetch_targets); + } } // Disable the profiler and print the timing information -- GitLab