未验证 提交 c10157a5 编写于 作者: W wangchaochaohu 提交者: GitHub

revise the cudnn conv choose algorithm to improve the performance(mask rcnn benchmark) (#17753)

* revise conv layer cudnn algo choose test=develop

* update for code style test=develop

* update for code style test=develop
上级 863c7516
......@@ -166,10 +166,23 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
// TODO(dangqingqing) simplify the following code by SearchAlgorithm in
// conv_cudnn_helper.h
if ((!exhaustive_search) && (!half_float)) {
#if CUDNN_VERSION >= 7001
using perf_t = cudnnConvolutionFwdAlgoPerf_t;
int perf_count;
int best_algo_idx = 0;
std::unique_ptr<perf_t[]> perf_results(new perf_t[kNUM_CUDNN_FWD_ALGS]);
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm_v7(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, kNUM_CUDNN_FWD_ALGS, &perf_count,
perf_results.get()));
algo = (perf_results.get())[best_algo_idx].algo;
#else
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
#endif
VLOG(3) << "cuDNN forward algo " << algo;
} else if (exhaustive_search && (!half_float)) {
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>& algo_cache =
......@@ -388,6 +401,37 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
} else if (FLAGS_cudnn_deterministic) {
data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
} else {
#if CUDNN_VERSION >= 7001
using perf_t = cudnnConvolutionBwdDataAlgoPerf_t;
int perf_count;
int best_algo_idx = 0;
std::unique_ptr<perf_t[]> perf_results(
new perf_t[kNUM_CUDNN_BWD_DATA_ALGS]);
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm_v7(
handle, cudnn_filter_desc,
// dyDesc: Handle to the previously initialized input
// differential
// tensor descriptor.
cudnn_output_grad_desc, cudnn_conv_desc,
// dxDesc: Handle to the previously initialized output tensor
// descriptor.
cudnn_input_desc, kNUM_CUDNN_BWD_DATA_ALGS, &perf_count,
perf_results.get()));
data_algo = (perf_results.get())[best_algo_idx].algo;
int stride_dim = input->dims().size() - 2;
bool blacklist =
std::any_of(strides.begin(), strides.begin() + stride_dim,
[=](int n) { return n != 1; });
if (blacklist && (static_cast<cudnnConvolutionBwdDataAlgo_t>(
perf_results[best_algo_idx].algo) ==
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING ||
static_cast<cudnnConvolutionBwdDataAlgo_t>(
perf_results[best_algo_idx].algo) ==
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) {
data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
}
#else
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
handle, cudnn_filter_desc,
......@@ -400,6 +444,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
cudnn_input_desc,
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &data_algo));
#endif
}
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
......@@ -437,12 +482,27 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
} else if (FLAGS_cudnn_deterministic) {
filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
} else {
#if CUDNN_VERSION >= 7001
using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t;
int perf_count;
int best_algo_idx = 0;
std::unique_ptr<perf_t[]> perf_results(
new perf_t[kNUM_CUDNN_BWD_FILTER_ALGS]);
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm_v7(
handle, cudnn_input_desc, cudnn_output_grad_desc,
cudnn_conv_desc, cudnn_filter_desc, kNUM_CUDNN_BWD_FILTER_ALGS,
&perf_count, perf_results.get()));
filter_algo = (perf_results.get())[best_algo_idx].algo;
#else
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
handle, cudnn_input_desc, cudnn_output_grad_desc,
cudnn_conv_desc, cudnn_filter_desc,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &filter_algo));
#endif
}
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
......
......@@ -172,16 +172,19 @@ CUDNN_DNN_ROUTINE_EACH_R6(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif
#if CUDNN_VERSION >= 7001
#define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \
__macro(cudnnSetConvolutionGroupCount); \
__macro(cudnnSetConvolutionMathType); \
__macro(cudnnConvolutionBiasActivationForward); \
__macro(cudnnCreateCTCLossDescriptor); \
__macro(cudnnDestroyCTCLossDescriptor); \
__macro(cudnnGetCTCLossDescriptor); \
__macro(cudnnSetCTCLossDescriptor); \
__macro(cudnnGetCTCLossWorkspaceSize); \
__macro(cudnnCTCLoss);
#define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \
__macro(cudnnSetConvolutionGroupCount); \
__macro(cudnnSetConvolutionMathType); \
__macro(cudnnConvolutionBiasActivationForward); \
__macro(cudnnCreateCTCLossDescriptor); \
__macro(cudnnDestroyCTCLossDescriptor); \
__macro(cudnnGetCTCLossDescriptor); \
__macro(cudnnSetCTCLossDescriptor); \
__macro(cudnnGetCTCLossWorkspaceSize); \
__macro(cudnnCTCLoss); \
__macro(cudnnGetConvolutionBackwardDataAlgorithm_v7); \
__macro(cudnnGetConvolutionBackwardFilterAlgorithm_v7); \
__macro(cudnnGetConvolutionForwardAlgorithm_v7);
CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册