diff --git a/paddle/operators/deconv2d_op.cc b/paddle/operators/deconv2d_op.cc index 6b20fe4589378fa77a6b378f3f2542d487dbb05a..331fbd59825ec2c65f5cacb6323af53377300a8f 100644 --- a/paddle/operators/deconv2d_op.cc +++ b/paddle/operators/deconv2d_op.cc @@ -30,7 +30,6 @@ void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const { auto filter_dims = ctx->GetInputDim("Filter"); std::vector strides = ctx->Attrs().Get>("strides"); std::vector paddings = ctx->Attrs().Get>("paddings"); - int groups = ctx->Attrs().Get("groups"); for (int i = 0; i < paddings.size(); ++i) { PADDLE_ENFORCE_EQ(paddings[i], 0, "No Padding allowed in deconv op."); @@ -41,9 +40,6 @@ void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0], "input and kernel input dimension should be equal."); - PADDLE_ENFORCE_EQ(groups, 1, - "The number of groups should be 1 in case of deconv op."); - auto output_height = (in_dims[2] - 1) * strides[0] + filter_dims[2]; auto output_width = (in_dims[3] - 1) * strides[1] + filter_dims[3]; ctx->SetOutputDim("Output", diff --git a/paddle/operators/deconv2d_op.h b/paddle/operators/deconv2d_op.h index 0c6b6cc09463667fc00fa9b144668551099d0309..9036801a6589e6860763bbece00908c3a3294bb3 100644 --- a/paddle/operators/deconv2d_op.h +++ b/paddle/operators/deconv2d_op.h @@ -83,7 +83,7 @@ class GemmDeconv2DKernel : public framework::OpKernel { DDim col_shape = {C, K_H, K_W, H, W}; // use col_matrix_shape in the gemm calculation - DDim col_matrix_shape = {M * K_H * K_W, H * W}; + DDim col_matrix_shape = {C * K_H * K_W, H * W}; Tensor col; col.mutable_data(col_shape, context.GetPlace()); @@ -108,11 +108,11 @@ class GemmDeconv2DKernel : public framework::OpKernel { for (int i = 0; i < N; i++) { // batch with size (M, H * W) Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); + // filter size: (M, C * K_H * K_W) + // output size: (C, O_H, O_W) Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape); - // filter size: (Co, Ci * Hf * Wf) - // col_matrix = filter * input_batch // of shape (C * K_H * K_W, H * W) math::matmul(context.device_context(), filter, true, @@ -132,8 +132,8 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel { const Tensor* output_grad = context.Input(framework::GradVarName("Output")); - // For filter, we do not use const pointer - // but we should avoid + // For filter, we do not use const pointer b/c we will do reshape + // but we should avoid modifying its value Tensor filter = *context.Input("Filter"); Tensor* input_grad = @@ -157,7 +157,7 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel { int O_H = output_grad->dims()[2]; int O_W = output_grad->dims()[3]; - // Two functors required to get to the right shape + // Only im2col functor required for bp to get to the right shape paddle::operators::math::Im2ColFunctor< paddle::operators::math::ColFormat::kCFO, Place, T> im2col; @@ -166,15 +166,13 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel { DDim col_shape = {C, K_H, K_W, H, W}; // use col_matrix_shape in the gemm calculation - DDim col_matrix_shape = {C * K_H * K_W, H * W}; + DDim col_matrix_shape_f = {C * H * W, K_H * K_W}; Tensor col; col.mutable_data(col_shape, context.GetPlace()); // col_matrix shares the same piece of data with col, // but will be reshaped into a two-dimensional matrix shape // to call the matrix multiplication interface. - Tensor col_matrix = col; - col_matrix.Resize(col_matrix_shape); DDim output_shape = {C, O_H, O_W}; DDim input_matrix_shape = {M, H * W}; @@ -186,6 +184,10 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel { // im2col + gemm (similar to conv-forward) // input need to compute gradient if (input_grad) { + Tensor col_matrix = col; + DDim col_matrix_shape = {C * K_H * K_W, H * W}; + col_matrix.Resize(col_matrix_shape); + input_grad->mutable_data(context.GetPlace()); auto t = framework::EigenVector::Flatten(*input_grad); t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); @@ -194,14 +196,18 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel { // batch with size (C, O_H * O_W) Tensor output_grad_batch = output_grad->Slice(i, i + 1).Resize(output_shape); + // filter of size (M, C * K_H * K_W) + // batch with size (M, H, W) Tensor input_grad_batch = input_grad->Slice(i, i + 1).Resize(input_matrix_shape); - // im2col: (C * K_H * K_W, H * W) + // im2col: dy from (C, O_H, O_W) -> (C * K_H * K_W, H * W) im2col(context.device_context(), output_grad_batch, col_matrix, strides[0], strides[1], paddings[0], paddings[1]); + // gemm: dx = filter * dy + // (M, C * K_H * K_W) * (C * K_H * K_W, H * W) -> (M, C, H) math::matmul(context.device_context(), filter, false, col_matrix, false, T(1.0), &input_grad_batch, T(0.0)); @@ -210,6 +216,10 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel { // filter gradient required if (filter_grad) { + Tensor col_matrix_f = col; + DDim col_matrix_shape_f = {C * H * W, K_H * K_W}; + col_matrix_f.Resize(col_matrix_shape_f); + filter_grad->mutable_data(context.GetPlace()); Tensor filter_grad_ = *filter_grad; filter_grad_.Resize(filter_matrix_shape); @@ -223,10 +233,12 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel { // input batch Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); - // im2col: (C * K_H * K_W, H * W) - im2col(context.device_context(), output_grad_batch, col_matrix, + // im2col: (C * H * W, K_H * K_W) + im2col(context.device_context(), output_grad_batch, col_matrix_f, strides[0], strides[1], paddings[0], paddings[1]); + // gemm: d_filter = x * y_grad^T + // (M, C * H * W) * (K_H * K_W, C * H * W) -> (M, C, H) math::matmul(context.device_context(), in_batch, false, col_matrix, true, T(1.0), &filter_grad_, T(1.0)); }