conv_transpose_op.h 29.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16 17
#include <algorithm>
#include <string>
S
Siddharth Goyal 已提交
18
#include <vector>
Y
Yi Wang 已提交
19 20
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
21
#include "paddle/fluid/operators/conv_op.h"
Y
Yu Yang 已提交
22
#include "paddle/fluid/operators/math/blas.h"
23
#include "paddle/fluid/operators/math/concat_and_split.h"
24
#include "paddle/fluid/operators/math/depthwise_conv.h"
Y
Yi Wang 已提交
25 26
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
C
chengduoZH 已提交
27 28 29 30 31 32 33

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using DDim = framework::DDim;

34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
template <typename DeviceContext, typename T, size_t D>
static void Slice(const framework::ExecutionContext& context,
                  const Tensor* input, Tensor* out,
                  const std::vector<int64_t>& begin_vec,
                  const std::vector<int64_t>& end_vec,
                  const std::vector<int64_t>& axes_vec) {
  auto& place =
      *context.template device_context<DeviceContext>().eigen_device();
  auto in_dims = input->dims();
  auto offsets = Eigen::array<int, D>();
  auto extents = Eigen::array<int, D>();
  for (size_t i = 0; i < D; ++i) {
    offsets[i] = 0;
    extents[i] = in_dims[i];
  }

  std::vector<int64_t> out_shape_vec = framework::vectorize(in_dims);
  for (size_t i = 0; i < axes_vec.size(); ++i) {
    offsets[axes_vec[i]] = begin_vec[i];
    extents[axes_vec[i]] = end_vec[i] - begin_vec[i];
    out_shape_vec[axes_vec[i]] = end_vec[i] - begin_vec[i];
  }

  framework::DDim out_dims(framework::make_ddim(out_shape_vec));
  out->mutable_data<T>(out_dims, context.GetPlace());

  auto in_t =
      framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
          *input);
  auto out_t =
      framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
          *out, out_dims);

  out_t.device(place) = in_t.slice(offsets, extents);
  out->Resize(out_dims);
}

template <typename DeviceContext, typename T, size_t D>
static void Slice(const framework::ExecutionContext& context,
                  const Tensor* input, Tensor* out, int64_t begin_idx,
                  int64_t end_idx, int64_t axes) {
  std::vector<int64_t> begin_vec = {begin_idx};
  std::vector<int64_t> end_vec = {end_idx};
  std::vector<int64_t> axes_vec = {axes};
  Slice<DeviceContext, T, D>(context, input, out, begin_vec, end_vec, axes_vec);
}

C
chengduoZH 已提交
81 82
// Define Op classes in .h file so that other conv transpose
// operator implementations can reuse the code.
C
chengduoZH 已提交
83 84
class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
85
  void Make() override;
C
chengduoZH 已提交
86 87
};

C
chengduoZH 已提交
88 89
class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
90
  void Make() override;
C
chengduoZH 已提交
91 92
};

C
chengduoZH 已提交
93
class ConvTransposeOp : public framework::OperatorWithKernel {
C
chengduoZH 已提交
94 95 96
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;
97 98 99 100

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
101 102 103 104

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override;
C
chengduoZH 已提交
105 106
};

C
chengduoZH 已提交
107
class ConvTransposeOpGrad : public framework::OperatorWithKernel {
C
chengduoZH 已提交
108 109 110
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;
111 112 113 114

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
C
chengduoZH 已提交
115 116
};

117 118 119 120 121 122 123 124 125 126
class ConvTransposeOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
};

Q
QI JUN 已提交
127
template <typename DeviceContext, typename T>
128
class GemmConvTransposeKernel : public framework::OpKernel<T> {
C
chengduoZH 已提交
129 130
 public:
  void Compute(const framework::ExecutionContext& context) const override {
131 132 133 134
    const std::string data_layout_str =
        context.Attr<std::string>("data_format");
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
C
chengduoZH 已提交
135 136 137 138 139 140
    const Tensor* input = context.Input<Tensor>("Input");
    // The filter will be reshaped, so it should not be constant pointer
    Tensor filter = *context.Input<Tensor>("Filter");
    Tensor* output = context.Output<Tensor>("Output");

    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
C
chengduoZH 已提交
141
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
142
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
Y
Yibing Liu 已提交
143
    int groups = context.Attr<int>("groups");
144 145
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");
C
chengduoZH 已提交
146

147 148 149
    auto in_dims = input->dims();
    auto filter_dims = filter.dims();
    auto out_dims = output->dims();
C
chengduoZH 已提交
150
    const int batch_size = static_cast<int>(input->dims()[0]);
C
chengduoZH 已提交
151

152
    framework::DDim in_data_dims;
153
    if (data_layout != framework::DataLayout::kNHWC) {
154 155 156 157 158 159 160 161 162 163 164 165
      in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
    } else {
      in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
    }
    framework::DDim filter_data_dims =
        framework::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

    // input_shape_vec: {n, c, h, w} or {n, c, d, h, w} for channel_first
    // input_shape_vec: {n, h, w, c} or {n, d, h, w, c} for channel_last
166
    std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims());
167
    // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
168 169 170 171
    std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims());

    // use col_shape in the im2col and col2im (or vol2col and col2vol)
    // calculation
172
    // col_shape_vec: {o_c/g, k_h, k_w, h, w} or {o_c/g, k_d, k_h, k_w, d, h, w}
C
chengduoZH 已提交
173 174
    size_t data_dim = filter_shape_vec.size() - 2;
    std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
175
    if (data_layout != framework::DataLayout::kNHWC) {
176 177 178 179 180 181 182 183 184 185 186
      col_shape_vec[0] = out_dims[1] / groups;
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
      }
    } else {
      col_shape_vec[0] = out_dims[out_dims.size() - 1] / groups;
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 1];
      }
C
chengduoZH 已提交
187
    }
188
    DDim col_shape(framework::make_ddim(col_shape_vec));
C
chengduoZH 已提交
189 190

    // use col_matrix_shape in the gemm calculation
191
    // size: (o_c/g * k_h * k_w, h * w) or (o_c/g * k_d * k_h * k_w, d * h * w)
C
chengduoZH 已提交
192
    DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1);
C
chengduoZH 已提交
193 194 195 196 197 198 199 200 201 202

    Tensor col;
    col.mutable_data<T>(col_shape, context.GetPlace());
    // col_matrix shares the same piece of data with col,
    // but will be reshaped into a two-dimensional matrix shape
    // to call the matrix multiplication interface.
    Tensor col_matrix;
    col_matrix.ShareDataWith(col);
    col_matrix.Resize(col_matrix_shape);

203 204
    // output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first
    // output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last
205 206
    DDim output_shape =
        framework::slice_ddim(output->dims(), 1, output->dims().size());
C
chengduoZH 已提交
207

208 209 210
    // input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first
    // input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last
    DDim input_matrix_shape;
211
    if (data_layout != framework::DataLayout::kNHWC) {
212 213 214 215
      input_matrix_shape = {in_dims[1], col_matrix_shape[1]};
    } else {
      input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]};
    }
216

217 218
    // filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w)
    DDim filter_matrix_shape;
219
    if (data_layout != framework::DataLayout::kNHWC) {
220 221 222 223
      filter_matrix_shape = {in_dims[1], col_matrix_shape[0]};
    } else {
      filter_matrix_shape = {in_dims[in_dims.size() - 1], col_matrix_shape[0]};
    }
C
chengduoZH 已提交
224 225 226
    filter.Resize(filter_matrix_shape);

    output->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
227 228
    math::SetConstant<DeviceContext, T> set_zero;
    auto& dev_ctx = context.template device_context<DeviceContext>();
Y
Yu Yang 已提交
229
    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
Q
QI JUN 已提交
230
    set_zero(dev_ctx, output, static_cast<T>(0));
C
chengduoZH 已提交
231

232
    int in_step =
233
        (data_layout != framework::DataLayout::kNHWC
234 235 236 237
             ? static_cast<int>(in_dims[1]) / groups
             : static_cast<int>(in_dims[in_dims.size() - 1]) / groups);

    int out_step =
238
        (data_layout != framework::DataLayout::kNHWC
239 240
             ? static_cast<int>(out_dims[1]) / groups
             : static_cast<int>(out_dims[out_dims.size() - 1]) / groups);
Q
QI JUN 已提交
241 242
    math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
    math::Col2VolFunctor<DeviceContext, T> col2vol;
243
    math::ConcatFunctor<DeviceContext, T> concat_functor;
C
chengduoZH 已提交
244

245 246
    // convolution transpose: gemm + col2im or col2vol (similar to conv-backward
    // on input)
247
    size_t D = input->dims().size();
C
chengduoZH 已提交
248
    for (int i = 0; i < batch_size; i++) {
249 250
      // batch with size (i_c, h * w) or (i_c, d * h * w) for channel_first
      // batch with size (h * w, i_c) or (d * h * w, i_c) for channel_last
C
chengduoZH 已提交
251 252
      Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);

253 254
      // output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first
      // output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last
C
chengduoZH 已提交
255 256
      Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape);

257
      std::vector<Tensor> output_batch_vec;
Y
Yibing Liu 已提交
258
      for (int g = 0; g < groups; g++) {
259 260
        int64_t start = g * in_step;
        int64_t end = (g + 1) * in_step;
261
        int axes = (data_layout != framework::DataLayout::kNHWC ? 0 : 1);
Y
Yibing Liu 已提交
262
        Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step);
263
        Tensor in_slice, out_slice;
Y
Yibing Liu 已提交
264 265

        // col_matrix = filter_slice * input_slice
266 267
        // of shape (o_c/g * k_h * k_w, h * w)
        // or (o_c/g * k_d * k_h * k_w, d * h * w)
268
        if (data_layout != framework::DataLayout::kNHWC) {
269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
          in_slice = input_batch.Slice(g * in_step, (g + 1) * in_step);
          out_slice = output_batch.Slice(g * out_step, (g + 1) * out_step);
          blas.MatMul(filter_slice, true, in_slice, false, static_cast<T>(1.0),
                      &col_matrix, static_cast<T>(0.0));
        } else {
          Slice<DeviceContext, T, 2>(context, &input_batch, &in_slice, start,
                                     end, axes);
          start = g * out_step;
          end = (g + 1) * out_step;
          axes = D - 2;
          if (D == 4U) {
            Slice<DeviceContext, T, 3>(context, &output_batch, &out_slice,
                                       start, end, axes);
          } else if (D == 5U) {
            Slice<DeviceContext, T, 4>(context, &output_batch, &out_slice,
                                       start, end, axes);
          }
          blas.MatMul(filter_slice, true, in_slice, true, static_cast<T>(1.0),
                      &col_matrix, static_cast<T>(0.0));
        }
Y
Yibing Liu 已提交
289 290 291

        if (data_dim == 2U) {
          // col2im: col_matrix -> dy
292 293
          // from (o_c/g * k_h * k_w, h * w) to (o_c/g, o_h, o_w) or (o_h, o_w,
          // o_c/g)
Y
Yibing Liu 已提交
294
          col2im(dev_ctx, col, dilations, strides,
295 296 297
                 std::vector<int>{paddings[0], paddings[2], paddings[1],
                                  paddings[3]},
                 &out_slice, data_layout);
Y
Yibing Liu 已提交
298 299
        } else if (data_dim == 3U) {
          // col2vol: col_matrix -> dy
300 301 302 303
          // from (o_c/g * k_d * k_h * k_w, d * h * w) to (o_c/g, o_d, o_h, o_w)
          // or (o_d, o_h, o_w, o_c/g)
          col2vol(dev_ctx, col, dilations, strides, paddings, &out_slice,
                  data_layout);
Y
Yibing Liu 已提交
304
        }
305 306 307
        if (data_layout == framework::DataLayout::kNHWC) {
          output_batch_vec.push_back(out_slice);
        }
308 309 310 311
      }
      if (data_layout == framework::DataLayout::kNHWC) {
        concat_functor(dev_ctx, output_batch_vec, static_cast<int>(D - 2),
                       &output_batch);
312
      }
C
chengduoZH 已提交
313 314 315 316
    }
  }
};

Q
QI JUN 已提交
317
template <typename DeviceContext, typename T>
318
class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
C
chengduoZH 已提交
319 320
 public:
  void Compute(const framework::ExecutionContext& context) const override {
321 322 323 324
    const std::string data_layout_str =
        context.Attr<std::string>("data_format");
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
C
chengduoZH 已提交
325 326 327 328 329 330 331 332 333 334 335
    const Tensor* input = context.Input<Tensor>("Input");
    const Tensor* output_grad =
        context.Input<Tensor>(framework::GradVarName("Output"));
    // For filter, we do not use const pointer b/c we will do reshape,
    // but we should avoid modifying its value.
    Tensor filter = *context.Input<Tensor>("Filter");
    Tensor* input_grad =
        context.Output<Tensor>(framework::GradVarName("Input"));
    Tensor* filter_grad =
        context.Output<Tensor>(framework::GradVarName("Filter"));

336 337
    if ((!input_grad) && (!filter_grad)) return;

C
chengduoZH 已提交
338 339
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
340
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
Y
Yibing Liu 已提交
341
    int groups = context.Attr<int>("groups");
342 343
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");
C
chengduoZH 已提交
344

345 346 347
    auto in_dims = input->dims();
    auto filter_dims = filter.dims();
    auto out_grad_dims = output_grad->dims();
C
chengduoZH 已提交
348
    const int batch_size = static_cast<int>(input->dims()[0]);
C
chengduoZH 已提交
349

350
    framework::DDim in_data_dims;
351
    if (data_layout != framework::DataLayout::kNHWC) {
352 353 354 355 356 357 358 359 360 361 362 363
      in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
    } else {
      in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
    }
    framework::DDim filter_data_dims =
        framework::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

    // input_shape_vec: {n, c, h, w} or {n, c, d, h, w} for channel_first
    // input_shape_vec: {n, h, w, c} or {n, d, h, w, c} for channel_last
364
    std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims());
365
    // filter_shape_vec: {i_c, o_c, k_h, k_w} or {i_c, o_c, k_d, k_h, k_w}
366 367 368 369
    std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims());

    // use col_shape in the im2col and col2im (or vol2col and col2vol)
    // calculation
370
    // col_shape_vec: {o_c, k_h, k_w, h, w} or {o_c, k_d, k_h, k_w, d, h, w} for
C
chengduoZH 已提交
371 372
    size_t data_dim = filter_shape_vec.size() - 2;
    std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
373
    if (data_layout != framework::DataLayout::kNHWC) {
374 375 376 377 378 379 380 381 382 383 384
      col_shape_vec[0] = out_grad_dims[1];
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
      }
    } else {
      col_shape_vec[0] = out_grad_dims[out_grad_dims.size() - 1];
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 1];
      }
C
chengduoZH 已提交
385
    }
386
    DDim col_shape(framework::make_ddim(col_shape_vec));
C
chengduoZH 已提交
387

388
    // use col_matrix_shape in the gemm calculation
389
    // size: (o_c * k_h * k_w, h * w) or (o_c * k_d * k_h * k_w, d * h * w)
C
chengduoZH 已提交
390
    DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1);
C
chengduoZH 已提交
391

392 393
    // output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first
    // output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last
394 395
    DDim output_shape = framework::slice_ddim(output_grad->dims(), 1,
                                              output_grad->dims().size());
C
chengduoZH 已提交
396

397 398 399
    // input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first
    // input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last
    DDim input_matrix_shape;
400
    if (data_layout != framework::DataLayout::kNHWC) {
401 402 403 404
      input_matrix_shape = {in_dims[1], col_matrix_shape[1]};
    } else {
      input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]};
    }
C
chengduoZH 已提交
405

406 407
    // filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w)
    DDim filter_matrix_shape;
408
    if (data_layout != framework::DataLayout::kNHWC) {
409 410 411 412 413
      filter_matrix_shape = {in_dims[1], col_matrix_shape[0] / groups};
    } else {
      filter_matrix_shape = {in_dims[in_dims.size() - 1],
                             col_matrix_shape[0] / groups};
    }
C
chengduoZH 已提交
414
    filter.Resize(filter_matrix_shape);
415 416

    int in_step =
417
        (data_layout != framework::DataLayout::kNHWC
418 419
             ? static_cast<int>(in_dims[1]) / groups
             : static_cast<int>(in_dims[in_dims.size() - 1]) / groups);
Y
Yibing Liu 已提交
420
    int col_step = static_cast<int>(col_matrix_shape[0]) / groups;
C
chengduoZH 已提交
421 422 423 424

    // convolution transpose grad on input:
    // im2col + gemm (similar to conv-forward)
    // input need to compute gradient
Q
QI JUN 已提交
425
    auto& dev_ctx = context.template device_context<DeviceContext>();
Y
Yu Yang 已提交
426
    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
C
chengduoZH 已提交
427 428 429 430 431 432
    if (input_grad || filter_grad) {
      Tensor col;
      col.mutable_data<T>(col_shape, context.GetPlace());
      // col_matrix shares the same piece of data with col,
      // but will be reshaped into a two-dimensional matrix shape
      // to call the matrix multiplication interface.
C
chengduoZH 已提交
433 434 435 436
      Tensor col_matrix;
      col_matrix.ShareDataWith(col);
      col_matrix.Resize(col_matrix_shape);

C
chengduoZH 已提交
437
      Tensor filter_grad_;
Q
QI JUN 已提交
438
      math::SetConstant<DeviceContext, T> set_zero;
C
chengduoZH 已提交
439

Q
QI JUN 已提交
440 441
      math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
      math::Vol2ColFunctor<DeviceContext, T> vol2col;
442
      math::ConcatFunctor<DeviceContext, T> concat_functor;
C
chengduoZH 已提交
443

C
chengduoZH 已提交
444 445
      if (input_grad) {
        input_grad->mutable_data<T>(context.GetPlace());
446
        set_zero(dev_ctx, input_grad, static_cast<T>(0));
C
chengduoZH 已提交
447
      }
448
      if (filter_grad) {  // filter_grad_ size (i_c, o_c/g, k_h, k_w)
C
chengduoZH 已提交
449
        filter_grad->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
450
        set_zero(dev_ctx, filter_grad, static_cast<T>(0));
C
chengduoZH 已提交
451 452
        filter_grad_ = *filter_grad;
        filter_grad_.Resize(filter_matrix_shape);
C
chengduoZH 已提交
453 454
      }

455
      size_t D = input->dims().size();
C
chengduoZH 已提交
456
      for (int i = 0; i < batch_size; i++) {
457 458 459 460
        // batch with size (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for
        // channel_first
        // batch with size (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for
        // channel_last
C
chengduoZH 已提交
461 462 463
        Tensor output_grad_batch =
            output_grad->Slice(i, i + 1).Resize(output_shape);

C
chengduoZH 已提交
464
        if (data_dim == 2U) {
465
          // im2col: dy -> col matrix
466 467 468 469
          // from (o_c, o_h, o_w) to (o_c * k_h * k_w, i_h * i_w) for
          // channel_first
          // from (o_h, o_w, o_c) to (o_c * k_h * k_w, i_h * i_w) for
          // channel_last
470
          im2col(dev_ctx, output_grad_batch, dilations, strides,
471 472 473
                 std::vector<int>{paddings[0], paddings[2], paddings[1],
                                  paddings[3]},
                 &col, data_layout);
C
chengduoZH 已提交
474
        } else if (data_dim == 3U) {
475
          // vol2col: dy -> col_matrix
476 477 478 479
          // from (o_c, o_d, o_h, o_w) to (o_c * k_d * k_h * k_w, i_d * i_h *
          // i_w) for channel_first
          // from (o_d, o_h, o_w, o_c) to (i_d * i_h * i_w, o_c * k_d * k_h *
          // k_w) for channel_last
Q
QI JUN 已提交
480
          vol2col(dev_ctx, output_grad_batch, dilations, strides, paddings,
481
                  &col, data_layout);
482
        }
C
chengduoZH 已提交
483

C
chengduoZH 已提交
484
        if (input_grad) {
485
          // batch with size (i_c, i_h, i_w) or (i_h, i_w, i_c)
C
chengduoZH 已提交
486 487
          Tensor input_grad_batch =
              input_grad->Slice(i, i + 1).Resize(input_matrix_shape);
488

C
chengduoZH 已提交
489
          // gemm: dx = filter * dy
490 491
          // (i_c, o_c * k_h * k_w) * (o_c * k_h * k_w, i_h * i_w) -> (i_c, i_h
          // * i_w)
492
          // or
493 494 495 496 497 498
          // (i_c, o_c * k_d * k_h * k_w) * (o_c * k_d * k_h * k_w, i_d * i_h *
          // i_w) -> (i_c,
          // i_d, i_h, i_w)
          // gemm: dx = dy^T * filter^T for channel_last

          std::vector<Tensor> input_grad_batch_vec;
Y
Yibing Liu 已提交
499
          for (int g = 0; g < groups; g++) {
500 501 502 503 504
            // input_grad_slice: (i_c/g, i_h * i_w) or (i_c/g, i_d * i_h * i_w)
            // for channel_first
            // input_grad_slice: (i_h * i_w, i_c/g) or (i_d * i_h * i_w, i_c/g)
            // for channel_last
            // filter_slice: (i_c/g, o_c/g * k_h * k_w)
Y
Yibing Liu 已提交
505
            Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step);
506 507
            // col_matrix_slice: (o_c/g * k_h * k_w, h * w) or (o_c/g * k_d *
            // k_h * k_w, d * h * w)
Y
Yibing Liu 已提交
508 509
            Tensor col_matrix_slice =
                col_matrix.Slice(g * col_step, (g + 1) * col_step);
510
            if (data_layout != framework::DataLayout::kNHWC) {
511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538
              Tensor input_grad_slice =
                  input_grad_batch.Slice(g * in_step, (g + 1) * in_step);
              blas.MatMul(filter_slice, false, col_matrix_slice, false,
                          static_cast<T>(1.0), &input_grad_slice,
                          static_cast<T>(0.0));
            } else {
              Tensor input_grad_slice;
              Slice<DeviceContext, T, 2>(context, &input_grad_batch,
                                         &input_grad_slice, g * in_step,
                                         (g + 1) * in_step, 1);
              blas.MatMul(col_matrix_slice, true, filter_slice, true,
                          static_cast<T>(1.0), &input_grad_slice,
                          static_cast<T>(0.0));
              DDim input_grad_slice_shape;
              if (data_dim == 2U) {
                input_grad_slice_shape = {in_dims[1], in_dims[2], in_step};
              } else {
                input_grad_slice_shape = {in_dims[1], in_dims[2], in_dims[3],
                                          in_step};
              }
              input_grad_slice =
                  input_grad_slice.Resize(input_grad_slice_shape);
              input_grad_batch_vec.push_back(input_grad_slice);
            }
          }
          if (data_layout == framework::DataLayout::kNHWC) {
            concat_functor(dev_ctx, input_grad_batch_vec,
                           static_cast<int>(D - 2), &input_grad_batch);
Y
Yibing Liu 已提交
539
          }
C
chengduoZH 已提交
540 541
        }
        if (filter_grad) {
542
          // input batch: (i_c, i_h * i_w) or (i_h, i_w * i_c)
C
chengduoZH 已提交
543 544
          Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
          // gemm: d_filter = x * dy^T
545 546
          // (i_c, i_h * i_w) * (i_h * i_w, o_c * k_h * k_w) -> (i_c, o_c * k_h
          // * k_w)
547
          // or
548 549
          // (i_c, i_d * i_h * i_w) * (i_d * i_h * i_w, o_c * k_d * k_h * k_w)
          // -> (i_c, o_c * k_d *
C
chengduoZH 已提交
550
          // k_h * k_w)
551 552
          // gemm: d_filter = x^T * dy^T for channel_last

Y
Yibing Liu 已提交
553 554 555 556 557
          for (int g = 0; g < groups; g++) {
            Tensor filter_grad_slice =
                filter_grad_.Slice(g * in_step, (g + 1) * in_step);
            Tensor col_matrix_slice =
                col_matrix.Slice(g * col_step, (g + 1) * col_step);
558
            if (data_layout != framework::DataLayout::kNHWC) {
559 560 561 562 563 564 565 566 567 568 569 570 571
              Tensor in_batch_slice =
                  in_batch.Slice(g * in_step, (g + 1) * in_step);
              blas.MatMul(in_batch_slice, false, col_matrix_slice, true,
                          static_cast<T>(1.0), &filter_grad_slice,
                          static_cast<T>(1.0));
            } else {
              Tensor in_batch_slice;
              Slice<DeviceContext, T, 2>(context, &in_batch, &in_batch_slice,
                                         g * in_step, (g + 1) * in_step, 1);
              blas.MatMul(in_batch_slice, true, col_matrix_slice, true,
                          static_cast<T>(1.0), &filter_grad_slice,
                          static_cast<T>(1.0));
            }
Y
Yibing Liu 已提交
572
          }
C
chengduoZH 已提交
573
        }
C
chengduoZH 已提交
574 575 576 577
      }
    }
  }
};
578 579 580 581 582

template <typename DeviceContext, typename T>
class DepthwiseConvTransposeKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
583 584 585 586
    const std::string data_layout_str =
        context.Attr<std::string>("data_format");
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
587 588 589 590 591 592
    const Tensor* input = context.Input<Tensor>("Input");
    Tensor filter = *context.Input<Tensor>("Filter");
    Tensor* output = context.Output<Tensor>("Output");
    output->mutable_data<T>(context.GetPlace());

    int groups = context.Attr<int>("groups");
593 594 595 596 597 598
    PADDLE_ENFORCE_EQ(
        groups, filter.dims()[0],
        platform::errors::InvalidArgument(
            "groups should be error to the 1st dimension of filter. But "
            "received groups is %d and filter dimension[0] is %d",
            groups, filter.dims()[0]));
599 600 601 602

    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
603 604
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");
605
    for (auto v : dilations) {
606 607 608 609
      PADDLE_ENFORCE_EQ(v, 1, platform::errors::InvalidArgument(
                                  "dilations should be 1 in depthwise conv. "
                                  "But received dilations is %d",
                                  v));
610 611
    }

612 613 614 615
    auto in_dims = input->dims();
    auto filter_dims = filter.dims();

    framework::DDim in_data_dims;
616
    if (data_layout != framework::DataLayout::kNHWC) {
617 618 619 620 621 622 623 624 625 626
      in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
    } else {
      in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
    }
    framework::DDim filter_data_dims =
        framework::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

627 628 629 630 631 632 633
    output->mutable_data<T>(context.GetPlace());
    auto& dev_ctx = context.template device_context<DeviceContext>();
    math::SetConstant<DeviceContext, T> set_zero;
    set_zero(dev_ctx, output, static_cast<T>(0));

    math::DepthwiseConvInputGradFunctor<DeviceContext, T>
        depthwiseConvInputGrad;
634 635 636 637
    depthwiseConvInputGrad(
        dev_ctx, *output, filter, *input, strides,
        std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
        dilations, output, data_layout);
638 639 640 641 642 643 644
  }
};

template <typename DeviceContext, typename T>
class DepthwiseConvTransposeGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
645 646 647 648
    const std::string data_layout_str =
        context.Attr<std::string>("data_format");
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
649 650 651 652 653 654 655 656 657 658 659 660 661 662
    const Tensor* input = context.Input<Tensor>("Input");
    const Tensor* output_grad =
        context.Input<Tensor>(framework::GradVarName("Output"));
    Tensor* input_grad =
        context.Output<Tensor>(framework::GradVarName("Input"));
    Tensor* filter_grad =
        context.Output<Tensor>(framework::GradVarName("Filter"));
    Tensor filter = *context.Input<Tensor>("Filter");

    if (!input_grad && !filter_grad) return;

    auto& dev_ctx = context.template device_context<DeviceContext>();
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
663
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
664 665 666 667 668 669 670
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");

    auto in_dims = input->dims();
    auto filter_dims = filter.dims();

    framework::DDim in_data_dims;
671
    if (data_layout != framework::DataLayout::kNHWC) {
672 673 674 675 676 677 678 679 680
      in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
    } else {
      in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
    }
    framework::DDim filter_data_dims =
        framework::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);
681 682 683

    if (input_grad) {
      math::DepthwiseConvFunctor<DeviceContext, T> depthwiseConv;
684 685 686 687
      depthwiseConv(
          dev_ctx, *output_grad, filter, strides, paddings,
          std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
          input_grad, data_layout);
688 689 690 691 692 693 694 695 696
    }

    if (filter_grad) {
      math::SetConstant<DeviceContext, T> set_zero;
      filter_grad->mutable_data<T>(context.GetPlace());
      set_zero(dev_ctx, filter_grad, static_cast<T>(0));

      math::DepthwiseConvFilterGradFunctor<DeviceContext, T>
          depthwiseConvFilterGrad;
697 698 699 700
      depthwiseConvFilterGrad(
          dev_ctx, *output_grad, *input, strides,
          std::vector<int>{paddings[0], paddings[2], paddings[1], paddings[3]},
          dilations, filter_grad, data_layout);
701 702 703
    }
  }
};
C
chengduoZH 已提交
704 705
}  // namespace operators
}  // namespace paddle