conv_transpose_op.h 24.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16 17
#include <algorithm>
#include <string>
S
Siddharth Goyal 已提交
18
#include <vector>
Y
Yi Wang 已提交
19 20
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
21
#include "paddle/fluid/operators/conv_op.h"
22
#include "paddle/fluid/operators/eigen/eigen_function.h"
23
#include "paddle/fluid/operators/math/concat_and_split.h"
Y
Yi Wang 已提交
24 25
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
26
#include "paddle/phi/kernels/funcs/blas/blas.h"
C
chengduoZH 已提交
27 28 29 30 31 32 33

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using DDim = framework::DDim;

34 35 36 37 38 39 40 41 42
template <typename DeviceContext, typename T, size_t D>
static void Slice(const framework::ExecutionContext& context,
                  const Tensor* input, Tensor* out,
                  const std::vector<int64_t>& begin_vec,
                  const std::vector<int64_t>& end_vec,
                  const std::vector<int64_t>& axes_vec) {
  auto& place =
      *context.template device_context<DeviceContext>().eigen_device();
  auto in_dims = input->dims();
43 44
  auto offsets = Eigen::DSizes<Eigen::DenseIndex, D>();
  auto extents = Eigen::DSizes<Eigen::DenseIndex, D>();
45 46 47 48 49
  for (size_t i = 0; i < D; ++i) {
    offsets[i] = 0;
    extents[i] = in_dims[i];
  }

50
  std::vector<int64_t> out_shape_vec = phi::vectorize(in_dims);
51 52 53 54 55 56
  for (size_t i = 0; i < axes_vec.size(); ++i) {
    offsets[axes_vec[i]] = begin_vec[i];
    extents[axes_vec[i]] = end_vec[i] - begin_vec[i];
    out_shape_vec[axes_vec[i]] = end_vec[i] - begin_vec[i];
  }

57
  framework::DDim out_dims(phi::make_ddim(out_shape_vec));
58 59 60 61 62 63 64 65 66
  out->mutable_data<T>(out_dims, context.GetPlace());

  auto in_t =
      framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
          *input);
  auto out_t =
      framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
          *out, out_dims);

67 68
  EigenSlice<std::decay_t<decltype(place)>, T, D>::Eval(place, out_t, in_t,
                                                        offsets, extents);
69 70 71 72 73 74 75 76 77 78 79 80 81
  out->Resize(out_dims);
}

template <typename DeviceContext, typename T, size_t D>
static void Slice(const framework::ExecutionContext& context,
                  const Tensor* input, Tensor* out, int64_t begin_idx,
                  int64_t end_idx, int64_t axes) {
  std::vector<int64_t> begin_vec = {begin_idx};
  std::vector<int64_t> end_vec = {end_idx};
  std::vector<int64_t> axes_vec = {axes};
  Slice<DeviceContext, T, D>(context, input, out, begin_vec, end_vec, axes_vec);
}

C
chengduoZH 已提交
82 83
// Define Op classes in .h file so that other conv transpose
// operator implementations can reuse the code.
C
chengduoZH 已提交
84 85
class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
86
  void Make() override;
C
chengduoZH 已提交
87 88
};

C
chengduoZH 已提交
89 90
class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
91
  void Make() override;
C
chengduoZH 已提交
92 93
};

C
chengduoZH 已提交
94
class ConvTransposeOp : public framework::OperatorWithKernel {
C
chengduoZH 已提交
95 96 97
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;
98 99 100 101

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
102 103 104 105

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override;
C
chengduoZH 已提交
106 107
};

C
chengduoZH 已提交
108
class ConvTransposeOpGrad : public framework::OperatorWithKernel {
C
chengduoZH 已提交
109 110 111
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;
112 113 114 115

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
C
chengduoZH 已提交
116 117
};

118 119 120 121 122 123 124 125 126 127
class ConvTransposeOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
};

Q
QI JUN 已提交
128
template <typename DeviceContext, typename T>
129
class GemmConvTransposeKernel : public framework::OpKernel<T> {
C
chengduoZH 已提交
130 131
 public:
  void Compute(const framework::ExecutionContext& context) const override {
132 133 134 135
    const std::string data_layout_str =
        context.Attr<std::string>("data_format");
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
C
chengduoZH 已提交
136 137 138 139 140 141
    const Tensor* input = context.Input<Tensor>("Input");
    // The filter will be reshaped, so it should not be constant pointer
    Tensor filter = *context.Input<Tensor>("Filter");
    Tensor* output = context.Output<Tensor>("Output");

    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
C
chengduoZH 已提交
142
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
143
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
Y
Yibing Liu 已提交
144
    int groups = context.Attr<int>("groups");
145 146
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");
C
chengduoZH 已提交
147

148 149 150
    auto in_dims = input->dims();
    auto filter_dims = filter.dims();
    auto out_dims = output->dims();
C
chengduoZH 已提交
151
    const int batch_size = static_cast<int>(input->dims()[0]);
C
chengduoZH 已提交
152

153
    framework::DDim in_data_dims;
154
    if (data_layout != framework::DataLayout::kNHWC) {
155
      in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
156
    } else {
157
      in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
158 159
    }
    framework::DDim filter_data_dims =
160 161
        phi::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
162 163 164 165 166
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

    // input_shape_vec: {n, c, h, w} or {n, c, d, h, w} for channel_first
    // input_shape_vec: {n, h, w, c} or {n, d, h, w, c} for channel_last
167
    std::vector<int64_t> input_shape_vec = phi::vectorize(input->dims());
168
    // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
169
    std::vector<int64_t> filter_shape_vec = phi::vectorize(filter.dims());
170 171 172

    // use col_shape in the im2col and col2im (or vol2col and col2vol)
    // calculation
173
    // col_shape_vec: {o_c/g, k_h, k_w, h, w} or {o_c/g, k_d, k_h, k_w, d, h, w}
C
chengduoZH 已提交
174 175
    size_t data_dim = filter_shape_vec.size() - 2;
    std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
176
    if (data_layout != framework::DataLayout::kNHWC) {
177 178 179 180 181 182 183 184 185 186 187
      col_shape_vec[0] = out_dims[1] / groups;
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
      }
    } else {
      col_shape_vec[0] = out_dims[out_dims.size() - 1] / groups;
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 1];
      }
C
chengduoZH 已提交
188
    }
189
    DDim col_shape(phi::make_ddim(col_shape_vec));
C
chengduoZH 已提交
190 191

    // use col_matrix_shape in the gemm calculation
192
    // size: (o_c/g * k_h * k_w, h * w) or (o_c/g * k_d * k_h * k_w, d * h * w)
193
    DDim col_matrix_shape = phi::flatten_to_2d(col_shape, data_dim + 1);
C
chengduoZH 已提交
194 195 196 197 198 199 200 201 202 203

    Tensor col;
    col.mutable_data<T>(col_shape, context.GetPlace());
    // col_matrix shares the same piece of data with col,
    // but will be reshaped into a two-dimensional matrix shape
    // to call the matrix multiplication interface.
    Tensor col_matrix;
    col_matrix.ShareDataWith(col);
    col_matrix.Resize(col_matrix_shape);

204 205
    // output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first
    // output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last
206
    DDim output_shape =
207
        phi::slice_ddim(output->dims(), 1, output->dims().size());
C
chengduoZH 已提交
208

209 210 211
    // input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first
    // input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last
    DDim input_matrix_shape;
212
    if (data_layout != framework::DataLayout::kNHWC) {
213 214 215 216
      input_matrix_shape = {in_dims[1], col_matrix_shape[1]};
    } else {
      input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]};
    }
217

218 219
    // filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w)
    DDim filter_matrix_shape;
220
    if (data_layout != framework::DataLayout::kNHWC) {
221 222 223 224
      filter_matrix_shape = {in_dims[1], col_matrix_shape[0]};
    } else {
      filter_matrix_shape = {in_dims[in_dims.size() - 1], col_matrix_shape[0]};
    }
C
chengduoZH 已提交
225 226 227
    filter.Resize(filter_matrix_shape);

    output->mutable_data<T>(context.GetPlace());
228
    phi::funcs::SetConstant<DeviceContext, T> set_zero;
Q
QI JUN 已提交
229
    auto& dev_ctx = context.template device_context<DeviceContext>();
230
    auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
Q
QI JUN 已提交
231
    set_zero(dev_ctx, output, static_cast<T>(0));
C
chengduoZH 已提交
232

233
    int in_step =
234
        (data_layout != framework::DataLayout::kNHWC
235 236 237 238
             ? static_cast<int>(in_dims[1]) / groups
             : static_cast<int>(in_dims[in_dims.size() - 1]) / groups);

    int out_step =
239
        (data_layout != framework::DataLayout::kNHWC
240 241
             ? static_cast<int>(out_dims[1]) / groups
             : static_cast<int>(out_dims[out_dims.size() - 1]) / groups);
Q
QI JUN 已提交
242 243
    math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
    math::Col2VolFunctor<DeviceContext, T> col2vol;
244
    math::ConcatFunctor<DeviceContext, T> concat_functor;
C
chengduoZH 已提交
245

246 247
    // convolution transpose: gemm + col2im or col2vol (similar to conv-backward
    // on input)
248
    size_t D = input->dims().size();
C
chengduoZH 已提交
249
    for (int i = 0; i < batch_size; i++) {
250 251
      // batch with size (i_c, h * w) or (i_c, d * h * w) for channel_first
      // batch with size (h * w, i_c) or (d * h * w, i_c) for channel_last
C
chengduoZH 已提交
252 253
      Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);

254 255
      // output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first
      // output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last
C
chengduoZH 已提交
256 257
      Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape);

258
      std::vector<Tensor> output_batch_vec;
Y
Yibing Liu 已提交
259
      for (int g = 0; g < groups; g++) {
260 261
        int64_t start = g * in_step;
        int64_t end = (g + 1) * in_step;
262
        int axes = (data_layout != framework::DataLayout::kNHWC ? 0 : 1);
Y
Yibing Liu 已提交
263
        Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step);
264
        Tensor in_slice, out_slice;
Y
Yibing Liu 已提交
265 266

        // col_matrix = filter_slice * input_slice
267 268
        // of shape (o_c/g * k_h * k_w, h * w)
        // or (o_c/g * k_d * k_h * k_w, d * h * w)
269
        if (data_layout != framework::DataLayout::kNHWC) {
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
          in_slice = input_batch.Slice(g * in_step, (g + 1) * in_step);
          out_slice = output_batch.Slice(g * out_step, (g + 1) * out_step);
          blas.MatMul(filter_slice, true, in_slice, false, static_cast<T>(1.0),
                      &col_matrix, static_cast<T>(0.0));
        } else {
          Slice<DeviceContext, T, 2>(context, &input_batch, &in_slice, start,
                                     end, axes);
          start = g * out_step;
          end = (g + 1) * out_step;
          axes = D - 2;
          if (D == 4U) {
            Slice<DeviceContext, T, 3>(context, &output_batch, &out_slice,
                                       start, end, axes);
          } else if (D == 5U) {
            Slice<DeviceContext, T, 4>(context, &output_batch, &out_slice,
                                       start, end, axes);
          }
          blas.MatMul(filter_slice, true, in_slice, true, static_cast<T>(1.0),
                      &col_matrix, static_cast<T>(0.0));
        }
Y
Yibing Liu 已提交
290 291 292

        if (data_dim == 2U) {
          // col2im: col_matrix -> dy
293 294
          // from (o_c/g * k_h * k_w, h * w) to (o_c/g, o_h, o_w) or (o_h, o_w,
          // o_c/g)
Y
Yibing Liu 已提交
295
          col2im(dev_ctx, col, dilations, strides,
296 297 298
                 std::vector<int>{paddings[0], paddings[2], paddings[1],
                                  paddings[3]},
                 &out_slice, data_layout);
Y
Yibing Liu 已提交
299 300
        } else if (data_dim == 3U) {
          // col2vol: col_matrix -> dy
301 302 303 304
          // from (o_c/g * k_d * k_h * k_w, d * h * w) to (o_c/g, o_d, o_h, o_w)
          // or (o_d, o_h, o_w, o_c/g)
          col2vol(dev_ctx, col, dilations, strides, paddings, &out_slice,
                  data_layout);
Y
Yibing Liu 已提交
305
        }
306 307 308
        if (data_layout == framework::DataLayout::kNHWC) {
          output_batch_vec.push_back(out_slice);
        }
309 310 311 312
      }
      if (data_layout == framework::DataLayout::kNHWC) {
        concat_functor(dev_ctx, output_batch_vec, static_cast<int>(D - 2),
                       &output_batch);
313
      }
C
chengduoZH 已提交
314 315 316 317
    }
  }
};

Q
QI JUN 已提交
318
template <typename DeviceContext, typename T>
319
class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
C
chengduoZH 已提交
320 321
 public:
  void Compute(const framework::ExecutionContext& context) const override {
322 323 324 325
    const std::string data_layout_str =
        context.Attr<std::string>("data_format");
    const framework::DataLayout data_layout =
        framework::StringToDataLayout(data_layout_str);
C
chengduoZH 已提交
326 327 328 329 330 331 332 333 334 335 336
    const Tensor* input = context.Input<Tensor>("Input");
    const Tensor* output_grad =
        context.Input<Tensor>(framework::GradVarName("Output"));
    // For filter, we do not use const pointer b/c we will do reshape,
    // but we should avoid modifying its value.
    Tensor filter = *context.Input<Tensor>("Filter");
    Tensor* input_grad =
        context.Output<Tensor>(framework::GradVarName("Input"));
    Tensor* filter_grad =
        context.Output<Tensor>(framework::GradVarName("Filter"));

337 338
    if ((!input_grad) && (!filter_grad)) return;

C
chengduoZH 已提交
339 340
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
341
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
Y
Yibing Liu 已提交
342
    int groups = context.Attr<int>("groups");
343 344
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");
C
chengduoZH 已提交
345

346 347 348
    auto in_dims = input->dims();
    auto filter_dims = filter.dims();
    auto out_grad_dims = output_grad->dims();
C
chengduoZH 已提交
349
    const int batch_size = static_cast<int>(input->dims()[0]);
C
chengduoZH 已提交
350

351
    framework::DDim in_data_dims;
352
    if (data_layout != framework::DataLayout::kNHWC) {
353
      in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
354
    } else {
355
      in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
356 357
    }
    framework::DDim filter_data_dims =
358 359
        phi::slice_ddim(filter_dims, 2, filter_dims.size());
    std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
360 361 362 363 364
    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             in_data_dims, strides, ksize);

    // input_shape_vec: {n, c, h, w} or {n, c, d, h, w} for channel_first
    // input_shape_vec: {n, h, w, c} or {n, d, h, w, c} for channel_last
365
    std::vector<int64_t> input_shape_vec = phi::vectorize(input->dims());
366
    // filter_shape_vec: {i_c, o_c, k_h, k_w} or {i_c, o_c, k_d, k_h, k_w}
367
    std::vector<int64_t> filter_shape_vec = phi::vectorize(filter.dims());
368 369 370

    // use col_shape in the im2col and col2im (or vol2col and col2vol)
    // calculation
371
    // col_shape_vec: {o_c, k_h, k_w, h, w} or {o_c, k_d, k_h, k_w, d, h, w} for
C
chengduoZH 已提交
372 373
    size_t data_dim = filter_shape_vec.size() - 2;
    std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
374
    if (data_layout != framework::DataLayout::kNHWC) {
375 376 377 378 379 380 381 382 383 384 385
      col_shape_vec[0] = out_grad_dims[1];
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
      }
    } else {
      col_shape_vec[0] = out_grad_dims[out_grad_dims.size() - 1];
      for (size_t j = 0; j < data_dim; ++j) {
        col_shape_vec[j + 1] = filter_shape_vec[j + 2];
        col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 1];
      }
C
chengduoZH 已提交
386
    }
387
    DDim col_shape(phi::make_ddim(col_shape_vec));
C
chengduoZH 已提交
388

389
    // use col_matrix_shape in the gemm calculation
390
    // size: (o_c * k_h * k_w, h * w) or (o_c * k_d * k_h * k_w, d * h * w)
391
    DDim col_matrix_shape = phi::flatten_to_2d(col_shape, data_dim + 1);
C
chengduoZH 已提交
392

393 394
    // output size: (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for channel_first
    // output size: (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for channel_last
395
    DDim output_shape =
396
        phi::slice_ddim(output_grad->dims(), 1, output_grad->dims().size());
C
chengduoZH 已提交
397

398 399 400
    // input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first
    // input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last
    DDim input_matrix_shape;
401
    if (data_layout != framework::DataLayout::kNHWC) {
402 403 404 405
      input_matrix_shape = {in_dims[1], col_matrix_shape[1]};
    } else {
      input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]};
    }
C
chengduoZH 已提交
406

407 408
    // filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w)
    DDim filter_matrix_shape;
409
    if (data_layout != framework::DataLayout::kNHWC) {
410 411 412 413 414
      filter_matrix_shape = {in_dims[1], col_matrix_shape[0] / groups};
    } else {
      filter_matrix_shape = {in_dims[in_dims.size() - 1],
                             col_matrix_shape[0] / groups};
    }
C
chengduoZH 已提交
415
    filter.Resize(filter_matrix_shape);
416 417

    int in_step =
418
        (data_layout != framework::DataLayout::kNHWC
419 420
             ? static_cast<int>(in_dims[1]) / groups
             : static_cast<int>(in_dims[in_dims.size() - 1]) / groups);
Y
Yibing Liu 已提交
421
    int col_step = static_cast<int>(col_matrix_shape[0]) / groups;
C
chengduoZH 已提交
422 423 424 425

    // convolution transpose grad on input:
    // im2col + gemm (similar to conv-forward)
    // input need to compute gradient
Q
QI JUN 已提交
426
    auto& dev_ctx = context.template device_context<DeviceContext>();
427
    auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
C
chengduoZH 已提交
428 429 430 431 432 433
    if (input_grad || filter_grad) {
      Tensor col;
      col.mutable_data<T>(col_shape, context.GetPlace());
      // col_matrix shares the same piece of data with col,
      // but will be reshaped into a two-dimensional matrix shape
      // to call the matrix multiplication interface.
C
chengduoZH 已提交
434 435 436 437
      Tensor col_matrix;
      col_matrix.ShareDataWith(col);
      col_matrix.Resize(col_matrix_shape);

C
chengduoZH 已提交
438
      Tensor filter_grad_;
439
      phi::funcs::SetConstant<DeviceContext, T> set_zero;
C
chengduoZH 已提交
440

Q
QI JUN 已提交
441 442
      math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
      math::Vol2ColFunctor<DeviceContext, T> vol2col;
443
      math::ConcatFunctor<DeviceContext, T> concat_functor;
C
chengduoZH 已提交
444

C
chengduoZH 已提交
445 446
      if (input_grad) {
        input_grad->mutable_data<T>(context.GetPlace());
447
        set_zero(dev_ctx, input_grad, static_cast<T>(0));
C
chengduoZH 已提交
448
      }
449
      if (filter_grad) {  // filter_grad_ size (i_c, o_c/g, k_h, k_w)
C
chengduoZH 已提交
450
        filter_grad->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
451
        set_zero(dev_ctx, filter_grad, static_cast<T>(0));
C
chengduoZH 已提交
452 453
        filter_grad_ = *filter_grad;
        filter_grad_.Resize(filter_matrix_shape);
C
chengduoZH 已提交
454 455
      }

456
      size_t D = input->dims().size();
C
chengduoZH 已提交
457
      for (int i = 0; i < batch_size; i++) {
458 459 460 461
        // batch with size (o_c, o_h, o_w) or (o_c, o_d, o_h, o_w) for
        // channel_first
        // batch with size (o_h, o_w, o_c) or (o_d, o_h, o_w, o_c) for
        // channel_last
C
chengduoZH 已提交
462 463 464
        Tensor output_grad_batch =
            output_grad->Slice(i, i + 1).Resize(output_shape);

C
chengduoZH 已提交
465
        if (data_dim == 2U) {
466
          // im2col: dy -> col matrix
467 468 469 470
          // from (o_c, o_h, o_w) to (o_c * k_h * k_w, i_h * i_w) for
          // channel_first
          // from (o_h, o_w, o_c) to (o_c * k_h * k_w, i_h * i_w) for
          // channel_last
471
          im2col(dev_ctx, output_grad_batch, dilations, strides,
472 473 474
                 std::vector<int>{paddings[0], paddings[2], paddings[1],
                                  paddings[3]},
                 &col, data_layout);
C
chengduoZH 已提交
475
        } else if (data_dim == 3U) {
476
          // vol2col: dy -> col_matrix
477 478 479 480
          // from (o_c, o_d, o_h, o_w) to (o_c * k_d * k_h * k_w, i_d * i_h *
          // i_w) for channel_first
          // from (o_d, o_h, o_w, o_c) to (i_d * i_h * i_w, o_c * k_d * k_h *
          // k_w) for channel_last
Q
QI JUN 已提交
481
          vol2col(dev_ctx, output_grad_batch, dilations, strides, paddings,
482
                  &col, data_layout);
483
        }
C
chengduoZH 已提交
484

C
chengduoZH 已提交
485
        if (input_grad) {
486
          // batch with size (i_c, i_h, i_w) or (i_h, i_w, i_c)
C
chengduoZH 已提交
487 488
          Tensor input_grad_batch =
              input_grad->Slice(i, i + 1).Resize(input_matrix_shape);
489

C
chengduoZH 已提交
490
          // gemm: dx = filter * dy
491 492
          // (i_c, o_c * k_h * k_w) * (o_c * k_h * k_w, i_h * i_w) -> (i_c, i_h
          // * i_w)
493
          // or
494 495 496 497 498 499
          // (i_c, o_c * k_d * k_h * k_w) * (o_c * k_d * k_h * k_w, i_d * i_h *
          // i_w) -> (i_c,
          // i_d, i_h, i_w)
          // gemm: dx = dy^T * filter^T for channel_last

          std::vector<Tensor> input_grad_batch_vec;
Y
Yibing Liu 已提交
500
          for (int g = 0; g < groups; g++) {
501 502 503 504 505
            // input_grad_slice: (i_c/g, i_h * i_w) or (i_c/g, i_d * i_h * i_w)
            // for channel_first
            // input_grad_slice: (i_h * i_w, i_c/g) or (i_d * i_h * i_w, i_c/g)
            // for channel_last
            // filter_slice: (i_c/g, o_c/g * k_h * k_w)
Y
Yibing Liu 已提交
506
            Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step);
507 508
            // col_matrix_slice: (o_c/g * k_h * k_w, h * w) or (o_c/g * k_d *
            // k_h * k_w, d * h * w)
Y
Yibing Liu 已提交
509 510
            Tensor col_matrix_slice =
                col_matrix.Slice(g * col_step, (g + 1) * col_step);
511
            if (data_layout != framework::DataLayout::kNHWC) {
512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539
              Tensor input_grad_slice =
                  input_grad_batch.Slice(g * in_step, (g + 1) * in_step);
              blas.MatMul(filter_slice, false, col_matrix_slice, false,
                          static_cast<T>(1.0), &input_grad_slice,
                          static_cast<T>(0.0));
            } else {
              Tensor input_grad_slice;
              Slice<DeviceContext, T, 2>(context, &input_grad_batch,
                                         &input_grad_slice, g * in_step,
                                         (g + 1) * in_step, 1);
              blas.MatMul(col_matrix_slice, true, filter_slice, true,
                          static_cast<T>(1.0), &input_grad_slice,
                          static_cast<T>(0.0));
              DDim input_grad_slice_shape;
              if (data_dim == 2U) {
                input_grad_slice_shape = {in_dims[1], in_dims[2], in_step};
              } else {
                input_grad_slice_shape = {in_dims[1], in_dims[2], in_dims[3],
                                          in_step};
              }
              input_grad_slice =
                  input_grad_slice.Resize(input_grad_slice_shape);
              input_grad_batch_vec.push_back(input_grad_slice);
            }
          }
          if (data_layout == framework::DataLayout::kNHWC) {
            concat_functor(dev_ctx, input_grad_batch_vec,
                           static_cast<int>(D - 2), &input_grad_batch);
Y
Yibing Liu 已提交
540
          }
C
chengduoZH 已提交
541 542
        }
        if (filter_grad) {
543
          // input batch: (i_c, i_h * i_w) or (i_h, i_w * i_c)
C
chengduoZH 已提交
544 545
          Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
          // gemm: d_filter = x * dy^T
546 547
          // (i_c, i_h * i_w) * (i_h * i_w, o_c * k_h * k_w) -> (i_c, o_c * k_h
          // * k_w)
548
          // or
549 550
          // (i_c, i_d * i_h * i_w) * (i_d * i_h * i_w, o_c * k_d * k_h * k_w)
          // -> (i_c, o_c * k_d *
C
chengduoZH 已提交
551
          // k_h * k_w)
552 553
          // gemm: d_filter = x^T * dy^T for channel_last

Y
Yibing Liu 已提交
554 555 556 557 558
          for (int g = 0; g < groups; g++) {
            Tensor filter_grad_slice =
                filter_grad_.Slice(g * in_step, (g + 1) * in_step);
            Tensor col_matrix_slice =
                col_matrix.Slice(g * col_step, (g + 1) * col_step);
559
            if (data_layout != framework::DataLayout::kNHWC) {
560 561 562 563 564 565 566 567 568 569 570 571 572
              Tensor in_batch_slice =
                  in_batch.Slice(g * in_step, (g + 1) * in_step);
              blas.MatMul(in_batch_slice, false, col_matrix_slice, true,
                          static_cast<T>(1.0), &filter_grad_slice,
                          static_cast<T>(1.0));
            } else {
              Tensor in_batch_slice;
              Slice<DeviceContext, T, 2>(context, &in_batch, &in_batch_slice,
                                         g * in_step, (g + 1) * in_step, 1);
              blas.MatMul(in_batch_slice, true, col_matrix_slice, true,
                          static_cast<T>(1.0), &filter_grad_slice,
                          static_cast<T>(1.0));
            }
Y
Yibing Liu 已提交
573
          }
C
chengduoZH 已提交
574
        }
C
chengduoZH 已提交
575 576 577 578
      }
    }
  }
};
579

C
chengduoZH 已提交
580 581
}  // namespace operators
}  // namespace paddle