conv_op.h 40.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

L
liym27 已提交
17
#include <algorithm>
Q
qingqing01 已提交
18
#include <string>
Q
qingqing01 已提交
19
#include <unordered_map>
20
#include <vector>
Y
Yi Wang 已提交
21 22
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
F
Feiyu Chan 已提交
23
#include "paddle/fluid/operators/layout_utils.h"
Y
Yi Wang 已提交
24 25 26
#include "paddle/fluid/operators/math/depthwise_conv.h"
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
27
#include "paddle/pten/kernels/funcs/blas/blas.h"
28 29 30 31 32

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
X
Xin Pan 已提交
33 34
constexpr int kConvMKLDNNFP32 = 1;
constexpr int kConvMKLDNNINT8 = 2;
35
constexpr int MaxKeyLength = 256;
36

武毅 已提交
37 38
// Base convolution operator definations for other conv
// like operators to reuse the implementation.
Y
Yang Yang 已提交
39 40
inline int ConvOutputSize(int input_size, int filter_size, int dilation,
                          int padding, int stride) {
C
chengduoZH 已提交
41
  const int dkernel = dilation * (filter_size - 1) + 1;
C
chengduoZH 已提交
42
  int output_size = (input_size + 2 * padding - dkernel) / stride + 1;
L
liym27 已提交
43 44
  PADDLE_ENFORCE_GT(
      output_size, 0,
45 46 47 48 49 50 51
      platform::errors::InvalidArgument(
          "The output's size is expected to be greater than 0. "
          "But recieved: output's size is %d. The output's size is computed by "
          "((input_size + 2 * padding - (dilation * (filter_size - 1) + 1)) / "
          "stride + 1), where input_size is %d, padding is %d, "
          "filter_size is %d, dilation is %d, stride is %d.",
          output_size, input_size, padding, filter_size, dilation, stride));
C
chengduoZH 已提交
52

武毅 已提交
53 54
  return output_size;
}
L
liym27 已提交
55 56 57 58 59

inline int ConvOutputSize(int input_size, int filter_size, int dilation,
                          int padding_1, int padding_2, int stride) {
  const int dkernel = dilation * (filter_size - 1) + 1;
  int output_size = (input_size + padding_1 + padding_2 - dkernel) / stride + 1;
60 61 62 63 64 65 66 67 68 69
  PADDLE_ENFORCE_GT(
      output_size, 0,
      platform::errors::InvalidArgument(
          "The output's size is expected to be greater than 0. "
          "But recieved: output's size is %d. The output's size is computed by "
          "((input_size + padding_1 + padding_2 - (dilation * (filter_size - "
          "1) + 1)) / stride + 1), where input_size is %d, padding is "
          "(%d, %d), filter_size is %d, dilation is %d, stride is %d.",
          output_size, input_size, padding_1, padding_2, filter_size, dilation,
          stride));
L
liym27 已提交
70 71 72

  return output_size;
}
73 74 75 76

template <typename T = int>
inline void UpdatePaddingAndDilation(std::vector<T>* paddings,
                                     std::vector<T>* dilation,
L
liym27 已提交
77 78
                                     const std::string padding_algorithm,
                                     const framework::DDim data_dims,
79 80
                                     const std::vector<T>& strides,
                                     const std::vector<T>& ksize) {
L
liym27 已提交
81
  // set padding size == data_dims.size() * 2
82
  auto data_shape = pten::vectorize<T>(data_dims);
83 84
  if (static_cast<int>(paddings->size()) == data_dims.size()) {
    for (int i = 0; i < data_dims.size(); ++i) {
85
      T copy_pad = *(paddings->begin() + 2 * i);
L
liym27 已提交
86 87 88 89 90
      paddings->insert(paddings->begin() + 2 * i + 1, copy_pad);
    }
  } else {
    PADDLE_ENFORCE_EQ(
        data_dims.size() * 2, paddings->size(),
91 92 93 94 95
        platform::errors::InvalidArgument(
            "Attribute padding's size should be the same or twice as the "
            "input's dimension. "
            "But recieved: padding's size is %d, padding is [%s]; input's "
            "dimension is %d, input's shape is [%s].",
96
            paddings->size(), pten::make_ddim(*paddings), data_dims.size(),
97
            data_dims));
L
liym27 已提交
98 99
  }

100
  // when padding_algorithm is "VALID" or "SAME"
L
liym27 已提交
101
  if (padding_algorithm == "SAME") {
102
    for (int i = 0; i < data_dims.size(); ++i) {
103 104
      T out_size = (data_dims[i] + strides[i] - 1) / strides[i];
      T pad_sum =
105 106
          std::max((out_size - 1) * strides[i] + ksize[i] - data_shape[i],
                   static_cast<T>(0));
107 108
      T pad_0 = pad_sum / 2;
      T pad_1 = pad_sum - pad_0;
L
liym27 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122
      *(paddings->begin() + i * 2) = pad_0;
      *(paddings->begin() + i * 2 + 1) = pad_1;

      // dilation
      *(dilation->begin() + i) = 1;
    }

  } else if (padding_algorithm == "VALID") {
    for (auto it = paddings->begin(); it != paddings->end(); it++) {
      *it = 0;
    }
  }
}

123 124 125 126
inline bool IsExpand(const std::vector<int64_t>& filter_dim,
                     const std::vector<int>& strides,
                     const std::vector<int>& paddings,
                     const std::vector<int>& dilations) {
C
chengduoZH 已提交
127 128
  bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true;
  for (size_t j = 0; j < strides.size(); ++j) {
C
chengduoZH 已提交
129
    filter_1 = filter_1 && (static_cast<int>(filter_dim[j + 2]) == 1);
C
chengduoZH 已提交
130 131 132
    strides_1 = strides_1 && (strides[j] == 1);
    padding_0 = padding_0 && (paddings[j] == 0);
    dilation_1 = dilation_1 && (dilations[j] == 1);
C
chengduoZH 已提交
133
  }
L
liym27 已提交
134 135 136 137 138
  if (paddings.size() != strides.size()) {
    for (size_t j = 0; j < paddings.size(); ++j) {
      padding_0 = padding_0 && (paddings[j] == 0);
    }
  }
C
chengduoZH 已提交
139
  return !(filter_1 && strides_1 && padding_0 && dilation_1);
C
chengduoZH 已提交
140
}
武毅 已提交
141 142 143 144 145

// Define Op classes in .h file so that other conv
// operator implementations can reuse the code.
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Q
qingqing01 已提交
146 147 148 149
  void Make() final;

 protected:
  virtual void Apply() {}
武毅 已提交
150 151
};

C
chengduoZH 已提交
152 153
class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Q
qingqing01 已提交
154 155 156 157 158 159 160 161
  void Make() final;

 protected:
  virtual void Apply() {}
};

class ConvOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
162
  std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
Q
qingqing01 已提交
163
      const override {
164
    static std::unordered_map<std::string, std::string> m{
Q
qingqing01 已提交
165
        {"Input", /*->*/ "Output"}};
166
    return m;
Q
qingqing01 已提交
167
  }
C
chengduoZH 已提交
168 169 170
};

class ConvOp : public framework::OperatorWithKernel {
武毅 已提交
171 172
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
173 174 175 176
  void InferShape(framework::InferShapeContext* ctx) const override {
    std::vector<int64_t> output_shape = ComputeOutputShape(ctx);

    OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "Conv");
177
    ctx->SetOutputDim("Output", pten::make_ddim(output_shape));
178 179
    ctx->ShareLoD("Input", "Output");
  }
180 181

 protected:
182 183 184
  std::vector<int64_t> ComputeOutputShape(
      framework::InferShapeContext* ctx) const;

185 186
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
187 188 189 190

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override;
武毅 已提交
191 192
};

C
chengduoZH 已提交
193
class ConvOpGrad : public framework::OperatorWithKernel {
武毅 已提交
194 195 196
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;
197

Q
qingqing01 已提交
198 199 200
 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
201 202 203 204

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override;
Q
qingqing01 已提交
205 206 207 208 209 210 211
};

class ConvOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;

212 213 214
 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
武毅 已提交
215 216
};

Q
QI JUN 已提交
217
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
218
class GemmConvKernel : public framework::OpKernel<T> {
219 220 221
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    const Tensor* input = context.Input<Tensor>("Input");
H
hedaoyuan 已提交
222 223 224 225
    // The filter will be reshaped in the calculations,
    // so here use an assignment operation,
    // that avoids modifying the variable in the Scope.
    Tensor filter = *context.Input<Tensor>("Filter");
226 227 228
    Tensor* output = context.Output<Tensor>("Output");
    output->mutable_data<T>(context.GetPlace());

L
liym27 已提交
229 230
    const int groups = context.Attr<int>("groups");
    const std::vector<int> strides = context.Attr<std::vector<int>>("strides");
231
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
232
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
L
liym27 已提交
233 234 235 236 237
    const std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");
    const std::string data_format = context.Attr<std::string>("data_format");
    const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");

238 239
    Tensor transformed_input(input->dtype());
    Tensor transformed_output(output->dtype());
L
liym27 已提交
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258

    if (channel_last) {
      ResizeToChannelFirst<DeviceContext, T>(context, input,
                                             &transformed_input);
      TransToChannelFirst<DeviceContext, T>(context, input, &transformed_input);

      ResizeToChannelFirst<DeviceContext, T>(context, output,
                                             &transformed_output);

    } else {
      transformed_input = *input;
      transformed_output = *output;
    }

    // update padding and dilation
    auto trans_in_dims = transformed_input.dims();
    auto filter_dims = filter.dims();

    framework::DDim in_data_dims =
259
        pten::slice_ddim(trans_in_dims, 2, trans_in_dims.size());
L
liym27 已提交
260
    framework::DDim filter_data_dims =
261
        pten::slice_ddim(filter_dims, 2, filter_dims.size());
L
liym27 已提交
262

263
    std::vector<int> ksize = pten::vectorize<int>(filter_data_dims);
L
liym27 已提交
264 265
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);
266

267 268
    auto& dev_ctx = context.template device_context<DeviceContext>();

L
liym27 已提交
269
    const int batch_size = static_cast<int>(transformed_input.dims()[0]);
C
chengduoZH 已提交
270

L
liym27 已提交
271 272
    // filter_shape_vec:
    // {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
273
    std::vector<int64_t> filter_shape_vec(pten::vectorize(filter.dims()));
L
liym27 已提交
274 275 276 277

    // output_shape_vec:
    // {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
    std::vector<int64_t> output_shape_vec(
278
        pten::vectorize(transformed_output.dims()));
279

H
hedaoyuan 已提交
280
    // use col_shape in the im2col calculation
L
liym27 已提交
281 282 283
    // col_shape_vec:
    // {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w,
    // o_d,o_h, o_w}
C
chengduoZH 已提交
284
    size_t data_dim = filter_shape_vec.size() - 2;
L
liym27 已提交
285

C
chengduoZH 已提交
286
    std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
L
liym27 已提交
287
    col_shape_vec[0] = trans_in_dims[1] / groups;
C
chengduoZH 已提交
288 289 290 291
    for (size_t j = 0; j < data_dim; ++j) {
      col_shape_vec[j + 1] = filter_shape_vec[j + 2];
      col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
    }
L
liym27 已提交
292

293
    framework::DDim col_shape(pten::make_ddim(col_shape_vec));
C
chengduoZH 已提交
294

H
hedaoyuan 已提交
295
    // use col_matrix_shape in the gemm calculation
L
liym27 已提交
296 297 298 299
    // size:
    // (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d * o_h *
    // o_w)

300
    framework::DDim col_matrix_shape = pten::flatten_to_2d(col_shape, data_dim);
C
chengduoZH 已提交
301

C
chengduoZH 已提交
302
    bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
L
liym27 已提交
303

H
hedaoyuan 已提交
304
    Tensor col;
H
hedaoyuan 已提交
305 306 307
    // col_matrix shares the same piece of data with col,
    // but will be reshaped into a two-dimensional matrix shape
    // to call the matrix multiplication interface.
C
chengduoZH 已提交
308
    Tensor col_matrix;
C
chengduoZH 已提交
309
    if (is_expand) {
X
Xin Pan 已提交
310
      col = context.AllocateTmpTensor<T, DeviceContext>(col_shape, dev_ctx);
C
chengduoZH 已提交
311 312 313
      col_matrix.ShareDataWith(col);
      col_matrix.Resize(col_matrix_shape);
    }
314

315
    framework::DDim in_matrix_shape = pten::slice_ddim(
L
liym27 已提交
316
        transformed_input.dims(), 1, transformed_input.dims().size());
C
chengduoZH 已提交
317

H
hedaoyuan 已提交
318 319
    framework::DDim filter_matrix_shape = {filter.dims()[0],
                                           filter.numel() / filter.dims()[0]};
H
hedaoyuan 已提交
320 321
    filter.Resize(filter_matrix_shape);

C
chengduoZH 已提交
322
    framework::DDim output_matrix_shape = {
L
liym27 已提交
323 324 325
        transformed_output.dims()[1],
        transformed_output.numel() /
            (transformed_output.dims()[0] * transformed_output.dims()[1])};
C
chengduoZH 已提交
326 327

    // convolution operator: im2col(or vol2col) + gemm
L
liym27 已提交
328 329
    int in_step = static_cast<int>(transformed_input.dims()[1]) / groups;
    int out_step = static_cast<int>(transformed_output.dims()[1]) / groups;
C
chengduoZH 已提交
330

Q
QI JUN 已提交
331 332
    math::Vol2ColFunctor<DeviceContext, T> vol2col;
    math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
C
chengduoZH 已提交
333

334
    auto blas = pten::funcs::GetBlas<DeviceContext, T>(dev_ctx);
C
chengduoZH 已提交
335
    for (int i = 0; i < batch_size; i++) {
L
liym27 已提交
336 337 338 339
      Tensor in_batch =
          transformed_input.Slice(i, i + 1).Resize(in_matrix_shape);
      Tensor out_batch =
          transformed_output.Slice(i, i + 1).Resize(output_matrix_shape);
C
chengduoZH 已提交
340

C
chengduoZH 已提交
341 342
      for (int g = 0; g < groups; g++) {
        Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
H
hedaoyuan 已提交
343

C
chengduoZH 已提交
344
        if (!is_expand) {
C
chengduoZH 已提交
345 346 347
          col.ShareDataWith(in_slice);
          col_matrix.ShareDataWith(col);
          col_matrix.Resize(col_matrix_shape);
C
chengduoZH 已提交
348
        } else if (data_dim == 2U) {
Q
QI JUN 已提交
349
          im2col(dev_ctx, in_slice, dilations, strides,
L
liym27 已提交
350 351
                 std::vector<int>{paddings[0], paddings[2], paddings[1],
                                  paddings[3]},
C
chengduoZH 已提交
352
                 &col);
L
liym27 已提交
353

C
chengduoZH 已提交
354
        } else if (data_dim == 3U) {
Q
QI JUN 已提交
355
          vol2col(dev_ctx, in_slice, dilations, strides, paddings, &col);
C
chengduoZH 已提交
356
        }
C
chengduoZH 已提交
357 358 359 360

        // gemm
        Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
        Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
C
chengduoZH 已提交
361 362
        blas.MatMul(filter_slice, false, col_matrix, false, T(1.0), &out_slice,
                    T(0.0));
H
hedaoyuan 已提交
363
      }
364
    }
L
liym27 已提交
365 366 367 368
    if (channel_last) {
      TransToChannelLast<DeviceContext, T>(context, &transformed_output,
                                           output);
    }
369 370 371
  }
};

Q
QI JUN 已提交
372
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
373
class GemmConvGradKernel : public framework::OpKernel<T> {
374 375
 public:
  void Compute(const framework::ExecutionContext& context) const override {
H
hedaoyuan 已提交
376 377 378 379 380
    const Tensor* input = context.Input<Tensor>("Input");
    const Tensor* output_grad =
        context.Input<Tensor>(framework::GradVarName("Output"));
    Tensor* input_grad =
        context.Output<Tensor>(framework::GradVarName("Input"));
H
hedaoyuan 已提交
381
    Tensor* filter_grad =
H
hedaoyuan 已提交
382
        context.Output<Tensor>(framework::GradVarName("Filter"));
H
hedaoyuan 已提交
383 384 385 386
    // The filter and filter_grad will be reshaped in the calculations,
    // so here use an assignment operation,
    // that avoids modifying the variable in the Scope.
    Tensor filter = *context.Input<Tensor>("Filter");
H
hedaoyuan 已提交
387

C
chengduoZH 已提交
388 389
    if (!input_grad && !filter_grad) return;

C
chengduoZH 已提交
390
    int groups = context.Attr<int>("groups");
L
liym27 已提交
391
    const std::vector<int> strides = context.Attr<std::vector<int>>("strides");
H
hedaoyuan 已提交
392
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
393
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
L
liym27 已提交
394 395 396 397 398 399
    const std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");
    const std::string data_format = context.Attr<std::string>("data_format");

    const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");

400 401
    Tensor transformed_input(input->dtype());
    Tensor transformed_output_grad(output_grad->dtype());
H
hedaoyuan 已提交
402

L
liym27 已提交
403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419
    if (channel_last) {
      ResizeToChannelFirst<DeviceContext, T>(context, input,
                                             &transformed_input);
      TransToChannelFirst<DeviceContext, T>(context, input, &transformed_input);

      ResizeToChannelFirst<DeviceContext, T>(context, output_grad,
                                             &transformed_output_grad);
      TransToChannelFirst<DeviceContext, T>(context, output_grad,
                                            &transformed_output_grad);
    } else {
      transformed_input = *input;
      transformed_output_grad = *output_grad;
    }

    // update padding and dilation
    auto in_dims = transformed_input.dims();
    auto filter_dims = filter.dims();
420
    framework::DDim in_data_dims = pten::slice_ddim(in_dims, 2, in_dims.size());
L
liym27 已提交
421
    framework::DDim filter_data_dims =
422 423
        pten::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = pten::vectorize<int>(filter_data_dims);
L
liym27 已提交
424 425 426 427
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

    const int batch_size = static_cast<int>(transformed_input.dims()[0]);
H
hedaoyuan 已提交
428

429 430
    auto& dev_ctx = context.template device_context<DeviceContext>();

C
chengduoZH 已提交
431
    // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
432
    std::vector<int64_t> filter_shape_vec(pten::vectorize(filter.dims()));
C
chengduoZH 已提交
433
    // output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
C
chengduoZH 已提交
434
    std::vector<int64_t> output_shape_vec(
435
        pten::vectorize(transformed_output_grad.dims()));
C
chengduoZH 已提交
436

C
chengduoZH 已提交
437 438 439
    // use col_shape in the im2col calculation
    // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
    // o_h, o_w}
C
chengduoZH 已提交
440 441
    size_t data_dim = filter_shape_vec.size() - 2;
    std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
L
liym27 已提交
442
    col_shape_vec[0] = transformed_input.dims()[1] / groups;
C
chengduoZH 已提交
443 444 445 446
    for (size_t j = 0; j < data_dim; ++j) {
      col_shape_vec[j + 1] = filter_shape_vec[j + 2];
      col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
    }
447
    framework::DDim col_shape(pten::make_ddim(col_shape_vec));
C
chengduoZH 已提交
448 449

    // use col_matrix_shape in the gemm calculation
C
chengduoZH 已提交
450 451 452 453
    // size: (i_c/g * k_h * k_w, o_h * o_w)
    // or
    // (i_c/g * k_d * k_h * k_w, o_d * o_h * o_w)
    framework::DDim col_matrix_shape =
454
        pten::flatten_to_2d(col_shape, data_dim + 1);
C
chengduoZH 已提交
455

456
    framework::DDim input_shape = pten::slice_ddim(
L
liym27 已提交
457
        transformed_input.dims(), 1, transformed_input.dims().size());
C
chengduoZH 已提交
458

C
chengduoZH 已提交
459 460
    framework::DDim filter_matrix_shape = {filter.dims()[0],
                                           filter.numel() / filter.dims()[0]};
C
chengduoZH 已提交
461 462 463
    filter.Resize(filter_matrix_shape);

    framework::DDim output_matrix_shape = {
L
liym27 已提交
464 465 466
        transformed_output_grad.dims()[1],
        transformed_output_grad.numel() / (transformed_output_grad.dims()[0] *
                                           transformed_output_grad.dims()[1])};
C
chengduoZH 已提交
467

C
chengduoZH 已提交
468 469
    // convolution backward input operator:  gemm + col2im(or col2vol)
    // convolution backward weight operator: im2col(or vol2col) + gemm
L
liym27 已提交
470 471
    int in_step = static_cast<int>(transformed_input.dims()[1]) / groups;
    int out_step = static_cast<int>(transformed_output_grad.dims()[1]) / groups;
C
chengduoZH 已提交
472

C
chengduoZH 已提交
473
    bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
L
liym27 已提交
474

C
chengduoZH 已提交
475 476 477 478
    Tensor col;
    // col_matrix shares the same piece of data with col,
    // but will be reshaped into a two-dimensional matrix shape
    // to call the matrix multiplication interface.
C
chengduoZH 已提交
479
    Tensor col_matrix;
C
chengduoZH 已提交
480
    if (is_expand) {
X
Xin Pan 已提交
481
      col = context.AllocateTmpTensor<T, DeviceContext>(col_shape, dev_ctx);
C
chengduoZH 已提交
482 483 484
      col_matrix.ShareDataWith(col);
      col_matrix.Resize(col_matrix_shape);
    }
C
chengduoZH 已提交
485

486
    pten::funcs::SetConstant<DeviceContext, T> set_zero;
487
    auto blas = pten::funcs::GetBlas<DeviceContext, T>(dev_ctx);
C
chengduoZH 已提交
488 489 490

    if (input_grad) {
      input_grad->mutable_data<T>(context.GetPlace());
491
      Tensor transformed_input_grad(input_grad->dtype());
L
liym27 已提交
492 493 494
      if (channel_last) {
        ResizeToChannelFirst<DeviceContext, T>(context, input_grad,
                                               &transformed_input_grad);
C
chengduoZH 已提交
495

L
liym27 已提交
496 497 498
      } else {
        transformed_input_grad = *input_grad;
      }
C
chengduoZH 已提交
499 500 501
      // if is_expand is false, the operation of set_zero is unnecessary,
      // because math::matmul will reset input_grad.
      if (is_expand) {
L
liym27 已提交
502
        set_zero(dev_ctx, &transformed_input_grad, static_cast<T>(0));
C
chengduoZH 已提交
503
      }
Q
QI JUN 已提交
504 505
      math::Col2VolFunctor<DeviceContext, T> col2vol;
      math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
C
chengduoZH 已提交
506

C
chengduoZH 已提交
507 508
      for (int i = 0; i < batch_size; i++) {
        Tensor out_grad_batch =
L
liym27 已提交
509 510 511
            transformed_output_grad.Slice(i, i + 1).Resize(output_matrix_shape);
        Tensor in_grad_batch =
            transformed_input_grad.Slice(i, i + 1).Resize(input_shape);
C
chengduoZH 已提交
512 513 514 515 516 517 518 519 520 521
        for (int g = 0; g < groups; g++) {
          // gemm
          Tensor out_grad_slice =
              out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
          Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);

          Tensor in_grad_slice =
              in_grad_batch.Slice(g * in_step, (g + 1) * in_step);

          if (!is_expand) {
C
chengduoZH 已提交
522 523
            col_matrix.ShareDataWith(in_grad_slice);
            col_matrix.Resize(col_matrix_shape);
C
chengduoZH 已提交
524
          }
C
chengduoZH 已提交
525 526
          blas.MatMul(filter_slice, true, out_grad_slice, false, T(1.0),
                      &col_matrix, T(0.0));
C
chengduoZH 已提交
527

C
chengduoZH 已提交
528
          if (is_expand && data_dim == 2U) {
Q
QI JUN 已提交
529
            col2im(dev_ctx, col, dilations, strides,
L
liym27 已提交
530 531
                   std::vector<int>{paddings[0], paddings[2], paddings[1],
                                    paddings[3]},
C
chengduoZH 已提交
532
                   &in_grad_slice);
C
chengduoZH 已提交
533
          } else if (is_expand && data_dim == 3U) {
Q
QI JUN 已提交
534
            col2vol(dev_ctx, col, dilations, strides, paddings, &in_grad_slice);
C
chengduoZH 已提交
535
          }
C
chengduoZH 已提交
536 537
        }
      }
L
liym27 已提交
538 539 540 541
      if (channel_last) {
        TransToChannelLast<DeviceContext, T>(context, &transformed_input_grad,
                                             input_grad);
      }
C
chengduoZH 已提交
542 543 544 545 546 547
    }

    if (filter_grad) {
      filter_grad->mutable_data<T>(context.GetPlace());
      Tensor filter_grad_ = *filter_grad;
      filter_grad_.Resize(filter_matrix_shape);
Q
QI JUN 已提交
548 549 550
      set_zero(dev_ctx, filter_grad, static_cast<T>(0));
      math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
      math::Vol2ColFunctor<DeviceContext, T> vol2col;
C
chengduoZH 已提交
551 552
      for (int i = 0; i < batch_size; i++) {
        Tensor out_grad_batch =
L
liym27 已提交
553 554
            transformed_output_grad.Slice(i, i + 1).Resize(output_matrix_shape);
        Tensor in_batch = transformed_input.Slice(i, i + 1).Resize(input_shape);
C
chengduoZH 已提交
555 556 557 558 559
        for (int g = 0; g < groups; g++) {
          // im2col
          Tensor out_grad_slice =
              out_grad_batch.Slice(g * out_step, (g + 1) * out_step);
          Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
C
chengduoZH 已提交
560

C
chengduoZH 已提交
561
          if (!is_expand) {
C
chengduoZH 已提交
562 563 564
            col.ShareDataWith(in_slice);
            col_matrix.ShareDataWith(col);
            col_matrix.Resize(col_matrix_shape);
C
chengduoZH 已提交
565
          } else if (data_dim == 2U) {
Q
QI JUN 已提交
566
            im2col(dev_ctx, in_slice, dilations, strides,
L
liym27 已提交
567 568
                   std::vector<int>{paddings[0], paddings[2], paddings[1],
                                    paddings[3]},
C
chengduoZH 已提交
569
                   &col);
L
liym27 已提交
570

C
chengduoZH 已提交
571
          } else if (data_dim == 3U) {
Q
QI JUN 已提交
572
            vol2col(dev_ctx, in_slice, dilations, strides, paddings, &col);
C
chengduoZH 已提交
573
          }
C
chengduoZH 已提交
574 575 576 577

          // gemm
          Tensor filter_grad_slice =
              filter_grad_.Slice(g * out_step, (g + 1) * out_step);
C
chengduoZH 已提交
578 579
          blas.MatMul(out_grad_slice, false, col_matrix, true, T(1.0),
                      &filter_grad_slice, T(1.0));
C
chengduoZH 已提交
580 581 582 583 584
        }
      }
    }
  }
};
Z
zlx 已提交
585

L
lvmengsi 已提交
586 587 588 589 590
template <typename DeviceContext, typename T>
class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
591 592 593
    PADDLE_ENFORCE_EQ(
        platform::is_cpu_place(ctx.GetPlace()), true,
        paddle::platform::errors::PreconditionNotMet("It must use CPUPlace."));
L
lvmengsi 已提交
594 595 596 597 598 599 600 601
    const Tensor* X = ctx.Input<Tensor>("Input");
    const Tensor* dY = ctx.Input<Tensor>("DOutput");
    const Tensor* ddX = ctx.Input<Tensor>("DDInput");
    const Tensor* ddW_in = ctx.Input<Tensor>("DDFilter");

    Tensor* ddY = ctx.Output<Tensor>("DDOutput");
    Tensor* dW = ctx.Output<Tensor>("DFilter");
    Tensor* dX = ctx.Output<Tensor>("DInput");
602 603
    Tensor W = GET_DATA_SAFELY(ctx.Input<Tensor>("Filter"), "Input", "Filter",
                               "GemmConvDoubleGrad");
L
lvmengsi 已提交
604
    if (!ddY && !dW && !dX) return;
L
liym27 已提交
605 606 607

    const int groups = ctx.Attr<int>("groups");
    const std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
L
lvmengsi 已提交
608 609
    std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
    std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
L
liym27 已提交
610 611 612 613 614 615 616
    const std::string padding_algorithm =
        ctx.Attr<std::string>("padding_algorithm");
    const std::string data_format = ctx.Attr<std::string>("data_format");

    const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");

    // transform Tensor
617 618 619
    Tensor transformed_X(X->dtype());
    Tensor transformed_dY(dY->dtype());
    Tensor transformed_ddX(X->dtype());
L
liym27 已提交
620 621 622 623 624 625 626 627

    if (channel_last) {
      ResizeToChannelFirst<DeviceContext, T>(ctx, X, &transformed_X);
      TransToChannelFirst<DeviceContext, T>(ctx, X, &transformed_X);

      ResizeToChannelFirst<DeviceContext, T>(ctx, dY, &transformed_dY);
      TransToChannelFirst<DeviceContext, T>(ctx, dY, &transformed_dY);

L
lvmengsi 已提交
628 629 630 631
      if (ddX) {
        ResizeToChannelFirst<DeviceContext, T>(ctx, ddX, &transformed_ddX);
        TransToChannelFirst<DeviceContext, T>(ctx, ddX, &transformed_ddX);
      }
L
liym27 已提交
632 633 634
    } else {
      transformed_X = *X;
      transformed_dY = *dY;
L
lvmengsi 已提交
635 636 637
      if (ddX) {
        transformed_ddX = *ddX;
      }
L
liym27 已提交
638 639 640 641 642 643
    }

    // update padding and dilation
    auto in_dims = transformed_X.dims();
    auto filter_dims = W.dims();

644
    framework::DDim in_data_dims = pten::slice_ddim(in_dims, 2, in_dims.size());
L
liym27 已提交
645
    framework::DDim filter_data_dims =
646 647
        pten::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = pten::vectorize<int>(filter_data_dims);
L
liym27 已提交
648 649 650 651
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

    const int batch_size = static_cast<int>(transformed_X.dims()[0]);
652
    std::vector<int64_t> filter_shape_vec(pten::vectorize(W.dims()));
L
liym27 已提交
653
    std::vector<int64_t> output_shape_vec(
654
        pten::vectorize(transformed_dY.dims()));
L
lvmengsi 已提交
655 656 657 658

    size_t data_dim = filter_shape_vec.size() - 2;
    std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
    // col_shape [in_channel/group, kh, kw, oh, ow]
L
liym27 已提交
659
    col_shape_vec[0] = transformed_X.dims()[1] / groups;
L
lvmengsi 已提交
660 661 662 663
    for (size_t j = 0; j < data_dim; ++j) {
      col_shape_vec[j + 1] = filter_shape_vec[j + 2];
      col_shape_vec[j + data_dim + 1] = output_shape_vec[j + 2];
    }
664
    framework::DDim col_shape(pten::make_ddim(col_shape_vec));
L
lvmengsi 已提交
665 666
    // col_matrix_shape [in_channel/group * kh * kw, oh * ow]
    framework::DDim col_matrix_shape =
667
        pten::flatten_to_2d(col_shape, data_dim + 1);
L
lvmengsi 已提交
668
    // input_shape [Cin, H, W]
669 670
    framework::DDim input_shape =
        pten::slice_ddim(transformed_X.dims(), 1, transformed_X.dims().size());
L
lvmengsi 已提交
671 672 673 674 675 676
    // filter_matrix_shape [Cout, Cin * kh * kw]
    framework::DDim filter_matrix_shape = {W.dims()[0],
                                           W.numel() / W.dims()[0]};

    W.Resize(filter_matrix_shape);
    framework::DDim output_matrix_shape = {
L
liym27 已提交
677 678 679 680 681
        transformed_dY.dims()[1],
        transformed_dY.numel() /
            (transformed_dY.dims()[0] * transformed_dY.dims()[1])};
    int in_step = static_cast<int>(transformed_X.dims()[1]) / groups;
    int out_step = static_cast<int>(transformed_dY.dims()[1]) / groups;
L
lvmengsi 已提交
682 683 684 685 686 687 688 689 690 691

    bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
    Tensor col;
    Tensor col_matrix;
    if (is_expand) {
      col = ctx.AllocateTmpTensor<T, DeviceContext>(col_shape, dev_ctx);
      col_matrix.ShareDataWith(col);
      col_matrix.Resize(col_matrix_shape);
    }

692
    pten::funcs::SetConstant<DeviceContext, T> set_zero;
693
    auto blas = pten::funcs::GetBlas<DeviceContext, T>(dev_ctx);
L
lvmengsi 已提交
694 695 696 697 698 699 700 701

    // dx convolution double grad:  gemm + col2im(col2vol)
    // dx = ddw * dy  ==> dx(N, Cin, H, W), ddw(Cout, Cin, kh, kw), dy(N, Cout,
    // oH, oW)
    if (dX && ddW_in) {
      Tensor ddW;
      ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape);
      dX->mutable_data<T>(ctx.GetPlace());
L
liym27 已提交
702

703
      Tensor transformed_dX(dX->dtype());
L
liym27 已提交
704 705 706 707 708 709 710

      if (channel_last) {
        ResizeToChannelFirst<DeviceContext, T>(ctx, dX, &transformed_dX);

      } else {
        transformed_dX = *dX;
      }
L
lvmengsi 已提交
711 712 713
      // if is_expand is false, the operation of set_zero is unnecessary
      // because math::matmul will reset dx
      if (is_expand) {
L
liym27 已提交
714
        set_zero(dev_ctx, &transformed_dX, static_cast<T>(0));
L
lvmengsi 已提交
715 716 717 718 719
      }
      math::Col2VolFunctor<DeviceContext, T> col2vol;
      math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;

      for (int i = 0; i < batch_size; i++) {
L
liym27 已提交
720 721 722
        Tensor dy_batch =
            transformed_dY.Slice(i, i + 1).Resize(output_matrix_shape);
        Tensor dx_batch = transformed_dX.Slice(i, i + 1).Resize(input_shape);
L
lvmengsi 已提交
723 724 725 726 727 728 729 730 731 732 733 734 735 736
        for (int g = 0; g < groups; g++) {
          // gemm
          Tensor dy_slice = dy_batch.Slice(g * out_step, (g + 1) * out_step);
          Tensor ddw_slice = ddW.Slice(g * out_step, (g + 1) * out_step);
          Tensor dx_slice = dx_batch.Slice(g * in_step, (g + 1) * in_step);
          if (!is_expand) {
            col_matrix.ShareDataWith(dx_slice);
            col_matrix.Resize(col_matrix_shape);
          }
          blas.MatMul(ddw_slice, true, dy_slice, false, T(1.0), &col_matrix,
                      T(0.0));

          if (is_expand && data_dim == 2U) {
            col2im(dev_ctx, col, dilations, strides,
L
liym27 已提交
737 738
                   std::vector<int>{paddings[0], paddings[2], paddings[1],
                                    paddings[3]},
L
lvmengsi 已提交
739 740 741 742 743 744
                   &dx_slice);
          } else if (is_expand && data_dim == 3U) {
            col2vol(dev_ctx, col, dilations, strides, paddings, &dx_slice);
          }
        }
      }
L
liym27 已提交
745 746 747
      if (channel_last) {
        TransToChannelLast<DeviceContext, T>(ctx, &transformed_dX, dX);
      }
L
lvmengsi 已提交
748 749 750 751 752
    }

    // dw = ddx * dy  ==> dw(Cout, Cin, kh, kw), ddx(N, Cin, H, W), dy(N, Cout,
    // oH, oW)
    // dw convolution double grad:  im2col(vol2col) + gemm
L
lvmengsi 已提交
753
    if (dW && ddX) {
L
lvmengsi 已提交
754 755 756 757 758 759 760
      dW->mutable_data<T>(ctx.GetPlace());
      set_zero(dev_ctx, dW, static_cast<T>(0));
      Tensor dW_arr = *dW;
      dW_arr.Resize(filter_matrix_shape);
      math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
      math::Vol2ColFunctor<DeviceContext, T> vol2col;
      for (int i = 0; i < batch_size; ++i) {
L
liym27 已提交
761 762 763
        Tensor dy_batch =
            transformed_dY.Slice(i, i + 1).Resize(output_matrix_shape);
        Tensor ddx_batch = transformed_ddX.Slice(i, i + 1).Resize(input_shape);
L
lvmengsi 已提交
764 765 766 767 768 769 770 771 772 773
        for (int g = 0; g < groups; ++g) {
          // im2col
          Tensor dy_slice = dy_batch.Slice(g * out_step, (g + 1) * out_step);
          Tensor ddx_slice = ddx_batch.Slice(g * in_step, (g + 1) * in_step);
          if (!is_expand) {
            col.ShareDataWith(ddx_slice);
            col_matrix.ShareDataWith(col);
            col_matrix.Resize(col_matrix_shape);
          } else if (data_dim == 2U) {
            im2col(dev_ctx, ddx_slice, dilations, strides,
L
liym27 已提交
774 775
                   std::vector<int>{paddings[0], paddings[2], paddings[1],
                                    paddings[3]},
L
lvmengsi 已提交
776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792
                   &col);
          } else if (data_dim == 3U) {
            vol2col(dev_ctx, ddx_slice, dilations, strides, paddings, &col);
          }

          Tensor dw_slice = dW_arr.Slice(g * out_step, (g + 1) * out_step);
          blas.MatMul(dy_slice, false, col_matrix, true, T(1.0), &dw_slice,
                      T(1.0));
        }
      }
    }

    // ddy = w * ddx + x * ddw ==> ddy(N, Cout, oH, oW), x/ddx(N, Cin, H, W),
    // w/ddw(Cout, Cin, kh, kw)
    // ddy convolution double grad: im2col(vol2col) + gemm
    if (ddY) {
      ddY->mutable_data<T>(ctx.GetPlace());
L
liym27 已提交
793

794
      Tensor transformed_ddY(ddY->dtype());
L
liym27 已提交
795 796 797 798 799 800 801
      if (channel_last) {
        ResizeToChannelFirst<DeviceContext, T>(ctx, ddY, &transformed_ddY);
      } else {
        transformed_ddY = *ddY;
      }

      set_zero(dev_ctx, &transformed_ddY, static_cast<T>(0));
L
lvmengsi 已提交
802 803 804
      math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
      math::Vol2ColFunctor<DeviceContext, T> vol2col;
      for (int i = 0; i < batch_size; ++i) {
L
liym27 已提交
805 806
        Tensor ddy_batch =
            transformed_ddY.Slice(i, i + 1).Resize(output_matrix_shape);
L
lvmengsi 已提交
807
        for (int g = 0; g < groups; ++g) {
L
liym27 已提交
808
          // gemm
L
lvmengsi 已提交
809
          Tensor ddy_slice = ddy_batch.Slice(g * out_step, (g + 1) * out_step);
L
liym27 已提交
810

L
lvmengsi 已提交
811
          if (ddX) {
L
liym27 已提交
812 813
            Tensor ddx_batch =
                transformed_ddX.Slice(i, i + 1).Resize(input_shape);
L
lvmengsi 已提交
814 815 816 817 818 819 820
            Tensor ddx_slice = ddx_batch.Slice(g * in_step, (g + 1) * in_step);
            if (!is_expand) {
              col.ShareDataWith(ddx_slice);
              col_matrix.ShareDataWith(col);
              col_matrix.Resize(col_matrix_shape);
            } else if (data_dim == 2U) {
              im2col(dev_ctx, ddx_slice, dilations, strides,
L
liym27 已提交
821 822
                     std::vector<int>{paddings[0], paddings[2], paddings[1],
                                      paddings[3]},
L
lvmengsi 已提交
823 824 825 826
                     &col);
            } else if (data_dim == 3U) {
              vol2col(dev_ctx, ddx_slice, dilations, strides, paddings, &col);
            }
L
lvmengsi 已提交
827 828 829
            Tensor w_slice = W.Slice(g * out_step, (g + 1) * out_step);
            blas.MatMul(w_slice, false, col_matrix, false, T(1.0), &ddy_slice,
                        T(0.0));
L
lvmengsi 已提交
830
          }
L
lvmengsi 已提交
831 832

          if (ddW_in) {
L
liym27 已提交
833
            Tensor x_batch = transformed_X.Slice(i, i + 1).Resize(input_shape);
L
lvmengsi 已提交
834
            Tensor x_slice = x_batch.Slice(g * in_step, (g + 1) * in_step);
L
lvmengsi 已提交
835

L
liym27 已提交
836 837
            Tensor ddW;
            ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape);
L
lvmengsi 已提交
838 839 840 841 842 843
            if (!is_expand) {
              col.ShareDataWith(x_slice);
              col_matrix.ShareDataWith(col);
              col_matrix.Resize(col_matrix_shape);
            } else if (data_dim == 2U) {
              im2col(dev_ctx, x_slice, dilations, strides,
L
liym27 已提交
844 845
                     std::vector<int>{paddings[0], paddings[2], paddings[1],
                                      paddings[3]},
L
lvmengsi 已提交
846 847 848 849 850 851 852 853 854 855 856 857
                     &col);
            } else if (data_dim == 3U) {
              vol2col(dev_ctx, x_slice, dilations, strides, paddings, &col);
            }

            // gemm
            Tensor ddw_slice = ddW.Slice(g * out_step, (g + 1) * out_step);
            blas.MatMul(ddw_slice, false, col_matrix, false, T(1.0), &ddy_slice,
                        T(1.0));
          }
        }
      }
L
liym27 已提交
858 859 860
      if (channel_last) {
        TransToChannelLast<DeviceContext, T>(ctx, &transformed_ddY, ddY);
      }
L
lvmengsi 已提交
861 862 863 864
    }
  }
};

Z
zlx 已提交
865 866 867 868 869 870 871 872 873
template <typename DeviceContext, typename T>
class DepthwiseConvKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    const Tensor* input = context.Input<Tensor>("Input");
    Tensor filter = *context.Input<Tensor>("Filter");
    Tensor* output = context.Output<Tensor>("Output");
    output->mutable_data<T>(context.GetPlace());

L
liym27 已提交
874
    const std::vector<int> strides = context.Attr<std::vector<int>>("strides");
Z
zlx 已提交
875 876
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
877
    bool fuse_relu = context.Attr<bool>("fuse_relu_before_depthwise_conv");
L
liym27 已提交
878 879 880 881 882 883 884 885 886 887

    const std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");
    const std::string data_format = context.Attr<std::string>("data_format");

    const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
    if (channel_last) {
      PADDLE_ENFORCE_EQ(
          output->dims()[output->dims().size() - 1] %
              input->dims()[input->dims().size() - 1],
888 889 890 891 892 893
          0, platform::errors::InvalidArgument(
                 "ShapeError: The output channels must be a multiple of the "
                 "input channels. But receivced output channel number is %d "
                 "and input channel number is %d",
                 output->dims()[output->dims().size() - 1],
                 input->dims()[input->dims().size() - 1]));
L
liym27 已提交
894 895 896
    } else {
      PADDLE_ENFORCE_EQ(
          output->dims()[1] % input->dims()[1], 0,
897 898 899 900 901
          platform::errors::InvalidArgument(
              "ShapeError: The output channels must be a multiple of the "
              "input channels. But receivced output channel number is %d "
              "and input channel number is %d",
              output->dims()[1], input->dims()[1]));
L
liym27 已提交
902 903 904
    }

    // update padding and dilation
905
    auto in_dims = input->dims();
L
liym27 已提交
906 907 908
    auto filter_dims = filter.dims();

    framework::DDim in_data_dims;
909 910 911
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_format);
    if (data_layout != framework::DataLayout::kNHWC) {
912
      in_data_dims = pten::slice_ddim(in_dims, 2, in_dims.size());
913
    } else {
914
      in_data_dims = pten::slice_ddim(in_dims, 1, in_dims.size() - 1);
915
    }
L
liym27 已提交
916 917

    framework::DDim filter_data_dims =
918 919
        pten::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = pten::vectorize<int>(filter_data_dims);
L
liym27 已提交
920 921 922 923 924 925 926 927 928 929
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

    bool is_sys_pad = strides.size() * 2 == paddings.size() ? false : true;
    if (!is_sys_pad) {
      for (size_t i = 0; i < strides.size(); ++i) {
        paddings.erase(paddings.begin() + i + 1);
      }
    }

Z
zlx 已提交
930
    auto& dev_ctx = context.template device_context<DeviceContext>();
931 932 933

    if (fuse_relu) {
      math::DepthwiseConvFunctor<DeviceContext, T, true> depthwiseConv;
934 935
      depthwiseConv(dev_ctx, *input, filter, strides, paddings, dilations,
                    output, data_layout);
936 937
    } else {
      math::DepthwiseConvFunctor<DeviceContext, T, false> depthwiseConv;
938 939
      depthwiseConv(dev_ctx, *input, filter, strides, paddings, dilations,
                    output, data_layout);
940
    }
941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961
  }
};

template <typename DeviceContext, typename T>
class DepthwiseConvGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    const Tensor* input = context.Input<Tensor>("Input");
    const Tensor* output_grad =
        context.Input<Tensor>(framework::GradVarName("Output"));
    Tensor* input_grad =
        context.Output<Tensor>(framework::GradVarName("Input"));
    Tensor* filter_grad =
        context.Output<Tensor>(framework::GradVarName("Filter"));
    Tensor filter = *context.Input<Tensor>("Filter");

    if (!input_grad && !filter_grad) return;

    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
962
    bool fuse_relu = context.Attr<bool>("fuse_relu_before_depthwise_conv");
L
liym27 已提交
963 964 965 966 967
    const std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");
    const std::string data_format = context.Attr<std::string>("data_format");

    // update padding and dilation
968
    auto in_dims = input->dims();
L
liym27 已提交
969 970 971
    auto filter_dims = filter.dims();

    framework::DDim in_data_dims;
972 973 974
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_format);
    if (data_layout != framework::DataLayout::kNHWC) {
975
      in_data_dims = pten::slice_ddim(in_dims, 2, in_dims.size());
976
    } else {
977
      in_data_dims = pten::slice_ddim(in_dims, 1, in_dims.size() - 1);
978
    }
L
liym27 已提交
979
    framework::DDim filter_data_dims =
980 981
        pten::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = pten::vectorize<int>(filter_data_dims);
L
liym27 已提交
982 983 984 985 986 987 988 989 990
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

    bool is_sys_pad = strides.size() * 2 == paddings.size() ? false : true;
    if (!is_sys_pad) {
      for (size_t i = 0; i < strides.size(); ++i) {
        paddings.erase(paddings.begin() + i + 1);
      }
    }
991
    pten::funcs::SetConstant<DeviceContext, T> set_zero;
992 993 994 995
    auto& dev_ctx = context.template device_context<DeviceContext>();

    if (input_grad) {
      input_grad->mutable_data<T>(context.GetPlace());
996
      set_zero(dev_ctx, input_grad, static_cast<T>(0));
997 998 999 1000

      if (fuse_relu) {
        math::DepthwiseConvInputGradFunctor<DeviceContext, T, true>
            depthwiseConvInputGrad;
1001 1002
        depthwiseConvInputGrad(dev_ctx, *input, filter, *output_grad, strides,
                               paddings, dilations, input_grad, data_layout);
1003 1004 1005
      } else {
        math::DepthwiseConvInputGradFunctor<DeviceContext, T, false>
            depthwiseConvInputGrad;
1006 1007
        depthwiseConvInputGrad(dev_ctx, *input, filter, *output_grad, strides,
                               paddings, dilations, input_grad, data_layout);
1008
      }
1009 1010 1011 1012 1013
    }

    if (filter_grad) {
      filter_grad->mutable_data<T>(context.GetPlace());
      set_zero(dev_ctx, filter_grad, static_cast<T>(0));
1014 1015 1016
      if (fuse_relu) {
        math::DepthwiseConvFilterGradFunctor<DeviceContext, T, true>
            depthwiseConvFilterGrad;
1017 1018
        depthwiseConvFilterGrad(dev_ctx, *input, *output_grad, strides,
                                paddings, dilations, filter_grad, data_layout);
1019 1020 1021
      } else {
        math::DepthwiseConvFilterGradFunctor<DeviceContext, T, false>
            depthwiseConvFilterGrad;
1022 1023
        depthwiseConvFilterGrad(dev_ctx, *input, *output_grad, strides,
                                paddings, dilations, filter_grad, data_layout);
1024
      }
1025
    }
Z
zlx 已提交
1026 1027 1028
  }
};

1029 1030
}  // namespace operators
}  // namespace paddle