prelu.cu 4.2 KB
Newer Older
N
nhzlx 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/math/prelu.h"

namespace paddle {
namespace operators {
namespace math {

21 22 23 24 25 26 27 28
#define CUDA_NUM_THREADS 1024

// CUDA: grid stride looping
#define CUDA_KERNEL_LOOP(i, n)                                 \
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
       i += blockDim.x * gridDim.x)

inline static int PADDLE_GET_BLOCKS(const int N) {
N
nhzlx 已提交
29 30 31 32 33
  return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}

template <typename T>
__global__ void PReluChannelWiseKernel(const T *input, const T *alpha,
34 35 36 37 38 39 40 41 42
                                       T *output, size_t channel_num,
                                       size_t plane_size, size_t numel) {
  size_t index;
  CUDA_KERNEL_LOOP(index, numel) {
    size_t temp = index / plane_size;
    size_t channel_index = temp % channel_num;
    T scale = alpha[channel_index];
    T x = input[index];
    output[index] = (x > 0) ? x : scale * x;
N
nhzlx 已提交
43 44 45 46 47
  }
}

template <typename T>
__global__ void PReluElementWiseKernel(const T *input, const T *alpha,
48 49 50 51 52 53 54 55
                                       T *output, size_t spatial_size,
                                       size_t numel) {
  size_t index;
  CUDA_KERNEL_LOOP(index, numel) {
    size_t element_index = index % spatial_size;
    T scale = alpha[element_index];
    T x = input[index];
    output[index] = (x > 0) ? x : scale * x;
N
nhzlx 已提交
56 57 58 59 60
  }
}

template <typename T>
__global__ void PReluScalarKernel(const T *input, const T *alpha, T *output,
61 62 63 64 65 66
                                  size_t numel) {
  T scale = alpha[0];
  size_t index;
  CUDA_KERNEL_LOOP(index, numel) {
    T x = input[index];
    output[index] = (x > 0) ? x : scale * x;
N
nhzlx 已提交
67 68 69 70 71 72 73
  }
}

template <typename T>
void PreluChannelWiseDirectCUDAFunctor<T>::operator()(
    cudaStream_t stream, const T *input, const T *alpha, T *output,
    std::vector<int> input_shape) {
74 75 76 77 78 79
  size_t plane_size = input_shape[2] * input_shape[3];
  size_t spatial_size = input_shape[1] * plane_size;
  size_t numel = input_shape[0] * spatial_size;
  PReluChannelWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0,
                           stream>>>(input, alpha, output, input_shape[1],
                                     plane_size, numel);
N
nhzlx 已提交
80 81 82 83 84 85
}

template <typename T>
void PreluElementWiseDirectCUDAFunctor<T>::operator()(
    cudaStream_t stream, const T *input, const T *alpha, T *output,
    std::vector<int> input_shape) {
86 87 88 89 90
  size_t plane_size = input_shape[2] * input_shape[3];
  size_t spatial_size = input_shape[1] * plane_size;
  size_t numel = input_shape[0] * spatial_size;
  PReluElementWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0,
                           stream>>>(input, alpha, output, spatial_size, numel);
N
nhzlx 已提交
91 92 93 94 95 96 97
}

template <typename T>
void PreluScalarDirectCUDAFunctor<T>::operator()(cudaStream_t stream,
                                                 const T *input, const T *alpha,
                                                 T *output,
                                                 std::vector<int> input_shape) {
98 99 100 101 102
  size_t plane_size = input_shape[2] * input_shape[3];
  size_t spatial_size = input_shape[1] * plane_size;
  size_t numel = input_shape[0] * spatial_size;
  PReluScalarKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, stream>>>(
      input, alpha, output, numel);
N
nhzlx 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116
}

template class PreluChannelWiseDirectCUDAFunctor<float>;
template class PreluChannelWiseDirectCUDAFunctor<double>;

template class PreluElementWiseDirectCUDAFunctor<float>;
template class PreluElementWiseDirectCUDAFunctor<double>;

template class PreluScalarDirectCUDAFunctor<float>;
template class PreluScalarDirectCUDAFunctor<double>;

}  // namespace math
}  // namespace operators
}  // namespace paddle