diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 3166593365404e98fad0e91a7d7b5cd7176cd9ed..b19929127ee8105c2f1a5cf34d831edeffc36b4e 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -313,14 +313,23 @@ CUDADeviceContext::~CUDADeviceContext() { Place CUDADeviceContext::GetPlace() const { return place_; } void CUDADeviceContext::Wait() const { - cudaError_t e_sync = cudaStreamSynchronize(stream_); - if (e_sync != 0) { + cudaError_t e_sync = cudaSuccess; +#if !defined(_WIN32) + e_sync = cudaStreamSynchronize(stream_); +#else + while (e_sync = cudaStreamQuery(stream_)) { + if (e_sync == cudaErrorNotReady) continue; + break; + } +#endif + + if (cudaSuccess != e_sync) { LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync) << " errno: " << e_sync; } cudaError_t e_get = cudaGetLastError(); - if (e_get != 0) { + if (cudaSuccess != e_get) { LOG(FATAL) << "cudaGetLastError " << cudaGetErrorString(e_get) << " errno: " << e_get; }