未验证 提交 3ce879db 编写于 作者: Y Yiqun Liu 提交者: GitHub

Optimize the finding of max workspace size. (#41741)

上级 64237c3f
......@@ -276,11 +276,12 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
args.handle, args.idesc.desc(), args.wdesc.desc(),
args.cdesc.desc(), args.odesc.desc(),
static_cast<cudnnConvolutionFwdAlgo_t>(algo), &workspace_size);
if (status == CUDNN_STATUS_SUCCESS) {
if (status == CUDNN_STATUS_SUCCESS &&
workspace_size <= workspace_size_limit) {
max_workspace_size = std::max(workspace_size, max_workspace_size);
}
}
return std::min(max_workspace_size, workspace_size_limit);
return max_workspace_size;
} else {
return workspace_size_limit;
}
......@@ -425,11 +426,12 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
args.cdesc.desc(), args.idesc.desc(),
static_cast<cudnnConvolutionBwdDataAlgo_t>(algo),
&workspace_size);
if (status == CUDNN_STATUS_SUCCESS) {
if (status == CUDNN_STATUS_SUCCESS &&
workspace_size <= workspace_size_limit) {
max_workspace_size = std::max(workspace_size, max_workspace_size);
}
}
return std::min(max_workspace_size, workspace_size_limit);
return max_workspace_size;
} else {
return workspace_size_limit;
}
......@@ -588,11 +590,12 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
args.cdesc.desc(), args.wdesc.desc(),
static_cast<cudnnConvolutionBwdFilterAlgo_t>(algo),
&workspace_size);
if (status == CUDNN_STATUS_SUCCESS) {
if (status == CUDNN_STATUS_SUCCESS &&
workspace_size <= workspace_size_limit) {
max_workspace_size = std::max(workspace_size, max_workspace_size);
}
}
return std::min(max_workspace_size, workspace_size_limit);
return max_workspace_size;
} else {
return workspace_size_limit;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册