conv_transpose_op.h 30.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16 17
#include <algorithm>
#include <string>
S
Siddharth Goyal 已提交
18
#include <vector>
Y
Yi Wang 已提交
19 20
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
Y
Yu Yang 已提交
21
#include "paddle/fluid/operators/math/blas.h"
22
#include "paddle/fluid/operators/math/concat_and_split.h"
23
#include "paddle/fluid/operators/math/depthwise_conv.h"
Y
Yi Wang 已提交
24 25
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
C
chengduoZH 已提交
26 27 28 29 30 31 32

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using DDim = framework::DDim;

33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
template <typename DeviceContext, typename T, size_t D>
static void Slice(const framework::ExecutionContext& context,
                  const Tensor* input, Tensor* out,
                  const std::vector<int64_t>& begin_vec,
                  const std::vector<int64_t>& end_vec,
                  const std::vector<int64_t>& axes_vec) {
  auto& place =
      *context.template device_context<DeviceContext>().eigen_device();
  auto in_dims = input->dims();
  auto offsets = Eigen::array<int, D>();
  auto extents = Eigen::array<int, D>();
  for (size_t i = 0; i < D; ++i) {
    offsets[i] = 0;
    extents[i] = in_dims[i];
  }

  std::vector<int64_t> out_shape_vec = framework::vectorize(in_dims);
  for (size_t i = 0; i < axes_vec.size(); ++i) {
    offsets[axes_vec[i]] = begin_vec[i];
    extents[axes_vec[i]] = end_vec[i] - begin_vec[i];
    out_shape_vec[axes_vec[i]] = end_vec[i] - begin_vec[i];
  }

  framework::DDim out_dims(framework::make_ddim(out_shape_vec));
  out->mutable_data<T>(out_dims, context.GetPlace());

  auto in_t =
      framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
          *input);
  auto out_t =
      framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
          *out, out_dims);

  out_t.device(place) = in_t.slice(offsets, extents);
  out->Resize(out_dims);
}

template <typename DeviceContext, typename T, size_t D>
static void Slice(const framework::ExecutionContext& context,
                  const Tensor* input, Tensor* out, int64_t begin_idx,
                  int64_t end_idx, int64_t axes) {
  std::vector<int64_t> begin_vec = {begin_idx};
  std::vector<int64_t> end_vec = {end_idx};
  std::vector<int64_t> axes_vec = {axes};
  Slice<DeviceContext, T, D>(context, input, out, begin_vec, end_vec, axes_vec);
}

inline void UpdatePaddingAndDilation(std::vector<int>* paddings,
                                     std::vector<int>* dilation,
                                     const std::string padding_algorithm,
                                     const framework::DDim data_dims,
                                     const std::vector<int>& strides,
                                     const std::vector<int>& ksize) {
  // set padding size == data_dims.size() * 2
  auto data_shape = framework::vectorize<int>(data_dims);
  if (paddings->size() == data_dims.size()) {
    for (size_t i = 0; i < data_dims.size(); ++i) {
      int copy_pad = *(paddings->begin() + 2 * i);
      paddings->insert(paddings->begin() + 2 * i + 1, copy_pad);
    }
  } else {
    PADDLE_ENFORCE_EQ(
        data_dims.size() * 2, paddings->size(),
        "Paddings size should be the same or twice as the input data size.");
  }

  // when padding_algorithm is "VALID" or "SAME"
  if (padding_algorithm == "SAME") {
    for (size_t i = 0; i < data_dims.size(); ++i) {
      int out_size = (data_dims[i] + strides[i] - 1) / strides[0];
      int pad_sum =
          std::max((out_size - 1) * strides[i] + ksize[i] - data_shape[i], 0);
      int pad_0 = pad_sum / 2;
      int pad_1 = pad_sum - pad_0;
      *(paddings->begin() + i * 2) = pad_0;
      *(paddings->begin() + i * 2 + 1) = pad_1;

      // dilation
      *(dilation->begin() + i) = 1;
    }

  } else if (padding_algorithm == "VALID") {
    for (auto it = paddings->begin(); it != paddings->end(); it++) {
      *it = 0;
    }
  }
}

C
chengduoZH 已提交
121 122
// Define Op classes in .h file so that other conv transpose
// operator implementations can reuse the code.
C
chengduoZH 已提交
123 124
class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
125
  void Make() override;
C
chengduoZH 已提交
126 127
};

C
chengduoZH 已提交
128 129
class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
130
  void Make() override;
C
chengduoZH 已提交
131 132
};

C
chengduoZH 已提交
133
class ConvTransposeOp : public framework::OperatorWithKernel {
C
chengduoZH 已提交
134 135 136
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;
137 138 139 140

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
C
chengduoZH 已提交
141 142
};

C
chengduoZH 已提交
143
class ConvTransposeOpGrad : public framework::OperatorWithKernel {
C
chengduoZH 已提交
144 145 146
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;
147 148 149 150

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
C
chengduoZH 已提交
151 152
};

Q
QI JUN 已提交
153
template <typename DeviceContext, typename T>
154
class GemmConvTransposeKernel : public framework::OpKernel<T> {
C
chengduoZH 已提交
155 156
 public:
  void Compute(const framework::ExecutionContext& context) const override {
157 158 159 160
    const std::string data_layout_str =
        context.Attr<std::string>("data_format");
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
C
chengduoZH 已提交
161 162 163 164 165 166
    const Tensor* input = context.Input<Tensor>("Input");
    // The filter will be reshaped, so it should not be constant pointer
    Tensor filter = *context.Input<Tensor>("Filter");
    Tensor* output = context.Output<Tensor>("Output");

    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
C
chengduoZH 已提交
167
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
168
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
Y
Yibing Liu 已提交
169
    int groups = context.Attr<int>("groups");
170 171
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");
C
chengduoZH 已提交
172

173 174 175
    auto in_dims = input->dims();
    auto filter_dims = filter.dims();
    auto out_dims = output->dims();
C
chengduoZH 已提交
176
    const int batch_size = static_cast<int>(input->dims()[0]);
C
chengduoZH 已提交
177

178
    framework::DDim in_data_dims;
179
    if (data_layout != framework::DataLayout::kNHWC) {
180 181 182 183 184 185 186 187 188 189 190 191
      in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
    } else {
      in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
    }
    framework::DDim filter_data_dims =
        framework::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

    // input_shape_vec: {n, c, h, w} or {n, c, d, h, w} for channel_first
    // input_shape_vec: {n, h, w, c} or {n, d, h, w, c} for channel_last
192
    std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims());
193
    // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
194 195 196 197
    std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims());

    // use col_shape in the im2col and col2im (or vol2col and col2vol)
    // calculation
198
    // col_shape_vec: {o_c/g, k_h, k_w, h, w} or {o_c/g, k_d, k_h, k_w, d, h, w}
C
chengduoZH 已提交
199 200
    size_t data_dim = filter_shape_vec.size() - 2;
    std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
201
    if (data_layout != framework::DataLayout::kNHWC) {
202 203 204 205 206 207 208 209 210 211 212
      col_shape_vec[0] = out_dims[1] / groups;
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
      }
    } else {
      col_shape_vec[0] = out_dims[out_dims.size() - 1] / groups;
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 1];
      }
C
chengduoZH 已提交
213
    }
214
    DDim col_shape(framework::make_ddim(col_shape_vec));
C
chengduoZH 已提交
215 216

    // use col_matrix_shape in the gemm calculation
217
    // size: (o_c/g * k_h * k_w, h * w) or (o_c/g * k_d * k_h * k_w, d * h * w)
C
chengduoZH 已提交
218
    DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1);
C
chengduoZH 已提交
219 220 221 222 223 224 225 226 227 228

    Tensor col;
    col.mutable_data<T>(col_shape, context.GetPlace());
    // col_matrix shares the same piece of data with col,
    // but will be reshaped into a two-dimensional matrix shape
    // to call the matrix multiplication interface.
    Tensor col_matrix;
    col_matrix.ShareDataWith(col);
    col_matrix.Resize(col_matrix_shape);

229 230
    // output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first
    // output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last
231 232
    DDim output_shape =
        framework::slice_ddim(output->dims(), 1, output->dims().size());
C
chengduoZH 已提交
233

234 235 236
    // input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first
    // input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last
    DDim input_matrix_shape;
237
    if (data_layout != framework::DataLayout::kNHWC) {
238 239 240 241
      input_matrix_shape = {in_dims[1], col_matrix_shape[1]};
    } else {
      input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]};
    }
242

243 244
    // filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w)
    DDim filter_matrix_shape;
245
    if (data_layout != framework::DataLayout::kNHWC) {
246 247 248 249
      filter_matrix_shape = {in_dims[1], col_matrix_shape[0]};
    } else {
      filter_matrix_shape = {in_dims[in_dims.size() - 1], col_matrix_shape[0]};
    }
C
chengduoZH 已提交
250 251 252
    filter.Resize(filter_matrix_shape);

    output->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
253 254
    math::SetConstant<DeviceContext, T> set_zero;
    auto& dev_ctx = context.template device_context<DeviceContext>();
Y
Yu Yang 已提交
255
    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
Q
QI JUN 已提交
256
    set_zero(dev_ctx, output, static_cast<T>(0));
C
chengduoZH 已提交
257

258
    int in_step =
259
        (data_layout != framework::DataLayout::kNHWC
260 261 262 263
             ? static_cast<int>(in_dims[1]) / groups
             : static_cast<int>(in_dims[in_dims.size() - 1]) / groups);

    int out_step =
264
        (data_layout != framework::DataLayout::kNHWC
265 266
             ? static_cast<int>(out_dims[1]) / groups
             : static_cast<int>(out_dims[out_dims.size() - 1]) / groups);
Q
QI JUN 已提交
267 268
    math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
    math::Col2VolFunctor<DeviceContext, T> col2vol;
269
    math::ConcatFunctor<DeviceContext, T> concat_functor;
C
chengduoZH 已提交
270

271 272
    // convolution transpose: gemm + col2im or col2vol (similar to conv-backward
    // on input)
273
    size_t D = input->dims().size();
C
chengduoZH 已提交
274
    for (int i = 0; i < batch_size; i++) {
275 276
      // batch with size (i_c, h * w) or (i_c, d * h * w) for channel_first
      // batch with size (h * w, i_c) or (d * h * w, i_c) for channel_last
C
chengduoZH 已提交
277 278
      Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);

279 280
      // output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first
      // output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last
C
chengduoZH 已提交
281 282
      Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape);

283
      std::vector<Tensor> output_batch_vec;
Y
Yibing Liu 已提交
284
      for (int g = 0; g < groups; g++) {
285 286
        int64_t start = g * in_step;
        int64_t end = (g + 1) * in_step;
287
        int axes = (data_layout != framework::DataLayout::kNHWC ? 0 : 1);
Y
Yibing Liu 已提交
288
        Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step);
289
        Tensor in_slice, out_slice;
Y
Yibing Liu 已提交
290 291

        // col_matrix = filter_slice * input_slice
292 293
        // of shape (o_c/g * k_h * k_w, h * w)
        // or (o_c/g * k_d * k_h * k_w, d * h * w)
294
        if (data_layout != framework::DataLayout::kNHWC) {
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314
          in_slice = input_batch.Slice(g * in_step, (g + 1) * in_step);
          out_slice = output_batch.Slice(g * out_step, (g + 1) * out_step);
          blas.MatMul(filter_slice, true, in_slice, false, static_cast<T>(1.0),
                      &col_matrix, static_cast<T>(0.0));
        } else {
          Slice<DeviceContext, T, 2>(context, &input_batch, &in_slice, start,
                                     end, axes);
          start = g * out_step;
          end = (g + 1) * out_step;
          axes = D - 2;
          if (D == 4U) {
            Slice<DeviceContext, T, 3>(context, &output_batch, &out_slice,
                                       start, end, axes);
          } else if (D == 5U) {
            Slice<DeviceContext, T, 4>(context, &output_batch, &out_slice,
                                       start, end, axes);
          }
          blas.MatMul(filter_slice, true, in_slice, true, static_cast<T>(1.0),
                      &col_matrix, static_cast<T>(0.0));
        }
Y
Yibing Liu 已提交
315 316 317

        if (data_dim == 2U) {
          // col2im: col_matrix -> dy
318 319
          // from (o_c/g * k_h * k_w, h * w) to (o_c/g, o_h, o_w) or (o_h, o_w,
          // o_c/g)
Y
Yibing Liu 已提交
320
          col2im(dev_ctx, col, dilations, strides,
321 322 323
                 std::vector<int>{paddings[0], paddings[2], paddings[1],
                                  paddings[3]},
                 &out_slice, data_layout);
Y
Yibing Liu 已提交
324 325
        } else if (data_dim == 3U) {
          // col2vol: col_matrix -> dy
326 327 328 329
          // from (o_c/g * k_d * k_h * k_w, d * h * w) to (o_c/g, o_d, o_h, o_w)
          // or (o_d, o_h, o_w, o_c/g)
          col2vol(dev_ctx, col, dilations, strides, paddings, &out_slice,
                  data_layout);
Y
Yibing Liu 已提交
330
        }
331 332 333 334 335
        output_batch_vec.push_back(out_slice);
      }
      if (data_layout == framework::DataLayout::kNHWC) {
        concat_functor(dev_ctx, output_batch_vec, static_cast<int>(D - 2),
                       &output_batch);
336
      }
C
chengduoZH 已提交
337 338 339 340
    }
  }
};

Q
QI JUN 已提交
341
template <typename DeviceContext, typename T>
342
class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
C
chengduoZH 已提交
343 344
 public:
  void Compute(const framework::ExecutionContext& context) const override {
345 346 347 348
    const std::string data_layout_str =
        context.Attr<std::string>("data_format");
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
C
chengduoZH 已提交
349 350 351 352 353 354 355 356 357 358 359
    const Tensor* input = context.Input<Tensor>("Input");
    const Tensor* output_grad =
        context.Input<Tensor>(framework::GradVarName("Output"));
    // For filter, we do not use const pointer b/c we will do reshape,
    // but we should avoid modifying its value.
    Tensor filter = *context.Input<Tensor>("Filter");
    Tensor* input_grad =
        context.Output<Tensor>(framework::GradVarName("Input"));
    Tensor* filter_grad =
        context.Output<Tensor>(framework::GradVarName("Filter"));

360 361
    if ((!input_grad) && (!filter_grad)) return;

C
chengduoZH 已提交
362 363
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
364
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
Y
Yibing Liu 已提交
365
    int groups = context.Attr<int>("groups");
366 367
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");
C
chengduoZH 已提交
368

369 370 371
    auto in_dims = input->dims();
    auto filter_dims = filter.dims();
    auto out_grad_dims = output_grad->dims();
C
chengduoZH 已提交
372
    const int batch_size = static_cast<int>(input->dims()[0]);
C
chengduoZH 已提交
373

374
    framework::DDim in_data_dims;
375
    if (data_layout != framework::DataLayout::kNHWC) {
376 377 378 379 380 381 382 383 384 385 386 387
      in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
    } else {
      in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
    }
    framework::DDim filter_data_dims =
        framework::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

    // input_shape_vec: {n, c, h, w} or {n, c, d, h, w} for channel_first
    // input_shape_vec: {n, h, w, c} or {n, d, h, w, c} for channel_last
388
    std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims());
389
    // filter_shape_vec: {i_c, o_c, k_h, k_w} or {i_c, o_c, k_d, k_h, k_w}
390 391 392 393
    std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims());

    // use col_shape in the im2col and col2im (or vol2col and col2vol)
    // calculation
394
    // col_shape_vec: {o_c, k_h, k_w, h, w} or {o_c, k_d, k_h, k_w, d, h, w} for
C
chengduoZH 已提交
395 396
    size_t data_dim = filter_shape_vec.size() - 2;
    std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
397
    if (data_layout != framework::DataLayout::kNHWC) {
398 399 400 401 402 403 404 405 406 407 408
      col_shape_vec[0] = out_grad_dims[1];
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
      }
    } else {
      col_shape_vec[0] = out_grad_dims[out_grad_dims.size() - 1];
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 1];
      }
C
chengduoZH 已提交
409
    }
410
    DDim col_shape(framework::make_ddim(col_shape_vec));
C
chengduoZH 已提交
411

412
    // use col_matrix_shape in the gemm calculation
413
    // size: (o_c * k_h * k_w, h * w) or (o_c * k_d * k_h * k_w, d * h * w)
C
chengduoZH 已提交
414
    DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1);
C
chengduoZH 已提交
415

416 417
    // output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first
    // output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last
418 419
    DDim output_shape = framework::slice_ddim(output_grad->dims(), 1,
                                              output_grad->dims().size());
C
chengduoZH 已提交
420

421 422 423
    // input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first
    // input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last
    DDim input_matrix_shape;
424
    if (data_layout != framework::DataLayout::kNHWC) {
425 426 427 428
      input_matrix_shape = {in_dims[1], col_matrix_shape[1]};
    } else {
      input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]};
    }
C
chengduoZH 已提交
429

430 431
    // filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w)
    DDim filter_matrix_shape;
432
    if (data_layout != framework::DataLayout::kNHWC) {
433 434 435 436 437
      filter_matrix_shape = {in_dims[1], col_matrix_shape[0] / groups};
    } else {
      filter_matrix_shape = {in_dims[in_dims.size() - 1],
                             col_matrix_shape[0] / groups};
    }
C
chengduoZH 已提交
438
    filter.Resize(filter_matrix_shape);
439 440

    int in_step =
441
        (data_layout != framework::DataLayout::kNHWC
442 443
             ? static_cast<int>(in_dims[1]) / groups
             : static_cast<int>(in_dims[in_dims.size() - 1]) / groups);
Y
Yibing Liu 已提交
444
    int col_step = static_cast<int>(col_matrix_shape[0]) / groups;
C
chengduoZH 已提交
445 446 447 448

    // convolution transpose grad on input:
    // im2col + gemm (similar to conv-forward)
    // input need to compute gradient
Q
QI JUN 已提交
449
    auto& dev_ctx = context.template device_context<DeviceContext>();
Y
Yu Yang 已提交
450
    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
C
chengduoZH 已提交
451 452 453 454 455 456
    if (input_grad || filter_grad) {
      Tensor col;
      col.mutable_data<T>(col_shape, context.GetPlace());
      // col_matrix shares the same piece of data with col,
      // but will be reshaped into a two-dimensional matrix shape
      // to call the matrix multiplication interface.
C
chengduoZH 已提交
457 458 459 460
      Tensor col_matrix;
      col_matrix.ShareDataWith(col);
      col_matrix.Resize(col_matrix_shape);

C
chengduoZH 已提交
461
      Tensor filter_grad_;
Q
QI JUN 已提交
462
      math::SetConstant<DeviceContext, T> set_zero;
C
chengduoZH 已提交
463

Q
QI JUN 已提交
464 465
      math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
      math::Vol2ColFunctor<DeviceContext, T> vol2col;
466
      math::ConcatFunctor<DeviceContext, T> concat_functor;
C
chengduoZH 已提交
467

C
chengduoZH 已提交
468 469
      if (input_grad) {
        input_grad->mutable_data<T>(context.GetPlace());
470
        set_zero(dev_ctx, input_grad, static_cast<T>(0));
C
chengduoZH 已提交
471
      }
472
      if (filter_grad) {  // filter_grad_ size (i_c, o_c/g, k_h, k_w)
C
chengduoZH 已提交
473
        filter_grad->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
474
        set_zero(dev_ctx, filter_grad, static_cast<T>(0));
C
chengduoZH 已提交
475 476
        filter_grad_ = *filter_grad;
        filter_grad_.Resize(filter_matrix_shape);
C
chengduoZH 已提交
477 478
      }

479
      size_t D = input->dims().size();
C
chengduoZH 已提交
480
      for (int i = 0; i < batch_size; i++) {
481 482 483 484
        // batch with size (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for
        // channel_first
        // batch with size (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for
        // channel_last
C
chengduoZH 已提交
485 486 487
        Tensor output_grad_batch =
            output_grad->Slice(i, i + 1).Resize(output_shape);

C
chengduoZH 已提交
488
        if (data_dim == 2U) {
489
          // im2col: dy -> col matrix
490 491 492 493
          // from (o_c, o_h, o_w) to (o_c * k_h * k_w, i_h * i_w) for
          // channel_first
          // from (o_h, o_w, o_c) to (o_c * k_h * k_w, i_h * i_w) for
          // channel_last
494
          im2col(dev_ctx, output_grad_batch, dilations, strides,
495 496 497
                 std::vector<int>{paddings[0], paddings[2], paddings[1],
                                  paddings[3]},
                 &col, data_layout);
C
chengduoZH 已提交
498
        } else if (data_dim == 3U) {
499
          // vol2col: dy -> col_matrix
500 501 502 503
          // from (o_c, o_d, o_h, o_w) to (o_c * k_d * k_h * k_w, i_d * i_h *
          // i_w) for channel_first
          // from (o_d, o_h, o_w, o_c) to (i_d * i_h * i_w, o_c * k_d * k_h *
          // k_w) for channel_last
Q
QI JUN 已提交
504
          vol2col(dev_ctx, output_grad_batch, dilations, strides, paddings,
505
                  &col, data_layout);
506
        }
C
chengduoZH 已提交
507

C
chengduoZH 已提交
508
        if (input_grad) {
509
          // batch with size (i_c, i_h, i_w) or (i_h, i_w, i_c)
C
chengduoZH 已提交
510 511
          Tensor input_grad_batch =
              input_grad->Slice(i, i + 1).Resize(input_matrix_shape);
512

C
chengduoZH 已提交
513
          // gemm: dx = filter * dy
514 515
          // (i_c, o_c * k_h * k_w) * (o_c * k_h * k_w, i_h * i_w) -> (i_c, i_h
          // * i_w)
516
          // or
517 518 519 520 521 522
          // (i_c, o_c * k_d * k_h * k_w) * (o_c * k_d * k_h * k_w, i_d * i_h *
          // i_w) -> (i_c,
          // i_d, i_h, i_w)
          // gemm: dx = dy^T * filter^T for channel_last

          std::vector<Tensor> input_grad_batch_vec;
Y
Yibing Liu 已提交
523
          for (int g = 0; g < groups; g++) {
524 525 526 527 528
            // input_grad_slice: (i_c/g, i_h * i_w) or (i_c/g, i_d * i_h * i_w)
            // for channel_first
            // input_grad_slice: (i_h * i_w, i_c/g) or (i_d * i_h * i_w, i_c/g)
            // for channel_last
            // filter_slice: (i_c/g, o_c/g * k_h * k_w)
Y
Yibing Liu 已提交
529
            Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step);
530 531
            // col_matrix_slice: (o_c/g * k_h * k_w, h * w) or (o_c/g * k_d *
            // k_h * k_w, d * h * w)
Y
Yibing Liu 已提交
532 533
            Tensor col_matrix_slice =
                col_matrix.Slice(g * col_step, (g + 1) * col_step);
534
            if (data_layout != framework::DataLayout::kNHWC) {
535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562
              Tensor input_grad_slice =
                  input_grad_batch.Slice(g * in_step, (g + 1) * in_step);
              blas.MatMul(filter_slice, false, col_matrix_slice, false,
                          static_cast<T>(1.0), &input_grad_slice,
                          static_cast<T>(0.0));
            } else {
              Tensor input_grad_slice;
              Slice<DeviceContext, T, 2>(context, &input_grad_batch,
                                         &input_grad_slice, g * in_step,
                                         (g + 1) * in_step, 1);
              blas.MatMul(col_matrix_slice, true, filter_slice, true,
                          static_cast<T>(1.0), &input_grad_slice,
                          static_cast<T>(0.0));
              DDim input_grad_slice_shape;
              if (data_dim == 2U) {
                input_grad_slice_shape = {in_dims[1], in_dims[2], in_step};
              } else {
                input_grad_slice_shape = {in_dims[1], in_dims[2], in_dims[3],
                                          in_step};
              }
              input_grad_slice =
                  input_grad_slice.Resize(input_grad_slice_shape);
              input_grad_batch_vec.push_back(input_grad_slice);
            }
          }
          if (data_layout == framework::DataLayout::kNHWC) {
            concat_functor(dev_ctx, input_grad_batch_vec,
                           static_cast<int>(D - 2), &input_grad_batch);
Y
Yibing Liu 已提交
563
          }
C
chengduoZH 已提交
564 565
        }
        if (filter_grad) {
566
          // input batch: (i_c, i_h * i_w) or (i_h, i_w * i_c)
C
chengduoZH 已提交
567 568
          Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
          // gemm: d_filter = x * dy^T
569 570
          // (i_c, i_h * i_w) * (i_h * i_w, o_c * k_h * k_w) -> (i_c, o_c * k_h
          // * k_w)
571
          // or
572 573
          // (i_c, i_d * i_h * i_w) * (i_d * i_h * i_w, o_c * k_d * k_h * k_w)
          // -> (i_c, o_c * k_d *
C
chengduoZH 已提交
574
          // k_h * k_w)
575 576
          // gemm: d_filter = x^T * dy^T for channel_last

Y
Yibing Liu 已提交
577 578 579 580 581
          for (int g = 0; g < groups; g++) {
            Tensor filter_grad_slice =
                filter_grad_.Slice(g * in_step, (g + 1) * in_step);
            Tensor col_matrix_slice =
                col_matrix.Slice(g * col_step, (g + 1) * col_step);
582
            if (data_layout != framework::DataLayout::kNHWC) {
583 584 585 586 587 588 589 590 591 592 593 594 595
              Tensor in_batch_slice =
                  in_batch.Slice(g * in_step, (g + 1) * in_step);
              blas.MatMul(in_batch_slice, false, col_matrix_slice, true,
                          static_cast<T>(1.0), &filter_grad_slice,
                          static_cast<T>(1.0));
            } else {
              Tensor in_batch_slice;
              Slice<DeviceContext, T, 2>(context, &in_batch, &in_batch_slice,
                                         g * in_step, (g + 1) * in_step, 1);
              blas.MatMul(in_batch_slice, true, col_matrix_slice, true,
                          static_cast<T>(1.0), &filter_grad_slice,
                          static_cast<T>(1.0));
            }
Y
Yibing Liu 已提交
596
          }
C
chengduoZH 已提交
597
        }
C
chengduoZH 已提交
598 599 600 601
      }
    }
  }
};
602 603 604 605 606

template <typename DeviceContext, typename T>
class DepthwiseConvTransposeKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
607 608 609 610
    const std::string data_layout_str =
        context.Attr<std::string>("data_format");
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
611 612 613 614 615 616 617 618 619 620 621
    const Tensor* input = context.Input<Tensor>("Input");
    Tensor filter = *context.Input<Tensor>("Filter");
    Tensor* output = context.Output<Tensor>("Output");
    output->mutable_data<T>(context.GetPlace());

    int groups = context.Attr<int>("groups");
    PADDLE_ENFORCE_EQ(groups, filter.dims()[0]);

    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
622 623
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");
624 625 626 627
    for (auto v : dilations) {
      PADDLE_ENFORCE_EQ(v, 1);
    }

628 629 630 631
    auto in_dims = input->dims();
    auto filter_dims = filter.dims();

    framework::DDim in_data_dims;
632
    if (data_layout != framework::DataLayout::kNHWC) {
633 634 635 636 637 638 639 640 641 642
      in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
    } else {
      in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
    }
    framework::DDim filter_data_dims =
        framework::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

643 644 645 646 647 648 649
    output->mutable_data<T>(context.GetPlace());
    auto& dev_ctx = context.template device_context<DeviceContext>();
    math::SetConstant<DeviceContext, T> set_zero;
    set_zero(dev_ctx, output, static_cast<T>(0));

    math::DepthwiseConvInputGradFunctor<DeviceContext, T>
        depthwiseConvInputGrad;
650 651 652 653
    depthwiseConvInputGrad(
        dev_ctx, *output, filter, *input, strides,
        std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
        dilations, output, data_layout);
654 655 656 657 658 659 660
  }
};

template <typename DeviceContext, typename T>
class DepthwiseConvTransposeGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
661 662 663 664
    const std::string data_layout_str =
        context.Attr<std::string>("data_format");
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
665 666 667 668 669 670 671 672 673 674 675 676 677 678
    const Tensor* input = context.Input<Tensor>("Input");
    const Tensor* output_grad =
        context.Input<Tensor>(framework::GradVarName("Output"));
    Tensor* input_grad =
        context.Output<Tensor>(framework::GradVarName("Input"));
    Tensor* filter_grad =
        context.Output<Tensor>(framework::GradVarName("Filter"));
    Tensor filter = *context.Input<Tensor>("Filter");

    if (!input_grad && !filter_grad) return;

    auto& dev_ctx = context.template device_context<DeviceContext>();
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
679
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
680 681 682 683 684 685 686
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");

    auto in_dims = input->dims();
    auto filter_dims = filter.dims();

    framework::DDim in_data_dims;
687
    if (data_layout != framework::DataLayout::kNHWC) {
688 689 690 691 692 693 694 695 696
      in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
    } else {
      in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
    }
    framework::DDim filter_data_dims =
        framework::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);
697 698 699

    if (input_grad) {
      math::DepthwiseConvFunctor<DeviceContext, T> depthwiseConv;
700 701 702 703
      depthwiseConv(
          dev_ctx, *output_grad, filter, strides, paddings,
          std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
          input_grad, data_layout);
704 705 706 707 708 709 710 711 712
    }

    if (filter_grad) {
      math::SetConstant<DeviceContext, T> set_zero;
      filter_grad->mutable_data<T>(context.GetPlace());
      set_zero(dev_ctx, filter_grad, static_cast<T>(0));

      math::DepthwiseConvFilterGradFunctor<DeviceContext, T>
          depthwiseConvFilterGrad;
713 714 715 716
      depthwiseConvFilterGrad(
          dev_ctx, *output_grad, *input, strides,
          std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
          dilations, filter_grad, data_layout);
717 718 719
    }
  }
};
C
chengduoZH 已提交
720 721
}  // namespace operators
}  // namespace paddle