提交 b630d401 编写于 作者: Q qijun

add scope

上级 d4be9730
...@@ -44,5 +44,5 @@ add_custom_command(TARGET framework_py_proto POST_BUILD ...@@ -44,5 +44,5 @@ add_custom_command(TARGET framework_py_proto POST_BUILD
cc_library(backward SRCS backward.cc DEPS net_op) cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context) cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context)
cc_library(executor SRCS executor.cc DEPS device_context framework_proto) cc_library(executor SRCS executor.cc DEPS device_context scope framework_proto)
cc_test(executor_test SRCS executor_test.cc DEPS executor) cc_test(executor_test SRCS executor_test.cc DEPS executor)
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/framework/executor.h" #include "paddle/framework/executor.h"
#include <memory> #include <memory>
#include "paddle/framework/scope.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
...@@ -58,9 +59,10 @@ void GraphView::Initialize(const ProgramDesc*) { ...@@ -58,9 +59,10 @@ void GraphView::Initialize(const ProgramDesc*) {
class ExecutorImpl : public Executor { class ExecutorImpl : public Executor {
public: public:
ExecutorImpl(const platform::DeviceContext* ctx, const ProgramDesc* pdesc, ExecutorImpl(Scope* scope, const platform::DeviceContext* ctx,
bool is_linear) const ProgramDesc* pdesc, bool is_linear)
: device_context_(ctx), : scope_(scope),
device_context_(ctx),
program_desc_(pdesc), program_desc_(pdesc),
view_(ProgramDescView::Create(is_linear)) {} view_(ProgramDescView::Create(is_linear)) {}
...@@ -73,6 +75,7 @@ class ExecutorImpl : public Executor { ...@@ -73,6 +75,7 @@ class ExecutorImpl : public Executor {
void Initialize(); void Initialize();
private: private:
Scope* scope_;
const platform::DeviceContext* device_context_; const platform::DeviceContext* device_context_;
const ProgramDesc* program_desc_; const ProgramDesc* program_desc_;
ProgramDescView* view_; ProgramDescView* view_;
...@@ -97,6 +100,12 @@ platform::CUDADeviceContext* GetCUDADeviceContext() { ...@@ -97,6 +100,12 @@ platform::CUDADeviceContext* GetCUDADeviceContext() {
} }
#endif #endif
framework::Scope* GetScope() {
static std::unique_ptr<framework::Scope> g_scope =
make_unique<framework::Scope>();
return g_scope.get();
}
Executor* NewLocalExecutor(const platform::Place& place, Executor* NewLocalExecutor(const platform::Place& place,
const ProgramDesc& pdesc, bool is_linear) { const ProgramDesc& pdesc, bool is_linear) {
platform::DeviceContext* device_context = nullptr; platform::DeviceContext* device_context = nullptr;
...@@ -110,11 +119,12 @@ Executor* NewLocalExecutor(const platform::Place& place, ...@@ -110,11 +119,12 @@ Executor* NewLocalExecutor(const platform::Place& place,
PADDLE_THROW("'GPUPlace' is not supported in CPU only device."); PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
} }
#endif #endif
return new ExecutorImpl(device_context, &pdesc, is_linear); return new ExecutorImpl(GetScope(), device_context, &pdesc, is_linear);
} }
void ExecutorImpl::Run() { void ExecutorImpl::Run() {
// operators running // operators running
scope_->NewVar();
device_context_->Wait(); device_context_->Wait();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册