提交 ce4d14b4 编写于 作者: Q qijun

add struct Device

上级 09500917
...@@ -44,5 +44,5 @@ add_custom_command(TARGET framework_py_proto POST_BUILD ...@@ -44,5 +44,5 @@ add_custom_command(TARGET framework_py_proto POST_BUILD
cc_library(backward SRCS backward.cc DEPS net_op) cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context) cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context)
cc_library(executor SRCS executor.cc DEPS device_context scope framework_proto) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto)
cc_test(executor_test SRCS executor_test.cc DEPS executor) cc_test(executor_test SRCS executor_test.cc DEPS executor)
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/framework/executor.h" #include "paddle/framework/executor.h"
#include <memory> #include <memory>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
...@@ -34,6 +36,9 @@ class ProgramDescView { ...@@ -34,6 +36,9 @@ class ProgramDescView {
class LinearListView : public ProgramDescView { class LinearListView : public ProgramDescView {
public: public:
void Initialize(const ProgramDesc*) override; void Initialize(const ProgramDesc*) override;
private:
std::vector<std::unique_ptr<OperatorBase>> ops_;
}; };
class GraphView : public ProgramDescView { class GraphView : public ProgramDescView {
...@@ -49,20 +54,36 @@ ProgramDescView* ProgramDescView::Create(bool is_linear) { ...@@ -49,20 +54,36 @@ ProgramDescView* ProgramDescView::Create(bool is_linear) {
} }
} }
void LinearListView::Initialize(const ProgramDesc*) { void LinearListView::Initialize(const ProgramDesc* pdesc) {
// get a LinearView of ProgramDesc // get a LinearView of ProgramDesc
for (auto& block_desc : pdesc->blocks()) {
for (auto& op_desc : block_desc.ops()) {
ops_.emplace_back(OpRegistry::CreateOp(op_desc));
}
}
} }
void GraphView::Initialize(const ProgramDesc*) { void GraphView::Initialize(const ProgramDesc* pdesc) {
// get a GraphView of ProgramDesc // get a GraphView of ProgramDesc
} }
struct Device {
platform::CPUDeviceContext* cpu_device_context;
#ifndef PADDLE_ONLY_CPU
Device(platform::CPUDeviceContext* cpu, platform::CUDADeviceContext* gpu)
: cpu_device_context(cpu), cuda_device_context(gpu) {}
platform::CDUADeviceContext* cuda_device_context;
#else
explicit Device(platform::CPUDeviceContext* cpu) : cpu_device_context(cpu) {}
#endif
};
class ExecutorImpl : public Executor { class ExecutorImpl : public Executor {
public: public:
ExecutorImpl(Scope* scope, const platform::DeviceContext* ctx, ExecutorImpl(Scope* scope, const Device* device, const ProgramDesc* pdesc,
const ProgramDesc* pdesc, bool is_linear) bool is_linear)
: scope_(scope), : scope_(scope),
device_context_(ctx), device_(device),
program_desc_(pdesc), program_desc_(pdesc),
view_(ProgramDescView::Create(is_linear)) {} view_(ProgramDescView::Create(is_linear)) {}
...@@ -76,7 +97,7 @@ class ExecutorImpl : public Executor { ...@@ -76,7 +97,7 @@ class ExecutorImpl : public Executor {
private: private:
Scope* scope_; Scope* scope_;
const platform::DeviceContext* device_context_; const Device* device_;
const ProgramDesc* program_desc_; const ProgramDesc* program_desc_;
ProgramDescView* view_; ProgramDescView* view_;
}; };
...@@ -86,20 +107,36 @@ std::unique_ptr<T> make_unique(Args&&... args) { ...@@ -86,20 +107,36 @@ std::unique_ptr<T> make_unique(Args&&... args) {
return std::unique_ptr<T>(new T(std::forward<Args>(args)...)); return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
} }
platform::CPUDeviceContext* GetCPUDeviceContext(platform::CPUPlace& place) { platform::CPUDeviceContext* GetCPUDeviceContext(
const platform::CPUPlace& place) {
static std::unique_ptr<platform::CPUDeviceContext> g_cpu_device_context = static std::unique_ptr<platform::CPUDeviceContext> g_cpu_device_context =
make_unique<platform::CPUDeviceContext>(place); make_unique<platform::CPUDeviceContext>(place);
return g_cpu_device_context.get(); return g_cpu_device_context.get();
} }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
platform::CUDADeviceContext* GetCUDADeviceContext(platform::GPUPlace& place) { platform::CUDADeviceContext* GetCUDADeviceContext(
const platform::GPUPlace& place) {
static std::unique_ptr<platform::CUDADeviceContext> g_cuda_device_context = static std::unique_ptr<platform::CUDADeviceContext> g_cuda_device_context =
make_unique<platform::CUDADeviceContext>(place); make_unique<platform::CUDADeviceContext>(place);
return g_cuda_device_context.get(); return g_cuda_device_context.get();
} }
#endif #endif
Device* GetDevice(const platform::Place& place) {
platform::CPUPlace cpu_place;
#ifndef PADDLE_ONLY_CPU
platform::GPUPlace gpu_place = boost::get<platform::GPUPlace>(place);
static std::unique_ptr<Device> g_device = make_unique<Device>(
GetCPUDeviceContext(cpu_place), GetCUDADeviceContext(gpu_place));
return g_device.get();
#else
static std::unique_ptr<Device> g_device =
make_unique<Device>(GetCPUDeviceContext(cpu_place));
return g_device.get();
#endif
}
framework::Scope* GetScope() { framework::Scope* GetScope() {
static std::unique_ptr<framework::Scope> g_scope = static std::unique_ptr<framework::Scope> g_scope =
make_unique<framework::Scope>(); make_unique<framework::Scope>();
...@@ -108,26 +145,16 @@ framework::Scope* GetScope() { ...@@ -108,26 +145,16 @@ framework::Scope* GetScope() {
Executor* NewLocalExecutor(const platform::Place& place, Executor* NewLocalExecutor(const platform::Place& place,
const ProgramDesc& pdesc, bool is_linear) { const ProgramDesc& pdesc, bool is_linear) {
platform::DeviceContext* device_context = nullptr; return new ExecutorImpl(GetScope(), GetDevice(place), &pdesc, is_linear);
if (platform::is_cpu_place(place)) {
auto cpu_place = boost::get<platform::CPUPlace>(place);
device_context = GetCPUDeviceContext(cpu_place);
} else if (platform::is_gpu_place(place)) {
#ifndef PADDLE_ONLY_CPU
auto gpu_place = boost::get<platform::GPUPlace>(place);
device_context = GetCUDADeviceContext(gpu_place);
}
#else
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
}
#endif
return new ExecutorImpl(GetScope(), device_context, &pdesc, is_linear);
} }
void ExecutorImpl::Run() { void ExecutorImpl::Run() {
// operators running // operators running
scope_->NewVar(); scope_->NewVar();
device_context_->Wait(); device_->cpu_device_context->Wait();
#ifndef PADDLE_ONLY_CPU
device_->cuda_device_context->Wait();
#endif
} }
void ExecutorImpl::Initialize() { void ExecutorImpl::Initialize() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册