GemmConvOp.cpp 17.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "ConvOp.h"
16
#include "GemmFunctor.h"
17
#include "Im2Col.h"
18 19 20 21 22
#include "paddle/math/MemoryHandle.h"

namespace paddle {

/*
23
 * \brief Forward calculation of convolution.
24 25 26
 */
template <DeviceType Device>
class GemmConvFunction : public ConvFunctionBase {
W
Wu Yi 已提交
27
 public:
28 29 30 31
  void init(const FuncConfig& config) override {
    ConvFunctionBase::init(config);
  }

L
liaogang 已提交
32
  void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
H
hedaoyuan 已提交
33 34 35 36 37 38
    const TensorShape& input = inputs[0].shape();
    const TensorShape& filter = inputs[1].shape();
    const TensorShape& output = outputs[0].shape();
    checkShape(input, filter, output);
  }

39
  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
40 41
    CHECK_EQ(numInputs_, inputs.size());
    CHECK_EQ(numOutputs_, outputs.size());
H
hedaoyuan 已提交
42
    check(inputs, outputs);
43 44 45 46 47 48 49 50 51 52 53 54
    // TODO(hedaoyuan): Need to define some index macros,
    // to avoid useing 0 and 1.
    const TensorShape& input = inputs[0].shape();
    const TensorShape& filter = inputs[1].shape();
    const TensorShape& output = outputs[0].shape();

    real beta;
    if (outputs[0].getArgType() == ADD_TO) {
      beta = 1.0;
    } else {
      beta = 0.0;
    }
55

H
hedaoyuan 已提交
56 57 58 59 60 61 62 63 64
    size_t batchSize = input[0];
    size_t inputChannels = input[1];
    size_t inputHeight = input[2];
    size_t inputWidth = input[3];
    size_t filterHeight = getFilterHeight(filter);
    size_t filterWidth = getFilterWidth(filter);
    size_t outputChannels = output[1];
    size_t outputHeight = output[2];
    size_t outputWidth = output[3];
65 66 67 68

    real* inputData = inputs[0].data<real>();
    real* filterData = inputs[1].data<real>();
    real* outputData = outputs[0].data<real>();
69
    bool needIm2col = isNeedIm2col(filter);
70

71 72 73
    TensorShape imShape =
        TensorShape({inputChannels / groups_, inputHeight, inputWidth});

74
    TensorShape colShape;
75
    real* colData = NULL;
76

77
    if (needIm2col) {
78 79 80 81 82 83 84 85
      colShape = TensorShape({inputChannels / groups_,
                              filterHeight,
                              filterWidth,
                              outputHeight,
                              outputWidth});
      resizeBuffer<Device>(colShape.getElements());
      colData = reinterpret_cast<real*>(memory_->getBuf());
    }
86

87 88
    Im2ColFunctor<kCFO, Device, real> im2col;
    size_t inputOffset = imShape.getElements();
89 90
    size_t outputOffset =
        (outputChannels / groups_) * outputHeight * outputWidth;
H
hedaoyuan 已提交
91 92
    size_t filterOffset = filter.getElements() / groups_;

93
    for (size_t i = 0; i < batchSize; i++) {
94
      for (size_t g = 0; g < groups_; g++) {
95
        if (needIm2col) {
96 97 98 99 100 101 102
          im2col(inputData + g * inputOffset,
                 imShape,
                 colData,
                 colShape,
                 strideH(),
                 strideW(),
                 paddingH(),
X
xzl 已提交
103 104 105
                 paddingW(),
                 dilationH(),
                 dilationW());
106 107
        } else {
          colData = inputData + g * inputOffset;
108
        }
H
Bug fix  
hedaoyuan 已提交
109
        int M = outputChannels / groups_;
110
        int N = outputHeight * outputWidth;
H
Bug fix  
hedaoyuan 已提交
111
        int K = inputChannels / groups_ * filterHeight * filterWidth;
H
hedaoyuan 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124
        BlasGemm<Device, real>::compute(false,
                                        false,
                                        M,
                                        N,
                                        K,
                                        1.0f,
                                        filterData + g * filterOffset,
                                        K,
                                        colData,
                                        N,
                                        beta,
                                        outputData + g * outputOffset,
                                        N);
125
      }
H
hedaoyuan 已提交
126 127
      inputData += inputChannels * inputHeight * inputWidth;
      outputData += outputChannels * outputHeight * outputWidth;
128 129 130 131
    }
  }
};

H
hedaoyuan 已提交
132 133
#ifdef PADDLE_MOBILE_INFERENCE

H
hedaoyuan 已提交
134 135 136 137 138
/*
 * \brief Forward calculation of convolution, optimized for mobile.
 */
template <DeviceType Device>
class GemmConvMobileFunction : public ConvFunctionBase {
W
Wu Yi 已提交
139
 public:
H
hedaoyuan 已提交
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
  void init(const FuncConfig& config) override {
    ConvFunctionBase::init(config);
  }

  void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
    const TensorShape& input = inputs[0].shape();
    const TensorShape& filter = inputs[1].shape();
    const TensorShape& output = outputs[0].shape();
    checkShape(input, filter, output);
  }

  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
    CHECK_EQ(numInputs_, inputs.size());
    CHECK_EQ(numOutputs_, outputs.size());
    check(inputs, outputs);
    // TODO(hedaoyuan): Need to define some index macros,
    // to avoid useing 0 and 1.
    const TensorShape& input = inputs[0].shape();
    const TensorShape& filter = inputs[1].shape();
    const TensorShape& output = outputs[0].shape();

    real beta;
    if (outputs[0].getArgType() == ADD_TO) {
      beta = 1.0;
    } else {
      beta = 0.0;
    }

    size_t batchSize = input[0];
    size_t inputChannels = input[1];
    size_t inputHeight = input[2];
    size_t inputWidth = input[3];
    size_t filterHeight = getFilterHeight(filter);
    size_t filterWidth = getFilterWidth(filter);
    size_t outputChannels = output[1];
    size_t outputHeight = output[2];
    size_t outputWidth = output[3];

    real* inputData = inputs[0].data<real>();
    real* filterData = inputs[1].data<real>();
    real* outputData = outputs[0].data<real>();
H
hedaoyuan 已提交
181
    real* colData = NULL;
H
hedaoyuan 已提交
182 183 184 185 186 187
    bool needIm2col = isNeedIm2col(filter);

    TensorShape imShape =
        TensorShape({inputChannels / groups_, inputHeight, inputWidth});
    TensorShape colShape;

H
hedaoyuan 已提交
188 189 190 191 192 193 194 195 196
    // Max col matrix width 4096, Max col matrix size 4M.
    size_t outputHeightSteps =
        std::min(std::max(4096 / outputWidth, (size_t)1), outputHeight);
    size_t maxColWidth = outputHeightSteps * outputWidth;
    size_t channelSteps =
        std::min(std::max((1048576 / maxColWidth) / filterHeight * filterWidth,
                          (size_t)1),
                 inputChannels / groups_);
    size_t maxColHeight = channelSteps * filterHeight * filterWidth;
H
hedaoyuan 已提交
197 198 199 200 201 202 203 204

    if (needIm2col) {
      colShape = TensorShape({inputChannels / groups_,
                              filterHeight,
                              filterWidth,
                              outputHeight,
                              outputWidth});

H
hedaoyuan 已提交
205
      resizeBuffer<Device>(maxColHeight * maxColWidth * sizeof(real));
H
hedaoyuan 已提交
206 207 208
      colData = reinterpret_cast<real*>(memory_->getBuf());
    }

H
hedaoyuan 已提交
209
    Im2ColMobileFunctor<real> im2col;
H
hedaoyuan 已提交
210 211 212 213 214
    size_t inputOffset = imShape.getElements();
    size_t outputOffset =
        (outputChannels / groups_) * outputHeight * outputWidth;
    size_t filterOffset = filter.getElements() / groups_;

H
hedaoyuan 已提交
215 216
    int nStride = outputHeight * outputWidth;
    int kStride = inputChannels / groups_ * filterHeight * filterWidth;
H
hedaoyuan 已提交
217
    for (size_t i = 0; i < batchSize; i++) {
H
hedaoyuan 已提交
218
      filterData = inputs[1].data<real>();
H
hedaoyuan 已提交
219 220 221
      for (size_t g = 0; g < groups_; g++) {
        if (needIm2col) {
          real beta_ = beta;
H
hedaoyuan 已提交
222 223 224 225 226 227 228 229 230
          for (size_t ic = 0; ic < inputChannels / groups_;
               ic += channelSteps) {
            int channels = std::min(inputChannels / groups_ - ic, channelSteps);
            for (size_t oh = 0; oh < outputHeight; oh += outputHeightSteps) {
              int height = std::min(outputHeight - oh, outputHeightSteps);

              int M = outputChannels / groups_;
              int N = height * outputWidth;
              int K = channels * filterHeight * filterWidth;
H
hedaoyuan 已提交
231
              // im2col
H
hedaoyuan 已提交
232
              im2col(inputData,
H
hedaoyuan 已提交
233 234 235 236 237 238 239
                     imShape,
                     colData,
                     colShape,
                     strideH(),
                     strideW(),
                     paddingH(),
                     paddingW(),
H
hedaoyuan 已提交
240 241
                     dilationH(),
                     dilationW(),
H
hedaoyuan 已提交
242 243 244
                     channels,
                     oh,
                     height,
H
hedaoyuan 已提交
245 246 247
                     N);

              // gemm
H
hedaoyuan 已提交
248 249 250 251 252 253 254
              BlasGemm<Device, real>::compute(
                  false,
                  false,
                  M,
                  N,
                  K,
                  1.0f,
H
hedaoyuan 已提交
255
                  filterData + ic * filterHeight * filterWidth,
H
hedaoyuan 已提交
256 257 258 259
                  kStride,
                  colData,
                  N,
                  beta_,
H
hedaoyuan 已提交
260
                  outputData + oh * outputWidth,
H
hedaoyuan 已提交
261
                  nStride);
H
hedaoyuan 已提交
262 263 264 265 266 267 268
            }
            beta_ = 1.0;
          }
        } else {
          int M = outputChannels / groups_;
          int N = outputHeight * outputWidth;
          int K = inputChannels / groups_ * filterHeight * filterWidth;
H
hedaoyuan 已提交
269 270 271 272 273 274
          BlasGemm<Device, real>::compute(false,
                                          false,
                                          M,
                                          N,
                                          K,
                                          1.0f,
H
hedaoyuan 已提交
275
                                          filterData,
H
hedaoyuan 已提交
276
                                          K,
H
hedaoyuan 已提交
277
                                          inputData,
H
hedaoyuan 已提交
278 279
                                          N,
                                          beta,
H
hedaoyuan 已提交
280
                                          outputData,
H
hedaoyuan 已提交
281
                                          N);
H
hedaoyuan 已提交
282
        }
H
hedaoyuan 已提交
283 284 285
        inputData += inputOffset;
        outputData += outputOffset;
        filterData += filterOffset;
H
hedaoyuan 已提交
286 287
      }
    }
H
hedaoyuan 已提交
288 289

    memory_.reset();
H
hedaoyuan 已提交
290 291 292
  }
};

H
hedaoyuan 已提交
293 294
#endif

295 296 297 298 299
/*
 * \brief Backward input calculation of convolution.
 */
template <DeviceType Device>
class GemmConvGradInputFunction : public ConvFunctionBase {
W
Wu Yi 已提交
300
 public:
301 302 303 304
  void init(const FuncConfig& config) override {
    ConvFunctionBase::init(config);
  }

L
liaogang 已提交
305
  void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
H
hedaoyuan 已提交
306 307 308 309 310 311
    const TensorShape& output = inputs[0].shape();
    const TensorShape& filter = inputs[1].shape();
    const TensorShape& input = outputs[0].shape();
    checkShape(input, filter, output);
  }

312 313 314
  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
    CHECK_EQ(numInputs_, inputs.size());
    CHECK_EQ(numOutputs_, outputs.size());
H
hedaoyuan 已提交
315
    check(inputs, outputs);
H
hedaoyuan 已提交
316 317 318
    // Since the implementation of Col2ImFunctor is ADD_TO,
    // this function only supports ADD_TO mode.
    CHECK_EQ(outputs[0].getArgType(), ADD_TO);
319
    const TensorShape& output = inputs[0].shape();
320
    const TensorShape& filter = inputs[1].shape();
321 322 323 324 325 326
    const TensorShape& input = outputs[0].shape();

    size_t batchSize = input[0];
    size_t inputChannels = input[1];
    size_t inputHeight = input[2];
    size_t inputWidth = input[3];
H
hedaoyuan 已提交
327 328
    size_t filterHeight = getFilterHeight(filter);
    size_t filterWidth = getFilterWidth(filter);
329 330 331 332 333 334 335
    size_t outputChannels = output[1];
    size_t outputHeight = output[2];
    size_t outputWidth = output[3];

    real* outputGrad = inputs[0].data<real>();
    real* filterData = inputs[1].data<real>();
    real* inputGrad = outputs[0].data<real>();
336
    bool needIm2col = isNeedIm2col(filter);
337

338 339 340
    TensorShape imShape =
        TensorShape({inputChannels / groups_, inputHeight, inputWidth});

341
    TensorShape colShape;
342
    real* colData = NULL;
343

344
    if (needIm2col) {
345 346 347 348 349 350 351 352
      colShape = TensorShape({inputChannels / groups_,
                              filterHeight,
                              filterWidth,
                              outputHeight,
                              outputWidth});
      resizeBuffer<Device>(colShape.getElements());
      colData = reinterpret_cast<real*>(memory_->getBuf());
    }
353

354 355
    Col2ImFunctor<kCFO, Device, real> col2im;
    size_t inputOffset = imShape.getElements();
H
format  
hedaoyuan 已提交
356
    size_t outputOffset =
357 358 359 360 361 362 363 364
        (outputChannels / groups_) * outputHeight * outputWidth;
    size_t filterOffset = filter.getElements() / groups_;

    for (size_t i = 0; i < batchSize; i++) {
      for (size_t g = 0; g < groups_; g++) {
        int K = outputChannels / groups_;
        int N = outputHeight * outputWidth;
        int M = inputChannels / groups_ * filterHeight * filterWidth;
365
        real scale = 0.0f;
366 367
        if (!needIm2col) {
          colData = inputGrad + g * inputOffset;
368 369
          scale = 1.0f;
        }
H
hedaoyuan 已提交
370 371 372 373 374 375 376 377 378 379 380 381 382
        BlasGemm<Device, real>::compute(true,
                                        false,
                                        M,
                                        N,
                                        K,
                                        1.0f,
                                        filterData + g * filterOffset,
                                        M,
                                        outputGrad + g * outputOffset,
                                        N,
                                        scale,
                                        colData,
                                        N);
383
        if (needIm2col) {
384 385
          col2im(inputGrad + g * inputOffset,
                 imShape,
386
                 colData,
387 388 389 390
                 colShape,
                 strideH(),
                 strideW(),
                 paddingH(),
X
xzl 已提交
391 392 393
                 paddingW(),
                 dilationH(),
                 dilationW());
394
        }
395 396 397 398
      }
      inputGrad += inputChannels * inputHeight * inputWidth;
      outputGrad += outputChannels * outputHeight * outputWidth;
    }
399 400 401 402 403 404 405 406
  }
};

/*
 * \brief Backward filter calculation of convolution.
 */
template <DeviceType Device>
class GemmConvGradFilterFunction : public ConvFunctionBase {
W
Wu Yi 已提交
407
 public:
408 409 410 411
  void init(const FuncConfig& config) override {
    ConvFunctionBase::init(config);
  }

L
liaogang 已提交
412
  void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
H
hedaoyuan 已提交
413 414 415 416 417 418
    const TensorShape& output = inputs[0].shape();
    const TensorShape& input = inputs[1].shape();
    const TensorShape& filter = outputs[0].shape();
    checkShape(input, filter, output);
  }

419 420 421
  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
    CHECK_EQ(numInputs_, inputs.size());
    CHECK_EQ(numOutputs_, outputs.size());
H
hedaoyuan 已提交
422
    check(inputs, outputs);
423
    const TensorShape& output = inputs[0].shape();
424
    const TensorShape& input = inputs[1].shape();
425 426
    const TensorShape& filter = outputs[0].shape();

427 428 429 430 431 432 433
    real beta;
    if (outputs[0].getArgType() == ADD_TO) {
      beta = 1.0;
    } else {
      beta = 0.0;
    }

434 435 436 437
    size_t batchSize = input[0];
    size_t inputChannels = input[1];
    size_t inputHeight = input[2];
    size_t inputWidth = input[3];
H
hedaoyuan 已提交
438 439
    size_t filterHeight = getFilterHeight(filter);
    size_t filterWidth = getFilterWidth(filter);
440 441 442 443 444 445 446
    size_t outputChannels = output[1];
    size_t outputHeight = output[2];
    size_t outputWidth = output[3];

    real* outputGrad = inputs[0].data<real>();
    real* inputData = inputs[1].data<real>();
    real* filterGrad = outputs[0].data<real>();
447
    bool needIm2col = isNeedIm2col(filter);
448

449 450 451
    TensorShape imShape =
        TensorShape({inputChannels / groups_, inputHeight, inputWidth});

452
    TensorShape colShape;
453
    real* colData = NULL;
454

455
    if (needIm2col) {
456 457 458 459 460 461 462 463
      colShape = TensorShape({inputChannels / groups_,
                              filterHeight,
                              filterWidth,
                              outputHeight,
                              outputWidth});
      resizeBuffer<Device>(colShape.getElements());
      colData = reinterpret_cast<real*>(memory_->getBuf());
    }
464

465 466
    Im2ColFunctor<kCFO, Device, real> im2col;
    size_t inputOffset = imShape.getElements();
467 468 469 470 471
    size_t outputOffset =
        (outputChannels / groups_) * outputHeight * outputWidth;
    size_t filterOffset = filter.getElements() / groups_;
    for (size_t i = 0; i < batchSize; i++) {
      for (size_t g = 0; g < groups_; g++) {
472
        if (needIm2col) {
473 474 475 476 477 478 479
          im2col(inputData + g * inputOffset,
                 imShape,
                 colData,
                 colShape,
                 strideH(),
                 strideW(),
                 paddingH(),
X
xzl 已提交
480 481 482
                 paddingW(),
                 dilationH(),
                 dilationW());
483 484
        } else {
          colData = inputData + g * inputOffset;
485
        }
486 487 488
        int M = outputChannels / groups_;
        int K = outputHeight * outputWidth;
        int N = inputChannels / groups_ * filterHeight * filterWidth;
H
hedaoyuan 已提交
489 490 491 492 493 494 495 496 497 498 499 500 501
        BlasGemm<Device, real>::compute(false,
                                        true,
                                        M,
                                        N,
                                        K,
                                        1.0f,
                                        outputGrad + g * outputOffset,
                                        K,
                                        colData,
                                        K,
                                        i == 0 ? beta : 1.0f,
                                        filterGrad + g * filterOffset,
                                        N);
502
      }
503 504
      inputData += inputChannels * inputHeight * inputWidth;
      outputGrad += outputChannels * outputHeight * outputWidth;
505
    }
506 507 508
  }
};

H
hedaoyuan 已提交
509 510 511
#ifdef PADDLE_MOBILE_INFERENCE
REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvMobileFunction);
#else
512
REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction);
H
hedaoyuan 已提交
513
#endif
514 515
REGISTER_TYPED_FUNC(GemmConvGradInput, CPU, GemmConvGradInputFunction);
REGISTER_TYPED_FUNC(GemmConvGradFilter, CPU, GemmConvGradFilterFunction);
516
#ifdef PADDLE_WITH_CUDA
517
REGISTER_TYPED_FUNC(GemmConv, GPU, GemmConvFunction);
518 519
REGISTER_TYPED_FUNC(GemmConvGradInput, GPU, GemmConvGradInputFunction);
REGISTER_TYPED_FUNC(GemmConvGradFilter, GPU, GemmConvGradFilterFunction);
H
hedaoyuan 已提交
520
#endif
521 522

}  // namespace paddle