conv_transpose_op.h 29.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16 17
#include <algorithm>
#include <string>
S
Siddharth Goyal 已提交
18
#include <vector>
Y
Yi Wang 已提交
19 20
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
21
#include "paddle/fluid/operators/conv_op.h"
22
#include "paddle/fluid/operators/eigen/eigen_function.h"
23
#include "paddle/fluid/operators/math/concat_and_split.h"
24
#include "paddle/fluid/operators/math/depthwise_conv.h"
Y
Yi Wang 已提交
25 26
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
27
#include "paddle/phi/kernels/funcs/blas/blas.h"
C
chengduoZH 已提交
28 29 30 31 32 33 34

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using DDim = framework::DDim;

35 36 37 38 39 40 41 42 43
template <typename DeviceContext, typename T, size_t D>
static void Slice(const framework::ExecutionContext& context,
                  const Tensor* input, Tensor* out,
                  const std::vector<int64_t>& begin_vec,
                  const std::vector<int64_t>& end_vec,
                  const std::vector<int64_t>& axes_vec) {
  auto& place =
      *context.template device_context<DeviceContext>().eigen_device();
  auto in_dims = input->dims();
44 45
  auto offsets = Eigen::DSizes<Eigen::DenseIndex, D>();
  auto extents = Eigen::DSizes<Eigen::DenseIndex, D>();
46 47 48 49 50
  for (size_t i = 0; i < D; ++i) {
    offsets[i] = 0;
    extents[i] = in_dims[i];
  }

51
  std::vector<int64_t> out_shape_vec = phi::vectorize(in_dims);
52 53 54 55 56 57
  for (size_t i = 0; i < axes_vec.size(); ++i) {
    offsets[axes_vec[i]] = begin_vec[i];
    extents[axes_vec[i]] = end_vec[i] - begin_vec[i];
    out_shape_vec[axes_vec[i]] = end_vec[i] - begin_vec[i];
  }

58
  framework::DDim out_dims(phi::make_ddim(out_shape_vec));
59 60 61 62 63 64 65 66 67
  out->mutable_data<T>(out_dims, context.GetPlace());

  auto in_t =
      framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
          *input);
  auto out_t =
      framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
          *out, out_dims);

68 69
  EigenSlice<std::decay_t<decltype(place)>, T, D>::Eval(place, out_t, in_t,
                                                        offsets, extents);
70 71 72 73 74 75 76 77 78 79 80 81 82
  out->Resize(out_dims);
}

template <typename DeviceContext, typename T, size_t D>
static void Slice(const framework::ExecutionContext& context,
                  const Tensor* input, Tensor* out, int64_t begin_idx,
                  int64_t end_idx, int64_t axes) {
  std::vector<int64_t> begin_vec = {begin_idx};
  std::vector<int64_t> end_vec = {end_idx};
  std::vector<int64_t> axes_vec = {axes};
  Slice<DeviceContext, T, D>(context, input, out, begin_vec, end_vec, axes_vec);
}

C
chengduoZH 已提交
83 84
// Define Op classes in .h file so that other conv transpose
// operator implementations can reuse the code.
C
chengduoZH 已提交
85 86
class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
87
  void Make() override;
C
chengduoZH 已提交
88 89
};

C
chengduoZH 已提交
90 91
class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
92
  void Make() override;
C
chengduoZH 已提交
93 94
};

C
chengduoZH 已提交
95
class ConvTransposeOp : public framework::OperatorWithKernel {
C
chengduoZH 已提交
96 97 98
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;
99 100 101 102

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
103 104 105 106

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override;
C
chengduoZH 已提交
107 108
};

C
chengduoZH 已提交
109
class ConvTransposeOpGrad : public framework::OperatorWithKernel {
C
chengduoZH 已提交
110 111 112
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;
113 114 115 116

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
C
chengduoZH 已提交
117 118
};

119 120 121 122 123 124 125 126 127 128
class ConvTransposeOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
};

Q
QI JUN 已提交
129
template <typename DeviceContext, typename T>
130
class GemmConvTransposeKernel : public framework::OpKernel<T> {
C
chengduoZH 已提交
131 132
 public:
  void Compute(const framework::ExecutionContext& context) const override {
133 134 135 136
    const std::string data_layout_str =
        context.Attr<std::string>("data_format");
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
C
chengduoZH 已提交
137 138 139 140 141 142
    const Tensor* input = context.Input<Tensor>("Input");
    // The filter will be reshaped, so it should not be constant pointer
    Tensor filter = *context.Input<Tensor>("Filter");
    Tensor* output = context.Output<Tensor>("Output");

    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
C
chengduoZH 已提交
143
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
144
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
Y
Yibing Liu 已提交
145
    int groups = context.Attr<int>("groups");
146 147
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");
C
chengduoZH 已提交
148

149 150 151
    auto in_dims = input->dims();
    auto filter_dims = filter.dims();
    auto out_dims = output->dims();
C
chengduoZH 已提交
152
    const int batch_size = static_cast<int>(input->dims()[0]);
C
chengduoZH 已提交
153

154
    framework::DDim in_data_dims;
155
    if (data_layout != framework::DataLayout::kNHWC) {
156
      in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
157
    } else {
158
      in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
159 160
    }
    framework::DDim filter_data_dims =
161 162
        phi::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
163 164 165 166 167
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

    // input_shape_vec: {n, c, h, w} or {n, c, d, h, w} for channel_first
    // input_shape_vec: {n, h, w, c} or {n, d, h, w, c} for channel_last
168
    std::vector<int64_t> input_shape_vec = phi::vectorize(input->dims());
169
    // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
170
    std::vector<int64_t> filter_shape_vec = phi::vectorize(filter.dims());
171 172 173

    // use col_shape in the im2col and col2im (or vol2col and col2vol)
    // calculation
174
    // col_shape_vec: {o_c/g, k_h, k_w, h, w} or {o_c/g, k_d, k_h, k_w, d, h, w}
C
chengduoZH 已提交
175 176
    size_t data_dim = filter_shape_vec.size() - 2;
    std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
177
    if (data_layout != framework::DataLayout::kNHWC) {
178 179 180 181 182 183 184 185 186 187 188
      col_shape_vec[0] = out_dims[1] / groups;
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
      }
    } else {
      col_shape_vec[0] = out_dims[out_dims.size() - 1] / groups;
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 1];
      }
C
chengduoZH 已提交
189
    }
190
    DDim col_shape(phi::make_ddim(col_shape_vec));
C
chengduoZH 已提交
191 192

    // use col_matrix_shape in the gemm calculation
193
    // size: (o_c/g * k_h * k_w, h * w) or (o_c/g * k_d * k_h * k_w, d * h * w)
194
    DDim col_matrix_shape = phi::flatten_to_2d(col_shape, data_dim + 1);
C
chengduoZH 已提交
195 196 197 198 199 200 201 202 203 204

    Tensor col;
    col.mutable_data<T>(col_shape, context.GetPlace());
    // col_matrix shares the same piece of data with col,
    // but will be reshaped into a two-dimensional matrix shape
    // to call the matrix multiplication interface.
    Tensor col_matrix;
    col_matrix.ShareDataWith(col);
    col_matrix.Resize(col_matrix_shape);

205 206
    // output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first
    // output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last
207
    DDim output_shape =
208
        phi::slice_ddim(output->dims(), 1, output->dims().size());
C
chengduoZH 已提交
209

210 211 212
    // input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first
    // input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last
    DDim input_matrix_shape;
213
    if (data_layout != framework::DataLayout::kNHWC) {
214 215 216 217
      input_matrix_shape = {in_dims[1], col_matrix_shape[1]};
    } else {
      input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]};
    }
218

219 220
    // filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w)
    DDim filter_matrix_shape;
221
    if (data_layout != framework::DataLayout::kNHWC) {
222 223 224 225
      filter_matrix_shape = {in_dims[1], col_matrix_shape[0]};
    } else {
      filter_matrix_shape = {in_dims[in_dims.size() - 1], col_matrix_shape[0]};
    }
C
chengduoZH 已提交
226 227 228
    filter.Resize(filter_matrix_shape);

    output->mutable_data<T>(context.GetPlace());
229
    phi::funcs::SetConstant<DeviceContext, T> set_zero;
Q
QI JUN 已提交
230
    auto& dev_ctx = context.template device_context<DeviceContext>();
231
    auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
Q
QI JUN 已提交
232
    set_zero(dev_ctx, output, static_cast<T>(0));
C
chengduoZH 已提交
233

234
    int in_step =
235
        (data_layout != framework::DataLayout::kNHWC
236 237 238 239
             ? static_cast<int>(in_dims[1]) / groups
             : static_cast<int>(in_dims[in_dims.size() - 1]) / groups);

    int out_step =
240
        (data_layout != framework::DataLayout::kNHWC
241 242
             ? static_cast<int>(out_dims[1]) / groups
             : static_cast<int>(out_dims[out_dims.size() - 1]) / groups);
Q
QI JUN 已提交
243 244
    math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
    math::Col2VolFunctor<DeviceContext, T> col2vol;
245
    math::ConcatFunctor<DeviceContext, T> concat_functor;
C
chengduoZH 已提交
246

247 248
    // convolution transpose: gemm + col2im or col2vol (similar to conv-backward
    // on input)
249
    size_t D = input->dims().size();
C
chengduoZH 已提交
250
    for (int i = 0; i < batch_size; i++) {
251 252
      // batch with size (i_c, h * w) or (i_c, d * h * w) for channel_first
      // batch with size (h * w, i_c) or (d * h * w, i_c) for channel_last
C
chengduoZH 已提交
253 254
      Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);

255 256
      // output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first
      // output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last
C
chengduoZH 已提交
257 258
      Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape);

259
      std::vector<Tensor> output_batch_vec;
Y
Yibing Liu 已提交
260
      for (int g = 0; g < groups; g++) {
261 262
        int64_t start = g * in_step;
        int64_t end = (g + 1) * in_step;
263
        int axes = (data_layout != framework::DataLayout::kNHWC ? 0 : 1);
Y
Yibing Liu 已提交
264
        Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step);
265
        Tensor in_slice, out_slice;
Y
Yibing Liu 已提交
266 267

        // col_matrix = filter_slice * input_slice
268 269
        // of shape (o_c/g * k_h * k_w, h * w)
        // or (o_c/g * k_d * k_h * k_w, d * h * w)
270
        if (data_layout != framework::DataLayout::kNHWC) {
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
          in_slice = input_batch.Slice(g * in_step, (g + 1) * in_step);
          out_slice = output_batch.Slice(g * out_step, (g + 1) * out_step);
          blas.MatMul(filter_slice, true, in_slice, false, static_cast<T>(1.0),
                      &col_matrix, static_cast<T>(0.0));
        } else {
          Slice<DeviceContext, T, 2>(context, &input_batch, &in_slice, start,
                                     end, axes);
          start = g * out_step;
          end = (g + 1) * out_step;
          axes = D - 2;
          if (D == 4U) {
            Slice<DeviceContext, T, 3>(context, &output_batch, &out_slice,
                                       start, end, axes);
          } else if (D == 5U) {
            Slice<DeviceContext, T, 4>(context, &output_batch, &out_slice,
                                       start, end, axes);
          }
          blas.MatMul(filter_slice, true, in_slice, true, static_cast<T>(1.0),
                      &col_matrix, static_cast<T>(0.0));
        }
Y
Yibing Liu 已提交
291 292 293

        if (data_dim == 2U) {
          // col2im: col_matrix -> dy
294 295
          // from (o_c/g * k_h * k_w, h * w) to (o_c/g, o_h, o_w) or (o_h, o_w,
          // o_c/g)
Y
Yibing Liu 已提交
296
          col2im(dev_ctx, col, dilations, strides,
297 298 299
                 std::vector<int>{paddings[0], paddings[2], paddings[1],
                                  paddings[3]},
                 &out_slice, data_layout);
Y
Yibing Liu 已提交
300 301
        } else if (data_dim == 3U) {
          // col2vol: col_matrix -> dy
302 303 304 305
          // from (o_c/g * k_d * k_h * k_w, d * h * w) to (o_c/g, o_d, o_h, o_w)
          // or (o_d, o_h, o_w, o_c/g)
          col2vol(dev_ctx, col, dilations, strides, paddings, &out_slice,
                  data_layout);
Y
Yibing Liu 已提交
306
        }
307 308 309
        if (data_layout == framework::DataLayout::kNHWC) {
          output_batch_vec.push_back(out_slice);
        }
310 311 312 313
      }
      if (data_layout == framework::DataLayout::kNHWC) {
        concat_functor(dev_ctx, output_batch_vec, static_cast<int>(D - 2),
                       &output_batch);
314
      }
C
chengduoZH 已提交
315 316 317 318
    }
  }
};

Q
QI JUN 已提交
319
template <typename DeviceContext, typename T>
320
class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
C
chengduoZH 已提交
321 322
 public:
  void Compute(const framework::ExecutionContext& context) const override {
323 324 325 326
    const std::string data_layout_str =
        context.Attr<std::string>("data_format");
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
C
chengduoZH 已提交
327 328 329 330 331 332 333 334 335 336 337
    const Tensor* input = context.Input<Tensor>("Input");
    const Tensor* output_grad =
        context.Input<Tensor>(framework::GradVarName("Output"));
    // For filter, we do not use const pointer b/c we will do reshape,
    // but we should avoid modifying its value.
    Tensor filter = *context.Input<Tensor>("Filter");
    Tensor* input_grad =
        context.Output<Tensor>(framework::GradVarName("Input"));
    Tensor* filter_grad =
        context.Output<Tensor>(framework::GradVarName("Filter"));

338 339
    if ((!input_grad) && (!filter_grad)) return;

C
chengduoZH 已提交
340 341
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
342
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
Y
Yibing Liu 已提交
343
    int groups = context.Attr<int>("groups");
344 345
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");
C
chengduoZH 已提交
346

347 348 349
    auto in_dims = input->dims();
    auto filter_dims = filter.dims();
    auto out_grad_dims = output_grad->dims();
C
chengduoZH 已提交
350
    const int batch_size = static_cast<int>(input->dims()[0]);
C
chengduoZH 已提交
351

352
    framework::DDim in_data_dims;
353
    if (data_layout != framework::DataLayout::kNHWC) {
354
      in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
355
    } else {
356
      in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
357 358
    }
    framework::DDim filter_data_dims =
359 360
        phi::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
361 362 363 364 365
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

    // input_shape_vec: {n, c, h, w} or {n, c, d, h, w} for channel_first
    // input_shape_vec: {n, h, w, c} or {n, d, h, w, c} for channel_last
366
    std::vector<int64_t> input_shape_vec = phi::vectorize(input->dims());
367
    // filter_shape_vec: {i_c, o_c, k_h, k_w} or {i_c, o_c, k_d, k_h, k_w}
368
    std::vector<int64_t> filter_shape_vec = phi::vectorize(filter.dims());
369 370 371

    // use col_shape in the im2col and col2im (or vol2col and col2vol)
    // calculation
372
    // col_shape_vec: {o_c, k_h, k_w, h, w} or {o_c, k_d, k_h, k_w, d, h, w} for
C
chengduoZH 已提交
373 374
    size_t data_dim = filter_shape_vec.size() - 2;
    std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
375
    if (data_layout != framework::DataLayout::kNHWC) {
376 377 378 379 380 381 382 383 384 385 386
      col_shape_vec[0] = out_grad_dims[1];
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
      }
    } else {
      col_shape_vec[0] = out_grad_dims[out_grad_dims.size() - 1];
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 1];
      }
C
chengduoZH 已提交
387
    }
388
    DDim col_shape(phi::make_ddim(col_shape_vec));
C
chengduoZH 已提交
389

390
    // use col_matrix_shape in the gemm calculation
391
    // size: (o_c * k_h * k_w, h * w) or (o_c * k_d * k_h * k_w, d * h * w)
392
    DDim col_matrix_shape = phi::flatten_to_2d(col_shape, data_dim + 1);
C
chengduoZH 已提交
393

394 395
    // output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first
    // output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last
396
    DDim output_shape =
397
        phi::slice_ddim(output_grad->dims(), 1, output_grad->dims().size());
C
chengduoZH 已提交
398

399 400 401
    // input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first
    // input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last
    DDim input_matrix_shape;
402
    if (data_layout != framework::DataLayout::kNHWC) {
403 404 405 406
      input_matrix_shape = {in_dims[1], col_matrix_shape[1]};
    } else {
      input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]};
    }
C
chengduoZH 已提交
407

408 409
    // filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w)
    DDim filter_matrix_shape;
410
    if (data_layout != framework::DataLayout::kNHWC) {
411 412 413 414 415
      filter_matrix_shape = {in_dims[1], col_matrix_shape[0] / groups};
    } else {
      filter_matrix_shape = {in_dims[in_dims.size() - 1],
                             col_matrix_shape[0] / groups};
    }
C
chengduoZH 已提交
416
    filter.Resize(filter_matrix_shape);
417 418

    int in_step =
419
        (data_layout != framework::DataLayout::kNHWC
420 421
             ? static_cast<int>(in_dims[1]) / groups
             : static_cast<int>(in_dims[in_dims.size() - 1]) / groups);
Y
Yibing Liu 已提交
422
    int col_step = static_cast<int>(col_matrix_shape[0]) / groups;
C
chengduoZH 已提交
423 424 425 426

    // convolution transpose grad on input:
    // im2col + gemm (similar to conv-forward)
    // input need to compute gradient
Q
QI JUN 已提交
427
    auto& dev_ctx = context.template device_context<DeviceContext>();
428
    auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
C
chengduoZH 已提交
429 430 431 432 433 434
    if (input_grad || filter_grad) {
      Tensor col;
      col.mutable_data<T>(col_shape, context.GetPlace());
      // col_matrix shares the same piece of data with col,
      // but will be reshaped into a two-dimensional matrix shape
      // to call the matrix multiplication interface.
C
chengduoZH 已提交
435 436 437 438
      Tensor col_matrix;
      col_matrix.ShareDataWith(col);
      col_matrix.Resize(col_matrix_shape);

C
chengduoZH 已提交
439
      Tensor filter_grad_;
440
      phi::funcs::SetConstant<DeviceContext, T> set_zero;
C
chengduoZH 已提交
441

Q
QI JUN 已提交
442 443
      math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
      math::Vol2ColFunctor<DeviceContext, T> vol2col;
444
      math::ConcatFunctor<DeviceContext, T> concat_functor;
C
chengduoZH 已提交
445

C
chengduoZH 已提交
446 447
      if (input_grad) {
        input_grad->mutable_data<T>(context.GetPlace());
448
        set_zero(dev_ctx, input_grad, static_cast<T>(0));
C
chengduoZH 已提交
449
      }
450
      if (filter_grad) {  // filter_grad_ size (i_c, o_c/g, k_h, k_w)
C
chengduoZH 已提交
451
        filter_grad->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
452
        set_zero(dev_ctx, filter_grad, static_cast<T>(0));
C
chengduoZH 已提交
453 454
        filter_grad_ = *filter_grad;
        filter_grad_.Resize(filter_matrix_shape);
C
chengduoZH 已提交
455 456
      }

457
      size_t D = input->dims().size();
C
chengduoZH 已提交
458
      for (int i = 0; i < batch_size; i++) {
459 460 461 462
        // batch with size (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for
        // channel_first
        // batch with size (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for
        // channel_last
C
chengduoZH 已提交
463 464 465
        Tensor output_grad_batch =
            output_grad->Slice(i, i + 1).Resize(output_shape);

C
chengduoZH 已提交
466
        if (data_dim == 2U) {
467
          // im2col: dy -> col matrix
468 469 470 471
          // from (o_c, o_h, o_w) to (o_c * k_h * k_w, i_h * i_w) for
          // channel_first
          // from (o_h, o_w, o_c) to (o_c * k_h * k_w, i_h * i_w) for
          // channel_last
472
          im2col(dev_ctx, output_grad_batch, dilations, strides,
473 474 475
                 std::vector<int>{paddings[0], paddings[2], paddings[1],
                                  paddings[3]},
                 &col, data_layout);
C
chengduoZH 已提交
476
        } else if (data_dim == 3U) {
477
          // vol2col: dy -> col_matrix
478 479 480 481
          // from (o_c, o_d, o_h, o_w) to (o_c * k_d * k_h * k_w, i_d * i_h *
          // i_w) for channel_first
          // from (o_d, o_h, o_w, o_c) to (i_d * i_h * i_w, o_c * k_d * k_h *
          // k_w) for channel_last
Q
QI JUN 已提交
482
          vol2col(dev_ctx, output_grad_batch, dilations, strides, paddings,
483
                  &col, data_layout);
484
        }
C
chengduoZH 已提交
485

C
chengduoZH 已提交
486
        if (input_grad) {
487
          // batch with size (i_c, i_h, i_w) or (i_h, i_w, i_c)
C
chengduoZH 已提交
488 489
          Tensor input_grad_batch =
              input_grad->Slice(i, i + 1).Resize(input_matrix_shape);
490

C
chengduoZH 已提交
491
          // gemm: dx = filter * dy
492 493
          // (i_c, o_c * k_h * k_w) * (o_c * k_h * k_w, i_h * i_w) -> (i_c, i_h
          // * i_w)
494
          // or
495 496 497 498 499 500
          // (i_c, o_c * k_d * k_h * k_w) * (o_c * k_d * k_h * k_w, i_d * i_h *
          // i_w) -> (i_c,
          // i_d, i_h, i_w)
          // gemm: dx = dy^T * filter^T for channel_last

          std::vector<Tensor> input_grad_batch_vec;
Y
Yibing Liu 已提交
501
          for (int g = 0; g < groups; g++) {
502 503 504 505 506
            // input_grad_slice: (i_c/g, i_h * i_w) or (i_c/g, i_d * i_h * i_w)
            // for channel_first
            // input_grad_slice: (i_h * i_w, i_c/g) or (i_d * i_h * i_w, i_c/g)
            // for channel_last
            // filter_slice: (i_c/g, o_c/g * k_h * k_w)
Y
Yibing Liu 已提交
507
            Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step);
508 509
            // col_matrix_slice: (o_c/g * k_h * k_w, h * w) or (o_c/g * k_d *
            // k_h * k_w, d * h * w)
Y
Yibing Liu 已提交
510 511
            Tensor col_matrix_slice =
                col_matrix.Slice(g * col_step, (g + 1) * col_step);
512
            if (data_layout != framework::DataLayout::kNHWC) {
513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540
              Tensor input_grad_slice =
                  input_grad_batch.Slice(g * in_step, (g + 1) * in_step);
              blas.MatMul(filter_slice, false, col_matrix_slice, false,
                          static_cast<T>(1.0), &input_grad_slice,
                          static_cast<T>(0.0));
            } else {
              Tensor input_grad_slice;
              Slice<DeviceContext, T, 2>(context, &input_grad_batch,
                                         &input_grad_slice, g * in_step,
                                         (g + 1) * in_step, 1);
              blas.MatMul(col_matrix_slice, true, filter_slice, true,
                          static_cast<T>(1.0), &input_grad_slice,
                          static_cast<T>(0.0));
              DDim input_grad_slice_shape;
              if (data_dim == 2U) {
                input_grad_slice_shape = {in_dims[1], in_dims[2], in_step};
              } else {
                input_grad_slice_shape = {in_dims[1], in_dims[2], in_dims[3],
                                          in_step};
              }
              input_grad_slice =
                  input_grad_slice.Resize(input_grad_slice_shape);
              input_grad_batch_vec.push_back(input_grad_slice);
            }
          }
          if (data_layout == framework::DataLayout::kNHWC) {
            concat_functor(dev_ctx, input_grad_batch_vec,
                           static_cast<int>(D - 2), &input_grad_batch);
Y
Yibing Liu 已提交
541
          }
C
chengduoZH 已提交
542 543
        }
        if (filter_grad) {
544
          // input batch: (i_c, i_h * i_w) or (i_h, i_w * i_c)
C
chengduoZH 已提交
545 546
          Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
          // gemm: d_filter = x * dy^T
547 548
          // (i_c, i_h * i_w) * (i_h * i_w, o_c * k_h * k_w) -> (i_c, o_c * k_h
          // * k_w)
549
          // or
550 551
          // (i_c, i_d * i_h * i_w) * (i_d * i_h * i_w, o_c * k_d * k_h * k_w)
          // -> (i_c, o_c * k_d *
C
chengduoZH 已提交
552
          // k_h * k_w)
553 554
          // gemm: d_filter = x^T * dy^T for channel_last

Y
Yibing Liu 已提交
555 556 557 558 559
          for (int g = 0; g < groups; g++) {
            Tensor filter_grad_slice =
                filter_grad_.Slice(g * in_step, (g + 1) * in_step);
            Tensor col_matrix_slice =
                col_matrix.Slice(g * col_step, (g + 1) * col_step);
560
            if (data_layout != framework::DataLayout::kNHWC) {
561 562 563 564 565 566 567 568 569 570 571 572 573
              Tensor in_batch_slice =
                  in_batch.Slice(g * in_step, (g + 1) * in_step);
              blas.MatMul(in_batch_slice, false, col_matrix_slice, true,
                          static_cast<T>(1.0), &filter_grad_slice,
                          static_cast<T>(1.0));
            } else {
              Tensor in_batch_slice;
              Slice<DeviceContext, T, 2>(context, &in_batch, &in_batch_slice,
                                         g * in_step, (g + 1) * in_step, 1);
              blas.MatMul(in_batch_slice, true, col_matrix_slice, true,
                          static_cast<T>(1.0), &filter_grad_slice,
                          static_cast<T>(1.0));
            }
Y
Yibing Liu 已提交
574
          }
C
chengduoZH 已提交
575
        }
C
chengduoZH 已提交
576 577 578 579
      }
    }
  }
};
580 581 582 583 584

template <typename DeviceContext, typename T>
class DepthwiseConvTransposeKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
585 586 587 588
    const std::string data_layout_str =
        context.Attr<std::string>("data_format");
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
589 590 591 592 593 594
    const Tensor* input = context.Input<Tensor>("Input");
    Tensor filter = *context.Input<Tensor>("Filter");
    Tensor* output = context.Output<Tensor>("Output");
    output->mutable_data<T>(context.GetPlace());

    int groups = context.Attr<int>("groups");
595 596 597 598 599 600
    PADDLE_ENFORCE_EQ(
        groups, filter.dims()[0],
        platform::errors::InvalidArgument(
            "groups should be error to the 1st dimension of filter. But "
            "received groups is %d and filter dimension[0] is %d",
            groups, filter.dims()[0]));
601 602 603 604

    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
605 606
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");
607
    for (auto v : dilations) {
608 609 610 611
      PADDLE_ENFORCE_EQ(v, 1, platform::errors::InvalidArgument(
                                  "dilations should be 1 in depthwise conv. "
                                  "But received dilations is %d",
                                  v));
612 613
    }

614 615 616 617
    auto in_dims = input->dims();
    auto filter_dims = filter.dims();

    framework::DDim in_data_dims;
618
    if (data_layout != framework::DataLayout::kNHWC) {
619
      in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
620
    } else {
621
      in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
622 623
    }
    framework::DDim filter_data_dims =
624 625
        phi::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
626 627 628
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

629 630
    output->mutable_data<T>(context.GetPlace());
    auto& dev_ctx = context.template device_context<DeviceContext>();
631
    phi::funcs::SetConstant<DeviceContext, T> set_zero;
632 633 634 635
    set_zero(dev_ctx, output, static_cast<T>(0));

    math::DepthwiseConvInputGradFunctor<DeviceContext, T>
        depthwiseConvInputGrad;
636 637 638 639
    depthwiseConvInputGrad(
        dev_ctx, *output, filter, *input, strides,
        std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
        dilations, output, data_layout);
640 641 642 643 644 645 646
  }
};

template <typename DeviceContext, typename T>
class DepthwiseConvTransposeGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
647 648 649 650
    const std::string data_layout_str =
        context.Attr<std::string>("data_format");
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
651 652 653 654 655 656 657 658 659 660 661 662 663 664
    const Tensor* input = context.Input<Tensor>("Input");
    const Tensor* output_grad =
        context.Input<Tensor>(framework::GradVarName("Output"));
    Tensor* input_grad =
        context.Output<Tensor>(framework::GradVarName("Input"));
    Tensor* filter_grad =
        context.Output<Tensor>(framework::GradVarName("Filter"));
    Tensor filter = *context.Input<Tensor>("Filter");

    if (!input_grad && !filter_grad) return;

    auto& dev_ctx = context.template device_context<DeviceContext>();
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
665
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
666 667 668 669 670 671 672
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");

    auto in_dims = input->dims();
    auto filter_dims = filter.dims();

    framework::DDim in_data_dims;
673
    if (data_layout != framework::DataLayout::kNHWC) {
674
      in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
675
    } else {
676
      in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
677 678
    }
    framework::DDim filter_data_dims =
679 680
        phi::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
681 682
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);
683 684 685

    if (input_grad) {
      math::DepthwiseConvFunctor<DeviceContext, T> depthwiseConv;
686
      depthwiseConv(
687
          dev_ctx, *output_grad, filter, strides,
688
          std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
689
          dilations, input_grad, data_layout);
690 691 692
    }

    if (filter_grad) {
693
      phi::funcs::SetConstant<DeviceContext, T> set_zero;
694 695 696 697 698
      filter_grad->mutable_data<T>(context.GetPlace());
      set_zero(dev_ctx, filter_grad, static_cast<T>(0));

      math::DepthwiseConvFilterGradFunctor<DeviceContext, T>
          depthwiseConvFilterGrad;
699 700 701 702
      depthwiseConvFilterGrad(
          dev_ctx, *output_grad, *input, strides,
          std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
          dilations, filter_grad, data_layout);
703 704 705
    }
  }
};
C
chengduoZH 已提交
706 707
}  // namespace operators
}  // namespace paddle