From 3ce879dba700ef20415e95722de1c5845deab403 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Thu, 14 Apr 2022 19:57:40 +0800 Subject: [PATCH] Optimize the finding of max workspace size. (#41741) --- paddle/fluid/operators/conv_cudnn_helper.h | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/conv_cudnn_helper.h b/paddle/fluid/operators/conv_cudnn_helper.h index 3c29c60b21..1311f812be 100644 --- a/paddle/fluid/operators/conv_cudnn_helper.h +++ b/paddle/fluid/operators/conv_cudnn_helper.h @@ -276,11 +276,12 @@ struct SearchAlgorithm { args.handle, args.idesc.desc(), args.wdesc.desc(), args.cdesc.desc(), args.odesc.desc(), static_cast(algo), &workspace_size); - if (status == CUDNN_STATUS_SUCCESS) { + if (status == CUDNN_STATUS_SUCCESS && + workspace_size <= workspace_size_limit) { max_workspace_size = std::max(workspace_size, max_workspace_size); } } - return std::min(max_workspace_size, workspace_size_limit); + return max_workspace_size; } else { return workspace_size_limit; } @@ -425,11 +426,12 @@ struct SearchAlgorithm { args.cdesc.desc(), args.idesc.desc(), static_cast(algo), &workspace_size); - if (status == CUDNN_STATUS_SUCCESS) { + if (status == CUDNN_STATUS_SUCCESS && + workspace_size <= workspace_size_limit) { max_workspace_size = std::max(workspace_size, max_workspace_size); } } - return std::min(max_workspace_size, workspace_size_limit); + return max_workspace_size; } else { return workspace_size_limit; } @@ -588,11 +590,12 @@ struct SearchAlgorithm { args.cdesc.desc(), args.wdesc.desc(), static_cast(algo), &workspace_size); - if (status == CUDNN_STATUS_SUCCESS) { + if (status == CUDNN_STATUS_SUCCESS && + workspace_size <= workspace_size_limit) { max_workspace_size = std::max(workspace_size, max_workspace_size); } } - return std::min(max_workspace_size, workspace_size_limit); + return max_workspace_size; } else { return workspace_size_limit; } -- GitLab