From 4a8708bb3c24d375ef5e7a5a49c3e7daa1afd609 Mon Sep 17 00:00:00 2001 From: Wilber Date: Wed, 4 Jan 2023 19:27:07 +0800 Subject: [PATCH] [Inference] Add conv_fusion nhwc impl. (#49047) --- .../fluid/operators/fused/conv_fusion_op.cc | 66 +- .../kernels/fusion/gpu/conv_fusion_kernel.cu | 653 ++++++++++++++++++ paddle/phi/ops/compat/conv_fusion_sig.cc | 38 + .../tests/unittests/test_conv2d_fusion_op.py | 59 +- 4 files changed, 769 insertions(+), 47 deletions(-) create mode 100644 paddle/phi/kernels/fusion/gpu/conv_fusion_kernel.cu create mode 100644 paddle/phi/ops/compat/conv_fusion_sig.cc diff --git a/paddle/fluid/operators/fused/conv_fusion_op.cc b/paddle/fluid/operators/fused/conv_fusion_op.cc index 022c21a205d..e50b42832f1 100644 --- a/paddle/fluid/operators/fused/conv_fusion_op.cc +++ b/paddle/fluid/operators/fused/conv_fusion_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/core/ddim.h" namespace paddle { namespace operators { @@ -55,6 +56,10 @@ class Conv2DFusionOpMaker : public Conv2DOpMaker { "search_times", "The number of exhaustive search times for convolution algorithm.") .SetDefault(-1); + AddAttr( + "use_cudnn", + "(bool, default false) Only used in cudnn kernel, need install cudnn") + .SetDefault(true); } }; @@ -67,31 +72,14 @@ class Conv2DFusionOp : public operators::ConvOp { OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Conv2DFusion"); OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "Conv2DFusion"); - auto in_dims = ctx->GetInputDim("Input"); - PADDLE_ENFORCE_EQ( - in_dims.size(), - 4U, - platform::errors::InvalidArgument( - "The input's dimension of Operator(Conv2DFusion) is expected " - "to be 4. But received: input's dimension = %u, shape = [%s].", - in_dims.size(), - in_dims)); - // In some case, attribute data_format is "AnyLayout". std::string data_format = ctx->Attrs().Get("data_format"); - PADDLE_ENFORCE_NE( - data_format, - "NDHWC", - platform::errors::PermissionDenied( - "Operator(Conv2DFusion) supports data format of " - "channel first (NCHW,NCDHW) and data format of channel last(NHWC) " - "now. But received: data_format = '%s'.", - data_format)); // MKL-DNN Kernels are using NCHW order of dims description // so we ignore data_format consideration for MKL-DNN kernel const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) && (data_format == "NHWC" || data_format == "NDHWC"); - std::vector output_shape = ComputeOutputShape(ctx); + std::vector output_shape = + ComputeOutputShape(ctx, data_format, channel_last); ctx->SetOutputDim("Output", phi::make_ddim(output_shape)); ctx->ShareLoD("Input", "Output"); @@ -145,8 +133,9 @@ class Conv2DFusionOp : public operators::ConvOp { } } - std::vector ComputeOutputShape( - framework::InferShapeContext* ctx) const { + std::vector ComputeOutputShape(framework::InferShapeContext* ctx, + const std::string& data_format, + bool channel_last) const { OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Conv"); OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "Conv"); @@ -170,24 +159,6 @@ class Conv2DFusionOp : public operators::ConvOp { "dilation is %d.", dilations[i])); } - const std::string data_format = - ctx->Attrs().Get("data_format"); - - // if data_format is NHWC, we convert the weight dimension to the form of - // nchw to minimize program changes. - if (data_format == "NHWC") { - int kh = filter_dims[1]; - int kw = filter_dims[2]; - int ic = filter_dims[3]; - filter_dims[1] = ic; - filter_dims[2] = kh; - filter_dims[3] = kw; - } - - // MKL-DNN Kernels are using NCHW order of dims description - // so we ignore data_format consideration for MKL-DNN kernel - const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) && - (data_format == "NHWC" || data_format == "NDHWC"); PADDLE_ENFORCE_EQ( in_dims.size() == 4 || in_dims.size() == 5, @@ -223,7 +194,6 @@ class Conv2DFusionOp : public operators::ConvOp { strides[i])); } - int in_sub_stride_size = in_dims.size() - stride_size; PADDLE_ENFORCE_EQ( in_dims.size(), strides.size() + 2U, @@ -237,14 +207,15 @@ class Conv2DFusionOp : public operators::ConvOp { in_dims, strides.size(), phi::make_ddim(strides), - in_sub_stride_size)); + in_dims.size() - stride_size)); const auto input_channels = channel_last ? in_dims[in_dims.size() - 1] : in_dims[1]; PADDLE_ENFORCE_EQ( input_channels, - filter_dims[1] * groups, + (channel_last ? filter_dims[filter_dims.size() - 1] : filter_dims[1]) * + groups, platform::errors::InvalidArgument( "The number of input's channels should be equal to filter's " "channels " @@ -254,7 +225,7 @@ class Conv2DFusionOp : public operators::ConvOp { "The error may come from wrong data_format setting.", input_channels, in_dims, - filter_dims[1], + channel_last ? filter_dims[filter_dims.size() - 1] : filter_dims[1], filter_dims, groups, data_format)); @@ -285,8 +256,13 @@ class Conv2DFusionOp : public operators::ConvOp { in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size()); } - framework::DDim filter_data_dims = - phi::slice_ddim(filter_dims, 2, filter_dims.size()); + framework::DDim filter_data_dims; + if (channel_last) { + filter_data_dims = + phi::slice_ddim(filter_dims, 1, filter_dims.size() - 1); + } else { + filter_data_dims = phi::slice_ddim(filter_dims, 2, filter_dims.size()); + } std::vector ksize = phi::vectorize(filter_data_dims); UpdatePaddingAndDilation( diff --git a/paddle/phi/kernels/fusion/gpu/conv_fusion_kernel.cu b/paddle/phi/kernels/fusion/gpu/conv_fusion_kernel.cu new file mode 100644 index 00000000000..7aa93911734 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/conv_fusion_kernel.cu @@ -0,0 +1,653 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_CUDA +#include + +#include +#include +#include +#include +#include +#include + +#include "paddle/phi/backends/dynload/cudnn.h" +#include "paddle/phi/backends/gpu/cuda/cudnn_desc.h" +#include "paddle/phi/common/backend.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/impl/conv_cudnn_impl.h" +#include "paddle/utils/optional.h" + +namespace phi { +namespace fusion { + +namespace { +// TODO(wilber): Add a LRU strategy. +class CudnnConvDescManager { + public: + static CudnnConvDescManager* Instance() { + static CudnnConvDescManager global; + return &global; + } + + struct CudnnCacheInfo { + phi::backends::gpu::TensorDescriptor* x_desc{nullptr}; + phi::backends::gpu::FilterDescriptor* w_desc{nullptr}; + phi::backends::gpu::TensorDescriptor* b_desc{nullptr}; + phi::backends::gpu::TensorDescriptor* o_desc{nullptr}; + phi::backends::gpu::ConvolutionDescriptor* conv_desc{nullptr}; + phi::backends::gpu::ActivationDescriptor* act_desc{nullptr}; + size_t workspace_size; + cudnnConvolutionFwdAlgo_t algo; + + std::vector paddings; + std::vector dilations; + std::vector input_pad; + std::vector new_input_shape_vec; + bool is_sys_pad; + + // TODO(wilber): The destruction of cudnn descriptor depends on the + // phi::dynload::cudnn singleton, but when the process exits, the singleton + // destruction order cannot be determined. + // After testing, it is found that the phi::dynload::cudnn related singleton + // on Windows is destructed first, causing the descriptor to be destructed + // and failed, while the descriptor on Linux is destructed first, and the + // phi::dynload::cudnn singleton is destructed later, so that it is correct. + // To circumvent this problem, we rely entirely on freeing resources when + // the process exits. + + // ~CudnnCacheInfo() { + // if (x_desc) delete x_desc; + // if (w_desc) delete w_desc; + // if (b_desc) delete b_desc; + // if (o_desc) delete o_desc; + // if (conv_desc) delete conv_desc; + // if (act_desc) delete act_desc; + // } + }; + + CudnnCacheInfo* GetCudnnCacheInfo( + const std::vector& input_dims, + const std::vector& filter_dims, + const std::vector& bias_dims, + const std::vector& output_dims, + const std::vector& paddings, + const std::vector& strides, + const std::vector& dilations, + phi::DataType input_dtype, + int groups, + cudnnDataType_t dtype, + cudnnTensorFormat_t format, + const std::function& search_func, + const std::string& act, + double value_max = std::numeric_limits::max()) { + // std::hash takes about 5us, xxhash can optimize to 2.5us. + XXH64_state_t* const state = XXH64_createState(); + if (state == nullptr) { + PADDLE_THROW(phi::errors::PreconditionNotMet( + "xxhash create state failed, maybe a environment error.")); + } + XXH64_hash_t const seed = 0; + if (XXH64_reset(state, seed) == XXH_ERROR) { + PADDLE_THROW(phi::errors::PreconditionNotMet( + "xxhash reset state failed, maybe a environment error.")); + } + XXH64_update(state, input_dims.data(), input_dims.size() * sizeof(int)); + XXH64_update(state, filter_dims.data(), filter_dims.size() * sizeof(int)); + XXH64_update(state, bias_dims.data(), bias_dims.size() * sizeof(int)); + // XXH64_update(state, output_dims.data(), output_dims.size() * + // sizeof(int)); + XXH64_update(state, paddings.data(), paddings.size() * sizeof(int)); + XXH64_update(state, strides.data(), strides.size() * sizeof(int)); + XXH64_update(state, dilations.data(), dilations.size() * sizeof(int)); + XXH64_update(state, &input_dtype, sizeof(int)); + XXH64_update(state, &groups, sizeof(int)); + XXH64_update(state, &dtype, sizeof(int)); + XXH64_update(state, &format, sizeof(int)); + XXH64_update(state, act.data(), act.length() * sizeof(char)); + // XXH64_update(state, &value_max, sizeof(double)); + XXH64_hash_t hash_key = XXH64_digest(state); + XXH64_freeState(state); + + if (!cudnn_conv_cache_.count(hash_key)) { + std::lock_guard lock(cache_mutex_); + if (!cudnn_conv_cache_.count(hash_key)) { + cudnn_conv_cache_[hash_key] = CudnnCacheInfo(); + cudnn_conv_cache_[hash_key].x_desc = + GetTensorDescInfo(input_dims, input_dtype, format); + cudnn_conv_cache_[hash_key].w_desc = + GetFilterDescInfo(filter_dims, input_dtype, format); + cudnn_conv_cache_[hash_key].o_desc = + GetTensorDescInfo(output_dims, input_dtype, format); + cudnn_conv_cache_[hash_key].b_desc = + GetTensorDescInfo(bias_dims, input_dtype, format); + cudnn_conv_cache_[hash_key].conv_desc = + GetConvDescInfo(paddings, strides, dilations, groups, dtype); + cudnn_conv_cache_[hash_key].act_desc = + GetActivationDescInfo(act, value_max); + + size_t workspace_size; + cudnnConvolutionFwdAlgo_t algo; + search_func(&algo, + &workspace_size, + cudnn_conv_cache_[hash_key].x_desc->desc(), + cudnn_conv_cache_[hash_key].w_desc->desc(), + cudnn_conv_cache_[hash_key].o_desc->desc(), + cudnn_conv_cache_[hash_key].conv_desc->desc()); + cudnn_conv_cache_[hash_key].workspace_size = workspace_size; + cudnn_conv_cache_[hash_key].algo = algo; + } + } + + return &cudnn_conv_cache_.at(hash_key); + } + + struct ConvAttrCacheInfo { + std::vector paddings; + std::vector dilations; + std::vector input_pad; + std::vector new_input_shape_vec; + bool is_sys_pad; + }; + ConvAttrCacheInfo* GetConvAttr(const std::vector& paddings_t, + const std::vector& dilations_t, + const std::string& padding_algorithm, + const std::vector& input_dims, + const std::vector& filter_dims, + const std::vector& strides, + cudnnTensorFormat_t format) { + XXH64_state_t* const state = XXH64_createState(); + if (state == nullptr) { + PADDLE_THROW(phi::errors::PreconditionNotMet( + "xxhash create state failed, maybe a environment error.")); + } + XXH64_hash_t const seed = 0; + if (XXH64_reset(state, seed) == XXH_ERROR) { + PADDLE_THROW(phi::errors::PreconditionNotMet( + "xxhash create state failed, maybe a environment error.")); + } + XXH64_update(state, paddings_t.data(), paddings_t.size() * sizeof(int)); + XXH64_update(state, dilations_t.data(), dilations_t.size() * sizeof(int)); + XXH64_update(state, input_dims.data(), input_dims.size() * sizeof(int)); + XXH64_update(state, filter_dims.data(), filter_dims.size() * sizeof(int)); + XXH64_update(state, strides.data(), strides.size() * sizeof(int)); + XXH64_update(state, &format, sizeof(int)); + XXH64_update(state, + padding_algorithm.data(), + padding_algorithm.length() * sizeof(char)); + XXH64_hash_t hash_key = XXH64_digest(state); + XXH64_freeState(state); + + if (!conv_attr_cache_.count(hash_key)) { + std::lock_guard lock(attr_mutex_); + if (!conv_attr_cache_.count(hash_key)) { + ConvAttrCacheInfo cache; + auto paddings = paddings_t; + auto dilations = dilations_t; + std::vector in_data_dims(input_dims.size() - 2); + std::vector ksize(filter_dims.size() - 2); + if (format == CUDNN_TENSOR_NHWC) { + for (size_t i = 1; i < input_dims.size() - 1; ++i) { + in_data_dims[i - 1] = input_dims[i]; + } + for (size_t i = 1; i < filter_dims.size() - 1; ++i) { + ksize[i - 1] = filter_dims[i]; + } + } else { + for (size_t i = 2; i < input_dims.size(); ++i) { + in_data_dims[i - 2] = input_dims[i]; + } + for (size_t i = 2; i < filter_dims.size(); ++i) { + ksize[i - 2] = filter_dims[i]; + } + } + phi::UpdatePaddingAndDilation(&paddings, + &dilations, + padding_algorithm, + make_ddim(in_data_dims), + strides, + ksize); + + int data_dim = strides.size(); // 2d or 3d + bool is_sys_pad = funcs::IsSymmetricPadding(paddings, data_dim); + std::vector padding_common(data_dim, 0); + if (!is_sys_pad) { + std::vector padding_diff(data_dim); + std::vector new_input_shape_vec(data_dim + 2); + new_input_shape_vec[0] = input_dims[0]; + + if (format == CUDNN_TENSOR_NCHW) { + new_input_shape_vec[1] = input_dims[1]; + } else { + new_input_shape_vec[data_dim + 1] = input_dims[data_dim + 1]; + } + + std::vector input_pad(input_dims.size() * 2, 0); + for (size_t i = 0; i < data_dim; ++i) { + padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]); + padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]); + if (format == CUDNN_TENSOR_NCHW) { + new_input_shape_vec[i + 2] = input_dims[i + 2] + padding_diff[i]; + } else { + new_input_shape_vec[i + 1] = input_dims[i + 1] + padding_diff[i]; + } + if (format == CUDNN_TENSOR_NCHW) { + input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i]; + input_pad[2 * i + 4 + 1] = + paddings[2 * i + 1] - padding_common[i]; + } else { + input_pad[2 * i + 2] = paddings[2 * i] - padding_common[i]; + input_pad[2 * i + 2 + 1] = + paddings[2 * i + 1] - padding_common[i]; + } + } + + cache.is_sys_pad = false; + cache.input_pad = input_pad; + cache.new_input_shape_vec = new_input_shape_vec; + } else { + cache.is_sys_pad = true; + if (paddings.size() == data_dim) { + for (size_t i = 0; i < data_dim; ++i) { + padding_common[i] = paddings[i]; + } + } else { + for (size_t i = 0; i < data_dim; ++i) { + padding_common[i] = paddings[2 * i]; + } + } + } + + cache.dilations = dilations; + cache.paddings = padding_common; + conv_attr_cache_[hash_key] = cache; + } + } + + return &conv_attr_cache_.at(hash_key); + } + + private: + phi::backends::gpu::TensorDescriptor* GetTensorDescInfo( + const std::vector& input_dims, + phi::DataType input_dtype, + cudnnTensorFormat_t input_format) { + auto* desc = new phi::backends::gpu::TensorDescriptor(); + desc->set( + input_dims, input_format, backends::gpu::ToCudnnDataType(input_dtype)); + return desc; + } + + phi::backends::gpu::FilterDescriptor* GetFilterDescInfo( + const std::vector& input_dims, + phi::DataType input_dtype, + cudnnTensorFormat_t input_format) { + auto* desc = new phi::backends::gpu::FilterDescriptor(); + desc->set( + input_dims, input_format, backends::gpu::ToCudnnDataType(input_dtype)); + return desc; + } + + phi::backends::gpu::ConvolutionDescriptor* GetConvDescInfo( + const std::vector& paddings, + const std::vector& strides, + const std::vector& dilations, + int groups, + cudnnDataType_t dtype) { + auto* desc = new phi::backends::gpu::ConvolutionDescriptor(); + desc->set(dtype, + paddings, + strides, + dilations, + paddle::platform::AllowTF32Cudnn(), + groups); + return desc; + } + + phi::backends::gpu::ActivationDescriptor* GetActivationDescInfo( + const std::string& act, + double value_max = std::numeric_limits::max()) { + auto* desc = new phi::backends::gpu::ActivationDescriptor(); + cudnnActivationMode_t mode; + double relu_ceiling = 0.0; + if (act == "identity") { + mode = CUDNN_ACTIVATION_IDENTITY; + } else if (act == "relu") { + mode = CUDNN_ACTIVATION_RELU; + } else if (act == "relu6") { + relu_ceiling = 6.0; + mode = CUDNN_ACTIVATION_CLIPPED_RELU; + } else if (act == "sigmoid") { + mode = CUDNN_ACTIVATION_SIGMOID; + } else if (act == "relux") { + relu_ceiling = value_max; + mode = CUDNN_ACTIVATION_CLIPPED_RELU; + } else if (act == "tanh") { + mode = CUDNN_ACTIVATION_TANH; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Unknown CUDNN activation string: %s.", act)); + } + desc->set(mode, relu_ceiling); + return desc; + } + + std::mutex cache_mutex_; + std::unordered_map cudnn_conv_cache_; + + std::mutex attr_mutex_; + std::unordered_map conv_attr_cache_; +}; +} // namespace + +template +void ConvFusionKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& filter, + const DenseTensor& bias, + const paddle::optional& residual, + const std::vector& strides, + const std::vector& paddings_t, + const std::string& padding_algorithm, + const std::vector& dilations_t, + int groups, + const std::string& data_format, + const std::string& activation, + bool exhaustive_search, + const std::vector& channels, + int user_workspace_size, + DenseTensor* output, + std::vector outs) { + auto handle = ctx.cudnn_handle(); + ctx.template Alloc(output); + auto workspace_handle = ctx.cudnn_workspace_handle(); + + exhaustive_search = FLAGS_cudnn_exhaustive_search || exhaustive_search; + bool deterministic = FLAGS_cudnn_deterministic; + PADDLE_ENFORCE_EQ(exhaustive_search && deterministic, + false, + phi::errors::InvalidArgument( + "Cann't set exhaustive_search True and " + "FLAGS_cudnn_deterministic True at same time.")); + + size_t workspace_size_limit = 0; + if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) { + int64_t max_user_size = + std::min(static_cast(FLAGS_conv_workspace_size_limit), + static_cast(user_workspace_size)); + workspace_size_limit = max_user_size * 1024 * 1024; + } + + auto dtype = phi::backends::gpu::CudnnDataType::type; + const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); + // Choose NHWC or NCHW by data_format attr. + auto compute_format = channel_last ? CUDNN_TENSOR_NHWC : CUDNN_TENSOR_NCHW; + VLOG(3) << "Compute ConvFusionOp with cuDNN:" + << " data_format=" << data_format << " compute_format=" + << (compute_format == CUDNN_TENSOR_NHWC ? "NHWC" : "NCHW"); + + auto* conv_attr_cache = CudnnConvDescManager::Instance()->GetConvAttr( + paddings_t, + dilations_t, + padding_algorithm, + phi::vectorize(input.dims()), + phi::vectorize(filter.dims()), + strides, + compute_format); + + DenseTensor transformed_input; + auto unsys_pad_process = [&](const std::vector& new_input_shape_vec, + const std::vector& input_pad) { + DDim new_input_shape(make_ddim(new_input_shape_vec)); + transformed_input.Resize(new_input_shape); + ctx.template Alloc(&transformed_input); + + const int rank = input.dims().size(); + T pad_value(0.0); + switch (rank) { + case 4: { + funcs::PadFunction( + ctx, input_pad, input, pad_value, &transformed_input); + } break; + case 5: { + funcs::PadFunction( + ctx, input_pad, input, pad_value, &transformed_input); + } break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "ConvOp only support tensors with 4 or 5 dimensions.")); + } + }; + if (conv_attr_cache->is_sys_pad) { + transformed_input.ShareDataWith(input); + } else { + unsys_pad_process(conv_attr_cache->new_input_shape_vec, + conv_attr_cache->input_pad); + } + + std::vector b_dims(input.dims().size(), 1); + if (compute_format == CUDNN_TENSOR_NCHW) { + b_dims[1] = static_cast(bias.dims()[0]); + } else { + b_dims[input.dims().size() - 1] = static_cast(bias.dims()[0]); + } + + auto search_func = [&](cudnnConvolutionFwdAlgo_t* cudnn_algo, + size_t* wks_bytes, + cudnnTensorDescriptor_t x_desc, + cudnnFilterDescriptor_t w_desc, + cudnnTensorDescriptor_t o_desc, + cudnnConvolutionDescriptor_t cudnn_conv_desc) { + if (!exhaustive_search) { +#if CUDNN_VERSION >= 8000 + int perf_count; + int best_algo_idx = 0; + size_t tmp_size = 0; + std::unique_ptr perf_results( + new cudnnConvolutionFwdAlgoPerf_t[phi::kNUM_CUDNN_FWD_ALGS]); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnGetConvolutionForwardAlgorithm_v7( + handle, + x_desc, + w_desc, + cudnn_conv_desc, + o_desc, + phi::kNUM_CUDNN_FWD_ALGS, + &perf_count, + perf_results.get())); + *cudnn_algo = (perf_results.get())[best_algo_idx].algo; +#else + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnGetConvolutionForwardAlgorithm( + handle, + x_desc, + w_desc, + cudnn_conv_desc, + o_desc, + CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, + cudnn_algo)); +#endif + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnGetConvolutionForwardWorkspaceSize(handle, + x_desc, + w_desc, + cudnn_conv_desc, + o_desc, + *cudnn_algo, + wks_bytes)); + } else { + std::array + fwd_perf_stat; + int returned_algo_count; + auto cudnn_find_func = [&](void* cudnn_workspace) { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnFindConvolutionForwardAlgorithmEx( + handle, + x_desc, + transformed_input.data(), + w_desc, + filter.data(), + cudnn_conv_desc, + o_desc, + output->data(), + phi::kNUM_CUDNN_FWD_ALGS, + &returned_algo_count, + fwd_perf_stat.data(), + cudnn_workspace, + workspace_size_limit)); + }; + workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit); + *cudnn_algo = fwd_perf_stat[0].algo; + + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnGetConvolutionForwardWorkspaceSize( + handle, + x_desc, + w_desc, + cudnn_conv_desc, + o_desc, + fwd_perf_stat[0].algo, + wks_bytes)); + } + }; + + auto cudnn_cache_info = CudnnConvDescManager::Instance()->GetCudnnCacheInfo( + phi::vectorize(transformed_input.dims()), + phi::vectorize(filter.dims()), + b_dims, + phi::vectorize(output->dims()), + conv_attr_cache->paddings, + strides, + conv_attr_cache->dilations, + transformed_input.dtype(), + groups, + phi::backends::gpu::CudnnDataType::type, + compute_format, + search_func, + activation); + + auto x_desc = cudnn_cache_info->x_desc->desc(); + auto w_desc = cudnn_cache_info->w_desc->desc(); + auto b_desc = cudnn_cache_info->b_desc->desc(); + auto o_desc = cudnn_cache_info->o_desc->desc(); + auto cudnn_conv_desc = cudnn_cache_info->conv_desc->desc(); + auto act_desc = cudnn_cache_info->act_desc->desc(); + auto algo = cudnn_cache_info->algo; + auto workspace_size = cudnn_cache_info->workspace_size; + + if ((activation == "identity") && (!residual.get_ptr())) { + // Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is + // enabled with CUDNN_ACTIVATION_IDENTITY in cuDNN lib. + // But test in some case, the speed is slower, change to use + // cudnnConvolutionForward and cudnnAddTensor + // ------------- cudnn conv forward and bias add --------------------- + ScalingParamType alpha = 1.0f, beta = 0.0f; + auto cudnn_func = [&](void* cudnn_workspace) { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnConvolutionForward(handle, + &alpha, + x_desc, + transformed_input.data(), + w_desc, + filter.data(), + cudnn_conv_desc, + algo, + cudnn_workspace, + workspace_size, + &beta, + o_desc, + output->data())); + }; + workspace_handle.RunFunc(cudnn_func, workspace_size); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnAddTensor( + handle, &alpha, b_desc, bias.data(), &alpha, o_desc, output->data())); + } else { + // Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_​PRECOMP_GEMM algo is + // enabled with CUDNN_ACTIVATION_IDENTITY. + if (activation == "identity") { + algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + } + + ScalingParamType alpha = 1.0f; + ScalingParamType beta = residual.get_ptr() ? 1.0f : 0.0f; + auto cudnn_func = [&](void* cudnn_workspace) { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnConvolutionBiasActivationForward( + handle, + &alpha, + x_desc, + transformed_input.data(), + w_desc, + filter.data(), + cudnn_conv_desc, + algo, + cudnn_workspace, + workspace_size, + &beta, + o_desc, + residual.get_ptr() ? residual->data() : output->data(), + b_desc, + bias.data(), + act_desc, + o_desc, + output->data())); + }; + workspace_handle.RunFunc(cudnn_func, workspace_size); + } + + if (!channels.empty()) { + if (transformed_input.dims()[0] == 1 && + compute_format == CUDNN_TENSOR_NCHW) { + // share data with Output + phi::DenseTensor t; + t.ShareDataWith(*output); + auto y_dims = output->dims(); + t.Resize({y_dims[1], y_dims[2], y_dims[3]}); + int s = 0; + for (size_t i = 0; i < channels.size(); ++i) { + int e = s + channels[i]; + outs[i]->ShareDataWith(t.Slice(s, e)); + outs[i]->Resize( + {transformed_input.dims()[0], channels[i], y_dims[2], y_dims[3]}); + s = e; + } + } else { + // TODO(qingiqng): do copy when batch size large than 1 + PADDLE_THROW(phi::errors::Unimplemented( + "Input with batch size greater than 1 is unsupported. The received " + "batch size is %d, Input's shape is [%s].", + transformed_input.dims()[0], + transformed_input.dims())); + } + } +} +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(conv2d_fusion, // cuda_only + GPUDNN, + ALL_LAYOUT, + phi::fusion::ConvFusionKernel, + float, + double, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/ops/compat/conv_fusion_sig.cc b/paddle/phi/ops/compat/conv_fusion_sig.cc new file mode 100644 index 00000000000..4cadfe87f53 --- /dev/null +++ b/paddle/phi/ops/compat/conv_fusion_sig.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature ConvFusionOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("conv2d_fusion", + {"Input", "Filter", "Bias", "ResidualData"}, + { + "strides", + "paddings", + "padding_algorithm", + "dilations", + "groups", + "data_format", + "activation", + "exhaustive_search", + "split_channels", + "workspace_size_MB", + }, + {"Output", "Outputs"}); +} +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(conv2d_fusion, phi::ConvFusionOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_fusion_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_fusion_op.py index 3e5ec0f8a97..6cc922204f3 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_fusion_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_fusion_op.py @@ -43,6 +43,30 @@ def create_test_padding_VALID_class(parent): globals()[cls_name] = TestPaddingVALIDCase +def create_test_cudnn_channel_last_class(parent): + @unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" + ) + class TestCudnnChannelLastCase(parent): + def init_test_case(self): + super().init_test_case() + self.data_format = "NHWC" + N, C, H, W = self.input_size + self.input_size = [N, H, W, C] + K1, K2, R, S = self.filter_size + self.filter_size = [K1, R, S, K2] + + def test_check_output(self): + print(self.attrs) + if self.has_cuda(): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=1e-5) + + cls_name = "{0}_{1}".format(parent.__name__, "CudnnChannelLast") + TestCudnnChannelLastCase.__name__ = cls_name + globals()[cls_name] = TestCudnnChannelLastCase + + class TestConv2DFusionOp(OpTest): def setUp(self): self.op_type = "conv2d_fusion" @@ -73,9 +97,14 @@ class TestConv2DFusionOp(OpTest): filter = np.random.random(self.filter_size).astype(self.dtype) bias = np.random.random(self.filter_size[0]).astype(self.dtype) + if self.data_format == "NHWC": + filter_nchw = np.transpose(filter, [0, 3, 1, 2]) + else: + filter_nchw = filter + self.output, _, _, _, _ = conv2d_forward_naive( input, - filter, + filter_nchw, self.groups, conv2d_param, self.padding_algorithm, @@ -100,7 +129,10 @@ class TestConv2DFusionOp(OpTest): self.output += residual_data # Add bias - self.output = self.output + bias.reshape((1, bias.size, 1, 1)) + if self.data_format == "NCHW": + self.output = self.output + bias.reshape((1, bias.size, 1, 1)) + else: + self.output = self.output + bias.reshape((1, 1, 1, bias.size)) assert self.activation in ['relu', 'identity'] if self.activation == 'relu': @@ -359,6 +391,23 @@ class TestWithInput1x1Filter1x1_AsyPadding(TestConv2DFusionOp): self.padding_algorithm = "EXPLICIT" +class TestSimpleNHWC(TestConv2DFusionOp): + def init_test_case(self): + self.stride = [1, 1] + self.input_size = [3, 5, 5, 2] # NHWC + self.data_format = "NHWC" + assert np.mod(self.input_size[3], self.groups) == 0 + f_c = self.input_size[3] // self.groups + self.filter_size = [4, 3, 3, f_c] + + def init_group(self): + self.groups = 1 + + def init_paddings(self): + self.pad = [1, 1] + self.padding_algorithm = "EXPLICIT" + + create_test_padding_SAME_class(TestAsyPadding) create_test_padding_SAME_class(TestWithPad_AsyPadding) create_test_padding_SAME_class(TestWithStride_AsyPadding) @@ -371,5 +420,11 @@ create_test_padding_VALID_class(TestWithStride_AsyPadding) create_test_padding_VALID_class(TestWithGroup_AsyPadding) create_test_padding_VALID_class(TestWithInput1x1Filter1x1_AsyPadding) +create_test_cudnn_channel_last_class(TestAsyPadding) +create_test_cudnn_channel_last_class(TestWithPad_AsyPadding) +create_test_cudnn_channel_last_class(TestWithStride_AsyPadding) +create_test_cudnn_channel_last_class(TestWithGroup_AsyPadding) +create_test_cudnn_channel_last_class(TestWithInput1x1Filter1x1_AsyPadding) + if __name__ == '__main__': unittest.main() -- GitLab