diff --git a/paddle/operators/deconv2d_op.cc b/paddle/operators/deconv2d_op.cc index 6b71a1fea769bb4857c965c5e83566a570ced51d..0abe2a8fba944f876d7ac143c20edffcb8884b26 100644 --- a/paddle/operators/deconv2d_op.cc +++ b/paddle/operators/deconv2d_op.cc @@ -31,22 +31,23 @@ void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const { std::vector strides = ctx->Attrs().Get>("strides"); std::vector paddings = ctx->Attrs().Get>("paddings"); int groups = ctx->Attrs().Get("groups"); - int input_channels = in_dims[1]; - int output_channels = filter_dims[0]; - - PADDLE_ENFORCE_EQ(in_dims.size(), 4, "Conv2DOp input should be 4-D."); - PADDLE_ENFORCE_EQ(filter_dims.size(), 4, "Conv2DOp filter should be 4-D."); - PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups, - "The number of input channels should be equal to filter " - "channels * groups."); - PADDLE_ENFORCE_EQ( - output_channels % groups, 0, - "The number of output channels should be divided by groups."); + + for (int i = 0; i < paddings.size(); ++i) { + PADDLE_ENFORCE_EQ(paddings[i], 0, "No Padding allowed in deconv op."); + } + + PADDLE_ENFORCE_EQ(in_dims.size(), 4, "Deconv2DOp input should be 4-D."); + PADDLE_ENFORCE_EQ(filter_dims.size(), 4, "Deconv2DOp filter should be 4-D."); + PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0], + "input and kernel input dimension should be equal."); + + PADDLE_ENFORCE_EQ(groups, 1, + "The number of groups should be 1 in case of deconv op."); auto output_height = (in_dims[2] - 1) * strides[0] + filter_dims[2]; auto output_width = (in_dims[3] - 1) * strides[1] + filter_dims[3]; ctx->SetOutputDim("Output", - {in_dims[0], filter_dims[0], output_height, output_width}); + {in_dims[0], filter_dims[1], output_height, output_width}); } Deconv2DOpMaker::Deconv2DOpMaker(framework::OpProto* proto, @@ -55,12 +56,12 @@ Deconv2DOpMaker::Deconv2DOpMaker(framework::OpProto* proto, AddInput( "Input", "The input tensor of deconvolution operator. " - "The format of input tensor is NCHW. Where N is batch size, C is the " - "number of channels, H and W is the height and width of image."); + "The format of input tensor is NMHW. Where N is batch size, M is the " + "number of input channels, H and W is the height and width of image."); AddInput("Filter", "The filter tensor of deconvolution operator." "The format of the filter tensor is MCHW, where M is the number of " - "output image channels, C is the number of input image channels, " + "input image channels, C is the number of output image channels, " "H and W is height and width of filter. " "We enforce groups number == 1 and padding == 0 in our " "deconvolution Scenario."); @@ -97,6 +98,6 @@ REGISTER_OP(deconv2d, ops::Deconv2DOp, ops::Deconv2DOpMaker, deconv2d_grad, ops::Deconv2DOpGrad); REGISTER_OP_CPU_KERNEL( - deconv2d, ops::GemmConvGrad2DKernel); + deconv2d, ops::GemmDeconv2DKernel); REGISTER_OP_CPU_KERNEL( deconv2d_grad, ops::GemmConv2DKernel); diff --git a/paddle/operators/deconv2d_op.h b/paddle/operators/deconv2d_op.h index 4f5a0242b114874a3fcb1f261a6b80988d709bca..fbba421ae930f7360211e45b5a5e6152cb4dc573 100644 --- a/paddle/operators/deconv2d_op.h +++ b/paddle/operators/deconv2d_op.h @@ -23,6 +23,7 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; +using DDim = framework::DDim; // Define Op classes in .h file so that other deconv // operator implementations can reuse the code. @@ -48,5 +49,167 @@ class Deconv2DOpGrad : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; }; +template +class GemmDeconv2DKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + // filter will be reshaped, so we do not use constant pointer here + Tensor filter = *context.Input("Filter"); + + Tensor* output = context.Output("Output"); + + std::vector strides = context.Attr>("strides"); + + // no paddings and groups allowed in deconv + + int N = input->dims()[0]; + int M = input->dims()[1]; + int H = input->dims()[2]; + int W = input->dims()[3]; + + int K_H = filter.dims()[2]; + int K_W = filter.dims()[3]; + + int C = output->dims()[1]; // output channels + int O_H = output->dims()[2]; + int O_W = output->dims()[3]; + + paddle::operators::math::Col2ImFunctor< + paddle::operators::math::ColFormat::kCFO, Place, T> + col2im; + + // use col_shape in the im2col and col2im calculation + framework::DDim col_shape = {C, K_H, K_W, H, W}; + + // use col_matrix_shape in the gemm calculation + framework::DDim col_matrix_shape = {M * K_H * K_W, H * W}; + + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix = col; + col_matrix.Resize(col_matrix_shape); + + DDim output_shape = {C, O_H, O_W}; + DDim input_matrix_shape = {M, H * W}; + + DDim filter_matrix_shape = {M, C * K_H * K_W}; + filter.Resize(filter_matrix_shape); + + // deconvolution: gemm + col2im (similar to conv-backward on input) + + output->mutable_data(context.GetPlace()); + auto t = framework::EigenVector::Flatten(*output); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < N; i++) { + // batch with size (M, H * W) + Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); + // output size: (C, O_H, O_W) + Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape); + + // filter size: (Co, Ci * Hf * Wf) + + // col_matrix = filter * input_batch + // of shape (C * K_H * K_W, H * W) + math::matmul(context.device_context(), filter, true, + input_batch, false, T(1.0), &col_matrix, T(0.0)); + + col2im(context.device_context(), output_batch, col_matrix, strides[0], + strides[1], 0, 0); + } + } +}; + +/* +template +class GemmDeconvGrad2DKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + const Tensor* output_grad = + context.Input(framework::GradVarName("Output")); + + // For filter, we do not use const pointer + // but we should avoid + Tensor filter = *context.Input("Filter"); + + Tensor* input_grad = + context.Output(framework::GradVarName("Input")); + Tensor* filter_grad = + context.Output(framework::GradVarName("Filter")); + + std::vector strides = context.Attr>("strides"); + + // no paddings and groups allowed in deconv + + int N = input->dims()[0]; + int M = input->dims()[1]; + int H = input->dims()[2]; + int W = input->dims()[3]; + + int K_H = filter.dims()[2]; + int K_W = filter.dims()[3]; + + int C = output->dims()[1]; // output channels + int O_H = output->dims()[2]; + int O_W = output->dims()[3]; + + paddle::operators::math::Col2ImFunctor< + paddle::operators::math::ColFormat::kCFO, Place, T> + col2im; + + // use col_shape in the im2col and col2im calculation + framework::DDim col_shape = {C, K_H, K_W, H, W}; + + // use col_matrix_shape in the gemm calculation + framework::DDim col_matrix_shape = {M * K_H * K_W, H * W}; + + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix = col; + col_matrix.Resize(col_matrix_shape); + + DDim output_shape = {C, O_H, O_W}; + DDim input_matrix_shape = {M, H * W}; + + DDim filter_matrix_shape = {M, C* K_H * K_W}; + filter.Resize(filter_matrix_shape); + + // deconvolution: gemm + col2im (similar to conv-backward on input) + + output->mutable_data(context.GetPlace()); + auto t = framework::EigenVector::Flatten(*output); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < N; i++) { + // batch with size (M, H * W) + Tensor input_batch = + input->Slice(i, i + 1).Resize(input_matrix_shape); + // output size: (C, O_H, O_W) + Tensor output_batch = + output->Slice(i, i + 1).Resize(output_shape); + + // filter size: (Co, Ci * Hf * Wf) + + // col_matrix = filter * input_batch + // of shape (C * K_H * K_W, H * W) + math::matmul(context.device_context(), filter, true, + input_batch, false, T(1.0), &col_matrix, + T(0.0)); + + col2im(context.device_context(), output_batch, col_matrix, strides[0], + strides[1], 0, 0); + } + } +}; +*/ + } // namespace operators } // namespace paddle