From 8ae1b325d61c2789bdfc9d541e89dd70fb5156c0 Mon Sep 17 00:00:00 2001 From: liaogang Date: Thu, 1 Sep 2016 06:32:51 +0000 Subject: [PATCH] =?UTF-8?q?=C3=A2=C2=80fix=20bug=20in=20cuda=5Faggregate?= =?UTF-8?q?=20ISSUE=3D4608831?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit git-svn-id: https://svn.baidu.com/idl/trunk/paddle@1498 1ad973e4-5ce8-4261-8a94-b56d1f490c56 --- paddle/cuda/include/hl_cuda.h | 11 ++++++----- paddle/cuda/include/stub/hl_cuda_stub.h | 2 +- paddle/cuda/src/hl_cuda_aggregate.cu | 22 ++++++++++------------ paddle/cuda/src/hl_cuda_device.cc | 5 +++-- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/paddle/cuda/include/hl_cuda.h b/paddle/cuda/include/hl_cuda.h index ffdf71229a..3196db67f6 100644 --- a/paddle/cuda/include/hl_cuda.h +++ b/paddle/cuda/include/hl_cuda.h @@ -321,13 +321,14 @@ extern const char* hl_get_device_error_string(size_t err); extern int hl_get_device_last_error(); /** - * @brief hppl query event. + * @brief check cuda event is ready * - * @param[in] event cuda event to query. - * @param[out] isNotReady this work under device has not yet been - * completed, vice versa. + * @param[in] event cuda event to query. + * + * @return true cuda event is ready. + * false cuda event is not ready. */ -extern void hl_cuda_event_query(hl_event_t event, bool& isNotReady); +extern bool hl_cuda_event_is_ready(hl_event_t event); /** * @brief hppl device synchronization. diff --git a/paddle/cuda/include/stub/hl_cuda_stub.h b/paddle/cuda/include/stub/hl_cuda_stub.h index 395101c6f7..675ac03b0e 100644 --- a/paddle/cuda/include/stub/hl_cuda_stub.h +++ b/paddle/cuda/include/stub/hl_cuda_stub.h @@ -89,7 +89,7 @@ inline const char* hl_get_device_error_string() { return NULL; } inline const char* hl_get_device_error_string(size_t err) { return NULL; } -inline void hl_cuda_event_query(hl_event_t event, bool& isNotReady) {} +inline bool hl_cuda_event_is_ready(hl_event_t event) { return true; } inline void hl_device_synchronize() {} diff --git a/paddle/cuda/src/hl_cuda_aggregate.cu b/paddle/cuda/src/hl_cuda_aggregate.cu index c0b84b087b..4eb775eb79 100644 --- a/paddle/cuda/src/hl_cuda_aggregate.cu +++ b/paddle/cuda/src/hl_cuda_aggregate.cu @@ -261,11 +261,7 @@ void hl_vector_sum(real *A_d, real *C_h, int dimM) { struct _hl_event_st hl_event_st = {.cu_event = t_resource.event}; hl_event_t hl_event = &hl_event_st; - - bool isNotReady = false; - do { - hl_cuda_event_query(hl_event, isNotReady); - } while (isNotReady == cudaErrorNotReady); + while (!hl_cuda_event_is_ready(hl_event)) {} KeVectorSum<128><<< grid, threads, 0, STREAM_DEFAULT >>> (A_d, t_resource.gpu_mem, dimM); @@ -275,7 +271,10 @@ void hl_vector_sum(real *A_d, real *C_h, int dimM) { hl_memcpy_async(C_h, t_resource.cpu_mem, sizeof(real), HPPL_STREAM_DEFAULT); hl_stream_record_event(HPPL_STREAM_DEFAULT, hl_event); - CHECK_SYNC("hl_vector_sum failed"); + hl_stream_synchronize(HPPL_STREAM_DEFAULT); + cudaError_t err = (cudaError_t)hl_get_device_last_error(); + CHECK_EQ(cudaSuccess, err) + << "CUDA error: " << hl_get_device_error_string((size_t)err); } template @@ -317,11 +316,7 @@ void hl_vector_abs_sum(real *A_d, real *C_h, int dimM) { struct _hl_event_st hl_event_st = {.cu_event = t_resource.event}; hl_event_t hl_event = &hl_event_st; - - bool isNotReady = false; - do { - hl_cuda_event_query(hl_event, isNotReady); - } while (isNotReady == cudaErrorNotReady); + while (!hl_cuda_event_is_ready(hl_event)) {} KeVectorAbsSum<128><<< grid, threads, 0, STREAM_DEFAULT >>> (A_d, t_resource.gpu_mem, dimM); @@ -331,5 +326,8 @@ void hl_vector_abs_sum(real *A_d, real *C_h, int dimM) { hl_memcpy_async(C_h, t_resource.cpu_mem, sizeof(real), HPPL_STREAM_DEFAULT); hl_stream_record_event(HPPL_STREAM_DEFAULT, hl_event); - CHECK_SYNC("hl_vector_abs_sum failed"); + hl_stream_synchronize(HPPL_STREAM_DEFAULT); + cudaError_t err = (cudaError_t)hl_get_device_last_error(); + CHECK_EQ(cudaSuccess, err) + << "CUDA error: " << hl_get_device_error_string((size_t)err); } diff --git a/paddle/cuda/src/hl_cuda_device.cc b/paddle/cuda/src/hl_cuda_device.cc index 774eef8b89..f07538d6ba 100644 --- a/paddle/cuda/src/hl_cuda_device.cc +++ b/paddle/cuda/src/hl_cuda_device.cc @@ -751,11 +751,12 @@ void hl_set_device_flags_block() { cudaDeviceScheduleBlockingSync)); } -void hl_cuda_event_query(hl_event_t event, bool& isNotReady) { +bool hl_cuda_event_is_ready(hl_event_t event) { cudaError_t err = dynload::cudaEventQuery(event->cu_event); CHECK(cudaSuccess == err || cudaErrorNotReady == err); if (cudaErrorNotReady == err) { - isNotReady = true; + return false; } + return true; } -- GitLab