gelu_op_plugin.cu 7.7 KB
Newer Older
P
Pei Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"

namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {

// constants for approximating the normal cdf
27 28 29 30 31
static const float kA = 1.41421356237309504;  // sqrt(2)

static const float kAT = 0.5;
static const float kBT = 0.7978845608028654;    // sqrt(2.0/M_PI)
static const float kCT = 0.035677408136300125;  // 0.044715 * sqrt(2.0/M_PI)
P
Pei Yang 已提交
32 33 34 35

GeluPlugin* CreateGeluPluginDeserialize(const void* buffer, size_t length) {
  return new GeluPlugin(buffer, length);
}
36 37 38 39 40 41 42 43 44 45 46 47 48 49

REGISTER_TRT_PLUGIN("gelu_plugin", CreateGeluPluginDeserialize);

bool GeluPlugin::supportsFormat(nvinfer1::DataType type,
                                nvinfer1::PluginFormat format) const {
#ifdef SUPPORTS_CUDA_FP16
  return ((type == nvinfer1::DataType::kFLOAT ||
           type == nvinfer1::DataType::kHALF) &&
          (format == nvinfer1::PluginFormat::kNCHW));
#else
  return ((type == nvinfer1::DataType::kFLOAT) &&
          (format == nvinfer1::PluginFormat::kNCHW));
#endif
}
P
Pei Yang 已提交
50 51 52 53 54 55 56 57 58 59 60 61

nvinfer1::Dims GeluPlugin::getOutputDimensions(int index,
                                               const nvinfer1::Dims* in_dims,
                                               int nb_inputs) {
  assert(nb_inputs == 1);
  assert(index < this->getNbOutputs());
  nvinfer1::Dims const& input_dims = in_dims[0];
  nvinfer1::Dims output_dims = input_dims;
  return output_dims;
}

template <typename T, unsigned TPB>
62
__global__ void gelu_kernel(const T a, int n, const T* input, T* output) {
P
Pei Yang 已提交
63 64 65 66 67 68 69 70
  const int idx = blockIdx.x * TPB + threadIdx.x;
  if (idx < n) {
    const T in = input[idx];
    const T cdf = 0.5 * (1.0 + erf(in * 0.5 * a));
    output[idx] = in * cdf;
  }
}

71 72 73 74 75 76
template <typename T>
__device__ T do_tanh(T a);

template <>
__device__ float do_tanh<float>(float a) {
  return tanf(a);
P
Pei Yang 已提交
77 78
}

79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
template <>
__device__ half do_tanh<half>(half a) {
  const float tmp = tanhf(__half2float(a));
  return __float2half(tmp);
}

// the kernel below is not aligned with fluid fp32 forwrad ones, use it for
// fp16.
template <typename T, unsigned TPB>
__global__ void no_exact_gelu_kernel(const T a, const T b, const T c, int n,
                                     const T* input, T* output) {
  const int idx = blockIdx.x * TPB + threadIdx.x;
  if (idx < n) {
    const T in = input[idx];
    const T tmp = in * (c * in * in + b);
    const T cdf = a + a * do_tanh<T>(tmp);
    output[idx] = in * cdf;
  }
}

int GeluPlugin::enqueue(int batch_size, const void* const* inputs,
P
Pei Yang 已提交
100
                        void** outputs, void*, cudaStream_t stream) {
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
  const auto& input_dims = this->getInputDims(0);
  int num = batch_size;
  for (int i = 0; i < input_dims.nbDims; i++) {
    num *= input_dims.d[i];
  }
  const int block_size = 256;
  const int grid_size = (num + block_size - 1) / block_size;

  auto type = getDataType();
  if (type == nvinfer1::DataType::kFLOAT) {
    const float* input = static_cast<const float*>(inputs[0]);
    float* output = static_cast<float*>(outputs[0]);
    gelu_kernel<float, block_size><<<grid_size, block_size, 0, stream>>>(
        kA, num, input, output);
  } else if (type == nvinfer1::DataType::kHALF) {
#ifdef SUPPORTS_CUDA_FP16
    const half* input = static_cast<const half*>(inputs[0]);
    half* output = static_cast<half*>(outputs[0]);
    no_exact_gelu_kernel<half,
                         block_size><<<grid_size, block_size, 0, stream>>>(
        kAT, kBT, kCT, num, input, output);
#else
    PADDLE_THROW(platform::errors::Fatal(
        "The cuda archs you specific should greater than 600."));
#endif
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "The Gelu TRT Plugin's input type should be float or half."));
  }
  return cudaGetLastError() != cudaSuccess;
}

// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)

nvinfer1::DimsExprs GeluPluginDynamic::getOutputDimensions(
    int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
    nvinfer1::IExprBuilder& expr_builder) {
  return inputs[0];
}

bool GeluPluginDynamic::supportsFormatCombination(
    int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs,
    int nb_outputs) {
  PADDLE_ENFORCE_NOT_NULL(
      in_out, platform::errors::InvalidArgument(
                  "The input of swish plugin shoule not be nullptr."));

  PADDLE_ENFORCE_LT(
      pos, nb_inputs + nb_outputs,
      platform::errors::InvalidArgument("The pos(%d) should be less than the "
                                        "num(%d) of the input and the output.",
                                        pos, nb_inputs + nb_outputs));
  (in_out && pos < (nb_inputs + nb_outputs));

  const nvinfer1::PluginTensorDesc& in = in_out[pos];
  if (pos == 0) {
#ifdef SUPPORTS_CUDA_FP16
    return (in.type == nvinfer1::DataType::kFLOAT ||
            in.type == nvinfer1::DataType::kHALF) &&
           (in.format == nvinfer1::TensorFormat::kLINEAR);
#else
    return (in.type == nvinfer1::DataType::kFLOAT) &&
           (in.format == nvinfer1::TensorFormat::kLINEAR);
#endif
  }
  const nvinfer1::PluginTensorDesc& prev = in_out[pos - 1];
  // output
  return in.type == prev.type && in.format == prev.format;
}

nvinfer1::DataType GeluPluginDynamic::getOutputDataType(
    int index, const nvinfer1::DataType* input_types, int nb_inputs) const {
  PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
                                  "The Gelu Plugin only has one input, so the "
                                  "index value should be 0, but get %d.",
                                  index));
  return input_types[0];
}

int GeluPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
                               const nvinfer1::PluginTensorDesc* output_desc,
                               const void* const* inputs, void* const* outputs,
                               void* workspace, cudaStream_t stream) {
  auto input_dims = input_desc[0].dims;
  size_t num = ProductDim(input_dims);
  const int block_size = 256;
  const int grid_size = (num + block_size - 1) / block_size;

  auto input_type = input_desc[0].type;
  if (input_type == nvinfer1::DataType::kFLOAT) {
    const float* input = static_cast<const float*>(inputs[0]);
    float* output = static_cast<float*>(outputs[0]);
194 195
    gelu_kernel<float, block_size><<<grid_size, block_size, 0, stream>>>(
        kA, num, input, output);
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
  } else if (input_type == nvinfer1::DataType::kHALF) {
#ifdef SUPPORTS_CUDA_FP16
    const half* input = static_cast<const half*>(inputs[0]);
    half* output = static_cast<half*>(outputs[0]);
    no_exact_gelu_kernel<half,
                         block_size><<<grid_size, block_size, 0, stream>>>(
        kAT, kBT, kCT, num, input, output);
#else
    PADDLE_THROW(platform::errors::Fatal(
        "The cuda archs you specific should greater than 600."));
#endif
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "The Gelu TRT Plugin's input type should be float or half."));
  }
  return cudaGetLastError() != cudaSuccess;
P
Pei Yang 已提交
212
}
213
#endif
P
Pei Yang 已提交
214 215 216 217 218

}  // namespace plugin
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle