From 03dbdbd1bc2300cd6bf2661e0e177553b964c12b Mon Sep 17 00:00:00 2001 From: Zhang Jun Date: Mon, 12 Jun 2023 10:28:06 +0800 Subject: [PATCH] [inference]conv_fusion support bias's rank equal to input's rank (#54477) * support bias's rank equal to input's rank --- .../phi/kernels/fusion/gpu/conv_fusion_kernel.cu | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/paddle/phi/kernels/fusion/gpu/conv_fusion_kernel.cu b/paddle/phi/kernels/fusion/gpu/conv_fusion_kernel.cu index da71c0bf7d3..5a8d2769e66 100644 --- a/paddle/phi/kernels/fusion/gpu/conv_fusion_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/conv_fusion_kernel.cu @@ -413,15 +413,15 @@ void ConvFusionKernel(const Context& ctx, compute_format); DenseTensor transformed_input; + const int input_rank = input.dims().size(); auto unsys_pad_process = [&](const std::vector& new_input_shape_vec, const std::vector& input_pad) { DDim new_input_shape(make_ddim(new_input_shape_vec)); transformed_input.Resize(new_input_shape); ctx.template Alloc(&transformed_input); - const int rank = input.dims().size(); T pad_value(0.0); - switch (rank) { + switch (input_rank) { case 4: { funcs::PadFunction( ctx, input_pad, input, pad_value, &transformed_input); @@ -442,11 +442,16 @@ void ConvFusionKernel(const Context& ctx, conv_attr_cache->input_pad); } - std::vector b_dims(input.dims().size(), 1); + std::vector b_dims(input_rank, 1); if (compute_format == CUDNN_TENSOR_NCHW) { - b_dims[1] = static_cast(bias.dims()[0]); + auto bias_rank = bias.dims().size(); + if (input_rank == bias_rank) { + b_dims[1] = static_cast(bias.dims()[1]); + } else { + b_dims[1] = static_cast(bias.dims()[0]); + } } else { - b_dims[input.dims().size() - 1] = static_cast(bias.dims()[0]); + b_dims[input_rank - 1] = static_cast(bias.dims()[0]); } auto search_func = [&](cudnnConvolutionFwdAlgo_t* cudnn_algo, -- GitLab