conv_op.cc 27.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
C
chengduoZH 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
C
chengduoZH 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
C
chengduoZH 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/conv_op.h"
Y
Update  
Yi Wang 已提交
16

17
#include <memory>
Y
Update  
Yi Wang 已提交
18 19 20
#include <string>
#include <vector>

21
#include "paddle/fluid/framework/op_version_registry.h"
22
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
23

24 25 26
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
H
hong 已提交
27 28 29
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/infermeta/binary.h"

C
chengduoZH 已提交
30 31 32
namespace paddle {
namespace operators {

33 34
std::vector<int64_t> ConvOp::ComputeOutputShape(
    framework::InferShapeContext* ctx) const {
35 36
  OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Conv");
  OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "Conv");
C
chengduoZH 已提交
37 38 39

  auto in_dims = ctx->GetInputDim("Input");
  auto filter_dims = ctx->GetInputDim("Filter");
40

C
chengduoZH 已提交
41 42
  std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
  std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
L
liym27 已提交
43 44
  std::string padding_algorithm =
      ctx->Attrs().Get<std::string>("padding_algorithm");
C
chengduoZH 已提交
45
  int groups = ctx->Attrs().Get<int>("groups");
C
chengduoZH 已提交
46
  std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
47 48 49
  int dilation_size = dilations.size();
  for (int i = 0; i < dilation_size; ++i) {
    PADDLE_ENFORCE_GT(
50 51
        dilations[i],
        0,
52 53 54 55 56
        platform::errors::InvalidArgument(
            "The dilation of Op(Conv) should be larget than 0, but received "
            "dilation is %d.",
            dilations[i]));
  }
L
liym27 已提交
57
  const std::string data_format = ctx->Attrs().Get<std::string>("data_format");
58 59 60

  // MKL-DNN Kernels are using NCHW order of dims description
  // so we ignore data_format consideration for MKL-DNN kernel
61
  const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) &&
62
                            (data_format == "NHWC" || data_format == "NDHWC");
C
chengduoZH 已提交
63

64
  PADDLE_ENFORCE_EQ(
65 66
      in_dims.size() == 4 || in_dims.size() == 5,
      true,
67
      platform::errors::InvalidArgument(
68 69
          "The input of Op(Conv) should be a 4-D or 5-D Tensor. But "
          "received: input's dimension is %u, input's shape is [%s].",
70 71
          in_dims.size(),
          in_dims));
72

C
chengduoZH 已提交
73
  PADDLE_ENFORCE_EQ(
74 75
      in_dims.size(),
      filter_dims.size(),
76
      platform::errors::InvalidArgument(
77 78 79 80
          "The input's dimension and filter's dimension of "
          "Op(Conv) should be equal. But received: the input's shape is [%s], "
          "the input's dimension is %d; the filter's shape is [%s],  "
          "the filter's dimension is %d.",
81 82 83 84
          in_dims,
          in_dims.size(),
          filter_dims,
          filter_dims.size()));
85

86 87 88
  int stride_size = strides.size();
  for (int i = 0; i < stride_size; ++i) {
    PADDLE_ENFORCE_GT(
89 90
        strides[i],
        0,
91 92 93 94 95 96 97
        platform::errors::InvalidArgument(
            "The stride of Op(Conv) should be larget than 0, but received "
            "stride is %d.",
            strides[i]));
  }

  int in_sub_stride_size = in_dims.size() - stride_size;
98
  PADDLE_ENFORCE_EQ(
99 100
      in_dims.size(),
      strides.size() + 2U,
101
      platform::errors::InvalidArgument(
102 103 104 105 106
          "The difference of input's dimension and Attr(strides)'s "
          "length must be euqal to 2 for Op(Conv). "
          "But received: input's dimension is %d, input's shape is [%s]; "
          "Attr(stride)'s length is %d, Attr(stride) is [%s]; "
          "difference of input's dimention and Attr(strides)'s length = %u.",
107 108 109 110
          in_dims.size(),
          in_dims,
          strides.size(),
          phi::make_ddim(strides),
111
          in_sub_stride_size));
L
liym27 已提交
112 113 114

  const auto input_channels =
      channel_last ? in_dims[in_dims.size() - 1] : in_dims[1];
F
fengjiayi 已提交
115

116
  PADDLE_ENFORCE_EQ(
117 118
      input_channels,
      filter_dims[1] * groups,
119
      platform::errors::InvalidArgument(
120 121 122 123 124
          "The number of input's channels should be equal to filter's channels "
          "* groups for Op(Conv). But received: the input's channels is %d, "
          "the input's shape is [%s]; the filter's channels is %d, the "
          "filter's shape is [%s]; the groups is %d, the data_format is %s. "
          "The error may come from wrong data_format setting.",
125 126 127 128 129
          input_channels,
          in_dims,
          filter_dims[1],
          filter_dims,
          groups,
130
          data_format));
C
chengduoZH 已提交
131
  PADDLE_ENFORCE_EQ(
132 133
      filter_dims[0] % groups,
      0,
134
      platform::errors::InvalidArgument(
135 136 137 138
          "The number of output's channels (filter's first dimension) of "
          "Op(Conv) should be divided by groups. But received: "
          "the output channels is %d, the filter's shape is [%s], "
          "the groups is %d.",
139 140 141
          filter_dims[0],
          filter_dims,
          groups));
W
wangxinxin08 已提交
142 143 144

  if (ctx->IsRuntime()) {
    PADDLE_ENFORCE_GT(
145 146
        filter_dims[0],
        0,
W
wangxinxin08 已提交
147 148 149
        platform::errors::InvalidArgument(
            "the size of filter at axis 0 should be greater than 0"));
  }
C
chengduoZH 已提交
150

L
liym27 已提交
151 152
  framework::DDim in_data_dims;
  if (channel_last) {
153
    in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
L
liym27 已提交
154
  } else {
155
    in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
L
liym27 已提交
156
  }
157

158
  framework::DDim filter_data_dims =
159
      phi::slice_ddim(filter_dims, 2, filter_dims.size());
160

161
  std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
162 163
  UpdatePaddingAndDilation(
      &paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize);
L
liym27 已提交
164 165 166 167 168

  std::vector<int64_t> output_shape({in_dims[0]});
  if (!channel_last) {
    output_shape.push_back(filter_dims[0]);
  }
169
  for (int i = 0; i < in_data_dims.size(); ++i) {
T
tink2123 已提交
170
    if ((!ctx->IsRuntime()) &&
L
liym27 已提交
171
        (in_data_dims[i] <= 0 || filter_dims[i + 2] <= 0)) {
T
tink2123 已提交
172 173
      output_shape.push_back(-1);
    } else {
174 175 176 177 178 179
      output_shape.push_back(ConvOutputSize(in_data_dims[i],
                                            filter_data_dims[i],
                                            dilations[i],
                                            paddings[2 * i],
                                            paddings[2 * i + 1],
                                            strides[i]));
T
tink2123 已提交
180
    }
C
chengduoZH 已提交
181
  }
L
liym27 已提交
182 183 184 185
  if (channel_last) {
    output_shape.push_back(filter_dims[0]);
  }

186
  return output_shape;
C
chengduoZH 已提交
187 188
}

189 190
framework::OpKernelType ConvOp::GetExpectedKernelType(
    const framework::ExecutionContext& ctx) const {
191
  auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
192 193
  // todo enable data layout when it's ready
  // (https://github.com/PaddlePaddle/Paddle/pull/20042)
194

195
  if (input_data_type != framework::proto::VarType::INT8 &&
196 197
      input_data_type != framework::proto::VarType::UINT8 &&
      input_data_type != framework::proto::VarType::BF16) {
198 199
    auto filter_data_type = framework::TransToProtoVarType(
        ctx.Input<phi::DenseTensor>("Filter")->dtype());
200
    PADDLE_ENFORCE_EQ(
201 202
        input_data_type,
        filter_data_type,
203 204 205 206 207 208
        platform::errors::InvalidArgument(
            "input and filter data type should be consistent, "
            "but received input data type is %s and filter type "
            "is %s",
            paddle::framework::DataTypeToString(input_data_type),
            paddle::framework::DataTypeToString(filter_data_type)));
209
  }
210 211

  return framework::OpKernelType(input_data_type, ctx.GetPlace());
212 213
}

214
framework::OpKernelType ConvOp::GetKernelTypeForVar(
215
    const std::string& var_name,
216
    const phi::DenseTensor& tensor,
217 218 219 220 221
    const framework::OpKernelType& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
  // Only input require reshaping, weights and
  // bias are having shape in NCHW order
  if ((var_name == "Input") &&
222 223
      (expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
      (tensor.layout() != phi::DataLayout::ONEDNN)) {
224 225 226
    auto attrs = Attrs();
    auto ar = paddle::framework::AttrReader(attrs);
    const std::string data_format = ar.Get<std::string>("data_format");
227
    auto dl = phi::StringToDataLayout(data_format);
228
    // Some models may have intentionally set "AnyLayout" for conv
229
    // op. Treat this as NCHW (default data_format value)
230
    if (dl != phi::DataLayout::kAnyLayout) {
231 232
      return framework::OpKernelType(
          expected_kernel_type.data_type_, tensor.place(), dl);
233 234 235
    }
  }
#endif
236 237
  return framework::OpKernelType(
      expected_kernel_type.data_type_, tensor.place(), tensor.layout());
238 239
}

Y
Yu Yang 已提交
240
void Conv2DOpMaker::Make() {
L
liym27 已提交
241 242 243 244 245 246
  AddInput("Input",
           "(Tensor) The input tensor of convolution operator. "
           "The format of input tensor is NCHW or NHWC, where N is batch size, "
           "C is the "
           "number of channels, H is the height of the feature, "
           "and W is the width of the feature.");
C
chengduoZH 已提交
247
  AddInput("Filter",
C
fix doc  
chengduoZH 已提交
248
           "(Tensor) The filter tensor of convolution operator. "
C
chengduoZH 已提交
249 250
           "The format of the filter tensor is MCHW, where M is the number of "
           "output image channels, C is the number of input image channels, "
C
fix doc  
chengduoZH 已提交
251 252
           "H is the height of the filter, and W is the width of the filter. "
           "If the groups attribute is greater than 1, C equals the number of "
C
chengduoZH 已提交
253
           "input image channels divided by the groups.");
254 255 256 257
  AddInput("Bias",
           "(Tensor) Bias to be added to each output of filter application."
           "The format of output tensor is X (one-dimensional) of size equal"
           "to the number of output channels. Only used with MKL-DNN.")
258 259
      .AsDispensable()
      .AsExtra();
260 261 262
  AddInput("ResidualData",
           "(Tensor) Tensor with residual data "
           "to which convolution output will be added."
263
           "Used with fuse_residual_connection fusion.")
264 265
      .AsDispensable()
      .AsExtra();
Y
Yihua Xu 已提交
266 267
  AddOutput("Output",
            "(Tensor) The output tensor of convolution operator. "
L
liym27 已提交
268
            "It has same data fromat and data type as the Input.");
C
chengduoZH 已提交
269 270 271 272
  AddAttr<std::vector<int>>("strides",
                            "(vector<int> default:{1, 1}), the "
                            "strides(h_stride, w_stride) of "
                            "convolution operator.")
C
chengduoZH 已提交
273
      .SetDefault({1, 1});
C
chengduoZH 已提交
274 275
  AddAttr<std::vector<int>>("paddings",
                            "(vector<int> default:{0, 0}), the "
L
liym27 已提交
276 277
                            "paddings(pad_height_top, pad_height_bottom, "
                            "pad_width_left, pad_wifth_right)  of "
C
chengduoZH 已提交
278
                            "convolution operator.")
C
chengduoZH 已提交
279
      .SetDefault({0, 0});
L
liym27 已提交
280 281 282 283 284 285
  AddAttr<std::string>(
      "padding_algorithm",
      "(string, default \"EXPLICIT\") An optional string from: \"EXPLICIT\","
      "\"SAME\",\"VALID\". Set to \"EXPLICIT\" for explicit padding. "
      "Set to \"SAME\" or \"VALID\" for algorithm of padding. ")
      .SetDefault("EXPLICIT");
C
chengduoZH 已提交
286 287
  AddAttr<int>(
      "groups",
C
chengduoZH 已提交
288
      "(int default:1), the groups number of the convolution operator. "
C
fix doc  
chengduoZH 已提交
289 290 291 292
      "According to grouped convolution in Alex Krizhevsky's Deep CNN paper: "
      "when group=2, the first half of the filters is only connected to the "
      "first half of the input channels, while the second half of the filters "
      "is only connected to the second half of the input channels.")
C
chengduoZH 已提交
293
      .SetDefault(1);
C
chengduoZH 已提交
294
  AddAttr<std::vector<int>>("dilations",
C
chengduoZH 已提交
295 296
                            "(vector<int> default:{1, 1}), the "
                            "dilations(h_dilation, w_dilation) of "
C
chengduoZH 已提交
297
                            "convolution operator.")
C
chengduoZH 已提交
298
      .SetDefault({1, 1});
299 300 301 302 303 304
  AddAttr<std::string>(
      "data_format",
      "(string, default NCHW) Only used in "
      "An optional string from: \"NHWC\", \"NCHW\". "
      "Defaults to \"NHWC\". Specify the data format of the output data, "
      "the input will be transformed automatically. ")
L
liym27 已提交
305
      .SetDefault("NCHW");
306
  // TODO(dzhwinter): need to registered layout transform function
C
chengduoZH 已提交
307
  AddComment(R"DOC(
C
fix doc  
chengduoZH 已提交
308 309
Convolution Operator.

C
chengduoZH 已提交
310
The convolution operation calculates the output based on the input, filter
C
chengduoZH 已提交
311
and strides, paddings, dilations, groups parameters. The size of each dimension of the
C
chengduoZH 已提交
312
parameters is checked in the infer-shape.
L
liym27 已提交
313
Input(Input) and Output(Output) are in NCHW or NHWC format. Where N is batch
C
fix doc  
chengduoZH 已提交
314
size, C is the number of channels, H is the height of the feature, and W is
C
chengduoZH 已提交
315
the width of the feature.
316
Filters(Input) is MCHW format format. Where M is the number of output image channels, C is
C
chengduoZH 已提交
317 318 319 320
the number of input image channels, H is the height of the filter, and W
is the width of the filter.
Parameters(strides, paddings, dilations) are two elements. These two elements represent
height and width, respectively.
C
chengduoZH 已提交
321 322 323 324
The input(X) size and output(Out) size may be different.

Example:
  Input:
C
chengduoZH 已提交
325 326
       Input shape: $(N, C_{in}, H_{in}, W_{in})$
       Filter shape: $(C_{out}, C_{in}, H_f, W_f)$
C
chengduoZH 已提交
327
  Output:
C
chengduoZH 已提交
328 329 330
       Output shape: $(N, C_{out}, H_{out}, W_{out})$
  Where
$$
L
liym27 已提交
331 332
       H_{out}= \frac{(H_{in} + pad_height_top + pad_height_bottom - (dilations[0] * (H_f - 1) + 1))}{strides[0]}+ 1 \\
       W_{out}= \frac{(W_{in} + pad_width_left + pad_width_right - (dilations[1] * (W_f - 1) + 1))}{strides[1]}+ 1
C
chengduoZH 已提交
333
$$
C
chengduoZH 已提交
334
)DOC");
Q
qingqing01 已提交
335
  Apply();
C
chengduoZH 已提交
336 337
}

338 339 340 341 342 343 344 345 346 347 348
class DepthwiseConv2DOpMaker : public Conv2DOpMaker {
 protected:
  void Apply() override {
    AddAttr<bool>(
        "use_cudnn",
        "(bool, default false) Only used in cudnn kernel, need install cudnn")
        .SetDefault(false)
        .AsExtra();
  }
};

Y
Yu Yang 已提交
349
void Conv3DOpMaker::Make() {
C
chengduoZH 已提交
350 351
  AddInput(
      "Input",
C
fix doc  
chengduoZH 已提交
352
      "(Tensor) The input tensor of convolution operator. "
L
liym27 已提交
353 354
      "The format of input tensor is NCDHW or NDHWC. Where N is batch size, C "
      "is the "
C
fix doc  
chengduoZH 已提交
355 356 357
      "number of channels, D is the depth of the feature, H is the height of "
      "the feature, "
      "and W is the width of the feature.");
C
chengduoZH 已提交
358
  AddInput("Filter",
C
fix doc  
chengduoZH 已提交
359
           "(Tensor) The filter tensor of convolution operator. "
C
chengduoZH 已提交
360 361
           "The format of the filter tensor is MCDHW, where M is the number of "
           "output image channels, C is the number of input image channels, "
C
fix doc  
chengduoZH 已提交
362 363 364
           "D is the depth of the filter, H is the height of the filter, and W "
           "is the width of the filter."
           "If the groups attribute is greater than 1, C equals the number of "
C
chengduoZH 已提交
365
           "input image channels divided by the groups.");
Y
Yihua Xu 已提交
366 367
  AddOutput("Output",
            "(Tensor) The output tensor of convolution operator."
L
liym27 已提交
368
            "It has same data fromat and data type as the Input.");
C
chengduoZH 已提交
369 370 371 372
  AddAttr<std::vector<int>>("strides",
                            "(vector<int>, default:{1, 1, 1}), the "
                            "strides(d_stride, h_stride, w_stride) of "
                            "convolution operator.")
C
chengduoZH 已提交
373
      .SetDefault({1, 1, 1});
L
liym27 已提交
374 375 376 377 378 379
  AddAttr<std::vector<int>>(
      "paddings",
      "(vector<int>, default:{0, 0, 0}), the "
      "paddings(pad_depth_front, pad_depth_back, pad_height_top, "
      "pad_height_bottom, pad_width_left, pad_width_right) of convolution "
      "operator.")
C
chengduoZH 已提交
380
      .SetDefault({0, 0, 0});
L
liym27 已提交
381 382 383 384 385 386
  AddAttr<std::string>(
      "padding_algorithm",
      "(string, default \"EXPLICIT\") An optional string from: \"EXPLICIT\","
      "\"SAME\",\"VALID\". Set to \"EXPLICIT\" for explicit padding. "
      "Set to \"SAME\" or \"VALID\" for algorithm of padding. ")
      .SetDefault("EXPLICIT");
C
chengduoZH 已提交
387 388
  AddAttr<int>(
      "groups",
C
chengduoZH 已提交
389
      "(int default:1), the groups number of the convolution operator. "
C
fix doc  
chengduoZH 已提交
390 391 392 393
      "According to grouped convolution in Alex Krizhevsky's Deep CNN paper: "
      "when group=2, the first half of the filters is only connected to the "
      "first half of the input channels, while the second half of the filters "
      "is only connected to the second half of the input channels.")
C
chengduoZH 已提交
394
      .SetDefault(1);
C
chengduoZH 已提交
395
  AddAttr<std::vector<int>>("dilations",
C
chengduoZH 已提交
396 397
                            "(vector<int> default:{1, 1, 1}), the "
                            "dilations(d_dilation, h_dilation, w_dilation) of "
C
chengduoZH 已提交
398
                            "convolution operator.")
C
chengduoZH 已提交
399
      .SetDefault({1, 1, 1});
400 401
  AddAttr<std::string>(
      "data_format",
L
liym27 已提交
402 403 404
      "(string, default NCDHW) Only used in "
      "An optional string from: \"NDHWC\", \"NCDHW\". "
      "Defaults to \"NDHWC\". Specify the data format of the output data, "
405
      "the input will be transformed automatically. ")
L
liym27 已提交
406
      .SetDefault("NCDHW");
C
chengduoZH 已提交
407
  AddComment(R"DOC(
C
fix doc  
chengduoZH 已提交
408 409
Convolution3D Operator.

C
chengduoZH 已提交
410
The convolution operation calculates the output based on the input, filter
C
chengduoZH 已提交
411
and strides, paddings, dilations, groups parameters. The size of each dimension of the
C
chengduoZH 已提交
412
parameters is checked in the infer-shape.
L
liym27 已提交
413
Input(Input) and output(Output) are in NCDHW or NDHWC format, where N is batch
C
fix doc  
chengduoZH 已提交
414
size, C is the number of channels,D is the depth of the feature, H is the height of
C
chengduoZH 已提交
415 416 417 418 419 420
the feature, and W is the width of the feature.
Filters(Input) is MCDHW format, where M is the number of output image channels,
C is the number of input image channels, D is the depth of the filter,
H is the height of the filter, and W is the width of the filter.
Parameters(strides, paddings, dilations) are three elements. These three elements
represent depth, height and width, respectively.
C
fix doc  
chengduoZH 已提交
421 422 423 424
The input(X) size and output(Out) size may be different.

Example:
  Input:
C
chengduoZH 已提交
425 426
       Input shape: $(N, C_{in}, D_{in}, H_{in}, W_{in})$
       Filter shape: $(C_{out}, C_{in}, D_f, H_f, W_f)$
C
fix doc  
chengduoZH 已提交
427
  Output:
C
chengduoZH 已提交
428 429 430
       Output shape: $(N, C_{out}, D_{out}, H_{out}, W_{out})$
  Where
  $$
L
liym27 已提交
431 432 433
       D_{out}= \frac{(D_{in} + pad_depth_front + pad_depth_back - (dilations[0] * (D_f - 1) + 1))}{ strides[0]}+ 1 \\
       H_{out}= \frac{(H_{in} + pad_height_top + pad_height_bottom - (dilations[1] * (H_f - 1) + 1))}{ strides[1]}+ 1 \\
       W_{out}= \frac{(W_{in} + pad_width_left + pad_width_right - (dilations[2] * (W_f - 1) + 1))}{ strides[2]}+ 1
C
chengduoZH 已提交
434
  $$
C
chengduoZH 已提交
435
)DOC");
Q
qingqing01 已提交
436
  Apply();
C
chengduoZH 已提交
437 438
}

C
chengduoZH 已提交
439 440 441 442 443 444 445 446 447 448 449
void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const {
  auto in_dims = ctx->GetInputDim("Input");
  auto filter_dims = ctx->GetInputDim("Filter");
  if (ctx->HasOutput(framework::GradVarName("Input"))) {
    ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
  }
  if (ctx->HasOutput(framework::GradVarName("Filter"))) {
    ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
  }
}

450 451
framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
    const framework::ExecutionContext& ctx) const {
M
mozga-intel 已提交
452
  // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
453
  auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
454
  return framework::OpKernelType(data_type, ctx.GetPlace());
455 456
}

457
framework::OpKernelType ConvOpGrad::GetKernelTypeForVar(
458
    const std::string& var_name,
459
    const phi::DenseTensor& tensor,
460 461 462 463 464 465
    const framework::OpKernelType& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
  // Only input require reshaping, weights and
  // bias are having shape in NCHW order
  if (((var_name == "Input") ||
       (var_name == framework::GradVarName("Output"))) &&
466 467
      (expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
      (tensor.layout() != phi::DataLayout::ONEDNN)) {
468 469 470
    auto attrs = Attrs();
    auto ar = paddle::framework::AttrReader(attrs);
    const std::string data_format = ar.Get<std::string>("data_format");
471
    auto dl = phi::StringToDataLayout(data_format);
472 473
    // Some models may have intentionally set "AnyLayout" for pool
    // op. Treat this as NCHW (default data_format value)
474
    if (dl != phi::DataLayout::kAnyLayout) {
475 476
      return framework::OpKernelType(
          expected_kernel_type.data_type_, tensor.place(), dl);
477 478 479
    }
  }
#endif
480 481
  return framework::OpKernelType(
      expected_kernel_type.data_type_, tensor.place(), tensor.layout());
482 483
}

H
hong 已提交
484 485
template <typename T>
class Conv2DGradMaker : public framework::SingleGradOpMaker<T> {
486
 public:
H
hong 已提交
487
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
488

489
  void Apply(GradOpPtr<T> op) const override {
S
sneaxiy 已提交
490
    op->SetType(this->ForwardOpType() + "_grad");
H
hong 已提交
491 492 493
    op->SetInput("Input", this->Input("Input"));
    op->SetInput("Filter", this->Input("Filter"));
    op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
494

H
hong 已提交
495 496
    op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
    op->SetOutput(framework::GradVarName("Filter"), this->InputGrad("Filter"));
497 498 499 500 501

    if (this->HasInput("Bias")) {
      op->SetInput("Bias", this->Input("Bias"));
      op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
    }
H
hong 已提交
502
    op->SetAttrMap(this->Attrs());
503
  }
S
sneaxiy 已提交
504 505
};

H
hong 已提交
506 507
template <typename T>
class Conv3DGradMaker : public framework::SingleGradOpMaker<T> {
S
sneaxiy 已提交
508
 public:
H
hong 已提交
509
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
510

511
  void Apply(GradOpPtr<T> op) const override {
S
sneaxiy 已提交
512
    op->SetType(this->ForwardOpType() + "_grad");
H
hong 已提交
513 514 515
    op->SetInput("Input", this->Input("Input"));
    op->SetInput("Filter", this->Input("Filter"));
    op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
S
sneaxiy 已提交
516

H
hong 已提交
517 518
    op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
    op->SetOutput(framework::GradVarName("Filter"), this->InputGrad("Filter"));
S
sneaxiy 已提交
519

H
hong 已提交
520 521
    if (this->HasInput("ResidualData")) {
      op->SetInput("ResidualData", this->Input("ResidualData"));
S
sneaxiy 已提交
522 523
    }

H
hong 已提交
524
    op->SetAttrMap(this->Attrs());
525 526 527
  }
};

Q
qingqing01 已提交
528 529 530 531
/*
 * Inputs:  I, W, dO, ddI, ddW
 * Outputs: ddO, dW, dI
 */
H
hong 已提交
532 533
template <typename T>
class Conv2DDoubleGradMaker : public framework::SingleGradOpMaker<T> {
Q
qingqing01 已提交
534
 public:
H
hong 已提交
535
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
Q
qingqing01 已提交
536

537
  void Apply(GradOpPtr<T> op) const override {
Q
qingqing01 已提交
538 539
    op->SetType(this->ForwardOpType() + "_grad");
    // I, W, dO, ddI, ddW
H
hong 已提交
540 541 542 543 544 545
    op->SetInput("Input", this->Input("Input"));
    op->SetInput("Filter", this->Input("Filter"));
    op->SetInput("DOutput", this->Input(framework::GradVarName("Output")));
    op->SetInput("DDInput", this->OutputGrad(framework::GradVarName("Input")));
    op->SetInput("DDFilter",
                 this->OutputGrad(framework::GradVarName("Filter")));
Q
qingqing01 已提交
546 547 548 549

    // ddO, dI, dW
    // Unlike grad op, double grad op does not use name@GRAD@GRAD
    // as key of ops' inputs and outputs.
H
hong 已提交
550 551
    auto ddx = this->OutputGrad(framework::GradVarName("Input"));
    auto ddw = this->OutputGrad(framework::GradVarName("Filter"));
552

L
lvmengsi 已提交
553
    op->SetOutput("DDOutput",
H
hong 已提交
554
                  ddx.empty()
555
                      ? this->EmptyInputGrad()
H
hong 已提交
556
                      : this->InputGrad(framework::GradVarName("Output")));
557 558 559 560 561 562
    op->SetOutput(
        "DFilter",
        ddx.empty() ? this->EmptyInputGrad() : this->InputGrad("Filter"));
    op->SetOutput(
        "DInput",
        ddw.empty() ? this->EmptyInputGrad() : this->InputGrad("Input"));
563

H
hong 已提交
564
    op->SetAttrMap(this->Attrs());
Q
qingqing01 已提交
565 566 567
  }
};

L
lvmengsi 已提交
568 569 570 571
/*
 * Inputs:  I, W, dO, ddI, ddW
 * Outputs: ddO, dW, dI
 */
H
hong 已提交
572 573
template <typename T>
class Conv3DDoubleGradMaker : public framework::SingleGradOpMaker<T> {
L
lvmengsi 已提交
574
 public:
H
hong 已提交
575
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
L
lvmengsi 已提交
576

577
  void Apply(GradOpPtr<T> op) const override {
L
lvmengsi 已提交
578 579
    op->SetType(this->ForwardOpType() + "_grad");
    // I, W, dO, ddI, ddW
H
hong 已提交
580 581 582 583 584 585
    op->SetInput("Input", this->Input("Input"));
    op->SetInput("Filter", this->Input("Filter"));
    op->SetInput("DOutput", this->Input(framework::GradVarName("Output")));
    op->SetInput("DDInput", this->OutputGrad(framework::GradVarName("Input")));
    op->SetInput("DDFilter",
                 this->OutputGrad(framework::GradVarName("Filter")));
L
lvmengsi 已提交
586

H
hong 已提交
587 588
    auto ddx = this->OutputGrad(framework::GradVarName("Input"));
    auto ddw = this->OutputGrad(framework::GradVarName("Filter"));
L
lvmengsi 已提交
589

L
lvmengsi 已提交
590
    op->SetOutput("DDOutput",
H
hong 已提交
591
                  ddx.empty()
592
                      ? this->EmptyInputGrad()
H
hong 已提交
593
                      : this->InputGrad(framework::GradVarName("Output")));
594 595 596 597 598 599
    op->SetOutput(
        "DFilter",
        ddx.empty() ? this->EmptyInputGrad() : this->InputGrad("Filter"));
    op->SetOutput(
        "DInput",
        ddw.empty() ? this->EmptyInputGrad() : this->InputGrad("Input"));
L
lvmengsi 已提交
600

H
hong 已提交
601
    op->SetAttrMap(this->Attrs());
L
lvmengsi 已提交
602 603 604
  }
};

Q
qingqing01 已提交
605 606 607 608 609
void ConvOpDoubleGrad::InferShape(framework::InferShapeContext* ctx) const {
  auto x_dims = ctx->GetInputDim("Input");
  auto w_dims = ctx->GetInputDim("Filter");
  auto do_dims = ctx->GetInputDim("DOutput");

L
lvmengsi 已提交
610 611
  if (ctx->HasOutput("DDOutput") &&
      (ctx->HasInput("DDInput") || (ctx->HasInput("DDFilter")))) {
Q
qingqing01 已提交
612 613
    ctx->SetOutputDim("DDOutput", do_dims);
  }
614
  if (ctx->HasOutput("DFilter") && ctx->HasInput("DDInput")) {
Q
qingqing01 已提交
615 616
    ctx->SetOutputDim("DFilter", w_dims);
  }
617
  if (ctx->HasOutput("DInput") && ctx->HasInput("DDFilter")) {
Q
qingqing01 已提交
618 619 620 621 622 623
    ctx->SetOutputDim("DInput", x_dims);
  }
}

framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType(
    const framework::ExecutionContext& ctx) const {
624 625
  auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
  return framework::OpKernelType(data_type, ctx.GetPlace());
Q
qingqing01 已提交
626 627
}

C
chengduoZH 已提交
628 629 630 631
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
632 633 634
REGISTER_OPERATOR(conv2d,
                  ops::ConvOp,
                  ops::Conv2DOpMaker,
H
hong 已提交
635 636 637
                  ops::ConvOpInferVarType,
                  ops::Conv2DGradMaker<paddle::framework::OpDesc>,
                  ops::Conv2DGradMaker<paddle::imperative::OpBase>);
638 639
REGISTER_OPERATOR(conv2d_grad,
                  ops::ConvOpGrad,
H
hong 已提交
640 641
                  ops::Conv2DDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::Conv2DDoubleGradMaker<paddle::imperative::OpBase>);
Q
qingqing01 已提交
642
REGISTER_OPERATOR(conv2d_grad_grad, ops::ConvOpDoubleGrad);
643 644

// depthwise convolution op
645 646
REGISTER_OPERATOR(depthwise_conv2d,
                  ops::ConvOp,
647
                  ops::DepthwiseConv2DOpMaker,
H
hong 已提交
648 649 650
                  ops::ConvOpInferVarType,
                  ops::Conv2DGradMaker<paddle::framework::OpDesc>,
                  ops::Conv2DGradMaker<paddle::imperative::OpBase>);
651 652
REGISTER_OPERATOR(depthwise_conv2d_grad,
                  ops::ConvOpGrad,
653 654 655
                  ops::Conv2DDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::Conv2DDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(depthwise_conv2d_grad_grad, ops::ConvOpDoubleGrad);
C
chengduo 已提交
656

657 658 659
REGISTER_OPERATOR(conv3d,
                  ops::ConvOp,
                  ops::Conv3DOpMaker,
H
hong 已提交
660 661 662
                  ops::ConvOpInferVarType,
                  ops::Conv3DGradMaker<paddle::framework::OpDesc>,
                  ops::Conv3DGradMaker<paddle::imperative::OpBase>);
663 664
REGISTER_OPERATOR(conv3d_grad,
                  ops::ConvOpGrad,
H
hong 已提交
665 666
                  ops::Conv3DDoubleGradMaker<paddle::framework::OpDesc>,
                  ops::Conv3DDoubleGradMaker<paddle::imperative::OpBase>);
L
lvmengsi 已提交
667
REGISTER_OPERATOR(conv3d_grad_grad, ops::ConvOpDoubleGrad);
C
chengduoZH 已提交
668

669 670
REGISTER_OP_VERSION(conv2d).AddCheckpoint(
    R"ROC(
671 672
      Upgrade conv2d, add a new attribute [use_addto].
    )ROC",
673 674 675 676 677
    paddle::framework::compatible::OpVersionDesc().NewAttr(
        "use_addto",
        "In order to support new feature (inplace addto strategy) for "
        "gradient accumulation.",
        false));
678 679 680 681 682 683 684 685 686 687 688 689

REGISTER_OP_VERSION(depthwise_conv2d)
    .AddCheckpoint(
        R"ROC(
      Upgrade depthwise_conv2d, add a new attribute [use_addto].
    )ROC",
        paddle::framework::compatible::OpVersionDesc().NewAttr(
            "use_addto",
            "In order to support new feature (inplace addto strategy) for "
            "gradient accumulation.",
            false));

690 691
REGISTER_OP_VERSION(conv3d).AddCheckpoint(
    R"ROC(
692 693
      Upgrade conv3d, add a new attribute [use_addto].
    )ROC",
694 695 696 697 698
    paddle::framework::compatible::OpVersionDesc().NewAttr(
        "use_addto",
        "In order to support new feature (inplace addto strategy) for "
        "gradient accumulation.",
        false));