未验证 提交 f962bd34 编写于 作者: L Leo Chen 提交者: GitHub

Fix cudnn workspace limit in cudnn-8 (#28611)

上级 90805e2d
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/conv_search_cache.h" #include "paddle/fluid/framework/conv_search_cache.h"
#include "paddle/fluid/framework/operator_kernel_configs.h" #include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h"
...@@ -101,6 +102,24 @@ inline int MaxBwdFilterAlgos(cudnnHandle_t cudnn_handle) { ...@@ -101,6 +102,24 @@ inline int MaxBwdFilterAlgos(cudnnHandle_t cudnn_handle) {
return max_algos; return max_algos;
} }
template <typename PerfType, typename AlgoType>
void ChooseAlgoByWorkspace(PerfType* perf_results, size_t perf_num,
size_t workspace_byte, AlgoType* algo) {
for (size_t i = 0; i < perf_num; ++i) {
auto result = perf_results[i];
if (result.status == CUDNN_STATUS_SUCCESS &&
result.memory < workspace_byte) {
*algo = result.algo;
VLOG(3) << " algo: " << result.algo << ", time: " << result.time
<< " ms, wksp = " << result.memory
<< ", status = " << result.status;
return;
}
}
VLOG(3) << "Can not find alog that requires memory < "
<< static_cast<double>(workspace_byte) / (1 << 20) << " MB";
}
template <typename PerfType, typename AlgoType> template <typename PerfType, typename AlgoType>
void ChooseAlgo(const std::vector<PerfType>& perf_results, void ChooseAlgo(const std::vector<PerfType>& perf_results,
size_t workspace_byte, AlgoType* algo) { size_t workspace_byte, AlgoType* algo) {
...@@ -219,7 +238,10 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> { ...@@ -219,7 +238,10 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
if (workspace_size > workspace_size_limit) { if (workspace_size > workspace_size_limit) {
#if CUDNN_VERSION >= 8000 #if CUDNN_VERSION >= 8000
workspace_size_limit = workspace_size; // cudnnGetConvolutionForwardAlgorithm is removed in CUDNN-8
ChooseAlgoByWorkspace<perf_t, algo_t>(perf_results.get(),
kNUM_CUDNN_FWD_ALGS,
workspace_size_limit, &algo);
#else #else
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue " VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
"the workspace size request(" "the workspace size request("
...@@ -316,7 +338,6 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> { ...@@ -316,7 +338,6 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
size_t workspace_size = 0; size_t workspace_size = 0;
bool has_got_workspace_size = true; bool has_got_workspace_size = true;
algo_t algo; algo_t algo;
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) #if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) { if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
...@@ -362,9 +383,10 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> { ...@@ -362,9 +383,10 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
if (workspace_size > workspace_size_limit) { if (workspace_size > workspace_size_limit) {
has_got_workspace_size = false; has_got_workspace_size = false;
#if CUDNN_VERSION >= 8000 #if CUDNN_VERSION >= 8000
// There is no cudnnGetConvolutionBackwardDataAlgorithm in CUDNN 8 // cudnnGetConvolutionBackwardDataAlgorithm is removed in CUDNN-8
// version. ChooseAlgoByWorkspace<perf_t, algo_t>(perf_results.get(),
workspace_size_limit = workspace_size; kNUM_CUDNN_BWD_DATA_ALGS,
workspace_size_limit, &algo);
#else #else
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue " VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
"the workspace size request(" "the workspace size request("
...@@ -493,6 +515,23 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -493,6 +515,23 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
workspace_size = GetWorkspaceSize(args, algo); workspace_size = GetWorkspaceSize(args, algo);
if (workspace_size > workspace_size_limit) { if (workspace_size > workspace_size_limit) {
workspace_size = workspace_size_limit; workspace_size = workspace_size_limit;
#if CUDNN_VERSION >= 8000
// cudnnGetConvolutionBackwardFilterAlgorithm is removed in CUDNN-8
ChooseAlgoByWorkspace<perf_t, algo_t>(perf_results.get(),
kNUM_CUDNN_BWD_FILTER_ALGS,
workspace_size_limit, &algo);
#else
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
"the workspace size request("
<< workspace_size << ") exceeds the limit("
<< workspace_size_limit << ")";
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
args.handle, args.idesc.desc(), args.odesc.desc(),
args.cdesc.desc(), args.wdesc.desc(),
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
#endif
} }
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册