CropOp.cpp 5.4 KB
Newer Older
W
wanghaoshuang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "CropOp.h"
#include "paddle/function/TensorShape.h"
W
wanghaoshuang 已提交
17 18
#include "paddle/math/Vector.h"

W
wanghaoshuang 已提交
19 20 21 22 23 24
namespace paddle {

template <>
void Crop<DEVICE_TYPE_CPU>(real* outputs,
                           const real* inputs,
                           const TensorShape inShape,
25
                           const TensorShape outShape,
26 27 28 29 30 31
                           const FuncConfig& conf) {
  std::vector<uint32_t> crop_corner =
      conf.get<std::vector<uint32_t>>("crop_corner");
  int cCrop = crop_corner[1];
  int hCrop = crop_corner[2];
  int wCrop = crop_corner[3];
W
wanghaoshuang 已提交
32 33 34 35 36 37

  int num = inShape[0];
  int inC = inShape[1];
  int inH = inShape[2];
  int inW = inShape[3];

38 39 40
  int outC = outShape[1];
  int outH = outShape[2];
  int outW = outShape[3];
W
wanghaoshuang 已提交
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55

  for (int n = 0; n < num; n++) {
    for (int c = 0; c < outC; c++) {
      for (int h = 0; h < outH; h++) {
        int outoff = ((n * outC + c) * outH + h) * outW;
        int inoff = ((n * inC + c + cCrop) * inH + h + hCrop) * inW + wCrop;
        memcpy(outputs + outoff, inputs + inoff, outW * sizeof(real));
      }
    }
  }
}

template <>
void CropGrad<DEVICE_TYPE_CPU>(const real* inGrad,
                               real* outGrad,
56
                               const TensorShape inShape,
W
wanghaoshuang 已提交
57
                               const TensorShape outShape,
58 59 60 61 62 63
                               const FuncConfig& conf) {
  std::vector<uint32_t> crop_corner =
      conf.get<std::vector<uint32_t>>("crop_corner");
  int cCrop = crop_corner[1];
  int hCrop = crop_corner[2];
  int wCrop = crop_corner[3];
W
wanghaoshuang 已提交
64 65 66 67 68 69

  int num = outShape[0];
  int outC = outShape[1];
  int outH = outShape[2];
  int outW = outShape[3];

70 71 72
  int inC = inShape[1];
  int inH = inShape[2];
  int inW = inShape[3];
W
wanghaoshuang 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115

  for (int n = 0; n < num; n++) {
    for (int c = 0; c < inC; c++) {
      for (int h = 0; h < inH; h++) {
        int outoff = ((n * outC + c + cCrop) * outH + h + hCrop) * outW + wCrop;
        int inoff = ((n * inC + c) * inH + h) * inW;
        CpuVector inG = CpuVector(inW, const_cast<real*>(inGrad + inoff));
        CpuVector outG = CpuVector(inW, outGrad + outoff);
        outG += inG;
      }
    }
  }
}

/**
 * \brief Crop input according to the specify corner and shape.
 *        The input and output is a 4D tensor. In CropFunc, we only
 *        crop the 2nd to 4th dimension.
 *
 * Argument in this Function:
 * \param pad_    A struct object contains the cropping corner and shape.
 * \param inputs  A 4D tensor, only one input.
 * \param outputs A 4D tensor, the output value after cropping.
 *
 * For example,
 * Input(2,2,2,3) = [
 *                    [ [[1,2,3], [3,4,5]],
 *                      [[2,3,5], [1,6,7]] ],
 *                    [ [[4,3,1], [1,8,7]],
 *                      [[3,8,9], [2,3,5]] ]
 *                  ] # the input shape is (2,2,2,3)
 *
 * pad_: if corner = (0,1,1) and crop_shape = (2,1,2)
 * Output(2,2,1,2) = [
 *                    [ [[4,5]],
 *                      [[6,7]] ],
 *                    [ [[8,7]],
 *                      [[3,5]] ]
 *                  ] # the input shape is (2,2,2,3)
 */
template <DeviceType Device>
class CropFunc : public FunctionBase {
public:
116
  void init(const FuncConfig& config) override { conf_ = config; }
W
wanghaoshuang 已提交
117 118 119 120 121 122 123

  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
    CHECK_EQ(1UL, inputs.size());
    CHECK_EQ(1UL, outputs.size());
    CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);

    TensorShape inShape = inputs[0].shape();
124
    TensorShape outShape = outputs[0].shape();
W
wanghaoshuang 已提交
125

126 127 128 129 130
    Crop<Device>(outputs[0].data<real>(),
                 inputs[0].data<real>(),
                 inShape,
                 outShape,
                 conf_);
W
wanghaoshuang 已提交
131 132 133
  }

private:
134
  FuncConfig conf_;
W
wanghaoshuang 已提交
135 136 137 138 139 140 141 142 143 144 145 146 147 148
};

/**
 * \brief The backward propagation of cropping Function.
 *
 * Argument in this Function:
 * \param crop_    The same meaning as it in CropFunc.
 * \param inputs  The gradient with respect to the output value of CropFunc.
 * \param outputs The gradient with respect to the input value of CropFunc.
 */

template <DeviceType Device>
class CropGradFunc : public FunctionBase {
public:
149
  void init(const FuncConfig& config) override { conf_ = config; }
W
wanghaoshuang 已提交
150 151 152 153

  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
    CHECK_EQ(1UL, inputs.size());
    CHECK_EQ(1UL, outputs.size());
154
    CHECK_EQ(outputs[0].getArgType(), ADD_TO);
W
wanghaoshuang 已提交
155 156

    TensorShape outShape = outputs[0].shape();
157
    TensorShape inShape = inputs[0].shape();
W
wanghaoshuang 已提交
158

159 160 161 162 163
    CropGrad<Device>(inputs[0].data<real>(),
                     outputs[0].data<real>(),
                     inShape,
                     outShape,
                     conf_);
W
wanghaoshuang 已提交
164 165 166
  }

private:
167
  FuncConfig conf_;
W
wanghaoshuang 已提交
168 169 170 171
};

REGISTER_TYPED_FUNC(Crop, CPU, CropFunc);
REGISTER_TYPED_FUNC(CropGrad, CPU, CropGradFunc);
172
#ifdef PADDLE_WITH_CUDA
W
wanghaoshuang 已提交
173 174 175 176 177
REGISTER_TYPED_FUNC(Crop, GPU, CropFunc);
REGISTER_TYPED_FUNC(CropGrad, GPU, CropGradFunc);
#endif

}  // namespace paddle