conv_transpose_op.h 28.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16 17
#include <algorithm>
#include <string>
S
Siddharth Goyal 已提交
18
#include <vector>
Y
Yi Wang 已提交
19 20
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
21
#include "paddle/fluid/operators/conv_op.h"
Y
Yu Yang 已提交
22
#include "paddle/fluid/operators/math/blas.h"
23
#include "paddle/fluid/operators/math/concat_and_split.h"
24
#include "paddle/fluid/operators/math/depthwise_conv.h"
Y
Yi Wang 已提交
25 26
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
C
chengduoZH 已提交
27 28 29 30 31 32 33

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using DDim = framework::DDim;

34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
template <typename DeviceContext, typename T, size_t D>
static void Slice(const framework::ExecutionContext& context,
                  const Tensor* input, Tensor* out,
                  const std::vector<int64_t>& begin_vec,
                  const std::vector<int64_t>& end_vec,
                  const std::vector<int64_t>& axes_vec) {
  auto& place =
      *context.template device_context<DeviceContext>().eigen_device();
  auto in_dims = input->dims();
  auto offsets = Eigen::array<int, D>();
  auto extents = Eigen::array<int, D>();
  for (size_t i = 0; i < D; ++i) {
    offsets[i] = 0;
    extents[i] = in_dims[i];
  }

  std::vector<int64_t> out_shape_vec = framework::vectorize(in_dims);
  for (size_t i = 0; i < axes_vec.size(); ++i) {
    offsets[axes_vec[i]] = begin_vec[i];
    extents[axes_vec[i]] = end_vec[i] - begin_vec[i];
    out_shape_vec[axes_vec[i]] = end_vec[i] - begin_vec[i];
  }

  framework::DDim out_dims(framework::make_ddim(out_shape_vec));
  out->mutable_data<T>(out_dims, context.GetPlace());

  auto in_t =
      framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
          *input);
  auto out_t =
      framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
          *out, out_dims);

  out_t.device(place) = in_t.slice(offsets, extents);
  out->Resize(out_dims);
}

template <typename DeviceContext, typename T, size_t D>
static void Slice(const framework::ExecutionContext& context,
                  const Tensor* input, Tensor* out, int64_t begin_idx,
                  int64_t end_idx, int64_t axes) {
  std::vector<int64_t> begin_vec = {begin_idx};
  std::vector<int64_t> end_vec = {end_idx};
  std::vector<int64_t> axes_vec = {axes};
  Slice<DeviceContext, T, D>(context, input, out, begin_vec, end_vec, axes_vec);
}

C
chengduoZH 已提交
81 82
// Define Op classes in .h file so that other conv transpose
// operator implementations can reuse the code.
C
chengduoZH 已提交
83 84
class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
85
  void Make() override;
C
chengduoZH 已提交
86 87
};

C
chengduoZH 已提交
88 89
class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
90
  void Make() override;
C
chengduoZH 已提交
91 92
};

C
chengduoZH 已提交
93
class ConvTransposeOp : public framework::OperatorWithKernel {
C
chengduoZH 已提交
94 95 96
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;
97 98 99 100

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
C
chengduoZH 已提交
101 102
};

C
chengduoZH 已提交
103
class ConvTransposeOpGrad : public framework::OperatorWithKernel {
C
chengduoZH 已提交
104 105 106
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;
107 108 109 110

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
C
chengduoZH 已提交
111 112
};

Q
QI JUN 已提交
113
template <typename DeviceContext, typename T>
114
class GemmConvTransposeKernel : public framework::OpKernel<T> {
C
chengduoZH 已提交
115 116
 public:
  void Compute(const framework::ExecutionContext& context) const override {
117 118 119 120
    const std::string data_layout_str =
        context.Attr<std::string>("data_format");
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
C
chengduoZH 已提交
121 122 123 124 125 126
    const Tensor* input = context.Input<Tensor>("Input");
    // The filter will be reshaped, so it should not be constant pointer
    Tensor filter = *context.Input<Tensor>("Filter");
    Tensor* output = context.Output<Tensor>("Output");

    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
C
chengduoZH 已提交
127
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
128
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
Y
Yibing Liu 已提交
129
    int groups = context.Attr<int>("groups");
130 131
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");
C
chengduoZH 已提交
132

133 134 135
    auto in_dims = input->dims();
    auto filter_dims = filter.dims();
    auto out_dims = output->dims();
C
chengduoZH 已提交
136
    const int batch_size = static_cast<int>(input->dims()[0]);
C
chengduoZH 已提交
137

138
    framework::DDim in_data_dims;
139
    if (data_layout != framework::DataLayout::kNHWC) {
140 141 142 143 144 145 146 147 148 149 150 151
      in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
    } else {
      in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
    }
    framework::DDim filter_data_dims =
        framework::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

    // input_shape_vec: {n, c, h, w} or {n, c, d, h, w} for channel_first
    // input_shape_vec: {n, h, w, c} or {n, d, h, w, c} for channel_last
152
    std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims());
153
    // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
154 155 156 157
    std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims());

    // use col_shape in the im2col and col2im (or vol2col and col2vol)
    // calculation
158
    // col_shape_vec: {o_c/g, k_h, k_w, h, w} or {o_c/g, k_d, k_h, k_w, d, h, w}
C
chengduoZH 已提交
159 160
    size_t data_dim = filter_shape_vec.size() - 2;
    std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
161
    if (data_layout != framework::DataLayout::kNHWC) {
162 163 164 165 166 167 168 169 170 171 172
      col_shape_vec[0] = out_dims[1] / groups;
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
      }
    } else {
      col_shape_vec[0] = out_dims[out_dims.size() - 1] / groups;
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 1];
      }
C
chengduoZH 已提交
173
    }
174
    DDim col_shape(framework::make_ddim(col_shape_vec));
C
chengduoZH 已提交
175 176

    // use col_matrix_shape in the gemm calculation
177
    // size: (o_c/g * k_h * k_w, h * w) or (o_c/g * k_d * k_h * k_w, d * h * w)
C
chengduoZH 已提交
178
    DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1);
C
chengduoZH 已提交
179 180 181 182 183 184 185 186 187 188

    Tensor col;
    col.mutable_data<T>(col_shape, context.GetPlace());
    // col_matrix shares the same piece of data with col,
    // but will be reshaped into a two-dimensional matrix shape
    // to call the matrix multiplication interface.
    Tensor col_matrix;
    col_matrix.ShareDataWith(col);
    col_matrix.Resize(col_matrix_shape);

189 190
    // output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first
    // output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last
191 192
    DDim output_shape =
        framework::slice_ddim(output->dims(), 1, output->dims().size());
C
chengduoZH 已提交
193

194 195 196
    // input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first
    // input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last
    DDim input_matrix_shape;
197
    if (data_layout != framework::DataLayout::kNHWC) {
198 199 200 201
      input_matrix_shape = {in_dims[1], col_matrix_shape[1]};
    } else {
      input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]};
    }
202

203 204
    // filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w)
    DDim filter_matrix_shape;
205
    if (data_layout != framework::DataLayout::kNHWC) {
206 207 208 209
      filter_matrix_shape = {in_dims[1], col_matrix_shape[0]};
    } else {
      filter_matrix_shape = {in_dims[in_dims.size() - 1], col_matrix_shape[0]};
    }
C
chengduoZH 已提交
210 211 212
    filter.Resize(filter_matrix_shape);

    output->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
213 214
    math::SetConstant<DeviceContext, T> set_zero;
    auto& dev_ctx = context.template device_context<DeviceContext>();
Y
Yu Yang 已提交
215
    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
Q
QI JUN 已提交
216
    set_zero(dev_ctx, output, static_cast<T>(0));
C
chengduoZH 已提交
217

218
    int in_step =
219
        (data_layout != framework::DataLayout::kNHWC
220 221 222 223
             ? static_cast<int>(in_dims[1]) / groups
             : static_cast<int>(in_dims[in_dims.size() - 1]) / groups);

    int out_step =
224
        (data_layout != framework::DataLayout::kNHWC
225 226
             ? static_cast<int>(out_dims[1]) / groups
             : static_cast<int>(out_dims[out_dims.size() - 1]) / groups);
Q
QI JUN 已提交
227 228
    math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
    math::Col2VolFunctor<DeviceContext, T> col2vol;
229
    math::ConcatFunctor<DeviceContext, T> concat_functor;
C
chengduoZH 已提交
230

231 232
    // convolution transpose: gemm + col2im or col2vol (similar to conv-backward
    // on input)
233
    size_t D = input->dims().size();
C
chengduoZH 已提交
234
    for (int i = 0; i < batch_size; i++) {
235 236
      // batch with size (i_c, h * w) or (i_c, d * h * w) for channel_first
      // batch with size (h * w, i_c) or (d * h * w, i_c) for channel_last
C
chengduoZH 已提交
237 238
      Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);

239 240
      // output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first
      // output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last
C
chengduoZH 已提交
241 242
      Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape);

243
      std::vector<Tensor> output_batch_vec;
Y
Yibing Liu 已提交
244
      for (int g = 0; g < groups; g++) {
245 246
        int64_t start = g * in_step;
        int64_t end = (g + 1) * in_step;
247
        int axes = (data_layout != framework::DataLayout::kNHWC ? 0 : 1);
Y
Yibing Liu 已提交
248
        Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step);
249
        Tensor in_slice, out_slice;
Y
Yibing Liu 已提交
250 251

        // col_matrix = filter_slice * input_slice
252 253
        // of shape (o_c/g * k_h * k_w, h * w)
        // or (o_c/g * k_d * k_h * k_w, d * h * w)
254
        if (data_layout != framework::DataLayout::kNHWC) {
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
          in_slice = input_batch.Slice(g * in_step, (g + 1) * in_step);
          out_slice = output_batch.Slice(g * out_step, (g + 1) * out_step);
          blas.MatMul(filter_slice, true, in_slice, false, static_cast<T>(1.0),
                      &col_matrix, static_cast<T>(0.0));
        } else {
          Slice<DeviceContext, T, 2>(context, &input_batch, &in_slice, start,
                                     end, axes);
          start = g * out_step;
          end = (g + 1) * out_step;
          axes = D - 2;
          if (D == 4U) {
            Slice<DeviceContext, T, 3>(context, &output_batch, &out_slice,
                                       start, end, axes);
          } else if (D == 5U) {
            Slice<DeviceContext, T, 4>(context, &output_batch, &out_slice,
                                       start, end, axes);
          }
          blas.MatMul(filter_slice, true, in_slice, true, static_cast<T>(1.0),
                      &col_matrix, static_cast<T>(0.0));
        }
Y
Yibing Liu 已提交
275 276 277

        if (data_dim == 2U) {
          // col2im: col_matrix -> dy
278 279
          // from (o_c/g * k_h * k_w, h * w) to (o_c/g, o_h, o_w) or (o_h, o_w,
          // o_c/g)
Y
Yibing Liu 已提交
280
          col2im(dev_ctx, col, dilations, strides,
281 282 283
                 std::vector<int>{paddings[0], paddings[2], paddings[1],
                                  paddings[3]},
                 &out_slice, data_layout);
Y
Yibing Liu 已提交
284 285
        } else if (data_dim == 3U) {
          // col2vol: col_matrix -> dy
286 287 288 289
          // from (o_c/g * k_d * k_h * k_w, d * h * w) to (o_c/g, o_d, o_h, o_w)
          // or (o_d, o_h, o_w, o_c/g)
          col2vol(dev_ctx, col, dilations, strides, paddings, &out_slice,
                  data_layout);
Y
Yibing Liu 已提交
290
        }
291 292 293
        if (data_layout == framework::DataLayout::kNHWC) {
          output_batch_vec.push_back(out_slice);
        }
294 295 296 297
      }
      if (data_layout == framework::DataLayout::kNHWC) {
        concat_functor(dev_ctx, output_batch_vec, static_cast<int>(D - 2),
                       &output_batch);
298
      }
C
chengduoZH 已提交
299 300 301 302
    }
  }
};

Q
QI JUN 已提交
303
template <typename DeviceContext, typename T>
304
class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
C
chengduoZH 已提交
305 306
 public:
  void Compute(const framework::ExecutionContext& context) const override {
307 308 309 310
    const std::string data_layout_str =
        context.Attr<std::string>("data_format");
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
C
chengduoZH 已提交
311 312 313 314 315 316 317 318 319 320 321
    const Tensor* input = context.Input<Tensor>("Input");
    const Tensor* output_grad =
        context.Input<Tensor>(framework::GradVarName("Output"));
    // For filter, we do not use const pointer b/c we will do reshape,
    // but we should avoid modifying its value.
    Tensor filter = *context.Input<Tensor>("Filter");
    Tensor* input_grad =
        context.Output<Tensor>(framework::GradVarName("Input"));
    Tensor* filter_grad =
        context.Output<Tensor>(framework::GradVarName("Filter"));

322 323
    if ((!input_grad) && (!filter_grad)) return;

C
chengduoZH 已提交
324 325
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
326
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
Y
Yibing Liu 已提交
327
    int groups = context.Attr<int>("groups");
328 329
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");
C
chengduoZH 已提交
330

331 332 333
    auto in_dims = input->dims();
    auto filter_dims = filter.dims();
    auto out_grad_dims = output_grad->dims();
C
chengduoZH 已提交
334
    const int batch_size = static_cast<int>(input->dims()[0]);
C
chengduoZH 已提交
335

336
    framework::DDim in_data_dims;
337
    if (data_layout != framework::DataLayout::kNHWC) {
338 339 340 341 342 343 344 345 346 347 348 349
      in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
    } else {
      in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
    }
    framework::DDim filter_data_dims =
        framework::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

    // input_shape_vec: {n, c, h, w} or {n, c, d, h, w} for channel_first
    // input_shape_vec: {n, h, w, c} or {n, d, h, w, c} for channel_last
350
    std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims());
351
    // filter_shape_vec: {i_c, o_c, k_h, k_w} or {i_c, o_c, k_d, k_h, k_w}
352 353 354 355
    std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims());

    // use col_shape in the im2col and col2im (or vol2col and col2vol)
    // calculation
356
    // col_shape_vec: {o_c, k_h, k_w, h, w} or {o_c, k_d, k_h, k_w, d, h, w} for
C
chengduoZH 已提交
357 358
    size_t data_dim = filter_shape_vec.size() - 2;
    std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
359
    if (data_layout != framework::DataLayout::kNHWC) {
360 361 362 363 364 365 366 367 368 369 370
      col_shape_vec[0] = out_grad_dims[1];
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
      }
    } else {
      col_shape_vec[0] = out_grad_dims[out_grad_dims.size() - 1];
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 1];
      }
C
chengduoZH 已提交
371
    }
372
    DDim col_shape(framework::make_ddim(col_shape_vec));
C
chengduoZH 已提交
373

374
    // use col_matrix_shape in the gemm calculation
375
    // size: (o_c * k_h * k_w, h * w) or (o_c * k_d * k_h * k_w, d * h * w)
C
chengduoZH 已提交
376
    DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1);
C
chengduoZH 已提交
377

378 379
    // output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first
    // output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last
380 381
    DDim output_shape = framework::slice_ddim(output_grad->dims(), 1,
                                              output_grad->dims().size());
C
chengduoZH 已提交
382

383 384 385
    // input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first
    // input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last
    DDim input_matrix_shape;
386
    if (data_layout != framework::DataLayout::kNHWC) {
387 388 389 390
      input_matrix_shape = {in_dims[1], col_matrix_shape[1]};
    } else {
      input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]};
    }
C
chengduoZH 已提交
391

392 393
    // filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w)
    DDim filter_matrix_shape;
394
    if (data_layout != framework::DataLayout::kNHWC) {
395 396 397 398 399
      filter_matrix_shape = {in_dims[1], col_matrix_shape[0] / groups};
    } else {
      filter_matrix_shape = {in_dims[in_dims.size() - 1],
                             col_matrix_shape[0] / groups};
    }
C
chengduoZH 已提交
400
    filter.Resize(filter_matrix_shape);
401 402

    int in_step =
403
        (data_layout != framework::DataLayout::kNHWC
404 405
             ? static_cast<int>(in_dims[1]) / groups
             : static_cast<int>(in_dims[in_dims.size() - 1]) / groups);
Y
Yibing Liu 已提交
406
    int col_step = static_cast<int>(col_matrix_shape[0]) / groups;
C
chengduoZH 已提交
407 408 409 410

    // convolution transpose grad on input:
    // im2col + gemm (similar to conv-forward)
    // input need to compute gradient
Q
QI JUN 已提交
411
    auto& dev_ctx = context.template device_context<DeviceContext>();
Y
Yu Yang 已提交
412
    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
C
chengduoZH 已提交
413 414 415 416 417 418
    if (input_grad || filter_grad) {
      Tensor col;
      col.mutable_data<T>(col_shape, context.GetPlace());
      // col_matrix shares the same piece of data with col,
      // but will be reshaped into a two-dimensional matrix shape
      // to call the matrix multiplication interface.
C
chengduoZH 已提交
419 420 421 422
      Tensor col_matrix;
      col_matrix.ShareDataWith(col);
      col_matrix.Resize(col_matrix_shape);

C
chengduoZH 已提交
423
      Tensor filter_grad_;
Q
QI JUN 已提交
424
      math::SetConstant<DeviceContext, T> set_zero;
C
chengduoZH 已提交
425

Q
QI JUN 已提交
426 427
      math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
      math::Vol2ColFunctor<DeviceContext, T> vol2col;
428
      math::ConcatFunctor<DeviceContext, T> concat_functor;
C
chengduoZH 已提交
429

C
chengduoZH 已提交
430 431
      if (input_grad) {
        input_grad->mutable_data<T>(context.GetPlace());
432
        set_zero(dev_ctx, input_grad, static_cast<T>(0));
C
chengduoZH 已提交
433
      }
434
      if (filter_grad) {  // filter_grad_ size (i_c, o_c/g, k_h, k_w)
C
chengduoZH 已提交
435
        filter_grad->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
436
        set_zero(dev_ctx, filter_grad, static_cast<T>(0));
C
chengduoZH 已提交
437 438
        filter_grad_ = *filter_grad;
        filter_grad_.Resize(filter_matrix_shape);
C
chengduoZH 已提交
439 440
      }

441
      size_t D = input->dims().size();
C
chengduoZH 已提交
442
      for (int i = 0; i < batch_size; i++) {
443 444 445 446
        // batch with size (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for
        // channel_first
        // batch with size (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for
        // channel_last
C
chengduoZH 已提交
447 448 449
        Tensor output_grad_batch =
            output_grad->Slice(i, i + 1).Resize(output_shape);

C
chengduoZH 已提交
450
        if (data_dim == 2U) {
451
          // im2col: dy -> col matrix
452 453 454 455
          // from (o_c, o_h, o_w) to (o_c * k_h * k_w, i_h * i_w) for
          // channel_first
          // from (o_h, o_w, o_c) to (o_c * k_h * k_w, i_h * i_w) for
          // channel_last
456
          im2col(dev_ctx, output_grad_batch, dilations, strides,
457 458 459
                 std::vector<int>{paddings[0], paddings[2], paddings[1],
                                  paddings[3]},
                 &col, data_layout);
C
chengduoZH 已提交
460
        } else if (data_dim == 3U) {
461
          // vol2col: dy -> col_matrix
462 463 464 465
          // from (o_c, o_d, o_h, o_w) to (o_c * k_d * k_h * k_w, i_d * i_h *
          // i_w) for channel_first
          // from (o_d, o_h, o_w, o_c) to (i_d * i_h * i_w, o_c * k_d * k_h *
          // k_w) for channel_last
Q
QI JUN 已提交
466
          vol2col(dev_ctx, output_grad_batch, dilations, strides, paddings,
467
                  &col, data_layout);
468
        }
C
chengduoZH 已提交
469

C
chengduoZH 已提交
470
        if (input_grad) {
471
          // batch with size (i_c, i_h, i_w) or (i_h, i_w, i_c)
C
chengduoZH 已提交
472 473
          Tensor input_grad_batch =
              input_grad->Slice(i, i + 1).Resize(input_matrix_shape);
474

C
chengduoZH 已提交
475
          // gemm: dx = filter * dy
476 477
          // (i_c, o_c * k_h * k_w) * (o_c * k_h * k_w, i_h * i_w) -> (i_c, i_h
          // * i_w)
478
          // or
479 480 481 482 483 484
          // (i_c, o_c * k_d * k_h * k_w) * (o_c * k_d * k_h * k_w, i_d * i_h *
          // i_w) -> (i_c,
          // i_d, i_h, i_w)
          // gemm: dx = dy^T * filter^T for channel_last

          std::vector<Tensor> input_grad_batch_vec;
Y
Yibing Liu 已提交
485
          for (int g = 0; g < groups; g++) {
486 487 488 489 490
            // input_grad_slice: (i_c/g, i_h * i_w) or (i_c/g, i_d * i_h * i_w)
            // for channel_first
            // input_grad_slice: (i_h * i_w, i_c/g) or (i_d * i_h * i_w, i_c/g)
            // for channel_last
            // filter_slice: (i_c/g, o_c/g * k_h * k_w)
Y
Yibing Liu 已提交
491
            Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step);
492 493
            // col_matrix_slice: (o_c/g * k_h * k_w, h * w) or (o_c/g * k_d *
            // k_h * k_w, d * h * w)
Y
Yibing Liu 已提交
494 495
            Tensor col_matrix_slice =
                col_matrix.Slice(g * col_step, (g + 1) * col_step);
496
            if (data_layout != framework::DataLayout::kNHWC) {
497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524
              Tensor input_grad_slice =
                  input_grad_batch.Slice(g * in_step, (g + 1) * in_step);
              blas.MatMul(filter_slice, false, col_matrix_slice, false,
                          static_cast<T>(1.0), &input_grad_slice,
                          static_cast<T>(0.0));
            } else {
              Tensor input_grad_slice;
              Slice<DeviceContext, T, 2>(context, &input_grad_batch,
                                         &input_grad_slice, g * in_step,
                                         (g + 1) * in_step, 1);
              blas.MatMul(col_matrix_slice, true, filter_slice, true,
                          static_cast<T>(1.0), &input_grad_slice,
                          static_cast<T>(0.0));
              DDim input_grad_slice_shape;
              if (data_dim == 2U) {
                input_grad_slice_shape = {in_dims[1], in_dims[2], in_step};
              } else {
                input_grad_slice_shape = {in_dims[1], in_dims[2], in_dims[3],
                                          in_step};
              }
              input_grad_slice =
                  input_grad_slice.Resize(input_grad_slice_shape);
              input_grad_batch_vec.push_back(input_grad_slice);
            }
          }
          if (data_layout == framework::DataLayout::kNHWC) {
            concat_functor(dev_ctx, input_grad_batch_vec,
                           static_cast<int>(D - 2), &input_grad_batch);
Y
Yibing Liu 已提交
525
          }
C
chengduoZH 已提交
526 527
        }
        if (filter_grad) {
528
          // input batch: (i_c, i_h * i_w) or (i_h, i_w * i_c)
C
chengduoZH 已提交
529 530
          Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
          // gemm: d_filter = x * dy^T
531 532
          // (i_c, i_h * i_w) * (i_h * i_w, o_c * k_h * k_w) -> (i_c, o_c * k_h
          // * k_w)
533
          // or
534 535
          // (i_c, i_d * i_h * i_w) * (i_d * i_h * i_w, o_c * k_d * k_h * k_w)
          // -> (i_c, o_c * k_d *
C
chengduoZH 已提交
536
          // k_h * k_w)
537 538
          // gemm: d_filter = x^T * dy^T for channel_last

Y
Yibing Liu 已提交
539 540 541 542 543
          for (int g = 0; g < groups; g++) {
            Tensor filter_grad_slice =
                filter_grad_.Slice(g * in_step, (g + 1) * in_step);
            Tensor col_matrix_slice =
                col_matrix.Slice(g * col_step, (g + 1) * col_step);
544
            if (data_layout != framework::DataLayout::kNHWC) {
545 546 547 548 549 550 551 552 553 554 555 556 557
              Tensor in_batch_slice =
                  in_batch.Slice(g * in_step, (g + 1) * in_step);
              blas.MatMul(in_batch_slice, false, col_matrix_slice, true,
                          static_cast<T>(1.0), &filter_grad_slice,
                          static_cast<T>(1.0));
            } else {
              Tensor in_batch_slice;
              Slice<DeviceContext, T, 2>(context, &in_batch, &in_batch_slice,
                                         g * in_step, (g + 1) * in_step, 1);
              blas.MatMul(in_batch_slice, true, col_matrix_slice, true,
                          static_cast<T>(1.0), &filter_grad_slice,
                          static_cast<T>(1.0));
            }
Y
Yibing Liu 已提交
558
          }
C
chengduoZH 已提交
559
        }
C
chengduoZH 已提交
560 561 562 563
      }
    }
  }
};
564 565 566 567 568

template <typename DeviceContext, typename T>
class DepthwiseConvTransposeKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
569 570 571 572
    const std::string data_layout_str =
        context.Attr<std::string>("data_format");
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
573 574 575 576 577 578 579 580 581 582 583
    const Tensor* input = context.Input<Tensor>("Input");
    Tensor filter = *context.Input<Tensor>("Filter");
    Tensor* output = context.Output<Tensor>("Output");
    output->mutable_data<T>(context.GetPlace());

    int groups = context.Attr<int>("groups");
    PADDLE_ENFORCE_EQ(groups, filter.dims()[0]);

    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
584 585
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");
586 587 588 589
    for (auto v : dilations) {
      PADDLE_ENFORCE_EQ(v, 1);
    }

590 591 592 593
    auto in_dims = input->dims();
    auto filter_dims = filter.dims();

    framework::DDim in_data_dims;
594
    if (data_layout != framework::DataLayout::kNHWC) {
595 596 597 598 599 600 601 602 603 604
      in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
    } else {
      in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
    }
    framework::DDim filter_data_dims =
        framework::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

605 606 607 608 609 610 611
    output->mutable_data<T>(context.GetPlace());
    auto& dev_ctx = context.template device_context<DeviceContext>();
    math::SetConstant<DeviceContext, T> set_zero;
    set_zero(dev_ctx, output, static_cast<T>(0));

    math::DepthwiseConvInputGradFunctor<DeviceContext, T>
        depthwiseConvInputGrad;
612 613 614 615
    depthwiseConvInputGrad(
        dev_ctx, *output, filter, *input, strides,
        std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
        dilations, output, data_layout);
616 617 618 619 620 621 622
  }
};

template <typename DeviceContext, typename T>
class DepthwiseConvTransposeGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
623 624 625 626
    const std::string data_layout_str =
        context.Attr<std::string>("data_format");
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
627 628 629 630 631 632 633 634 635 636 637 638 639 640
    const Tensor* input = context.Input<Tensor>("Input");
    const Tensor* output_grad =
        context.Input<Tensor>(framework::GradVarName("Output"));
    Tensor* input_grad =
        context.Output<Tensor>(framework::GradVarName("Input"));
    Tensor* filter_grad =
        context.Output<Tensor>(framework::GradVarName("Filter"));
    Tensor filter = *context.Input<Tensor>("Filter");

    if (!input_grad && !filter_grad) return;

    auto& dev_ctx = context.template device_context<DeviceContext>();
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
641
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
642 643 644 645 646 647 648
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");

    auto in_dims = input->dims();
    auto filter_dims = filter.dims();

    framework::DDim in_data_dims;
649
    if (data_layout != framework::DataLayout::kNHWC) {
650 651 652 653 654 655 656 657 658
      in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
    } else {
      in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
    }
    framework::DDim filter_data_dims =
        framework::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);
659 660 661

    if (input_grad) {
      math::DepthwiseConvFunctor<DeviceContext, T> depthwiseConv;
662 663 664 665
      depthwiseConv(
          dev_ctx, *output_grad, filter, strides, paddings,
          std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
          input_grad, data_layout);
666 667 668 669 670 671 672 673 674
    }

    if (filter_grad) {
      math::SetConstant<DeviceContext, T> set_zero;
      filter_grad->mutable_data<T>(context.GetPlace());
      set_zero(dev_ctx, filter_grad, static_cast<T>(0));

      math::DepthwiseConvFilterGradFunctor<DeviceContext, T>
          depthwiseConvFilterGrad;
675 676 677 678
      depthwiseConvFilterGrad(
          dev_ctx, *output_grad, *input, strides,
          std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
          dilations, filter_grad, data_layout);
679 680 681
    }
  }
};
C
chengduoZH 已提交
682 683
}  // namespace operators
}  // namespace paddle