/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/platform/cudnn_helper.h" DEFINE_int64(cudnn_exhaustive_search_times, -1, "Exhaustive search times for cuDNN convolution, " "defalut is 1, only search once."); namespace paddle { namespace operators { #if CUDNN_VERSION >= 7100 using Tensor = framework::Tensor; using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; using ScopedActivationDescriptor = platform::ScopedActivationDescriptor; using DataLayout = platform::DataLayout; template using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; template class CUDNNConvFusionOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto& dev_ctx = ctx.template device_context(); auto* input = ctx.Input("Input"); auto* filter = ctx.Input("Filter"); auto* bias = ctx.Input("Bias"); PADDLE_ENFORCE(bias, "The bias should not be null."); auto* residual = ctx.Input("ResidualData"); auto* output = ctx.Output("Output"); std::vector strides = ctx.Attr>("strides"); std::vector paddings = ctx.Attr>("paddings"); std::vector dilations = ctx.Attr>("dilations"); const std::string activation = ctx.Attr("activation"); int groups = ctx.Attr("groups"); int64_t user_workspace_size = static_cast(ctx.Attr("workspace_size_MB")); bool exhaustive_search = FLAGS_cudnn_exhaustive_search || ctx.Attr("exhaustive_search"); const T* input_data = input->data(); const T* filter_data = filter->data(); const T* bias_data = bias->data(); T* output_data = output->mutable_data(ctx.GetPlace()); const T* residual_data = residual ? residual->data() : output_data; // ------------------- cudnn descriptors --------------------- ScopedTensorDescriptor input_desc; ScopedTensorDescriptor output_desc; ScopedFilterDescriptor filter_desc; ScopedTensorDescriptor bias_desc; ScopedConvolutionDescriptor conv_desc; ScopedActivationDescriptor act_desc; DataLayout layout = DataLayout::kNCHW; if (input->dims().size() == 5) { layout = DataLayout::kNCDHW; } cudnnConvolutionDescriptor_t cudnn_conv_desc = conv_desc.descriptor(paddings, strides, dilations); CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionGroupCount( cudnn_conv_desc, groups)); cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( layout, framework::vectorize2int(input->dims())); cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( layout, framework::vectorize2int(output->dims())); cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( layout, framework::vectorize2int(filter->dims())); // Now only support NCHW std::vector bias_dim = {1, static_cast(output->dims()[1]), 1, 1}; cudnnTensorDescriptor_t cudnn_bias_desc = bias_desc.descriptor(layout, bias_dim); cudnnActivationDescriptor_t cudnn_act_desc = act_desc.descriptor(activation); // ------------------- cudnn conv workspace --------------------- size_t workspace_size_in_bytes; // final workspace to allocate. size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) { int64_t max_user_size = std::max(static_cast(FLAGS_conv_workspace_size_limit), user_workspace_size); workspace_size_limit = max_user_size * 1024 * 1024; } // ------------------- cudnn conv algorithm --------------------- cudnnConvolutionFwdAlgo_t algo; auto handle = dev_ctx.cudnn_handle(); Tensor cudnn_workspace; void* cudnn_workspace_ptr = nullptr; CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( cudnn_conv_desc, CUDNN_DEFAULT_MATH)); auto x_dims = framework::vectorize(input->dims()); auto f_dims = framework::vectorize(filter->dims()); if (!exhaustive_search) { CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, workspace_size_limit, &algo)); VLOG(3) << "cuDNN forward algo " << algo; } else { cudnn_workspace = ctx.AllocateTmpTensor( framework::make_ddim( {static_cast(workspace_size_limit)}), dev_ctx); cudnn_workspace_ptr = static_cast(cudnn_workspace.data()); auto search_func = [&]() { int returned_algo_count; std::array fwd_perf_stat; CUDNN_ENFORCE(platform::dynload::cudnnFindConvolutionForwardAlgorithmEx( handle, cudnn_input_desc, input_data, cudnn_filter_desc, filter_data, cudnn_conv_desc, cudnn_output_desc, output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count, fwd_perf_stat.data(), cudnn_workspace_ptr, workspace_size_limit)); VLOG(3) << "Perf result: (algo: stat, time, memory)"; for (int i = 0; i < returned_algo_count; ++i) { const auto& stat = fwd_perf_stat[i]; VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time << " " << stat.memory; } return fwd_perf_stat[0].algo; }; AlgorithmsCache* algo_cache = nullptr; int search_times = ctx.Attr("search_times"); search_times = std::max( static_cast(FLAGS_cudnn_exhaustive_search_times), search_times); if (search_times > 0) { // The searched algo will be cached by `search_times` times for // different input dimension. For other dimensions, select the algo // of closest area. auto var_name = ctx.Inputs("AlgoCache")[0]; algo_cache = ctx.scope() .FindVar(var_name) ->GetMutable>(); algo = algo_cache->GetAlgorithm(x_dims[2] * x_dims[3], search_times, 0, search_func); } else { // Cache searched algo in Var(kCUDNNFwdAlgoCache). // all conv ops use the same kCUDNNFwdAlgoCache variable. if (ctx.scope().FindVar(kCUDNNFwdAlgoCache)) { algo_cache = ctx.scope() .FindVar(kCUDNNFwdAlgoCache) ->GetMutable>(); } else { // TODO(qingqing) remove const_cast algo_cache = const_cast(ctx.scope().parent()) ->Var(kCUDNNFwdAlgoCache) ->GetMutable>(); } algo = algo_cache->GetAlgorithm(x_dims, f_dims, strides, paddings, dilations, 0, search_func); } VLOG(3) << "choose algo " << algo; } CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, cudnn_output_desc, algo, &workspace_size_in_bytes)); PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit, "workspace_size to be allocated exceeds the limit"); if (!cudnn_workspace_ptr) { cudnn_workspace = ctx.AllocateTmpTensor( framework::make_ddim( {static_cast(workspace_size_in_bytes)}), dev_ctx); cudnn_workspace_ptr = static_cast(cudnn_workspace.data()); } if ((activation == "identity") && (!residual)) { // Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is // enabled with CUDNN_ACTIVATION_IDENTITY in cuDNN lib. // But test in some case, the speed is slower, change to use // cudnnConvolutionForward and cudnnAddTensor // ------------- cudnn conv forward and bias add --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( handle, &alpha, cudnn_input_desc, input_data, cudnn_filter_desc, filter_data, cudnn_conv_desc, algo, cudnn_workspace_ptr, workspace_size_in_bytes, &beta, cudnn_output_desc, output_data)); CUDNN_ENFORCE(platform::dynload::cudnnAddTensor( handle, &alpha, cudnn_bias_desc, bias_data, &alpha, cudnn_output_desc, output_data)); } else { if (activation == "identity") { algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; } // ------------------- cudnn conv+bias+act forward -------------------- ScalingParamType alpha1 = 1.0f; ScalingParamType alpha2 = residual ? 1.0f : 0.0f; CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBiasActivationForward( handle, &alpha1, cudnn_input_desc, input_data, cudnn_filter_desc, filter_data, cudnn_conv_desc, algo, cudnn_workspace_ptr, workspace_size_in_bytes, &alpha2, cudnn_output_desc, residual_data, cudnn_bias_desc, bias_data, cudnn_act_desc, cudnn_output_desc, output_data)); } std::vector channels = ctx.Attr>("split_channels"); if (channels.size()) { auto outs = ctx.MultiOutput("Outputs"); if (x_dims[0] == 1) { // share data with Output framework::Tensor t; t.ShareDataWith(*output); auto y_dims = output->dims(); t.Resize({y_dims[1], y_dims[2], y_dims[3]}); int s = 0; for (size_t i = 0; i < channels.size(); ++i) { int e = s + channels[i]; outs[i]->ShareDataWith(t.Slice(s, e)); outs[i]->Resize({x_dims[0], channels[i], y_dims[2], y_dims[3]}); s = e; } } else { // TODO(qingiqng): do copy when batch size large than 1 PADDLE_THROW("Batch size greater than 1 is Unsupported"); } } } }; #endif } // namespace operators } // namespace paddle #if CUDNN_VERSION >= 7100 namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel, ops::CUDNNConvFusionOpKernel); #endif