未验证 提交 a91b8014 编写于 作者: L liuwei1031 提交者: GitHub

cherry-pick (#21201) to release/1.6 (#21306)

cudaStreamSynchronize randomly hang when used in multi-thread environment, replace it with cudaStreamQuery API on windows
上级 3848f720
...@@ -313,14 +313,23 @@ CUDADeviceContext::~CUDADeviceContext() { ...@@ -313,14 +313,23 @@ CUDADeviceContext::~CUDADeviceContext() {
Place CUDADeviceContext::GetPlace() const { return place_; } Place CUDADeviceContext::GetPlace() const { return place_; }
void CUDADeviceContext::Wait() const { void CUDADeviceContext::Wait() const {
cudaError_t e_sync = cudaStreamSynchronize(stream_); cudaError_t e_sync = cudaSuccess;
if (e_sync != 0) { #if !defined(_WIN32)
e_sync = cudaStreamSynchronize(stream_);
#else
while (e_sync = cudaStreamQuery(stream_)) {
if (e_sync == cudaErrorNotReady) continue;
break;
}
#endif
if (cudaSuccess != e_sync) {
LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync) LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync)
<< " errno: " << e_sync; << " errno: " << e_sync;
} }
cudaError_t e_get = cudaGetLastError(); cudaError_t e_get = cudaGetLastError();
if (e_get != 0) { if (cudaSuccess != e_get) {
LOG(FATAL) << "cudaGetLastError " << cudaGetErrorString(e_get) LOG(FATAL) << "cudaGetLastError " << cudaGetErrorString(e_get)
<< " errno: " << e_get; << " errno: " << e_get;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册