prelu_op_plugin.cu 6.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <stdio.h>
16

17
#include <cassert>
N
nhzlx 已提交
18
#include <vector>
19

20 21
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
N
nhzlx 已提交
22
#include "paddle/fluid/operators/math/prelu.h"
23 24 25 26

namespace paddle {
namespace inference {
namespace tensorrt {
27
namespace plugin {
28

N
nhzlx 已提交
29 30 31 32
int PReluPlugin::initialize() {
  cudaMalloc(&p_gpu_weight_, sizeof(float) * weight_.size());
  cudaMemcpy(p_gpu_weight_, weight_.data(), weight_.size() * sizeof(float),
             cudaMemcpyHostToDevice);
33
  return 0;
N
nhzlx 已提交
34 35
}

36 37 38 39 40 41 42
void PReluPlugin::terminate() {
  if (p_gpu_weight_) {
    cudaFree(p_gpu_weight_);
    p_gpu_weight_ = nullptr;
  }
}

43 44 45 46 47 48 49 50 51 52
nvinfer1::Dims PReluPlugin::getOutputDimensions(int index,
                                                const nvinfer1::Dims *inputDims,
                                                int nbInputs) {
  assert(nbInputs == 1);
  assert(index < this->getNbOutputs());
  nvinfer1::Dims const &input_dims = inputDims[0];
  nvinfer1::Dims output_dims = input_dims;
  return output_dims;
}

N
nhzlx 已提交
53
int PReluPlugin::enqueue(int batch_size, const void *const *inputs,
54
#if IS_TRT_VERSION_LT(8000)
55
                         void **outputs, void *workspace, cudaStream_t stream) {
56 57 58 59
#else
                         void *const *outputs, void *workspace,
                         cudaStream_t stream) {
#endif
60 61 62
  // input dims is CHW.
  const auto &input_dims = this->getInputDims(0);
  const float *input = reinterpret_cast<const float *>(inputs[0]);
N
nhzlx 已提交
63 64
  // const float *alpha = reinterpret_cast<const float *>(alpha_.get().values);
  const float *alpha = p_gpu_weight_;
65
  float *output = reinterpret_cast<float **>(outputs)[0];
66
  int numel = 1;
N
nhzlx 已提交
67
  for (int i = 0; i < input_dims.nbDims; i++) {
68
    numel *= input_dims.d[i];
N
nhzlx 已提交
69 70
  }

71
  if (mode_ == "channel") {
N
nhzlx 已提交
72 73
    operators::math::PreluChannelWiseDirectCUDAFunctor<float>
        prelu_channel_wise;
74 75
    prelu_channel_wise(stream, input, alpha, output, input_dims.d[0],
                       input_dims.d[1], numel);
76
  } else if (mode_ == "element") {
N
nhzlx 已提交
77 78
    operators::math::PreluElementWiseDirectCUDAFunctor<float>
        prelu_element_wise;
79
    prelu_element_wise(stream, input, alpha, output, input_dims.d[0], numel);
80
  } else {
N
nhzlx 已提交
81
    operators::math::PreluScalarDirectCUDAFunctor<float> prelu_scalar;
82
    prelu_scalar(stream, input, alpha, output, numel);
83 84 85 86
  }
  return cudaGetLastError() != cudaSuccess;
}

87 88
#if IS_TRT_VERSION_GE(6000)

89 90 91 92 93 94
void PReluPluginDynamic::terminate() {
  if (p_gpu_weight_) {
    cudaFree(p_gpu_weight_);
  }
}

95 96 97 98 99 100 101
int PReluPluginDynamic::initialize() {
  cudaMalloc(&p_gpu_weight_, sizeof(float) * weight_.size());
  cudaMemcpy(p_gpu_weight_, weight_.data(), weight_.size() * sizeof(float),
             cudaMemcpyHostToDevice);
  return 0;
}

102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
PReluPluginDynamic::PReluPluginDynamic(void const *serialData,
                                       size_t serialLength) {
  DeserializeValue(&serialData, &serialLength, &weight_);
  const char *prelu_mode;
  DeserializeValue(&serialData, &serialLength, &prelu_mode);
  mode_ = std::string(prelu_mode);
}

size_t PReluPluginDynamic::getSerializationSize() const {
  return SerializedSize(mode_.c_str()) + SerializedSize(weight_);
}

void PReluPluginDynamic::serialize(void *buffer) const {
  SerializeValue(&buffer, weight_);
  SerializeValue(&buffer, mode_.c_str());
}
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139

nvinfer1::DimsExprs PReluPluginDynamic::getOutputDimensions(
    int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
    nvinfer1::IExprBuilder &expr_builder) {
  return inputs[0];
}

bool PReluPluginDynamic::supportsFormatCombination(
    int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
    int nb_outputs) {
  PADDLE_ENFORCE_NOT_NULL(
      in_out, platform::errors::InvalidArgument(
                  "The input of swish plugin shoule not be nullptr."));

  PADDLE_ENFORCE_LT(
      pos, nb_inputs + nb_outputs,
      platform::errors::InvalidArgument("The pos(%d) should be less than the "
                                        "num(%d) of the input and the output.",
                                        pos, nb_inputs + nb_outputs));
  (in_out && pos < (nb_inputs + nb_outputs));

  return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
140
          in_out[pos].format == nvinfer1::PluginFormat::kLINEAR);
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
}

nvinfer1::DataType PReluPluginDynamic::getOutputDataType(
    int index, const nvinfer1::DataType *input_types, int nb_inputs) const {
  PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
                                  "The PRelu Plugin only has one input, so the "
                                  "index value should be 0, but get %d.",
                                  index));
  PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT), true,
                    platform::errors::InvalidArgument(
                        "The input type should be half or float"));
  return input_types[0];
}

int PReluPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
                                const nvinfer1::PluginTensorDesc *output_desc,
                                const void *const *inputs, void *const *outputs,
                                void *workspace, cudaStream_t stream) {
  auto input_dims = input_desc[0].dims;
  const float *alpha = p_gpu_weight_;
  const float *input = static_cast<const float *>(inputs[0]);
  float *output = static_cast<float *>(outputs[0]);
163
  int numel = 1;
164
  for (int i = 0; i < input_dims.nbDims; i++) {
165
    numel *= input_dims.d[i];
166 167 168 169 170
  }

  if (mode_ == "channel") {
    operators::math::PreluChannelWiseDirectCUDAFunctor<float>
        prelu_channel_wise;
171 172
    prelu_channel_wise(stream, input, alpha, output, input_dims.d[0],
                       input_dims.d[1], numel);
173 174 175
  } else if (mode_ == "element") {
    operators::math::PreluElementWiseDirectCUDAFunctor<float>
        prelu_element_wise;
176
    prelu_element_wise(stream, input, alpha, output, input_dims.d[0], numel);
177 178
  } else {
    operators::math::PreluScalarDirectCUDAFunctor<float> prelu_scalar;
179
    prelu_scalar(stream, input, alpha, output, numel);
180 181 182 183 184
  }
  return cudaGetLastError() != cudaSuccess;
}
#endif

185
}  // namespace plugin
186 187 188
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle