提交 383b96f3 编写于 作者: L liaogang

FIX: merge conflicts

上级 f404282d
...@@ -25,7 +25,7 @@ MetadataCache::MetadataCache(bool uses_gpu) : uses_gpu_(uses_gpu) {} ...@@ -25,7 +25,7 @@ MetadataCache::MetadataCache(bool uses_gpu) : uses_gpu_(uses_gpu) {}
Metadata MetadataCache::load(const MemoryBlock* block) { Metadata MetadataCache::load(const MemoryBlock* block) {
if (uses_gpu_) { if (uses_gpu_) {
auto existing_metadata = cache_.find(block); auto existing_metadata = cache_.find(block);
assert(existing_metadata->second.check_guards()); PADDLE_ASSERT(existing_metadata->second.check_guards());
return existing_metadata->second; return existing_metadata->second;
} else { } else {
PADDLE_ASSERT(reinterpret_cast<const Metadata*>(block)->check_guards()); PADDLE_ASSERT(reinterpret_cast<const Metadata*>(block)->check_guards());
......
...@@ -52,7 +52,7 @@ size_t Used<platform::CPUPlace>(platform::CPUPlace place) { ...@@ -52,7 +52,7 @@ size_t Used<platform::CPUPlace>(platform::CPUPlace place) {
detail::BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { detail::BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
static detail::BuddyAllocator** as = NULL; static detail::BuddyAllocator** as = NULL;
if (as == NULL) { if (as == NULL) {
int gpu_num = platform::GpuDeviceCount(); int gpu_num = platform::GetDeviceCount();
as = new detail::BuddyAllocator*[gpu_num]; as = new detail::BuddyAllocator*[gpu_num];
for (int gpu = 0; gpu < gpu_num; gpu++) { for (int gpu = 0; gpu < gpu_num; gpu++) {
platform::SetDeviceId(gpu); platform::SetDeviceId(gpu);
......
...@@ -8,4 +8,4 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags) ...@@ -8,4 +8,4 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
cc_library(dynamic_loader SRCS dynload/dynamic_loader.cc DEPS gflags glog) cc_library(dynamic_loader SRCS dynload/dynamic_loader.cc DEPS gflags glog)
nv_test(device_context_test SRCS device_context_test.cc DEPS dynamic_loader place eigen3) nv_test(device_context_test SRCS device_context_test.cc DEPS dynamic_loader place eigen3 gpu_info)
...@@ -16,10 +16,11 @@ limitations under the License. */ ...@@ -16,10 +16,11 @@ limitations under the License. */
#include "paddle/framework/enforce.h" #include "paddle/framework/enforce.h"
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
#include "paddle/platform/cuda.h"
#include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h" #include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h" #include "paddle/platform/dynload/curand.h"
#include "paddle/platform/error.h"
#include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
......
...@@ -23,11 +23,11 @@ DEFINE_double(fraction_of_gpu_memory_to_use, 0.95, ...@@ -23,11 +23,11 @@ DEFINE_double(fraction_of_gpu_memory_to_use, 0.95,
namespace paddle { namespace paddle {
namespace platform { namespace platform {
int GpuDeviceCount() { int GetDeviceCount() {
int count; int count;
throw_on_error( throw_on_error(
cudaGetDeviceCount(&count), cudaGetDeviceCount(&count),
"cudaGetDeviceCount failed in paddle::platform::GpuDeviceCount"); "cudaGetDeviceCount failed in paddle::platform::GetDeviceCount");
return count; return count;
} }
......
...@@ -22,7 +22,7 @@ namespace paddle { ...@@ -22,7 +22,7 @@ namespace paddle {
namespace platform { namespace platform {
//! Get the total number of GPU devices in system. //! Get the total number of GPU devices in system.
int GpuDeviceCount(); int GetDeviceCount();
//! Get the current GPU device id in system. //! Get the current GPU device id in system.
int GetCurrentDeviceId(); int GetCurrentDeviceId();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册