prelu_op_plugin.cu 7.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <stdio.h>
16

17
#include <cassert>
N
nhzlx 已提交
18
#include <vector>
19

20 21
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
N
nhzlx 已提交
22
#include "paddle/fluid/operators/math/prelu.h"
23 24 25 26

namespace paddle {
namespace inference {
namespace tensorrt {
27
namespace plugin {
28

29
int PReluPlugin::initialize() TRT_NOEXCEPT {
N
nhzlx 已提交
30
  cudaMalloc(&p_gpu_weight_, sizeof(float) * weight_.size());
31 32 33
  cudaMemcpy(p_gpu_weight_,
             weight_.data(),
             weight_.size() * sizeof(float),
N
nhzlx 已提交
34
             cudaMemcpyHostToDevice);
35
  return 0;
N
nhzlx 已提交
36 37
}

38
void PReluPlugin::terminate() TRT_NOEXCEPT {
39 40 41 42 43 44
  if (p_gpu_weight_) {
    cudaFree(p_gpu_weight_);
    p_gpu_weight_ = nullptr;
  }
}

45 46
nvinfer1::Dims PReluPlugin::getOutputDimensions(int index,
                                                const nvinfer1::Dims *inputDims,
47
                                                int nbInputs) TRT_NOEXCEPT {
48 49 50 51 52 53 54
  assert(nbInputs == 1);
  assert(index < this->getNbOutputs());
  nvinfer1::Dims const &input_dims = inputDims[0];
  nvinfer1::Dims output_dims = input_dims;
  return output_dims;
}

55 56
int PReluPlugin::enqueue(int batch_size,
                         const void *const *inputs,
57
#if IS_TRT_VERSION_LT(8000)
58 59 60
                         void **outputs,
                         void *workspace,
                         cudaStream_t stream) {
61
#else
62 63
                         void *const *outputs,
                         void *workspace,
64
                         cudaStream_t stream) TRT_NOEXCEPT {
65
#endif
66 67 68
  // input dims is CHW.
  const auto &input_dims = this->getInputDims(0);
  const float *input = reinterpret_cast<const float *>(inputs[0]);
N
nhzlx 已提交
69 70
  // const float *alpha = reinterpret_cast<const float *>(alpha_.get().values);
  const float *alpha = p_gpu_weight_;
71
  float *const output = reinterpret_cast<float *const *>(outputs)[0];
72
  int numel = 1;
N
nhzlx 已提交
73
  for (int i = 0; i < input_dims.nbDims; i++) {
74
    numel *= input_dims.d[i];
N
nhzlx 已提交
75 76
  }

77
  if (mode_ == "channel") {
78
    bool channel_last = data_format_ == "NHWC";
N
nhzlx 已提交
79 80
    operators::math::PreluChannelWiseDirectCUDAFunctor<float>
        prelu_channel_wise;
81 82 83 84 85 86 87 88
    prelu_channel_wise(stream,
                       input,
                       alpha,
                       output,
                       input_dims.d[0],
                       input_dims.d[1],
                       channel_last,
                       numel);
89
  } else if (mode_ == "element") {
N
nhzlx 已提交
90 91
    operators::math::PreluElementWiseDirectCUDAFunctor<float>
        prelu_element_wise;
92
    prelu_element_wise(stream, input, alpha, output, input_dims.d[0], numel);
93
  } else {
N
nhzlx 已提交
94
    operators::math::PreluScalarDirectCUDAFunctor<float> prelu_scalar;
95
    prelu_scalar(stream, input, alpha, output, numel);
96 97 98 99
  }
  return cudaGetLastError() != cudaSuccess;
}

100 101
#if IS_TRT_VERSION_GE(6000)

102
void PReluPluginDynamic::terminate() TRT_NOEXCEPT {
103 104 105 106 107
  if (p_gpu_weight_) {
    cudaFree(p_gpu_weight_);
  }
}

108
int PReluPluginDynamic::initialize() TRT_NOEXCEPT {
109
  cudaMalloc(&p_gpu_weight_, sizeof(float) * weight_.size());
110 111 112
  cudaMemcpy(p_gpu_weight_,
             weight_.data(),
             weight_.size() * sizeof(float),
113 114 115 116
             cudaMemcpyHostToDevice);
  return 0;
}

117 118 119 120 121 122 123 124
PReluPluginDynamic::PReluPluginDynamic(void const *serialData,
                                       size_t serialLength) {
  DeserializeValue(&serialData, &serialLength, &weight_);
  const char *prelu_mode;
  DeserializeValue(&serialData, &serialLength, &prelu_mode);
  mode_ = std::string(prelu_mode);
}

125
size_t PReluPluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
126 127 128
  return SerializedSize(mode_.c_str()) + SerializedSize(weight_);
}

129
void PReluPluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {
130 131 132
  SerializeValue(&buffer, weight_);
  SerializeValue(&buffer, mode_.c_str());
}
133 134

nvinfer1::DimsExprs PReluPluginDynamic::getOutputDimensions(
135 136 137
    int output_index,
    const nvinfer1::DimsExprs *inputs,
    int nb_inputs,
138
    nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT {
139 140 141 142
  return inputs[0];
}

bool PReluPluginDynamic::supportsFormatCombination(
143 144 145
    int pos,
    const nvinfer1::PluginTensorDesc *in_out,
    int nb_inputs,
146
    int nb_outputs) TRT_NOEXCEPT {
147
  PADDLE_ENFORCE_NOT_NULL(
148 149 150
      in_out,
      platform::errors::InvalidArgument(
          "The input of swish plugin shoule not be nullptr."));
151 152

  PADDLE_ENFORCE_LT(
153 154
      pos,
      nb_inputs + nb_outputs,
155 156
      platform::errors::InvalidArgument("The pos(%d) should be less than the "
                                        "num(%d) of the input and the output.",
157 158
                                        pos,
                                        nb_inputs + nb_outputs));
159 160 161
  (in_out && pos < (nb_inputs + nb_outputs));

  return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
162
          in_out[pos].format == nvinfer1::PluginFormat::kLINEAR);
163 164 165
}

nvinfer1::DataType PReluPluginDynamic::getOutputDataType(
166 167
    int index,
    const nvinfer1::DataType *input_types,
168
    int nb_inputs) const TRT_NOEXCEPT {
169 170
  PADDLE_ENFORCE_EQ(index,
                    0,
171 172 173 174
                    platform::errors::InvalidArgument(
                        "The PRelu Plugin only has one input, so the "
                        "index value should be 0, but get %d.",
                        index));
175 176
  PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT),
                    true,
177 178 179 180 181 182 183
                    platform::errors::InvalidArgument(
                        "The input type should be half or float"));
  return input_types[0];
}

int PReluPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
                                const nvinfer1::PluginTensorDesc *output_desc,
184 185
                                const void *const *inputs,
                                void *const *outputs,
186 187
                                void *workspace,
                                cudaStream_t stream) TRT_NOEXCEPT {
188 189 190 191
  auto input_dims = input_desc[0].dims;
  const float *alpha = p_gpu_weight_;
  const float *input = static_cast<const float *>(inputs[0]);
  float *output = static_cast<float *>(outputs[0]);
192
  int numel = 1;
193
  for (int i = 0; i < input_dims.nbDims; i++) {
194
    numel *= input_dims.d[i];
195 196 197
  }

  if (mode_ == "channel") {
198
    bool channel_last = data_format_ == "NHWC";
199 200
    operators::math::PreluChannelWiseDirectCUDAFunctor<float>
        prelu_channel_wise;
201 202 203 204 205 206 207 208
    prelu_channel_wise(stream,
                       input,
                       alpha,
                       output,
                       input_dims.d[0],
                       input_dims.d[1],
                       channel_last,
                       numel);
209 210 211
  } else if (mode_ == "element") {
    operators::math::PreluElementWiseDirectCUDAFunctor<float>
        prelu_element_wise;
212
    prelu_element_wise(stream, input, alpha, output, input_dims.d[0], numel);
213 214
  } else {
    operators::math::PreluScalarDirectCUDAFunctor<float> prelu_scalar;
215
    prelu_scalar(stream, input, alpha, output, numel);
216 217 218 219 220
  }
  return cudaGetLastError() != cudaSuccess;
}
#endif

221
}  // namespace plugin
222 223 224
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle