From c92348c3b927912ba06ba291a0790af1a542dabf Mon Sep 17 00:00:00 2001 From: lvmengsi Date: Sat, 28 Sep 2019 16:04:50 +0800 Subject: [PATCH] fix conv_grad_grad (#20054) --- paddle/fluid/operators/conv_cudnn_op.cu.cc | 49 +++++++++++---------- paddle/fluid/operators/conv_op.cc | 17 +++++--- paddle/fluid/operators/conv_op.h | 50 +++++++++++----------- 3 files changed, 63 insertions(+), 53 deletions(-) diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index 6629a203f80..f82d9f6d2b9 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -355,15 +355,17 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { size_t workspace_size = 0; if (ddO) { ddy = ddO->mutable_data(ctx.GetPlace()); - args1.handle = handle; - args1.idesc.set(*ddX, iwo_group); - args1.wdesc.set(*W, layout, iwo_group); - args1.odesc.set(*ddO, iwo_group); - args1.cdesc.set(dtype, paddings, strides, dilations, c_group); - - using search1 = SearchAlgorithm; - fwd_algo1 = search1::Find(args1, exhaustive_search, false, 0, ctx); - workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1); + if (ddX) { + args1.handle = handle; + args1.idesc.set(*ddX, iwo_group); + args1.wdesc.set(*W, layout, iwo_group); + args1.odesc.set(*ddO, iwo_group); + args1.cdesc.set(dtype, paddings, strides, dilations, c_group); + + using search1 = SearchAlgorithm; + fwd_algo1 = search1::Find(args1, exhaustive_search, false, 0, ctx); + workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1); + } if (ddW) { ddw = ddW->data(); @@ -380,7 +382,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { } } - if (dW) { + if (dW && ddX) { dw = dW->mutable_data(ctx.GetPlace()); args3.handle = handle; args3.idesc.set(*ddX, iwo_group); @@ -423,17 +425,20 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { auto wkspace_handle = dev_ctx.cudnn_workspace_handle(); if (ddO) { - ddx = ddX->data(); - for (int i = 0; i < groups; i++) { - wkspace_handle.RunFunc( - [&](void* workspace_ptr) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( - handle, &alpha, args1.idesc.desc(), ddx + i * group_offset_in, - args1.wdesc.desc(), w + i * group_offset_filter, - args1.cdesc.desc(), fwd_algo1, workspace_ptr, workspace_size, - &beta, args1.odesc.desc(), ddy + i * group_offset_out)); - }, - workspace_size); + if (ddX) { + ddx = ddX->data(); + for (int i = 0; i < groups; i++) { + wkspace_handle.RunFunc( + [&](void* workspace_ptr) { + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( + handle, &alpha, args1.idesc.desc(), + ddx + i * group_offset_in, args1.wdesc.desc(), + w + i * group_offset_filter, args1.cdesc.desc(), fwd_algo1, + workspace_ptr, workspace_size, &beta, args1.odesc.desc(), + ddy + i * group_offset_out)); + }, + workspace_size); + } } if (ddW) { for (int i = 0; i < groups; i++) { @@ -451,7 +456,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { } } - if (dW) { + if (dW && ddX) { ddx = ddX->data(); for (int i = 0; i < groups; i++) { wkspace_handle.RunFunc( diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 1cfdf7da86a..5528f758732 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -553,9 +553,10 @@ class Conv2DDoubleGradMaker : public framework::SingleGradOpDescMaker { auto ddw = OutputGrad(framework::GradVarName("Filter")); std::vector empty_str = {}; - op->SetOutput( - "DDOutput", - ddx.empty() ? empty_str : InputGrad(framework::GradVarName("Output"))); + op->SetOutput("DDOutput", + (ddx.empty() && ddw.empty()) + ? empty_str + : InputGrad(framework::GradVarName("Output"))); op->SetOutput("DFilter", ddx.empty() ? empty_str : InputGrad("Filter")); op->SetOutput("DInput", ddw.empty() ? empty_str : InputGrad("Input")); @@ -587,9 +588,10 @@ class Conv3DDoubleGradMaker : public framework::SingleGradOpDescMaker { auto ddw = OutputGrad(framework::GradVarName("Filter")); std::vector empty_str = {}; - op->SetOutput( - "DDOutput", - ddx.empty() ? empty_str : InputGrad(framework::GradVarName("Output"))); + op->SetOutput("DDOutput", + (ddx.empty() && ddw.empty()) + ? empty_str + : InputGrad(framework::GradVarName("Output"))); op->SetOutput("DFilter", ddx.empty() ? empty_str : InputGrad("Filter")); op->SetOutput("DInput", ddw.empty() ? empty_str : InputGrad("Input")); @@ -604,7 +606,8 @@ void ConvOpDoubleGrad::InferShape(framework::InferShapeContext* ctx) const { auto w_dims = ctx->GetInputDim("Filter"); auto do_dims = ctx->GetInputDim("DOutput"); - if (ctx->HasOutput("DDOutput") && ctx->HasInput("DDInput")) { + if (ctx->HasOutput("DDOutput") && + (ctx->HasInput("DDInput") || (ctx->HasInput("DDFilter")))) { ctx->SetOutputDim("DDOutput", do_dims); } if (ctx->HasOutput("DFilter") && ctx->HasInput("DDInput")) { diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index aa621529b52..a6882897ad7 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -506,7 +506,7 @@ class GemmConvDoubleGradKernel : public framework::OpKernel { // dw = ddx * dy ==> dw(Cout, Cin, kh, kw), ddx(N, Cin, H, W), dy(N, Cout, // oH, oW) // dw convolution double grad: im2col(vol2col) + gemm - if (dW) { + if (dW && ddX) { dW->mutable_data(ctx.GetPlace()); set_zero(dev_ctx, dW, static_cast(0)); Tensor dW_arr = *dW; @@ -549,36 +549,38 @@ class GemmConvDoubleGradKernel : public framework::OpKernel { math::Im2ColFunctor im2col; math::Vol2ColFunctor vol2col; for (int i = 0; i < batch_size; ++i) { - Tensor ddx_batch = ddX->Slice(i, i + 1).Resize(input_shape); - Tensor x_batch = X->Slice(i, i + 1).Resize(input_shape); Tensor ddy_batch = ddY->Slice(i, i + 1).Resize(output_matrix_shape); for (int g = 0; g < groups; ++g) { - Tensor x_slice = x_batch.Slice(g * in_step, (g + 1) * in_step); - Tensor ddx_slice = ddx_batch.Slice(g * in_step, (g + 1) * in_step); - if (!is_expand) { - col.ShareDataWith(ddx_slice); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - } else if (data_dim == 2U) { - // im2col - im2col(dev_ctx, ddx_slice, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, - &col); - } else if (data_dim == 3U) { - // vol2col - vol2col(dev_ctx, ddx_slice, dilations, strides, paddings, &col); - } - - // gemm Tensor ddy_slice = ddy_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor w_slice = W.Slice(g * out_step, (g + 1) * out_step); - blas.MatMul(w_slice, false, col_matrix, false, T(1.0), &ddy_slice, - T(0.0)); + if (ddX) { + Tensor ddx_batch = ddX->Slice(i, i + 1).Resize(input_shape); + Tensor ddx_slice = ddx_batch.Slice(g * in_step, (g + 1) * in_step); + if (!is_expand) { + col.ShareDataWith(ddx_slice); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } else if (data_dim == 2U) { + // im2col + im2col(dev_ctx, ddx_slice, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); + } else if (data_dim == 3U) { + // vol2col + vol2col(dev_ctx, ddx_slice, dilations, strides, paddings, &col); + } + + // gemm + Tensor w_slice = W.Slice(g * out_step, (g + 1) * out_step); + blas.MatMul(w_slice, false, col_matrix, false, T(1.0), &ddy_slice, + T(0.0)); + } if (ddW_in) { Tensor ddW; ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape); + Tensor x_batch = X->Slice(i, i + 1).Resize(input_shape); + Tensor x_slice = x_batch.Slice(g * in_step, (g + 1) * in_step); if (!is_expand) { col.ShareDataWith(x_slice); -- GitLab