maxouting.cu 7.7 KB
Newer Older
1
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved.
W
wanghaox 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/math/maxouting.h"
16
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
17
#include "paddle/phi/backends/gpu/gpu_context.h"
W
wanghaox 已提交
18 19 20 21 22

namespace paddle {
namespace operators {
namespace math {

W
wanghaox 已提交
23
template <typename T>
24 25 26 27 28 29 30 31
__global__ void KernelMaxOut(const int nthreads,
                             const T* input_data,
                             const int channels,
                             const int input_height,
                             const int input_width,
                             const int groups,
                             const int axis,
                             T* output_data) {
W
wanghaox 已提交
32 33
  const int size = input_height * input_width * channels / groups;
  const int feat_len = input_height * input_width;
W
wanghaox 已提交
34 35 36 37 38
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (int i = index; i < nthreads; i += offset) {
    int batch_idx = i / size;
    int batch_offset = i % size;
39 40 41 42 43 44 45 46 47 48 49 50
    int channel_idx, feat_idx, data_idx;
    if (axis == 1) {
      channel_idx = batch_offset / feat_len;
      feat_idx = batch_offset % feat_len;
      data_idx =
          (batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
    } else {
      channel_idx = batch_offset % channels;
      feat_idx = batch_offset / channels;
      data_idx =
          (batch_idx * size + feat_idx * channels + channel_idx) * groups;
    }
W
wanghaox 已提交
51
    T ele = static_cast<T>(-FLT_MAX);
W
wanghaox 已提交
52
    for (int g = 0; g < groups; ++g) {
53 54
      int idx_offset = (axis == 1 ? g * feat_len : g);
      T x = input_data[data_idx + idx_offset];
W
wanghaox 已提交
55
      ele = ele > x ? ele : x;
W
wanghaox 已提交
56
    }
W
wanghaox 已提交
57
    output_data[i] = ele;
W
wanghaox 已提交
58 59 60
  }
}
template <typename T>
61 62 63 64 65 66 67 68 69 70
__global__ void KernelMaxoutGrad(const int nthreads,
                                 const T* input_data,
                                 const T* output_data,
                                 const T* output_grad,
                                 T* input_grad,
                                 const int channels,
                                 const int input_height,
                                 const int input_width,
                                 const int groups,
                                 const int axis) {
71 72 73 74 75 76 77
  const int size = input_height * input_width * channels / groups;
  const int feat_len = input_height * input_width;
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (int i = index; i < nthreads; i += offset) {
    int batch_idx = i / size;
    int batch_offset = i % size;
78 79 80 81 82 83 84 85 86 87 88 89
    int channel_idx, feat_idx, data_idx;
    if (axis == 1) {
      channel_idx = batch_offset / feat_len;
      feat_idx = batch_offset % feat_len;
      data_idx =
          (batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
    } else {
      channel_idx = batch_offset % channels;
      feat_idx = batch_offset / channels;
      data_idx =
          (batch_idx * size + feat_idx * channels + channel_idx) * groups;
    }
90 91 92
    int max_index = -1;
    bool continue_match = true;
    for (int g = 0; g < groups && continue_match; ++g) {
93 94 95
      int idx_offset = (axis == 1 ? g * feat_len : g);
      if (input_data[data_idx + idx_offset] == output_data[i]) {
        max_index = data_idx + idx_offset;
96 97
        continue_match = false;
        break;
W
wanghaox 已提交
98 99
      }
    }
100 101 102 103
    if (max_index != -1) {
      input_grad[max_index] += output_grad[index];
    }
  }
W
wanghaox 已提交
104 105
}
/*
106
 * All tensors are in NCHW or NHWC format.
W
wanghaox 已提交
107
 */
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
template <typename DeviceContext, typename T>
void MaxOutFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
                                                 const framework::Tensor& input,
                                                 framework::Tensor* output,
                                                 const int groups,
                                                 const int axis) {
  const int batch_size = input.dims()[0];
  const int input_channels = input.dims()[axis];
  const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
  const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
  const int output_channels = output->dims()[axis];

  const T* input_data = input.data<T>();
  T* output_data = output->mutable_data<T>(context.GetPlace());
  int nthreads = output->numel();
  int blocks = (nthreads + 1024 - 1) / 1024;
  dim3 threads(1024, 1);
  dim3 grid(blocks, 1);

127 128 129 130 131 132 133 134
  KernelMaxOut<T><<<grid, threads, 0, context.stream()>>>(nthreads,
                                                          input_data,
                                                          input_channels,
                                                          input_height,
                                                          input_width,
                                                          groups,
                                                          axis,
                                                          output_data);
135 136
}

W
wanghaox 已提交
137
/*
138
 * All tensors are in NCHW or NHWC format.
W
wanghaox 已提交
139
 */
140 141
template <typename DeviceContext, typename T>
void MaxOutGradFunctor<DeviceContext, T>::operator()(
142 143 144 145 146 147 148
    const DeviceContext& context,
    const framework::Tensor& input,
    framework::Tensor* input_grad,
    const framework::Tensor& output,
    const framework::Tensor& output_grad,
    const int groups,
    const int axis) {
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
  const int batch_size = input.dims()[0];
  const int input_channels = input.dims()[axis];
  const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
  const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
  const int output_channels = output.dims()[axis];

  const T* input_data = input.data<T>();
  const T* output_data = output.data<T>();
  const T* output_grad_data = output_grad.data<T>();
  T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
  int nthreads = output.numel();
  int blocks = (nthreads + 1024 - 1) / 1024;
  dim3 threads(1024, 1);
  dim3 grid(blocks, 1);

164 165 166 167 168 169 170 171 172 173
  KernelMaxoutGrad<T><<<grid, threads, 0, context.stream()>>>(nthreads,
                                                              input_data,
                                                              output_data,
                                                              output_grad_data,
                                                              input_grad_data,
                                                              input_channels,
                                                              input_height,
                                                              input_width,
                                                              groups,
                                                              axis);
174
}
W
wanghaox 已提交
175

Q
QI JUN 已提交
176 177
template class MaxOutGradFunctor<platform::CUDADeviceContext, float>;
template class MaxOutGradFunctor<platform::CUDADeviceContext, double>;
W
wanghaox 已提交
178

Q
QI JUN 已提交
179 180
template class MaxOutFunctor<platform::CUDADeviceContext, float>;
template class MaxOutFunctor<platform::CUDADeviceContext, double>;
W
wanghaox 已提交
181

182 183 184 185 186 187
template class MaxOutGradFunctor<phi::GPUContext, float>;
template class MaxOutGradFunctor<phi::GPUContext, double>;

template class MaxOutFunctor<phi::GPUContext, float>;
template class MaxOutFunctor<phi::GPUContext, double>;

W
wanghaox 已提交
188 189 190
}  // namespace math
}  // namespace operators
}  // namespace paddle