提交 fbd3604c 编写于 作者: L Liu Yiqun

Split Executor.Run to Executor.Prepare and Executor.RunPreparedContext for inference.

上级 172c887d
......@@ -129,6 +129,7 @@ static bool has_feed_operators(
feed_count, feed_targets.size(),
"The number of feed operators should match 'feed_targets'");
if (!feed_holder_name.empty()) {
// When feed operator are present, so should be feed_holder
auto var = block.FindVar(feed_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
......@@ -137,6 +138,7 @@ static bool has_feed_operators(
"'%s' variable should be 'FEED_MINIBATCH' type",
feed_holder_name);
}
}
return feed_count > 0;
}
......@@ -169,6 +171,7 @@ static bool has_fetch_operators(
fetch_count, fetch_targets.size(),
"The number of fetch operators should match 'fetch_targets'");
if (!fetch_holder_name.empty()) {
// When fetch operator are present, so should be fetch_holder
auto var = block.FindVar(fetch_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
......@@ -177,6 +180,7 @@ static bool has_fetch_operators(
"'%s' variable should be 'FETCH_LIST' type",
fetch_holder_name);
}
}
return fetch_count > 0;
}
......@@ -222,16 +226,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
}
}
// map the data of feed_targets to feed_holder
for (auto* op : global_block->AllOps()) {
if (op->Type() == kFeedOpType) {
std::string feed_target_name = op->Output("Out")[0];
int idx = boost::get<int>(op->GetAttr("col"));
SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name,
idx);
}
}
if (!has_fetch_ops) {
// create fetch_holder variable
auto* fetch_holder = global_block->Var(fetch_holder_name);
......@@ -255,17 +249,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
}
}
Run(*copy_program, scope, 0, create_vars, create_vars);
// obtain the data of fetch_targets from fetch_holder
for (auto* op : global_block->AllOps()) {
if (op->Type() == kFetchOpType) {
std::string fetch_target_name = op->Input("X")[0];
int idx = boost::get<int>(op->GetAttr("col"));
*fetch_targets[fetch_target_name] =
GetFetchVariable(*scope, fetch_holder_name, idx);
}
}
auto ctx = Prepare(*copy_program, 0);
RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets,
feed_holder_name, fetch_holder_name, create_vars);
}
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
......@@ -343,5 +329,43 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
}
}
void Executor::RunPreparedContext(
ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& feed_holder_name, const std::string& fetch_holder_name,
bool create_vars) {
auto& global_block = ctx->prog_.Block(ctx->block_id_);
// map the data of feed_targets to feed_holder
for (auto* op : global_block.AllOps()) {
if (op->Type() == kFeedOpType) {
std::string feed_target_name = op->Output("Out")[0];
PADDLE_ENFORCE(feed_targets.find(feed_target_name) != feed_targets.end(),
"Variable %s is not feeded.");
int idx = boost::get<int>(op->GetAttr("col"));
SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name,
idx);
}
}
RunPreparedContext(ctx, scope, create_vars, create_vars);
// obtain the data of fetch_targets from fetch_holder
for (auto* op : global_block.AllOps()) {
if (op->Type() == kFetchOpType) {
std::string fetch_target_name = op->Input("X")[0];
PADDLE_ENFORCE(
fetch_targets.find(fetch_target_name) != fetch_targets.end(),
"Variable %s is not fetched.");
int idx = boost::get<int>(op->GetAttr("col"));
*fetch_targets[fetch_target_name] =
GetFetchVariable(*scope, fetch_holder_name, idx);
}
}
}
} // namespace framework
} // namespace paddle
......@@ -65,6 +65,13 @@ class Executor {
bool create_local_scope = true,
bool create_vars = true);
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch",
bool create_vars = true);
private:
const platform::Place place_;
};
......
......@@ -48,7 +48,7 @@ TEST(inference, image_classification) {
// Run inference on CPU
LOG(INFO) << "--- CPU Runs: ---";
TestInference<paddle::platform::CPUPlace>(
TestInference<paddle::platform::CPUPlace, true>(
dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat);
LOG(INFO) << output1.dims();
......@@ -59,7 +59,7 @@ TEST(inference, image_classification) {
// Run inference on CUDA GPU
LOG(INFO) << "--- GPU Runs: ---";
TestInference<paddle::platform::CUDAPlace>(
TestInference<paddle::platform::CUDAPlace, true>(
dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat);
LOG(INFO) << output2.dims();
......
......@@ -88,7 +88,7 @@ void CheckError(paddle::framework::LoDTensor& output1,
EXPECT_EQ(count, 0U) << "There are " << count << " different elements.";
}
template <typename Place>
template <typename Place, bool PrepareContext = false>
void TestInference(const std::string& dirname,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
std::vector<paddle::framework::LoDTensor*>& cpu_fetchs,
......@@ -170,7 +170,14 @@ void TestInference(const std::string& dirname,
// 6. Run the inference program
{
// Ignore the profiling results of the first run
std::unique_ptr<paddle::framework::ExecutorPrepareContext> ctx;
if (PrepareContext) {
ctx = executor.Prepare(*inference_program, 0);
executor.RunPreparedContext(
ctx.get(), scope, feed_targets, fetch_targets);
} else {
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
}
// Enable the profiler
paddle::platform::EnableProfiler(state);
......@@ -181,8 +188,15 @@ void TestInference(const std::string& dirname,
"run_inference",
paddle::platform::DeviceContextPool::Instance().Get(place));
if (PrepareContext) {
// Note: if you changed the inference_program, you need to call
// executor.Prepare() again to get a new ExecutorPrepareContext.
executor.RunPreparedContext(
ctx.get(), scope, feed_targets, fetch_targets);
} else {
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
}
}
// Disable the profiler and print the timing information
paddle::platform::DisableProfiler(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册