SwitchOp.cpp 4.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "SwitchOp.h"
#include "paddle/math/Vector.h"

namespace paddle {

template <>
void NCHW2NHWC<DEVICE_TYPE_CPU>(real* outputs,
                                const real* inputs,
                                const int num,
                                const int inC,
                                const int inH,
26 27
                                const int inW,
                                const int argType) {
28 29 30 31
  for (int n = 0; n < num; ++n) {
    for (int c = 0; c < inC; ++c) {
      for (int h = 0; h < inH; ++h) {
        for (int w = 0; w < inW; ++w) {
32 33 34 35 36
          if (argType == ADD_TO) {
            outputs[((n * inH + h) * inW + w) * inC + c] += *(inputs++);
          } else {
            outputs[((n * inH + h) * inW + w) * inC + c] = *(inputs++);
          }
37 38 39 40 41 42 43 44 45 46 47 48
        }
      }
    }
  }
}

template <>
void NHWC2NCHW<DEVICE_TYPE_CPU>(real* outputs,
                                const real* inputs,
                                const int num,
                                const int inH,
                                const int inW,
49 50
                                const int inC,
                                const int argType) {
51 52 53 54
  for (int n = 0; n < num; ++n) {
    for (int h = 0; h < inH; ++h) {
      for (int w = 0; w < inW; ++w) {
        for (int c = 0; c < inC; ++c) {
55 56 57 58 59
          if (argType == ADD_TO) {
            outputs[((n * inC + c) * inH + h) * inW + w] += *(inputs++);
          } else {
            outputs[((n * inC + c) * inH + h) * inW + w] = *(inputs++);
          }
60 61 62 63 64 65 66
        }
      }
    }
  }
}

/**
67 68 69 70
 * \brief  Switch dimension order of image input.
 *         The input and output is a 4D tensor. Switch order
 *         'batch_size,channels, height, width' to
 *         order 'batch_size, height, width, channels'.
71 72
 *
 * Argument in this Function:
73 74
 * \param inputs  input data with order 'batch_size,channels, height, width'.
 * \param outputs output data with order 'batch_size, height, width, channels'.
75 76 77
 */
template <DeviceType Device>
class NCHW2NHWCFunc : public FunctionBase {
W
Wu Yi 已提交
78
 public:
79 80 81 82 83 84 85 86 87 88
  void init(const FuncConfig& config) override {}

  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
    CHECK_EQ(1UL, inputs.size());
    CHECK_EQ(1UL, outputs.size());

    size_t num = inputs[0].shape()[0];
    size_t inC = inputs[0].shape()[1];
    size_t inH = inputs[0].shape()[2];
    size_t inW = inputs[0].shape()[3];
89 90 91 92 93 94 95
    NCHW2NHWC<Device>(outputs[0].data<real>(),
                      inputs[0].data<real>(),
                      num,
                      inC,
                      inH,
                      inW,
                      outputs[0].getArgType());
96 97 98 99
  }
};

/**
100 101 102 103
 * \brief  Switch dimension order of image input.
 *         The input and output is a 4D tensor. Switch order
 *         'batch_size, height, width, channels' to
 *         order 'batch_size, channels, height, width'.
104 105
 *
 * Argument in this Function:
106 107
 * \param inputs  input data with order 'batch_size, height, width, channels'.
 * \param outputs output data with order 'batch_size, channels, height, width'.
108 109 110
 */
template <DeviceType Device>
class NHWC2NCHWFunc : public FunctionBase {
W
Wu Yi 已提交
111
 public:
112 113 114 115 116 117 118 119 120 121 122
  void init(const FuncConfig& config) override {}

  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
    CHECK_EQ(1UL, inputs.size());
    CHECK_EQ(1UL, outputs.size());

    size_t num = inputs[0].shape()[0];
    size_t inH = inputs[0].shape()[1];
    size_t inW = inputs[0].shape()[2];
    size_t inC = inputs[0].shape()[3];

123 124 125 126 127 128 129
    NHWC2NCHW<Device>(outputs[0].data<real>(),
                      inputs[0].data<real>(),
                      num,
                      inH,
                      inW,
                      inC,
                      outputs[0].getArgType());
130 131 132 133 134
  }
};

REGISTER_TYPED_FUNC(NCHW2NHWC, CPU, NCHW2NHWCFunc);
REGISTER_TYPED_FUNC(NHWC2NCHW, CPU, NHWC2NCHWFunc);
135
#ifdef PADDLE_WITH_CUDA
136 137 138 139 140
REGISTER_TYPED_FUNC(NCHW2NHWC, GPU, NCHW2NHWCFunc);
REGISTER_TYPED_FUNC(NHWC2NCHW, GPU, NHWC2NCHWFunc);
#endif

}  // namespace paddle