From 436ea50d5fc8867848892fc53b7f82aa59ae3b41 Mon Sep 17 00:00:00 2001 From: qijun Date: Tue, 10 Oct 2017 14:31:47 -0700 Subject: [PATCH] follow comments --- paddle/framework/executor.cc | 4 +++- paddle/framework/executor_test.cc | 17 +++++++++-------- paddle/framework/scope.cc | 4 ++-- paddle/framework/scope.h | 2 +- paddle/operators/feed_op.h | 2 +- paddle/operators/fetch_op.h | 2 +- 6 files changed, 17 insertions(+), 14 deletions(-) diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index def1d1fd06..1db5c878d6 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -44,7 +44,9 @@ Executor::Executor(const std::vector& places) { device_contexts_[i] = new platform::CUDADeviceContext( boost::get(places[i])); #else - PADDLE_THROW("'GPUPlace' is not supported in CPU only device."); + PADDLE_THROW( + "'GPUPlace' is not supported, Please re-compile with WITH_GPU " + "option"); #endif } } diff --git a/paddle/framework/executor_test.cc b/paddle/framework/executor_test.cc index 5ad5b98e7b..f36284b528 100644 --- a/paddle/framework/executor_test.cc +++ b/paddle/framework/executor_test.cc @@ -67,7 +67,7 @@ void AddOp(const std::string& type, const VariableNameMap& inputs, template void SetFeedVariable(const std::vector>& inputs, const std::vector>& dims) { - Variable* g_feed_value = GetGlobalScope()->FindVar("feed_value"); + Variable* g_feed_value = GetGlobalScope().FindVar("feed_value"); auto& feed_inputs = *(g_feed_value->GetMutable>()); size_t size = inputs.size(); @@ -82,7 +82,7 @@ void SetFeedVariable(const std::vector>& inputs, // So we can memcpy the data from fetch_value to vector template std::vector> GetFetchVariable() { - Variable* g_fetch_value = GetGlobalScope()->FindVar("fetch_value"); + Variable* g_fetch_value = GetGlobalScope().FindVar("fetch_value"); auto& fetch_outputs = *(g_fetch_value->GetMutable>()); @@ -232,8 +232,9 @@ TEST_F(ExecutorTesterRandom, CPU) { std::unique_ptr executor(new Executor(places)); - executor->Run(init_pdesc_, GetGlobalScope(), 0); - executor->Run(pdesc_, GetGlobalScope(), 0); + executor->Run(init_pdesc_, &GetGlobalScope(), 0); + SetFeedVariable(inputs_, dims_); + executor->Run(pdesc_, &GetGlobalScope(), 0); std::vector> result = GetFetchVariable(); } @@ -252,7 +253,7 @@ TEST_F(ExecutorTesterFeedAndFetch, CPU) { for (int batch_id = 0; batch_id < 3; batch_id++) { SetFeedVariable(inputs_, dims_); - executor->Run(pdesc_, GetGlobalScope(), 0); + executor->Run(pdesc_, &GetGlobalScope(), 0); std::vector> result = GetFetchVariable(); PADDLE_ENFORCE_EQ(result.size(), inputs_.size()); for (size_t i = 0; i < result.size(); ++i) { @@ -280,10 +281,10 @@ TEST_F(ExecutorTesterRandom, GPU) { std::unique_ptr executor(new Executor(places)); - executor->Run(init_pdesc_, GetGlobalScope(), 0); + executor->Run(init_pdesc_, &GetGlobalScope(), 0); for (int batch_id = 0; batch_id < 3; batch_id++) { SetFeedVariable(inputs_, dims_); - executor->Run(pdesc_, GetGlobalScope(), 0); + executor->Run(pdesc_, &GetGlobalScope(), 0); } } @@ -304,7 +305,7 @@ TEST_F(ExecutorTesterFeedAndFetch, GPU) { for (int batch_id = 0; batch_id < 3; batch_id++) { SetFeedVariable(inputs_, dims_); - executor->Run(pdesc_, GetGlobalScope(), 0); + executor->Run(pdesc_, &GetGlobalScope(), 0); std::vector> result = GetFetchVariable(); PADDLE_ENFORCE_EQ(result.size(), inputs_.size()); for (size_t i = 0; i < result.size(); ++i) { diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc index c9e53a0d85..5821bac928 100644 --- a/paddle/framework/scope.cc +++ b/paddle/framework/scope.cc @@ -67,14 +67,14 @@ void Scope::DropKids() { std::once_flag feed_variable_flag; -framework::Scope* GetGlobalScope() { +framework::Scope& GetGlobalScope() { static std::unique_ptr g_scope{nullptr}; std::call_once(feed_variable_flag, [&]() { g_scope.reset(new framework::Scope()); g_scope->NewVar("feed_value"); g_scope->NewVar("fetch_value"); }); - return g_scope.get(); + return *(g_scope.get()); } } // namespace framework diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index 319d291efe..a8cfb107c2 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -73,7 +73,7 @@ class Scope { DISABLE_COPY_AND_ASSIGN(Scope); }; -framework::Scope* GetGlobalScope(); +framework::Scope& GetGlobalScope(); } // namespace framework } // namespace paddle diff --git a/paddle/operators/feed_op.h b/paddle/operators/feed_op.h index e406d22209..9d8158299f 100644 --- a/paddle/operators/feed_op.h +++ b/paddle/operators/feed_op.h @@ -26,7 +26,7 @@ class FeedKernel : public framework::OpKernel { framework::Tensor* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); framework::Variable* g_feed_variable = - framework::GetGlobalScope()->FindVar("feed_value"); + framework::GetGlobalScope().FindVar("feed_value"); const auto& tensors = g_feed_variable->Get>(); int col = ctx.template Attr("col"); diff --git a/paddle/operators/fetch_op.h b/paddle/operators/fetch_op.h index 6fee8b0589..eb9c3a7b59 100644 --- a/paddle/operators/fetch_op.h +++ b/paddle/operators/fetch_op.h @@ -25,7 +25,7 @@ class FetchKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { const framework::Tensor* input = ctx.Input("Input"); framework::Variable* g_fetch_variable = - framework::GetGlobalScope()->FindVar("fetch_value"); + framework::GetGlobalScope().FindVar("fetch_value"); auto* tensors = g_fetch_variable->GetMutable>(); int col = ctx.template Attr("col"); -- GitLab