未验证 提交 c10157a5 编写于 作者: W wangchaochaohu 提交者: GitHub

revise the cudnn conv choose algorithm to improve the performance(mask rcnn benchmark) (#17753)

* revise conv layer cudnn algo choose test=develop

* update for code style test=develop

* update for code style test=develop
上级 863c7516
...@@ -166,10 +166,23 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -166,10 +166,23 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
// TODO(dangqingqing) simplify the following code by SearchAlgorithm in // TODO(dangqingqing) simplify the following code by SearchAlgorithm in
// conv_cudnn_helper.h // conv_cudnn_helper.h
if ((!exhaustive_search) && (!half_float)) { if ((!exhaustive_search) && (!half_float)) {
#if CUDNN_VERSION >= 7001
using perf_t = cudnnConvolutionFwdAlgoPerf_t;
int perf_count;
int best_algo_idx = 0;
std::unique_ptr<perf_t[]> perf_results(new perf_t[kNUM_CUDNN_FWD_ALGS]);
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm_v7(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, kNUM_CUDNN_FWD_ALGS, &perf_count,
perf_results.get()));
algo = (perf_results.get())[best_algo_idx].algo;
#else
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo)); workspace_size_limit, &algo));
#endif
VLOG(3) << "cuDNN forward algo " << algo; VLOG(3) << "cuDNN forward algo " << algo;
} else if (exhaustive_search && (!half_float)) { } else if (exhaustive_search && (!half_float)) {
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>& algo_cache = AlgorithmsCache<cudnnConvolutionFwdAlgo_t>& algo_cache =
...@@ -388,6 +401,37 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -388,6 +401,37 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
} else if (FLAGS_cudnn_deterministic) { } else if (FLAGS_cudnn_deterministic) {
data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
} else { } else {
#if CUDNN_VERSION >= 7001
using perf_t = cudnnConvolutionBwdDataAlgoPerf_t;
int perf_count;
int best_algo_idx = 0;
std::unique_ptr<perf_t[]> perf_results(
new perf_t[kNUM_CUDNN_BWD_DATA_ALGS]);
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm_v7(
handle, cudnn_filter_desc,
// dyDesc: Handle to the previously initialized input
// differential
// tensor descriptor.
cudnn_output_grad_desc, cudnn_conv_desc,
// dxDesc: Handle to the previously initialized output tensor
// descriptor.
cudnn_input_desc, kNUM_CUDNN_BWD_DATA_ALGS, &perf_count,
perf_results.get()));
data_algo = (perf_results.get())[best_algo_idx].algo;
int stride_dim = input->dims().size() - 2;
bool blacklist =
std::any_of(strides.begin(), strides.begin() + stride_dim,
[=](int n) { return n != 1; });
if (blacklist && (static_cast<cudnnConvolutionBwdDataAlgo_t>(
perf_results[best_algo_idx].algo) ==
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING ||
static_cast<cudnnConvolutionBwdDataAlgo_t>(
perf_results[best_algo_idx].algo) ==
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) {
data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
}
#else
CUDNN_ENFORCE( CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
handle, cudnn_filter_desc, handle, cudnn_filter_desc,
...@@ -400,6 +444,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -400,6 +444,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
cudnn_input_desc, cudnn_input_desc,
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &data_algo)); workspace_size_limit, &data_algo));
#endif
} }
CUDNN_ENFORCE( CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
...@@ -437,12 +482,27 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -437,12 +482,27 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
} else if (FLAGS_cudnn_deterministic) { } else if (FLAGS_cudnn_deterministic) {
filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
} else { } else {
#if CUDNN_VERSION >= 7001
using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t;
int perf_count;
int best_algo_idx = 0;
std::unique_ptr<perf_t[]> perf_results(
new perf_t[kNUM_CUDNN_BWD_FILTER_ALGS]);
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm_v7(
handle, cudnn_input_desc, cudnn_output_grad_desc,
cudnn_conv_desc, cudnn_filter_desc, kNUM_CUDNN_BWD_FILTER_ALGS,
&perf_count, perf_results.get()));
filter_algo = (perf_results.get())[best_algo_idx].algo;
#else
CUDNN_ENFORCE( CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
handle, cudnn_input_desc, cudnn_output_grad_desc, handle, cudnn_input_desc, cudnn_output_grad_desc,
cudnn_conv_desc, cudnn_filter_desc, cudnn_conv_desc, cudnn_filter_desc,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &filter_algo)); workspace_size_limit, &filter_algo));
#endif
} }
CUDNN_ENFORCE( CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
......
...@@ -181,7 +181,10 @@ CUDNN_DNN_ROUTINE_EACH_R6(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) ...@@ -181,7 +181,10 @@ CUDNN_DNN_ROUTINE_EACH_R6(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
__macro(cudnnGetCTCLossDescriptor); \ __macro(cudnnGetCTCLossDescriptor); \
__macro(cudnnSetCTCLossDescriptor); \ __macro(cudnnSetCTCLossDescriptor); \
__macro(cudnnGetCTCLossWorkspaceSize); \ __macro(cudnnGetCTCLossWorkspaceSize); \
__macro(cudnnCTCLoss); __macro(cudnnCTCLoss); \
__macro(cudnnGetConvolutionBackwardDataAlgorithm_v7); \
__macro(cudnnGetConvolutionBackwardFilterAlgorithm_v7); \
__macro(cudnnGetConvolutionForwardAlgorithm_v7);
CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册