diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index 52963d20f066dca78294e1246e93e84789e5345f..74153f244990aa61f4bbf1aba08bce9526f16b06 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -86,16 +86,16 @@ std::unique_ptr make_unique(Args&&... args) { return std::unique_ptr(new T(std::forward(args)...)); } -platform::CPUDeviceContext* GetCPUDeviceContext() { +platform::CPUDeviceContext* GetCPUDeviceContext(platform::CPUPlace& place) { static std::unique_ptr g_cpu_device_context = - make_unique(platform::CPUPlace()); + make_unique(place); return g_cpu_device_context.get(); } #ifndef PADDLE_ONLY_CPU -platform::CUDADeviceContext* GetCUDADeviceContext() { +platform::CUDADeviceContext* GetCUDADeviceContext(platform::GPUPlace& place) { static std::unique_ptr g_cuda_device_context = - make_unique(platform::GPUPlace(0)); + make_unique(place); return g_cuda_device_context.get(); } #endif @@ -110,10 +110,12 @@ Executor* NewLocalExecutor(const platform::Place& place, const ProgramDesc& pdesc, bool is_linear) { platform::DeviceContext* device_context = nullptr; if (platform::is_cpu_place(place)) { - device_context = GetCPUDeviceContext(); + auto cpu_place = boost::get(place); + device_context = GetCPUDeviceContext(cpu_place); } else if (platform::is_gpu_place(place)) { #ifndef PADDLE_ONLY_CPU - device_context = GetCUDADeviceContext(); + auto gpu_place = boost::get(place); + device_context = GetCUDADeviceContext(gpu_place); } #else PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");