diff --git a/paddle/fluid/operators/conv_cudnn_helper.h b/paddle/fluid/operators/conv_cudnn_helper.h index c2ad468fa6029158e6f5aaaafda1b6125fec954f..4a5cd3262217941461f1e950056d64e29834eddb 100644 --- a/paddle/fluid/operators/conv_cudnn_helper.h +++ b/paddle/fluid/operators/conv_cudnn_helper.h @@ -14,11 +14,11 @@ limitations under the License. */ #pragma once +#include #include #include "paddle/fluid/framework/operator_kernel_configs.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/platform/cudnn_desc.h" - namespace paddle { namespace operators { @@ -57,16 +57,57 @@ struct SearchAlgorithm { bool deterministic, int algo_cache_id, const framework::ExecutionContext& ctx) { auto dtype = platform::CudnnDataType::type; + bool has_got_workspace_size = true; bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF); - size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024; - + size_t workspace_size = 0; algo_t algo; + +#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) + auto& dev_ctx = ctx.template device_context(); + if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) { + CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( + args.cdesc.desc(), CUDNN_TENSOR_OP_MATH)); + VLOG(5) << "use cudnn_tensor_op_math"; + } else { + CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( + args.cdesc.desc(), CUDNN_DEFAULT_MATH)); + VLOG(5) << "NOT use cudnn_tensor_op_math"; + } +#endif + if (!exhaustive) { +#if CUDNN_VERSION >= 7001 + int perf_count; + int best_algo_idx = 0; + std::unique_ptr perf_results(new perf_t[kNUM_CUDNN_FWD_ALGS]); + CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm_v7( + args.handle, args.idesc.desc(), args.wdesc.desc(), args.cdesc.desc(), + args.odesc.desc(), kNUM_CUDNN_FWD_ALGS, &perf_count, + perf_results.get())); + algo = (perf_results.get())[best_algo_idx].algo; + workspace_size = GetWorkspaceSize(args, algo); + + if (workspace_size > workspace_size_limit) { + has_got_workspace_size = false; + VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue " + "the workspace size request(" + << workspace_size << ") exceeds the limit(" + << workspace_size_limit << ")"; + } + if (!has_got_workspace_size) { + CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( + args.handle, args.idesc.desc(), args.wdesc.desc(), + args.cdesc.desc(), args.odesc.desc(), + CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, workspace_size_limit, + &algo)); + } +#else CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( args.handle, args.idesc.desc(), args.wdesc.desc(), args.cdesc.desc(), args.odesc.desc(), CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, workspace_size_limit, &algo)); +#endif VLOG(3) << "choose algo " << algo; } else { AlgorithmsCache& algo_cache = @@ -128,15 +169,72 @@ struct SearchAlgorithm { const framework::ExecutionContext& ctx) { auto dtype = platform::CudnnDataType::type; bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF); - size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024; - + size_t workspace_size = 0; + bool has_got_workspace_size = true; algo_t algo; + +#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) + auto& dev_ctx = ctx.template device_context(); + if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) { + CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( + args.cdesc.desc(), CUDNN_TENSOR_OP_MATH)); + VLOG(5) << "use cudnn_tensor_op_math"; + } else { + CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( + args.cdesc.desc(), CUDNN_DEFAULT_MATH)); + VLOG(5) << "NOT use cudnn_tensor_op_math"; + } +#endif + if (!exhaustive && !deterministic) { +#if CUDNN_VERSION >= 7001 + int perf_count; + int best_algo_idx = 0; + std::unique_ptr perf_results( + new perf_t[kNUM_CUDNN_BWD_DATA_ALGS]); + CUDNN_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm_v7( + args.handle, args.wdesc.desc(), args.odesc.desc(), + args.cdesc.desc(), args.idesc.desc(), kNUM_CUDNN_BWD_DATA_ALGS, + &perf_count, perf_results.get())); + algo = (perf_results.get())[best_algo_idx].algo; + +#if CUDNN_VERSION < 7500 + int stride_dim = args.x->dims().size() - 2; + bool blacklist = std::any_of(args.s.begin(), args.s.begin() + stride_dim, + [=](int n) { return n != 1; }); + if (blacklist && (static_cast( + perf_results[best_algo_idx].algo) == + CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING || + static_cast( + perf_results[best_algo_idx].algo) == + CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) { + algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + } +#endif + workspace_size = GetWorkspaceSize(args, algo); + if (workspace_size > workspace_size_limit) { + has_got_workspace_size = false; + VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue " + "the workspace size request(" + << workspace_size << ") exceeds the limit(" + << workspace_size_limit << ")"; + } + if (!has_got_workspace_size) { + CUDNN_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( + args.handle, args.wdesc.desc(), args.odesc.desc(), + args.cdesc.desc(), args.idesc.desc(), + CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &algo)); + } +#else CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( - args.handle, args.wdesc.desc(), args.idesc.desc(), args.cdesc.desc(), - args.odesc.desc(), CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, + args.handle, args.wdesc.desc(), args.odesc.desc(), args.cdesc.desc(), + args.idesc.desc(), CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, workspace_size_limit, &algo)); +#endif } else if (deterministic) { return CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; } else { @@ -186,8 +284,8 @@ struct SearchAlgorithm { size_t workspace_size = 0; CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( - args.handle, args.wdesc.desc(), args.idesc.desc(), - args.cdesc.desc(), args.odesc.desc(), algo, &workspace_size)); + args.handle, args.wdesc.desc(), args.odesc.desc(), + args.cdesc.desc(), args.idesc.desc(), algo, &workspace_size)); return workspace_size; } }; @@ -203,17 +301,61 @@ struct SearchAlgorithm { const framework::ExecutionContext& ctx) { auto dtype = platform::CudnnDataType::type; bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF); - size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024; + size_t workspace_size = 0; + bool has_got_workspace_size = true; + +#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) + auto& dev_ctx = ctx.template device_context(); + if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) { + CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( + args.cdesc.desc(), CUDNN_TENSOR_OP_MATH)); + VLOG(5) << "use cudnn_tensor_op_math"; + } else { + CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( + args.cdesc.desc(), CUDNN_DEFAULT_MATH)); + VLOG(5) << "NOT use cudnn_tensor_op_math"; + } +#endif algo_t algo; if (!exhaustive && !deterministic) { +#if CUDNN_VERSION >= 7001 + using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t; + int perf_count; + int best_algo_idx = 0; + std::unique_ptr perf_results( + new perf_t[kNUM_CUDNN_BWD_FILTER_ALGS]); + CUDNN_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm_v7( + args.handle, args.idesc.desc(), args.odesc.desc(), + args.cdesc.desc(), args.wdesc.desc(), kNUM_CUDNN_BWD_FILTER_ALGS, + &perf_count, perf_results.get())); + algo = (perf_results.get())[best_algo_idx].algo; + workspace_size = GetWorkspaceSize(args, algo); + if (workspace_size > workspace_size_limit) { + has_got_workspace_size = false; + VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue " + "the workspace size request(" + << workspace_size << ") exceeds the limit(" + << workspace_size_limit << ")"; + } + if (!has_got_workspace_size) { + CUDNN_ENFORCE( + platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( + args.handle, args.idesc.desc(), args.odesc.desc(), + args.cdesc.desc(), args.wdesc.desc(), + CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &algo)); + } +#else CUDNN_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( args.handle, args.idesc.desc(), args.odesc.desc(), args.cdesc.desc(), args.wdesc.desc(), CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, workspace_size_limit, &algo)); +#endif } else if (deterministic) { return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; } else { diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index 3dfd1b4ef2dc2194388600e1b9027a4369dddfc6..ec0278e5a230ec9c5cbb38855d0c2a07912f332c 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -79,44 +79,42 @@ class CUDNNConvOpKernel : public framework::OpKernel { std::vector paddings = ctx.Attr>("paddings"); std::vector dilations = ctx.Attr>("dilations"); int groups = ctx.Attr("groups"); - int64_t user_workspace_size = - static_cast(ctx.Attr("workspace_size_MB")); bool exhaustive_search = FLAGS_cudnn_exhaustive_search || ctx.Attr("exhaustive_search"); + if (exhaustive_search && FLAGS_cudnn_deterministic) { + PADDLE_THROW( + "Cann't set exhaustive_search True and " + "FLAGS_cudnn_deterministic True at same time."); + } + const T* input_data = input->data(); const T* filter_data = filter->data(); T* output_data = output->mutable_data(ctx.GetPlace()); - // ------------------- cudnn descriptors --------------------- - ScopedTensorDescriptor input_desc; - ScopedTensorDescriptor output_desc; - ScopedFilterDescriptor filter_desc; - ScopedConvolutionDescriptor conv_desc; + ConvArgs args{input, filter, output, strides, paddings, dilations}; + auto handle = dev_ctx.cudnn_handle(); + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); + auto dtype = platform::CudnnDataType::type; DataLayout layout = DataLayout::kNCHW; if (input->dims().size() == 5) { layout = DataLayout::kNCDHW; } + auto layout_format = GetCudnnTensorFormat(layout); - cudnnConvolutionDescriptor_t cudnn_conv_desc = - conv_desc.descriptor(paddings, strides, dilations); - + args.handle = handle; + args.cdesc.set(dtype, paddings, strides, dilations); #if CUDNN_VERSION_MIN(7, 0, 1) // cudnn 7 can support groups, no need to do it manually // FIXME(typhoonzero): find a better way to disable groups // rather than setting it to 1. CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionGroupCount( - cudnn_conv_desc, groups)); + args.cdesc.desc(), groups)); groups = 1; #endif - - cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( - layout, framework::vectorize2int(input->dims()), groups); - cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( - layout, framework::vectorize2int(output->dims()), groups); - cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( - layout, framework::vectorize2int(filter->dims()), groups); - + args.idesc.set(*input, groups); + args.wdesc.set(*filter, layout_format, groups); + args.odesc.set(*output, groups); int i_n, i_c, i_d, i_h, i_w; GetNCDHW(input->dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w); int o_n, o_c, o_d, o_h, o_w; @@ -126,149 +124,27 @@ class CUDNNConvOpKernel : public framework::OpKernel { int group_offset_out = o_c / groups * o_h * o_w * o_d; int group_offset_filter = filter->numel() / groups; // ------------------- cudnn conv workspace --------------------- - size_t workspace_size_in_bytes; // final workspace to allocate. - size_t workspace_size_limit = 0; - if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) { - int64_t max_user_size = - std::min(static_cast(FLAGS_conv_workspace_size_limit), - user_workspace_size); - workspace_size_limit = max_user_size * 1024 * 1024; - } - + size_t workspace_size = 0; // final workspace to allocate. // ------------------- cudnn conv algorithm --------------------- cudnnConvolutionFwdAlgo_t algo{}; - bool half_float = false; - -#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) - // Tensor core is supported since the volta GPU and - // is only enabled when input and filter data are float16 - if (dev_ctx.GetComputeCapability() >= 70 && - std::type_index(typeid(T)) == - std::type_index(typeid(platform::float16))) { - CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( - cudnn_conv_desc, CUDNN_TENSOR_OP_MATH)); - // Currently tensor core is only enabled using this algo - algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; - half_float = true; - VLOG(5) << "use cudnn_tensor_op_math"; - } else { - CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( - cudnn_conv_desc, CUDNN_DEFAULT_MATH)); - VLOG(5) << "NOT use cudnn_tensor_op_math"; - } -#endif - auto handle = dev_ctx.cudnn_handle(); - auto workspace_handle = dev_ctx.cudnn_workspace_handle(); - auto x_dims = framework::vectorize(input->dims()); - auto f_dims = framework::vectorize(filter->dims()); - - // TODO(dangqingqing) simplify the following code by SearchAlgorithm in - // conv_cudnn_helper.h - bool has_got_workspace_size = false; - if ((!exhaustive_search) && (!half_float)) { -#if CUDNN_VERSION >= 7001 - using perf_t = cudnnConvolutionFwdAlgoPerf_t; - int perf_count; - int best_algo_idx = 0; - std::unique_ptr perf_results(new perf_t[kNUM_CUDNN_FWD_ALGS]); - CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm_v7( - handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, - cudnn_output_desc, kNUM_CUDNN_FWD_ALGS, &perf_count, - perf_results.get())); - algo = (perf_results.get())[best_algo_idx].algo; - - // get workspace size able to allocate - CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( - handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, - cudnn_output_desc, algo, &workspace_size_in_bytes)); - - // NOTE(zjl): cudnnGetConvolutionForwardAlgorithm_v7 cannot limit - // workspace size. If the workspace size found by v7 exceeds the limit, - // we should fallback to non-v7 method to find another algorithm. - if (workspace_size_in_bytes > workspace_size_limit) { - VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue " - "the workspace size request(" - << workspace_size_in_bytes << ") exceeds the limit(" - << workspace_size_limit << ")"; -#endif - CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( - handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, - cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_size_limit, &algo)); -#if CUDNN_VERSION >= 7001 - } else { - has_got_workspace_size = true; - } -#endif + using search = SearchAlgorithm; + algo = search::Find(args, exhaustive_search, false, 0, ctx); + workspace_size = search::GetWorkspaceSize(args, algo); - VLOG(3) << "cuDNN forward algo " << algo; - } else if (exhaustive_search && (!half_float)) { - AlgorithmsCache& algo_cache = - ctx.GetKernelConfig>(0); - - algo = algo_cache.GetAlgorithm( - x_dims, f_dims, strides, paddings, dilations, 0, [&]() { - int returned_algo_count; - std::array - fwd_perf_stat; - - auto cudnn_find_func = [&](void* cudnn_workspace) { - CUDNN_ENFORCE( - platform::dynload::cudnnFindConvolutionForwardAlgorithmEx( - handle, cudnn_input_desc, input_data, cudnn_filter_desc, - filter_data, cudnn_conv_desc, cudnn_output_desc, - output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count, - fwd_perf_stat.data(), cudnn_workspace, - workspace_size_limit)); - }; - workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit); - - VLOG(3) << "Perf result: (algo: stat, time, memory)"; - for (int i = 0; i < returned_algo_count; ++i) { - const auto& stat = fwd_perf_stat[i]; - VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time - << " " << stat.memory; - } - return fwd_perf_stat[0].algo; - }); - VLOG(3) << "choose algo " << algo; - } else { - PADDLE_ENFORCE(half_float, - "cuDNN exhaustive search doesn't support half float."); - } - - if (!has_got_workspace_size) { - // get workspace size able to allocate - CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( - handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, - cudnn_output_desc, algo, &workspace_size_in_bytes)); - } - - // It is possible for float16 on Volta GPU to allocate more memory than - // the limit because the algo is overrided to use tensor core. - PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit, - "workspace_size to be allocated exceeds the limit"); - - // Allocate on GPU memory - Tensor cudnn_workspace = - ctx.AllocateTmpTensor( - framework::make_ddim( - {static_cast(workspace_size_in_bytes)}), - dev_ctx); - void* cudnn_workspace_ptr = - static_cast(cudnn_workspace.data()); - VLOG(2) << "Cudnn workspace size fwd: " - << static_cast(workspace_size_in_bytes) / (1 << 20) - << " MB"; // ------------------- cudnn conv forward --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; for (int i = 0; i < groups; i++) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( - handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, - cudnn_filter_desc, filter_data + i * group_offset_filter, - cudnn_conv_desc, algo, cudnn_workspace_ptr, workspace_size_in_bytes, - &beta, cudnn_output_desc, output_data + i * group_offset_out)); + workspace_handle.RunFunc( + [&](void* workspace_ptr) { + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( + handle, &alpha, args.idesc.desc(), + input_data + i * group_offset_in, args.wdesc.desc(), + filter_data + i * group_offset_filter, args.cdesc.desc(), algo, + workspace_ptr, workspace_size, &beta, args.odesc.desc(), + output_data + i * group_offset_out)); + }, + workspace_size); } } }; @@ -294,62 +170,30 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { std::vector paddings = ctx.Attr>("paddings"); std::vector dilations = ctx.Attr>("dilations"); int groups = ctx.Attr("groups"); - int64_t user_workspace_size = - static_cast(ctx.Attr("workspace_size_MB")); bool exhaustive_search = FLAGS_cudnn_exhaustive_search || ctx.Attr("exhaustive_search"); - if (exhaustive_search && FLAGS_cudnn_deterministic) { + bool deterministic = FLAGS_cudnn_deterministic; + if (exhaustive_search && deterministic) { PADDLE_THROW( "Can't set exhaustive_search True and " "FLAGS_cudnn_deterministic True at same time."); } - // ------------------- cudnn descriptors --------------------- - ScopedTensorDescriptor input_desc; - ScopedTensorDescriptor output_grad_desc; - - ScopedFilterDescriptor filter_desc; - ScopedFilterDescriptor filter_grad_desc; - ScopedConvolutionDescriptor conv_desc; + T* filter_grad_data = nullptr; + T* input_grad_data = nullptr; + ConvArgs args1{input_grad, filter, output_grad, + strides, paddings, dilations}; + ConvArgs args2{input, filter_grad, output_grad, + strides, paddings, dilations}; + // conv_cudnn_helper.h + auto handle = dev_ctx.cudnn_handle(); + auto dtype = platform::CudnnDataType::type; DataLayout layout = DataLayout::kNCHW; if (input->dims().size() == 5) { layout = DataLayout::kNCDHW; } - - cudnnConvolutionDescriptor_t cudnn_conv_desc = - conv_desc.descriptor(paddings, strides, dilations); - -#if CUDNN_VERSION_MIN(7, 0, 1) - // cudnn 7 can support groups, no need to do it manually - // FIXME(typhoonzero): find a better way to disable groups - // rather than setting it to 1. - CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionGroupCount( - cudnn_conv_desc, groups)); - groups = 1; -#endif - - cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( - layout, framework::vectorize2int(input->dims()), groups); - cudnnTensorDescriptor_t cudnn_output_grad_desc = - output_grad_desc.descriptor( - layout, framework::vectorize2int(output_grad->dims()), groups); - cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( - layout, framework::vectorize2int(filter->dims()), groups); - -#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) - // Enable Tensor Core for cudnn backward - if (dev_ctx.GetComputeCapability() >= 70 && - std::type_index(typeid(T)) == - std::type_index(typeid(platform::float16))) { - CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( - cudnn_conv_desc, CUDNN_TENSOR_OP_MATH)); - VLOG(5) << "use cudnn_tensor_op_math for backward"; - } else { - CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( - cudnn_conv_desc, CUDNN_DEFAULT_MATH)); - VLOG(5) << "NOT use cudnn_tensor_op_math for backward"; - } -#endif + auto layout_tensor = GetCudnnTensorFormat(layout); + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); int i_n, i_c, i_d, i_h, i_w; GetNCDHW(input->dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w); @@ -361,263 +205,83 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { int group_offset_out = o_c / groups * o_h * o_w * o_d; int group_offset_filter = filter->numel() / groups; // ------------------- cudnn backward algorithm --------------------- - cudnnConvolutionBwdDataAlgo_t data_algo{}; - cudnnConvolutionBwdFilterAlgo_t filter_algo{}; - size_t workspace_size_in_bytes = 0, tmp_size = 0; - size_t workspace_size_limit = 0; - if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) { - int64_t max_user_size = - std::min(static_cast(FLAGS_conv_workspace_size_limit), - user_workspace_size); - workspace_size_limit = max_user_size * 1024 * 1024; - } - - Tensor cudnn_workspace; - void* cudnn_workspace_ptr = nullptr; - if ((input_data || filter_data) && exhaustive_search) { - cudnn_workspace = - ctx.AllocateTmpTensor( - framework::make_ddim( - {static_cast(workspace_size_limit)}), - dev_ctx); - cudnn_workspace_ptr = static_cast(cudnn_workspace.data()); - } + cudnnConvolutionBwdDataAlgo_t data_algo = + static_cast(0); + cudnnConvolutionBwdFilterAlgo_t filter_algo = + static_cast(0); + size_t workspace_size = 0; + int iwo_groups, c_groups; - // TODO(dangqingqing) simplify the following code by SearchAlgorithm in - // conv_cudnn_helper.h - auto x_dims = framework::vectorize(input->dims()); - auto f_dims = framework::vectorize(filter->dims()); - auto handle = dev_ctx.cudnn_handle(); +#if CUDNN_VERSION_MIN(7, 0, 1) + iwo_groups = 1; + c_groups = groups; + groups = 1; +#endif - bool has_got_bwd_data_ws_size = false; if (input_grad) { - T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); - if (exhaustive_search) { - AlgorithmsCache& data_algo_cache = - ctx.GetKernelConfig>( - 0); - - data_algo = data_algo_cache.GetAlgorithm( - x_dims, f_dims, strides, paddings, dilations, 0, [&]() { - int returned_algo_count; - std::array - data_perf_stat; - - CUDNN_ENFORCE(platform::dynload:: - cudnnFindConvolutionBackwardDataAlgorithmEx( - handle, cudnn_filter_desc, filter_data, - cudnn_output_grad_desc, output_grad_data, - cudnn_conv_desc, cudnn_input_desc, - input_grad_data, kNUM_CUDNN_BWD_DATA_ALGS, - &returned_algo_count, data_perf_stat.data(), - cudnn_workspace_ptr, workspace_size_limit)); - - VLOG(3) << "Perf result: (algo: stat, time, memory)"; - for (int i = 0; i < returned_algo_count; ++i) { - const auto& stat = data_perf_stat[i]; - VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time - << " " << stat.memory; - } - return data_perf_stat[0].algo; - }); - VLOG(3) << "cuDNN backward data algo " << data_algo; - } else if (FLAGS_cudnn_deterministic) { - data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; - } else { -#if CUDNN_VERSION >= 7001 - using perf_t = cudnnConvolutionBwdDataAlgoPerf_t; - int perf_count; - int best_algo_idx = 0; - std::unique_ptr perf_results( - new perf_t[kNUM_CUDNN_BWD_DATA_ALGS]); - CUDNN_ENFORCE( - platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm_v7( - handle, cudnn_filter_desc, - // dyDesc: Handle to the previously initialized input - // differential - // tensor descriptor. - cudnn_output_grad_desc, cudnn_conv_desc, - // dxDesc: Handle to the previously initialized output tensor - // descriptor. - cudnn_input_desc, kNUM_CUDNN_BWD_DATA_ALGS, &perf_count, - perf_results.get())); - data_algo = (perf_results.get())[best_algo_idx].algo; - int stride_dim = input->dims().size() - 2; - bool blacklist = - std::any_of(strides.begin(), strides.begin() + stride_dim, - [=](int n) { return n != 1; }); - if (blacklist && (static_cast( - perf_results[best_algo_idx].algo) == - CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING || - static_cast( - perf_results[best_algo_idx].algo) == - CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) { - data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; - } - - CUDNN_ENFORCE( - platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( - handle, cudnn_filter_desc, cudnn_output_grad_desc, - cudnn_conv_desc, cudnn_input_desc, data_algo, &tmp_size)); - auto new_workspace_size = std::max(workspace_size_in_bytes, tmp_size); - - if (new_workspace_size > workspace_size_limit) { - VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue " - "the workspace size request(" - << new_workspace_size << ") exceeds the limit(" - << workspace_size_limit << ")"; -#endif - CUDNN_ENFORCE( - platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( - handle, cudnn_filter_desc, - // dyDesc: Handle to the previously initialized input - // differential - // tensor descriptor. - cudnn_output_grad_desc, cudnn_conv_desc, - // dxDesc: Handle to the previously initialized output tensor - // descriptor. - cudnn_input_desc, - CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, - workspace_size_limit, &data_algo)); -#if CUDNN_VERSION >= 7001 - } else { - workspace_size_in_bytes = new_workspace_size; - has_got_bwd_data_ws_size = true; - } -#endif - } + // ------------------- cudnn descriptors --------------------- + input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + args1.handle = handle; + args1.idesc.set(*input_grad, iwo_groups); + args1.wdesc.set(*filter, layout_tensor, iwo_groups); + args1.odesc.set(*output_grad, iwo_groups); + args1.cdesc.set(dtype, paddings, strides, dilations, c_groups); - if (!has_got_bwd_data_ws_size) { - CUDNN_ENFORCE( - platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( - handle, cudnn_filter_desc, cudnn_output_grad_desc, - cudnn_conv_desc, cudnn_input_desc, data_algo, &tmp_size)); - workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); - } + using search1 = SearchAlgorithm; + data_algo = + search1::Find(args1, exhaustive_search, deterministic, 0, ctx); + workspace_size = + std::max(workspace_size, search1::GetWorkspaceSize(args1, data_algo)); } - bool has_got_bwd_filter_ws_size = false; if (filter_grad) { - T* filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); - if (exhaustive_search) { - AlgorithmsCache& f_algo_cache = - ctx.GetKernelConfig< - AlgorithmsCache>(1); - - filter_algo = f_algo_cache.GetAlgorithm( - x_dims, f_dims, strides, paddings, dilations, 0, [&]() { - int returned_algo_count; - std::array - filter_perf_stat; - - CUDNN_ENFORCE( - platform::dynload:: - cudnnFindConvolutionBackwardFilterAlgorithmEx( - handle, cudnn_input_desc, input_data, - cudnn_output_grad_desc, output_grad_data, - cudnn_conv_desc, cudnn_filter_desc, filter_grad_data, - kNUM_CUDNN_BWD_FILTER_ALGS, &returned_algo_count, - filter_perf_stat.data(), cudnn_workspace_ptr, - workspace_size_limit)); - return filter_perf_stat[0].algo; - }); - VLOG(3) << "cuDNN backward filter algo " << filter_algo; - } else if (FLAGS_cudnn_deterministic) { - filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; - } else { -#if CUDNN_VERSION >= 7001 - using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t; - int perf_count; - int best_algo_idx = 0; - std::unique_ptr perf_results( - new perf_t[kNUM_CUDNN_BWD_FILTER_ALGS]); - - CUDNN_ENFORCE( - platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm_v7( - handle, cudnn_input_desc, cudnn_output_grad_desc, - cudnn_conv_desc, cudnn_filter_desc, kNUM_CUDNN_BWD_FILTER_ALGS, - &perf_count, perf_results.get())); - filter_algo = (perf_results.get())[best_algo_idx].algo; - - CUDNN_ENFORCE( - platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( - handle, cudnn_input_desc, cudnn_output_grad_desc, - cudnn_conv_desc, cudnn_filter_desc, filter_algo, &tmp_size)); - auto new_workspace_size = std::max(workspace_size_in_bytes, tmp_size); - - if (new_workspace_size > workspace_size_limit) { - VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue " - "the workspace size request(" - << new_workspace_size << ") exceeds the limit(" - << workspace_size_limit << ")"; -#endif - CUDNN_ENFORCE( - platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( - handle, cudnn_input_desc, cudnn_output_grad_desc, - cudnn_conv_desc, cudnn_filter_desc, - CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, - workspace_size_limit, &filter_algo)); -#if CUDNN_VERSION >= 7001 - } else { - workspace_size_in_bytes = new_workspace_size; - has_got_bwd_filter_ws_size = true; - } -#endif - } - - if (!has_got_bwd_filter_ws_size) { - CUDNN_ENFORCE( - platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( - handle, cudnn_input_desc, cudnn_output_grad_desc, - cudnn_conv_desc, cudnn_filter_desc, filter_algo, &tmp_size)); - workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); - } - } - - PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit, - "workspace_size to be allocated exceeds the limit"); - - // ------------------- cudnn conv workspace --------------------- - if (!cudnn_workspace_ptr) { - cudnn_workspace = - ctx.AllocateTmpTensor( - framework::make_ddim( - {static_cast(workspace_size_in_bytes)}), - dev_ctx); - cudnn_workspace_ptr = static_cast(cudnn_workspace.data()); - VLOG(2) << "Cudnn workspace size bwd: " - << static_cast(workspace_size_in_bytes) / (1 << 20) - << " MB"; + // ------------------- cudnn descriptors --------------------- + filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); + args2.handle = handle; + args2.idesc.set(*input, iwo_groups); + args2.wdesc.set(*filter_grad, layout_tensor, iwo_groups); + args2.odesc.set(*output_grad, iwo_groups); + args2.cdesc.set(dtype, paddings, strides, dilations, c_groups); + + using search2 = SearchAlgorithm; + filter_algo = + search2::Find(args2, exhaustive_search, deterministic, 1, ctx); + workspace_size = std::max(workspace_size, + search2::GetWorkspaceSize(args2, filter_algo)); } // ------------------- cudnn conv backward data --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; if (input_grad) { - T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); // Because beta is zero, it is unnecessary to reset input_grad. - for (int i = 0; i < groups; i++) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( - handle, &alpha, cudnn_filter_desc, - filter_data + i * group_offset_filter, cudnn_output_grad_desc, - output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo, - cudnn_workspace_ptr, workspace_size_in_bytes, &beta, - cudnn_input_desc, input_grad_data + i * group_offset_in)); + workspace_handle.RunFunc( + [&](void* cudnn_workspace_ptr) { + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( + handle, &alpha, args1.wdesc.desc(), + filter_data + i * group_offset_filter, args1.odesc.desc(), + output_grad_data + i * group_offset_out, args1.cdesc.desc(), + data_algo, cudnn_workspace_ptr, workspace_size, &beta, + args1.idesc.desc(), input_grad_data + i * group_offset_in)); + }, + workspace_size); } } // ------------------- cudnn conv backward filter --------------------- if (filter_grad) { - T* filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); // Because beta is zero, it is unnecessary to reset filter_grad. for (int i = 0; i < groups; i++) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( - handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, - cudnn_output_grad_desc, output_grad_data + i * group_offset_out, - cudnn_conv_desc, filter_algo, cudnn_workspace_ptr, - workspace_size_in_bytes, &beta, cudnn_filter_desc, - filter_grad_data + i * group_offset_filter)); + workspace_handle.RunFunc( + [&](void* cudnn_workspace_ptr) { + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( + handle, &alpha, args2.idesc.desc(), + input_data + i * group_offset_in, args2.odesc.desc(), + output_grad_data + i * group_offset_out, args2.cdesc.desc(), + filter_algo, cudnn_workspace_ptr, workspace_size, &beta, + args2.wdesc.desc(), + filter_grad_data + i * group_offset_filter)); + }, + workspace_size); } } }