maxouting.cu 5.9 KB
Newer Older
W
wanghaox 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/math/maxouting.h"
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {
namespace math {

template <typename MaxOutProcess, typename T>
__global__ void KernelMaxOut(const int nthreads, const T* input_data,
                             T* output_data, const int channels,
                             const int input_height, const int input_width,
                             int groups, MaxOutProcess maxout_process) {
W
wanghaox 已提交
27 28
  const int size = input_height * input_width * channels / groups;
  const int feat_len = input_height * input_width;
W
wanghaox 已提交
29 30 31
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
              index += blockDim.x * gridDim.x) {
    int batch_idx = index / size;
W
wanghaox 已提交
32 33 34
    int batch_offset = index % size;
    int channel_idx = batch_offset / feat_len;
    int feat_idx = batch_offset % feat_len;
W
wanghaox 已提交
35
    int data_idx =
W
wanghaox 已提交
36
      (batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
W
wanghaox 已提交
37
    T ele = maxout_process.initial();
W
wanghaox 已提交
38 39
    for (int g = 0; g < groups; ++g) {
      maxout_process.compute(ele, input_data[data_idx + g * feat_len]);
W
wanghaox 已提交
40 41 42 43 44 45 46 47 48
    }
    output_data[index] = ele;
  }
}
template <typename T>
__global__ void KernelMaxoutGrad(
    const int nthreads, const T* input_data, const T* output_data,
    const T* output_grad, T* input_grad, const int channels,
    const int input_height, const int input_width, int groups) {
W
wanghaox 已提交
49 50
    const int size = input_height * input_width * channels / groups;
    const int feat_len = input_height * input_width;
W
wanghaox 已提交
51 52 53
    for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
         index += blockDim.x * gridDim.x) {
      int batch_idx = index / size;
W
wanghaox 已提交
54 55 56
      int batch_offset = index % size;
      int channel_idx = batch_offset / feat_len;
      int feat_idx = batch_offset % feat_len;
W
wanghaox 已提交
57
      int data_idx =
W
wanghaox 已提交
58
        (batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
W
wanghaox 已提交
59 60 61
      int maxIndex = -1;
      bool stop = false;
      for (int g = 0; g < groups && !stop; g++) {
W
wanghaox 已提交
62 63
        if (input_data[data_idx + g * feat_len] == output_data[index]) {
          maxIndex = data_idx + g * feat_len;
W
wanghaox 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
          stop = true;
        }
      }
      if (maxIndex != -1) {
        // atomic add
        platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]);
      }
    }
}
/*
 * All tensors are in NCHW format.
 */
template <typename MaxOutProcess, typename T>
class MaxOutFunctor<platform::GPUPlace, MaxOutProcess, T> {
 public:
  void operator()(const platform::DeviceContext& context,
W
wanghaox 已提交
80 81
                  const framework::Tensor& input, framework::Tensor * output,
                  int groups,
W
wanghaox 已提交
82 83 84 85 86
                  MaxOutProcess maxout_process) {
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
W
wanghaox 已提交
87 88 89
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
W
wanghaox 已提交
90 91

    const T* input_data = input.data<T>();
W
wanghaox 已提交
92 93
    T* output_data = output->mutable_data<T>(context.GetPlace());
    int nthreads =  output->numel();
W
wanghaox 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

    KernelMaxOut<
        MaxOutProcess,
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(nthreads, input_data, output_data, input_channels,
                              input_height, input_width, groups,
                              maxout_process);
  }
};
/*
 * All tensors are in NCHW format.
 */
template <typename T>
class MaxOutGradFunctor<platform::GPUPlace, T> {
 public:
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& input, framework::Tensor& input_grad,
                  const framework::Tensor& output,
                  const framework::Tensor& output_grad,
W
wanghaox 已提交
117
                  int groups) {
W
wanghaox 已提交
118 119 120 121 122 123 124 125 126 127 128 129
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
W
wanghaox 已提交
130
    int nthreads =  output.numel();
W
wanghaox 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

    KernelMaxoutGrad<
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(
        nthreads, input_data, output_data, output_grad_data, input_grad_data,
        input_channels, input_height, input_width, groups);
  }
};

template class MaxOutGradFunctor<platform::GPUPlace, float>;
template class MaxOutGradFunctor<platform::GPUPlace, double>;

template class MaxOutFunctor<platform::GPUPlace,
W
wanghaox 已提交
148
                             math::MaxOut<float>, float>;
W
wanghaox 已提交
149
template class MaxOutFunctor<platform::GPUPlace,
W
wanghaox 已提交
150
                             math::MaxOut<double>, double>;
W
wanghaox 已提交
151 152 153 154

}  // namespace math
}  // namespace operators
}  // namespace paddle