diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index 749258183ba058cf0ed8d91c4406813694314b85..d2de4e80f751d4938ac9cad60871b470fccf225c 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -261,8 +261,12 @@ class GemmConvGradKernel : public framework::OpKernel { if (input_grad) { input_grad->mutable_data(context.GetPlace()); - set_zero(dev_ctx, input_grad, static_cast(0)); + // if is_expand is false, the operation of set_zero is unnecessary, + // because math::matmul will reset input_grad. + if (is_expand) { + set_zero(dev_ctx, input_grad, static_cast(0)); + } math::Col2VolFunctor col2vol; math::Col2ImFunctor col2im; diff --git a/paddle/operators/conv_transpose_op.h b/paddle/operators/conv_transpose_op.h index 80600b53614994ba0c740aed0d75c9944333fecc..1171b0435fd2b1abe541043e8283a8fc09dc13c7 100644 --- a/paddle/operators/conv_transpose_op.h +++ b/paddle/operators/conv_transpose_op.h @@ -225,7 +225,6 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { if (input_grad) { input_grad->mutable_data(context.GetPlace()); - set_zero(dev_ctx, input_grad, static_cast(0)); } if (filter_grad) { // filter size (m, c, k_h, k_w) filter_grad->mutable_data(context.GetPlace());