conv_transpose_op.cc 9.7 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

C
chengduoZH 已提交
15
#include "paddle/operators/conv_transpose_op.h"
C
chengduoZH 已提交
16 17 18 19

namespace paddle {
namespace operators {

C
chengduoZH 已提交
20
void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
C
chengduoZH 已提交
21
  PADDLE_ENFORCE(ctx->HasInput("Input"),
C
chengduoZH 已提交
22
                 "Input(Input) of ConvTransposeOp should not be null.");
C
chengduoZH 已提交
23
  PADDLE_ENFORCE(ctx->HasInput("Filter"),
C
chengduoZH 已提交
24
                 "Input(Filter) of ConvTransposeOp should not be null.");
C
chengduoZH 已提交
25
  PADDLE_ENFORCE(ctx->HasOutput("Output"),
C
chengduoZH 已提交
26
                 "Output(Output) of ConvTransposeOp should not be null.");
C
chengduoZH 已提交
27 28 29 30 31 32

  auto in_dims = ctx->GetInputDim("Input");
  auto filter_dims = ctx->GetInputDim("Filter");
  std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
  std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");

C
chengduoZH 已提交
33 34 35 36 37 38 39 40 41
  PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5,
                 "ConvTransposeOp intput should be 4-D or 5-D tensor.");
  PADDLE_ENFORCE_EQ(in_dims.size(), filter_dims.size(),
                    "ConvTransposeOp input dimension and filter dimension "
                    "should be the same.");
  PADDLE_ENFORCE(in_dims.size() - strides.size() == 2U,
                 "ConvTransposeOp input dimension and strides dimension should "
                 "be consistent.");
  PADDLE_ENFORCE_EQ(paddings.size(), strides.size(),
C
chengduoZH 已提交
42
                    "ConvTransposeOp paddings dimension and strides "
C
chengduoZH 已提交
43
                    "dimension should be the same.");
C
chengduoZH 已提交
44 45 46
  PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0],
                    "In ConvTransposeOp, The input channel should be the same "
                    "as the number of filters.");
C
chengduoZH 已提交
47

C
chengduoZH 已提交
48
  std::vector<int64_t> output_shape({in_dims[0], filter_dims[1]});
C
chengduoZH 已提交
49
  for (size_t i = 0; i < strides.size(); ++i) {
C
chengduoZH 已提交
50
    output_shape.push_back((in_dims[i + 2] - 1) * strides[i] - 2 * paddings[i] +
C
chengduoZH 已提交
51 52
                           filter_dims[i + 2]);
  }
C
chengduoZH 已提交
53
  ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
C
chengduoZH 已提交
54 55
}

C
chengduoZH 已提交
56 57 58 59 60 61 62
Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(
    framework::OpProto* proto, framework::OpAttrChecker* op_checker)
    : OpProtoAndCheckerMaker(proto, op_checker) {
  AddInput(
      "Input",
      "(Tensor) The input tensor of convolution transpose operator. "
      "The format of input tensor is NCHW. Where N is batch size, C is the "
C
chengduoZH 已提交
63 64
      "number of input channels, H is the height of the feature, and "
      "W is the width of the feature.");
C
chengduoZH 已提交
65 66 67 68 69 70 71 72
  AddInput(
      "Filter",
      "(Tensor) The filter tensor of convolution transpose operator. "
      "The format of the filter tensor is MCHW, where M is the number of "
      "input feature channels, C is the number of "
      "output feature channels,"
      "H is the height of the filter, and W is the width of the filter. "
      "We enforce groups number == 1 in the convolution transpose scenario.");
C
chengduoZH 已提交
73
  AddOutput("Output",
C
chengduoZH 已提交
74
            "(Tensor) The output tensor of convolution transpose operator. "
C
chengduoZH 已提交
75
            "The format of output tensor is also NCHW.");
C
chengduoZH 已提交
76 77
  AddAttr<std::vector<int>>(
      "strides",
C
chengduoZH 已提交
78
      "(vector<int> default:{1, 1}), the strides(h_stride, w_stride) of "
79
      "convolution transpose operator.")
C
chengduoZH 已提交
80
      .SetDefault({1, 1});
C
chengduoZH 已提交
81 82
  AddAttr<std::vector<int>>(
      "paddings",
C
chengduoZH 已提交
83
      "(vector<int> default:{0, 0}), the paddings(h_pad, w_pad) of convolution "
C
chengduoZH 已提交
84
      "transpose operator.")
C
chengduoZH 已提交
85 86
      .SetDefault({0, 0});
  AddComment(R"DOC(
C
chengduoZH 已提交
87 88
Convolution2D Transpose Operator.

C
chengduoZH 已提交
89 90 91
The convolution transpose operation calculates the output based on the input, filter
and strides, paddings, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
C
chengduoZH 已提交
92 93 94 95 96 97 98
Input(Input) and output(Output) are in NCHW format. Where N is batchsize, C is the
number of channels, H is the height of the feature, and W is the width of the feature.
Filter(Input) is in MCHW format. Where M is the number of input feature channels,
C is the number of output feature channels, H is the height of the filter,
and W is the width of the filter.
Parameters(strides, paddings) are two elements. These two elements represent height
and width, respectively.
C
chengduoZH 已提交
99
The input(X) size and output(Out) size may be different.
C
chengduoZH 已提交
100

C
chengduoZH 已提交
101 102
Example:
  Input:
C
chengduoZH 已提交
103 104
       Input shape: $(N, C_{in}, H_{in}, W_{in})$
       Filter shape: $(C_{in}, C_{out}, H_f, W_f)$
C
chengduoZH 已提交
105
  Output:
C
chengduoZH 已提交
106 107 108 109 110 111
       Output shape: $(N, C_{out}, H_{out}, W_{out})$
  Where
  $$
       H_{out} = (H_{in} - 1) * strides[0] - 2 * paddings[0] + H_f \\
       W_{out} = (W_{in} - 1) * strides[1] - 2 * paddings[1] + W_f
  $$
C
chengduoZH 已提交
112 113 114
)DOC");
}

C
chengduoZH 已提交
115 116 117
Conv3DTransposeOpMaker::Conv3DTransposeOpMaker(
    framework::OpProto* proto, framework::OpAttrChecker* op_checker)
    : OpProtoAndCheckerMaker(proto, op_checker) {
C
chengduoZH 已提交
118 119 120 121 122 123
  AddInput("Input",
           "(Tensor) The input tensor of convolution transpose operator."
           "The format of input tensor is NCDHW. Where N is batch size, C is "
           "the number of channels, D is the depth of the feature, H is the "
           "height of the feature, and "
           "W is the width of the feature.");
C
chengduoZH 已提交
124 125
  AddInput("Filter",
           "(Tensor) The filter tensor of convolution transpose operator."
C
chengduoZH 已提交
126 127 128
           "The format of the filter tensor is MCDHW, where M is the number of "
           "input feature channels, C is the number of "
           "output feature channels, D "
C
chengduoZH 已提交
129 130
           "is the depth of the filter, H is the height of the filter, and "
           "W is the width of the filter."
C
chengduoZH 已提交
131
           "We enforce groups number == 1 and padding == 0 in "
C
chengduoZH 已提交
132
           "the convolution3d transpose scenario.");
C
chengduoZH 已提交
133 134 135 136
  AddOutput("Output",
            "(Tensor) The output tensor of convolution transpose operator."
            "The format of output tensor is also NCDHW."
            "Where N is batch size, C is "
C
chengduoZH 已提交
137 138
            "the number of channels, D is the depth of the feature, H is the "
            "height of the feature, and W is the width of the feature.");
C
chengduoZH 已提交
139
  AddAttr<std::vector<int>>("strides",
C
chengduoZH 已提交
140
                            "(vector<int> default:{1, 1, 1}), the "
141
                            "strides{d_stride, h_stride, w_stride} of "
C
chengduoZH 已提交
142
                            "convolution transpose operator.")
C
chengduoZH 已提交
143
      .SetDefault({1, 1, 1});
C
chengduoZH 已提交
144
  AddAttr<std::vector<int>>("paddings",
C
chengduoZH 已提交
145
                            "(vector<int> default:{0, 0, 0}), paddings(d_pad, "
C
chengduoZH 已提交
146
                            "h_pad, w_pad) of convolution transpose operator.")
C
chengduoZH 已提交
147 148
      .SetDefault({0, 0, 0});
  AddComment(R"DOC(
C
chengduoZH 已提交
149 150
Convolution3D Transpose Operator.

C
chengduoZH 已提交
151 152 153
The convolution transpose operation calculates the output based on the input, filter
and strides, paddings, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
C
chengduoZH 已提交
154 155 156 157 158 159 160 161
Input(Input) and output(Output) are in NCDHW format. Where N is batch size, C is the
number of channels, D is the depth of the feature, H is the height of the feature,
and W is the width of the feature.
Filter(Input) is in MCDHW format. Where M is the number of input feature channels,
C is the number of output feature channels, D is the depth of the filter,H is the
height of the filter, and W is the width of the filter.
Parameters(strides, paddings) are three elements. These three elements represent
depth, height and width, respectively.
C
chengduoZH 已提交
162
The input(X) size and output(Out) size may be different.
C
chengduoZH 已提交
163 164

Example:   
C
chengduoZH 已提交
165
  Input:
C
chengduoZH 已提交
166 167
       Input shape: $(N, C_{in}, D_{in}, H_{in}, W_{in})$
       Filter shape: $(C_{in}, C_{out}, D_f, H_f, W_f)$
C
chengduoZH 已提交
168
  Output:
C
chengduoZH 已提交
169 170 171 172 173 174 175
       Output shape: $(N, C_{out}, D_{out}, H_{out}, W_{out})$
  Where
  $$
       D_{out} = (D_{in} - 1) * strides[0] - 2 * paddings[0] + D_f \\
       H_{out} = (H_{in} - 1) * strides[1] - 2 * paddings[1] + H_f \\
       W_{out} = (W_{in} - 1) * strides[2] - 2 * paddings[2] + W_f
  $$
C
chengduoZH 已提交
176 177 178
)DOC");
}

C
chengduoZH 已提交
179
void ConvTransposeOpGrad::InferShape(framework::InferShapeContext* ctx) const {
C
chengduoZH 已提交
180 181 182 183 184 185 186 187 188 189 190 191 192 193
  auto in_dims = ctx->GetInputDim("Input");
  auto filter_dims = ctx->GetInputDim("Filter");
  if (ctx->HasOutput(framework::GradVarName("Input"))) {
    ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
  }
  if (ctx->HasOutput(framework::GradVarName("Filter"))) {
    ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
  }
}

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
C
chengduoZH 已提交
194

C
chengduoZH 已提交
195 196
REGISTER_OP(conv2d_transpose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker,
            conv2d_transpose_grad, ops::ConvTransposeOpGrad);
C
chengduoZH 已提交
197 198

REGISTER_OP_CPU_KERNEL(
C
chengduoZH 已提交
199
    conv2d_transpose,
Q
QI JUN 已提交
200 201
    ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float>,
    ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double>);
C
chengduoZH 已提交
202
REGISTER_OP_CPU_KERNEL(
C
chengduoZH 已提交
203
    conv2d_transpose_grad,
Q
QI JUN 已提交
204 205 206
    ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
    ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
                                     double>);
C
chengduoZH 已提交
207

C
chengduoZH 已提交
208 209
REGISTER_OP(conv3d_transpose, ops::ConvTransposeOp, ops::Conv3DTransposeOpMaker,
            conv3d_transpose_grad, ops::ConvTransposeOpGrad);
C
chengduoZH 已提交
210 211

REGISTER_OP_CPU_KERNEL(
C
chengduoZH 已提交
212
    conv3d_transpose,
Q
QI JUN 已提交
213 214
    ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float>,
    ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double>);
C
chengduoZH 已提交
215
REGISTER_OP_CPU_KERNEL(
C
chengduoZH 已提交
216
    conv3d_transpose_grad,
Q
QI JUN 已提交
217 218 219
    ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
    ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
                                     double>);