maxouting.cc 4.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaox 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/math/maxouting.h"
W
wanghaox 已提交
16 17 18 19 20

namespace paddle {
namespace operators {
namespace math {

21
// All tensors are in NCHW or NHWC format, and the groups must be greater than 1
W
wanghaox 已提交
22
template <typename T>
Q
QI JUN 已提交
23
class MaxOutFunctor<platform::CPUDeviceContext, T> {
W
wanghaox 已提交
24
 public:
Q
QI JUN 已提交
25
  void operator()(const platform::CPUDeviceContext& context,
26
                  const framework::Tensor& input, framework::Tensor* output,
27
                  const int groups, const int axis) {
W
wanghaox 已提交
28
    const int batch_size = input.dims()[0];
29 30 31
    const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
    const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
    const int output_channels = output->dims()[axis];
W
wanghaox 已提交
32
    int fea_size = input_height * input_width;
W
wanghaox 已提交
33
    // c_size means the output size of each sample
W
wanghaox 已提交
34 35
    int c_size = fea_size * output_channels;
    const T* input_data = input.data<T>();
W
wanghaox 已提交
36
    T* output_data = output->mutable_data<T>(context.GetPlace());
W
wanghaox 已提交
37
    for (int i = 0; i < batch_size; ++i) {
38
      int new_bindex = c_size * i;
W
wanghaox 已提交
39 40
      for (int c = 0; c < output_channels; ++c) {
        int new_cindex = fea_size * c;
W
wanghaox 已提交
41
        for (int f = 0; f < fea_size; ++f) {
W
wanghaox 已提交
42
          T ele = static_cast<T>(-FLT_MAX);
43
          int input_idx, output_idx;
W
wanghaox 已提交
44
          for (int ph = 0; ph < groups; ++ph) {
45 46 47 48 49 50 51
            if (axis == 1) {
              input_idx =
                  (new_bindex + new_cindex) * groups + ph * fea_size + f;
            } else {
              input_idx = (new_bindex + f * output_channels + c) * groups + ph;
            }
            T x = input_data[input_idx];
W
wanghaox 已提交
52
            ele = ele > x ? ele : x;
W
wanghaox 已提交
53
          }
54 55 56 57 58 59
          if (axis == 1) {
            output_idx = new_bindex + new_cindex + f;
          } else {
            output_idx = new_bindex + f * output_channels + c;
          }
          output_data[output_idx] = ele;
W
wanghaox 已提交
60 61 62 63 64 65 66
        }
      }
    }
  }
};

template <class T>
Q
QI JUN 已提交
67
class MaxOutGradFunctor<platform::CPUDeviceContext, T> {
68
 public:
Q
QI JUN 已提交
69
  void operator()(const platform::CPUDeviceContext& context,
70
                  const framework::Tensor& input, framework::Tensor* input_grad,
W
wanghaox 已提交
71
                  const framework::Tensor& output,
72 73
                  const framework::Tensor& output_grad, const int groups,
                  const int axis) {
W
wanghaox 已提交
74
    const int batch_size = input.dims()[0];
75 76 77
    const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
    const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
    const int output_channels = output.dims()[axis];
W
wanghaox 已提交
78 79 80 81
    int fea_size = input_height * input_width;
    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
W
wanghaox 已提交
82
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
W
wanghaox 已提交
83

W
wanghaox 已提交
84
    for (int i = 0; i < batch_size; ++i) {
W
wanghaox 已提交
85 86 87
      int blen = fea_size * output_channels * i;
      for (int c = 0; c < output_channels; ++c) {
        int clen = fea_size * c;
W
wanghaox 已提交
88
        for (int f = 0; f < fea_size; ++f) {
89
          int input_idx0, output_idx;
W
wanghaox 已提交
90
          bool continue_match = true;
91 92 93 94 95 96 97
          if (axis == 1) {
            input_idx0 = (blen + clen) * groups + f;
            output_idx = blen + clen + f;
          } else {
            input_idx0 = (blen + f * output_channels + c) * groups;
            output_idx = blen + f * output_channels + c;
          }
W
wanghaox 已提交
98
          for (int g = 0; g < groups && continue_match; ++g) {
99 100
            int idx_offset = (axis == 1 ? fea_size * g : g);
            int input_idx = input_idx0 + idx_offset;
101 102 103 104
            if (input_data[input_idx] == output_data[output_idx]) {
              input_grad_data[input_idx] += output_grad_data[output_idx];
              continue_match = false;
            }
W
wanghaox 已提交
105 106 107 108 109 110 111
          }
        }
      }
    }
  }
};

Q
QI JUN 已提交
112 113 114 115
template class MaxOutGradFunctor<platform::CPUDeviceContext, float>;
template class MaxOutGradFunctor<platform::CPUDeviceContext, double>;
template class MaxOutFunctor<platform::CPUDeviceContext, float>;
template class MaxOutFunctor<platform::CPUDeviceContext, double>;
W
wanghaox 已提交
116 117 118 119

}  // namespace math
}  // namespace operators
}  // namespace paddle