diff --git a/paddle/fluid/operators/conv_cudnn_op.cu b/paddle/fluid/operators/conv_cudnn_op.cu index 14c119a40ca3adae12f014c555d90e230eaa3435..d34d152bc7d9c68cb760d6cb2a5aa483ed6126f1 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu +++ b/paddle/fluid/operators/conv_cudnn_op.cu @@ -609,6 +609,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { auto dX = ctx.Output("DInput"); if (ddO) { ddO->mutable_data(ctx.GetPlace()); + math::SetConstant set_zero; + set_zero(dev_ctx, ddO, static_cast(0)); } if (dW) { dW->mutable_data(ctx.GetPlace()); @@ -646,7 +648,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { // transform Tensors to channel first----------- Tensor transformed_X_channel(X->type()); Tensor transformed_dO_channel(dO->type()); - Tensor transformed_ddX_channel(ddX->type()); + Tensor transformed_ddX_channel(X->type()); Tensor transformed_ddO_channel(dO->type()); Tensor transformed_dX_channel(X->type()); @@ -662,10 +664,12 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { TransToChannelFirst( ctx, dO, &transformed_dO_channel); - ResizeToChannelFirst( - ctx, ddX, &transformed_ddX_channel); - TransToChannelFirst( - ctx, ddX, &transformed_ddX_channel); + if (ddX) { + ResizeToChannelFirst( + ctx, ddX, &transformed_ddX_channel); + TransToChannelFirst( + ctx, ddX, &transformed_ddX_channel); + } if (ddO) { ResizeToChannelFirst( @@ -680,7 +684,9 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { } else { transformed_X_channel = *X; transformed_dO_channel = *dO; - transformed_ddX_channel = *ddX; + if (ddX) { + transformed_ddX_channel = *ddX; + } if (ddO) { transformed_ddO_channel.ShareDataWith(*ddO); } @@ -729,15 +735,15 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { transformed_X.Resize(new_input_shape); transformed_ddX.Resize(new_input_shape); transformed_dX.Resize(new_input_shape); - auto& dev_ctx = - ctx.template device_context(); transformed_X = ctx.AllocateTmpTensor( new_input_shape, dev_ctx); - transformed_ddX = - ctx.AllocateTmpTensor( - new_input_shape, dev_ctx); + if (ddX) { + transformed_ddX = + ctx.AllocateTmpTensor( + new_input_shape, dev_ctx); + } if (dX) { transformed_dX = ctx.AllocateTmpTensor( @@ -751,16 +757,20 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { case 4: { math::PadFunction( ctx, input_pad, transformed_X_channel, pad_value, &transformed_X); - math::PadFunction( - ctx, input_pad, transformed_ddX_channel, pad_value, - &transformed_ddX); + if (ddX) { + math::PadFunction( + ctx, input_pad, transformed_ddX_channel, pad_value, + &transformed_ddX); + } } break; case 5: { math::PadFunction( ctx, input_pad, transformed_X_channel, pad_value, &transformed_X); - math::PadFunction( - ctx, input_pad, transformed_ddX_channel, pad_value, - &transformed_ddX); + if (ddX) { + math::PadFunction( + ctx, input_pad, transformed_ddX_channel, pad_value, + &transformed_ddX); + } } break; default: PADDLE_THROW("ConvOp only support tensors with 4 or 5 dimensions."); @@ -768,7 +778,9 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { } else { transformed_X.ShareDataWith(transformed_X_channel); - transformed_ddX.ShareDataWith(transformed_ddX_channel); + if (ddX) { + transformed_ddX.ShareDataWith(transformed_ddX_channel); + } if (dX) { transformed_dX.ShareDataWith(transformed_dX_channel); } @@ -936,10 +948,9 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { ctx, &transformed_ddO_channel, ddO); } } - T* transformed_dy_channel = nullptr; + T* transformed_dy_channel = transformed_dO_channel.data(); if (dW && ddX) { ddx = transformed_ddX.data(); - transformed_dy_channel = transformed_dO_channel.data(); for (int i = 0; i < groups; i++) { wkspace_handle.RunFunc( [&](void* workspace_ptr) { diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index b6820a7bef4fd27f689df365549330dc5d384ca3..c1b8d868b29cc9f9e9c698f5055dceb2f446e5c2 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -651,7 +651,7 @@ class GemmConvDoubleGradKernel : public framework::OpKernel { // transform Tensor Tensor transformed_X(X->type()); Tensor transformed_dY(dY->type()); - Tensor transformed_ddX(ddX->type()); + Tensor transformed_ddX(X->type()); if (channel_last) { ResizeToChannelFirst(ctx, X, &transformed_X); @@ -660,13 +660,16 @@ class GemmConvDoubleGradKernel : public framework::OpKernel { ResizeToChannelFirst(ctx, dY, &transformed_dY); TransToChannelFirst(ctx, dY, &transformed_dY); - ResizeToChannelFirst(ctx, ddX, &transformed_ddX); - TransToChannelFirst(ctx, ddX, &transformed_ddX); - + if (ddX) { + ResizeToChannelFirst(ctx, ddX, &transformed_ddX); + TransToChannelFirst(ctx, ddX, &transformed_ddX); + } } else { transformed_X = *X; transformed_dY = *dY; - transformed_ddX = *ddX; + if (ddX) { + transformed_ddX = *ddX; + } } // update padding and dilation @@ -857,12 +860,11 @@ class GemmConvDoubleGradKernel : public framework::OpKernel { } else if (data_dim == 3U) { vol2col(dev_ctx, ddx_slice, dilations, strides, paddings, &col); } + Tensor w_slice = W.Slice(g * out_step, (g + 1) * out_step); + blas.MatMul(w_slice, false, col_matrix, false, T(1.0), &ddy_slice, + T(0.0)); } - Tensor w_slice = W.Slice(g * out_step, (g + 1) * out_step); - blas.MatMul(w_slice, false, col_matrix, false, T(1.0), &ddy_slice, - T(0.0)); - if (ddW_in) { Tensor x_batch = transformed_X.Slice(i, i + 1).Resize(input_shape); Tensor x_slice = x_batch.Slice(g * in_step, (g + 1) * in_step);