未验证 提交 3ce879db 编写于 作者: Y Yiqun Liu 提交者: GitHub

Optimize the finding of max workspace size. (#41741)

上级 64237c3f
...@@ -276,11 +276,12 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> { ...@@ -276,11 +276,12 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
args.handle, args.idesc.desc(), args.wdesc.desc(), args.handle, args.idesc.desc(), args.wdesc.desc(),
args.cdesc.desc(), args.odesc.desc(), args.cdesc.desc(), args.odesc.desc(),
static_cast<cudnnConvolutionFwdAlgo_t>(algo), &workspace_size); static_cast<cudnnConvolutionFwdAlgo_t>(algo), &workspace_size);
if (status == CUDNN_STATUS_SUCCESS) { if (status == CUDNN_STATUS_SUCCESS &&
workspace_size <= workspace_size_limit) {
max_workspace_size = std::max(workspace_size, max_workspace_size); max_workspace_size = std::max(workspace_size, max_workspace_size);
} }
} }
return std::min(max_workspace_size, workspace_size_limit); return max_workspace_size;
} else { } else {
return workspace_size_limit; return workspace_size_limit;
} }
...@@ -425,11 +426,12 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> { ...@@ -425,11 +426,12 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
args.cdesc.desc(), args.idesc.desc(), args.cdesc.desc(), args.idesc.desc(),
static_cast<cudnnConvolutionBwdDataAlgo_t>(algo), static_cast<cudnnConvolutionBwdDataAlgo_t>(algo),
&workspace_size); &workspace_size);
if (status == CUDNN_STATUS_SUCCESS) { if (status == CUDNN_STATUS_SUCCESS &&
workspace_size <= workspace_size_limit) {
max_workspace_size = std::max(workspace_size, max_workspace_size); max_workspace_size = std::max(workspace_size, max_workspace_size);
} }
} }
return std::min(max_workspace_size, workspace_size_limit); return max_workspace_size;
} else { } else {
return workspace_size_limit; return workspace_size_limit;
} }
...@@ -588,11 +590,12 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -588,11 +590,12 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
args.cdesc.desc(), args.wdesc.desc(), args.cdesc.desc(), args.wdesc.desc(),
static_cast<cudnnConvolutionBwdFilterAlgo_t>(algo), static_cast<cudnnConvolutionBwdFilterAlgo_t>(algo),
&workspace_size); &workspace_size);
if (status == CUDNN_STATUS_SUCCESS) { if (status == CUDNN_STATUS_SUCCESS &&
workspace_size <= workspace_size_limit) {
max_workspace_size = std::max(workspace_size, max_workspace_size); max_workspace_size = std::max(workspace_size, max_workspace_size);
} }
} }
return std::min(max_workspace_size, workspace_size_limit); return max_workspace_size;
} else { } else {
return workspace_size_limit; return workspace_size_limit;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册