GemmConvOp.cpp 12.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "ConvOp.h"
16
#include "GemmFunctor.h"
17
#include "Im2Col.h"
18 19 20 21 22
#include "paddle/math/MemoryHandle.h"

namespace paddle {

/*
23
 * \brief Forward calculation of convolution.
24 25 26 27 28 29 30 31
 */
template <DeviceType Device>
class GemmConvFunction : public ConvFunctionBase {
public:
  void init(const FuncConfig& config) override {
    ConvFunctionBase::init(config);
  }

L
liaogang 已提交
32
  void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
H
hedaoyuan 已提交
33 34 35 36 37 38
    const TensorShape& input = inputs[0].shape();
    const TensorShape& filter = inputs[1].shape();
    const TensorShape& output = outputs[0].shape();
    checkShape(input, filter, output);
  }

39
  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
40 41
    CHECK_EQ(numInputs_, inputs.size());
    CHECK_EQ(numOutputs_, outputs.size());
H
hedaoyuan 已提交
42
    check(inputs, outputs);
43 44 45 46 47 48 49 50 51 52 53 54
    // TODO(hedaoyuan): Need to define some index macros,
    // to avoid useing 0 and 1.
    const TensorShape& input = inputs[0].shape();
    const TensorShape& filter = inputs[1].shape();
    const TensorShape& output = outputs[0].shape();

    real beta;
    if (outputs[0].getArgType() == ADD_TO) {
      beta = 1.0;
    } else {
      beta = 0.0;
    }
55

H
hedaoyuan 已提交
56 57 58 59 60 61 62 63 64
    size_t batchSize = input[0];
    size_t inputChannels = input[1];
    size_t inputHeight = input[2];
    size_t inputWidth = input[3];
    size_t filterHeight = getFilterHeight(filter);
    size_t filterWidth = getFilterWidth(filter);
    size_t outputChannels = output[1];
    size_t outputHeight = output[2];
    size_t outputWidth = output[3];
65 66 67 68

    real* inputData = inputs[0].data<real>();
    real* filterData = inputs[1].data<real>();
    real* outputData = outputs[0].data<real>();
69
    bool needIm2col = isNeedIm2col(filter);
70

71 72 73
    TensorShape imShape =
        TensorShape({inputChannels / groups_, inputHeight, inputWidth});

74
    TensorShape colShape;
75
    real* colData = NULL;
76

77
    if (needIm2col) {
78 79 80 81 82 83 84 85
      colShape = TensorShape({inputChannels / groups_,
                              filterHeight,
                              filterWidth,
                              outputHeight,
                              outputWidth});
      resizeBuffer<Device>(colShape.getElements());
      colData = reinterpret_cast<real*>(memory_->getBuf());
    }
86

87 88
    Im2ColFunctor<kCFO, Device, real> im2col;
    size_t inputOffset = imShape.getElements();
89 90
    size_t outputOffset =
        (outputChannels / groups_) * outputHeight * outputWidth;
H
hedaoyuan 已提交
91 92
    size_t filterOffset = filter.getElements() / groups_;

93
    for (size_t i = 0; i < batchSize; i++) {
94
      for (size_t g = 0; g < groups_; g++) {
95
        if (needIm2col) {
96 97 98 99 100 101 102 103
          im2col(inputData + g * inputOffset,
                 imShape,
                 colData,
                 colShape,
                 strideH(),
                 strideW(),
                 paddingH(),
                 paddingW());
104 105
        } else {
          colData = inputData + g * inputOffset;
106
        }
H
Bug fix  
hedaoyuan 已提交
107
        int M = outputChannels / groups_;
108
        int N = outputHeight * outputWidth;
H
Bug fix  
hedaoyuan 已提交
109
        int K = inputChannels / groups_ * filterHeight * filterWidth;
H
hedaoyuan 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122
        BlasGemm<Device, real>::compute(false,
                                        false,
                                        M,
                                        N,
                                        K,
                                        1.0f,
                                        filterData + g * filterOffset,
                                        K,
                                        colData,
                                        N,
                                        beta,
                                        outputData + g * outputOffset,
                                        N);
123
      }
H
hedaoyuan 已提交
124 125
      inputData += inputChannels * inputHeight * inputWidth;
      outputData += outputChannels * outputHeight * outputWidth;
126 127 128 129
    }
  }
};

130 131 132 133 134 135 136 137 138 139
/*
 * \brief Backward input calculation of convolution.
 */
template <DeviceType Device>
class GemmConvGradInputFunction : public ConvFunctionBase {
public:
  void init(const FuncConfig& config) override {
    ConvFunctionBase::init(config);
  }

L
liaogang 已提交
140
  void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
H
hedaoyuan 已提交
141 142 143 144 145 146
    const TensorShape& output = inputs[0].shape();
    const TensorShape& filter = inputs[1].shape();
    const TensorShape& input = outputs[0].shape();
    checkShape(input, filter, output);
  }

147 148 149
  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
    CHECK_EQ(numInputs_, inputs.size());
    CHECK_EQ(numOutputs_, outputs.size());
H
hedaoyuan 已提交
150
    check(inputs, outputs);
H
hedaoyuan 已提交
151 152 153
    // Since the implementation of Col2ImFunctor is ADD_TO,
    // this function only supports ADD_TO mode.
    CHECK_EQ(outputs[0].getArgType(), ADD_TO);
154
    const TensorShape& output = inputs[0].shape();
155
    const TensorShape& filter = inputs[1].shape();
156 157 158 159 160 161
    const TensorShape& input = outputs[0].shape();

    size_t batchSize = input[0];
    size_t inputChannels = input[1];
    size_t inputHeight = input[2];
    size_t inputWidth = input[3];
H
hedaoyuan 已提交
162 163
    size_t filterHeight = getFilterHeight(filter);
    size_t filterWidth = getFilterWidth(filter);
164 165 166 167 168 169 170
    size_t outputChannels = output[1];
    size_t outputHeight = output[2];
    size_t outputWidth = output[3];

    real* outputGrad = inputs[0].data<real>();
    real* filterData = inputs[1].data<real>();
    real* inputGrad = outputs[0].data<real>();
171
    bool needIm2col = isNeedIm2col(filter);
172

173 174 175
    TensorShape imShape =
        TensorShape({inputChannels / groups_, inputHeight, inputWidth});

176
    TensorShape colShape;
177
    real* colData = NULL;
178

179
    if (needIm2col) {
180 181 182 183 184 185 186 187
      colShape = TensorShape({inputChannels / groups_,
                              filterHeight,
                              filterWidth,
                              outputHeight,
                              outputWidth});
      resizeBuffer<Device>(colShape.getElements());
      colData = reinterpret_cast<real*>(memory_->getBuf());
    }
188

189 190
    Col2ImFunctor<kCFO, Device, real> col2im;
    size_t inputOffset = imShape.getElements();
H
format  
hedaoyuan 已提交
191
    size_t outputOffset =
192 193 194 195 196 197 198 199
        (outputChannels / groups_) * outputHeight * outputWidth;
    size_t filterOffset = filter.getElements() / groups_;

    for (size_t i = 0; i < batchSize; i++) {
      for (size_t g = 0; g < groups_; g++) {
        int K = outputChannels / groups_;
        int N = outputHeight * outputWidth;
        int M = inputChannels / groups_ * filterHeight * filterWidth;
200
        real scale = 0.0f;
201 202
        if (!needIm2col) {
          colData = inputGrad + g * inputOffset;
203 204
          scale = 1.0f;
        }
H
hedaoyuan 已提交
205 206 207 208 209 210 211 212 213 214 215 216 217
        BlasGemm<Device, real>::compute(true,
                                        false,
                                        M,
                                        N,
                                        K,
                                        1.0f,
                                        filterData + g * filterOffset,
                                        M,
                                        outputGrad + g * outputOffset,
                                        N,
                                        scale,
                                        colData,
                                        N);
218
        if (needIm2col) {
219 220
          col2im(inputGrad + g * inputOffset,
                 imShape,
221
                 colData,
222 223 224 225 226 227
                 colShape,
                 strideH(),
                 strideW(),
                 paddingH(),
                 paddingW());
        }
228 229 230 231
      }
      inputGrad += inputChannels * inputHeight * inputWidth;
      outputGrad += outputChannels * outputHeight * outputWidth;
    }
232 233 234 235 236 237 238 239 240 241 242 243 244
  }
};

/*
 * \brief Backward filter calculation of convolution.
 */
template <DeviceType Device>
class GemmConvGradFilterFunction : public ConvFunctionBase {
public:
  void init(const FuncConfig& config) override {
    ConvFunctionBase::init(config);
  }

L
liaogang 已提交
245
  void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
H
hedaoyuan 已提交
246 247 248 249 250 251
    const TensorShape& output = inputs[0].shape();
    const TensorShape& input = inputs[1].shape();
    const TensorShape& filter = outputs[0].shape();
    checkShape(input, filter, output);
  }

252 253 254
  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
    CHECK_EQ(numInputs_, inputs.size());
    CHECK_EQ(numOutputs_, outputs.size());
H
hedaoyuan 已提交
255
    check(inputs, outputs);
256
    const TensorShape& output = inputs[0].shape();
257
    const TensorShape& input = inputs[1].shape();
258 259
    const TensorShape& filter = outputs[0].shape();

260 261 262 263 264 265 266
    real beta;
    if (outputs[0].getArgType() == ADD_TO) {
      beta = 1.0;
    } else {
      beta = 0.0;
    }

267 268 269 270
    size_t batchSize = input[0];
    size_t inputChannels = input[1];
    size_t inputHeight = input[2];
    size_t inputWidth = input[3];
H
hedaoyuan 已提交
271 272
    size_t filterHeight = getFilterHeight(filter);
    size_t filterWidth = getFilterWidth(filter);
273 274 275 276 277 278 279
    size_t outputChannels = output[1];
    size_t outputHeight = output[2];
    size_t outputWidth = output[3];

    real* outputGrad = inputs[0].data<real>();
    real* inputData = inputs[1].data<real>();
    real* filterGrad = outputs[0].data<real>();
280
    bool needIm2col = isNeedIm2col(filter);
281

282 283 284
    TensorShape imShape =
        TensorShape({inputChannels / groups_, inputHeight, inputWidth});

285
    TensorShape colShape;
286
    real* colData = NULL;
287

288
    if (needIm2col) {
289 290 291 292 293 294 295 296
      colShape = TensorShape({inputChannels / groups_,
                              filterHeight,
                              filterWidth,
                              outputHeight,
                              outputWidth});
      resizeBuffer<Device>(colShape.getElements());
      colData = reinterpret_cast<real*>(memory_->getBuf());
    }
297

298 299
    Im2ColFunctor<kCFO, Device, real> im2col;
    size_t inputOffset = imShape.getElements();
300 301 302 303 304
    size_t outputOffset =
        (outputChannels / groups_) * outputHeight * outputWidth;
    size_t filterOffset = filter.getElements() / groups_;
    for (size_t i = 0; i < batchSize; i++) {
      for (size_t g = 0; g < groups_; g++) {
305
        if (needIm2col) {
306 307 308 309 310 311 312 313
          im2col(inputData + g * inputOffset,
                 imShape,
                 colData,
                 colShape,
                 strideH(),
                 strideW(),
                 paddingH(),
                 paddingW());
314 315
        } else {
          colData = inputData + g * inputOffset;
316
        }
317 318 319
        int M = outputChannels / groups_;
        int K = outputHeight * outputWidth;
        int N = inputChannels / groups_ * filterHeight * filterWidth;
H
hedaoyuan 已提交
320 321 322 323 324 325 326 327 328 329 330 331 332
        BlasGemm<Device, real>::compute(false,
                                        true,
                                        M,
                                        N,
                                        K,
                                        1.0f,
                                        outputGrad + g * outputOffset,
                                        K,
                                        colData,
                                        K,
                                        i == 0 ? beta : 1.0f,
                                        filterGrad + g * filterOffset,
                                        N);
333
      }
334 335
      inputData += inputChannels * inputHeight * inputWidth;
      outputGrad += outputChannels * outputHeight * outputWidth;
336
    }
337 338 339
  }
};

340
REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction);
341 342
REGISTER_TYPED_FUNC(GemmConvGradInput, CPU, GemmConvGradInputFunction);
REGISTER_TYPED_FUNC(GemmConvGradFilter, CPU, GemmConvGradFilterFunction);
343
#ifdef PADDLE_WITH_GPU
344
REGISTER_TYPED_FUNC(GemmConv, GPU, GemmConvFunction);
345 346
REGISTER_TYPED_FUNC(GemmConvGradInput, GPU, GemmConvGradInputFunction);
REGISTER_TYPED_FUNC(GemmConvGradFilter, GPU, GemmConvGradFilterFunction);
H
hedaoyuan 已提交
347
#endif
348 349

}  // namespace paddle