提交 b630d401 编写于 作者: Q qijun

add scope

上级 d4be9730
......@@ -44,5 +44,5 @@ add_custom_command(TARGET framework_py_proto POST_BUILD
cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context)
cc_library(executor SRCS executor.cc DEPS device_context framework_proto)
cc_library(executor SRCS executor.cc DEPS device_context scope framework_proto)
cc_test(executor_test SRCS executor_test.cc DEPS executor)
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/framework/executor.h"
#include <memory>
#include "paddle/framework/scope.h"
#include "paddle/platform/device_context.h"
namespace paddle {
......@@ -58,9 +59,10 @@ void GraphView::Initialize(const ProgramDesc*) {
class ExecutorImpl : public Executor {
public:
ExecutorImpl(const platform::DeviceContext* ctx, const ProgramDesc* pdesc,
bool is_linear)
: device_context_(ctx),
ExecutorImpl(Scope* scope, const platform::DeviceContext* ctx,
const ProgramDesc* pdesc, bool is_linear)
: scope_(scope),
device_context_(ctx),
program_desc_(pdesc),
view_(ProgramDescView::Create(is_linear)) {}
......@@ -73,6 +75,7 @@ class ExecutorImpl : public Executor {
void Initialize();
private:
Scope* scope_;
const platform::DeviceContext* device_context_;
const ProgramDesc* program_desc_;
ProgramDescView* view_;
......@@ -80,23 +83,29 @@ class ExecutorImpl : public Executor {
template <typename T, typename... Args>
std::unique_ptr<T> make_unique(Args&&... args) {
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
}
platform::CPUDeviceContext* GetCPUDeviceContext() {
static std::unique_ptr<platform::CPUDeviceContext> g_cpu_device_context =
make_unique<platform::CPUDeviceContext>(platform::CPUPlace());
make_unique<platform::CPUDeviceContext>(platform::CPUPlace());
return g_cpu_device_context.get();
}
#ifndef PADDLE_ONLY_CPU
platform::CUDADeviceContext* GetCUDADeviceContext() {
static std::unique_ptr<platform::CUDADeviceContext> g_cuda_device_context =
make_unique<platform::CUDADeviceContext>(platform::GPUPlace(0));
make_unique<platform::CUDADeviceContext>(platform::GPUPlace(0));
return g_cuda_device_context.get();
}
#endif
framework::Scope* GetScope() {
static std::unique_ptr<framework::Scope> g_scope =
make_unique<framework::Scope>();
return g_scope.get();
}
Executor* NewLocalExecutor(const platform::Place& place,
const ProgramDesc& pdesc, bool is_linear) {
platform::DeviceContext* device_context = nullptr;
......@@ -110,11 +119,12 @@ Executor* NewLocalExecutor(const platform::Place& place,
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
}
#endif
return new ExecutorImpl(device_context, &pdesc, is_linear);
return new ExecutorImpl(GetScope(), device_context, &pdesc, is_linear);
}
void ExecutorImpl::Run() {
// operators running
scope_->NewVar();
device_context_->Wait();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册