diff --git a/paddle/fluid/operators/conv_cudnn_helper.h b/paddle/fluid/operators/conv_cudnn_helper.h index 3c29c60b215655269b2ff683eb13fe4a4700ef0a..1311f812be118b2b7bfce55e9b95beece7c48dea 100644 --- a/paddle/fluid/operators/conv_cudnn_helper.h +++ b/paddle/fluid/operators/conv_cudnn_helper.h @@ -276,11 +276,12 @@ struct SearchAlgorithm { args.handle, args.idesc.desc(), args.wdesc.desc(), args.cdesc.desc(), args.odesc.desc(), static_cast(algo), &workspace_size); - if (status == CUDNN_STATUS_SUCCESS) { + if (status == CUDNN_STATUS_SUCCESS && + workspace_size <= workspace_size_limit) { max_workspace_size = std::max(workspace_size, max_workspace_size); } } - return std::min(max_workspace_size, workspace_size_limit); + return max_workspace_size; } else { return workspace_size_limit; } @@ -425,11 +426,12 @@ struct SearchAlgorithm { args.cdesc.desc(), args.idesc.desc(), static_cast(algo), &workspace_size); - if (status == CUDNN_STATUS_SUCCESS) { + if (status == CUDNN_STATUS_SUCCESS && + workspace_size <= workspace_size_limit) { max_workspace_size = std::max(workspace_size, max_workspace_size); } } - return std::min(max_workspace_size, workspace_size_limit); + return max_workspace_size; } else { return workspace_size_limit; } @@ -588,11 +590,12 @@ struct SearchAlgorithm { args.cdesc.desc(), args.wdesc.desc(), static_cast(algo), &workspace_size); - if (status == CUDNN_STATUS_SUCCESS) { + if (status == CUDNN_STATUS_SUCCESS && + workspace_size <= workspace_size_limit) { max_workspace_size = std::max(workspace_size, max_workspace_size); } } - return std::min(max_workspace_size, workspace_size_limit); + return max_workspace_size; } else { return workspace_size_limit; }