maxouting.cc 3.8 KB
Newer Older
W
wanghaox 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/math/maxouting.h"
W
wanghaox 已提交
16 17 18 19 20

namespace paddle {
namespace operators {
namespace math {

W
wanghaox 已提交
21
// All tensors are in NCHW format, and the groups must be greater than 1
W
wanghaox 已提交
22
template <typename T>
Q
QI JUN 已提交
23
class MaxOutFunctor<platform::CPUDeviceContext, T> {
W
wanghaox 已提交
24
 public:
Q
QI JUN 已提交
25
  void operator()(const platform::CPUDeviceContext& context,
26
                  const framework::Tensor& input, framework::Tensor* output,
W
wanghaox 已提交
27
                  int groups) {
W
wanghaox 已提交
28 29 30
    const int batch_size = input.dims()[0];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
W
wanghaox 已提交
31
    const int output_channels = output->dims()[1];
W
wanghaox 已提交
32
    int fea_size = input_height * input_width;
W
wanghaox 已提交
33
    // c_size means the output size of each sample
W
wanghaox 已提交
34 35
    int c_size = fea_size * output_channels;
    const T* input_data = input.data<T>();
W
wanghaox 已提交
36
    T* output_data = output->mutable_data<T>(context.GetPlace());
W
wanghaox 已提交
37

W
wanghaox 已提交
38
    for (int i = 0; i < batch_size; ++i) {
39
      int new_bindex = c_size * i;
W
wanghaox 已提交
40 41
      for (int c = 0; c < output_channels; ++c) {
        int new_cindex = fea_size * c;
W
wanghaox 已提交
42
        for (int f = 0; f < fea_size; ++f) {
W
wanghaox 已提交
43
          T ele = static_cast<T>(-FLT_MAX);
W
wanghaox 已提交
44
          for (int ph = 0; ph < groups; ++ph) {
45 46
            T x = input_data[(new_bindex + new_cindex) * groups +
                             ph * fea_size + f];
W
wanghaox 已提交
47
            ele = ele > x ? ele : x;
W
wanghaox 已提交
48
          }
49
          output_data[(new_bindex + new_cindex + f)] = ele;
W
wanghaox 已提交
50 51 52 53 54 55 56
        }
      }
    }
  }
};

template <class T>
Q
QI JUN 已提交
57
class MaxOutGradFunctor<platform::CPUDeviceContext, T> {
58
 public:
Q
QI JUN 已提交
59
  void operator()(const platform::CPUDeviceContext& context,
60
                  const framework::Tensor& input, framework::Tensor* input_grad,
W
wanghaox 已提交
61
                  const framework::Tensor& output,
62
                  const framework::Tensor& output_grad, int groups) {
W
wanghaox 已提交
63 64 65
    const int batch_size = input.dims()[0];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
W
wanghaox 已提交
66
    const int output_channels = output.dims()[1];
W
wanghaox 已提交
67 68 69 70
    int fea_size = input_height * input_width;
    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
W
wanghaox 已提交
71
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
W
wanghaox 已提交
72

W
wanghaox 已提交
73
    for (int i = 0; i < batch_size; ++i) {
W
wanghaox 已提交
74 75 76
      int blen = fea_size * output_channels * i;
      for (int c = 0; c < output_channels; ++c) {
        int clen = fea_size * c;
W
wanghaox 已提交
77
        for (int f = 0; f < fea_size; ++f) {
W
wanghaox 已提交
78 79
          int input_idx0 = (blen + clen) * groups + f;
          bool continue_match = true;
W
wanghaox 已提交
80
          int output_idx = blen + clen + f;
W
wanghaox 已提交
81
          for (int g = 0; g < groups && continue_match; ++g) {
82 83 84 85 86
            int input_idx = input_idx0 + fea_size * g;
            if (input_data[input_idx] == output_data[output_idx]) {
              input_grad_data[input_idx] += output_grad_data[output_idx];
              continue_match = false;
            }
W
wanghaox 已提交
87 88 89 90 91 92 93
          }
        }
      }
    }
  }
};

Q
QI JUN 已提交
94 95 96 97
template class MaxOutGradFunctor<platform::CPUDeviceContext, float>;
template class MaxOutGradFunctor<platform::CPUDeviceContext, double>;
template class MaxOutFunctor<platform::CPUDeviceContext, float>;
template class MaxOutFunctor<platform::CPUDeviceContext, double>;
W
wanghaox 已提交
98 99 100 101

}  // namespace math
}  // namespace operators
}  // namespace paddle