提交 09500917 编写于 作者: Q qijun

pass place to GetCUDADeviceContext

上级 39b2ff36
...@@ -86,16 +86,16 @@ std::unique_ptr<T> make_unique(Args&&... args) { ...@@ -86,16 +86,16 @@ std::unique_ptr<T> make_unique(Args&&... args) {
return std::unique_ptr<T>(new T(std::forward<Args>(args)...)); return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
} }
platform::CPUDeviceContext* GetCPUDeviceContext() { platform::CPUDeviceContext* GetCPUDeviceContext(platform::CPUPlace& place) {
static std::unique_ptr<platform::CPUDeviceContext> g_cpu_device_context = static std::unique_ptr<platform::CPUDeviceContext> g_cpu_device_context =
make_unique<platform::CPUDeviceContext>(platform::CPUPlace()); make_unique<platform::CPUDeviceContext>(place);
return g_cpu_device_context.get(); return g_cpu_device_context.get();
} }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
platform::CUDADeviceContext* GetCUDADeviceContext() { platform::CUDADeviceContext* GetCUDADeviceContext(platform::GPUPlace& place) {
static std::unique_ptr<platform::CUDADeviceContext> g_cuda_device_context = static std::unique_ptr<platform::CUDADeviceContext> g_cuda_device_context =
make_unique<platform::CUDADeviceContext>(platform::GPUPlace(0)); make_unique<platform::CUDADeviceContext>(place);
return g_cuda_device_context.get(); return g_cuda_device_context.get();
} }
#endif #endif
...@@ -110,10 +110,12 @@ Executor* NewLocalExecutor(const platform::Place& place, ...@@ -110,10 +110,12 @@ Executor* NewLocalExecutor(const platform::Place& place,
const ProgramDesc& pdesc, bool is_linear) { const ProgramDesc& pdesc, bool is_linear) {
platform::DeviceContext* device_context = nullptr; platform::DeviceContext* device_context = nullptr;
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
device_context = GetCPUDeviceContext(); auto cpu_place = boost::get<platform::CPUPlace>(place);
device_context = GetCPUDeviceContext(cpu_place);
} else if (platform::is_gpu_place(place)) { } else if (platform::is_gpu_place(place)) {
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
device_context = GetCUDADeviceContext(); auto gpu_place = boost::get<platform::GPUPlace>(place);
device_context = GetCUDADeviceContext(gpu_place);
} }
#else #else
PADDLE_THROW("'GPUPlace' is not supported in CPU only device."); PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册