diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index 3c35102ff997114ebc95ba55dea3006dfcdf5581..69c21d745779c26fe8df583651168a9a1fdf640b 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -32,7 +32,68 @@ namespace framework { const std::string kFeedOpType = "feed"; const std::string kFetchOpType = "fetch"; -std::vector Prune(const ProgramDesc& pdesc, int block_id) { +Executor::Executor(const std::vector& places) { + PADDLE_ENFORCE_GT(places.size(), 0); + device_contexts_.resize(places.size()); + for (size_t i = 0; i < places.size(); i++) { + if (platform::is_cpu_place(places[i])) { + device_contexts_[i] = new platform::CPUDeviceContext( + boost::get(places[i])); + } else if (platform::is_gpu_place(places[i])) { +#ifdef PADDLE_WITH_CUDA + device_contexts_[i] = new platform::CUDADeviceContext( + boost::get(places[i])); +#else + PADDLE_THROW( + "'GPUPlace' is not supported, Please re-compile with WITH_GPU " + "option"); +#endif + } + } +} + +Executor::~Executor() { + for (auto& device_context : device_contexts_) { + delete device_context; + } +} + +void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) { + // TODO(tonyyang-svail): + // - only runs on the first device (i.e. no interdevice communication) + // - will change to use multiple blocks for RNN op and Cond Op + PADDLE_ENFORCE_GT(pdesc.blocks_size(), block_id); + auto& block = pdesc.blocks(block_id); + auto& device = device_contexts_[0]; + + // Instantiate all the vars in the global scope + for (auto& var : block.vars()) { + scope->NewVar(var.name()); + } + + Scope& local_scope = scope->NewScope(); + + std::vector should_run = Prune(pdesc, block_id); + PADDLE_ENFORCE_EQ(should_run.size(), static_cast(block.ops_size())); + for (size_t i = 0; i < should_run.size(); ++i) { + if (should_run[i]) { + for (auto& var : block.ops(i).outputs()) { + for (auto& argu : var.arguments()) { + if (local_scope.FindVar(argu) == nullptr) { + local_scope.NewVar(argu); + } + } + } + auto op = paddle::framework::OpRegistry::CreateOp(block.ops(i)); + op->Run(local_scope, *device); + } + } + + // TODO(tonyyang-svail): + // - Destroy local_scope +} + +std::vector Executor::Prune(const ProgramDesc& pdesc, int block_id) { // TODO(tonyyang-svail): // - will change to use multiple blocks for RNN op and Cond Op diff --git a/paddle/framework/executor_test.cc b/paddle/framework/executor_test.cc index 1cd72702400aebd1fd472be1a3dd28c589ad6431..137e53d849542e48080228e0002931867c4d7fb2 100644 --- a/paddle/framework/executor_test.cc +++ b/paddle/framework/executor_test.cc @@ -66,7 +66,7 @@ void AddOp(const std::string& type, const VariableNameMap& inputs, template void SetFeedVariable(const std::vector>& inputs, const std::vector>& dims) { - Variable* g_feed_value = GetGlobalScope()->FindVar("feed_value"); + Variable* g_feed_value = GetGlobalScope().FindVar("feed_value"); auto& feed_inputs = *(g_feed_value->GetMutable>()); size_t size = inputs.size(); @@ -81,7 +81,7 @@ void SetFeedVariable(const std::vector>& inputs, // So we can memcpy the data from fetch_value to vector template std::vector> GetFetchVariable() { - Variable* g_fetch_value = GetGlobalScope()->FindVar("fetch_value"); + Variable* g_fetch_value = GetGlobalScope().FindVar("fetch_value"); auto& fetch_outputs = *(g_fetch_value->GetMutable>()); @@ -231,8 +231,9 @@ TEST_F(ExecutorTesterRandom, CPU) { std::unique_ptr executor(new Executor(places)); - executor->Run(init_pdesc_, GetGlobalScope(), 0); - executor->Run(pdesc_, GetGlobalScope(), 0); + executor->Run(init_pdesc_, &GetGlobalScope(), 0); + SetFeedVariable(inputs_, dims_); + executor->Run(pdesc_, &GetGlobalScope(), 0); std::vector> result = GetFetchVariable(); } @@ -251,7 +252,7 @@ TEST_F(ExecutorTesterFeedAndFetch, CPU) { for (int batch_id = 0; batch_id < 3; batch_id++) { SetFeedVariable(inputs_, dims_); - executor->Run(pdesc_, GetGlobalScope(), 0); + executor->Run(pdesc_, &GetGlobalScope(), 0); std::vector> result = GetFetchVariable(); PADDLE_ENFORCE_EQ(result.size(), inputs_.size()); for (size_t i = 0; i < result.size(); ++i) { @@ -279,10 +280,10 @@ TEST_F(ExecutorTesterRandom, GPU) { std::unique_ptr executor(new Executor(places)); - executor->Run(init_pdesc_, GetGlobalScope(), 0); + executor->Run(init_pdesc_, &GetGlobalScope(), 0); for (int batch_id = 0; batch_id < 3; batch_id++) { SetFeedVariable(inputs_, dims_); - executor->Run(pdesc_, GetGlobalScope(), 0); + executor->Run(pdesc_, &GetGlobalScope(), 0); } } @@ -303,7 +304,7 @@ TEST_F(ExecutorTesterFeedAndFetch, GPU) { for (int batch_id = 0; batch_id < 3; batch_id++) { SetFeedVariable(inputs_, dims_); - executor->Run(pdesc_, GetGlobalScope(), 0); + executor->Run(pdesc_, &GetGlobalScope(), 0); std::vector> result = GetFetchVariable(); PADDLE_ENFORCE_EQ(result.size(), inputs_.size()); for (size_t i = 0; i < result.size(); ++i) { diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc index c9e53a0d850f49882ad3a8d0ba7ff92684af2208..5821bac928ed898971d61a3e2a86f59155d76991 100644 --- a/paddle/framework/scope.cc +++ b/paddle/framework/scope.cc @@ -67,14 +67,14 @@ void Scope::DropKids() { std::once_flag feed_variable_flag; -framework::Scope* GetGlobalScope() { +framework::Scope& GetGlobalScope() { static std::unique_ptr g_scope{nullptr}; std::call_once(feed_variable_flag, [&]() { g_scope.reset(new framework::Scope()); g_scope->NewVar("feed_value"); g_scope->NewVar("fetch_value"); }); - return g_scope.get(); + return *(g_scope.get()); } } // namespace framework diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index 319d291efec95d463f085cb03b1c06c0a637d8d9..a8cfb107c25ccd62039db7349cc1c1dbff772f39 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -73,7 +73,7 @@ class Scope { DISABLE_COPY_AND_ASSIGN(Scope); }; -framework::Scope* GetGlobalScope(); +framework::Scope& GetGlobalScope(); } // namespace framework } // namespace paddle diff --git a/paddle/operators/feed_op.h b/paddle/operators/feed_op.h index e406d22209dc5597f232918634634f3cf3b44e4a..9d8158299fea97a464a7bb64321b1adf8b7b2fab 100644 --- a/paddle/operators/feed_op.h +++ b/paddle/operators/feed_op.h @@ -26,7 +26,7 @@ class FeedKernel : public framework::OpKernel { framework::Tensor* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); framework::Variable* g_feed_variable = - framework::GetGlobalScope()->FindVar("feed_value"); + framework::GetGlobalScope().FindVar("feed_value"); const auto& tensors = g_feed_variable->Get>(); int col = ctx.template Attr("col"); diff --git a/paddle/operators/fetch_op.h b/paddle/operators/fetch_op.h index 6fee8b05892687d06eb1d3f7c92f0df92a8a63e6..eb9c3a7b593b84da7c8dc12d71c4f748269c64e6 100644 --- a/paddle/operators/fetch_op.h +++ b/paddle/operators/fetch_op.h @@ -25,7 +25,7 @@ class FetchKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor* input = ctx.Input("Input"); framework::Variable* g_fetch_variable = - framework::GetGlobalScope()->FindVar("fetch_value"); + framework::GetGlobalScope().FindVar("fetch_value"); auto* tensors = g_fetch_variable->GetMutable>(); int col = ctx.template Attr("col");