提交 436ea50d 编写于 作者: Q qijun

follow comments

上级 062ff4d7
......@@ -44,7 +44,9 @@ Executor::Executor(const std::vector<platform::Place>& places) {
device_contexts_[i] = new platform::CUDADeviceContext(
boost::get<platform::GPUPlace>(places[i]));
#else
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
PADDLE_THROW(
"'GPUPlace' is not supported, Please re-compile with WITH_GPU "
"option");
#endif
}
}
......
......@@ -67,7 +67,7 @@ void AddOp(const std::string& type, const VariableNameMap& inputs,
template <typename T>
void SetFeedVariable(const std::vector<std::vector<T>>& inputs,
const std::vector<std::vector<int64_t>>& dims) {
Variable* g_feed_value = GetGlobalScope()->FindVar("feed_value");
Variable* g_feed_value = GetGlobalScope().FindVar("feed_value");
auto& feed_inputs =
*(g_feed_value->GetMutable<std::vector<paddle::framework::Tensor>>());
size_t size = inputs.size();
......@@ -82,7 +82,7 @@ void SetFeedVariable(const std::vector<std::vector<T>>& inputs,
// So we can memcpy the data from fetch_value to vector<T>
template <typename T>
std::vector<std::vector<T>> GetFetchVariable() {
Variable* g_fetch_value = GetGlobalScope()->FindVar("fetch_value");
Variable* g_fetch_value = GetGlobalScope().FindVar("fetch_value");
auto& fetch_outputs =
*(g_fetch_value->GetMutable<std::vector<paddle::framework::Tensor>>());
......@@ -232,8 +232,9 @@ TEST_F(ExecutorTesterRandom, CPU) {
std::unique_ptr<Executor> executor(new Executor(places));
executor->Run(init_pdesc_, GetGlobalScope(), 0);
executor->Run(pdesc_, GetGlobalScope(), 0);
executor->Run(init_pdesc_, &GetGlobalScope(), 0);
SetFeedVariable<float>(inputs_, dims_);
executor->Run(pdesc_, &GetGlobalScope(), 0);
std::vector<std::vector<float>> result = GetFetchVariable<float>();
}
......@@ -252,7 +253,7 @@ TEST_F(ExecutorTesterFeedAndFetch, CPU) {
for (int batch_id = 0; batch_id < 3; batch_id++) {
SetFeedVariable<float>(inputs_, dims_);
executor->Run(pdesc_, GetGlobalScope(), 0);
executor->Run(pdesc_, &GetGlobalScope(), 0);
std::vector<std::vector<float>> result = GetFetchVariable<float>();
PADDLE_ENFORCE_EQ(result.size(), inputs_.size());
for (size_t i = 0; i < result.size(); ++i) {
......@@ -280,10 +281,10 @@ TEST_F(ExecutorTesterRandom, GPU) {
std::unique_ptr<Executor> executor(new Executor(places));
executor->Run(init_pdesc_, GetGlobalScope(), 0);
executor->Run(init_pdesc_, &GetGlobalScope(), 0);
for (int batch_id = 0; batch_id < 3; batch_id++) {
SetFeedVariable<float>(inputs_, dims_);
executor->Run(pdesc_, GetGlobalScope(), 0);
executor->Run(pdesc_, &GetGlobalScope(), 0);
}
}
......@@ -304,7 +305,7 @@ TEST_F(ExecutorTesterFeedAndFetch, GPU) {
for (int batch_id = 0; batch_id < 3; batch_id++) {
SetFeedVariable<float>(inputs_, dims_);
executor->Run(pdesc_, GetGlobalScope(), 0);
executor->Run(pdesc_, &GetGlobalScope(), 0);
std::vector<std::vector<float>> result = GetFetchVariable<float>();
PADDLE_ENFORCE_EQ(result.size(), inputs_.size());
for (size_t i = 0; i < result.size(); ++i) {
......
......@@ -67,14 +67,14 @@ void Scope::DropKids() {
std::once_flag feed_variable_flag;
framework::Scope* GetGlobalScope() {
framework::Scope& GetGlobalScope() {
static std::unique_ptr<framework::Scope> g_scope{nullptr};
std::call_once(feed_variable_flag, [&]() {
g_scope.reset(new framework::Scope());
g_scope->NewVar("feed_value");
g_scope->NewVar("fetch_value");
});
return g_scope.get();
return *(g_scope.get());
}
} // namespace framework
......
......@@ -73,7 +73,7 @@ class Scope {
DISABLE_COPY_AND_ASSIGN(Scope);
};
framework::Scope* GetGlobalScope();
framework::Scope& GetGlobalScope();
} // namespace framework
} // namespace paddle
......@@ -26,7 +26,7 @@ class FeedKernel : public framework::OpKernel<T> {
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
framework::Variable* g_feed_variable =
framework::GetGlobalScope()->FindVar("feed_value");
framework::GetGlobalScope().FindVar("feed_value");
const auto& tensors =
g_feed_variable->Get<std::vector<framework::Tensor>>();
int col = ctx.template Attr<int>("col");
......
......@@ -25,7 +25,7 @@ class FetchKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor* input = ctx.Input<framework::Tensor>("Input");
framework::Variable* g_fetch_variable =
framework::GetGlobalScope()->FindVar("fetch_value");
framework::GetGlobalScope().FindVar("fetch_value");
auto* tensors =
g_fetch_variable->GetMutable<std::vector<framework::Tensor>>();
int col = ctx.template Attr<int>("col");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册