diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index e04ffb4611a77b4edcac2c3dc65f59b620c5b831..04a52a5e9caea7643157ac035aef4957a7dec37e 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -314,14 +314,23 @@ CUDADeviceContext::~CUDADeviceContext() { Place CUDADeviceContext::GetPlace() const { return place_; } void CUDADeviceContext::Wait() const { - cudaError_t e_sync = cudaStreamSynchronize(stream_); - if (e_sync != 0) { + cudaError_t e_sync = cudaSuccess; +#if !defined(_WIN32) + e_sync = cudaStreamSynchronize(stream_); +#else + while (e_sync = cudaStreamQuery(stream_)) { + if (e_sync == cudaErrorNotReady) continue; + break; + } +#endif + + if (cudaSuccess != e_sync) { LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync) << " errno: " << e_sync; } cudaError_t e_get = cudaGetLastError(); - if (e_get != 0) { + if (cudaSuccess != e_get) { LOG(FATAL) << "cudaGetLastError " << cudaGetErrorString(e_get) << " errno: " << e_get; }