提交 f29a6b02 编写于 作者: Q qijun

fix gpu build error

上级 f1c5d9e7
...@@ -69,10 +69,13 @@ void GraphView::Initialize(const ProgramDesc* pdesc) { ...@@ -69,10 +69,13 @@ void GraphView::Initialize(const ProgramDesc* pdesc) {
struct Device { struct Device {
platform::CPUDeviceContext* cpu_device_context; platform::CPUDeviceContext* cpu_device_context;
#ifndef PADDLE_ONLY_CPU
platform::CUDADeviceContext* cuda_device_context;
#endif
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
Device(platform::CPUDeviceContext* cpu, platform::CUDADeviceContext* gpu) Device(platform::CPUDeviceContext* cpu, platform::CUDADeviceContext* gpu)
: cpu_device_context(cpu), cuda_device_context(gpu) {} : cpu_device_context(cpu), cuda_device_context(gpu) {}
platform::CDUADeviceContext* cuda_device_context;
#else #else
explicit Device(platform::CPUDeviceContext* cpu) : cpu_device_context(cpu) {} explicit Device(platform::CPUDeviceContext* cpu) : cpu_device_context(cpu) {}
#endif #endif
...@@ -126,10 +129,16 @@ platform::CUDADeviceContext* GetCUDADeviceContext( ...@@ -126,10 +129,16 @@ platform::CUDADeviceContext* GetCUDADeviceContext(
Device* GetDevice(const platform::Place& place) { Device* GetDevice(const platform::Place& place) {
platform::CPUPlace cpu_place; platform::CPUPlace cpu_place;
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
if (platform::is_gpu_place(place)) {
platform::GPUPlace gpu_place = boost::get<platform::GPUPlace>(place); platform::GPUPlace gpu_place = boost::get<platform::GPUPlace>(place);
static std::unique_ptr<Device> g_device = make_unique<Device>( static std::unique_ptr<Device> g_device = make_unique<Device>(
GetCPUDeviceContext(cpu_place), GetCUDADeviceContext(gpu_place)); GetCPUDeviceContext(cpu_place), GetCUDADeviceContext(gpu_place));
return g_device.get(); return g_device.get();
} else {
static std::unique_ptr<Device> g_device =
make_unique<Device>(GetCPUDeviceContext(cpu_place), nullptr);
return g_device.get();
}
#else #else
static std::unique_ptr<Device> g_device = static std::unique_ptr<Device> g_device =
make_unique<Device>(GetCPUDeviceContext(cpu_place)); make_unique<Device>(GetCPUDeviceContext(cpu_place));
...@@ -153,7 +162,9 @@ void ExecutorImpl::Run() { ...@@ -153,7 +162,9 @@ void ExecutorImpl::Run() {
scope_->NewVar(); scope_->NewVar();
device_->cpu_device_context->Wait(); device_->cpu_device_context->Wait();
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
if (device_->cuda_device_context) {
device_->cuda_device_context->Wait(); device_->cuda_device_context->Wait();
}
#endif #endif
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册