From 7783d3bd43b144d7bca932a6aa8c1fafc150b91e Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Wed, 16 Oct 2019 17:40:14 +0800 Subject: [PATCH] Conv refine (#20644) * add condition judgement for performance improvement test=develop * add condition judgement for performance improvement test=develop * refine code style test=develop --- paddle/fluid/operators/conv_cudnn_op.cu | 60 +++++++++++++------------ 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/operators/conv_cudnn_op.cu b/paddle/fluid/operators/conv_cudnn_op.cu index d34d152bc7..274da9abf0 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu +++ b/paddle/fluid/operators/conv_cudnn_op.cu @@ -540,23 +540,25 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { workspace_size); } - std::vector starts(transformed_input_channel.dims().size(), 0); - std::vector axes(transformed_input_channel.dims().size(), 0); + if (!is_sys_pad) { + std::vector starts(transformed_input_channel.dims().size(), 0); + std::vector axes(transformed_input_channel.dims().size(), 0); - for (size_t i = 0; i < transformed_input_channel.dims().size(); ++i) { - starts[i] = input_pad[2 * i]; - axes[i] = i; - } + for (size_t i = 0; i < transformed_input_channel.dims().size(); ++i) { + starts[i] = input_pad[2 * i]; + axes[i] = i; + } - transformed_input_grad_channel.mutable_data(ctx.GetPlace()); - if (transformed_input_channel.dims().size() == 4) { - Slice_2( - ctx, &transformed_input_grad, &transformed_input_grad_channel, - starts, axes); - } else { - Slice_2( - ctx, &transformed_input_grad, &transformed_input_grad_channel, - starts, axes); + transformed_input_grad_channel.mutable_data(ctx.GetPlace()); + if (transformed_input_channel.dims().size() == 4) { + Slice_2( + ctx, &transformed_input_grad, &transformed_input_grad_channel, + starts, axes); + } else { + Slice_2( + ctx, &transformed_input_grad, &transformed_input_grad_channel, + starts, axes); + } } if (channel_last) { @@ -982,20 +984,22 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { workspace_size); } - // reverse padded input - std::vector starts(X->dims().size(), 0); - std::vector axes(X->dims().size(), 0); + if (!is_sys_pad) { + // reverse padded input + std::vector starts(X->dims().size(), 0); + std::vector axes(X->dims().size(), 0); - for (size_t i = 0; i < X->dims().size(); ++i) { - starts[i] = input_pad[2 * i]; - axes[i] = i; - } - if (X->dims().size() == 4) { - Slice_2( - ctx, &transformed_dX, &transformed_dX_channel, starts, axes); - } else { - Slice_2( - ctx, &transformed_dX, &transformed_dX_channel, starts, axes); + for (size_t i = 0; i < X->dims().size(); ++i) { + starts[i] = input_pad[2 * i]; + axes[i] = i; + } + if (X->dims().size() == 4) { + Slice_2( + ctx, &transformed_dX, &transformed_dX_channel, starts, axes); + } else { + Slice_2( + ctx, &transformed_dX, &transformed_dX_channel, starts, axes); + } } if (channel_last) { TransToChannelLast( -- GitLab