提交 436ea50d 编写于 作者: Q qijun

follow comments

上级 062ff4d7
...@@ -44,7 +44,9 @@ Executor::Executor(const std::vector<platform::Place>& places) { ...@@ -44,7 +44,9 @@ Executor::Executor(const std::vector<platform::Place>& places) {
device_contexts_[i] = new platform::CUDADeviceContext( device_contexts_[i] = new platform::CUDADeviceContext(
boost::get<platform::GPUPlace>(places[i])); boost::get<platform::GPUPlace>(places[i]));
#else #else
PADDLE_THROW("'GPUPlace' is not supported in CPU only device."); PADDLE_THROW(
"'GPUPlace' is not supported, Please re-compile with WITH_GPU "
"option");
#endif #endif
} }
} }
......
...@@ -67,7 +67,7 @@ void AddOp(const std::string& type, const VariableNameMap& inputs, ...@@ -67,7 +67,7 @@ void AddOp(const std::string& type, const VariableNameMap& inputs,
template <typename T> template <typename T>
void SetFeedVariable(const std::vector<std::vector<T>>& inputs, void SetFeedVariable(const std::vector<std::vector<T>>& inputs,
const std::vector<std::vector<int64_t>>& dims) { const std::vector<std::vector<int64_t>>& dims) {
Variable* g_feed_value = GetGlobalScope()->FindVar("feed_value"); Variable* g_feed_value = GetGlobalScope().FindVar("feed_value");
auto& feed_inputs = auto& feed_inputs =
*(g_feed_value->GetMutable<std::vector<paddle::framework::Tensor>>()); *(g_feed_value->GetMutable<std::vector<paddle::framework::Tensor>>());
size_t size = inputs.size(); size_t size = inputs.size();
...@@ -82,7 +82,7 @@ void SetFeedVariable(const std::vector<std::vector<T>>& inputs, ...@@ -82,7 +82,7 @@ void SetFeedVariable(const std::vector<std::vector<T>>& inputs,
// So we can memcpy the data from fetch_value to vector<T> // So we can memcpy the data from fetch_value to vector<T>
template <typename T> template <typename T>
std::vector<std::vector<T>> GetFetchVariable() { std::vector<std::vector<T>> GetFetchVariable() {
Variable* g_fetch_value = GetGlobalScope()->FindVar("fetch_value"); Variable* g_fetch_value = GetGlobalScope().FindVar("fetch_value");
auto& fetch_outputs = auto& fetch_outputs =
*(g_fetch_value->GetMutable<std::vector<paddle::framework::Tensor>>()); *(g_fetch_value->GetMutable<std::vector<paddle::framework::Tensor>>());
...@@ -232,8 +232,9 @@ TEST_F(ExecutorTesterRandom, CPU) { ...@@ -232,8 +232,9 @@ TEST_F(ExecutorTesterRandom, CPU) {
std::unique_ptr<Executor> executor(new Executor(places)); std::unique_ptr<Executor> executor(new Executor(places));
executor->Run(init_pdesc_, GetGlobalScope(), 0); executor->Run(init_pdesc_, &GetGlobalScope(), 0);
executor->Run(pdesc_, GetGlobalScope(), 0); SetFeedVariable<float>(inputs_, dims_);
executor->Run(pdesc_, &GetGlobalScope(), 0);
std::vector<std::vector<float>> result = GetFetchVariable<float>(); std::vector<std::vector<float>> result = GetFetchVariable<float>();
} }
...@@ -252,7 +253,7 @@ TEST_F(ExecutorTesterFeedAndFetch, CPU) { ...@@ -252,7 +253,7 @@ TEST_F(ExecutorTesterFeedAndFetch, CPU) {
for (int batch_id = 0; batch_id < 3; batch_id++) { for (int batch_id = 0; batch_id < 3; batch_id++) {
SetFeedVariable<float>(inputs_, dims_); SetFeedVariable<float>(inputs_, dims_);
executor->Run(pdesc_, GetGlobalScope(), 0); executor->Run(pdesc_, &GetGlobalScope(), 0);
std::vector<std::vector<float>> result = GetFetchVariable<float>(); std::vector<std::vector<float>> result = GetFetchVariable<float>();
PADDLE_ENFORCE_EQ(result.size(), inputs_.size()); PADDLE_ENFORCE_EQ(result.size(), inputs_.size());
for (size_t i = 0; i < result.size(); ++i) { for (size_t i = 0; i < result.size(); ++i) {
...@@ -280,10 +281,10 @@ TEST_F(ExecutorTesterRandom, GPU) { ...@@ -280,10 +281,10 @@ TEST_F(ExecutorTesterRandom, GPU) {
std::unique_ptr<Executor> executor(new Executor(places)); std::unique_ptr<Executor> executor(new Executor(places));
executor->Run(init_pdesc_, GetGlobalScope(), 0); executor->Run(init_pdesc_, &GetGlobalScope(), 0);
for (int batch_id = 0; batch_id < 3; batch_id++) { for (int batch_id = 0; batch_id < 3; batch_id++) {
SetFeedVariable<float>(inputs_, dims_); SetFeedVariable<float>(inputs_, dims_);
executor->Run(pdesc_, GetGlobalScope(), 0); executor->Run(pdesc_, &GetGlobalScope(), 0);
} }
} }
...@@ -304,7 +305,7 @@ TEST_F(ExecutorTesterFeedAndFetch, GPU) { ...@@ -304,7 +305,7 @@ TEST_F(ExecutorTesterFeedAndFetch, GPU) {
for (int batch_id = 0; batch_id < 3; batch_id++) { for (int batch_id = 0; batch_id < 3; batch_id++) {
SetFeedVariable<float>(inputs_, dims_); SetFeedVariable<float>(inputs_, dims_);
executor->Run(pdesc_, GetGlobalScope(), 0); executor->Run(pdesc_, &GetGlobalScope(), 0);
std::vector<std::vector<float>> result = GetFetchVariable<float>(); std::vector<std::vector<float>> result = GetFetchVariable<float>();
PADDLE_ENFORCE_EQ(result.size(), inputs_.size()); PADDLE_ENFORCE_EQ(result.size(), inputs_.size());
for (size_t i = 0; i < result.size(); ++i) { for (size_t i = 0; i < result.size(); ++i) {
......
...@@ -67,14 +67,14 @@ void Scope::DropKids() { ...@@ -67,14 +67,14 @@ void Scope::DropKids() {
std::once_flag feed_variable_flag; std::once_flag feed_variable_flag;
framework::Scope* GetGlobalScope() { framework::Scope& GetGlobalScope() {
static std::unique_ptr<framework::Scope> g_scope{nullptr}; static std::unique_ptr<framework::Scope> g_scope{nullptr};
std::call_once(feed_variable_flag, [&]() { std::call_once(feed_variable_flag, [&]() {
g_scope.reset(new framework::Scope()); g_scope.reset(new framework::Scope());
g_scope->NewVar("feed_value"); g_scope->NewVar("feed_value");
g_scope->NewVar("fetch_value"); g_scope->NewVar("fetch_value");
}); });
return g_scope.get(); return *(g_scope.get());
} }
} // namespace framework } // namespace framework
......
...@@ -73,7 +73,7 @@ class Scope { ...@@ -73,7 +73,7 @@ class Scope {
DISABLE_COPY_AND_ASSIGN(Scope); DISABLE_COPY_AND_ASSIGN(Scope);
}; };
framework::Scope* GetGlobalScope(); framework::Scope& GetGlobalScope();
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -26,7 +26,7 @@ class FeedKernel : public framework::OpKernel<T> { ...@@ -26,7 +26,7 @@ class FeedKernel : public framework::OpKernel<T> {
framework::Tensor* out = ctx.Output<framework::Tensor>("Out"); framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
framework::Variable* g_feed_variable = framework::Variable* g_feed_variable =
framework::GetGlobalScope()->FindVar("feed_value"); framework::GetGlobalScope().FindVar("feed_value");
const auto& tensors = const auto& tensors =
g_feed_variable->Get<std::vector<framework::Tensor>>(); g_feed_variable->Get<std::vector<framework::Tensor>>();
int col = ctx.template Attr<int>("col"); int col = ctx.template Attr<int>("col");
......
...@@ -25,7 +25,7 @@ class FetchKernel : public framework::OpKernel<T> { ...@@ -25,7 +25,7 @@ class FetchKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor* input = ctx.Input<framework::Tensor>("Input"); const framework::Tensor* input = ctx.Input<framework::Tensor>("Input");
framework::Variable* g_fetch_variable = framework::Variable* g_fetch_variable =
framework::GetGlobalScope()->FindVar("fetch_value"); framework::GetGlobalScope().FindVar("fetch_value");
auto* tensors = auto* tensors =
g_fetch_variable->GetMutable<std::vector<framework::Tensor>>(); g_fetch_variable->GetMutable<std::vector<framework::Tensor>>();
int col = ctx.template Attr<int>("col"); int col = ctx.template Attr<int>("col");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册