From ce4d14b4ed5384dc5fb9eb4e2c6d7f1c6b9bc6dd Mon Sep 17 00:00:00 2001 From: qijun Date: Sun, 1 Oct 2017 15:08:20 -0700 Subject: [PATCH] add struct Device --- paddle/framework/CMakeLists.txt | 2 +- paddle/framework/executor.cc | 73 ++++++++++++++++++++++----------- 2 files changed, 51 insertions(+), 24 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 1168fc38af..129a0eb707 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -44,5 +44,5 @@ add_custom_command(TARGET framework_py_proto POST_BUILD cc_library(backward SRCS backward.cc DEPS net_op) cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context) -cc_library(executor SRCS executor.cc DEPS device_context scope framework_proto) +cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto) cc_test(executor_test SRCS executor_test.cc DEPS executor) diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index 74153f2449..559cbe125f 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -14,6 +14,8 @@ limitations under the License. */ #include "paddle/framework/executor.h" #include +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" #include "paddle/framework/scope.h" #include "paddle/platform/device_context.h" @@ -34,6 +36,9 @@ class ProgramDescView { class LinearListView : public ProgramDescView { public: void Initialize(const ProgramDesc*) override; + + private: + std::vector> ops_; }; class GraphView : public ProgramDescView { @@ -49,20 +54,36 @@ ProgramDescView* ProgramDescView::Create(bool is_linear) { } } -void LinearListView::Initialize(const ProgramDesc*) { +void LinearListView::Initialize(const ProgramDesc* pdesc) { // get a LinearView of ProgramDesc + for (auto& block_desc : pdesc->blocks()) { + for (auto& op_desc : block_desc.ops()) { + ops_.emplace_back(OpRegistry::CreateOp(op_desc)); + } + } } -void GraphView::Initialize(const ProgramDesc*) { +void GraphView::Initialize(const ProgramDesc* pdesc) { // get a GraphView of ProgramDesc } +struct Device { + platform::CPUDeviceContext* cpu_device_context; +#ifndef PADDLE_ONLY_CPU + Device(platform::CPUDeviceContext* cpu, platform::CUDADeviceContext* gpu) + : cpu_device_context(cpu), cuda_device_context(gpu) {} + platform::CDUADeviceContext* cuda_device_context; +#else + explicit Device(platform::CPUDeviceContext* cpu) : cpu_device_context(cpu) {} +#endif +}; + class ExecutorImpl : public Executor { public: - ExecutorImpl(Scope* scope, const platform::DeviceContext* ctx, - const ProgramDesc* pdesc, bool is_linear) + ExecutorImpl(Scope* scope, const Device* device, const ProgramDesc* pdesc, + bool is_linear) : scope_(scope), - device_context_(ctx), + device_(device), program_desc_(pdesc), view_(ProgramDescView::Create(is_linear)) {} @@ -76,7 +97,7 @@ class ExecutorImpl : public Executor { private: Scope* scope_; - const platform::DeviceContext* device_context_; + const Device* device_; const ProgramDesc* program_desc_; ProgramDescView* view_; }; @@ -86,20 +107,36 @@ std::unique_ptr make_unique(Args&&... args) { return std::unique_ptr(new T(std::forward(args)...)); } -platform::CPUDeviceContext* GetCPUDeviceContext(platform::CPUPlace& place) { +platform::CPUDeviceContext* GetCPUDeviceContext( + const platform::CPUPlace& place) { static std::unique_ptr g_cpu_device_context = make_unique(place); return g_cpu_device_context.get(); } #ifndef PADDLE_ONLY_CPU -platform::CUDADeviceContext* GetCUDADeviceContext(platform::GPUPlace& place) { +platform::CUDADeviceContext* GetCUDADeviceContext( + const platform::GPUPlace& place) { static std::unique_ptr g_cuda_device_context = make_unique(place); return g_cuda_device_context.get(); } #endif +Device* GetDevice(const platform::Place& place) { + platform::CPUPlace cpu_place; +#ifndef PADDLE_ONLY_CPU + platform::GPUPlace gpu_place = boost::get(place); + static std::unique_ptr g_device = make_unique( + GetCPUDeviceContext(cpu_place), GetCUDADeviceContext(gpu_place)); + return g_device.get(); +#else + static std::unique_ptr g_device = + make_unique(GetCPUDeviceContext(cpu_place)); + return g_device.get(); +#endif +} + framework::Scope* GetScope() { static std::unique_ptr g_scope = make_unique(); @@ -108,26 +145,16 @@ framework::Scope* GetScope() { Executor* NewLocalExecutor(const platform::Place& place, const ProgramDesc& pdesc, bool is_linear) { - platform::DeviceContext* device_context = nullptr; - if (platform::is_cpu_place(place)) { - auto cpu_place = boost::get(place); - device_context = GetCPUDeviceContext(cpu_place); - } else if (platform::is_gpu_place(place)) { -#ifndef PADDLE_ONLY_CPU - auto gpu_place = boost::get(place); - device_context = GetCUDADeviceContext(gpu_place); - } -#else - PADDLE_THROW("'GPUPlace' is not supported in CPU only device."); - } -#endif - return new ExecutorImpl(GetScope(), device_context, &pdesc, is_linear); + return new ExecutorImpl(GetScope(), GetDevice(place), &pdesc, is_linear); } void ExecutorImpl::Run() { // operators running scope_->NewVar(); - device_context_->Wait(); + device_->cpu_device_context->Wait(); +#ifndef PADDLE_ONLY_CPU + device_->cuda_device_context->Wait(); +#endif } void ExecutorImpl::Initialize() { -- GitLab