conv2d_transpose_op.cc 4.6 KB
Newer Older
Z
deconv  
zchen0211 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

Z
zchen0211 已提交
15
#include "paddle/operators/conv2d_transpose_op.h"
Z
deconv  
zchen0211 已提交
16 17 18 19

namespace paddle {
namespace operators {

Z
deconv  
zchen0211 已提交
20
void Conv2DTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
Z
deconv  
zchen0211 已提交
21
  PADDLE_ENFORCE(ctx->HasInput("Input"),
Z
deconv  
zchen0211 已提交
22
                 "Input(Input) of Conv2DTransposeOp should not be null.");
Z
deconv  
zchen0211 已提交
23
  PADDLE_ENFORCE(ctx->HasInput("Filter"),
Z
deconv  
zchen0211 已提交
24
                 "Input(Filter) of Conv2DTransposeOp should not be null.");
Z
deconv  
zchen0211 已提交
25
  PADDLE_ENFORCE(ctx->HasOutput("Output"),
Z
deconv  
zchen0211 已提交
26
                 "Output(Output) of Conv2DTransposeOp should not be null.");
Z
deconv  
zchen0211 已提交
27 28 29 30 31

  auto in_dims = ctx->GetInputDim("Input");
  auto filter_dims = ctx->GetInputDim("Filter");
  std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
  std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
Z
zchen0211 已提交
32

Z
zchen0211 已提交
33
  for (size_t i = 0; i < paddings.size(); ++i) {
Z
deconv  
zchen0211 已提交
34 35
    PADDLE_ENFORCE_EQ(paddings[i], 0,
                      "No Padding allowed in conv transpose op.");
Z
zchen0211 已提交
36 37
  }

Z
zchen0211 已提交
38
  PADDLE_ENFORCE_EQ(in_dims.size(), 4,
Z
deconv  
zchen0211 已提交
39
                    "Conv2DTransposeOp input should be 4-D tensor.");
Z
zchen0211 已提交
40
  PADDLE_ENFORCE_EQ(filter_dims.size(), 4,
Z
deconv  
zchen0211 已提交
41
                    "Conv2DTransposeOp filter should be 4-D tensor.");
Z
zchen0211 已提交
42 43 44
  PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0],
                    "input and kernel input dimension should be equal.");

Z
deconv  
zchen0211 已提交
45 46 47
  auto output_height = (in_dims[2] - 1) * strides[0] + filter_dims[2];
  auto output_width = (in_dims[3] - 1) * strides[1] + filter_dims[3];
  ctx->SetOutputDim("Output",
Z
zchen0211 已提交
48
                    {in_dims[0], filter_dims[1], output_height, output_width});
Z
deconv  
zchen0211 已提交
49 50
}

Z
deconv  
zchen0211 已提交
51 52
Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(
    framework::OpProto* proto, framework::OpAttrChecker* op_checker)
Z
deconv  
zchen0211 已提交
53 54 55
    : OpProtoAndCheckerMaker(proto, op_checker) {
  AddInput(
      "Input",
Z
zchen0211 已提交
56
      "(Tensor) The input tensor of convolution transpose operator. "
57 58 59
      "The format of input tensor is NCHW, where N is batch size, C is the "
      "number of input channels, H is the height of the image, and "
      "W is the width of the image.");
Z
deconv  
zchen0211 已提交
60
  AddInput("Filter",
Z
zchen0211 已提交
61
           "(Tensor) The filter tensor of convolution transpose operator."
Z
deconv  
zchen0211 已提交
62
           "The format of the filter tensor is CMHW, where C is the number of "
Z
zchen0211 已提交
63
           "output image channels, M is the number of input image channels, "
64
           "H is the height of the filter, and W is the width of the filter. "
Z
zchen0211 已提交
65
           "We enforce groups number == 1 and padding == 0 in "
66
           "the convolution transpose scenario.");
Z
deconv  
zchen0211 已提交
67
  AddOutput("Output",
Z
zchen0211 已提交
68
            "(Tensor) The output tensor of convolution transpose operator."
Z
deconv  
zchen0211 已提交
69
            "The format of output tensor is also NCHW.");
Z
deconv  
zchen0211 已提交
70 71
  AddAttr<std::vector<int>>("strides",
                            "strides of convolution transpose operator.")
Z
deconv  
zchen0211 已提交
72
      .SetDefault({1, 1});
Z
deconv  
zchen0211 已提交
73 74
  AddAttr<std::vector<int>>("paddings",
                            "paddings of convolution transpose operator.")
Z
deconv  
zchen0211 已提交
75 76
      .SetDefault({0, 0});
  AddComment(R"DOC(
77 78 79 80 81 82
Convolution Transpose Operator.

The convolution transpose operation calculates the output based on the input, 
filter, strides, paddings, and groups parameters. The size of each dimension 
of the parameters is checked in the infer-shape method.

Z
deconv  
zchen0211 已提交
83
)DOC");
Z
deconv  
zchen0211 已提交
84
}
Z
deconv  
zchen0211 已提交
85

Z
deconv  
zchen0211 已提交
86 87
void Conv2DTransposeOpGrad::InferShape(
    framework::InferShapeContext* ctx) const {
Z
deconv  
zchen0211 已提交
88 89 90 91 92 93 94
  auto in_dims = ctx->GetInputDim("Input");
  auto filter_dims = ctx->GetInputDim("Filter");
  if (ctx->HasOutput(framework::GradVarName("Input"))) {
    ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
  }
  if (ctx->HasOutput(framework::GradVarName("Filter"))) {
    ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
Z
deconv  
zchen0211 已提交
95
  }
Z
deconv  
zchen0211 已提交
96
}
Z
deconv  
zchen0211 已提交
97 98 99 100 101

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Z
zchen0211 已提交
102 103
REGISTER_OP(conv2d_transpose, ops::Conv2DTransposeOp,
            ops::Conv2DTransposeOpMaker, conv2d_transpose_grad,
Z
deconv  
zchen0211 已提交
104
            ops::Conv2DTransposeOpGrad);
Z
deconv  
zchen0211 已提交
105 106

REGISTER_OP_CPU_KERNEL(
Z
zchen0211 已提交
107
    conv2d_transpose,
Z
deconv  
zchen0211 已提交
108
    ops::GemmConv2DTransposeKernel<paddle::platform::CPUPlace, float>);
Z
deconv  
zchen0211 已提交
109
REGISTER_OP_CPU_KERNEL(
Z
zchen0211 已提交
110
    conv2d_transpose_grad,
Z
deconv  
zchen0211 已提交
111
    ops::GemmConv2DTransposeGradKernel<paddle::platform::CPUPlace, float>);