提交 ce4d14b4 编写于 作者: Q qijun

add struct Device

上级 09500917
......@@ -44,5 +44,5 @@ add_custom_command(TARGET framework_py_proto POST_BUILD
cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context)
cc_library(executor SRCS executor.cc DEPS device_context scope framework_proto)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto)
cc_test(executor_test SRCS executor_test.cc DEPS executor)
......@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/framework/executor.h"
#include <memory>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/platform/device_context.h"
......@@ -34,6 +36,9 @@ class ProgramDescView {
class LinearListView : public ProgramDescView {
public:
void Initialize(const ProgramDesc*) override;
private:
std::vector<std::unique_ptr<OperatorBase>> ops_;
};
class GraphView : public ProgramDescView {
......@@ -49,20 +54,36 @@ ProgramDescView* ProgramDescView::Create(bool is_linear) {
}
}
void LinearListView::Initialize(const ProgramDesc*) {
void LinearListView::Initialize(const ProgramDesc* pdesc) {
// get a LinearView of ProgramDesc
for (auto& block_desc : pdesc->blocks()) {
for (auto& op_desc : block_desc.ops()) {
ops_.emplace_back(OpRegistry::CreateOp(op_desc));
}
}
}
void GraphView::Initialize(const ProgramDesc*) {
void GraphView::Initialize(const ProgramDesc* pdesc) {
// get a GraphView of ProgramDesc
}
struct Device {
platform::CPUDeviceContext* cpu_device_context;
#ifndef PADDLE_ONLY_CPU
Device(platform::CPUDeviceContext* cpu, platform::CUDADeviceContext* gpu)
: cpu_device_context(cpu), cuda_device_context(gpu) {}
platform::CDUADeviceContext* cuda_device_context;
#else
explicit Device(platform::CPUDeviceContext* cpu) : cpu_device_context(cpu) {}
#endif
};
class ExecutorImpl : public Executor {
public:
ExecutorImpl(Scope* scope, const platform::DeviceContext* ctx,
const ProgramDesc* pdesc, bool is_linear)
ExecutorImpl(Scope* scope, const Device* device, const ProgramDesc* pdesc,
bool is_linear)
: scope_(scope),
device_context_(ctx),
device_(device),
program_desc_(pdesc),
view_(ProgramDescView::Create(is_linear)) {}
......@@ -76,7 +97,7 @@ class ExecutorImpl : public Executor {
private:
Scope* scope_;
const platform::DeviceContext* device_context_;
const Device* device_;
const ProgramDesc* program_desc_;
ProgramDescView* view_;
};
......@@ -86,20 +107,36 @@ std::unique_ptr<T> make_unique(Args&&... args) {
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
}
platform::CPUDeviceContext* GetCPUDeviceContext(platform::CPUPlace& place) {
platform::CPUDeviceContext* GetCPUDeviceContext(
const platform::CPUPlace& place) {
static std::unique_ptr<platform::CPUDeviceContext> g_cpu_device_context =
make_unique<platform::CPUDeviceContext>(place);
return g_cpu_device_context.get();
}
#ifndef PADDLE_ONLY_CPU
platform::CUDADeviceContext* GetCUDADeviceContext(platform::GPUPlace& place) {
platform::CUDADeviceContext* GetCUDADeviceContext(
const platform::GPUPlace& place) {
static std::unique_ptr<platform::CUDADeviceContext> g_cuda_device_context =
make_unique<platform::CUDADeviceContext>(place);
return g_cuda_device_context.get();
}
#endif
Device* GetDevice(const platform::Place& place) {
platform::CPUPlace cpu_place;
#ifndef PADDLE_ONLY_CPU
platform::GPUPlace gpu_place = boost::get<platform::GPUPlace>(place);
static std::unique_ptr<Device> g_device = make_unique<Device>(
GetCPUDeviceContext(cpu_place), GetCUDADeviceContext(gpu_place));
return g_device.get();
#else
static std::unique_ptr<Device> g_device =
make_unique<Device>(GetCPUDeviceContext(cpu_place));
return g_device.get();
#endif
}
framework::Scope* GetScope() {
static std::unique_ptr<framework::Scope> g_scope =
make_unique<framework::Scope>();
......@@ -108,26 +145,16 @@ framework::Scope* GetScope() {
Executor* NewLocalExecutor(const platform::Place& place,
const ProgramDesc& pdesc, bool is_linear) {
platform::DeviceContext* device_context = nullptr;
if (platform::is_cpu_place(place)) {
auto cpu_place = boost::get<platform::CPUPlace>(place);
device_context = GetCPUDeviceContext(cpu_place);
} else if (platform::is_gpu_place(place)) {
#ifndef PADDLE_ONLY_CPU
auto gpu_place = boost::get<platform::GPUPlace>(place);
device_context = GetCUDADeviceContext(gpu_place);
}
#else
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
}
#endif
return new ExecutorImpl(GetScope(), device_context, &pdesc, is_linear);
return new ExecutorImpl(GetScope(), GetDevice(place), &pdesc, is_linear);
}
void ExecutorImpl::Run() {
// operators running
scope_->NewVar();
device_context_->Wait();
device_->cpu_device_context->Wait();
#ifndef PADDLE_ONLY_CPU
device_->cuda_device_context->Wait();
#endif
}
void ExecutorImpl::Initialize() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册