maxouting.cu 6.2 KB
Newer Older
1
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved.
W
wanghaox 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/math/maxouting.h"
D
dzhwinter 已提交
16
#include "paddle/fluid/platform/cuda_primitives.h"
W
wanghaox 已提交
17 18 19 20 21

namespace paddle {
namespace operators {
namespace math {

W
wanghaox 已提交
22
template <typename T>
W
wanghaox 已提交
23
__global__ void KernelMaxOut(const int nthreads, const T* input_data,
24
                             const int channels, const int input_height,
25 26
                             const int input_width, const int groups,
                             const int axis, T* output_data) {
W
wanghaox 已提交
27 28
  const int size = input_height * input_width * channels / groups;
  const int feat_len = input_height * input_width;
W
wanghaox 已提交
29 30 31 32 33
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (int i = index; i < nthreads; i += offset) {
    int batch_idx = i / size;
    int batch_offset = i % size;
34 35 36 37 38 39 40 41 42 43 44 45
    int channel_idx, feat_idx, data_idx;
    if (axis == 1) {
      channel_idx = batch_offset / feat_len;
      feat_idx = batch_offset % feat_len;
      data_idx =
          (batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
    } else {
      channel_idx = batch_offset % channels;
      feat_idx = batch_offset / channels;
      data_idx =
          (batch_idx * size + feat_idx * channels + channel_idx) * groups;
    }
W
wanghaox 已提交
46
    T ele = static_cast<T>(-FLT_MAX);
W
wanghaox 已提交
47
    for (int g = 0; g < groups; ++g) {
48 49
      int idx_offset = (axis == 1 ? g * feat_len : g);
      T x = input_data[data_idx + idx_offset];
W
wanghaox 已提交
50
      ele = ele > x ? ele : x;
W
wanghaox 已提交
51
    }
W
wanghaox 已提交
52
    output_data[i] = ele;
W
wanghaox 已提交
53 54 55
  }
}
template <typename T>
56 57 58 59
__global__ void KernelMaxoutGrad(const int nthreads, const T* input_data,
                                 const T* output_data, const T* output_grad,
                                 T* input_grad, const int channels,
                                 const int input_height, const int input_width,
60
                                 const int groups, const int axis) {
61 62 63 64 65 66 67
  const int size = input_height * input_width * channels / groups;
  const int feat_len = input_height * input_width;
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (int i = index; i < nthreads; i += offset) {
    int batch_idx = i / size;
    int batch_offset = i % size;
68 69 70 71 72 73 74 75 76 77 78 79
    int channel_idx, feat_idx, data_idx;
    if (axis == 1) {
      channel_idx = batch_offset / feat_len;
      feat_idx = batch_offset % feat_len;
      data_idx =
          (batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
    } else {
      channel_idx = batch_offset % channels;
      feat_idx = batch_offset / channels;
      data_idx =
          (batch_idx * size + feat_idx * channels + channel_idx) * groups;
    }
80 81 82
    int max_index = -1;
    bool continue_match = true;
    for (int g = 0; g < groups && continue_match; ++g) {
83 84 85
      int idx_offset = (axis == 1 ? g * feat_len : g);
      if (input_data[data_idx + idx_offset] == output_data[i]) {
        max_index = data_idx + idx_offset;
86 87
        continue_match = false;
        break;
W
wanghaox 已提交
88 89
      }
    }
90 91 92 93
    if (max_index != -1) {
      input_grad[max_index] += output_grad[index];
    }
  }
W
wanghaox 已提交
94 95
}
/*
96
 * All tensors are in NCHW or NHWC format.
W
wanghaox 已提交
97
 */
W
wanghaox 已提交
98
template <typename T>
Q
QI JUN 已提交
99
class MaxOutFunctor<platform::CUDADeviceContext, T> {
W
wanghaox 已提交
100
 public:
Q
QI JUN 已提交
101
  void operator()(const platform::CUDADeviceContext& context,
102
                  const framework::Tensor& input, framework::Tensor* output,
103
                  const int groups, const int axis) {
W
wanghaox 已提交
104
    const int batch_size = input.dims()[0];
105 106 107 108
    const int input_channels = input.dims()[axis];
    const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
    const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
    const int output_channels = output->dims()[axis];
W
wanghaox 已提交
109 110

    const T* input_data = input.data<T>();
W
wanghaox 已提交
111
    T* output_data = output->mutable_data<T>(context.GetPlace());
112
    int nthreads = output->numel();
W
wanghaox 已提交
113 114 115 116
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
117 118
    KernelMaxOut<T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, input_channels, input_height, input_width, groups,
119
        axis, output_data);
W
wanghaox 已提交
120 121 122
  }
};
/*
123
 * All tensors are in NCHW or NHWC format.
W
wanghaox 已提交
124 125
 */
template <typename T>
Q
QI JUN 已提交
126
class MaxOutGradFunctor<platform::CUDADeviceContext, T> {
W
wanghaox 已提交
127
 public:
Q
QI JUN 已提交
128
  void operator()(const platform::CUDADeviceContext& context,
129
                  const framework::Tensor& input, framework::Tensor* input_grad,
W
wanghaox 已提交
130
                  const framework::Tensor& output,
131 132
                  const framework::Tensor& output_grad, const int groups,
                  const int axis) {
W
wanghaox 已提交
133
    const int batch_size = input.dims()[0];
134 135 136 137
    const int input_channels = input.dims()[axis];
    const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
    const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
    const int output_channels = output.dims()[axis];
W
wanghaox 已提交
138 139 140 141

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
W
wanghaox 已提交
142
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
143
    int nthreads = output.numel();
W
wanghaox 已提交
144 145 146 147
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
148 149
    KernelMaxoutGrad<T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, output_data, output_grad_data, input_grad_data,
150
        input_channels, input_height, input_width, groups, axis);
W
wanghaox 已提交
151 152 153
  }
};

Q
QI JUN 已提交
154 155
template class MaxOutGradFunctor<platform::CUDADeviceContext, float>;
template class MaxOutGradFunctor<platform::CUDADeviceContext, double>;
W
wanghaox 已提交
156

Q
QI JUN 已提交
157 158
template class MaxOutFunctor<platform::CUDADeviceContext, float>;
template class MaxOutFunctor<platform::CUDADeviceContext, double>;
W
wanghaox 已提交
159 160 161 162

}  // namespace math
}  // namespace operators
}  // namespace paddle