conv2d_transpose_op.cc 4.5 KB
Newer Older
Z
deconv  
zchen0211 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

Z
zchen0211 已提交
15
#include "paddle/operators/conv2d_transpose_op.h"
Z
deconv  
zchen0211 已提交
16 17 18 19

namespace paddle {
namespace operators {

Z
deconv  
zchen0211 已提交
20
void Conv2DTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
Z
deconv  
zchen0211 已提交
21
  PADDLE_ENFORCE(ctx->HasInput("Input"),
Z
deconv  
zchen0211 已提交
22
                 "Input(Input) of Conv2DTransposeOp should not be null.");
Z
deconv  
zchen0211 已提交
23
  PADDLE_ENFORCE(ctx->HasInput("Filter"),
Z
deconv  
zchen0211 已提交
24
                 "Input(Filter) of Conv2DTransposeOp should not be null.");
Z
deconv  
zchen0211 已提交
25
  PADDLE_ENFORCE(ctx->HasOutput("Output"),
Z
deconv  
zchen0211 已提交
26
                 "Output(Output) of Conv2DTransposeOp should not be null.");
Z
deconv  
zchen0211 已提交
27 28 29 30 31

  auto in_dims = ctx->GetInputDim("Input");
  auto filter_dims = ctx->GetInputDim("Filter");
  std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
  std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
Z
zchen0211 已提交
32

Z
zchen0211 已提交
33
  for (size_t i = 0; i < paddings.size(); ++i) {
Z
deconv  
zchen0211 已提交
34 35
    PADDLE_ENFORCE_EQ(paddings[i], 0,
                      "No Padding allowed in conv transpose op.");
Z
zchen0211 已提交
36 37
  }

Z
zchen0211 已提交
38
  PADDLE_ENFORCE_EQ(in_dims.size(), 4,
Z
deconv  
zchen0211 已提交
39
                    "Conv2DTransposeOp input should be 4-D tensor.");
Z
zchen0211 已提交
40
  PADDLE_ENFORCE_EQ(filter_dims.size(), 4,
Z
deconv  
zchen0211 已提交
41
                    "Conv2DTransposeOp filter should be 4-D tensor.");
Z
zchen0211 已提交
42 43 44
  PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0],
                    "input and kernel input dimension should be equal.");

Z
deconv  
zchen0211 已提交
45 46 47
  auto output_height = (in_dims[2] - 1) * strides[0] + filter_dims[2];
  auto output_width = (in_dims[3] - 1) * strides[1] + filter_dims[3];
  ctx->SetOutputDim("Output",
Z
zchen0211 已提交
48
                    {in_dims[0], filter_dims[1], output_height, output_width});
Z
deconv  
zchen0211 已提交
49 50
}

Z
deconv  
zchen0211 已提交
51 52
Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(
    framework::OpProto* proto, framework::OpAttrChecker* op_checker)
Z
deconv  
zchen0211 已提交
53 54 55
    : OpProtoAndCheckerMaker(proto, op_checker) {
  AddInput(
      "Input",
Z
zchen0211 已提交
56
      "(Tensor) The input tensor of convolution transpose operator. "
Z
zchen0211 已提交
57
      "The format of input tensor is NCHW. Where N is batch size, C is the "
Z
zchen0211 已提交
58
      "number of input channels, H and W is the height and width of image.");
Z
deconv  
zchen0211 已提交
59
  AddInput("Filter",
Z
zchen0211 已提交
60
           "(Tensor) The filter tensor of convolution transpose operator."
Z
deconv  
zchen0211 已提交
61
           "The format of the filter tensor is CMHW, where C is the number of "
Z
zchen0211 已提交
62
           "output image channels, M is the number of input image channels, "
Z
deconv  
zchen0211 已提交
63
           "H and W is height and width of filter. "
Z
zchen0211 已提交
64
           "We enforce groups number == 1 and padding == 0 in "
Z
deconv  
zchen0211 已提交
65
           "convolution transpose Scenario.");
Z
deconv  
zchen0211 已提交
66
  AddOutput("Output",
Z
zchen0211 已提交
67
            "(Tensor) The output tensor of convolution transpose operator."
Z
deconv  
zchen0211 已提交
68
            "The format of output tensor is also NCHW.");
Z
deconv  
zchen0211 已提交
69 70
  AddAttr<std::vector<int>>("strides",
                            "strides of convolution transpose operator.")
Z
deconv  
zchen0211 已提交
71
      .SetDefault({1, 1});
Z
deconv  
zchen0211 已提交
72 73
  AddAttr<std::vector<int>>("paddings",
                            "paddings of convolution transpose operator.")
Z
deconv  
zchen0211 已提交
74 75
      .SetDefault({0, 0});
  AddComment(R"DOC(
Z
deconv  
zchen0211 已提交
76
The convolution transpose operation calculates the output based on the input, filter
Z
deconv  
zchen0211 已提交
77 78 79
and strides, paddings, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
)DOC");
Z
deconv  
zchen0211 已提交
80
}
Z
deconv  
zchen0211 已提交
81

Z
deconv  
zchen0211 已提交
82 83
void Conv2DTransposeOpGrad::InferShape(
    framework::InferShapeContext* ctx) const {
Z
deconv  
zchen0211 已提交
84 85 86 87 88 89 90
  auto in_dims = ctx->GetInputDim("Input");
  auto filter_dims = ctx->GetInputDim("Filter");
  if (ctx->HasOutput(framework::GradVarName("Input"))) {
    ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
  }
  if (ctx->HasOutput(framework::GradVarName("Filter"))) {
    ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
Z
deconv  
zchen0211 已提交
91
  }
Z
deconv  
zchen0211 已提交
92
}
Z
deconv  
zchen0211 已提交
93 94 95 96 97

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Z
zchen0211 已提交
98 99
REGISTER_OP(conv2d_transpose, ops::Conv2DTransposeOp,
            ops::Conv2DTransposeOpMaker, conv2d_transpose_grad,
Z
deconv  
zchen0211 已提交
100
            ops::Conv2DTransposeOpGrad);
Z
deconv  
zchen0211 已提交
101 102

REGISTER_OP_CPU_KERNEL(
Z
zchen0211 已提交
103
    conv2d_transpose,
Z
deconv  
zchen0211 已提交
104
    ops::GemmConv2DTransposeKernel<paddle::platform::CPUPlace, float>);
Z
deconv  
zchen0211 已提交
105
REGISTER_OP_CPU_KERNEL(
Z
zchen0211 已提交
106
    conv2d_transpose_grad,
Z
deconv  
zchen0211 已提交
107
    ops::GemmConv2DTransposeGradKernel<paddle::platform::CPUPlace, float>);