prelu.cu 5.5 KB
Newer Older
N
nhzlx 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/math/prelu.h"

namespace paddle {
namespace operators {
namespace math {

static const int CUDA_NUM_THREADS = 1024;
static const int CUDA_MAX_NUM_BLOCKS = 65535;
inline static int GET_NUM_BLOCKS(const int N) {
  return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}

template <typename T>
__global__ void PReluChannelWiseKernel(const T *input, const T *alpha,
                                       T *output, int channel,
                                       size_t spatial_size) {
  size_t offset = blockIdx.x * spatial_size;
  const T *in = input + offset;
  T *out = output + offset;
  T scale = alpha[blockIdx.x % channel];

  for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
    T x = in[i];
    out[i] = (x > 0) ? x : scale * x;
  }
}

template <typename T>
__global__ void PReluElementWiseKernel(const T *input, const T *alpha,
                                       T *output, size_t spatial_size) {
  size_t offset = blockIdx.x * spatial_size;
  const T *in = input + offset;
  const T *scale = alpha + offset;
  T *out = output + offset;

  for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
    T x = in[i];
    out[i] = (x > 0) ? x : scale[i] * x;
  }
}

template <typename T>
__global__ void PReluScalarKernel(const T *input, const T *alpha, T *output,
                                  size_t spatial_size) {
  size_t offset = blockIdx.x * spatial_size;
  const T *in = input + offset;
  T scale = *alpha;
  T *out = output + offset;

  for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) {
    T x = in[i];
    out[i] = (x > 0) ? x : scale * x;
  }
}

template <typename T>
static inline void PReluChannelWise(cudaStream_t stream, const T *input,
                                    const T *alpha, T *output,
                                    std::vector<int> input_shape) {
  size_t unroll = input_shape[0] * input_shape[1];
  size_t spatial_size = input_shape[2] * input_shape[3];
  CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
  PReluChannelWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
      input, alpha, output, input_shape[1], spatial_size);
}

template <typename T>
static inline void PReluElementWise(cudaStream_t stream, const T *input,
                                    const T *alpha, T *output,
                                    std::vector<int> input_shape) {
  size_t unroll = input_shape[0] * input_shape[1];
  size_t spatial_size = input_shape[2] * input_shape[3];
  CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
  PReluElementWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
      input, alpha, output, spatial_size);
}

template <typename T>
static inline void PReluScalar(cudaStream_t stream, const T *input,
                               const T *alpha, T *output,
                               std::vector<int> input_shape) {
  size_t unroll = input_shape[0] * input_shape[1];
  size_t spatial_size = input_shape[2] * input_shape[3];
  CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
  PReluScalarKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
      input, alpha, output, spatial_size);
}

template <typename T>
void PreluChannelWiseDirectCUDAFunctor<T>::operator()(
    cudaStream_t stream, const T *input, const T *alpha, T *output,
    std::vector<int> input_shape) {
  size_t unroll = input_shape[0] * input_shape[1];
  size_t spatial_size = input_shape[2] * input_shape[3];
  CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
  PReluChannelWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
      input, alpha, output, input_shape[1], spatial_size);
}

template <typename T>
void PreluElementWiseDirectCUDAFunctor<T>::operator()(
    cudaStream_t stream, const T *input, const T *alpha, T *output,
    std::vector<int> input_shape) {
  size_t unroll = input_shape[0] * input_shape[1];
  size_t spatial_size = input_shape[2] * input_shape[3];
  CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
  PReluElementWiseKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
      input, alpha, output, spatial_size);
}

template <typename T>
void PreluScalarDirectCUDAFunctor<T>::operator()(cudaStream_t stream,
                                                 const T *input, const T *alpha,
                                                 T *output,
                                                 std::vector<int> input_shape) {
  size_t unroll = input_shape[0] * input_shape[1];
  size_t spatial_size = input_shape[2] * input_shape[3];
  CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS);
  PReluScalarKernel<<<unroll, CUDA_NUM_THREADS, 0, stream>>>(
      input, alpha, output, spatial_size);
}

template class PreluChannelWiseDirectCUDAFunctor<float>;
template class PreluChannelWiseDirectCUDAFunctor<double>;

template class PreluElementWiseDirectCUDAFunctor<float>;
template class PreluElementWiseDirectCUDAFunctor<double>;

template class PreluScalarDirectCUDAFunctor<float>;
template class PreluScalarDirectCUDAFunctor<double>;

}  // namespace math
}  // namespace operators
}  // namespace paddle