conv_transpose_op.h 30.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16 17
#include <algorithm>
#include <string>
S
Siddharth Goyal 已提交
18
#include <vector>
Y
Yi Wang 已提交
19 20
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
Y
Yu Yang 已提交
21
#include "paddle/fluid/operators/math/blas.h"
22
#include "paddle/fluid/operators/math/concat_and_split.h"
23
#include "paddle/fluid/operators/math/depthwise_conv.h"
Y
Yi Wang 已提交
24 25
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
C
chengduoZH 已提交
26 27 28 29 30 31 32

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using DDim = framework::DDim;

33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
template <typename DeviceContext, typename T, size_t D>
static void Slice(const framework::ExecutionContext& context,
                  const Tensor* input, Tensor* out,
                  const std::vector<int64_t>& begin_vec,
                  const std::vector<int64_t>& end_vec,
                  const std::vector<int64_t>& axes_vec) {
  auto& place =
      *context.template device_context<DeviceContext>().eigen_device();
  auto in_dims = input->dims();
  auto offsets = Eigen::array<int, D>();
  auto extents = Eigen::array<int, D>();
  for (size_t i = 0; i < D; ++i) {
    offsets[i] = 0;
    extents[i] = in_dims[i];
  }

  std::vector<int64_t> out_shape_vec = framework::vectorize(in_dims);
  for (size_t i = 0; i < axes_vec.size(); ++i) {
    offsets[axes_vec[i]] = begin_vec[i];
    extents[axes_vec[i]] = end_vec[i] - begin_vec[i];
    out_shape_vec[axes_vec[i]] = end_vec[i] - begin_vec[i];
  }

  framework::DDim out_dims(framework::make_ddim(out_shape_vec));
  out->mutable_data<T>(out_dims, context.GetPlace());

  auto in_t =
      framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
          *input);
  auto out_t =
      framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
          *out, out_dims);

  out_t.device(place) = in_t.slice(offsets, extents);
  out->Resize(out_dims);
}

template <typename DeviceContext, typename T, size_t D>
static void Slice(const framework::ExecutionContext& context,
                  const Tensor* input, Tensor* out, int64_t begin_idx,
                  int64_t end_idx, int64_t axes) {
  std::vector<int64_t> begin_vec = {begin_idx};
  std::vector<int64_t> end_vec = {end_idx};
  std::vector<int64_t> axes_vec = {axes};
  Slice<DeviceContext, T, D>(context, input, out, begin_vec, end_vec, axes_vec);
}

inline void UpdatePaddingAndDilation(std::vector<int>* paddings,
                                     std::vector<int>* dilation,
                                     const std::string padding_algorithm,
                                     const framework::DDim data_dims,
                                     const std::vector<int>& strides,
                                     const std::vector<int>& ksize) {
  // set padding size == data_dims.size() * 2
  auto data_shape = framework::vectorize<int>(data_dims);
  if (paddings->size() == data_dims.size()) {
    for (size_t i = 0; i < data_dims.size(); ++i) {
      int copy_pad = *(paddings->begin() + 2 * i);
      paddings->insert(paddings->begin() + 2 * i + 1, copy_pad);
    }
  } else {
    PADDLE_ENFORCE_EQ(
        data_dims.size() * 2, paddings->size(),
        "Paddings size should be the same or twice as the input data size.");
  }

  // when padding_algorithm is "VALID" or "SAME"
  if (padding_algorithm == "SAME") {
    for (size_t i = 0; i < data_dims.size(); ++i) {
      int out_size = (data_dims[i] + strides[i] - 1) / strides[0];
      int pad_sum =
          std::max((out_size - 1) * strides[i] + ksize[i] - data_shape[i], 0);
      int pad_0 = pad_sum / 2;
      int pad_1 = pad_sum - pad_0;
      *(paddings->begin() + i * 2) = pad_0;
      *(paddings->begin() + i * 2 + 1) = pad_1;

      // dilation
      *(dilation->begin() + i) = 1;
    }

  } else if (padding_algorithm == "VALID") {
    for (auto it = paddings->begin(); it != paddings->end(); it++) {
      *it = 0;
    }
  }
}

C
chengduoZH 已提交
121 122
// Define Op classes in .h file so that other conv transpose
// operator implementations can reuse the code.
C
chengduoZH 已提交
123 124
class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
125
  void Make() override;
C
chengduoZH 已提交
126 127
};

C
chengduoZH 已提交
128 129
class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
130
  void Make() override;
C
chengduoZH 已提交
131 132
};

C
chengduoZH 已提交
133
class ConvTransposeOp : public framework::OperatorWithKernel {
C
chengduoZH 已提交
134 135 136
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;
137 138 139 140

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
C
chengduoZH 已提交
141 142
};

C
chengduoZH 已提交
143
class ConvTransposeOpGrad : public framework::OperatorWithKernel {
C
chengduoZH 已提交
144 145 146
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;
147 148 149 150

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
C
chengduoZH 已提交
151 152
};

Q
QI JUN 已提交
153
template <typename DeviceContext, typename T>
154
class GemmConvTransposeKernel : public framework::OpKernel<T> {
C
chengduoZH 已提交
155 156
 public:
  void Compute(const framework::ExecutionContext& context) const override {
157 158 159 160
    const std::string data_layout_str =
        context.Attr<std::string>("data_format");
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
C
chengduoZH 已提交
161 162 163 164 165 166
    const Tensor* input = context.Input<Tensor>("Input");
    // The filter will be reshaped, so it should not be constant pointer
    Tensor filter = *context.Input<Tensor>("Filter");
    Tensor* output = context.Output<Tensor>("Output");

    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
C
chengduoZH 已提交
167
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
168
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
Y
Yibing Liu 已提交
169
    int groups = context.Attr<int>("groups");
170 171
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");
C
chengduoZH 已提交
172

173 174 175
    auto in_dims = input->dims();
    auto filter_dims = filter.dims();
    auto out_dims = output->dims();
C
chengduoZH 已提交
176
    const int batch_size = static_cast<int>(input->dims()[0]);
C
chengduoZH 已提交
177

178
    framework::DDim in_data_dims;
179
    if (data_layout != framework::DataLayout::kNHWC) {
180 181 182 183 184 185 186 187 188 189 190 191
      in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
    } else {
      in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
    }
    framework::DDim filter_data_dims =
        framework::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

    // input_shape_vec: {n, c, h, w} or {n, c, d, h, w} for channel_first
    // input_shape_vec: {n, h, w, c} or {n, d, h, w, c} for channel_last
192
    std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims());
193
    // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
194 195 196 197
    std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims());

    // use col_shape in the im2col and col2im (or vol2col and col2vol)
    // calculation
198
    // col_shape_vec: {o_c/g, k_h, k_w, h, w} or {o_c/g, k_d, k_h, k_w, d, h, w}
C
chengduoZH 已提交
199 200
    size_t data_dim = filter_shape_vec.size() - 2;
    std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
201
    if (data_layout != framework::DataLayout::kNHWC) {
202 203 204 205 206 207 208 209 210 211 212
      col_shape_vec[0] = out_dims[1] / groups;
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
      }
    } else {
      col_shape_vec[0] = out_dims[out_dims.size() - 1] / groups;
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 1];
      }
C
chengduoZH 已提交
213
    }
214
    DDim col_shape(framework::make_ddim(col_shape_vec));
C
chengduoZH 已提交
215 216

    // use col_matrix_shape in the gemm calculation
217
    // size: (o_c/g * k_h * k_w, h * w) or (o_c/g * k_d * k_h * k_w, d * h * w)
C
chengduoZH 已提交
218
    DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1);
C
chengduoZH 已提交
219 220 221 222 223 224 225 226 227 228

    Tensor col;
    col.mutable_data<T>(col_shape, context.GetPlace());
    // col_matrix shares the same piece of data with col,
    // but will be reshaped into a two-dimensional matrix shape
    // to call the matrix multiplication interface.
    Tensor col_matrix;
    col_matrix.ShareDataWith(col);
    col_matrix.Resize(col_matrix_shape);

229 230
    // output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first
    // output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last
231 232
    DDim output_shape =
        framework::slice_ddim(output->dims(), 1, output->dims().size());
C
chengduoZH 已提交
233

234 235 236
    // input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first
    // input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last
    DDim input_matrix_shape;
237
    if (data_layout != framework::DataLayout::kNHWC) {
238 239 240 241
      input_matrix_shape = {in_dims[1], col_matrix_shape[1]};
    } else {
      input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]};
    }
242

243 244
    // filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w)
    DDim filter_matrix_shape;
245
    if (data_layout != framework::DataLayout::kNHWC) {
246 247 248 249
      filter_matrix_shape = {in_dims[1], col_matrix_shape[0]};
    } else {
      filter_matrix_shape = {in_dims[in_dims.size() - 1], col_matrix_shape[0]};
    }
C
chengduoZH 已提交
250 251 252
    filter.Resize(filter_matrix_shape);

    output->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
253 254
    math::SetConstant<DeviceContext, T> set_zero;
    auto& dev_ctx = context.template device_context<DeviceContext>();
Y
Yu Yang 已提交
255
    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
Q
QI JUN 已提交
256
    set_zero(dev_ctx, output, static_cast<T>(0));
C
chengduoZH 已提交
257

258
    int in_step =
259
        (data_layout != framework::DataLayout::kNHWC
260 261 262 263
             ? static_cast<int>(in_dims[1]) / groups
             : static_cast<int>(in_dims[in_dims.size() - 1]) / groups);

    int out_step =
264
        (data_layout != framework::DataLayout::kNHWC
265 266
             ? static_cast<int>(out_dims[1]) / groups
             : static_cast<int>(out_dims[out_dims.size() - 1]) / groups);
Q
QI JUN 已提交
267 268
    math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
    math::Col2VolFunctor<DeviceContext, T> col2vol;
269
    math::ConcatFunctor<DeviceContext, T> concat_functor;
C
chengduoZH 已提交
270

271 272
    // convolution transpose: gemm + col2im or col2vol (similar to conv-backward
    // on input)
273
    size_t D = input->dims().size();
C
chengduoZH 已提交
274
    for (int i = 0; i < batch_size; i++) {
275 276
      // batch with size (i_c, h * w) or (i_c, d * h * w) for channel_first
      // batch with size (h * w, i_c) or (d * h * w, i_c) for channel_last
C
chengduoZH 已提交
277 278
      Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);

279 280
      // output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first
      // output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last
C
chengduoZH 已提交
281 282
      Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape);

283
      std::vector<Tensor> output_batch_vec;
Y
Yibing Liu 已提交
284
      for (int g = 0; g < groups; g++) {
285 286
        int64_t start = g * in_step;
        int64_t end = (g + 1) * in_step;
287
        int axes = (data_layout != framework::DataLayout::kNHWC ? 0 : 1);
Y
Yibing Liu 已提交
288
        Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step);
289
        Tensor in_slice, out_slice;
Y
Yibing Liu 已提交
290 291

        // col_matrix = filter_slice * input_slice
292 293
        // of shape (o_c/g * k_h * k_w, h * w)
        // or (o_c/g * k_d * k_h * k_w, d * h * w)
294
        if (data_layout != framework::DataLayout::kNHWC) {
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314
          in_slice = input_batch.Slice(g * in_step, (g + 1) * in_step);
          out_slice = output_batch.Slice(g * out_step, (g + 1) * out_step);
          blas.MatMul(filter_slice, true, in_slice, false, static_cast<T>(1.0),
                      &col_matrix, static_cast<T>(0.0));
        } else {
          Slice<DeviceContext, T, 2>(context, &input_batch, &in_slice, start,
                                     end, axes);
          start = g * out_step;
          end = (g + 1) * out_step;
          axes = D - 2;
          if (D == 4U) {
            Slice<DeviceContext, T, 3>(context, &output_batch, &out_slice,
                                       start, end, axes);
          } else if (D == 5U) {
            Slice<DeviceContext, T, 4>(context, &output_batch, &out_slice,
                                       start, end, axes);
          }
          blas.MatMul(filter_slice, true, in_slice, true, static_cast<T>(1.0),
                      &col_matrix, static_cast<T>(0.0));
        }
Y
Yibing Liu 已提交
315 316 317

        if (data_dim == 2U) {
          // col2im: col_matrix -> dy
318 319
          // from (o_c/g * k_h * k_w, h * w) to (o_c/g, o_h, o_w) or (o_h, o_w,
          // o_c/g)
Y
Yibing Liu 已提交
320
          col2im(dev_ctx, col, dilations, strides,
321 322 323
                 std::vector<int>{paddings[0], paddings[2], paddings[1],
                                  paddings[3]},
                 &out_slice, data_layout);
Y
Yibing Liu 已提交
324 325
        } else if (data_dim == 3U) {
          // col2vol: col_matrix -> dy
326 327 328 329
          // from (o_c/g * k_d * k_h * k_w, d * h * w) to (o_c/g, o_d, o_h, o_w)
          // or (o_d, o_h, o_w, o_c/g)
          col2vol(dev_ctx, col, dilations, strides, paddings, &out_slice,
                  data_layout);
Y
Yibing Liu 已提交
330
        }
331 332 333
        if (data_layout == framework::DataLayout::kNHWC) {
          output_batch_vec.push_back(out_slice);
        }
334 335 336 337
      }
      if (data_layout == framework::DataLayout::kNHWC) {
        concat_functor(dev_ctx, output_batch_vec, static_cast<int>(D - 2),
                       &output_batch);
338
      }
C
chengduoZH 已提交
339 340 341 342
    }
  }
};

Q
QI JUN 已提交
343
template <typename DeviceContext, typename T>
344
class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
C
chengduoZH 已提交
345 346
 public:
  void Compute(const framework::ExecutionContext& context) const override {
347 348 349 350
    const std::string data_layout_str =
        context.Attr<std::string>("data_format");
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
C
chengduoZH 已提交
351 352 353 354 355 356 357 358 359 360 361
    const Tensor* input = context.Input<Tensor>("Input");
    const Tensor* output_grad =
        context.Input<Tensor>(framework::GradVarName("Output"));
    // For filter, we do not use const pointer b/c we will do reshape,
    // but we should avoid modifying its value.
    Tensor filter = *context.Input<Tensor>("Filter");
    Tensor* input_grad =
        context.Output<Tensor>(framework::GradVarName("Input"));
    Tensor* filter_grad =
        context.Output<Tensor>(framework::GradVarName("Filter"));

362 363
    if ((!input_grad) && (!filter_grad)) return;

C
chengduoZH 已提交
364 365
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
366
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
Y
Yibing Liu 已提交
367
    int groups = context.Attr<int>("groups");
368 369
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");
C
chengduoZH 已提交
370

371 372 373
    auto in_dims = input->dims();
    auto filter_dims = filter.dims();
    auto out_grad_dims = output_grad->dims();
C
chengduoZH 已提交
374
    const int batch_size = static_cast<int>(input->dims()[0]);
C
chengduoZH 已提交
375

376
    framework::DDim in_data_dims;
377
    if (data_layout != framework::DataLayout::kNHWC) {
378 379 380 381 382 383 384 385 386 387 388 389
      in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
    } else {
      in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
    }
    framework::DDim filter_data_dims =
        framework::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

    // input_shape_vec: {n, c, h, w} or {n, c, d, h, w} for channel_first
    // input_shape_vec: {n, h, w, c} or {n, d, h, w, c} for channel_last
390
    std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims());
391
    // filter_shape_vec: {i_c, o_c, k_h, k_w} or {i_c, o_c, k_d, k_h, k_w}
392 393 394 395
    std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims());

    // use col_shape in the im2col and col2im (or vol2col and col2vol)
    // calculation
396
    // col_shape_vec: {o_c, k_h, k_w, h, w} or {o_c, k_d, k_h, k_w, d, h, w} for
C
chengduoZH 已提交
397 398
    size_t data_dim = filter_shape_vec.size() - 2;
    std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
399
    if (data_layout != framework::DataLayout::kNHWC) {
400 401 402 403 404 405 406 407 408 409 410
      col_shape_vec[0] = out_grad_dims[1];
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
      }
    } else {
      col_shape_vec[0] = out_grad_dims[out_grad_dims.size() - 1];
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 1];
      }
C
chengduoZH 已提交
411
    }
412
    DDim col_shape(framework::make_ddim(col_shape_vec));
C
chengduoZH 已提交
413

414
    // use col_matrix_shape in the gemm calculation
415
    // size: (o_c * k_h * k_w, h * w) or (o_c * k_d * k_h * k_w, d * h * w)
C
chengduoZH 已提交
416
    DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1);
C
chengduoZH 已提交
417

418 419
    // output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first
    // output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last
420 421
    DDim output_shape = framework::slice_ddim(output_grad->dims(), 1,
                                              output_grad->dims().size());
C
chengduoZH 已提交
422

423 424 425
    // input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first
    // input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last
    DDim input_matrix_shape;
426
    if (data_layout != framework::DataLayout::kNHWC) {
427 428 429 430
      input_matrix_shape = {in_dims[1], col_matrix_shape[1]};
    } else {
      input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]};
    }
C
chengduoZH 已提交
431

432 433
    // filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w)
    DDim filter_matrix_shape;
434
    if (data_layout != framework::DataLayout::kNHWC) {
435 436 437 438 439
      filter_matrix_shape = {in_dims[1], col_matrix_shape[0] / groups};
    } else {
      filter_matrix_shape = {in_dims[in_dims.size() - 1],
                             col_matrix_shape[0] / groups};
    }
C
chengduoZH 已提交
440
    filter.Resize(filter_matrix_shape);
441 442

    int in_step =
443
        (data_layout != framework::DataLayout::kNHWC
444 445
             ? static_cast<int>(in_dims[1]) / groups
             : static_cast<int>(in_dims[in_dims.size() - 1]) / groups);
Y
Yibing Liu 已提交
446
    int col_step = static_cast<int>(col_matrix_shape[0]) / groups;
C
chengduoZH 已提交
447 448 449 450

    // convolution transpose grad on input:
    // im2col + gemm (similar to conv-forward)
    // input need to compute gradient
Q
QI JUN 已提交
451
    auto& dev_ctx = context.template device_context<DeviceContext>();
Y
Yu Yang 已提交
452
    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
C
chengduoZH 已提交
453 454 455 456 457 458
    if (input_grad || filter_grad) {
      Tensor col;
      col.mutable_data<T>(col_shape, context.GetPlace());
      // col_matrix shares the same piece of data with col,
      // but will be reshaped into a two-dimensional matrix shape
      // to call the matrix multiplication interface.
C
chengduoZH 已提交
459 460 461 462
      Tensor col_matrix;
      col_matrix.ShareDataWith(col);
      col_matrix.Resize(col_matrix_shape);

C
chengduoZH 已提交
463
      Tensor filter_grad_;
Q
QI JUN 已提交
464
      math::SetConstant<DeviceContext, T> set_zero;
C
chengduoZH 已提交
465

Q
QI JUN 已提交
466 467
      math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
      math::Vol2ColFunctor<DeviceContext, T> vol2col;
468
      math::ConcatFunctor<DeviceContext, T> concat_functor;
C
chengduoZH 已提交
469

C
chengduoZH 已提交
470 471
      if (input_grad) {
        input_grad->mutable_data<T>(context.GetPlace());
472
        set_zero(dev_ctx, input_grad, static_cast<T>(0));
C
chengduoZH 已提交
473
      }
474
      if (filter_grad) {  // filter_grad_ size (i_c, o_c/g, k_h, k_w)
C
chengduoZH 已提交
475
        filter_grad->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
476
        set_zero(dev_ctx, filter_grad, static_cast<T>(0));
C
chengduoZH 已提交
477 478
        filter_grad_ = *filter_grad;
        filter_grad_.Resize(filter_matrix_shape);
C
chengduoZH 已提交
479 480
      }

481
      size_t D = input->dims().size();
C
chengduoZH 已提交
482
      for (int i = 0; i < batch_size; i++) {
483 484 485 486
        // batch with size (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for
        // channel_first
        // batch with size (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for
        // channel_last
C
chengduoZH 已提交
487 488 489
        Tensor output_grad_batch =
            output_grad->Slice(i, i + 1).Resize(output_shape);

C
chengduoZH 已提交
490
        if (data_dim == 2U) {
491
          // im2col: dy -> col matrix
492 493 494 495
          // from (o_c, o_h, o_w) to (o_c * k_h * k_w, i_h * i_w) for
          // channel_first
          // from (o_h, o_w, o_c) to (o_c * k_h * k_w, i_h * i_w) for
          // channel_last
496
          im2col(dev_ctx, output_grad_batch, dilations, strides,
497 498 499
                 std::vector<int>{paddings[0], paddings[2], paddings[1],
                                  paddings[3]},
                 &col, data_layout);
C
chengduoZH 已提交
500
        } else if (data_dim == 3U) {
501
          // vol2col: dy -> col_matrix
502 503 504 505
          // from (o_c, o_d, o_h, o_w) to (o_c * k_d * k_h * k_w, i_d * i_h *
          // i_w) for channel_first
          // from (o_d, o_h, o_w, o_c) to (i_d * i_h * i_w, o_c * k_d * k_h *
          // k_w) for channel_last
Q
QI JUN 已提交
506
          vol2col(dev_ctx, output_grad_batch, dilations, strides, paddings,
507
                  &col, data_layout);
508
        }
C
chengduoZH 已提交
509

C
chengduoZH 已提交
510
        if (input_grad) {
511
          // batch with size (i_c, i_h, i_w) or (i_h, i_w, i_c)
C
chengduoZH 已提交
512 513
          Tensor input_grad_batch =
              input_grad->Slice(i, i + 1).Resize(input_matrix_shape);
514

C
chengduoZH 已提交
515
          // gemm: dx = filter * dy
516 517
          // (i_c, o_c * k_h * k_w) * (o_c * k_h * k_w, i_h * i_w) -> (i_c, i_h
          // * i_w)
518
          // or
519 520 521 522 523 524
          // (i_c, o_c * k_d * k_h * k_w) * (o_c * k_d * k_h * k_w, i_d * i_h *
          // i_w) -> (i_c,
          // i_d, i_h, i_w)
          // gemm: dx = dy^T * filter^T for channel_last

          std::vector<Tensor> input_grad_batch_vec;
Y
Yibing Liu 已提交
525
          for (int g = 0; g < groups; g++) {
526 527 528 529 530
            // input_grad_slice: (i_c/g, i_h * i_w) or (i_c/g, i_d * i_h * i_w)
            // for channel_first
            // input_grad_slice: (i_h * i_w, i_c/g) or (i_d * i_h * i_w, i_c/g)
            // for channel_last
            // filter_slice: (i_c/g, o_c/g * k_h * k_w)
Y
Yibing Liu 已提交
531
            Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step);
532 533
            // col_matrix_slice: (o_c/g * k_h * k_w, h * w) or (o_c/g * k_d *
            // k_h * k_w, d * h * w)
Y
Yibing Liu 已提交
534 535
            Tensor col_matrix_slice =
                col_matrix.Slice(g * col_step, (g + 1) * col_step);
536
            if (data_layout != framework::DataLayout::kNHWC) {
537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564
              Tensor input_grad_slice =
                  input_grad_batch.Slice(g * in_step, (g + 1) * in_step);
              blas.MatMul(filter_slice, false, col_matrix_slice, false,
                          static_cast<T>(1.0), &input_grad_slice,
                          static_cast<T>(0.0));
            } else {
              Tensor input_grad_slice;
              Slice<DeviceContext, T, 2>(context, &input_grad_batch,
                                         &input_grad_slice, g * in_step,
                                         (g + 1) * in_step, 1);
              blas.MatMul(col_matrix_slice, true, filter_slice, true,
                          static_cast<T>(1.0), &input_grad_slice,
                          static_cast<T>(0.0));
              DDim input_grad_slice_shape;
              if (data_dim == 2U) {
                input_grad_slice_shape = {in_dims[1], in_dims[2], in_step};
              } else {
                input_grad_slice_shape = {in_dims[1], in_dims[2], in_dims[3],
                                          in_step};
              }
              input_grad_slice =
                  input_grad_slice.Resize(input_grad_slice_shape);
              input_grad_batch_vec.push_back(input_grad_slice);
            }
          }
          if (data_layout == framework::DataLayout::kNHWC) {
            concat_functor(dev_ctx, input_grad_batch_vec,
                           static_cast<int>(D - 2), &input_grad_batch);
Y
Yibing Liu 已提交
565
          }
C
chengduoZH 已提交
566 567
        }
        if (filter_grad) {
568
          // input batch: (i_c, i_h * i_w) or (i_h, i_w * i_c)
C
chengduoZH 已提交
569 570
          Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
          // gemm: d_filter = x * dy^T
571 572
          // (i_c, i_h * i_w) * (i_h * i_w, o_c * k_h * k_w) -> (i_c, o_c * k_h
          // * k_w)
573
          // or
574 575
          // (i_c, i_d * i_h * i_w) * (i_d * i_h * i_w, o_c * k_d * k_h * k_w)
          // -> (i_c, o_c * k_d *
C
chengduoZH 已提交
576
          // k_h * k_w)
577 578
          // gemm: d_filter = x^T * dy^T for channel_last

Y
Yibing Liu 已提交
579 580 581 582 583
          for (int g = 0; g < groups; g++) {
            Tensor filter_grad_slice =
                filter_grad_.Slice(g * in_step, (g + 1) * in_step);
            Tensor col_matrix_slice =
                col_matrix.Slice(g * col_step, (g + 1) * col_step);
584
            if (data_layout != framework::DataLayout::kNHWC) {
585 586 587 588 589 590 591 592 593 594 595 596 597
              Tensor in_batch_slice =
                  in_batch.Slice(g * in_step, (g + 1) * in_step);
              blas.MatMul(in_batch_slice, false, col_matrix_slice, true,
                          static_cast<T>(1.0), &filter_grad_slice,
                          static_cast<T>(1.0));
            } else {
              Tensor in_batch_slice;
              Slice<DeviceContext, T, 2>(context, &in_batch, &in_batch_slice,
                                         g * in_step, (g + 1) * in_step, 1);
              blas.MatMul(in_batch_slice, true, col_matrix_slice, true,
                          static_cast<T>(1.0), &filter_grad_slice,
                          static_cast<T>(1.0));
            }
Y
Yibing Liu 已提交
598
          }
C
chengduoZH 已提交
599
        }
C
chengduoZH 已提交
600 601 602 603
      }
    }
  }
};
604 605 606 607 608

template <typename DeviceContext, typename T>
class DepthwiseConvTransposeKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
609 610 611 612
    const std::string data_layout_str =
        context.Attr<std::string>("data_format");
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
613 614 615 616 617 618 619 620 621 622 623
    const Tensor* input = context.Input<Tensor>("Input");
    Tensor filter = *context.Input<Tensor>("Filter");
    Tensor* output = context.Output<Tensor>("Output");
    output->mutable_data<T>(context.GetPlace());

    int groups = context.Attr<int>("groups");
    PADDLE_ENFORCE_EQ(groups, filter.dims()[0]);

    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
624 625
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");
626 627 628 629
    for (auto v : dilations) {
      PADDLE_ENFORCE_EQ(v, 1);
    }

630 631 632 633
    auto in_dims = input->dims();
    auto filter_dims = filter.dims();

    framework::DDim in_data_dims;
634
    if (data_layout != framework::DataLayout::kNHWC) {
635 636 637 638 639 640 641 642 643 644
      in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
    } else {
      in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
    }
    framework::DDim filter_data_dims =
        framework::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

645 646 647 648 649 650 651
    output->mutable_data<T>(context.GetPlace());
    auto& dev_ctx = context.template device_context<DeviceContext>();
    math::SetConstant<DeviceContext, T> set_zero;
    set_zero(dev_ctx, output, static_cast<T>(0));

    math::DepthwiseConvInputGradFunctor<DeviceContext, T>
        depthwiseConvInputGrad;
652 653 654 655
    depthwiseConvInputGrad(
        dev_ctx, *output, filter, *input, strides,
        std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
        dilations, output, data_layout);
656 657 658 659 660 661 662
  }
};

template <typename DeviceContext, typename T>
class DepthwiseConvTransposeGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
663 664 665 666
    const std::string data_layout_str =
        context.Attr<std::string>("data_format");
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
667 668 669 670 671 672 673 674 675 676 677 678 679 680
    const Tensor* input = context.Input<Tensor>("Input");
    const Tensor* output_grad =
        context.Input<Tensor>(framework::GradVarName("Output"));
    Tensor* input_grad =
        context.Output<Tensor>(framework::GradVarName("Input"));
    Tensor* filter_grad =
        context.Output<Tensor>(framework::GradVarName("Filter"));
    Tensor filter = *context.Input<Tensor>("Filter");

    if (!input_grad && !filter_grad) return;

    auto& dev_ctx = context.template device_context<DeviceContext>();
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
681
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
682 683 684 685 686 687 688
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");

    auto in_dims = input->dims();
    auto filter_dims = filter.dims();

    framework::DDim in_data_dims;
689
    if (data_layout != framework::DataLayout::kNHWC) {
690 691 692 693 694 695 696 697 698
      in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
    } else {
      in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
    }
    framework::DDim filter_data_dims =
        framework::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);
699 700 701

    if (input_grad) {
      math::DepthwiseConvFunctor<DeviceContext, T> depthwiseConv;
702 703 704 705
      depthwiseConv(
          dev_ctx, *output_grad, filter, strides, paddings,
          std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
          input_grad, data_layout);
706 707 708 709 710 711 712 713 714
    }

    if (filter_grad) {
      math::SetConstant<DeviceContext, T> set_zero;
      filter_grad->mutable_data<T>(context.GetPlace());
      set_zero(dev_ctx, filter_grad, static_cast<T>(0));

      math::DepthwiseConvFilterGradFunctor<DeviceContext, T>
          depthwiseConvFilterGrad;
715 716 717 718
      depthwiseConvFilterGrad(
          dev_ctx, *output_grad, *input, strides,
          std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
          dilations, filter_grad, data_layout);
719 720 721
    }
  }
};
C
chengduoZH 已提交
722 723
}  // namespace operators
}  // namespace paddle