提交 f29a6b02 编写于 作者: Q qijun

fix gpu build error

上级 f1c5d9e7
......@@ -69,10 +69,13 @@ void GraphView::Initialize(const ProgramDesc* pdesc) {
struct Device {
platform::CPUDeviceContext* cpu_device_context;
#ifndef PADDLE_ONLY_CPU
platform::CUDADeviceContext* cuda_device_context;
#endif
#ifndef PADDLE_ONLY_CPU
Device(platform::CPUDeviceContext* cpu, platform::CUDADeviceContext* gpu)
: cpu_device_context(cpu), cuda_device_context(gpu) {}
platform::CDUADeviceContext* cuda_device_context;
#else
explicit Device(platform::CPUDeviceContext* cpu) : cpu_device_context(cpu) {}
#endif
......@@ -126,10 +129,16 @@ platform::CUDADeviceContext* GetCUDADeviceContext(
Device* GetDevice(const platform::Place& place) {
platform::CPUPlace cpu_place;
#ifndef PADDLE_ONLY_CPU
platform::GPUPlace gpu_place = boost::get<platform::GPUPlace>(place);
static std::unique_ptr<Device> g_device = make_unique<Device>(
GetCPUDeviceContext(cpu_place), GetCUDADeviceContext(gpu_place));
return g_device.get();
if (platform::is_gpu_place(place)) {
platform::GPUPlace gpu_place = boost::get<platform::GPUPlace>(place);
static std::unique_ptr<Device> g_device = make_unique<Device>(
GetCPUDeviceContext(cpu_place), GetCUDADeviceContext(gpu_place));
return g_device.get();
} else {
static std::unique_ptr<Device> g_device =
make_unique<Device>(GetCPUDeviceContext(cpu_place), nullptr);
return g_device.get();
}
#else
static std::unique_ptr<Device> g_device =
make_unique<Device>(GetCPUDeviceContext(cpu_place));
......@@ -153,7 +162,9 @@ void ExecutorImpl::Run() {
scope_->NewVar();
device_->cpu_device_context->Wait();
#ifndef PADDLE_ONLY_CPU
device_->cuda_device_context->Wait();
if (device_->cuda_device_context) {
device_->cuda_device_context->Wait();
}
#endif
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册