conv_transpose_op.h 15.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
S
Siddharth Goyal 已提交
16
#include <vector>
Y
Yi Wang 已提交
17 18
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
Y
Yu Yang 已提交
19
#include "paddle/fluid/operators/math/blas.h"
20
#include "paddle/fluid/operators/math/depthwise_conv.h"
Y
Yi Wang 已提交
21 22
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
C
chengduoZH 已提交
23 24 25 26 27 28 29 30 31

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using DDim = framework::DDim;

// Define Op classes in .h file so that other conv transpose
// operator implementations can reuse the code.
C
chengduoZH 已提交
32 33
class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
34
  void Make() override;
C
chengduoZH 已提交
35 36
};

C
chengduoZH 已提交
37 38
class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
39
  void Make() override;
C
chengduoZH 已提交
40 41
};

C
chengduoZH 已提交
42
class ConvTransposeOp : public framework::OperatorWithKernel {
C
chengduoZH 已提交
43 44 45
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;
46 47 48 49

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
C
chengduoZH 已提交
50 51
};

C
chengduoZH 已提交
52
class ConvTransposeOpGrad : public framework::OperatorWithKernel {
C
chengduoZH 已提交
53 54 55
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override;
56 57 58 59

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
C
chengduoZH 已提交
60 61
};

Q
QI JUN 已提交
62
template <typename DeviceContext, typename T>
63
class GemmConvTransposeKernel : public framework::OpKernel<T> {
C
chengduoZH 已提交
64 65 66 67 68 69 70 71
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    const Tensor* input = context.Input<Tensor>("Input");
    // The filter will be reshaped, so it should not be constant pointer
    Tensor filter = *context.Input<Tensor>("Filter");
    Tensor* output = context.Output<Tensor>("Output");

    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
C
chengduoZH 已提交
72
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
73
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
Y
Yibing Liu 已提交
74
    int groups = context.Attr<int>("groups");
C
chengduoZH 已提交
75

C
chengduoZH 已提交
76
    const int batch_size = static_cast<int>(input->dims()[0]);
C
chengduoZH 已提交
77

C
chengduoZH 已提交
78
    // input_shape_vec: {n, c, h, w} or {n, c, d, h, w}
79
    std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims());
C
chengduoZH 已提交
80
    // filter_shape_vec: {k_o, k_c, k_h, k_w} or {k_o, k_c, k_d, k_h, k_w}
81 82 83 84
    std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims());

    // use col_shape in the im2col and col2im (or vol2col and col2vol)
    // calculation
Y
Yibing Liu 已提交
85
    // col_shape_vec: {c/g, k_h, k_w, h, w} or {c/g, k_d, k_h, k_w, d, h, w}
C
chengduoZH 已提交
86 87
    size_t data_dim = filter_shape_vec.size() - 2;
    std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
Y
Yibing Liu 已提交
88
    col_shape_vec[0] = output->dims()[1] / groups;
C
chengduoZH 已提交
89 90 91 92
    for (size_t j = 0; j < data_dim; ++j) {
      col_shape_vec[j + 1] = filter_shape_vec[j + 2];
      col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
    }
93
    DDim col_shape(framework::make_ddim(col_shape_vec));
C
chengduoZH 已提交
94 95

    // use col_matrix_shape in the gemm calculation
Y
Yibing Liu 已提交
96
    // size: (c/g * k_h * k_w, h * w) or (c/g * k_d * k_h * k_w, d * h * w)
C
chengduoZH 已提交
97
    DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1);
C
chengduoZH 已提交
98 99 100 101 102 103 104 105 106 107

    Tensor col;
    col.mutable_data<T>(col_shape, context.GetPlace());
    // col_matrix shares the same piece of data with col,
    // but will be reshaped into a two-dimensional matrix shape
    // to call the matrix multiplication interface.
    Tensor col_matrix;
    col_matrix.ShareDataWith(col);
    col_matrix.Resize(col_matrix_shape);

108 109 110
    // output size: (c, o_h, o_w) or (c, o_d, o_h, o_w)
    DDim output_shape =
        framework::slice_ddim(output->dims(), 1, output->dims().size());
C
chengduoZH 已提交
111

112 113 114
    // input matrix size: (m, h * w) or (m, d * h * w)
    DDim input_matrix_shape = {input->dims()[1], col_matrix_shape[1]};

Y
Yibing Liu 已提交
115
    // filter size: (m, c/g * k_h * k_w) or (m, c/g * k_d * k_h * k_w)
116
    DDim filter_matrix_shape = {input->dims()[1], col_matrix_shape[0]};
C
chengduoZH 已提交
117 118 119
    filter.Resize(filter_matrix_shape);

    output->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
120 121
    math::SetConstant<DeviceContext, T> set_zero;
    auto& dev_ctx = context.template device_context<DeviceContext>();
Y
Yu Yang 已提交
122
    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
Q
QI JUN 已提交
123
    set_zero(dev_ctx, output, static_cast<T>(0));
C
chengduoZH 已提交
124

Y
Yibing Liu 已提交
125 126
    int in_step = static_cast<int>(input->dims()[1]) / groups;
    int out_step = static_cast<int>(output->dims()[1]) / groups;
Q
QI JUN 已提交
127 128
    math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
    math::Col2VolFunctor<DeviceContext, T> col2vol;
C
chengduoZH 已提交
129

130 131
    // convolution transpose: gemm + col2im or col2vol (similar to conv-backward
    // on input)
C
chengduoZH 已提交
132
    for (int i = 0; i < batch_size; i++) {
133
      // batch with size (m, h * w) or (m, d * h * w)
C
chengduoZH 已提交
134 135
      Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);

136
      // output size: (c, o_h, o_w) or (c, o_d, o_h, o_w)
C
chengduoZH 已提交
137 138
      Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape);

Y
Yibing Liu 已提交
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
      for (int g = 0; g < groups; g++) {
        Tensor in_slice = input_batch.Slice(g * in_step, (g + 1) * in_step);
        Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step);
        Tensor out_slice = output_batch.Slice(g * out_step, (g + 1) * out_step);

        // col_matrix = filter_slice * input_slice
        // of shape (c/g * k_h * k_w, h * w)
        // or (c/g * k_d * k_h * k_w, d * h * w)
        blas.MatMul(filter_slice, true, in_slice, false, static_cast<T>(1.0),
                    &col_matrix, static_cast<T>(0.0));

        if (data_dim == 2U) {
          // col2im: col_matrix -> dy
          // from (c/g * k_h * k_w, h * w) to (c/g, o_h, o_w)
          col2im(dev_ctx, col, dilations, strides,
                 std::vector<int>{paddings[0], paddings[1], paddings[0],
                                  paddings[1]},
                 &out_slice);
        } else if (data_dim == 3U) {
          // col2vol: col_matrix -> dy
          // from (c/g * k_d * k_h * k_w, d * h * w) to (c/g, o_d, o_h, o_w)
          col2vol(dev_ctx, col, dilations, strides, paddings, &out_slice);
        }
162
      }
C
chengduoZH 已提交
163 164 165 166
    }
  }
};

Q
QI JUN 已提交
167
template <typename DeviceContext, typename T>
168
class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
C
chengduoZH 已提交
169 170 171 172 173 174 175 176 177 178 179 180 181
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    const Tensor* input = context.Input<Tensor>("Input");
    const Tensor* output_grad =
        context.Input<Tensor>(framework::GradVarName("Output"));
    // For filter, we do not use const pointer b/c we will do reshape,
    // but we should avoid modifying its value.
    Tensor filter = *context.Input<Tensor>("Filter");
    Tensor* input_grad =
        context.Output<Tensor>(framework::GradVarName("Input"));
    Tensor* filter_grad =
        context.Output<Tensor>(framework::GradVarName("Filter"));

182 183
    if ((!input_grad) && (!filter_grad)) return;

C
chengduoZH 已提交
184 185
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
186
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
Y
Yibing Liu 已提交
187
    int groups = context.Attr<int>("groups");
C
chengduoZH 已提交
188

C
chengduoZH 已提交
189
    const int batch_size = static_cast<int>(input->dims()[0]);
C
chengduoZH 已提交
190

C
chengduoZH 已提交
191
    // input_shape_vec: {n, c, h, w} or {n, c, d, h, w}
192
    std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims());
C
chengduoZH 已提交
193
    // filter_shape_vec: {k_o, k_c, k_h, k_w} or {k_o, k_c, k_d, k_h, k_w}
194 195 196 197 198
    std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims());

    // use col_shape in the im2col and col2im (or vol2col and col2vol)
    // calculation
    // col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w}
C
chengduoZH 已提交
199 200 201 202 203 204 205
    size_t data_dim = filter_shape_vec.size() - 2;
    std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
    col_shape_vec[0] = output_grad->dims()[1];
    for (size_t j = 0; j < data_dim; ++j) {
      col_shape_vec[j + 1] = filter_shape_vec[j + 2];
      col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
    }
206
    DDim col_shape(framework::make_ddim(col_shape_vec));
C
chengduoZH 已提交
207

208 209
    // use col_matrix_shape in the gemm calculation
    // size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
C
chengduoZH 已提交
210
    DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1);
C
chengduoZH 已提交
211

212 213 214
    // output size: (c, o_h, o_w) or (c, o_d, o_h, o_w)
    DDim output_shape = framework::slice_ddim(output_grad->dims(), 1,
                                              output_grad->dims().size());
C
chengduoZH 已提交
215

216 217
    // input matrix size: (m, h * w) or (m, d * h * w)
    DDim input_matrix_shape = {input->dims()[1], col_matrix_shape[1]};
C
chengduoZH 已提交
218

Y
Yibing Liu 已提交
219 220
    // filter size: (m, c/g * k_h * k_w) or (m, c/g * k_d * k_h * k_w)
    DDim filter_matrix_shape = {input->dims()[1], col_matrix_shape[0] / groups};
C
chengduoZH 已提交
221
    filter.Resize(filter_matrix_shape);
Y
Yibing Liu 已提交
222 223
    int in_step = static_cast<int>(input->dims()[1]) / groups;
    int col_step = static_cast<int>(col_matrix_shape[0]) / groups;
C
chengduoZH 已提交
224 225 226 227

    // convolution transpose grad on input:
    // im2col + gemm (similar to conv-forward)
    // input need to compute gradient
Q
QI JUN 已提交
228
    auto& dev_ctx = context.template device_context<DeviceContext>();
Y
Yu Yang 已提交
229
    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
C
chengduoZH 已提交
230 231 232 233 234 235
    if (input_grad || filter_grad) {
      Tensor col;
      col.mutable_data<T>(col_shape, context.GetPlace());
      // col_matrix shares the same piece of data with col,
      // but will be reshaped into a two-dimensional matrix shape
      // to call the matrix multiplication interface.
C
chengduoZH 已提交
236 237 238 239
      Tensor col_matrix;
      col_matrix.ShareDataWith(col);
      col_matrix.Resize(col_matrix_shape);

C
chengduoZH 已提交
240
      Tensor filter_grad_;
Q
QI JUN 已提交
241
      math::SetConstant<DeviceContext, T> set_zero;
C
chengduoZH 已提交
242

Q
QI JUN 已提交
243 244
      math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
      math::Vol2ColFunctor<DeviceContext, T> vol2col;
C
chengduoZH 已提交
245

C
chengduoZH 已提交
246 247 248
      if (input_grad) {
        input_grad->mutable_data<T>(context.GetPlace());
      }
Y
Yibing Liu 已提交
249
      if (filter_grad) {  // filter size (m, c/g, k_h, k_w)
C
chengduoZH 已提交
250
        filter_grad->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
251
        set_zero(dev_ctx, filter_grad, static_cast<T>(0));
C
chengduoZH 已提交
252 253
        filter_grad_ = *filter_grad;
        filter_grad_.Resize(filter_matrix_shape);
C
chengduoZH 已提交
254 255
      }

C
chengduoZH 已提交
256 257
      for (int i = 0; i < batch_size; i++) {
        // batch with size (c, o_h * o_w)
C
chengduoZH 已提交
258 259 260
        Tensor output_grad_batch =
            output_grad->Slice(i, i + 1).Resize(output_shape);

C
chengduoZH 已提交
261
        if (data_dim == 2U) {
262 263
          // im2col: dy -> col matrix
          // from (c, o_h, o_w) to (c * k_h * k_w, h * w)
264
          im2col(dev_ctx, output_grad_batch, dilations, strides,
C
chengduoZH 已提交
265 266 267
                 std::vector<int>{paddings[0], paddings[1], paddings[0],
                                  paddings[1]},
                 &col);
C
chengduoZH 已提交
268
        } else if (data_dim == 3U) {
269 270
          // vol2col: dy -> col_matrix
          // from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
Q
QI JUN 已提交
271 272
          vol2col(dev_ctx, output_grad_batch, dilations, strides, paddings,
                  &col);
273
        }
C
chengduoZH 已提交
274

C
chengduoZH 已提交
275 276 277 278 279 280
        if (input_grad) {
          // batch with size (m, h, w)
          Tensor input_grad_batch =
              input_grad->Slice(i, i + 1).Resize(input_matrix_shape);
          // gemm: dx = filter * dy
          // (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, h * w)
281
          // or
C
chengduoZH 已提交
282 283
          // (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m,
          // d, h, w)
Y
Yibing Liu 已提交
284 285 286 287 288 289 290 291 292 293 294
          for (int g = 0; g < groups; g++) {
            Tensor input_grad_slice =
                input_grad_batch.Slice(g * in_step, (g + 1) * in_step);
            Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step);
            Tensor col_matrix_slice =
                col_matrix.Slice(g * col_step, (g + 1) * col_step);

            blas.MatMul(filter_slice, false, col_matrix_slice, false,
                        static_cast<T>(1.0), &input_grad_slice,
                        static_cast<T>(0.0));
          }
C
chengduoZH 已提交
295 296 297 298 299
        }
        if (filter_grad) {
          // input batch
          Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
          // gemm: d_filter = x * dy^T
300 301
          // (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, k_h * k_w)
          // or
C
chengduoZH 已提交
302 303
          // (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d *
          // k_h * k_w)
Y
Yibing Liu 已提交
304 305 306 307 308 309 310 311 312 313 314
          for (int g = 0; g < groups; g++) {
            Tensor in_batch_slice =
                in_batch.Slice(g * in_step, (g + 1) * in_step);
            Tensor filter_grad_slice =
                filter_grad_.Slice(g * in_step, (g + 1) * in_step);
            Tensor col_matrix_slice =
                col_matrix.Slice(g * col_step, (g + 1) * col_step);
            blas.MatMul(in_batch_slice, false, col_matrix_slice, true,
                        static_cast<T>(1.0), &filter_grad_slice,
                        static_cast<T>(1.0));
          }
C
chengduoZH 已提交
315
        }
C
chengduoZH 已提交
316 317 318 319
      }
    }
  }
};
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347

template <typename DeviceContext, typename T>
class DepthwiseConvTransposeKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    const Tensor* input = context.Input<Tensor>("Input");
    Tensor filter = *context.Input<Tensor>("Filter");
    Tensor* output = context.Output<Tensor>("Output");
    output->mutable_data<T>(context.GetPlace());

    int groups = context.Attr<int>("groups");
    PADDLE_ENFORCE_EQ(groups, filter.dims()[0]);

    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
    std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
    for (auto v : dilations) {
      PADDLE_ENFORCE_EQ(v, 1);
    }

    output->mutable_data<T>(context.GetPlace());
    auto& dev_ctx = context.template device_context<DeviceContext>();
    math::SetConstant<DeviceContext, T> set_zero;
    set_zero(dev_ctx, output, static_cast<T>(0));

    math::DepthwiseConvInputGradFunctor<DeviceContext, T>
        depthwiseConvInputGrad;
    depthwiseConvInputGrad(dev_ctx, *output, filter, *input, strides, paddings,
348
                           output);
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372
  }
};

template <typename DeviceContext, typename T>
class DepthwiseConvTransposeGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    const Tensor* input = context.Input<Tensor>("Input");
    const Tensor* output_grad =
        context.Input<Tensor>(framework::GradVarName("Output"));
    Tensor* input_grad =
        context.Output<Tensor>(framework::GradVarName("Input"));
    Tensor* filter_grad =
        context.Output<Tensor>(framework::GradVarName("Filter"));
    Tensor filter = *context.Input<Tensor>("Filter");

    if (!input_grad && !filter_grad) return;

    auto& dev_ctx = context.template device_context<DeviceContext>();
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");

    if (input_grad) {
      math::DepthwiseConvFunctor<DeviceContext, T> depthwiseConv;
373
      depthwiseConv(dev_ctx, *output_grad, filter, strides, paddings,
374 375 376 377 378 379 380 381 382 383 384
                    input_grad);
    }

    if (filter_grad) {
      math::SetConstant<DeviceContext, T> set_zero;
      filter_grad->mutable_data<T>(context.GetPlace());
      set_zero(dev_ctx, filter_grad, static_cast<T>(0));

      math::DepthwiseConvFilterGradFunctor<DeviceContext, T>
          depthwiseConvFilterGrad;
      depthwiseConvFilterGrad(dev_ctx, *output_grad, *input, strides, paddings,
385
                              filter_grad);
386 387 388
    }
  }
};
C
chengduoZH 已提交
389 390
}  // namespace operators
}  // namespace paddle