提交 f168843e 编写于 作者: Q qijun

fix gpu build error

上级 5e62605c
add_subdirectory(detail) add_subdirectory(detail)
cc_library(memory SRCS memory.cc) cc_library(memory SRCS memory.cc)
cc_library(memcpy SRCS memcpy.cc DEPS device_context) cc_library(memcpy SRCS memcpy.cc)
cc_library(paddle_memory cc_library(paddle_memory
DEPS DEPS
......
...@@ -16,5 +16,8 @@ ELSE() ...@@ -16,5 +16,8 @@ ELSE()
set(GPU_CTX_DEPS) set(GPU_CTX_DEPS)
ENDIF() ENDIF()
cc_library(device_context SRCS device_context.cc DEPS memory place eigen3 ${GPU_CTX_DEPS}) # memcpy deoends on device_context, here add deps individually for
# avoiding cycle dependencies
cc_library(device_context SRCS device_context.cc DEPS memory buddy_allocator
system_allocator memory_block meta_data meta_cache place eigen3 ${GPU_CTX_DEPS})
nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_info) nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_info)
...@@ -57,7 +57,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { ...@@ -57,7 +57,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
} }
void* allocate(size_t num_bytes) const override { void* allocate(size_t num_bytes) const override {
paddle::memory::Alloc(place_, num_bytes); return paddle::memory::Alloc(place_, num_bytes);
} }
void deallocate(void* buffer) const override { void deallocate(void* buffer) const override {
...@@ -86,7 +86,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { ...@@ -86,7 +86,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
GPUPlace place_; GPUPlace place_;
const cudaStream_t* stream_; // not owned; const cudaStream_t* stream_; // not owned;
const cudaDeviceProp* device_prop_; // not owned; const cudaDeviceProp* device_prop_; // not owned;
mutable char* scratch_; mutable void* scratch_;
mutable unsigned int* semaphore_; mutable unsigned int* semaphore_;
}; };
...@@ -145,7 +145,7 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() { ...@@ -145,7 +145,7 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() {
if (!cudnn_handle_) { if (!cudnn_handle_) {
SetDeviceId(place_.device); SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnnHandle_t, stream_)); PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
} }
return cudnn_handle_; return cudnn_handle_;
} }
...@@ -160,7 +160,7 @@ curandGenerator_t CUDADeviceContext::curand_generator() { ...@@ -160,7 +160,7 @@ curandGenerator_t CUDADeviceContext::curand_generator() {
PADDLE_ENFORCE( PADDLE_ENFORCE(
dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_)); dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_));
PADDLE_ENFORCE(dynload::curandSetStream(curandGenerator_t, stream_)); PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_));
} }
return curand_generator_; return curand_generator_;
} }
......
...@@ -52,6 +52,7 @@ class CPUDeviceContext : public DeviceContext { ...@@ -52,6 +52,7 @@ class CPUDeviceContext : public DeviceContext {
}; };
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
class EigenCudaStreamDevice;
class CUDADeviceContext : public DeviceContext { class CUDADeviceContext : public DeviceContext {
public: public:
...@@ -92,7 +93,7 @@ class CUDADeviceContext : public DeviceContext { ...@@ -92,7 +93,7 @@ class CUDADeviceContext : public DeviceContext {
uint64_t seed_; uint64_t seed_;
// clang-format off // clang-format off
cudaStream_t stream_{nullptr} cudaStream_t stream_{nullptr};
cudnnHandle_t cudnn_handle_{nullptr}; cudnnHandle_t cudnn_handle_{nullptr};
cublasHandle_t cublas_handle_{nullptr}; cublasHandle_t cublas_handle_{nullptr};
curandGenerator_t curand_generator_{nullptr}; curandGenerator_t curand_generator_{nullptr};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册