提交 b630d401 编写于 作者: Q qijun

add scope

上级 d4be9730
...@@ -44,5 +44,5 @@ add_custom_command(TARGET framework_py_proto POST_BUILD ...@@ -44,5 +44,5 @@ add_custom_command(TARGET framework_py_proto POST_BUILD
cc_library(backward SRCS backward.cc DEPS net_op) cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context) cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context)
cc_library(executor SRCS executor.cc DEPS device_context framework_proto) cc_library(executor SRCS executor.cc DEPS device_context scope framework_proto)
cc_test(executor_test SRCS executor_test.cc DEPS executor) cc_test(executor_test SRCS executor_test.cc DEPS executor)
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/framework/executor.h" #include "paddle/framework/executor.h"
#include <memory> #include <memory>
#include "paddle/framework/scope.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
...@@ -58,9 +59,10 @@ void GraphView::Initialize(const ProgramDesc*) { ...@@ -58,9 +59,10 @@ void GraphView::Initialize(const ProgramDesc*) {
class ExecutorImpl : public Executor { class ExecutorImpl : public Executor {
public: public:
ExecutorImpl(const platform::DeviceContext* ctx, const ProgramDesc* pdesc, ExecutorImpl(Scope* scope, const platform::DeviceContext* ctx,
bool is_linear) const ProgramDesc* pdesc, bool is_linear)
: device_context_(ctx), : scope_(scope),
device_context_(ctx),
program_desc_(pdesc), program_desc_(pdesc),
view_(ProgramDescView::Create(is_linear)) {} view_(ProgramDescView::Create(is_linear)) {}
...@@ -73,6 +75,7 @@ class ExecutorImpl : public Executor { ...@@ -73,6 +75,7 @@ class ExecutorImpl : public Executor {
void Initialize(); void Initialize();
private: private:
Scope* scope_;
const platform::DeviceContext* device_context_; const platform::DeviceContext* device_context_;
const ProgramDesc* program_desc_; const ProgramDesc* program_desc_;
ProgramDescView* view_; ProgramDescView* view_;
...@@ -80,23 +83,29 @@ class ExecutorImpl : public Executor { ...@@ -80,23 +83,29 @@ class ExecutorImpl : public Executor {
template <typename T, typename... Args> template <typename T, typename... Args>
std::unique_ptr<T> make_unique(Args&&... args) { std::unique_ptr<T> make_unique(Args&&... args) {
return std::unique_ptr<T>(new T(std::forward<Args>(args)...)); return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
} }
platform::CPUDeviceContext* GetCPUDeviceContext() { platform::CPUDeviceContext* GetCPUDeviceContext() {
static std::unique_ptr<platform::CPUDeviceContext> g_cpu_device_context = static std::unique_ptr<platform::CPUDeviceContext> g_cpu_device_context =
make_unique<platform::CPUDeviceContext>(platform::CPUPlace()); make_unique<platform::CPUDeviceContext>(platform::CPUPlace());
return g_cpu_device_context.get(); return g_cpu_device_context.get();
} }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
platform::CUDADeviceContext* GetCUDADeviceContext() { platform::CUDADeviceContext* GetCUDADeviceContext() {
static std::unique_ptr<platform::CUDADeviceContext> g_cuda_device_context = static std::unique_ptr<platform::CUDADeviceContext> g_cuda_device_context =
make_unique<platform::CUDADeviceContext>(platform::GPUPlace(0)); make_unique<platform::CUDADeviceContext>(platform::GPUPlace(0));
return g_cuda_device_context.get(); return g_cuda_device_context.get();
} }
#endif #endif
framework::Scope* GetScope() {
static std::unique_ptr<framework::Scope> g_scope =
make_unique<framework::Scope>();
return g_scope.get();
}
Executor* NewLocalExecutor(const platform::Place& place, Executor* NewLocalExecutor(const platform::Place& place,
const ProgramDesc& pdesc, bool is_linear) { const ProgramDesc& pdesc, bool is_linear) {
platform::DeviceContext* device_context = nullptr; platform::DeviceContext* device_context = nullptr;
...@@ -110,11 +119,12 @@ Executor* NewLocalExecutor(const platform::Place& place, ...@@ -110,11 +119,12 @@ Executor* NewLocalExecutor(const platform::Place& place,
PADDLE_THROW("'GPUPlace' is not supported in CPU only device."); PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
} }
#endif #endif
return new ExecutorImpl(device_context, &pdesc, is_linear); return new ExecutorImpl(GetScope(), device_context, &pdesc, is_linear);
} }
void ExecutorImpl::Run() { void ExecutorImpl::Run() {
// operators running // operators running
scope_->NewVar();
device_context_->Wait(); device_context_->Wait();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册