conv_transpose_op.cc 10.1 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

C
chengduoZH 已提交
15
#include "paddle/operators/conv_transpose_op.h"
C
chengduoZH 已提交
16 17 18 19

namespace paddle {
namespace operators {

C
chengduoZH 已提交
20
void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
C
chengduoZH 已提交
21
  PADDLE_ENFORCE(ctx->HasInput("Input"),
C
chengduoZH 已提交
22
                 "Input(Input) of ConvTransposeOp should not be null.");
C
chengduoZH 已提交
23
  PADDLE_ENFORCE(ctx->HasInput("Filter"),
C
chengduoZH 已提交
24
                 "Input(Filter) of ConvTransposeOp should not be null.");
C
chengduoZH 已提交
25
  PADDLE_ENFORCE(ctx->HasOutput("Output"),
C
chengduoZH 已提交
26
                 "Output(Output) of ConvTransposeOp should not be null.");
C
chengduoZH 已提交
27 28 29 30 31 32

  auto in_dims = ctx->GetInputDim("Input");
  auto filter_dims = ctx->GetInputDim("Filter");
  std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
  std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");

C
chengduoZH 已提交
33 34 35 36 37 38 39 40 41
  PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5,
                 "ConvTransposeOp intput should be 4-D or 5-D tensor.");
  PADDLE_ENFORCE_EQ(in_dims.size(), filter_dims.size(),
                    "ConvTransposeOp input dimension and filter dimension "
                    "should be the same.");
  PADDLE_ENFORCE(in_dims.size() - strides.size() == 2U,
                 "ConvTransposeOp input dimension and strides dimension should "
                 "be consistent.");
  PADDLE_ENFORCE_EQ(paddings.size(), strides.size(),
C
chengduoZH 已提交
42
                    "ConvTransposeOp paddings dimension and strides "
C
chengduoZH 已提交
43
                    "dimension should be the same.");
C
chengduoZH 已提交
44 45 46
  PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0],
                    "In ConvTransposeOp, The input channel should be the same "
                    "as the number of filters.");
C
chengduoZH 已提交
47

C
chengduoZH 已提交
48
  std::vector<int64_t> output_shape({in_dims[0], filter_dims[1]});
C
chengduoZH 已提交
49
  for (size_t i = 0; i < strides.size(); ++i) {
C
chengduoZH 已提交
50
    output_shape.push_back((in_dims[i + 2] - 1) * strides[i] - 2 * paddings[i] +
C
chengduoZH 已提交
51 52
                           filter_dims[i + 2]);
  }
C
chengduoZH 已提交
53
  ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
C
chengduoZH 已提交
54 55
}

C
chengduoZH 已提交
56 57 58 59 60 61 62
Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(
    framework::OpProto* proto, framework::OpAttrChecker* op_checker)
    : OpProtoAndCheckerMaker(proto, op_checker) {
  AddInput(
      "Input",
      "(Tensor) The input tensor of convolution transpose operator. "
      "The format of input tensor is NCHW. Where N is batch size, C is the "
C
chengduoZH 已提交
63 64
      "number of input channels, H is the height of the feature, and "
      "W is the width of the feature.");
C
chengduoZH 已提交
65 66 67 68 69 70 71 72
  AddInput(
      "Filter",
      "(Tensor) The filter tensor of convolution transpose operator. "
      "The format of the filter tensor is MCHW, where M is the number of "
      "input feature channels, C is the number of "
      "output feature channels,"
      "H is the height of the filter, and W is the width of the filter. "
      "We enforce groups number == 1 in the convolution transpose scenario.");
C
chengduoZH 已提交
73
  AddOutput("Output",
C
chengduoZH 已提交
74
            "(Tensor) The output tensor of convolution transpose operator. "
C
chengduoZH 已提交
75
            "The format of output tensor is also NCHW.");
C
chengduoZH 已提交
76 77 78 79 80 81

  AddAttr<std::vector<int>>("dilations",
                            "(vector<int> default:{1, 1}), the "
                            "dilations(h_dilation, w_dilation) of convolution "
                            "transpose operator.")
      .SetDefault({1, 1});
C
chengduoZH 已提交
82 83
  AddAttr<std::vector<int>>(
      "strides",
C
chengduoZH 已提交
84
      "(vector<int> default:{1, 1}), the strides(h_stride, w_stride) of "
85
      "convolution transpose operator.")
C
chengduoZH 已提交
86
      .SetDefault({1, 1});
C
chengduoZH 已提交
87 88
  AddAttr<std::vector<int>>(
      "paddings",
C
chengduoZH 已提交
89
      "(vector<int> default:{0, 0}), the paddings(h_pad, w_pad) of convolution "
C
chengduoZH 已提交
90
      "transpose operator.")
C
chengduoZH 已提交
91 92
      .SetDefault({0, 0});
  AddComment(R"DOC(
C
chengduoZH 已提交
93 94
Convolution2D Transpose Operator.

C
chengduoZH 已提交
95
The convolution transpose operation calculates the output based on the input, filter
C
chengduoZH 已提交
96
and dilations, strides, paddings, groups parameters. The size of each dimension of the
C
chengduoZH 已提交
97
parameters is checked in the infer-shape.
C
chengduoZH 已提交
98 99 100 101 102 103 104
Input(Input) and output(Output) are in NCHW format. Where N is batchsize, C is the
number of channels, H is the height of the feature, and W is the width of the feature.
Filter(Input) is in MCHW format. Where M is the number of input feature channels,
C is the number of output feature channels, H is the height of the filter,
and W is the width of the filter.
Parameters(strides, paddings) are two elements. These two elements represent height
and width, respectively.
C
chengduoZH 已提交
105
The input(X) size and output(Out) size may be different.
C
chengduoZH 已提交
106

C
chengduoZH 已提交
107 108
Example:
  Input:
C
chengduoZH 已提交
109 110
       Input shape: $(N, C_{in}, H_{in}, W_{in})$
       Filter shape: $(C_{in}, C_{out}, H_f, W_f)$
C
chengduoZH 已提交
111
  Output:
C
chengduoZH 已提交
112 113 114 115 116 117
       Output shape: $(N, C_{out}, H_{out}, W_{out})$
  Where
  $$
       H_{out} = (H_{in} - 1) * strides[0] - 2 * paddings[0] + H_f \\
       W_{out} = (W_{in} - 1) * strides[1] - 2 * paddings[1] + W_f
  $$
C
chengduoZH 已提交
118 119 120
)DOC");
}

C
chengduoZH 已提交
121 122 123
Conv3DTransposeOpMaker::Conv3DTransposeOpMaker(
    framework::OpProto* proto, framework::OpAttrChecker* op_checker)
    : OpProtoAndCheckerMaker(proto, op_checker) {
C
chengduoZH 已提交
124 125 126 127 128 129
  AddInput("Input",
           "(Tensor) The input tensor of convolution transpose operator."
           "The format of input tensor is NCDHW. Where N is batch size, C is "
           "the number of channels, D is the depth of the feature, H is the "
           "height of the feature, and "
           "W is the width of the feature.");
C
chengduoZH 已提交
130 131
  AddInput("Filter",
           "(Tensor) The filter tensor of convolution transpose operator."
C
chengduoZH 已提交
132 133 134
           "The format of the filter tensor is MCDHW, where M is the number of "
           "input feature channels, C is the number of "
           "output feature channels, D "
C
chengduoZH 已提交
135 136
           "is the depth of the filter, H is the height of the filter, and "
           "W is the width of the filter."
C
chengduoZH 已提交
137
           "We enforce groups number == 1 and padding == 0 in "
C
chengduoZH 已提交
138
           "the convolution3d transpose scenario.");
C
chengduoZH 已提交
139 140 141 142
  AddOutput("Output",
            "(Tensor) The output tensor of convolution transpose operator."
            "The format of output tensor is also NCDHW."
            "Where N is batch size, C is "
C
chengduoZH 已提交
143 144
            "the number of channels, D is the depth of the feature, H is the "
            "height of the feature, and W is the width of the feature.");
C
chengduoZH 已提交
145 146 147 148 149 150 151

  AddAttr<std::vector<int>>(
      "dilations",
      "(vector<int> default:{1, 1, 1}), the "
      "dilations(d_dilation,h_dilation, w_dilation) of convolution "
      "transpose operator.")
      .SetDefault({1, 1, 1});
C
chengduoZH 已提交
152
  AddAttr<std::vector<int>>("strides",
C
chengduoZH 已提交
153
                            "(vector<int> default:{1, 1, 1}), the "
154
                            "strides{d_stride, h_stride, w_stride} of "
C
chengduoZH 已提交
155
                            "convolution transpose operator.")
C
chengduoZH 已提交
156
      .SetDefault({1, 1, 1});
C
chengduoZH 已提交
157
  AddAttr<std::vector<int>>("paddings",
C
chengduoZH 已提交
158
                            "(vector<int> default:{0, 0, 0}), paddings(d_pad, "
C
chengduoZH 已提交
159
                            "h_pad, w_pad) of convolution transpose operator.")
C
chengduoZH 已提交
160 161
      .SetDefault({0, 0, 0});
  AddComment(R"DOC(
C
chengduoZH 已提交
162 163
Convolution3D Transpose Operator.

C
chengduoZH 已提交
164
The convolution transpose operation calculates the output based on the input, filter
C
chengduoZH 已提交
165
and dilations, strides, paddings, groups parameters. The size of each dimension of the
C
chengduoZH 已提交
166
parameters is checked in the infer-shape.
C
chengduoZH 已提交
167 168 169 170 171 172 173 174
Input(Input) and output(Output) are in NCDHW format. Where N is batch size, C is the
number of channels, D is the depth of the feature, H is the height of the feature,
and W is the width of the feature.
Filter(Input) is in MCDHW format. Where M is the number of input feature channels,
C is the number of output feature channels, D is the depth of the filter,H is the
height of the filter, and W is the width of the filter.
Parameters(strides, paddings) are three elements. These three elements represent
depth, height and width, respectively.
C
chengduoZH 已提交
175
The input(X) size and output(Out) size may be different.
C
chengduoZH 已提交
176 177

Example:   
C
chengduoZH 已提交
178
  Input:
C
chengduoZH 已提交
179 180
       Input shape: $(N, C_{in}, D_{in}, H_{in}, W_{in})$
       Filter shape: $(C_{in}, C_{out}, D_f, H_f, W_f)$
C
chengduoZH 已提交
181
  Output:
C
chengduoZH 已提交
182 183 184 185 186 187 188
       Output shape: $(N, C_{out}, D_{out}, H_{out}, W_{out})$
  Where
  $$
       D_{out} = (D_{in} - 1) * strides[0] - 2 * paddings[0] + D_f \\
       H_{out} = (H_{in} - 1) * strides[1] - 2 * paddings[1] + H_f \\
       W_{out} = (W_{in} - 1) * strides[2] - 2 * paddings[2] + W_f
  $$
C
chengduoZH 已提交
189 190 191
)DOC");
}

C
chengduoZH 已提交
192
void ConvTransposeOpGrad::InferShape(framework::InferShapeContext* ctx) const {
C
chengduoZH 已提交
193 194 195 196 197 198 199 200 201 202 203 204 205 206
  auto in_dims = ctx->GetInputDim("Input");
  auto filter_dims = ctx->GetInputDim("Filter");
  if (ctx->HasOutput(framework::GradVarName("Input"))) {
    ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
  }
  if (ctx->HasOutput(framework::GradVarName("Filter"))) {
    ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
  }
}

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
C
chengduoZH 已提交
207

C
chengduoZH 已提交
208 209
REGISTER_OP(conv2d_transpose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker,
            conv2d_transpose_grad, ops::ConvTransposeOpGrad);
C
chengduoZH 已提交
210 211

REGISTER_OP_CPU_KERNEL(
C
chengduoZH 已提交
212
    conv2d_transpose,
C
chengduoZH 已提交
213 214
    ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
    ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
C
chengduoZH 已提交
215
REGISTER_OP_CPU_KERNEL(
C
chengduoZH 已提交
216
    conv2d_transpose_grad,
C
chengduoZH 已提交
217 218
    ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
    ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);
C
chengduoZH 已提交
219

C
chengduoZH 已提交
220 221
REGISTER_OP(conv3d_transpose, ops::ConvTransposeOp, ops::Conv3DTransposeOpMaker,
            conv3d_transpose_grad, ops::ConvTransposeOpGrad);
C
chengduoZH 已提交
222 223

REGISTER_OP_CPU_KERNEL(
C
chengduoZH 已提交
224
    conv3d_transpose,
C
chengduoZH 已提交
225 226
    ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
    ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
C
chengduoZH 已提交
227
REGISTER_OP_CPU_KERNEL(
C
chengduoZH 已提交
228
    conv3d_transpose_grad,
C
chengduoZH 已提交
229 230
    ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
    ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);