提交 fbd3604c 编写于 作者: L Liu Yiqun

Split Executor.Run to Executor.Prepare and Executor.RunPreparedContext for inference.

上级 172c887d
...@@ -129,6 +129,7 @@ static bool has_feed_operators( ...@@ -129,6 +129,7 @@ static bool has_feed_operators(
feed_count, feed_targets.size(), feed_count, feed_targets.size(),
"The number of feed operators should match 'feed_targets'"); "The number of feed operators should match 'feed_targets'");
if (!feed_holder_name.empty()) {
// When feed operator are present, so should be feed_holder // When feed operator are present, so should be feed_holder
auto var = block.FindVar(feed_holder_name); auto var = block.FindVar(feed_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
...@@ -137,6 +138,7 @@ static bool has_feed_operators( ...@@ -137,6 +138,7 @@ static bool has_feed_operators(
"'%s' variable should be 'FEED_MINIBATCH' type", "'%s' variable should be 'FEED_MINIBATCH' type",
feed_holder_name); feed_holder_name);
} }
}
return feed_count > 0; return feed_count > 0;
} }
...@@ -169,6 +171,7 @@ static bool has_fetch_operators( ...@@ -169,6 +171,7 @@ static bool has_fetch_operators(
fetch_count, fetch_targets.size(), fetch_count, fetch_targets.size(),
"The number of fetch operators should match 'fetch_targets'"); "The number of fetch operators should match 'fetch_targets'");
if (!fetch_holder_name.empty()) {
// When fetch operator are present, so should be fetch_holder // When fetch operator are present, so should be fetch_holder
auto var = block.FindVar(fetch_holder_name); auto var = block.FindVar(fetch_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
...@@ -177,6 +180,7 @@ static bool has_fetch_operators( ...@@ -177,6 +180,7 @@ static bool has_fetch_operators(
"'%s' variable should be 'FETCH_LIST' type", "'%s' variable should be 'FETCH_LIST' type",
fetch_holder_name); fetch_holder_name);
} }
}
return fetch_count > 0; return fetch_count > 0;
} }
...@@ -222,16 +226,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -222,16 +226,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
} }
} }
// map the data of feed_targets to feed_holder
for (auto* op : global_block->AllOps()) {
if (op->Type() == kFeedOpType) {
std::string feed_target_name = op->Output("Out")[0];
int idx = boost::get<int>(op->GetAttr("col"));
SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name,
idx);
}
}
if (!has_fetch_ops) { if (!has_fetch_ops) {
// create fetch_holder variable // create fetch_holder variable
auto* fetch_holder = global_block->Var(fetch_holder_name); auto* fetch_holder = global_block->Var(fetch_holder_name);
...@@ -255,17 +249,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -255,17 +249,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
} }
} }
Run(*copy_program, scope, 0, create_vars, create_vars); auto ctx = Prepare(*copy_program, 0);
RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets,
// obtain the data of fetch_targets from fetch_holder feed_holder_name, fetch_holder_name, create_vars);
for (auto* op : global_block->AllOps()) {
if (op->Type() == kFetchOpType) {
std::string fetch_target_name = op->Input("X")[0];
int idx = boost::get<int>(op->GetAttr("col"));
*fetch_targets[fetch_target_name] =
GetFetchVariable(*scope, fetch_holder_name, idx);
}
}
} }
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare( std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
...@@ -343,5 +329,43 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -343,5 +329,43 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
} }
} }
void Executor::RunPreparedContext(
ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& feed_holder_name, const std::string& fetch_holder_name,
bool create_vars) {
auto& global_block = ctx->prog_.Block(ctx->block_id_);
// map the data of feed_targets to feed_holder
for (auto* op : global_block.AllOps()) {
if (op->Type() == kFeedOpType) {
std::string feed_target_name = op->Output("Out")[0];
PADDLE_ENFORCE(feed_targets.find(feed_target_name) != feed_targets.end(),
"Variable %s is not feeded.");
int idx = boost::get<int>(op->GetAttr("col"));
SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name,
idx);
}
}
RunPreparedContext(ctx, scope, create_vars, create_vars);
// obtain the data of fetch_targets from fetch_holder
for (auto* op : global_block.AllOps()) {
if (op->Type() == kFetchOpType) {
std::string fetch_target_name = op->Input("X")[0];
PADDLE_ENFORCE(
fetch_targets.find(fetch_target_name) != fetch_targets.end(),
"Variable %s is not fetched.");
int idx = boost::get<int>(op->GetAttr("col"));
*fetch_targets[fetch_target_name] =
GetFetchVariable(*scope, fetch_holder_name, idx);
}
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -65,6 +65,13 @@ class Executor { ...@@ -65,6 +65,13 @@ class Executor {
bool create_local_scope = true, bool create_local_scope = true,
bool create_vars = true); bool create_vars = true);
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch",
bool create_vars = true);
private: private:
const platform::Place place_; const platform::Place place_;
}; };
......
...@@ -48,7 +48,7 @@ TEST(inference, image_classification) { ...@@ -48,7 +48,7 @@ TEST(inference, image_classification) {
// Run inference on CPU // Run inference on CPU
LOG(INFO) << "--- CPU Runs: ---"; LOG(INFO) << "--- CPU Runs: ---";
TestInference<paddle::platform::CPUPlace>( TestInference<paddle::platform::CPUPlace, true>(
dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat); dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat);
LOG(INFO) << output1.dims(); LOG(INFO) << output1.dims();
...@@ -59,7 +59,7 @@ TEST(inference, image_classification) { ...@@ -59,7 +59,7 @@ TEST(inference, image_classification) {
// Run inference on CUDA GPU // Run inference on CUDA GPU
LOG(INFO) << "--- GPU Runs: ---"; LOG(INFO) << "--- GPU Runs: ---";
TestInference<paddle::platform::CUDAPlace>( TestInference<paddle::platform::CUDAPlace, true>(
dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat); dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat);
LOG(INFO) << output2.dims(); LOG(INFO) << output2.dims();
......
...@@ -88,7 +88,7 @@ void CheckError(paddle::framework::LoDTensor& output1, ...@@ -88,7 +88,7 @@ void CheckError(paddle::framework::LoDTensor& output1,
EXPECT_EQ(count, 0U) << "There are " << count << " different elements."; EXPECT_EQ(count, 0U) << "There are " << count << " different elements.";
} }
template <typename Place> template <typename Place, bool PrepareContext = false>
void TestInference(const std::string& dirname, void TestInference(const std::string& dirname,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds, const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
std::vector<paddle::framework::LoDTensor*>& cpu_fetchs, std::vector<paddle::framework::LoDTensor*>& cpu_fetchs,
...@@ -170,7 +170,14 @@ void TestInference(const std::string& dirname, ...@@ -170,7 +170,14 @@ void TestInference(const std::string& dirname,
// 6. Run the inference program // 6. Run the inference program
{ {
// Ignore the profiling results of the first run // Ignore the profiling results of the first run
std::unique_ptr<paddle::framework::ExecutorPrepareContext> ctx;
if (PrepareContext) {
ctx = executor.Prepare(*inference_program, 0);
executor.RunPreparedContext(
ctx.get(), scope, feed_targets, fetch_targets);
} else {
executor.Run(*inference_program, scope, feed_targets, fetch_targets); executor.Run(*inference_program, scope, feed_targets, fetch_targets);
}
// Enable the profiler // Enable the profiler
paddle::platform::EnableProfiler(state); paddle::platform::EnableProfiler(state);
...@@ -181,8 +188,15 @@ void TestInference(const std::string& dirname, ...@@ -181,8 +188,15 @@ void TestInference(const std::string& dirname,
"run_inference", "run_inference",
paddle::platform::DeviceContextPool::Instance().Get(place)); paddle::platform::DeviceContextPool::Instance().Get(place));
if (PrepareContext) {
// Note: if you changed the inference_program, you need to call
// executor.Prepare() again to get a new ExecutorPrepareContext.
executor.RunPreparedContext(
ctx.get(), scope, feed_targets, fetch_targets);
} else {
executor.Run(*inference_program, scope, feed_targets, fetch_targets); executor.Run(*inference_program, scope, feed_targets, fetch_targets);
} }
}
// Disable the profiler and print the timing information // Disable the profiler and print the timing information
paddle::platform::DisableProfiler( paddle::platform::DisableProfiler(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册